fix bug in crossover: the child from two normal networks should always be normal.

This commit is contained in:
wls2002
2024-05-22 10:27:32 +08:00
parent d1559317d1
commit 6a37563696
11 changed files with 46 additions and 43 deletions

8
tensorneat/.idea/.gitignore generated vendored Normal file
View File

@@ -0,0 +1,8 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

View File

@@ -1,3 +1,4 @@
from .ga import *
from .gene import * from .gene import *
from .genome import * from .genome import *
from .species import * from .species import *

View File

@@ -2,6 +2,7 @@ import jax, jax.numpy as jnp
from .base import BaseCrossover from .base import BaseCrossover
class DefaultCrossover(BaseCrossover): class DefaultCrossover(BaseCrossover):
def __call__(self, randkey, genome, nodes1, conns1, nodes2, conns2): def __call__(self, randkey, genome, nodes1, conns1, nodes2, conns2):
@@ -14,17 +15,19 @@ class DefaultCrossover(BaseCrossover):
# crossover nodes # crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0] keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
# make homologous genes align in nodes2 align with nodes1 # make homologous genes align in nodes2 align with nodes1
nodes2 = self.align_array(keys1, keys2, nodes2, False) nodes2 = self.align_array(keys1, keys2, nodes2, is_conn=False)
# For not homologous genes, use the value of nodes1(winner) # For not homologous genes, use the value of nodes1(winner)
# For homologous genes, use the crossover result between nodes1 and nodes2 # For homologous genes, use the crossover result between nodes1 and nodes2
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, self.crossover_gene(randkey_1, nodes1, nodes2)) new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1,
self.crossover_gene(randkey_1, nodes1, nodes2, is_conn=False))
# crossover connections # crossover connections
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2] con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
conns2 = self.align_array(con_keys1, con_keys2, conns2, True) conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, self.crossover_gene(randkey_2, conns1, conns2)) new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1,
self.crossover_gene(randkey_2, conns1, conns2, is_conn=True))
return new_nodes, new_conns return new_nodes, new_conns
@@ -54,14 +57,14 @@ class DefaultCrossover(BaseCrossover):
return refactor_ar2 return refactor_ar2
def crossover_gene(self, rand_key, g1, g2): def crossover_gene(self, rand_key, g1, g2, is_conn):
"""
crossover two genes
:param rand_key:
:param g1:
:param g2:
:return:
only gene with the same key will be crossover, thus don't need to consider change key
"""
r = jax.random.uniform(rand_key, shape=g1.shape) r = jax.random.uniform(rand_key, shape=g1.shape)
return jnp.where(r > 0.5, g1, g2) new_gene = jnp.where(r > 0.5, g1, g2)
if is_conn: # fix enabled
enabled = jnp.where(
g1[:, 2] + g2[:, 2] > 0, # any of them is enabled
1,
0
)
new_gene = new_gene.at[:, 2].set(enabled)
return new_gene

View File

@@ -154,8 +154,8 @@ class DefaultMutation(BaseMutation):
nodes_keys = jax.random.split(k1, num=nodes.shape[0]) nodes_keys = jax.random.split(k1, num=nodes.shape[0])
conns_keys = jax.random.split(k2, num=conns.shape[0]) conns_keys = jax.random.split(k2, num=conns.shape[0])
new_nodes = jax.vmap(genome.node_gene.mutate, in_axes=(0, 0))(nodes_keys, nodes) new_nodes = jax.vmap(genome.node_gene.mutate)(nodes_keys, nodes)
new_conns = jax.vmap(genome.conn_gene.mutate, in_axes=(0, 0))(conns_keys, conns) new_conns = jax.vmap(genome.conn_gene.mutate)(conns_keys, conns)
# nan nodes not changed # nan nodes not changed
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes) new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)

View File

@@ -8,5 +8,5 @@ class BaseNodeGene(BaseGene):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def forward(self, attrs, inputs): def forward(self, attrs, inputs, is_output_node=False):
raise NotImplementedError raise NotImplementedError

View File

@@ -85,11 +85,14 @@ class DefaultNodeGene(BaseNodeGene):
(node1[4] != node2[4]) (node1[4] != node2[4])
) )
def forward(self, attrs, inputs): def forward(self, attrs, inputs, is_output_node=False):
bias, res, act_idx, agg_idx = attrs bias, res, act_idx, agg_idx = attrs
z = agg(agg_idx, inputs, self.aggregation_options) z = agg(agg_idx, inputs, self.aggregation_options)
z = bias + res * z z = bias + res * z
z = act(act_idx, z, self.activation_options)
# the last output node should not be activated
if not is_output_node:
z = act(act_idx, z, self.activation_options)
return z return z

View File

@@ -25,19 +25,13 @@ class DefaultGenome(BaseGenome):
if output_transform is not None: if output_transform is not None:
try: try:
aux = output_transform(jnp.zeros(num_outputs)) _ = output_transform(jnp.zeros(num_outputs))
except Exception as e: except Exception as e:
raise ValueError(f"Output transform function failed: {e}") raise ValueError(f"Output transform function failed: {e}")
self.output_transform = output_transform self.output_transform = output_transform
def transform(self, nodes, conns): def transform(self, nodes, conns):
u_conns = unflatten_conns(nodes, conns) u_conns = unflatten_conns(nodes, conns)
# DONE: Seems like there is a bug in this line
# conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
# modified: exist conn and enable is true
# conn_enable = jnp.where( (~jnp.isnan(u_conns[0])) & (u_conns[0] == 1), True, False)
# advanced modified: when and only when enabled is True
conn_enable = u_conns[0] == 1 conn_enable = u_conns[0] == 1
# remove enable attr # remove enable attr
@@ -64,13 +58,7 @@ class DefaultGenome(BaseGenome):
def hit(): def hit():
ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values) ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values)
# ins = values * weights[:, i]
z = self.node_gene.forward(nodes_attrs[i], ins) z = self.node_gene.forward(nodes_attrs[i], ins)
# z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins)
# z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
# z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z)
new_values = values.at[i].set(z) new_values = values.at[i].set(z)
return new_values return new_values
@@ -78,7 +66,11 @@ class DefaultGenome(BaseGenome):
return values return values
# the val of input nodes is obtained by the task, not by calculation # the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(jnp.isin(i, self.input_idx), miss, hit) values = jax.lax.cond(
jnp.isin(i, self.input_idx),
miss,
hit
)
return values, idx + 1 return values, idx + 1

View File

@@ -18,7 +18,7 @@ if __name__ == '__main__':
activation_default=Act.tanh, activation_default=Act.tanh,
) )
), ),
pop_size=10, pop_size=10000,
species_size=10, species_size=10,
), ),
), ),

View File

@@ -69,7 +69,7 @@ class Pipeline:
pop_transformed pop_transformed
) )
fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses) # fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses)
alg_state = self.algorithm.tell(state.alg, fitnesses) alg_state = self.algorithm.tell(state.alg, fitnesses)
@@ -80,8 +80,10 @@ class Pipeline:
def auto_run(self, ini_state): def auto_run(self, ini_state):
state = ini_state state = ini_state
print("start compile")
tic = time.time()
compiled_step = jax.jit(self.step).lower(ini_state).compile() compiled_step = jax.jit(self.step).lower(ini_state).compile()
print(f"compile finished, cost time: {time.time() - tic:.6f}s", )
for _ in range(self.generation_limit): for _ in range(self.generation_limit):
self.generation_timestamp = time.time() self.generation_timestamp = time.time()
@@ -91,11 +93,6 @@ class Pipeline:
state, fitnesses = compiled_step(state) state, fitnesses = compiled_step(state)
fitnesses = jax.device_get(fitnesses) fitnesses = jax.device_get(fitnesses)
for idx, fitnesses_i in enumerate(fitnesses):
if np.isnan(fitnesses_i):
print("Fitness is nan")
print(previous_pop[0][idx], previous_pop[1][idx])
assert False
self.analysis(state, previous_pop, fitnesses) self.analysis(state, previous_pop, fitnesses)

View File

@@ -8,7 +8,6 @@ from .. import BaseProblem
class RLEnv(BaseProblem): class RLEnv(BaseProblem):
jitable = True jitable = True
# TODO: move output transform to algorithm
def __init__(self): def __init__(self):
super().__init__() super().__init__()