1. Add origin_node and origin_conn.
2. Change the behavior of crossover and mutation. Now, TensorNEAT will use all fix_attrs(include historical marker if it has one) as identifier for gene in crossover and distance calculation.
3. Other slightly change.
4. Add two related examples: xor_origin and hopper_origin
5. Add related test file.
This commit is contained in:
wls2002
2024-12-18 16:20:34 +08:00
parent e9a8110af5
commit ee1a2a8271
18 changed files with 667 additions and 204 deletions

View File

@@ -1,2 +1,3 @@
from .base import BaseConn
from .default import DefaultConn
from .origin import OriginConn

View File

@@ -0,0 +1,60 @@
import jax, jax.numpy as jnp
from .default import DefaultConn
class OriginConn(DefaultConn):
"""
Implementation of connections in origin NEAT Paper.
Details at https://github.com/EMI-Group/tensorneat/issues/11.
"""
# add historical_marker into fixed_attrs
fixed_attrs = ["input_index", "output_index", "historical_marker"]
custom_attrs = ["weight"]
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
def crossover(self, state, randkey, attrs1, attrs2):
# random pick one of attrs, without attrs exchange
return jnp.where(
# origin code, generate multiple random numbers, without attrs exchange
# jax.random.normal(randkey, attrs1.shape) > 0,
jax.random.normal(randkey)
> 0, # generate one random number, without attrs exchange
attrs1,
attrs2,
)
def get_historical_marker(self, state, gene_array):
return gene_array[2]
def repr(self, state, conn, precision=2, idx_width=3, func_width=8):
in_idx, out_idx, historical_marker, weight = conn
in_idx = int(in_idx)
out_idx = int(out_idx)
historical_marker = int(historical_marker)
weight = round(float(weight), precision)
return "{}(in: {:<{idx_width}}, out: {:<{idx_width}}, historical_marker: {:<{idx_width}}, weight: {:<{float_width}})".format(
self.__class__.__name__,
in_idx,
out_idx,
historical_marker,
weight,
idx_width=idx_width,
float_width=precision + 3,
)
def to_dict(self, state, conn):
return {
"in": int(conn[0]),
"out": int(conn[1]),
"historical_marker": int(conn[2]),
"weight": jnp.float32(conn[3]),
}

View File

@@ -1,3 +1,4 @@
from .base import BaseNode
from .default import DefaultNode
from .bias import BiasNode
from .origin import OriginNode

View File

@@ -0,0 +1,27 @@
import jax, jax.numpy as jnp
from .default import DefaultNode
class OriginNode(DefaultNode):
"""
Implementation of nodes in origin NEAT Paper.
Details at https://github.com/EMI-Group/tensorneat/issues/11.
"""
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
def crossover(self, state, randkey, attrs1, attrs2):
# random pick one of attrs, without attrs exchange
return jnp.where(
# origin code, generate multiple random numbers, without attrs exchange
# jax.random.normal(randkey, attrs1.shape) > 0,
jax.random.normal(randkey)
> 0, # generate one random number, without attrs exchange
attrs1,
attrs2,
)