import jax, jax.numpy as jnp from utils import State, Act, Agg from .. import BaseAlgorithm, NEAT from ..neat.gene import BaseNodeGene, BaseConnGene from ..neat.genome import RecurrentGenome from .substrate import * class HyperNEAT(BaseAlgorithm): def __init__( self, substrate: BaseSubstrate, neat: NEAT, below_threshold: float = 0.3, max_weight: float = 5., activation=Act.sigmoid, aggregation=Agg.sum, activate_time: int = 10, ): assert substrate.query_coors.shape[1] == neat.num_inputs, \ "Substrate input size should be equal to NEAT input size" self.substrate = substrate self.neat = neat self.below_threshold = below_threshold self.max_weight = max_weight self.hyper_genome = RecurrentGenome( num_inputs=substrate.num_inputs, num_outputs=substrate.num_outputs, max_nodes=substrate.nodes_cnt, max_conns=substrate.conns_cnt, node_gene=HyperNodeGene(activation, aggregation), conn_gene=HyperNEATConnGene(), activate_time=activate_time, ) def setup(self, randkey): return State( neat_state=self.neat.setup(randkey) ) def ask(self, state: State): return self.neat.ask(state.neat_state) def tell(self, state: State, fitness): return state.update( neat_state=self.neat.tell(state.neat_state, fitness) ) def transform(self, individual): transformed = self.neat.transform(individual) query_res = jax.vmap(self.neat.forward, in_axes=(0, None))(self.substrate.query_coors, transformed) # mute the connection with weight below threshold query_res = jnp.where( (-self.below_threshold < query_res) & (query_res < self.below_threshold), 0., query_res ) # make query res in range [-max_weight, max_weight] query_res = jnp.where(query_res > 0, query_res - self.below_threshold, query_res) query_res = jnp.where(query_res < 0, query_res + self.below_threshold, query_res) query_res = query_res / (1 - self.below_threshold) * self.max_weight h_nodes, h_conns = self.substrate.make_nodes(query_res), self.substrate.make_conn(query_res) return self.hyper_genome.transform(h_nodes, h_conns) def forward(self, inputs, transformed): # add bias inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])]) return self.hyper_genome.forward(inputs_with_bias, transformed) @property def num_inputs(self): return self.substrate.num_inputs - 1 # remove bias @property def num_outputs(self): return self.substrate.num_outputs @property def pop_size(self): return self.neat.pop_size def member_count(self, state: State): return self.neat.member_count(state.neat_state) def generation(self, state: State): return self.neat.generation(state.neat_state) class HyperNodeGene(BaseNodeGene): def __init__(self, activation=Act.sigmoid, aggregation=Agg.sum, ): super().__init__() self.activation = activation self.aggregation = aggregation def forward(self, attrs, inputs): return self.activation( self.aggregation(inputs) ) class HyperNEATConnGene(BaseConnGene): custom_attrs = ['weight'] def forward(self, attrs, inputs): weight = attrs[0] return inputs * weight