add ipynb test for testing whether add node or add conn will not change the output for the network.
46 lines
1.2 KiB
Python
46 lines
1.2 KiB
Python
import jax, jax.numpy as jnp
|
|
from utils import State, StatefulBaseClass
|
|
|
|
|
|
class BaseGene(StatefulBaseClass):
|
|
"Base class for node genes or connection genes."
|
|
fixed_attrs = []
|
|
custom_attrs = []
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def new_identity_attrs(self, state):
|
|
# the attrs which do identity transformation, used in mutate add node
|
|
raise NotImplementedError
|
|
|
|
def new_random_attrs(self, state, randkey):
|
|
# random attributes of the gene. used in initialization.
|
|
raise NotImplementedError
|
|
|
|
def mutate(self, state, randkey, attrs):
|
|
raise NotImplementedError
|
|
|
|
def crossover(self, state, randkey, attrs1, attrs2):
|
|
return jnp.where(
|
|
jax.random.normal(randkey, attrs1.shape) > 0,
|
|
attrs1,
|
|
attrs2,
|
|
)
|
|
|
|
def distance(self, state, attrs1, attrs2):
|
|
raise NotImplementedError
|
|
|
|
def forward(self, state, attrs, inputs):
|
|
raise NotImplementedError
|
|
|
|
def update_by_batch(self, state, attrs, batch_inputs):
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def length(self):
|
|
return len(self.fixed_attrs) + len(self.custom_attrs)
|
|
|
|
def repr(self, state, gene, precision=2):
|
|
raise NotImplementedError
|