43 lines
1.1 KiB
Python
43 lines
1.1 KiB
Python
import jax, jax.numpy as jnp
|
|
from utils import State, StatefulBaseClass
|
|
|
|
|
|
class BaseGene(StatefulBaseClass):
|
|
"Base class for node genes or connection genes."
|
|
fixed_attrs = []
|
|
custom_attrs = []
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def new_identity_attrs(self, state):
|
|
# the attrs which do identity transformation, used in mutate add node
|
|
raise NotImplementedError
|
|
|
|
def new_random_attrs(self, state, randkey):
|
|
# random attributes of the gene. used in initialization.
|
|
raise NotImplementedError
|
|
|
|
def mutate(self, state, randkey, attrs):
|
|
raise NotImplementedError
|
|
|
|
def crossover(self, state, randkey, attrs1, attrs2):
|
|
return jnp.where(
|
|
jax.random.normal(randkey, attrs1.shape) > 0,
|
|
attrs1,
|
|
attrs2,
|
|
)
|
|
|
|
def distance(self, state, attrs1, attrs2):
|
|
raise NotImplementedError
|
|
|
|
def forward(self, state, attrs, inputs):
|
|
raise NotImplementedError
|
|
|
|
def update_by_batch(self, state, attrs, batch_inputs):
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def length(self):
|
|
return len(self.fixed_attrs) + len(self.custom_attrs)
|