diff --git a/tensorneat/.idea/.gitignore b/tensorneat/.idea/.gitignore new file mode 100644 index 0000000..35410ca --- /dev/null +++ b/tensorneat/.idea/.gitignore @@ -0,0 +1,8 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml +# 基于编辑器的 HTTP 客户端请求 +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/tensorneat/algorithm/neat/__init__.py b/tensorneat/algorithm/neat/__init__.py index 97185ca..e64148a 100644 --- a/tensorneat/algorithm/neat/__init__.py +++ b/tensorneat/algorithm/neat/__init__.py @@ -1,3 +1,4 @@ +from .ga import * from .gene import * from .genome import * from .species import * diff --git a/tensorneat/algorithm/neat/ga/crossover/base.py b/tensorneat/algorithm/neat/ga/crossover/base.py index 9f638a2..9f84ac9 100644 --- a/tensorneat/algorithm/neat/ga/crossover/base.py +++ b/tensorneat/algorithm/neat/ga/crossover/base.py @@ -1,3 +1,3 @@ class BaseCrossover: def __call__(self, randkey, genome, nodes1, nodes2, conns1, conns2): - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/tensorneat/algorithm/neat/ga/crossover/default.py b/tensorneat/algorithm/neat/ga/crossover/default.py index 4d01b41..deaf745 100644 --- a/tensorneat/algorithm/neat/ga/crossover/default.py +++ b/tensorneat/algorithm/neat/ga/crossover/default.py @@ -2,6 +2,7 @@ import jax, jax.numpy as jnp from .base import BaseCrossover + class DefaultCrossover(BaseCrossover): def __call__(self, randkey, genome, nodes1, conns1, nodes2, conns2): @@ -14,17 +15,19 @@ class DefaultCrossover(BaseCrossover): # crossover nodes keys1, keys2 = nodes1[:, 0], nodes2[:, 0] # make homologous genes align in nodes2 align with nodes1 - nodes2 = self.align_array(keys1, keys2, nodes2, False) + nodes2 = self.align_array(keys1, keys2, nodes2, is_conn=False) # For not homologous genes, use the value of nodes1(winner) # For homologous genes, use the crossover result between nodes1 and nodes2 - new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, self.crossover_gene(randkey_1, nodes1, nodes2)) + new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, + self.crossover_gene(randkey_1, nodes1, nodes2, is_conn=False)) # crossover connections con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2] - conns2 = self.align_array(con_keys1, con_keys2, conns2, True) + conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True) - new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, self.crossover_gene(randkey_2, conns1, conns2)) + new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, + self.crossover_gene(randkey_2, conns1, conns2, is_conn=True)) return new_nodes, new_conns @@ -54,14 +57,14 @@ class DefaultCrossover(BaseCrossover): return refactor_ar2 - def crossover_gene(self, rand_key, g1, g2): - """ - crossover two genes - :param rand_key: - :param g1: - :param g2: - :return: - only gene with the same key will be crossover, thus don't need to consider change key - """ + def crossover_gene(self, rand_key, g1, g2, is_conn): r = jax.random.uniform(rand_key, shape=g1.shape) - return jnp.where(r > 0.5, g1, g2) + new_gene = jnp.where(r > 0.5, g1, g2) + if is_conn: # fix enabled + enabled = jnp.where( + g1[:, 2] + g2[:, 2] > 0, # any of them is enabled + 1, + 0 + ) + new_gene = new_gene.at[:, 2].set(enabled) + return new_gene diff --git a/tensorneat/algorithm/neat/ga/mutation/default.py b/tensorneat/algorithm/neat/ga/mutation/default.py index 33ba6fe..331dbf3 100644 --- a/tensorneat/algorithm/neat/ga/mutation/default.py +++ b/tensorneat/algorithm/neat/ga/mutation/default.py @@ -154,8 +154,8 @@ class DefaultMutation(BaseMutation): nodes_keys = jax.random.split(k1, num=nodes.shape[0]) conns_keys = jax.random.split(k2, num=conns.shape[0]) - new_nodes = jax.vmap(genome.node_gene.mutate, in_axes=(0, 0))(nodes_keys, nodes) - new_conns = jax.vmap(genome.conn_gene.mutate, in_axes=(0, 0))(conns_keys, conns) + new_nodes = jax.vmap(genome.node_gene.mutate)(nodes_keys, nodes) + new_conns = jax.vmap(genome.conn_gene.mutate)(conns_keys, conns) # nan nodes not changed new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes) diff --git a/tensorneat/algorithm/neat/gene/node/base.py b/tensorneat/algorithm/neat/gene/node/base.py index 465050a..2ebfd1b 100644 --- a/tensorneat/algorithm/neat/gene/node/base.py +++ b/tensorneat/algorithm/neat/gene/node/base.py @@ -8,5 +8,5 @@ class BaseNodeGene(BaseGene): def __init__(self): super().__init__() - def forward(self, attrs, inputs): + def forward(self, attrs, inputs, is_output_node=False): raise NotImplementedError diff --git a/tensorneat/algorithm/neat/gene/node/default.py b/tensorneat/algorithm/neat/gene/node/default.py index 3f43d4f..066ec5e 100644 --- a/tensorneat/algorithm/neat/gene/node/default.py +++ b/tensorneat/algorithm/neat/gene/node/default.py @@ -85,11 +85,14 @@ class DefaultNodeGene(BaseNodeGene): (node1[4] != node2[4]) ) - def forward(self, attrs, inputs): + def forward(self, attrs, inputs, is_output_node=False): bias, res, act_idx, agg_idx = attrs z = agg(agg_idx, inputs, self.aggregation_options) z = bias + res * z - z = act(act_idx, z, self.activation_options) + + # the last output node should not be activated + if not is_output_node: + z = act(act_idx, z, self.activation_options) return z diff --git a/tensorneat/algorithm/neat/genome/default.py b/tensorneat/algorithm/neat/genome/default.py index 768145d..fcc273c 100644 --- a/tensorneat/algorithm/neat/genome/default.py +++ b/tensorneat/algorithm/neat/genome/default.py @@ -25,19 +25,13 @@ class DefaultGenome(BaseGenome): if output_transform is not None: try: - aux = output_transform(jnp.zeros(num_outputs)) + _ = output_transform(jnp.zeros(num_outputs)) except Exception as e: raise ValueError(f"Output transform function failed: {e}") self.output_transform = output_transform def transform(self, nodes, conns): u_conns = unflatten_conns(nodes, conns) - - # DONE: Seems like there is a bug in this line - # conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False) - # modified: exist conn and enable is true - # conn_enable = jnp.where( (~jnp.isnan(u_conns[0])) & (u_conns[0] == 1), True, False) - # advanced modified: when and only when enabled is True conn_enable = u_conns[0] == 1 # remove enable attr @@ -64,13 +58,7 @@ class DefaultGenome(BaseGenome): def hit(): ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values) - # ins = values * weights[:, i] - z = self.node_gene.forward(nodes_attrs[i], ins) - # z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins) - # z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias - # z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z) - new_values = values.at[i].set(z) return new_values @@ -78,7 +66,11 @@ class DefaultGenome(BaseGenome): return values # the val of input nodes is obtained by the task, not by calculation - values = jax.lax.cond(jnp.isin(i, self.input_idx), miss, hit) + values = jax.lax.cond( + jnp.isin(i, self.input_idx), + miss, + hit + ) return values, idx + 1 diff --git a/tensorneat/examples/brax/walker.py b/tensorneat/examples/brax/walker.py index 4cdd645..993567d 100644 --- a/tensorneat/examples/brax/walker.py +++ b/tensorneat/examples/brax/walker.py @@ -18,7 +18,7 @@ if __name__ == '__main__': activation_default=Act.tanh, ) ), - pop_size=10, + pop_size=10000, species_size=10, ), ), diff --git a/tensorneat/pipeline.py b/tensorneat/pipeline.py index ad33945..4d34322 100644 --- a/tensorneat/pipeline.py +++ b/tensorneat/pipeline.py @@ -69,7 +69,7 @@ class Pipeline: pop_transformed ) - fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses) + # fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses) alg_state = self.algorithm.tell(state.alg, fitnesses) @@ -80,8 +80,10 @@ class Pipeline: def auto_run(self, ini_state): state = ini_state + print("start compile") + tic = time.time() compiled_step = jax.jit(self.step).lower(ini_state).compile() - + print(f"compile finished, cost time: {time.time() - tic:.6f}s", ) for _ in range(self.generation_limit): self.generation_timestamp = time.time() @@ -91,11 +93,6 @@ class Pipeline: state, fitnesses = compiled_step(state) fitnesses = jax.device_get(fitnesses) - for idx, fitnesses_i in enumerate(fitnesses): - if np.isnan(fitnesses_i): - print("Fitness is nan") - print(previous_pop[0][idx], previous_pop[1][idx]) - assert False self.analysis(state, previous_pop, fitnesses) diff --git a/tensorneat/problem/rl_env/rl_jit.py b/tensorneat/problem/rl_env/rl_jit.py index 128ebfb..b0e5e76 100644 --- a/tensorneat/problem/rl_env/rl_jit.py +++ b/tensorneat/problem/rl_env/rl_jit.py @@ -8,7 +8,6 @@ from .. import BaseProblem class RLEnv(BaseProblem): jitable = True - # TODO: move output transform to algorithm def __init__(self): super().__init__()