Next is to connect with Evox!
This commit is contained in:
wls2002
2023-06-25 02:57:45 +08:00
parent 0cb2f9473d
commit ba369db0b2
14 changed files with 392 additions and 268 deletions

View File

@@ -1,10 +1,8 @@
import numpy as np
from jax import jit, vmap
from .genome.forward import create_forward
from .genome.utils import unflatten_connections
from .genome.graph import topological_sort
from .genome import create_forward, topological_sort, unflatten_connections
from .operations import create_next_generation_then_speciate
def hash_symbols(symbols):
return symbols['P'], symbols['N'], symbols['C'], symbols['S']
@@ -15,8 +13,10 @@ class FunctionFactory:
Creates and compiles functions used in the NEAT pipeline.
"""
def __init__(self, config):
def __init__(self, config, jit_config):
self.config = config
self.jit_config = jit_config
self.func_dict = {}
self.function_info = {}
@@ -78,6 +78,24 @@ class FunctionFactory:
{'shape': ('P', 'N', 5), 'type': np.float32},
{'shape': ('P', 2, 'N', 'N'), 'type': np.float32}
]
},
'create_next_generation_then_speciate': {
'func': create_next_generation_then_speciate,
'lowers': [
{'shape': (2, ), 'type': np.uint32}, # rand_key
{'shape': ('P', 'N', 5), 'type': np.float32}, # pop_nodes
{'shape': ('P', 'C', 4), 'type': np.float32}, # pop_cons
{'shape': ('P', ), 'type': np.int32}, # winner
{'shape': ('P', ), 'type': np.int32}, # loser
{'shape': ('P', ), 'type': bool}, # elite_mask
{'shape': ('P',), 'type': np.int32}, # new_node_keys
{'shape': ('S', 'N', 5), 'type': np.float32}, # center_nodes
{'shape': ('S', 'C', 4), 'type': np.float32}, # center_cons
{'shape': ('S', ), 'type': np.int32}, # species_keys
{'shape': (), 'type': np.int32}, # new_species_key_start
"jit_config"
]
}
}
@@ -94,12 +112,19 @@ class FunctionFactory:
# prepare lower operands
lowers_operands = []
for lower in self.function_info[name]['lowers']:
shape = list(lower['shape'])
for i, s in enumerate(shape):
if s in symbols:
shape[i] = symbols[s]
assert isinstance(shape[i], int)
lowers_operands.append(np.zeros(shape, dtype=lower['type']))
if isinstance(lower, dict):
shape = list(lower['shape'])
for i, s in enumerate(shape):
if s in symbols:
shape[i] = symbols[s]
assert isinstance(shape[i], int)
lowers_operands.append(np.zeros(shape, dtype=lower['type']))
elif lower == "jit_config":
lowers_operands.append(self.jit_config)
else:
raise ValueError("Invalid lower operand")
# compile
compiled_func = jit(func).lower(*lowers_operands).compile()