modify the behavior for mutate_add_node and mutate_add_conn. Currently, this two mutation will just change the structure of the network, but not influence the output for the network.
This commit is contained in:
@@ -9,12 +9,9 @@ class BaseConnGene(BaseGene):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def crossover(self, state, randkey, gene1, gene2):
|
||||
return jnp.where(
|
||||
jax.random.normal(randkey, gene1.shape) > 0,
|
||||
gene1,
|
||||
gene2,
|
||||
)
|
||||
def new_zero_attrs(self, state):
|
||||
# the attrs which make the least influence on the network, used in mutate add conn
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, state, attrs, inputs):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -25,8 +25,11 @@ class DefaultConnGene(BaseConnGene):
|
||||
self.weight_mutate_rate = weight_mutate_rate
|
||||
self.weight_replace_rate = weight_replace_rate
|
||||
|
||||
def new_custom_attrs(self, state):
|
||||
return jnp.array([self.weight_init_mean])
|
||||
def new_zero_attrs(self, state):
|
||||
return jnp.array([0.0]) # weight = 0
|
||||
|
||||
def new_identity_attrs(self, state):
|
||||
return jnp.array([1.0]) # weight = 1
|
||||
|
||||
def new_random_attrs(self, state, randkey):
|
||||
weight = (
|
||||
@@ -35,12 +38,11 @@ class DefaultConnGene(BaseConnGene):
|
||||
)
|
||||
return jnp.array([weight])
|
||||
|
||||
def mutate(self, state, randkey, conn):
|
||||
input_index = conn[0]
|
||||
output_index = conn[1]
|
||||
def mutate(self, state, randkey, attrs):
|
||||
weight = attrs[0]
|
||||
weight = mutate_float(
|
||||
randkey,
|
||||
conn[2],
|
||||
weight,
|
||||
self.weight_init_mean,
|
||||
self.weight_init_std,
|
||||
self.weight_mutate_power,
|
||||
@@ -48,10 +50,12 @@ class DefaultConnGene(BaseConnGene):
|
||||
self.weight_replace_rate,
|
||||
)
|
||||
|
||||
return jnp.array([input_index, output_index, weight])
|
||||
return jnp.array([weight])
|
||||
|
||||
def distance(self, state, attrs1, attrs2):
|
||||
return jnp.abs(attrs1[0] - attrs2[0])
|
||||
weight1 = attrs1[0]
|
||||
weight2 = attrs2[0]
|
||||
return jnp.abs(weight1 - weight2)
|
||||
|
||||
def forward(self, state, attrs, inputs):
|
||||
weight = attrs[0]
|
||||
|
||||
Reference in New Issue
Block a user