make fully stateful in module gene.
This commit is contained in:
@@ -8,5 +8,5 @@ class BaseNodeGene(BaseGene):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, attrs, inputs, is_output_node=False):
|
||||
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -56,22 +56,12 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
self.aggregation_indices = jnp.arange(len(aggregation_options))
|
||||
self.aggregation_replace_rate = aggregation_replace_rate
|
||||
|
||||
def new_custom_attrs(self):
|
||||
def new_attrs(self, state, key):
|
||||
return jnp.array(
|
||||
[self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default]
|
||||
)
|
||||
|
||||
def new_random_attrs(self, key):
|
||||
return jnp.array([
|
||||
mutate_float(key, self.bias_init_mean, self.bias_init_mean, self.bias_init_std,
|
||||
self.bias_mutate_power, self.bias_mutate_rate, self.bias_replace_rate),
|
||||
mutate_float(key, self.response_init_mean, self.response_init_mean, self.response_init_std,
|
||||
self.response_mutate_power, self.response_mutate_rate, self.response_replace_rate),
|
||||
self.activation_default,
|
||||
self.aggregation_default,
|
||||
])
|
||||
|
||||
def mutate(self, key, node):
|
||||
def mutate(self, state, key, node):
|
||||
k1, k2, k3, k4 = jax.random.split(key, num=4)
|
||||
index = node[0]
|
||||
|
||||
@@ -87,7 +77,7 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
|
||||
return jnp.array([index, bias, res, act, agg])
|
||||
|
||||
def distance(self, node1, node2):
|
||||
def distance(self, state, node1, node2):
|
||||
return (
|
||||
jnp.abs(node1[1] - node2[1]) +
|
||||
jnp.abs(node1[2] - node2[2]) +
|
||||
@@ -95,7 +85,7 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
(node1[4] != node2[4])
|
||||
)
|
||||
|
||||
def forward(self, attrs, inputs, is_output_node=False):
|
||||
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||
bias, res, act_idx, agg_idx = attrs
|
||||
|
||||
z = agg(agg_idx, inputs, self.aggregation_options)
|
||||
|
||||
Reference in New Issue
Block a user