1. Add origin_node and origin_conn.
2. Change the behavior of crossover and mutation. Now, TensorNEAT will use all fix_attrs(include historical marker if it has one) as identifier for gene in crossover and distance calculation.
3. Other slightly change.
4. Add two related examples: xor_origin and hopper_origin
5. Add related test file.
This commit is contained in:
wls2002
2024-12-18 16:20:34 +08:00
parent e9a8110af5
commit ee1a2a8271
18 changed files with 667 additions and 204 deletions

View File

@@ -2,6 +2,7 @@ import jax
from jax import vmap, numpy as jnp
import numpy as np
from .gene import BaseGene
from tensorneat.common import fetch_first, I_INF
@@ -38,49 +39,27 @@ def valid_cnt(nodes_or_conns):
return jnp.sum(~jnp.isnan(nodes_or_conns[:, 0]))
def extract_node_attrs(node):
def extract_gene_attrs(gene: BaseGene, gene_array):
"""
node: Array(NL, )
extract the attributes of a node
extract the custom attributes of the gene
"""
return node[1:] # 0 is for idx
return gene_array[len(gene.fixed_attrs) :]
def set_node_attrs(node, attrs):
def set_gene_attrs(gene: BaseGene, gene_array, attrs):
"""
node: Array(NL, )
attrs: Array(NL-1, )
set the attributes of a node
set the custom attributes of the gene
"""
return node.at[1:].set(attrs) # 0 is for idx
return gene_array.at[len(gene.fixed_attrs) :].set(attrs)
def extract_conn_attrs(conn):
"""
conn: Array(CL, )
extract the attributes of a connection
"""
return conn[2:] # 0, 1 is for in-idx and out-idx
def set_conn_attrs(conn, attrs):
"""
conn: Array(CL, )
attrs: Array(CL-2, )
set the attributes of a connection
"""
return conn.at[2:].set(attrs) # 0, 1 is for in-idx and out-idx
def add_node(nodes, new_key: int, attrs):
def add_node(nodes, fix_attrs, custom_attrs):
"""
Add a new node to the genome.
The new node will place at the first NaN row.
"""
exist_keys = nodes[:, 0]
pos = fetch_first(jnp.isnan(exist_keys))
new_nodes = nodes.at[pos, 0].set(new_key)
return new_nodes.at[pos, 1:].set(attrs)
pos = fetch_first(jnp.isnan(nodes[:, 0]))
return nodes.at[pos].set(jnp.concatenate((fix_attrs, custom_attrs)))
def delete_node_by_pos(nodes, pos):
@@ -91,15 +70,13 @@ def delete_node_by_pos(nodes, pos):
return nodes.at[pos].set(jnp.nan)
def add_conn(conns, i_key, o_key, attrs):
def add_conn(conns, fix_attrs, custom_attrs):
"""
Add a new connection to the genome.
The new connection will place at the first NaN row.
"""
con_keys = conns[:, 0]
pos = fetch_first(jnp.isnan(con_keys))
new_conns = conns.at[pos, 0:2].set(jnp.array([i_key, o_key]))
return new_conns.at[pos, 2:].set(attrs)
pos = fetch_first(jnp.isnan(conns[:, 0]))
return conns.at[pos].set(jnp.concatenate((fix_attrs, custom_attrs)))
def delete_conn_by_pos(conns, pos):