From 4ad9f0a85a53fd6ed88f7a4eb38b99dc6db61486 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Fri, 31 May 2024 22:06:25 +0800 Subject: [PATCH] remove attr enable for conn --- .../algorithm/neat/ga/mutation/default.py | 36 ++++++-------- tensorneat/algorithm/neat/gene/conn/base.py | 21 ++------ .../algorithm/neat/gene/conn/default.py | 9 ++-- tensorneat/algorithm/neat/genome/base.py | 4 +- tensorneat/algorithm/neat/genome/default.py | 11 ++--- tensorneat/algorithm/neat/genome/recurrent.py | 8 ---- tensorneat/examples/func_fit/xor.py | 8 ++-- tensorneat/test/test_genome.py | 48 ++++--------------- tensorneat/utils/tools.py | 6 +-- 9 files changed, 43 insertions(+), 108 deletions(-) diff --git a/tensorneat/algorithm/neat/ga/mutation/default.py b/tensorneat/algorithm/neat/ga/mutation/default.py index 54f492f..823e0dd 100644 --- a/tensorneat/algorithm/neat/ga/mutation/default.py +++ b/tensorneat/algorithm/neat/ga/mutation/default.py @@ -45,8 +45,8 @@ class DefaultMutation(BaseMutation): i_key, o_key, idx = self.choice_connection_key(key_, conns_) def successful_add_node(): - # disable the connection - new_conns = conns_.at[idx, 2].set(False) + # remove the original connection + new_conns = delete_conn_by_pos(conns_, idx) # add a new node new_nodes = add_node( @@ -58,14 +58,12 @@ class DefaultMutation(BaseMutation): new_conns, i_key, new_node_key, - True, genome.conn_gene.new_custom_attrs(state), ) new_conns = add_conn( new_conns, new_node_key, o_key, - True, genome.conn_gene.new_custom_attrs(state), ) @@ -140,27 +138,26 @@ class DefaultMutation(BaseMutation): def successful(): return nodes_, add_conn( - conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs(state) + conns_, i_key, o_key, genome.conn_gene.new_custom_attrs(state) ) - def already_exist(): - return nodes_, conns_.at[conn_pos, 2].set(True) - if genome.network_type == "feedforward": u_cons = unflatten_conns(nodes_, conns_) - cons_exist = ~jnp.isnan(u_cons[0, :, :]) - is_cycle = check_cycles(nodes_, cons_exist, from_idx, to_idx) + conns_exist = ~jnp.isnan(u_cons[0, :, :]) + is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx) return jax.lax.cond( - is_already_exist, - already_exist, - lambda: jax.lax.cond( - is_cycle & (remain_conn_space < 1), nothing, successful - ), + is_already_exist | is_cycle | (remain_conn_space < 1), + nothing, + successful, ) elif genome.network_type == "recurrent": - return jax.lax.cond(is_already_exist, already_exist, successful) + return jax.lax.cond( + is_already_exist | (remain_conn_space < 1), + nothing, + successful, + ) else: raise ValueError(f"Invalid network type: {genome.network_type}") @@ -169,19 +166,16 @@ class DefaultMutation(BaseMutation): # randomly choose a connection i_key, o_key, idx = self.choice_connection_key(key_, conns_) - def successfully_delete_connection(): - return nodes_, delete_conn_by_pos(conns_, idx) - return jax.lax.cond( idx == I_INF, lambda: (nodes_, conns_), # nothing - successfully_delete_connection, + lambda: (nodes_, delete_conn_by_pos(conns_, idx)), # success ) k1, k2, k3, k4 = jax.random.split(randkey, num=4) r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,)) - def no(key_, nodes_, conns_): + def no(_, nodes_, conns_): return nodes_, conns_ if self.node_add > 0: diff --git a/tensorneat/algorithm/neat/gene/conn/base.py b/tensorneat/algorithm/neat/gene/conn/base.py index 17e6cb8..017b671 100644 --- a/tensorneat/algorithm/neat/gene/conn/base.py +++ b/tensorneat/algorithm/neat/gene/conn/base.py @@ -4,27 +4,16 @@ from .. import BaseGene class BaseConnGene(BaseGene): "Base class for connection genes." - fixed_attrs = ["input_index", "output_index", "enabled"] + fixed_attrs = ["input_index", "output_index"] def __init__(self): super().__init__() def crossover(self, state, randkey, gene1, gene2): - def crossover_attr(): - return jnp.where( - jax.random.normal(randkey, gene1.shape) > 0, - gene1, - gene2, - ) - - return jax.lax.cond( - gene1[2] == gene2[2], # if both genes are enabled or disabled - crossover_attr, # then randomly pick attributes from gene1 or gene2 - lambda: jnp.where( # one gene is enabled and the other is disabled - gene1[2], # if gene1 is enabled - gene1, # then return gene1 - gene2, # else return gene2 - ), + return jnp.where( + jax.random.normal(randkey, gene1.shape) > 0, + gene1, + gene2, ) def forward(self, state, attrs, inputs): diff --git a/tensorneat/algorithm/neat/gene/conn/default.py b/tensorneat/algorithm/neat/gene/conn/default.py index 2f2ed04..e645d36 100644 --- a/tensorneat/algorithm/neat/gene/conn/default.py +++ b/tensorneat/algorithm/neat/gene/conn/default.py @@ -38,10 +38,9 @@ class DefaultConnGene(BaseConnGene): def mutate(self, state, randkey, conn): input_index = conn[0] output_index = conn[1] - enabled = conn[2] weight = mutate_float( randkey, - conn[3], + conn[2], self.weight_init_mean, self.weight_init_std, self.weight_mutate_power, @@ -49,12 +48,10 @@ class DefaultConnGene(BaseConnGene): self.weight_replace_rate, ) - return jnp.array([input_index, output_index, enabled, weight]) + return jnp.array([input_index, output_index, weight]) def distance(self, state, attrs1, attrs2): - return (attrs1[2] != attrs2[2]) + jnp.abs( - attrs1[3] - attrs2[3] - ) # enable + weight + return jnp.abs(attrs1[0] - attrs2[0]) def forward(self, state, attrs, inputs): weight = attrs[0] diff --git a/tensorneat/algorithm/neat/genome/base.py b/tensorneat/algorithm/neat/genome/base.py index f52a210..e7807e9 100644 --- a/tensorneat/algorithm/neat/genome/base.py +++ b/tensorneat/algorithm/neat/genome/base.py @@ -106,21 +106,19 @@ class BaseGenome: self.input_idx, jnp.full_like(self.input_idx, new_node_key) ] conns = conns.at[self.input_idx, :2].set(input_conns) # in-keys, out-keys - conns = conns.at[self.input_idx, 2].set(True) # enable # output-hidden connections output_conns = jnp.c_[ jnp.full_like(self.output_idx, new_node_key), self.output_idx ] conns = conns.at[self.output_idx, :2].set(output_conns) # in-keys, out-keys - conns = conns.at[self.output_idx, 2].set(True) # enable conn_keys = jax.random.split(k2, num=len(self.input_idx) + len(self.output_idx)) # generate random attributes for conns random_conn_attrs = jax.vmap( self.conn_gene.new_random_attrs, in_axes=(None, 0) )(state, conn_keys) - conns = conns.at[: len(conn_keys), 3:].set(random_conn_attrs) + conns = conns.at[: len(conn_keys), 2:].set(random_conn_attrs) return nodes, conns diff --git a/tensorneat/algorithm/neat/genome/default.py b/tensorneat/algorithm/neat/genome/default.py index fb2eb9d..982b991 100644 --- a/tensorneat/algorithm/neat/genome/default.py +++ b/tensorneat/algorithm/neat/genome/default.py @@ -45,19 +45,15 @@ class DefaultGenome(BaseGenome): def transform(self, state, nodes, conns): u_conns = unflatten_conns(nodes, conns) - conn_enable = u_conns[0] == 1 + conn_exist = ~jnp.isnan(u_conns[0]) - # remove enable attr - u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan) - seqs = topological_sort(nodes, conn_enable) + seqs = topological_sort(nodes, conn_exist) return seqs, nodes, u_conns def restore(self, state, transformed): seqs, nodes, u_conns = transformed conns = flatten_conns(nodes, u_conns, C=self.max_conns) - # restore enable - conns = jnp.insert(conns, obj=2, values=1, axis=1) return nodes, conns def forward(self, state, inputs, transformed): @@ -79,14 +75,15 @@ class DefaultGenome(BaseGenome): ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 1, 0))( state, u_conns[:, :, i], values ) + z = self.node_gene.forward( state, nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx), ) - new_values = values.at[i].set(z) + new_values = values.at[i].set(z) return new_values # the val of input nodes is obtained by the task, not by calculation diff --git a/tensorneat/algorithm/neat/genome/recurrent.py b/tensorneat/algorithm/neat/genome/recurrent.py index 4670eae..9bd1880 100644 --- a/tensorneat/algorithm/neat/genome/recurrent.py +++ b/tensorneat/algorithm/neat/genome/recurrent.py @@ -47,19 +47,11 @@ class RecurrentGenome(BaseGenome): def transform(self, state, nodes, conns): u_conns = unflatten_conns(nodes, conns) - - # remove un-enable connections and remove enable attr - conn_enable = u_conns[0] == 1 - u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan) - return nodes, u_conns def restore(self, state, transformed): nodes, u_conns = transformed conns = flatten_conns(nodes, u_conns, C=self.max_conns) - - # restore enable - conns = jnp.insert(conns, obj=2, values=1, axis=1) return nodes, conns def forward(self, state, inputs, transformed): diff --git a/tensorneat/examples/func_fit/xor.py b/tensorneat/examples/func_fit/xor.py index 652c68c..c2ba018 100644 --- a/tensorneat/examples/func_fit/xor.py +++ b/tensorneat/examples/func_fit/xor.py @@ -11,8 +11,8 @@ if __name__ == "__main__": genome=DefaultGenome( num_inputs=3, num_outputs=1, - max_nodes=5, - max_conns=10, + max_nodes=50, + max_conns=100, node_gene=DefaultNodeGene( activation_default=Act.tanh, activation_options=(Act.tanh,), @@ -21,8 +21,8 @@ if __name__ == "__main__": mutation=DefaultMutation( node_add=0.1, conn_add=0.1, - node_delete=0.1, - conn_delete=0.1, + node_delete=0.05, + conn_delete=0.05, ), ), pop_size=1000, diff --git a/tensorneat/test/test_genome.py b/tensorneat/test/test_genome.py index b7da2c9..5a269fb 100644 --- a/tensorneat/test/test_genome.py +++ b/tensorneat/test/test_genome.py @@ -21,10 +21,10 @@ def test_default(): # in_node, out_node, enable, weight conns = jnp.array( [ - [0, 3, 1, 0.5], # in[0] -> hidden[0] - [1, 4, 1, 0.5], # in[1] -> hidden[1] - [3, 2, 1, 0.5], # hidden[0] -> out[0] - [4, 2, 1, 0.5], # hidden[1] -> out[0] + [0, 3, 0.5], # in[0] -> hidden[0] + [1, 4, 0.5], # in[1] -> hidden[1] + [3, 2, 0.5], # hidden[0] -> out[0] + [4, 2, 0.5], # hidden[1] -> out[0] ] ) @@ -54,22 +54,6 @@ def test_default(): assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]])) # expected: [[0.5], [0.75], [0.75], [1]] - print("\n-------------------------------------------------------\n") - - conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0] - print(conns) - - transformed = genome.transform(state, nodes, conns) - print(*transformed, sep="\n") - - inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]]) - outputs = jax.vmap(genome.forward, in_axes=(None, 0, None))( - state, inputs, transformed - ) - print(outputs) - assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]])) - # expected: [[0.5], [0.75], [0.5], [0.75]] - def test_recurrent(): @@ -87,10 +71,10 @@ def test_recurrent(): # in_node, out_node, enable, weight conns = jnp.array( [ - [0, 3, 1, 0.5], # in[0] -> hidden[0] - [1, 4, 1, 0.5], # in[1] -> hidden[1] - [3, 2, 1, 0.5], # hidden[0] -> out[0] - [4, 2, 1, 0.5], # hidden[1] -> out[0] + [0, 3, 0.5], # in[0] -> hidden[0] + [1, 4, 0.5], # in[1] -> hidden[1] + [3, 2, 0.5], # hidden[0] -> out[0] + [4, 2, 0.5], # hidden[1] -> out[0] ] ) @@ -121,22 +105,6 @@ def test_recurrent(): assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]])) # expected: [[0.5], [0.75], [0.75], [1]] - print("\n-------------------------------------------------------\n") - - conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0] - print(conns) - - transformed = genome.transform(state, nodes, conns) - print(*transformed, sep="\n") - - inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]]) - outputs = jax.vmap(genome.forward, in_axes=(None, 0, None))( - state, inputs, transformed - ) - print(outputs) - assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]])) - # expected: [[0.5], [0.75], [0.5], [0.75]] - def test_random_initialize(): genome = DefaultGenome( diff --git a/tensorneat/utils/tools.py b/tensorneat/utils/tools.py index f6f1de1..d0030ce 100644 --- a/tensorneat/utils/tools.py +++ b/tensorneat/utils/tools.py @@ -168,15 +168,15 @@ def delete_node_by_pos(nodes, pos): return nodes.at[pos].set(jnp.nan) -def add_conn(conns, i_key, o_key, enable: bool, attrs): +def add_conn(conns, i_key, o_key, attrs): """ Add a new connection to the genome. The new connection will place at the first NaN row. """ con_keys = conns[:, 0] pos = fetch_first(jnp.isnan(con_keys)) - new_conns = conns.at[pos, 0:3].set(jnp.array([i_key, o_key, enable])) - return new_conns.at[pos, 3:].set(attrs) + new_conns = conns.at[pos, 0:2].set(jnp.array([i_key, o_key])) + return new_conns.at[pos, 2:].set(attrs) def delete_conn_by_pos(conns, pos):