add ipynb test for testing whether add node or add conn will not change the output for the network.
50 lines
1.4 KiB
Python
50 lines
1.4 KiB
Python
import jax, jax.numpy as jnp
|
|
from .. import BaseGene
|
|
|
|
|
|
class BaseNodeGene(BaseGene):
|
|
"Base class for node genes."
|
|
fixed_attrs = ["index"]
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, state, attrs, inputs, is_output_node=False):
|
|
raise NotImplementedError
|
|
|
|
def input_transform(self, state, attrs, inputs):
|
|
"""
|
|
make transformation in the input node.
|
|
default: do nothing
|
|
"""
|
|
return inputs
|
|
|
|
def update_by_batch(self, state, attrs, batch_inputs, is_output_node=False):
|
|
# default: do not update attrs, but to calculate batch_res
|
|
return (
|
|
jax.vmap(self.forward, in_axes=(None, None, 0, None))(
|
|
state, attrs, batch_inputs, is_output_node
|
|
),
|
|
attrs,
|
|
)
|
|
|
|
def update_input_transform(self, state, attrs, batch_inputs):
|
|
"""
|
|
update the attrs for transformation in the input node.
|
|
default: do nothing
|
|
"""
|
|
return (
|
|
jax.vmap(self.input_transform, in_axes=(None, None, 0))(
|
|
state, attrs, batch_inputs
|
|
),
|
|
attrs,
|
|
)
|
|
|
|
def repr(self, state, node, precision=2, idx_width=3, func_width=8):
|
|
idx = node[0]
|
|
|
|
idx = int(idx)
|
|
return "{}(idx={:<{idx_width}})".format(
|
|
self.__class__.__name__, idx, idx_width=idx_width
|
|
)
|