Files
tensorneat-mend/neat/function_factory.py
wls2002 ba369db0b2 Perfect!
Next is to connect with Evox!
2023-06-25 02:57:45 +08:00

134 lines
5.1 KiB
Python

import numpy as np
from jax import jit, vmap
from .genome import create_forward, topological_sort, unflatten_connections
from .operations import create_next_generation_then_speciate
def hash_symbols(symbols):
return symbols['P'], symbols['N'], symbols['C'], symbols['S']
class FunctionFactory:
"""
Creates and compiles functions used in the NEAT pipeline.
"""
def __init__(self, config, jit_config):
self.config = config
self.jit_config = jit_config
self.func_dict = {}
self.function_info = {}
# (inputs_nums, ) -> (outputs_nums, )
forward = create_forward(config) # input size (inputs_nums, )
# (batch_size, inputs_nums) -> (batch_size, outputs_nums)
batch_forward = vmap(forward, in_axes=(0, None, None, None))
# (pop_size, batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
pop_batch_forward = vmap(batch_forward, in_axes=(0, 0, 0, 0))
# (batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0))
self.function_info = {
"pop_unflatten_connections": {
'func': vmap(unflatten_connections),
'lowers': [
{'shape': ('P', 'N', 5), 'type': np.float32},
{'shape': ('P', 'C', 4), 'type': np.float32}
]
},
"pop_topological_sort": {
'func': vmap(topological_sort),
'lowers': [
{'shape': ('P', 'N', 5), 'type': np.float32},
{'shape': ('P', 2, 'N', 'N'), 'type': np.float32},
]
},
"batch_forward": {
'func': batch_forward,
'lowers': [
{'shape': (config['batch_size'], config['num_inputs']), 'type': np.float32},
{'shape': ('N', ), 'type': np.int32},
{'shape': ('N', 5), 'type': np.float32},
{'shape': (2, 'N', 'N'), 'type': np.float32}
]
},
"pop_batch_forward": {
'func': pop_batch_forward,
'lowers': [
{'shape': ('P', config['batch_size'], config['num_inputs']), 'type': np.float32},
{'shape': ('P', 'N'), 'type': np.int32},
{'shape': ('P', 'N', 5), 'type': np.float32},
{'shape': ('P', 2, 'N', 'N'), 'type': np.float32}
]
},
'common_forward': {
'func': common_forward,
'lowers': [
{'shape': (config['batch_size'], config['num_inputs']), 'type': np.float32},
{'shape': ('P', 'N'), 'type': np.int32},
{'shape': ('P', 'N', 5), 'type': np.float32},
{'shape': ('P', 2, 'N', 'N'), 'type': np.float32}
]
},
'create_next_generation_then_speciate': {
'func': create_next_generation_then_speciate,
'lowers': [
{'shape': (2, ), 'type': np.uint32}, # rand_key
{'shape': ('P', 'N', 5), 'type': np.float32}, # pop_nodes
{'shape': ('P', 'C', 4), 'type': np.float32}, # pop_cons
{'shape': ('P', ), 'type': np.int32}, # winner
{'shape': ('P', ), 'type': np.int32}, # loser
{'shape': ('P', ), 'type': bool}, # elite_mask
{'shape': ('P',), 'type': np.int32}, # new_node_keys
{'shape': ('S', 'N', 5), 'type': np.float32}, # center_nodes
{'shape': ('S', 'C', 4), 'type': np.float32}, # center_cons
{'shape': ('S', ), 'type': np.int32}, # species_keys
{'shape': (), 'type': np.int32}, # new_species_key_start
"jit_config"
]
}
}
def get(self, name, symbols):
if (name, hash_symbols(symbols)) not in self.func_dict:
self.compile(name, symbols)
return self.func_dict[name, hash_symbols(symbols)]
def compile(self, name, symbols):
# prepare function prototype
func = self.function_info[name]['func']
# prepare lower operands
lowers_operands = []
for lower in self.function_info[name]['lowers']:
if isinstance(lower, dict):
shape = list(lower['shape'])
for i, s in enumerate(shape):
if s in symbols:
shape[i] = symbols[s]
assert isinstance(shape[i], int)
lowers_operands.append(np.zeros(shape, dtype=lower['type']))
elif lower == "jit_config":
lowers_operands.append(self.jit_config)
else:
raise ValueError("Invalid lower operand")
# compile
compiled_func = jit(func).lower(*lowers_operands).compile()
# save for reuse
self.func_dict[name, hash_symbols(symbols)] = compiled_func