make fully stateful in module gene.
This commit is contained in:
@@ -8,5 +8,5 @@ class BaseConnGene(BaseGene):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, attrs, inputs):
|
||||
def forward(self, state, attrs, inputs):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -24,21 +24,10 @@ class DefaultConnGene(BaseConnGene):
|
||||
self.weight_mutate_rate = weight_mutate_rate
|
||||
self.weight_replace_rate = weight_replace_rate
|
||||
|
||||
def new_custom_attrs(self):
|
||||
def new_attrs(self, state, key):
|
||||
return jnp.array([self.weight_init_mean])
|
||||
|
||||
def new_random_attrs(self, key):
|
||||
return jnp.array([mutate_float(key,
|
||||
self.weight_init_mean,
|
||||
self.weight_init_mean,
|
||||
1.0,
|
||||
0,
|
||||
0,
|
||||
1.0,
|
||||
)
|
||||
])
|
||||
|
||||
def mutate(self, key, conn):
|
||||
def mutate(self, state, key, conn):
|
||||
input_index = conn[0]
|
||||
output_index = conn[1]
|
||||
enabled = conn[2]
|
||||
@@ -53,9 +42,9 @@ class DefaultConnGene(BaseConnGene):
|
||||
|
||||
return jnp.array([input_index, output_index, enabled, weight])
|
||||
|
||||
def distance(self, attrs1, attrs2):
|
||||
def distance(self, state, attrs1, attrs2):
|
||||
return (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight
|
||||
|
||||
def forward(self, attrs, inputs):
|
||||
def forward(self, state, attrs, inputs):
|
||||
weight = attrs[0]
|
||||
return inputs * weight
|
||||
|
||||
Reference in New Issue
Block a user