71 lines
2.4 KiB
Python
71 lines
2.4 KiB
Python
from typing import Type
|
|
|
|
import jax
|
|
import numpy as np
|
|
|
|
from .substrate import BaseSubstrate, analysis_substrate
|
|
from .hyperneat_gene import HyperNEATGene
|
|
from algorithm import State, Algorithm, neat
|
|
|
|
|
|
class HyperNEAT(Algorithm):
|
|
|
|
def __init__(self, config, gene_type: Type[neat.BaseGene], substrate: Type[BaseSubstrate]):
|
|
super().__init__()
|
|
self.config = config
|
|
self.gene_type = gene_type
|
|
self.substrate = substrate
|
|
self.neat = neat.NEAT(config, gene_type)
|
|
|
|
self.tell = create_tell(self.neat)
|
|
self.forward_transform = create_forward_transform(config, self.neat)
|
|
self.forward = HyperNEATGene.create_forward(config)
|
|
|
|
def setup(self, randkey, state=State()):
|
|
state = state.update(
|
|
below_threshold=self.config['below_threshold'],
|
|
max_weight=self.config['max_weight']
|
|
)
|
|
|
|
state = self.substrate.setup(state, self.config)
|
|
h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state)
|
|
h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis]
|
|
h_conns = np.zeros((correspond_keys.shape[0], 3), dtype=np.float32)
|
|
h_conns[:, 0:2] = correspond_keys
|
|
|
|
state = state.update(
|
|
# h is short for hyperneat
|
|
h_input_idx=h_input_idx,
|
|
h_output_idx=h_output_idx,
|
|
h_hidden_idx=h_hidden_idx,
|
|
query_coors=query_coors,
|
|
correspond_keys=correspond_keys,
|
|
h_nodes=h_nodes,
|
|
h_conns=h_conns
|
|
)
|
|
state = self.neat.setup(randkey, state=state)
|
|
|
|
self.config['h_input_idx'] = h_input_idx
|
|
self.config['h_output_idx'] = h_output_idx
|
|
|
|
return state
|
|
|
|
|
|
def create_tell(neat_instance):
|
|
def tell(state, fitness):
|
|
return neat_instance.tell(state, fitness)
|
|
|
|
return tell
|
|
|
|
|
|
def create_forward_transform(config, neat_instance):
|
|
def forward_transform(state, nodes, conns):
|
|
t = neat_instance.forward_transform(state, nodes, conns)
|
|
batch_forward_func = jax.vmap(neat_instance.forward, in_axes=(0, None))
|
|
query_res = batch_forward_func(state.query_coors, t) # hyperneat connections
|
|
h_nodes = state.h_nodes
|
|
h_conns = state.h_conns.at[:, 2:].set(query_res)
|
|
return HyperNEATGene.forward_transform(state, h_nodes, h_conns)
|
|
|
|
return forward_transform
|