make fully stateful in module gene.

This commit is contained in:
wls2002
2024-05-25 16:13:41 +08:00
parent 3b2f917aee
commit 625c261a49
5 changed files with 21 additions and 36 deletions

View File

@@ -1,3 +1,6 @@
from utils import State
class BaseGene: class BaseGene:
"Base class for node genes or connection genes." "Base class for node genes or connection genes."
fixed_attrs = [] fixed_attrs = []
@@ -6,16 +9,19 @@ class BaseGene:
def __init__(self): def __init__(self):
pass pass
def new_custom_attrs(self): def setup(self, state=State()):
return state
def new_attrs(self, state, key):
raise NotImplementedError raise NotImplementedError
def mutate(self, randkey, gene): def mutate(self, state, key, gene):
raise NotImplementedError raise NotImplementedError
def distance(self, gene1, gene2): def distance(self, state, gene1, gene2):
raise NotImplementedError raise NotImplementedError
def forward(self, attrs, inputs): def forward(self, state, attrs, inputs):
raise NotImplementedError raise NotImplementedError
@property @property

View File

@@ -8,5 +8,5 @@ class BaseConnGene(BaseGene):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def forward(self, attrs, inputs): def forward(self, state, attrs, inputs):
raise NotImplementedError raise NotImplementedError

View File

@@ -24,21 +24,10 @@ class DefaultConnGene(BaseConnGene):
self.weight_mutate_rate = weight_mutate_rate self.weight_mutate_rate = weight_mutate_rate
self.weight_replace_rate = weight_replace_rate self.weight_replace_rate = weight_replace_rate
def new_custom_attrs(self): def new_attrs(self, state, key):
return jnp.array([self.weight_init_mean]) return jnp.array([self.weight_init_mean])
def new_random_attrs(self, key): def mutate(self, state, key, conn):
return jnp.array([mutate_float(key,
self.weight_init_mean,
self.weight_init_mean,
1.0,
0,
0,
1.0,
)
])
def mutate(self, key, conn):
input_index = conn[0] input_index = conn[0]
output_index = conn[1] output_index = conn[1]
enabled = conn[2] enabled = conn[2]
@@ -53,9 +42,9 @@ class DefaultConnGene(BaseConnGene):
return jnp.array([input_index, output_index, enabled, weight]) return jnp.array([input_index, output_index, enabled, weight])
def distance(self, attrs1, attrs2): def distance(self, state, attrs1, attrs2):
return (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight return (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight
def forward(self, attrs, inputs): def forward(self, state, attrs, inputs):
weight = attrs[0] weight = attrs[0]
return inputs * weight return inputs * weight

View File

@@ -8,5 +8,5 @@ class BaseNodeGene(BaseGene):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def forward(self, attrs, inputs, is_output_node=False): def forward(self, state, attrs, inputs, is_output_node=False):
raise NotImplementedError raise NotImplementedError

View File

@@ -56,22 +56,12 @@ class DefaultNodeGene(BaseNodeGene):
self.aggregation_indices = jnp.arange(len(aggregation_options)) self.aggregation_indices = jnp.arange(len(aggregation_options))
self.aggregation_replace_rate = aggregation_replace_rate self.aggregation_replace_rate = aggregation_replace_rate
def new_custom_attrs(self): def new_attrs(self, state, key):
return jnp.array( return jnp.array(
[self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default] [self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default]
) )
def new_random_attrs(self, key): def mutate(self, state, key, node):
return jnp.array([
mutate_float(key, self.bias_init_mean, self.bias_init_mean, self.bias_init_std,
self.bias_mutate_power, self.bias_mutate_rate, self.bias_replace_rate),
mutate_float(key, self.response_init_mean, self.response_init_mean, self.response_init_std,
self.response_mutate_power, self.response_mutate_rate, self.response_replace_rate),
self.activation_default,
self.aggregation_default,
])
def mutate(self, key, node):
k1, k2, k3, k4 = jax.random.split(key, num=4) k1, k2, k3, k4 = jax.random.split(key, num=4)
index = node[0] index = node[0]
@@ -87,7 +77,7 @@ class DefaultNodeGene(BaseNodeGene):
return jnp.array([index, bias, res, act, agg]) return jnp.array([index, bias, res, act, agg])
def distance(self, node1, node2): def distance(self, state, node1, node2):
return ( return (
jnp.abs(node1[1] - node2[1]) + jnp.abs(node1[1] - node2[1]) +
jnp.abs(node1[2] - node2[2]) + jnp.abs(node1[2] - node2[2]) +
@@ -95,7 +85,7 @@ class DefaultNodeGene(BaseNodeGene):
(node1[4] != node2[4]) (node1[4] != node2[4])
) )
def forward(self, attrs, inputs, is_output_node=False): def forward(self, state, attrs, inputs, is_output_node=False):
bias, res, act_idx, agg_idx = attrs bias, res, act_idx, agg_idx = attrs
z = agg(agg_idx, inputs, self.aggregation_options) z = agg(agg_idx, inputs, self.aggregation_options)