All function with state will update the state and return it.

Remove randkey args in functions with state, since it can attach the randkey by states.
This commit is contained in:
wls2002
2024-05-25 20:45:57 +08:00
parent 5626fddf41
commit 79d53ea7af
12 changed files with 84 additions and 70 deletions

View File

@@ -9,13 +9,13 @@ class BaseGene:
def __init__(self):
pass
def setup(self, key, state=State()):
def setup(self, state=State()):
return state
def new_attrs(self, state, key):
def new_attrs(self, state):
raise NotImplementedError
def mutate(self, state, key, gene):
def mutate(self, state, gene):
raise NotImplementedError
def distance(self, state, gene1, gene2):

View File

@@ -1,4 +1,5 @@
import jax.numpy as jnp
import jax.random
from utils import mutate_float
from . import BaseConnGene
@@ -24,14 +25,15 @@ class DefaultConnGene(BaseConnGene):
self.weight_mutate_rate = weight_mutate_rate
self.weight_replace_rate = weight_replace_rate
def new_attrs(self, state, key):
return jnp.array([self.weight_init_mean])
def new_attrs(self, state):
return state, jnp.array([self.weight_init_mean])
def mutate(self, state, key, conn):
def mutate(self, state, conn):
randkey_, randkey = jax.random.split(state.randkey, 2)
input_index = conn[0]
output_index = conn[1]
enabled = conn[2]
weight = mutate_float(key,
weight = mutate_float(randkey_,
conn[3],
self.weight_init_mean,
self.weight_init_std,
@@ -40,11 +42,11 @@ class DefaultConnGene(BaseConnGene):
self.weight_replace_rate
)
return jnp.array([input_index, output_index, enabled, weight])
return state.update(randkey=randkey), jnp.array([input_index, output_index, enabled, weight])
def distance(self, state, attrs1, attrs2):
return (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight
return state, (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight
def forward(self, state, attrs, inputs):
weight = attrs[0]
return inputs * weight
return state, inputs * weight

View File

@@ -56,13 +56,13 @@ class DefaultNodeGene(BaseNodeGene):
self.aggregation_indices = jnp.arange(len(aggregation_options))
self.aggregation_replace_rate = aggregation_replace_rate
def new_attrs(self, state, key):
return jnp.array(
def new_attrs(self, state):
return state, jnp.array(
[self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default]
)
def mutate(self, state, key, node):
k1, k2, k3, k4 = jax.random.split(key, num=4)
def mutate(self, state, node):
k1, k2, k3, k4, randkey = jax.random.split(state.randkey, num=5)
index = node[0]
bias = mutate_float(k1, node[1], self.bias_init_mean, self.bias_init_std,
@@ -75,10 +75,10 @@ class DefaultNodeGene(BaseNodeGene):
agg = mutate_int(k4, node[4], self.aggregation_indices, self.aggregation_replace_rate)
return jnp.array([index, bias, res, act, agg])
return state.update(randkey=randkey), jnp.array([index, bias, res, act, agg])
def distance(self, state, node1, node2):
return (
return state, (
jnp.abs(node1[1] - node2[1]) +
jnp.abs(node1[2] - node2[2]) +
(node1[3] != node2[3]) +
@@ -98,4 +98,4 @@ class DefaultNodeGene(BaseNodeGene):
lambda: act(act_idx, z, self.activation_options)
)
return z
return state, z