This commit is related to issue: https://github.com/EMI-Group/tensorneat/issues/11
1. Add origin_node and origin_conn. 2. Change the behavior of crossover and mutation. Now, TensorNEAT will use all fix_attrs(include historical marker if it has one) as identifier for gene in crossover and distance calculation. 3. Other slightly change. 4. Add two related examples: xor_origin and hopper_origin 5. Add related test file.
This commit is contained in:
@@ -2,7 +2,7 @@ from jax import vmap
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseSubstrate
|
||||
from tensorneat.genome.utils import set_conn_attrs
|
||||
from tensorneat.genome.utils import set_gene_attrs
|
||||
|
||||
|
||||
class DefaultSubstrate(BaseSubstrate):
|
||||
@@ -21,7 +21,8 @@ class DefaultSubstrate(BaseSubstrate):
|
||||
|
||||
def make_conns(self, query_res):
|
||||
# change weight of conns
|
||||
return vmap(set_conn_attrs)(self.conns, query_res)
|
||||
# the last column is the weight
|
||||
return self.conns.at[:, -1].set(query_res)
|
||||
|
||||
@property
|
||||
def query_coors(self):
|
||||
|
||||
@@ -116,6 +116,25 @@ class NEAT(BaseAlgorithm):
|
||||
next_node_key = max_node_key + 1
|
||||
new_node_keys = jnp.arange(self.pop_size) + next_node_key
|
||||
|
||||
# find next conn historical markers for mutation if needed
|
||||
if "historical_marker" in self.genome.conn_gene.fixed_attrs:
|
||||
all_conns_markers = vmap(
|
||||
self.genome.conn_gene.get_historical_marker, in_axes=(None, 0)
|
||||
)(state, state.pop_conns)
|
||||
|
||||
max_conn_markers = jnp.max(
|
||||
all_conns_markers, where=~jnp.isnan(all_conns_markers), initial=0
|
||||
)
|
||||
next_conn_markers = max_conn_markers + 1
|
||||
new_conn_markers = (
|
||||
jnp.arange(self.pop_size * 3).reshape(self.pop_size, 3)
|
||||
+ next_conn_markers
|
||||
)
|
||||
else:
|
||||
# no need to generate new conn historical markers
|
||||
# use 0
|
||||
new_conn_markers = jnp.full((self.pop_size, 3), 0)
|
||||
|
||||
# prepare random keys
|
||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||
crossover_randkeys = jax.random.split(k1, self.pop_size)
|
||||
@@ -133,9 +152,9 @@ class NEAT(BaseAlgorithm):
|
||||
|
||||
# batch mutation
|
||||
m_n_nodes, m_n_conns = vmap(
|
||||
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
|
||||
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0, 0)
|
||||
)(
|
||||
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
|
||||
state, mutate_randkeys, n_nodes, n_conns, new_node_keys, new_conn_markers
|
||||
) # mutated_new_nodes, mutated_new_conns
|
||||
|
||||
# elitism don't mutate
|
||||
|
||||
Reference in New Issue
Block a user