update HyperNEAT;
All example can currently run!
This commit is contained in:
@@ -40,19 +40,22 @@ class HyperNEAT(BaseAlgorithm):
|
||||
output_transform=output_transform,
|
||||
)
|
||||
|
||||
def setup(self, randkey):
|
||||
return State(neat_state=self.neat.setup(randkey))
|
||||
def setup(self, state=State()):
|
||||
state = self.neat.setup(state)
|
||||
state = self.substrate.setup(state)
|
||||
return self.hyper_genome.setup(state)
|
||||
|
||||
def ask(self, state: State):
|
||||
return self.neat.ask(state.neat_state)
|
||||
return self.neat.ask(state)
|
||||
|
||||
def tell(self, state: State, fitness):
|
||||
return state.update(neat_state=self.neat.tell(state.neat_state, fitness))
|
||||
state = self.neat.tell(state, fitness)
|
||||
return state
|
||||
|
||||
def transform(self, individual):
|
||||
transformed = self.neat.transform(individual)
|
||||
query_res = jax.vmap(self.neat.forward, in_axes=(0, None))(
|
||||
self.substrate.query_coors, transformed
|
||||
def transform(self, state, individual):
|
||||
transformed = self.neat.transform(state, individual)
|
||||
query_res = jax.vmap(self.neat.forward, in_axes=(None, 0, None))(
|
||||
state, self.substrate.query_coors, transformed
|
||||
)
|
||||
|
||||
# mute the connection with weight below threshold
|
||||
@@ -74,12 +77,12 @@ class HyperNEAT(BaseAlgorithm):
|
||||
h_nodes, h_conns = self.substrate.make_nodes(
|
||||
query_res
|
||||
), self.substrate.make_conn(query_res)
|
||||
return self.hyper_genome.transform(h_nodes, h_conns)
|
||||
return self.hyper_genome.transform(state, h_nodes, h_conns)
|
||||
|
||||
def forward(self, inputs, transformed):
|
||||
def forward(self, state, inputs, transformed):
|
||||
# add bias
|
||||
inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])])
|
||||
return self.hyper_genome.forward(inputs_with_bias, transformed)
|
||||
return self.hyper_genome.forward(state, inputs_with_bias, transformed)
|
||||
|
||||
@property
|
||||
def num_inputs(self):
|
||||
@@ -94,10 +97,10 @@ class HyperNEAT(BaseAlgorithm):
|
||||
return self.neat.pop_size
|
||||
|
||||
def member_count(self, state: State):
|
||||
return self.neat.member_count(state.neat_state)
|
||||
return self.neat.member_count(state)
|
||||
|
||||
def generation(self, state: State):
|
||||
return self.neat.generation(state.neat_state)
|
||||
return self.neat.generation(state)
|
||||
|
||||
|
||||
class HyperNodeGene(BaseNodeGene):
|
||||
@@ -110,7 +113,7 @@ class HyperNodeGene(BaseNodeGene):
|
||||
self.activation = activation
|
||||
self.aggregation = aggregation
|
||||
|
||||
def forward(self, attrs, inputs, is_output_node=False):
|
||||
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||
return jax.lax.cond(
|
||||
is_output_node,
|
||||
lambda: self.aggregation(inputs), # output node does not need activation
|
||||
@@ -121,6 +124,6 @@ class HyperNodeGene(BaseNodeGene):
|
||||
class HyperNEATConnGene(BaseConnGene):
|
||||
custom_attrs = ["weight"]
|
||||
|
||||
def forward(self, attrs, inputs):
|
||||
def forward(self, state, attrs, inputs):
|
||||
weight = attrs[0]
|
||||
return inputs * weight
|
||||
|
||||
Reference in New Issue
Block a user