All function with state will update the state and return it.
Remove randkey args in functions with state, since it can attach the randkey by states.
This commit is contained in:
@@ -9,13 +9,13 @@ class BaseGene:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def setup(self, key, state=State()):
|
||||
def setup(self, state=State()):
|
||||
return state
|
||||
|
||||
def new_attrs(self, state, key):
|
||||
def new_attrs(self, state):
|
||||
raise NotImplementedError
|
||||
|
||||
def mutate(self, state, key, gene):
|
||||
def mutate(self, state, gene):
|
||||
raise NotImplementedError
|
||||
|
||||
def distance(self, state, gene1, gene2):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import jax.numpy as jnp
|
||||
import jax.random
|
||||
|
||||
from utils import mutate_float
|
||||
from . import BaseConnGene
|
||||
@@ -24,14 +25,15 @@ class DefaultConnGene(BaseConnGene):
|
||||
self.weight_mutate_rate = weight_mutate_rate
|
||||
self.weight_replace_rate = weight_replace_rate
|
||||
|
||||
def new_attrs(self, state, key):
|
||||
return jnp.array([self.weight_init_mean])
|
||||
def new_attrs(self, state):
|
||||
return state, jnp.array([self.weight_init_mean])
|
||||
|
||||
def mutate(self, state, key, conn):
|
||||
def mutate(self, state, conn):
|
||||
randkey_, randkey = jax.random.split(state.randkey, 2)
|
||||
input_index = conn[0]
|
||||
output_index = conn[1]
|
||||
enabled = conn[2]
|
||||
weight = mutate_float(key,
|
||||
weight = mutate_float(randkey_,
|
||||
conn[3],
|
||||
self.weight_init_mean,
|
||||
self.weight_init_std,
|
||||
@@ -40,11 +42,11 @@ class DefaultConnGene(BaseConnGene):
|
||||
self.weight_replace_rate
|
||||
)
|
||||
|
||||
return jnp.array([input_index, output_index, enabled, weight])
|
||||
return state.update(randkey=randkey), jnp.array([input_index, output_index, enabled, weight])
|
||||
|
||||
def distance(self, state, attrs1, attrs2):
|
||||
return (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight
|
||||
return state, (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight
|
||||
|
||||
def forward(self, state, attrs, inputs):
|
||||
weight = attrs[0]
|
||||
return inputs * weight
|
||||
return state, inputs * weight
|
||||
|
||||
@@ -56,13 +56,13 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
self.aggregation_indices = jnp.arange(len(aggregation_options))
|
||||
self.aggregation_replace_rate = aggregation_replace_rate
|
||||
|
||||
def new_attrs(self, state, key):
|
||||
return jnp.array(
|
||||
def new_attrs(self, state):
|
||||
return state, jnp.array(
|
||||
[self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default]
|
||||
)
|
||||
|
||||
def mutate(self, state, key, node):
|
||||
k1, k2, k3, k4 = jax.random.split(key, num=4)
|
||||
def mutate(self, state, node):
|
||||
k1, k2, k3, k4, randkey = jax.random.split(state.randkey, num=5)
|
||||
index = node[0]
|
||||
|
||||
bias = mutate_float(k1, node[1], self.bias_init_mean, self.bias_init_std,
|
||||
@@ -75,10 +75,10 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
|
||||
agg = mutate_int(k4, node[4], self.aggregation_indices, self.aggregation_replace_rate)
|
||||
|
||||
return jnp.array([index, bias, res, act, agg])
|
||||
return state.update(randkey=randkey), jnp.array([index, bias, res, act, agg])
|
||||
|
||||
def distance(self, state, node1, node2):
|
||||
return (
|
||||
return state, (
|
||||
jnp.abs(node1[1] - node2[1]) +
|
||||
jnp.abs(node1[2] - node2[2]) +
|
||||
(node1[3] != node2[3]) +
|
||||
@@ -98,4 +98,4 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
lambda: act(act_idx, z, self.activation_options)
|
||||
)
|
||||
|
||||
return z
|
||||
return state, z
|
||||
|
||||
Reference in New Issue
Block a user