Files
tensorneat-mend/src/tensorneat/genome/gene/conn/origin.py
wls2002 ee1a2a8271 This commit is related to issue: https://github.com/EMI-Group/tensorneat/issues/11
1. Add origin_node and origin_conn.
2. Change the behavior of crossover and mutation. Now, TensorNEAT will use all fix_attrs(include historical marker if it has one) as identifier for gene in crossover and distance calculation.
3. Other slightly change.
4. Add two related examples: xor_origin and hopper_origin
5. Add related test file.
2024-12-18 16:20:34 +08:00

60 lines
1.8 KiB
Python

import jax, jax.numpy as jnp
from .default import DefaultConn
class OriginConn(DefaultConn):
"""
Implementation of connections in origin NEAT Paper.
Details at https://github.com/EMI-Group/tensorneat/issues/11.
"""
# add historical_marker into fixed_attrs
fixed_attrs = ["input_index", "output_index", "historical_marker"]
custom_attrs = ["weight"]
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
def crossover(self, state, randkey, attrs1, attrs2):
# random pick one of attrs, without attrs exchange
return jnp.where(
# origin code, generate multiple random numbers, without attrs exchange
# jax.random.normal(randkey, attrs1.shape) > 0,
jax.random.normal(randkey)
> 0, # generate one random number, without attrs exchange
attrs1,
attrs2,
)
def get_historical_marker(self, state, gene_array):
return gene_array[2]
def repr(self, state, conn, precision=2, idx_width=3, func_width=8):
in_idx, out_idx, historical_marker, weight = conn
in_idx = int(in_idx)
out_idx = int(out_idx)
historical_marker = int(historical_marker)
weight = round(float(weight), precision)
return "{}(in: {:<{idx_width}}, out: {:<{idx_width}}, historical_marker: {:<{idx_width}}, weight: {:<{float_width}})".format(
self.__class__.__name__,
in_idx,
out_idx,
historical_marker,
weight,
idx_width=idx_width,
float_width=precision + 3,
)
def to_dict(self, state, conn):
return {
"in": int(conn[0]),
"out": int(conn[1]),
"historical_marker": int(conn[2]),
"weight": jnp.float32(conn[3]),
}