modify pipeline for "update_by_data";
fix bug in speciate. currently, node_delete and conn_delete can successfully work
This commit is contained in:
@@ -19,9 +19,15 @@ class BaseAlgorithm:
|
||||
"""transform the genome into a neural network"""
|
||||
raise NotImplementedError
|
||||
|
||||
def restore(self, state, transformed):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, state, inputs, transformed):
|
||||
raise NotImplementedError
|
||||
|
||||
def update_by_batch(self, state, batch_input, transformed):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def num_inputs(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -178,18 +178,25 @@ class DefaultMutation(BaseMutation):
|
||||
def no(key_, nodes_, conns_):
|
||||
return nodes_, conns_
|
||||
|
||||
nodes, conns = jax.lax.cond(
|
||||
r1 < self.node_add, mutate_add_node, no, k1, nodes, conns
|
||||
)
|
||||
nodes, conns = jax.lax.cond(
|
||||
r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns
|
||||
)
|
||||
nodes, conns = jax.lax.cond(
|
||||
r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns
|
||||
)
|
||||
nodes, conns = jax.lax.cond(
|
||||
r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns
|
||||
)
|
||||
if self.node_add > 0:
|
||||
nodes, conns = jax.lax.cond(
|
||||
r1 < self.node_add, mutate_add_node, no, k1, nodes, conns
|
||||
)
|
||||
|
||||
if self.node_delete > 0:
|
||||
nodes, conns = jax.lax.cond(
|
||||
r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns
|
||||
)
|
||||
|
||||
if self.conn_add > 0:
|
||||
nodes, conns = jax.lax.cond(
|
||||
r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns
|
||||
)
|
||||
|
||||
if self.conn_delete > 0:
|
||||
nodes, conns = jax.lax.cond(
|
||||
r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns
|
||||
)
|
||||
|
||||
return nodes, conns
|
||||
|
||||
|
||||
@@ -117,7 +117,9 @@ class DefaultGenome(BaseGenome):
|
||||
|
||||
def hit():
|
||||
batch_ins, new_conn_attrs = jax.vmap(
|
||||
self.conn_gene.update_by_batch, in_axes=(None, 1, 1), out_axes=(1, 1)
|
||||
self.conn_gene.update_by_batch,
|
||||
in_axes=(None, 1, 1),
|
||||
out_axes=(1, 1),
|
||||
)(state, u_conns_[:, :, i], batch_values)
|
||||
batch_z, new_node_attrs = self.node_gene.update_by_batch(
|
||||
state,
|
||||
@@ -132,12 +134,12 @@ class DefaultGenome(BaseGenome):
|
||||
u_conns_.at[:, :, i].set(new_conn_attrs),
|
||||
)
|
||||
|
||||
# the val of input nodes is obtained by the task, not by calculation
|
||||
(batch_values, nodes_attrs_, u_conns_) = jax.lax.cond(
|
||||
jnp.isin(i, self.input_idx),
|
||||
lambda: (batch_values, nodes_attrs_, u_conns_),
|
||||
hit,
|
||||
)
|
||||
# the val of input nodes is obtained by the task, not by calculation
|
||||
|
||||
return batch_values, nodes_attrs_, u_conns_, idx + 1
|
||||
|
||||
|
||||
@@ -44,9 +44,15 @@ class NEAT(BaseAlgorithm):
|
||||
nodes, conns = individual
|
||||
return self.genome.transform(state, nodes, conns)
|
||||
|
||||
def restore(self, state, transformed):
|
||||
return self.genome.restore(state, transformed)
|
||||
|
||||
def forward(self, state, inputs, transformed):
|
||||
return self.genome.forward(state, inputs, transformed)
|
||||
|
||||
def update_by_batch(self, state, batch_input, transformed):
|
||||
return self.genome.update_by_batch(state, batch_input, transformed)
|
||||
|
||||
@property
|
||||
def num_inputs(self):
|
||||
return self.genome.num_inputs
|
||||
|
||||
@@ -113,6 +113,9 @@ class DefaultSpecies(BaseSpecies):
|
||||
return state.pop_nodes, state.pop_conns
|
||||
|
||||
def update_species(self, state, fitness):
|
||||
# set nan to -inf
|
||||
fitness = jnp.where(jnp.isnan(fitness), -jnp.inf, fitness)
|
||||
|
||||
# update the fitness of each species
|
||||
state, species_fitness = self.update_species_fitness(state, fitness)
|
||||
|
||||
@@ -121,6 +124,7 @@ class DefaultSpecies(BaseSpecies):
|
||||
|
||||
# sort species_info by their fitness. (also push nan to the end)
|
||||
sort_indices = jnp.argsort(species_fitness)[::-1]
|
||||
|
||||
state = state.update(
|
||||
species_keys=state.species_keys[sort_indices],
|
||||
best_fitness=state.best_fitness[sort_indices],
|
||||
|
||||
Reference in New Issue
Block a user