modify pipeline for "update_by_data";

fix bug in speciate. currently, node_delete and conn_delete can successfully work
This commit is contained in:
wls2002
2024-05-31 15:32:56 +08:00
parent 3ea9986bd4
commit 6aa9011043
12 changed files with 132 additions and 45 deletions

View File

@@ -19,9 +19,15 @@ class BaseAlgorithm:
"""transform the genome into a neural network"""
raise NotImplementedError
def restore(self, state, transformed):
raise NotImplementedError
def forward(self, state, inputs, transformed):
raise NotImplementedError
def update_by_batch(self, state, batch_input, transformed):
raise NotImplementedError
@property
def num_inputs(self):
raise NotImplementedError

View File

@@ -178,18 +178,25 @@ class DefaultMutation(BaseMutation):
def no(key_, nodes_, conns_):
return nodes_, conns_
nodes, conns = jax.lax.cond(
r1 < self.node_add, mutate_add_node, no, k1, nodes, conns
)
nodes, conns = jax.lax.cond(
r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns
)
nodes, conns = jax.lax.cond(
r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns
)
nodes, conns = jax.lax.cond(
r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns
)
if self.node_add > 0:
nodes, conns = jax.lax.cond(
r1 < self.node_add, mutate_add_node, no, k1, nodes, conns
)
if self.node_delete > 0:
nodes, conns = jax.lax.cond(
r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns
)
if self.conn_add > 0:
nodes, conns = jax.lax.cond(
r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns
)
if self.conn_delete > 0:
nodes, conns = jax.lax.cond(
r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns
)
return nodes, conns

View File

@@ -117,7 +117,9 @@ class DefaultGenome(BaseGenome):
def hit():
batch_ins, new_conn_attrs = jax.vmap(
self.conn_gene.update_by_batch, in_axes=(None, 1, 1), out_axes=(1, 1)
self.conn_gene.update_by_batch,
in_axes=(None, 1, 1),
out_axes=(1, 1),
)(state, u_conns_[:, :, i], batch_values)
batch_z, new_node_attrs = self.node_gene.update_by_batch(
state,
@@ -132,12 +134,12 @@ class DefaultGenome(BaseGenome):
u_conns_.at[:, :, i].set(new_conn_attrs),
)
# the val of input nodes is obtained by the task, not by calculation
(batch_values, nodes_attrs_, u_conns_) = jax.lax.cond(
jnp.isin(i, self.input_idx),
lambda: (batch_values, nodes_attrs_, u_conns_),
hit,
)
# the val of input nodes is obtained by the task, not by calculation
return batch_values, nodes_attrs_, u_conns_, idx + 1

View File

@@ -44,9 +44,15 @@ class NEAT(BaseAlgorithm):
nodes, conns = individual
return self.genome.transform(state, nodes, conns)
def restore(self, state, transformed):
return self.genome.restore(state, transformed)
def forward(self, state, inputs, transformed):
return self.genome.forward(state, inputs, transformed)
def update_by_batch(self, state, batch_input, transformed):
return self.genome.update_by_batch(state, batch_input, transformed)
@property
def num_inputs(self):
return self.genome.num_inputs

View File

@@ -113,6 +113,9 @@ class DefaultSpecies(BaseSpecies):
return state.pop_nodes, state.pop_conns
def update_species(self, state, fitness):
# set nan to -inf
fitness = jnp.where(jnp.isnan(fitness), -jnp.inf, fitness)
# update the fitness of each species
state, species_fitness = self.update_species_fitness(state, fitness)
@@ -121,6 +124,7 @@ class DefaultSpecies(BaseSpecies):
# sort species_info by their fitness. (also push nan to the end)
sort_indices = jnp.argsort(species_fitness)[::-1]
state = state.update(
species_keys=state.species_keys[sort_indices],
best_fitness=state.best_fitness[sort_indices],