update hyperneat and related examples

This commit is contained in:
root
2024-07-11 15:08:02 +08:00
parent 9bad577d89
commit 3cb5fbf581
7 changed files with 102 additions and 136 deletions

View File

@@ -1,12 +1,12 @@
from typing import Callable
import jax, jax.numpy as jnp
import jax
from jax import vmap, numpy as jnp
from tensorneat.common import State, Act, Agg
from .. import BaseAlgorithm, NEAT
from ..neat.gene import BaseNodeGene, BaseConnGene
from ..neat.genome import RecurrentGenome
from .substrate import *
from tensorneat.common import State, Act, Agg
from tensorneat.algorithm import BaseAlgorithm, NEAT
from tensorneat.genome import BaseNode, BaseConn, RecurrentGenome
class HyperNEAT(BaseAlgorithm):
@@ -14,64 +14,65 @@ class HyperNEAT(BaseAlgorithm):
self,
substrate: BaseSubstrate,
neat: NEAT,
below_threshold: float = 0.3,
weight_threshold: float = 0.3,
max_weight: float = 5.0,
aggregation=Agg.sum,
activation=Act.sigmoid,
aggregation: Callable = Agg.sum,
activation: Callable = Act.sigmoid,
activate_time: int = 10,
output_transform: Callable = Act.sigmoid,
output_transform: Callable = Act.standard_sigmoid,
):
assert (
substrate.query_coors.shape[1] == neat.num_inputs
), "Substrate input size should be equal to NEAT input size"
), "Query coors of Substrate should be equal to NEAT input size"
self.substrate = substrate
self.neat = neat
self.below_threshold = below_threshold
self.weight_threshold = weight_threshold
self.max_weight = max_weight
self.hyper_genome = RecurrentGenome(
num_inputs=substrate.num_inputs,
num_outputs=substrate.num_outputs,
max_nodes=substrate.nodes_cnt,
max_conns=substrate.conns_cnt,
node_gene=HyperNodeGene(aggregation, activation),
conn_gene=HyperNEATConnGene(),
node_gene=HyperNEATNode(aggregation, activation),
conn_gene=HyperNEATConn(),
activate_time=activate_time,
output_transform=output_transform,
)
self.pop_size = neat.pop_size
def setup(self, state=State()):
state = self.neat.setup(state)
state = self.substrate.setup(state)
return self.hyper_genome.setup(state)
def ask(self, state: State):
def ask(self, state):
return self.neat.ask(state)
def tell(self, state: State, fitness):
def tell(self, state, fitness):
state = self.neat.tell(state, fitness)
return state
def transform(self, state, individual):
transformed = self.neat.transform(state, individual)
query_res = jax.vmap(self.neat.forward, in_axes=(None, None, 0))(
query_res = vmap(self.neat.forward, in_axes=(None, None, 0))(
state, transformed, self.substrate.query_coors
)
# mute the connection with weight below threshold
# mute the connection with weight weight threshold
query_res = jnp.where(
(-self.below_threshold < query_res) & (query_res < self.below_threshold),
(-self.weight_threshold < query_res) & (query_res < self.weight_threshold),
0.0,
query_res,
)
# make query res in range [-max_weight, max_weight]
query_res = jnp.where(
query_res > 0, query_res - self.below_threshold, query_res
query_res > 0, query_res - self.weight_threshold, query_res
)
query_res = jnp.where(
query_res < 0, query_res + self.below_threshold, query_res
query_res < 0, query_res + self.weight_threshold, query_res
)
query_res = query_res / (1 - self.below_threshold) * self.max_weight
query_res = query_res / (1 - self.weight_threshold) * self.max_weight
h_nodes, h_conns = self.substrate.make_nodes(
query_res
@@ -79,11 +80,11 @@ class HyperNEAT(BaseAlgorithm):
return self.hyper_genome.transform(state, h_nodes, h_conns)
def forward(self, state, inputs, transformed):
def forward(self, state, transformed, inputs):
# add bias
inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])])
res = self.hyper_genome.forward(state, inputs_with_bias, transformed)
res = self.hyper_genome.forward(state, transformed, inputs_with_bias)
return res
@property
@@ -94,18 +95,11 @@ class HyperNEAT(BaseAlgorithm):
def num_outputs(self):
return self.substrate.num_outputs
@property
def pop_size(self):
return self.neat.pop_size
def member_count(self, state: State):
return self.neat.member_count(state)
def generation(self, state: State):
return self.neat.generation(state)
def show_details(self, state, fitness):
return self.neat.show_details(state, fitness)
class HyperNodeGene(BaseNodeGene):
class HyperNEATNode(BaseNode):
def __init__(
self,
aggregation=Agg.sum,
@@ -123,7 +117,7 @@ class HyperNodeGene(BaseNodeGene):
)
class HyperNEATConnGene(BaseConnGene):
class HyperNEATConn(BaseConn):
custom_attrs = ["weight"]
def forward(self, state, attrs, inputs):