from typing import Callable import jax from jax import vmap, numpy as jnp from .substrate import * from tensorneat.common import State, Act, Agg from tensorneat.algorithm import BaseAlgorithm, NEAT from tensorneat.genome import BaseNode, BaseConn, RecurrentGenome class HyperNEAT(BaseAlgorithm): def __init__( self, substrate: BaseSubstrate, neat: NEAT, weight_threshold: float = 0.3, max_weight: float = 5.0, aggregation: Callable = Agg.sum, activation: Callable = Act.sigmoid, activate_time: int = 10, output_transform: Callable = Act.standard_sigmoid, ): assert ( substrate.query_coors.shape[1] == neat.num_inputs ), "Query coors of Substrate should be equal to NEAT input size" self.substrate = substrate self.neat = neat self.weight_threshold = weight_threshold self.max_weight = max_weight self.hyper_genome = RecurrentGenome( num_inputs=substrate.num_inputs, num_outputs=substrate.num_outputs, max_nodes=substrate.nodes_cnt, max_conns=substrate.conns_cnt, node_gene=HyperNEATNode(aggregation, activation), conn_gene=HyperNEATConn(), activate_time=activate_time, output_transform=output_transform, ) self.pop_size = neat.pop_size def setup(self, state=State()): state = self.neat.setup(state) state = self.substrate.setup(state) return self.hyper_genome.setup(state) def ask(self, state): return self.neat.ask(state) def tell(self, state, fitness): state = self.neat.tell(state, fitness) return state def transform(self, state, individual): transformed = self.neat.transform(state, individual) query_res = vmap(self.neat.forward, in_axes=(None, None, 0))( state, transformed, self.substrate.query_coors ) # mute the connection with weight weight threshold query_res = jnp.where( (-self.weight_threshold < query_res) & (query_res < self.weight_threshold), 0.0, query_res, ) # make query res in range [-max_weight, max_weight] query_res = jnp.where( query_res > 0, query_res - self.weight_threshold, query_res ) query_res = jnp.where( query_res < 0, query_res + self.weight_threshold, query_res ) query_res = query_res / (1 - self.weight_threshold) * self.max_weight h_nodes, h_conns = self.substrate.make_nodes( query_res ), self.substrate.make_conn(query_res) return self.hyper_genome.transform(state, h_nodes, h_conns) def forward(self, state, transformed, inputs): # add bias inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])]) res = self.hyper_genome.forward(state, transformed, inputs_with_bias) return res @property def num_inputs(self): return self.substrate.num_inputs - 1 # remove bias @property def num_outputs(self): return self.substrate.num_outputs def show_details(self, state, fitness): return self.neat.show_details(state, fitness) class HyperNEATNode(BaseNode): def __init__( self, aggregation=Agg.sum, activation=Act.sigmoid, ): super().__init__() self.aggregation = aggregation self.activation = activation def forward(self, state, attrs, inputs, is_output_node=False): return jax.lax.cond( is_output_node, lambda: self.aggregation(inputs), # output node does not need activation lambda: self.activation(self.aggregation(inputs)), ) class HyperNEATConn(BaseConn): custom_attrs = ["weight"] def forward(self, state, attrs, inputs): weight = attrs[0] return inputs * weight