Files
tensorneat-mend/tensorneat/algorithm/neat/gene/conn/default.py
wls2002 79d53ea7af All function with state will update the state and return it.
Remove randkey args in functions with state, since it can attach the randkey by states.
2024-05-25 20:45:57 +08:00

53 lines
1.8 KiB
Python

import jax.numpy as jnp
import jax.random
from utils import mutate_float
from . import BaseConnGene
class DefaultConnGene(BaseConnGene):
"Default connection gene, with the same behavior as in NEAT-python."
custom_attrs = ['weight']
def __init__(
self,
weight_init_mean: float = 0.0,
weight_init_std: float = 1.0,
weight_mutate_power: float = 0.5,
weight_mutate_rate: float = 0.8,
weight_replace_rate: float = 0.1,
):
super().__init__()
self.weight_init_mean = weight_init_mean
self.weight_init_std = weight_init_std
self.weight_mutate_power = weight_mutate_power
self.weight_mutate_rate = weight_mutate_rate
self.weight_replace_rate = weight_replace_rate
def new_attrs(self, state):
return state, jnp.array([self.weight_init_mean])
def mutate(self, state, conn):
randkey_, randkey = jax.random.split(state.randkey, 2)
input_index = conn[0]
output_index = conn[1]
enabled = conn[2]
weight = mutate_float(randkey_,
conn[3],
self.weight_init_mean,
self.weight_init_std,
self.weight_mutate_power,
self.weight_mutate_rate,
self.weight_replace_rate
)
return state.update(randkey=randkey), jnp.array([input_index, output_index, enabled, weight])
def distance(self, state, attrs1, attrs2):
return state, (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight
def forward(self, state, attrs, inputs):
weight = attrs[0]
return state, inputs * weight