All function with state will update the state and return it.
Remove randkey args in functions with state, since it can attach the randkey by states.
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import jax.numpy as jnp
|
||||
import jax.random
|
||||
|
||||
from utils import mutate_float
|
||||
from . import BaseConnGene
|
||||
@@ -24,14 +25,15 @@ class DefaultConnGene(BaseConnGene):
|
||||
self.weight_mutate_rate = weight_mutate_rate
|
||||
self.weight_replace_rate = weight_replace_rate
|
||||
|
||||
def new_attrs(self, state, key):
|
||||
return jnp.array([self.weight_init_mean])
|
||||
def new_attrs(self, state):
|
||||
return state, jnp.array([self.weight_init_mean])
|
||||
|
||||
def mutate(self, state, key, conn):
|
||||
def mutate(self, state, conn):
|
||||
randkey_, randkey = jax.random.split(state.randkey, 2)
|
||||
input_index = conn[0]
|
||||
output_index = conn[1]
|
||||
enabled = conn[2]
|
||||
weight = mutate_float(key,
|
||||
weight = mutate_float(randkey_,
|
||||
conn[3],
|
||||
self.weight_init_mean,
|
||||
self.weight_init_std,
|
||||
@@ -40,11 +42,11 @@ class DefaultConnGene(BaseConnGene):
|
||||
self.weight_replace_rate
|
||||
)
|
||||
|
||||
return jnp.array([input_index, output_index, enabled, weight])
|
||||
return state.update(randkey=randkey), jnp.array([input_index, output_index, enabled, weight])
|
||||
|
||||
def distance(self, state, attrs1, attrs2):
|
||||
return (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight
|
||||
return state, (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight
|
||||
|
||||
def forward(self, state, attrs, inputs):
|
||||
weight = attrs[0]
|
||||
return inputs * weight
|
||||
return state, inputs * weight
|
||||
|
||||
Reference in New Issue
Block a user