make fully stateful in module gene.

This commit is contained in:
wls2002
2024-05-25 16:13:41 +08:00
parent 3b2f917aee
commit 625c261a49
5 changed files with 21 additions and 36 deletions

View File

@@ -24,21 +24,10 @@ class DefaultConnGene(BaseConnGene):
self.weight_mutate_rate = weight_mutate_rate
self.weight_replace_rate = weight_replace_rate
def new_custom_attrs(self):
def new_attrs(self, state, key):
return jnp.array([self.weight_init_mean])
def new_random_attrs(self, key):
return jnp.array([mutate_float(key,
self.weight_init_mean,
self.weight_init_mean,
1.0,
0,
0,
1.0,
)
])
def mutate(self, key, conn):
def mutate(self, state, key, conn):
input_index = conn[0]
output_index = conn[1]
enabled = conn[2]
@@ -53,9 +42,9 @@ class DefaultConnGene(BaseConnGene):
return jnp.array([input_index, output_index, enabled, weight])
def distance(self, attrs1, attrs2):
def distance(self, state, attrs1, attrs2):
return (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight
def forward(self, attrs, inputs):
def forward(self, state, attrs, inputs):
weight = attrs[0]
return inputs * weight