All function with state will update the state and return it.

Remove randkey args in functions with state, since it can attach the randkey by states.
This commit is contained in:
wls2002
2024-05-25 20:45:57 +08:00
parent 5626fddf41
commit 79d53ea7af
12 changed files with 84 additions and 70 deletions

View File

@@ -1,4 +1,5 @@
import jax.numpy as jnp
import jax.random
from utils import mutate_float
from . import BaseConnGene
@@ -24,14 +25,15 @@ class DefaultConnGene(BaseConnGene):
self.weight_mutate_rate = weight_mutate_rate
self.weight_replace_rate = weight_replace_rate
def new_attrs(self, state, key):
return jnp.array([self.weight_init_mean])
def new_attrs(self, state):
return state, jnp.array([self.weight_init_mean])
def mutate(self, state, key, conn):
def mutate(self, state, conn):
randkey_, randkey = jax.random.split(state.randkey, 2)
input_index = conn[0]
output_index = conn[1]
enabled = conn[2]
weight = mutate_float(key,
weight = mutate_float(randkey_,
conn[3],
self.weight_init_mean,
self.weight_init_std,
@@ -40,11 +42,11 @@ class DefaultConnGene(BaseConnGene):
self.weight_replace_rate
)
return jnp.array([input_index, output_index, enabled, weight])
return state.update(randkey=randkey), jnp.array([input_index, output_index, enabled, weight])
def distance(self, state, attrs1, attrs2):
return (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight
return state, (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight
def forward(self, state, attrs, inputs):
weight = attrs[0]
return inputs * weight
return state, inputs * weight