update HyperNEAT;

All example can currently run!
This commit is contained in:
wls2002
2024-05-26 19:51:22 +08:00
parent 18c3d44c79
commit 9f6154d128
15 changed files with 112 additions and 78 deletions

View File

@@ -40,19 +40,22 @@ class HyperNEAT(BaseAlgorithm):
output_transform=output_transform,
)
def setup(self, randkey):
return State(neat_state=self.neat.setup(randkey))
def setup(self, state=State()):
state = self.neat.setup(state)
state = self.substrate.setup(state)
return self.hyper_genome.setup(state)
def ask(self, state: State):
return self.neat.ask(state.neat_state)
return self.neat.ask(state)
def tell(self, state: State, fitness):
return state.update(neat_state=self.neat.tell(state.neat_state, fitness))
state = self.neat.tell(state, fitness)
return state
def transform(self, individual):
transformed = self.neat.transform(individual)
query_res = jax.vmap(self.neat.forward, in_axes=(0, None))(
self.substrate.query_coors, transformed
def transform(self, state, individual):
transformed = self.neat.transform(state, individual)
query_res = jax.vmap(self.neat.forward, in_axes=(None, 0, None))(
state, self.substrate.query_coors, transformed
)
# mute the connection with weight below threshold
@@ -74,12 +77,12 @@ class HyperNEAT(BaseAlgorithm):
h_nodes, h_conns = self.substrate.make_nodes(
query_res
), self.substrate.make_conn(query_res)
return self.hyper_genome.transform(h_nodes, h_conns)
return self.hyper_genome.transform(state, h_nodes, h_conns)
def forward(self, inputs, transformed):
def forward(self, state, inputs, transformed):
# add bias
inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])])
return self.hyper_genome.forward(inputs_with_bias, transformed)
return self.hyper_genome.forward(state, inputs_with_bias, transformed)
@property
def num_inputs(self):
@@ -94,10 +97,10 @@ class HyperNEAT(BaseAlgorithm):
return self.neat.pop_size
def member_count(self, state: State):
return self.neat.member_count(state.neat_state)
return self.neat.member_count(state)
def generation(self, state: State):
return self.neat.generation(state.neat_state)
return self.neat.generation(state)
class HyperNodeGene(BaseNodeGene):
@@ -110,7 +113,7 @@ class HyperNodeGene(BaseNodeGene):
self.activation = activation
self.aggregation = aggregation
def forward(self, attrs, inputs, is_output_node=False):
def forward(self, state, attrs, inputs, is_output_node=False):
return jax.lax.cond(
is_output_node,
lambda: self.aggregation(inputs), # output node does not need activation
@@ -121,6 +124,6 @@ class HyperNodeGene(BaseNodeGene):
class HyperNEATConnGene(BaseConnGene):
custom_attrs = ["weight"]
def forward(self, attrs, inputs):
def forward(self, state, attrs, inputs):
weight = attrs[0]
return inputs * weight