make fully stateful in module gene.
This commit is contained in:
@@ -1,3 +1,6 @@
|
|||||||
|
from utils import State
|
||||||
|
|
||||||
|
|
||||||
class BaseGene:
|
class BaseGene:
|
||||||
"Base class for node genes or connection genes."
|
"Base class for node genes or connection genes."
|
||||||
fixed_attrs = []
|
fixed_attrs = []
|
||||||
@@ -6,16 +9,19 @@ class BaseGene:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def new_custom_attrs(self):
|
def setup(self, state=State()):
|
||||||
|
return state
|
||||||
|
|
||||||
|
def new_attrs(self, state, key):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def mutate(self, randkey, gene):
|
def mutate(self, state, key, gene):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def distance(self, gene1, gene2):
|
def distance(self, state, gene1, gene2):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def forward(self, attrs, inputs):
|
def forward(self, state, attrs, inputs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -8,5 +8,5 @@ class BaseConnGene(BaseGene):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, attrs, inputs):
|
def forward(self, state, attrs, inputs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -24,21 +24,10 @@ class DefaultConnGene(BaseConnGene):
|
|||||||
self.weight_mutate_rate = weight_mutate_rate
|
self.weight_mutate_rate = weight_mutate_rate
|
||||||
self.weight_replace_rate = weight_replace_rate
|
self.weight_replace_rate = weight_replace_rate
|
||||||
|
|
||||||
def new_custom_attrs(self):
|
def new_attrs(self, state, key):
|
||||||
return jnp.array([self.weight_init_mean])
|
return jnp.array([self.weight_init_mean])
|
||||||
|
|
||||||
def new_random_attrs(self, key):
|
def mutate(self, state, key, conn):
|
||||||
return jnp.array([mutate_float(key,
|
|
||||||
self.weight_init_mean,
|
|
||||||
self.weight_init_mean,
|
|
||||||
1.0,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
1.0,
|
|
||||||
)
|
|
||||||
])
|
|
||||||
|
|
||||||
def mutate(self, key, conn):
|
|
||||||
input_index = conn[0]
|
input_index = conn[0]
|
||||||
output_index = conn[1]
|
output_index = conn[1]
|
||||||
enabled = conn[2]
|
enabled = conn[2]
|
||||||
@@ -53,9 +42,9 @@ class DefaultConnGene(BaseConnGene):
|
|||||||
|
|
||||||
return jnp.array([input_index, output_index, enabled, weight])
|
return jnp.array([input_index, output_index, enabled, weight])
|
||||||
|
|
||||||
def distance(self, attrs1, attrs2):
|
def distance(self, state, attrs1, attrs2):
|
||||||
return (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight
|
return (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight
|
||||||
|
|
||||||
def forward(self, attrs, inputs):
|
def forward(self, state, attrs, inputs):
|
||||||
weight = attrs[0]
|
weight = attrs[0]
|
||||||
return inputs * weight
|
return inputs * weight
|
||||||
|
|||||||
@@ -8,5 +8,5 @@ class BaseNodeGene(BaseGene):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, attrs, inputs, is_output_node=False):
|
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -56,22 +56,12 @@ class DefaultNodeGene(BaseNodeGene):
|
|||||||
self.aggregation_indices = jnp.arange(len(aggregation_options))
|
self.aggregation_indices = jnp.arange(len(aggregation_options))
|
||||||
self.aggregation_replace_rate = aggregation_replace_rate
|
self.aggregation_replace_rate = aggregation_replace_rate
|
||||||
|
|
||||||
def new_custom_attrs(self):
|
def new_attrs(self, state, key):
|
||||||
return jnp.array(
|
return jnp.array(
|
||||||
[self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default]
|
[self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default]
|
||||||
)
|
)
|
||||||
|
|
||||||
def new_random_attrs(self, key):
|
def mutate(self, state, key, node):
|
||||||
return jnp.array([
|
|
||||||
mutate_float(key, self.bias_init_mean, self.bias_init_mean, self.bias_init_std,
|
|
||||||
self.bias_mutate_power, self.bias_mutate_rate, self.bias_replace_rate),
|
|
||||||
mutate_float(key, self.response_init_mean, self.response_init_mean, self.response_init_std,
|
|
||||||
self.response_mutate_power, self.response_mutate_rate, self.response_replace_rate),
|
|
||||||
self.activation_default,
|
|
||||||
self.aggregation_default,
|
|
||||||
])
|
|
||||||
|
|
||||||
def mutate(self, key, node):
|
|
||||||
k1, k2, k3, k4 = jax.random.split(key, num=4)
|
k1, k2, k3, k4 = jax.random.split(key, num=4)
|
||||||
index = node[0]
|
index = node[0]
|
||||||
|
|
||||||
@@ -87,7 +77,7 @@ class DefaultNodeGene(BaseNodeGene):
|
|||||||
|
|
||||||
return jnp.array([index, bias, res, act, agg])
|
return jnp.array([index, bias, res, act, agg])
|
||||||
|
|
||||||
def distance(self, node1, node2):
|
def distance(self, state, node1, node2):
|
||||||
return (
|
return (
|
||||||
jnp.abs(node1[1] - node2[1]) +
|
jnp.abs(node1[1] - node2[1]) +
|
||||||
jnp.abs(node1[2] - node2[2]) +
|
jnp.abs(node1[2] - node2[2]) +
|
||||||
@@ -95,7 +85,7 @@ class DefaultNodeGene(BaseNodeGene):
|
|||||||
(node1[4] != node2[4])
|
(node1[4] != node2[4])
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, attrs, inputs, is_output_node=False):
|
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||||
bias, res, act_idx, agg_idx = attrs
|
bias, res, act_idx, agg_idx = attrs
|
||||||
|
|
||||||
z = agg(agg_idx, inputs, self.aggregation_options)
|
z = agg(agg_idx, inputs, self.aggregation_options)
|
||||||
|
|||||||
Reference in New Issue
Block a user