51 lines
1.6 KiB
Python
51 lines
1.6 KiB
Python
import jax.numpy as jnp
|
|
|
|
from utils import mutate_float
|
|
from . import BaseConnGene
|
|
|
|
|
|
class DefaultConnGene(BaseConnGene):
|
|
"Default connection gene, with the same behavior as in NEAT-python."
|
|
|
|
custom_attrs = ['weight']
|
|
|
|
def __init__(
|
|
self,
|
|
weight_init_mean: float = 0.0,
|
|
weight_init_std: float = 1.0,
|
|
weight_mutate_power: float = 0.5,
|
|
weight_mutate_rate: float = 0.8,
|
|
weight_replace_rate: float = 0.1,
|
|
):
|
|
super().__init__()
|
|
self.weight_init_mean = weight_init_mean
|
|
self.weight_init_std = weight_init_std
|
|
self.weight_mutate_power = weight_mutate_power
|
|
self.weight_mutate_rate = weight_mutate_rate
|
|
self.weight_replace_rate = weight_replace_rate
|
|
|
|
def new_custom_attrs(self):
|
|
return jnp.array([self.weight_init_mean])
|
|
|
|
def mutate(self, key, conn):
|
|
input_index = conn[0]
|
|
output_index = conn[1]
|
|
enabled = conn[2]
|
|
weight = mutate_float(key,
|
|
conn[3],
|
|
self.weight_init_mean,
|
|
self.weight_init_std,
|
|
self.weight_mutate_power,
|
|
self.weight_mutate_rate,
|
|
self.weight_replace_rate
|
|
)
|
|
|
|
return jnp.array([input_index, output_index, enabled, weight])
|
|
|
|
def distance(self, attrs1, attrs2):
|
|
return (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight
|
|
|
|
def forward(self, attrs, inputs):
|
|
weight = attrs[0]
|
|
return inputs * weight
|