From edfb0596e7539b0a6d9bab75b4c92b5c2d152e63 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Mon, 3 Jun 2024 10:53:15 +0800 Subject: [PATCH] add input_transform and update_input_transform; change the args for genome.forward. Origin: (state, inputs, transformed) New: (state, transformed, inputs) --- tensorneat/algorithm/base.py | 2 +- tensorneat/algorithm/hyperneat/hyperneat.py | 4 +- .../algorithm/neat/ga/mutation/default.py | 4 +- tensorneat/algorithm/neat/gene/node/base.py | 19 +++ .../algorithm/neat/gene/node/normalized.py | 37 +++++ tensorneat/algorithm/neat/genome/base.py | 2 +- tensorneat/algorithm/neat/genome/default.py | 98 ++++++++---- tensorneat/algorithm/neat/genome/recurrent.py | 7 +- tensorneat/algorithm/neat/neat.py | 4 +- tensorneat/examples/func_fit/xor.py | 14 +- tensorneat/examples/func_fit/xor_kan.py | 6 +- tensorneat/problem/func_fit/func_fit.py | 8 +- tensorneat/problem/rl_env/brax_env.py | 2 +- tensorneat/problem/rl_env/rl_jit.py | 2 +- tensorneat/test/test_genome.py | 143 +++--------------- tensorneat/utils/tools.py | 54 ++----- 16 files changed, 185 insertions(+), 221 deletions(-) diff --git a/tensorneat/algorithm/base.py b/tensorneat/algorithm/base.py index 23999c4..2c9ee50 100644 --- a/tensorneat/algorithm/base.py +++ b/tensorneat/algorithm/base.py @@ -22,7 +22,7 @@ class BaseAlgorithm: def restore(self, state, transformed): raise NotImplementedError - def forward(self, state, inputs, transformed): + def forward(self, state, transformed, inputs): raise NotImplementedError def update_by_batch(self, state, batch_input, transformed): diff --git a/tensorneat/algorithm/hyperneat/hyperneat.py b/tensorneat/algorithm/hyperneat/hyperneat.py index a802451..dce9cab 100644 --- a/tensorneat/algorithm/hyperneat/hyperneat.py +++ b/tensorneat/algorithm/hyperneat/hyperneat.py @@ -54,8 +54,8 @@ class HyperNEAT(BaseAlgorithm): def transform(self, state, individual): transformed = self.neat.transform(state, individual) - query_res = jax.vmap(self.neat.forward, in_axes=(None, 0, None))( - state, self.substrate.query_coors, transformed + query_res = jax.vmap(self.neat.forward, in_axes=(None, None, 0))( + state, transformed, self.substrate.query_coors ) # mute the connection with weight below threshold query_res = jnp.where( diff --git a/tensorneat/algorithm/neat/ga/mutation/default.py b/tensorneat/algorithm/neat/ga/mutation/default.py index db3d64c..2786022 100644 --- a/tensorneat/algorithm/neat/ga/mutation/default.py +++ b/tensorneat/algorithm/neat/ga/mutation/default.py @@ -163,8 +163,8 @@ class DefaultMutation(BaseMutation): ) if genome.network_type == "feedforward": - u_cons = unflatten_conns(nodes_, conns_) - conns_exist = ~jnp.isnan(u_cons[0, :, :]) + u_conns = unflatten_conns(nodes_, conns_) + conns_exist = (u_conns != I_INF) is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx) return jax.lax.cond( diff --git a/tensorneat/algorithm/neat/gene/node/base.py b/tensorneat/algorithm/neat/gene/node/base.py index 64d242e..452bf91 100644 --- a/tensorneat/algorithm/neat/gene/node/base.py +++ b/tensorneat/algorithm/neat/gene/node/base.py @@ -12,6 +12,13 @@ class BaseNodeGene(BaseGene): def forward(self, state, attrs, inputs, is_output_node=False): raise NotImplementedError + def input_transform(self, state, attrs, inputs): + """ + make transformation in the input node. + default: do nothing + """ + return inputs + def update_by_batch(self, state, attrs, batch_inputs, is_output_node=False): # default: do not update attrs, but to calculate batch_res return ( @@ -20,3 +27,15 @@ class BaseNodeGene(BaseGene): ), attrs, ) + + def update_input_transform(self, state, attrs, batch_inputs): + """ + update the attrs for transformation in the input node. + default: do nothing + """ + return ( + jax.vmap(self.input_transform, in_axes=(None, None, 0))( + state, attrs, batch_inputs + ), + attrs, + ) diff --git a/tensorneat/algorithm/neat/gene/node/normalized.py b/tensorneat/algorithm/neat/gene/node/normalized.py index 9e5e4b0..717aeb8 100644 --- a/tensorneat/algorithm/neat/gene/node/normalized.py +++ b/tensorneat/algorithm/neat/gene/node/normalized.py @@ -157,6 +157,15 @@ class NormalizedNode(BaseNodeGene): return z + def input_transform(self, state, attrs, inputs): + """ + make transform in the input node. + the normalization also need be done in the first node. + """ + bias, agg, act, mean, std, alpha, beta = attrs + inputs = (inputs - mean) / (std + self.eps) * alpha + beta # normalization + return inputs + def update_by_batch(self, state, attrs, batch_inputs, is_output_node=False): bias, agg, act, mean, std, alpha, beta = attrs @@ -192,3 +201,31 @@ class NormalizedNode(BaseNodeGene): attrs = attrs.at[4].set(std) return batch_z, attrs + + def update_input_transform(self, state, attrs, batch_inputs): + """ + update the attrs for transformation in the input node. + default: do nothing + """ + bias, agg, act, mean, std, alpha, beta = attrs + + # calculate mean + valid_values_count = jnp.sum(~jnp.isnan(batch_inputs)) + valid_values_sum = jnp.sum(jnp.where(jnp.isnan(batch_inputs), 0, batch_inputs)) + mean = valid_values_sum / valid_values_count + + # calculate std + std = jnp.sqrt( + jnp.sum(jnp.where(jnp.isnan(batch_inputs), 0, (batch_inputs - mean) ** 2)) + / valid_values_count + ) + + batch_inputs = (batch_inputs - mean) / ( + std + self.eps + ) * alpha + beta # normalization + + # update mean and std to the attrs + attrs = attrs.at[3].set(mean) + attrs = attrs.at[4].set(std) + + return batch_inputs, attrs diff --git a/tensorneat/algorithm/neat/genome/base.py b/tensorneat/algorithm/neat/genome/base.py index e7807e9..e97d828 100644 --- a/tensorneat/algorithm/neat/genome/base.py +++ b/tensorneat/algorithm/neat/genome/base.py @@ -42,7 +42,7 @@ class BaseGenome: def restore(self, state, transformed): raise NotImplementedError - def forward(self, state, inputs, transformed): + def forward(self, state, transformed, inputs): raise NotImplementedError def execute_mutation(self, state, randkey, nodes, conns, new_node_key): diff --git a/tensorneat/algorithm/neat/genome/default.py b/tensorneat/algorithm/neat/genome/default.py index 982b991..58c5150 100644 --- a/tensorneat/algorithm/neat/genome/default.py +++ b/tensorneat/algorithm/neat/genome/default.py @@ -1,7 +1,16 @@ from typing import Callable import jax, jax.numpy as jnp -from utils import unflatten_conns, flatten_conns, topological_sort, I_INF +from utils import ( + unflatten_conns, + topological_sort, + I_INF, + extract_node_attrs, + extract_conn_attrs, + set_node_attrs, + set_conn_attrs, + attach_with_inf, +) from . import BaseGenome from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene @@ -45,23 +54,23 @@ class DefaultGenome(BaseGenome): def transform(self, state, nodes, conns): u_conns = unflatten_conns(nodes, conns) - conn_exist = ~jnp.isnan(u_conns[0]) + conn_exist = u_conns != I_INF seqs = topological_sort(nodes, conn_exist) - return seqs, nodes, u_conns + return seqs, nodes, conns, u_conns def restore(self, state, transformed): - seqs, nodes, u_conns = transformed - conns = flatten_conns(nodes, u_conns, C=self.max_conns) + seqs, nodes, conns, u_conns = transformed return nodes, conns - def forward(self, state, inputs, transformed): - cal_seqs, nodes, u_conns = transformed + def forward(self, state, transformed, inputs): + cal_seqs, nodes, conns, u_conns = transformed ini_vals = jnp.full((self.max_nodes,), jnp.nan) ini_vals = ini_vals.at[self.input_idx].set(inputs) - nodes_attrs = nodes[:, 1:] + nodes_attrs = jax.vmap(extract_node_attrs)(nodes) + conns_attrs = jax.vmap(extract_conn_attrs)(conns) def cond_fun(carry): values, idx = carry @@ -71,9 +80,16 @@ class DefaultGenome(BaseGenome): values, idx = carry i = cal_seqs[idx] - def hit(): - ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 1, 0))( - state, u_conns[:, :, i], values + def input_node(): + z = self.node_gene.input_transform(state, nodes_attrs[i], values[i]) + new_values = values.at[i].set(z) + return new_values + + def otherwise(): + conn_indices = u_conns[:, i] + hit_attrs = attach_with_inf(conns_attrs, conn_indices) + ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 0, 0))( + state, hit_attrs, values ) z = self.node_gene.forward( @@ -86,8 +102,7 @@ class DefaultGenome(BaseGenome): new_values = values.at[i].set(z) return new_values - # the val of input nodes is obtained by the task, not by calculation - values = jax.lax.cond(jnp.isin(i, self.input_idx), lambda: values, hit) + values = jax.lax.cond(jnp.isin(i, self.input_idx), input_node, otherwise) return values, idx + 1 @@ -99,55 +114,72 @@ class DefaultGenome(BaseGenome): return self.output_transform(vals[self.output_idx]) def update_by_batch(self, state, batch_input, transformed): - cal_seqs, nodes, u_conns = transformed + cal_seqs, nodes, conns, u_conns = transformed batch_size = batch_input.shape[0] batch_ini_vals = jnp.full((batch_size, self.max_nodes), jnp.nan) batch_ini_vals = batch_ini_vals.at[:, self.input_idx].set(batch_input) - nodes_attrs = nodes[:, 1:] + nodes_attrs = jax.vmap(extract_node_attrs)(nodes) + conns_attrs = jax.vmap(extract_conn_attrs)(conns) def cond_fun(carry): - batch_values, nodes_attrs_, u_conns_, idx = carry + batch_values, nodes_attrs_, conns_attrs_, idx = carry return (idx < self.max_nodes) & (cal_seqs[idx] != I_INF) def body_func(carry): - batch_values, nodes_attrs_, u_conns_, idx = carry + batch_values, nodes_attrs_, conns_attrs_, idx = carry i = cal_seqs[idx] - def hit(): + def input_node(): + batch, new_attrs = self.node_gene.update_input_transform( + state, nodes_attrs_[i], batch_values[:, i] + ) + return ( + batch_values.at[:, i].set(batch), + nodes_attrs_.at[i].set(new_attrs), + conns_attrs_, + ) + + def otherwise(): + + conn_indices = u_conns[:, i] + hit_attrs = attach_with_inf(conns_attrs, conn_indices) batch_ins, new_conn_attrs = jax.vmap( self.conn_gene.update_by_batch, - in_axes=(None, 1, 1), - out_axes=(1, 1), - )(state, u_conns_[:, :, i], batch_values) + in_axes=(None, 0, 1), + out_axes=(1, 0), + )(state, hit_attrs, batch_values) + batch_z, new_node_attrs = self.node_gene.update_by_batch( state, - nodes_attrs[i], + nodes_attrs_[i], batch_ins, is_output_node=jnp.isin(i, self.output_idx), ) - new_batch_values = batch_values.at[:, i].set(batch_z) + return ( - new_batch_values, + batch_values.at[:, i].set(batch_z), nodes_attrs_.at[i].set(new_node_attrs), - u_conns_.at[:, :, i].set(new_conn_attrs), + conns_attrs_.at[conn_indices].set(new_conn_attrs), ) # the val of input nodes is obtained by the task, not by calculation - (batch_values, nodes_attrs_, u_conns_) = jax.lax.cond( + (batch_values, nodes_attrs_, conns_attrs_) = jax.lax.cond( jnp.isin(i, self.input_idx), - lambda: (batch_values, nodes_attrs_, u_conns_), - hit, + input_node, + otherwise, ) - return batch_values, nodes_attrs_, u_conns_, idx + 1 + return batch_values, nodes_attrs_, conns_attrs_, idx + 1 - batch_vals, nodes_attrs, u_conns, _ = jax.lax.while_loop( - cond_fun, body_func, (batch_ini_vals, nodes_attrs, u_conns, 0) + batch_vals, nodes_attrs, conns_attrs, _ = jax.lax.while_loop( + cond_fun, body_func, (batch_ini_vals, nodes_attrs, conns_attrs, 0) ) - nodes = nodes.at[:, 1:].set(nodes_attrs) - new_transformed = (cal_seqs, nodes, u_conns) + nodes = jax.vmap(set_node_attrs)(nodes, nodes_attrs) + conns = jax.vmap(set_conn_attrs)(conns, conns_attrs) + + new_transformed = (cal_seqs, nodes, conns, u_conns) if self.output_transform is None: return batch_vals[:, self.output_idx], new_transformed diff --git a/tensorneat/algorithm/neat/genome/recurrent.py b/tensorneat/algorithm/neat/genome/recurrent.py index 9bd1880..8b4898c 100644 --- a/tensorneat/algorithm/neat/genome/recurrent.py +++ b/tensorneat/algorithm/neat/genome/recurrent.py @@ -1,7 +1,7 @@ from typing import Callable import jax, jax.numpy as jnp -from utils import unflatten_conns, flatten_conns +from utils import unflatten_conns from . import BaseGenome from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene @@ -47,11 +47,10 @@ class RecurrentGenome(BaseGenome): def transform(self, state, nodes, conns): u_conns = unflatten_conns(nodes, conns) - return nodes, u_conns + return nodes, conns, u_conns def restore(self, state, transformed): - nodes, u_conns = transformed - conns = flatten_conns(nodes, u_conns, C=self.max_conns) + nodes, conns, u_conns = transformed return nodes, conns def forward(self, state, inputs, transformed): diff --git a/tensorneat/algorithm/neat/neat.py b/tensorneat/algorithm/neat/neat.py index fefc6fa..62ad5e0 100644 --- a/tensorneat/algorithm/neat/neat.py +++ b/tensorneat/algorithm/neat/neat.py @@ -47,8 +47,8 @@ class NEAT(BaseAlgorithm): def restore(self, state, transformed): return self.genome.restore(state, transformed) - def forward(self, state, inputs, transformed): - return self.genome.forward(state, inputs, transformed) + def forward(self, state, transformed, inputs): + return self.genome.forward(state, transformed, inputs) def update_by_batch(self, state, batch_input, transformed): return self.genome.update_by_batch(state, batch_input, transformed) diff --git a/tensorneat/examples/func_fit/xor.py b/tensorneat/examples/func_fit/xor.py index c2ba018..24ebc0f 100644 --- a/tensorneat/examples/func_fit/xor.py +++ b/tensorneat/examples/func_fit/xor.py @@ -2,7 +2,7 @@ from pipeline import Pipeline from algorithm.neat import * from problem.func_fit import XOR3d -from utils import Act +from utils import ACT_ALL, AGG_ALL, Act, Agg if __name__ == "__main__": pipeline = Pipeline( @@ -15,17 +15,21 @@ if __name__ == "__main__": max_conns=100, node_gene=DefaultNodeGene( activation_default=Act.tanh, - activation_options=(Act.tanh,), + # activation_options=(Act.tanh,), + activation_options=ACT_ALL, + aggregation_default=Agg.sum, + # aggregation_options=(Agg.sum,), + aggregation_options=AGG_ALL, ), output_transform=Act.sigmoid, # the activation function for output node mutation=DefaultMutation( node_add=0.1, conn_add=0.1, - node_delete=0.05, - conn_delete=0.05, + node_delete=0, + conn_delete=0, ), ), - pop_size=1000, + pop_size=100000, species_size=20, compatibility_threshold=2, survival_threshold=0.01, # magic diff --git a/tensorneat/examples/func_fit/xor_kan.py b/tensorneat/examples/func_fit/xor_kan.py index fff9920..5bdd1ed 100644 --- a/tensorneat/examples/func_fit/xor_kan.py +++ b/tensorneat/examples/func_fit/xor_kan.py @@ -16,7 +16,7 @@ if __name__ == "__main__": max_nodes=50, max_conns=100, node_gene=KANNode(), - conn_gene=BSplineConn(), + conn_gene=BSplineConn(grid_cnt=10), output_transform=Act.sigmoid, # the activation function for output node mutation=DefaultMutation( node_add=0.1, @@ -25,7 +25,7 @@ if __name__ == "__main__": conn_delete=0.05, ), ), - pop_size=1000, + pop_size=10000, species_size=20, compatibility_threshold=1.5, survival_threshold=0.01, # magic @@ -34,7 +34,7 @@ if __name__ == "__main__": # problem=XOR3d(return_data=True), problem=XOR3d(), generation_limit=10000, - fitness_target=-1e-8, + fitness_target=-1e-5, # update_batch_size=8, # pre_update=True, ) diff --git a/tensorneat/problem/func_fit/func_fit.py b/tensorneat/problem/func_fit/func_fit.py index e6cc70d..3d67dae 100644 --- a/tensorneat/problem/func_fit/func_fit.py +++ b/tensorneat/problem/func_fit/func_fit.py @@ -20,8 +20,8 @@ class FuncFit(BaseProblem): def evaluate(self, state, randkey, act_func, params): - predict = jax.vmap(act_func, in_axes=(None, 0, None))( - state, self.inputs, params + predict = jax.vmap(act_func, in_axes=(None, None, 0))( + state, params, self.inputs ) if self.error_method == "mse": @@ -45,8 +45,8 @@ class FuncFit(BaseProblem): return -loss def show(self, state, randkey, act_func, params, *args, **kwargs): - predict = jax.vmap(act_func, in_axes=(None, 0, None))( - state, self.inputs, params + predict = jax.vmap(act_func, in_axes=(None, None, 0))( + state, params, self.inputs, params ) inputs, target, predict = jax.device_get([self.inputs, self.targets, predict]) if self.return_data: diff --git a/tensorneat/problem/rl_env/brax_env.py b/tensorneat/problem/rl_env/brax_env.py index 8f79c81..eff6338 100644 --- a/tensorneat/problem/rl_env/brax_env.py +++ b/tensorneat/problem/rl_env/brax_env.py @@ -51,7 +51,7 @@ class BraxEnv(RLEnv): def step(key, env_state, obs): key, _ = jax.random.split(key) - action = act_func(obs, params) + action = act_func(params, obs) obs, env_state, r, done, _ = self.step(randkey, env_state, action) return key, env_state, obs, r, done diff --git a/tensorneat/problem/rl_env/rl_jit.py b/tensorneat/problem/rl_env/rl_jit.py index 06e020c..00dfcb3 100644 --- a/tensorneat/problem/rl_env/rl_jit.py +++ b/tensorneat/problem/rl_env/rl_jit.py @@ -36,7 +36,7 @@ class RLEnv(BaseProblem): def body_func(carry): obs, env_state, rng, done, tr, count, epis = carry # tr -> total reward - action = act_func(state, obs, params) + action = act_func(state, params, obs) next_obs, next_env_state, reward, done, _ = self.step( rng, env_state, action ) diff --git a/tensorneat/test/test_genome.py b/tensorneat/test/test_genome.py index 5a269fb..704952b 100644 --- a/tensorneat/test/test_genome.py +++ b/tensorneat/test/test_genome.py @@ -1,132 +1,27 @@ +import jax from algorithm.neat import * -from utils import Act, Agg, State -import jax, jax.numpy as jnp -from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse +genome = DefaultGenome( + num_inputs=3, + num_outputs=1, + max_nodes=5, + max_conns=10, +) -def test_default(): - - # index, bias, response, activation, aggregation - nodes = jnp.array( - [ - [0, 0, 1, 0, 0], # in[0] - [1, 0, 1, 0, 0], # in[1] - [2, 0.5, 1, 0, 0], # out[0], - [3, 1, 1, 0, 0], # hidden[0], - [4, -1, 1, 0, 0], # hidden[1], - ] - ) - - # in_node, out_node, enable, weight - conns = jnp.array( - [ - [0, 3, 0.5], # in[0] -> hidden[0] - [1, 4, 0.5], # in[1] -> hidden[1] - [3, 2, 0.5], # hidden[0] -> out[0] - [4, 2, 0.5], # hidden[1] -> out[0] - ] - ) - - genome = DefaultGenome( - num_inputs=2, - num_outputs=1, - max_nodes=5, - max_conns=4, - node_gene=DefaultNodeGene( - activation_default=Act.identity, - activation_options=(Act.identity,), - aggregation_default=Agg.sum, - aggregation_options=(Agg.sum,), - ), - ) - - state = genome.setup(State(randkey=jax.random.key(0))) - - transformed = genome.transform(state, nodes, conns) - print(*transformed, sep="\n") - - inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]]) - outputs = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, None)))( - state, inputs, transformed - ) - print(outputs) - assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]])) - # expected: [[0.5], [0.75], [0.75], [1]] - - -def test_recurrent(): - - # index, bias, response, activation, aggregation - nodes = jnp.array( - [ - [0, 0, 1, 0, 0], # in[0] - [1, 0, 1, 0, 0], # in[1] - [2, 0.5, 1, 0, 0], # out[0], - [3, 1, 1, 0, 0], # hidden[0], - [4, -1, 1, 0, 0], # hidden[1], - ] - ) - - # in_node, out_node, enable, weight - conns = jnp.array( - [ - [0, 3, 0.5], # in[0] -> hidden[0] - [1, 4, 0.5], # in[1] -> hidden[1] - [3, 2, 0.5], # hidden[0] -> out[0] - [4, 2, 0.5], # hidden[1] -> out[0] - ] - ) - - genome = RecurrentGenome( - num_inputs=2, - num_outputs=1, - max_nodes=5, - max_conns=4, - node_gene=DefaultNodeGene( - activation_default=Act.identity, - activation_options=(Act.identity,), - aggregation_default=Agg.sum, - aggregation_options=(Agg.sum,), - ), - activate_time=3, - ) - - state = genome.setup(State(randkey=jax.random.key(0))) - - transformed = genome.transform(state, nodes, conns) - print(*transformed, sep="\n") - - inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]]) - outputs = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, None)))( - state, inputs, transformed - ) - print(outputs) - assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]])) - # expected: [[0.5], [0.75], [0.75], [1]] - - -def test_random_initialize(): - genome = DefaultGenome( - num_inputs=2, - num_outputs=1, - max_nodes=5, - max_conns=4, - node_gene=NodeGeneWithoutResponse( - activation_default=Act.identity, - activation_options=(Act.identity,), - aggregation_default=Agg.sum, - aggregation_options=(Agg.sum,), - ), - ) +def test_output_work(): + randkey = jax.random.PRNGKey(0) state = genome.setup() - key = jax.random.PRNGKey(0) - nodes, conns = genome.initialize(state, key) + nodes, conns = genome.initialize(state, randkey) transformed = genome.transform(state, nodes, conns) - print(*transformed, sep="\n") + inputs = jax.random.normal(randkey, (3,)) + output = genome.forward(state, transformed, inputs) + print(output) - inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]]) - outputs = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, None)))( - state, inputs, transformed + batch_inputs = jax.random.normal(randkey, (10, 3)) + batch_output = jax.vmap(genome.forward, in_axes=(None, None, 0))( + state, transformed, batch_inputs ) - print(outputs) + print(batch_output) + + assert True diff --git a/tensorneat/utils/tools.py b/tensorneat/utils/tools.py index b33e48f..9eceb51 100644 --- a/tensorneat/utils/tools.py +++ b/tensorneat/utils/tools.py @@ -9,12 +9,12 @@ I_INF = np.iinfo(jnp.int32).max # infinite int def unflatten_conns(nodes, conns): """ - transform the (C, CL) connections to (CL-2, N, N), 2 is for the input index and output index), which CL means + transform the (C, CL) connections to (N, N), which contains the idx of the connection in conns connection length, N means the number of nodes, C means the number of connections - returns the un_flattened connections with shape (CL-2, N, N) + returns the unflatten connection indices with shape (N, N) """ N = nodes.shape[0] # max_nodes - CL = conns.shape[1] # connection length = (fix_attrs + custom_attrs) + C = conns.shape[0] # max_conns node_keys = nodes[:, 0] i_keys, o_keys = conns[:, 0], conns[:, 1] @@ -23,47 +23,25 @@ def unflatten_conns(nodes, conns): i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys) o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys) - unflatten = jnp.full((CL - 2, N, N), jnp.nan) # Is interesting that jax use clip when attach data in array - # however, it will do nothing set values in an array - # put all attributes include enable in res - unflatten = unflatten.at[:, i_idxs, o_idxs].set(conns[:, 2:].T) - assert unflatten.shape == (CL - 2, N, N) + # however, it will do nothing when setting values in an array + # put the index of connections in the unflatten array + unflatten = ( + jnp.full((N, N), I_INF, dtype=jnp.int32) + .at[i_idxs, o_idxs] + .set(jnp.arange(C, dtype=jnp.int32)) + ) return unflatten -def flatten_conns(nodes, unflatten, C): - """ - the inverse function of unflatten_conns - transform the unflatten conn (CL-2, N, N) to (C, CL) - """ - N = nodes.shape[0] - CL = unflatten.shape[0] + 2 - node_keys = nodes[:, 0] - - def extract_conn(i, j): - return jnp.where( - jnp.isnan(unflatten[0, i, j]), - jnp.nan, - jnp.concatenate( - [jnp.array([node_keys[i], node_keys[j]]), unflatten[:, i, j]] - ), - ) - - x, y = jnp.meshgrid(jnp.arange(N), jnp.arange(N), indexing="ij") - conns = vmap(extract_conn)(x.flatten(), y.flatten()) - assert conns.shape == (N * N, CL) - - # put nan to the tail of the conns - sorted_idx = jnp.argsort(conns[:, 0]) - sorted_conn = conns[sorted_idx] - - # truncate the conns to the number of connections - conns = sorted_conn[:C] - assert conns.shape == (C, CL) - return conns +def attach_with_inf(arr, idx): + expand_size = arr.ndim - idx.ndim + expand_idx = jnp.expand_dims( + idx, axis=tuple(range(idx.ndim, expand_size + idx.ndim)) + ) + return jnp.where(expand_idx == I_INF, jnp.nan, arr[idx]) def extract_node_attrs(node):