add ipynb test for testing whether add node or add conn will not change the output for the network.
79 lines
2.2 KiB
Python
79 lines
2.2 KiB
Python
import jax.numpy as jnp
|
|
import jax.random
|
|
|
|
from utils import mutate_float
|
|
from . import BaseConnGene
|
|
|
|
|
|
class DefaultConnGene(BaseConnGene):
|
|
"Default connection gene, with the same behavior as in NEAT-python."
|
|
|
|
custom_attrs = ["weight"]
|
|
|
|
def __init__(
|
|
self,
|
|
weight_init_mean: float = 0.0,
|
|
weight_init_std: float = 1.0,
|
|
weight_mutate_power: float = 0.5,
|
|
weight_mutate_rate: float = 0.8,
|
|
weight_replace_rate: float = 0.1,
|
|
):
|
|
super().__init__()
|
|
self.weight_init_mean = weight_init_mean
|
|
self.weight_init_std = weight_init_std
|
|
self.weight_mutate_power = weight_mutate_power
|
|
self.weight_mutate_rate = weight_mutate_rate
|
|
self.weight_replace_rate = weight_replace_rate
|
|
|
|
def new_zero_attrs(self, state):
|
|
return jnp.array([0.0]) # weight = 0
|
|
|
|
def new_identity_attrs(self, state):
|
|
return jnp.array([1.0]) # weight = 1
|
|
|
|
def new_random_attrs(self, state, randkey):
|
|
weight = (
|
|
jax.random.normal(randkey, ()) * self.weight_init_std
|
|
+ self.weight_init_mean
|
|
)
|
|
return jnp.array([weight])
|
|
|
|
def mutate(self, state, randkey, attrs):
|
|
weight = attrs[0]
|
|
weight = mutate_float(
|
|
randkey,
|
|
weight,
|
|
self.weight_init_mean,
|
|
self.weight_init_std,
|
|
self.weight_mutate_power,
|
|
self.weight_mutate_rate,
|
|
self.weight_replace_rate,
|
|
)
|
|
|
|
return jnp.array([weight])
|
|
|
|
def distance(self, state, attrs1, attrs2):
|
|
weight1 = attrs1[0]
|
|
weight2 = attrs2[0]
|
|
return jnp.abs(weight1 - weight2)
|
|
|
|
def forward(self, state, attrs, inputs):
|
|
weight = attrs[0]
|
|
return inputs * weight
|
|
|
|
def repr(self, state, conn, precision=2, idx_width=3, func_width=8):
|
|
in_idx, out_idx, weight = conn
|
|
|
|
in_idx = int(in_idx)
|
|
out_idx = int(out_idx)
|
|
weight = round(float(weight), precision)
|
|
|
|
return "{}(in: {:<{idx_width}}, out: {:<{idx_width}}, weight: {:<{float_width}})".format(
|
|
self.__class__.__name__,
|
|
in_idx,
|
|
out_idx,
|
|
weight,
|
|
idx_width=idx_width,
|
|
float_width=precision + 3,
|
|
)
|