1. Add origin_node and origin_conn.
2. Change the behavior of crossover and mutation. Now, TensorNEAT will use all fix_attrs(include historical marker if it has one) as identifier for gene in crossover and distance calculation.
3. Other slightly change.
4. Add two related examples: xor_origin and hopper_origin
5. Add related test file.
This commit is contained in:
wls2002
2024-12-18 16:20:34 +08:00
parent e9a8110af5
commit ee1a2a8271
18 changed files with 667 additions and 204 deletions

View File

@@ -116,6 +116,25 @@ class NEAT(BaseAlgorithm):
next_node_key = max_node_key + 1
new_node_keys = jnp.arange(self.pop_size) + next_node_key
# find next conn historical markers for mutation if needed
if "historical_marker" in self.genome.conn_gene.fixed_attrs:
all_conns_markers = vmap(
self.genome.conn_gene.get_historical_marker, in_axes=(None, 0)
)(state, state.pop_conns)
max_conn_markers = jnp.max(
all_conns_markers, where=~jnp.isnan(all_conns_markers), initial=0
)
next_conn_markers = max_conn_markers + 1
new_conn_markers = (
jnp.arange(self.pop_size * 3).reshape(self.pop_size, 3)
+ next_conn_markers
)
else:
# no need to generate new conn historical markers
# use 0
new_conn_markers = jnp.full((self.pop_size, 3), 0)
# prepare random keys
k1, k2, randkey = jax.random.split(state.randkey, 3)
crossover_randkeys = jax.random.split(k1, self.pop_size)
@@ -133,9 +152,9 @@ class NEAT(BaseAlgorithm):
# batch mutation
m_n_nodes, m_n_conns = vmap(
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0, 0)
)(
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
state, mutate_randkeys, n_nodes, n_conns, new_node_keys, new_conn_markers
) # mutated_new_nodes, mutated_new_conns
# elitism don't mutate