fix bug in crossover: the child from two normal networks should always be normal.
This commit is contained in:
8
tensorneat/.idea/.gitignore
generated
vendored
Normal file
8
tensorneat/.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
# 默认忽略的文件
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# 基于编辑器的 HTTP 客户端请求
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from .ga import *
|
||||||
from .gene import *
|
from .gene import *
|
||||||
from .genome import *
|
from .genome import *
|
||||||
from .species import *
|
from .species import *
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import jax, jax.numpy as jnp
|
|||||||
|
|
||||||
from .base import BaseCrossover
|
from .base import BaseCrossover
|
||||||
|
|
||||||
|
|
||||||
class DefaultCrossover(BaseCrossover):
|
class DefaultCrossover(BaseCrossover):
|
||||||
|
|
||||||
def __call__(self, randkey, genome, nodes1, conns1, nodes2, conns2):
|
def __call__(self, randkey, genome, nodes1, conns1, nodes2, conns2):
|
||||||
@@ -14,17 +15,19 @@ class DefaultCrossover(BaseCrossover):
|
|||||||
# crossover nodes
|
# crossover nodes
|
||||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||||
# make homologous genes align in nodes2 align with nodes1
|
# make homologous genes align in nodes2 align with nodes1
|
||||||
nodes2 = self.align_array(keys1, keys2, nodes2, False)
|
nodes2 = self.align_array(keys1, keys2, nodes2, is_conn=False)
|
||||||
|
|
||||||
# For not homologous genes, use the value of nodes1(winner)
|
# For not homologous genes, use the value of nodes1(winner)
|
||||||
# For homologous genes, use the crossover result between nodes1 and nodes2
|
# For homologous genes, use the crossover result between nodes1 and nodes2
|
||||||
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, self.crossover_gene(randkey_1, nodes1, nodes2))
|
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1,
|
||||||
|
self.crossover_gene(randkey_1, nodes1, nodes2, is_conn=False))
|
||||||
|
|
||||||
# crossover connections
|
# crossover connections
|
||||||
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
|
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
|
||||||
conns2 = self.align_array(con_keys1, con_keys2, conns2, True)
|
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
|
||||||
|
|
||||||
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, self.crossover_gene(randkey_2, conns1, conns2))
|
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1,
|
||||||
|
self.crossover_gene(randkey_2, conns1, conns2, is_conn=True))
|
||||||
|
|
||||||
return new_nodes, new_conns
|
return new_nodes, new_conns
|
||||||
|
|
||||||
@@ -54,14 +57,14 @@ class DefaultCrossover(BaseCrossover):
|
|||||||
|
|
||||||
return refactor_ar2
|
return refactor_ar2
|
||||||
|
|
||||||
def crossover_gene(self, rand_key, g1, g2):
|
def crossover_gene(self, rand_key, g1, g2, is_conn):
|
||||||
"""
|
|
||||||
crossover two genes
|
|
||||||
:param rand_key:
|
|
||||||
:param g1:
|
|
||||||
:param g2:
|
|
||||||
:return:
|
|
||||||
only gene with the same key will be crossover, thus don't need to consider change key
|
|
||||||
"""
|
|
||||||
r = jax.random.uniform(rand_key, shape=g1.shape)
|
r = jax.random.uniform(rand_key, shape=g1.shape)
|
||||||
return jnp.where(r > 0.5, g1, g2)
|
new_gene = jnp.where(r > 0.5, g1, g2)
|
||||||
|
if is_conn: # fix enabled
|
||||||
|
enabled = jnp.where(
|
||||||
|
g1[:, 2] + g2[:, 2] > 0, # any of them is enabled
|
||||||
|
1,
|
||||||
|
0
|
||||||
|
)
|
||||||
|
new_gene = new_gene.at[:, 2].set(enabled)
|
||||||
|
return new_gene
|
||||||
|
|||||||
@@ -154,8 +154,8 @@ class DefaultMutation(BaseMutation):
|
|||||||
nodes_keys = jax.random.split(k1, num=nodes.shape[0])
|
nodes_keys = jax.random.split(k1, num=nodes.shape[0])
|
||||||
conns_keys = jax.random.split(k2, num=conns.shape[0])
|
conns_keys = jax.random.split(k2, num=conns.shape[0])
|
||||||
|
|
||||||
new_nodes = jax.vmap(genome.node_gene.mutate, in_axes=(0, 0))(nodes_keys, nodes)
|
new_nodes = jax.vmap(genome.node_gene.mutate)(nodes_keys, nodes)
|
||||||
new_conns = jax.vmap(genome.conn_gene.mutate, in_axes=(0, 0))(conns_keys, conns)
|
new_conns = jax.vmap(genome.conn_gene.mutate)(conns_keys, conns)
|
||||||
|
|
||||||
# nan nodes not changed
|
# nan nodes not changed
|
||||||
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)
|
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)
|
||||||
|
|||||||
@@ -8,5 +8,5 @@ class BaseNodeGene(BaseGene):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, attrs, inputs):
|
def forward(self, attrs, inputs, is_output_node=False):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -85,11 +85,14 @@ class DefaultNodeGene(BaseNodeGene):
|
|||||||
(node1[4] != node2[4])
|
(node1[4] != node2[4])
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, attrs, inputs):
|
def forward(self, attrs, inputs, is_output_node=False):
|
||||||
bias, res, act_idx, agg_idx = attrs
|
bias, res, act_idx, agg_idx = attrs
|
||||||
|
|
||||||
z = agg(agg_idx, inputs, self.aggregation_options)
|
z = agg(agg_idx, inputs, self.aggregation_options)
|
||||||
z = bias + res * z
|
z = bias + res * z
|
||||||
|
|
||||||
|
# the last output node should not be activated
|
||||||
|
if not is_output_node:
|
||||||
z = act(act_idx, z, self.activation_options)
|
z = act(act_idx, z, self.activation_options)
|
||||||
|
|
||||||
return z
|
return z
|
||||||
|
|||||||
@@ -25,19 +25,13 @@ class DefaultGenome(BaseGenome):
|
|||||||
|
|
||||||
if output_transform is not None:
|
if output_transform is not None:
|
||||||
try:
|
try:
|
||||||
aux = output_transform(jnp.zeros(num_outputs))
|
_ = output_transform(jnp.zeros(num_outputs))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Output transform function failed: {e}")
|
raise ValueError(f"Output transform function failed: {e}")
|
||||||
self.output_transform = output_transform
|
self.output_transform = output_transform
|
||||||
|
|
||||||
def transform(self, nodes, conns):
|
def transform(self, nodes, conns):
|
||||||
u_conns = unflatten_conns(nodes, conns)
|
u_conns = unflatten_conns(nodes, conns)
|
||||||
|
|
||||||
# DONE: Seems like there is a bug in this line
|
|
||||||
# conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
|
|
||||||
# modified: exist conn and enable is true
|
|
||||||
# conn_enable = jnp.where( (~jnp.isnan(u_conns[0])) & (u_conns[0] == 1), True, False)
|
|
||||||
# advanced modified: when and only when enabled is True
|
|
||||||
conn_enable = u_conns[0] == 1
|
conn_enable = u_conns[0] == 1
|
||||||
|
|
||||||
# remove enable attr
|
# remove enable attr
|
||||||
@@ -64,13 +58,7 @@ class DefaultGenome(BaseGenome):
|
|||||||
|
|
||||||
def hit():
|
def hit():
|
||||||
ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values)
|
ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values)
|
||||||
# ins = values * weights[:, i]
|
|
||||||
|
|
||||||
z = self.node_gene.forward(nodes_attrs[i], ins)
|
z = self.node_gene.forward(nodes_attrs[i], ins)
|
||||||
# z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins)
|
|
||||||
# z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
|
|
||||||
# z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z)
|
|
||||||
|
|
||||||
new_values = values.at[i].set(z)
|
new_values = values.at[i].set(z)
|
||||||
return new_values
|
return new_values
|
||||||
|
|
||||||
@@ -78,7 +66,11 @@ class DefaultGenome(BaseGenome):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
# the val of input nodes is obtained by the task, not by calculation
|
# the val of input nodes is obtained by the task, not by calculation
|
||||||
values = jax.lax.cond(jnp.isin(i, self.input_idx), miss, hit)
|
values = jax.lax.cond(
|
||||||
|
jnp.isin(i, self.input_idx),
|
||||||
|
miss,
|
||||||
|
hit
|
||||||
|
)
|
||||||
|
|
||||||
return values, idx + 1
|
return values, idx + 1
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ if __name__ == '__main__':
|
|||||||
activation_default=Act.tanh,
|
activation_default=Act.tanh,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
pop_size=10,
|
pop_size=10000,
|
||||||
species_size=10,
|
species_size=10,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ class Pipeline:
|
|||||||
pop_transformed
|
pop_transformed
|
||||||
)
|
)
|
||||||
|
|
||||||
fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses)
|
# fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses)
|
||||||
|
|
||||||
alg_state = self.algorithm.tell(state.alg, fitnesses)
|
alg_state = self.algorithm.tell(state.alg, fitnesses)
|
||||||
|
|
||||||
@@ -80,8 +80,10 @@ class Pipeline:
|
|||||||
|
|
||||||
def auto_run(self, ini_state):
|
def auto_run(self, ini_state):
|
||||||
state = ini_state
|
state = ini_state
|
||||||
|
print("start compile")
|
||||||
|
tic = time.time()
|
||||||
compiled_step = jax.jit(self.step).lower(ini_state).compile()
|
compiled_step = jax.jit(self.step).lower(ini_state).compile()
|
||||||
|
print(f"compile finished, cost time: {time.time() - tic:.6f}s", )
|
||||||
for _ in range(self.generation_limit):
|
for _ in range(self.generation_limit):
|
||||||
|
|
||||||
self.generation_timestamp = time.time()
|
self.generation_timestamp = time.time()
|
||||||
@@ -91,11 +93,6 @@ class Pipeline:
|
|||||||
state, fitnesses = compiled_step(state)
|
state, fitnesses = compiled_step(state)
|
||||||
|
|
||||||
fitnesses = jax.device_get(fitnesses)
|
fitnesses = jax.device_get(fitnesses)
|
||||||
for idx, fitnesses_i in enumerate(fitnesses):
|
|
||||||
if np.isnan(fitnesses_i):
|
|
||||||
print("Fitness is nan")
|
|
||||||
print(previous_pop[0][idx], previous_pop[1][idx])
|
|
||||||
assert False
|
|
||||||
|
|
||||||
self.analysis(state, previous_pop, fitnesses)
|
self.analysis(state, previous_pop, fitnesses)
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from .. import BaseProblem
|
|||||||
class RLEnv(BaseProblem):
|
class RLEnv(BaseProblem):
|
||||||
jitable = True
|
jitable = True
|
||||||
|
|
||||||
# TODO: move output transform to algorithm
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user