Files
tensorneat-mend/tensorneat/algorithm/hyperneat/hyperneat.py
2024-07-11 15:08:02 +08:00

126 lines
3.9 KiB
Python

from typing import Callable
import jax
from jax import vmap, numpy as jnp
from .substrate import *
from tensorneat.common import State, Act, Agg
from tensorneat.algorithm import BaseAlgorithm, NEAT
from tensorneat.genome import BaseNode, BaseConn, RecurrentGenome
class HyperNEAT(BaseAlgorithm):
def __init__(
self,
substrate: BaseSubstrate,
neat: NEAT,
weight_threshold: float = 0.3,
max_weight: float = 5.0,
aggregation: Callable = Agg.sum,
activation: Callable = Act.sigmoid,
activate_time: int = 10,
output_transform: Callable = Act.standard_sigmoid,
):
assert (
substrate.query_coors.shape[1] == neat.num_inputs
), "Query coors of Substrate should be equal to NEAT input size"
self.substrate = substrate
self.neat = neat
self.weight_threshold = weight_threshold
self.max_weight = max_weight
self.hyper_genome = RecurrentGenome(
num_inputs=substrate.num_inputs,
num_outputs=substrate.num_outputs,
max_nodes=substrate.nodes_cnt,
max_conns=substrate.conns_cnt,
node_gene=HyperNEATNode(aggregation, activation),
conn_gene=HyperNEATConn(),
activate_time=activate_time,
output_transform=output_transform,
)
self.pop_size = neat.pop_size
def setup(self, state=State()):
state = self.neat.setup(state)
state = self.substrate.setup(state)
return self.hyper_genome.setup(state)
def ask(self, state):
return self.neat.ask(state)
def tell(self, state, fitness):
state = self.neat.tell(state, fitness)
return state
def transform(self, state, individual):
transformed = self.neat.transform(state, individual)
query_res = vmap(self.neat.forward, in_axes=(None, None, 0))(
state, transformed, self.substrate.query_coors
)
# mute the connection with weight weight threshold
query_res = jnp.where(
(-self.weight_threshold < query_res) & (query_res < self.weight_threshold),
0.0,
query_res,
)
# make query res in range [-max_weight, max_weight]
query_res = jnp.where(
query_res > 0, query_res - self.weight_threshold, query_res
)
query_res = jnp.where(
query_res < 0, query_res + self.weight_threshold, query_res
)
query_res = query_res / (1 - self.weight_threshold) * self.max_weight
h_nodes, h_conns = self.substrate.make_nodes(
query_res
), self.substrate.make_conn(query_res)
return self.hyper_genome.transform(state, h_nodes, h_conns)
def forward(self, state, transformed, inputs):
# add bias
inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])])
res = self.hyper_genome.forward(state, transformed, inputs_with_bias)
return res
@property
def num_inputs(self):
return self.substrate.num_inputs - 1 # remove bias
@property
def num_outputs(self):
return self.substrate.num_outputs
def show_details(self, state, fitness):
return self.neat.show_details(state, fitness)
class HyperNEATNode(BaseNode):
def __init__(
self,
aggregation=Agg.sum,
activation=Act.sigmoid,
):
super().__init__()
self.aggregation = aggregation
self.activation = activation
def forward(self, state, attrs, inputs, is_output_node=False):
return jax.lax.cond(
is_output_node,
lambda: self.aggregation(inputs), # output node does not need activation
lambda: self.activation(self.aggregation(inputs)),
)
class HyperNEATConn(BaseConn):
custom_attrs = ["weight"]
def forward(self, state, attrs, inputs):
weight = attrs[0]
return inputs * weight