Perfect!
Next is to connect with Evox!
This commit is contained in:
@@ -1,10 +1,8 @@
|
||||
import numpy as np
|
||||
from jax import jit, vmap
|
||||
|
||||
from .genome.forward import create_forward
|
||||
from .genome.utils import unflatten_connections
|
||||
from .genome.graph import topological_sort
|
||||
|
||||
from .genome import create_forward, topological_sort, unflatten_connections
|
||||
from .operations import create_next_generation_then_speciate
|
||||
|
||||
def hash_symbols(symbols):
|
||||
return symbols['P'], symbols['N'], symbols['C'], symbols['S']
|
||||
@@ -15,8 +13,10 @@ class FunctionFactory:
|
||||
Creates and compiles functions used in the NEAT pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, jit_config):
|
||||
self.config = config
|
||||
self.jit_config = jit_config
|
||||
|
||||
self.func_dict = {}
|
||||
self.function_info = {}
|
||||
|
||||
@@ -78,6 +78,24 @@ class FunctionFactory:
|
||||
{'shape': ('P', 'N', 5), 'type': np.float32},
|
||||
{'shape': ('P', 2, 'N', 'N'), 'type': np.float32}
|
||||
]
|
||||
},
|
||||
|
||||
'create_next_generation_then_speciate': {
|
||||
'func': create_next_generation_then_speciate,
|
||||
'lowers': [
|
||||
{'shape': (2, ), 'type': np.uint32}, # rand_key
|
||||
{'shape': ('P', 'N', 5), 'type': np.float32}, # pop_nodes
|
||||
{'shape': ('P', 'C', 4), 'type': np.float32}, # pop_cons
|
||||
{'shape': ('P', ), 'type': np.int32}, # winner
|
||||
{'shape': ('P', ), 'type': np.int32}, # loser
|
||||
{'shape': ('P', ), 'type': bool}, # elite_mask
|
||||
{'shape': ('P',), 'type': np.int32}, # new_node_keys
|
||||
{'shape': ('S', 'N', 5), 'type': np.float32}, # center_nodes
|
||||
{'shape': ('S', 'C', 4), 'type': np.float32}, # center_cons
|
||||
{'shape': ('S', ), 'type': np.int32}, # species_keys
|
||||
{'shape': (), 'type': np.int32}, # new_species_key_start
|
||||
"jit_config"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,12 +112,19 @@ class FunctionFactory:
|
||||
# prepare lower operands
|
||||
lowers_operands = []
|
||||
for lower in self.function_info[name]['lowers']:
|
||||
shape = list(lower['shape'])
|
||||
for i, s in enumerate(shape):
|
||||
if s in symbols:
|
||||
shape[i] = symbols[s]
|
||||
assert isinstance(shape[i], int)
|
||||
lowers_operands.append(np.zeros(shape, dtype=lower['type']))
|
||||
if isinstance(lower, dict):
|
||||
shape = list(lower['shape'])
|
||||
for i, s in enumerate(shape):
|
||||
if s in symbols:
|
||||
shape[i] = symbols[s]
|
||||
assert isinstance(shape[i], int)
|
||||
lowers_operands.append(np.zeros(shape, dtype=lower['type']))
|
||||
|
||||
elif lower == "jit_config":
|
||||
lowers_operands.append(self.jit_config)
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid lower operand")
|
||||
|
||||
# compile
|
||||
compiled_func = jit(func).lower(*lowers_operands).compile()
|
||||
|
||||
Reference in New Issue
Block a user