add input_transform and update_input_transform;

change the args for genome.forward.
Origin: (state, inputs, transformed)
New: (state, transformed, inputs)
This commit is contained in:
wls2002
2024-06-03 10:53:15 +08:00
parent a07a3b1cb2
commit edfb0596e7
16 changed files with 185 additions and 221 deletions

View File

@@ -22,7 +22,7 @@ class BaseAlgorithm:
def restore(self, state, transformed): def restore(self, state, transformed):
raise NotImplementedError raise NotImplementedError
def forward(self, state, inputs, transformed): def forward(self, state, transformed, inputs):
raise NotImplementedError raise NotImplementedError
def update_by_batch(self, state, batch_input, transformed): def update_by_batch(self, state, batch_input, transformed):

View File

@@ -54,8 +54,8 @@ class HyperNEAT(BaseAlgorithm):
def transform(self, state, individual): def transform(self, state, individual):
transformed = self.neat.transform(state, individual) transformed = self.neat.transform(state, individual)
query_res = jax.vmap(self.neat.forward, in_axes=(None, 0, None))( query_res = jax.vmap(self.neat.forward, in_axes=(None, None, 0))(
state, self.substrate.query_coors, transformed state, transformed, self.substrate.query_coors
) )
# mute the connection with weight below threshold # mute the connection with weight below threshold
query_res = jnp.where( query_res = jnp.where(

View File

@@ -163,8 +163,8 @@ class DefaultMutation(BaseMutation):
) )
if genome.network_type == "feedforward": if genome.network_type == "feedforward":
u_cons = unflatten_conns(nodes_, conns_) u_conns = unflatten_conns(nodes_, conns_)
conns_exist = ~jnp.isnan(u_cons[0, :, :]) conns_exist = (u_conns != I_INF)
is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx) is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx)
return jax.lax.cond( return jax.lax.cond(

View File

@@ -12,6 +12,13 @@ class BaseNodeGene(BaseGene):
def forward(self, state, attrs, inputs, is_output_node=False): def forward(self, state, attrs, inputs, is_output_node=False):
raise NotImplementedError raise NotImplementedError
def input_transform(self, state, attrs, inputs):
"""
make transformation in the input node.
default: do nothing
"""
return inputs
def update_by_batch(self, state, attrs, batch_inputs, is_output_node=False): def update_by_batch(self, state, attrs, batch_inputs, is_output_node=False):
# default: do not update attrs, but to calculate batch_res # default: do not update attrs, but to calculate batch_res
return ( return (
@@ -20,3 +27,15 @@ class BaseNodeGene(BaseGene):
), ),
attrs, attrs,
) )
def update_input_transform(self, state, attrs, batch_inputs):
"""
update the attrs for transformation in the input node.
default: do nothing
"""
return (
jax.vmap(self.input_transform, in_axes=(None, None, 0))(
state, attrs, batch_inputs
),
attrs,
)

View File

@@ -157,6 +157,15 @@ class NormalizedNode(BaseNodeGene):
return z return z
def input_transform(self, state, attrs, inputs):
"""
make transform in the input node.
the normalization also need be done in the first node.
"""
bias, agg, act, mean, std, alpha, beta = attrs
inputs = (inputs - mean) / (std + self.eps) * alpha + beta # normalization
return inputs
def update_by_batch(self, state, attrs, batch_inputs, is_output_node=False): def update_by_batch(self, state, attrs, batch_inputs, is_output_node=False):
bias, agg, act, mean, std, alpha, beta = attrs bias, agg, act, mean, std, alpha, beta = attrs
@@ -192,3 +201,31 @@ class NormalizedNode(BaseNodeGene):
attrs = attrs.at[4].set(std) attrs = attrs.at[4].set(std)
return batch_z, attrs return batch_z, attrs
def update_input_transform(self, state, attrs, batch_inputs):
"""
update the attrs for transformation in the input node.
default: do nothing
"""
bias, agg, act, mean, std, alpha, beta = attrs
# calculate mean
valid_values_count = jnp.sum(~jnp.isnan(batch_inputs))
valid_values_sum = jnp.sum(jnp.where(jnp.isnan(batch_inputs), 0, batch_inputs))
mean = valid_values_sum / valid_values_count
# calculate std
std = jnp.sqrt(
jnp.sum(jnp.where(jnp.isnan(batch_inputs), 0, (batch_inputs - mean) ** 2))
/ valid_values_count
)
batch_inputs = (batch_inputs - mean) / (
std + self.eps
) * alpha + beta # normalization
# update mean and std to the attrs
attrs = attrs.at[3].set(mean)
attrs = attrs.at[4].set(std)
return batch_inputs, attrs

View File

@@ -42,7 +42,7 @@ class BaseGenome:
def restore(self, state, transformed): def restore(self, state, transformed):
raise NotImplementedError raise NotImplementedError
def forward(self, state, inputs, transformed): def forward(self, state, transformed, inputs):
raise NotImplementedError raise NotImplementedError
def execute_mutation(self, state, randkey, nodes, conns, new_node_key): def execute_mutation(self, state, randkey, nodes, conns, new_node_key):

View File

@@ -1,7 +1,16 @@
from typing import Callable from typing import Callable
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from utils import unflatten_conns, flatten_conns, topological_sort, I_INF from utils import (
unflatten_conns,
topological_sort,
I_INF,
extract_node_attrs,
extract_conn_attrs,
set_node_attrs,
set_conn_attrs,
attach_with_inf,
)
from . import BaseGenome from . import BaseGenome
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
@@ -45,23 +54,23 @@ class DefaultGenome(BaseGenome):
def transform(self, state, nodes, conns): def transform(self, state, nodes, conns):
u_conns = unflatten_conns(nodes, conns) u_conns = unflatten_conns(nodes, conns)
conn_exist = ~jnp.isnan(u_conns[0]) conn_exist = u_conns != I_INF
seqs = topological_sort(nodes, conn_exist) seqs = topological_sort(nodes, conn_exist)
return seqs, nodes, u_conns return seqs, nodes, conns, u_conns
def restore(self, state, transformed): def restore(self, state, transformed):
seqs, nodes, u_conns = transformed seqs, nodes, conns, u_conns = transformed
conns = flatten_conns(nodes, u_conns, C=self.max_conns)
return nodes, conns return nodes, conns
def forward(self, state, inputs, transformed): def forward(self, state, transformed, inputs):
cal_seqs, nodes, u_conns = transformed cal_seqs, nodes, conns, u_conns = transformed
ini_vals = jnp.full((self.max_nodes,), jnp.nan) ini_vals = jnp.full((self.max_nodes,), jnp.nan)
ini_vals = ini_vals.at[self.input_idx].set(inputs) ini_vals = ini_vals.at[self.input_idx].set(inputs)
nodes_attrs = nodes[:, 1:] nodes_attrs = jax.vmap(extract_node_attrs)(nodes)
conns_attrs = jax.vmap(extract_conn_attrs)(conns)
def cond_fun(carry): def cond_fun(carry):
values, idx = carry values, idx = carry
@@ -71,9 +80,16 @@ class DefaultGenome(BaseGenome):
values, idx = carry values, idx = carry
i = cal_seqs[idx] i = cal_seqs[idx]
def hit(): def input_node():
ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 1, 0))( z = self.node_gene.input_transform(state, nodes_attrs[i], values[i])
state, u_conns[:, :, i], values new_values = values.at[i].set(z)
return new_values
def otherwise():
conn_indices = u_conns[:, i]
hit_attrs = attach_with_inf(conns_attrs, conn_indices)
ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 0, 0))(
state, hit_attrs, values
) )
z = self.node_gene.forward( z = self.node_gene.forward(
@@ -86,8 +102,7 @@ class DefaultGenome(BaseGenome):
new_values = values.at[i].set(z) new_values = values.at[i].set(z)
return new_values return new_values
# the val of input nodes is obtained by the task, not by calculation values = jax.lax.cond(jnp.isin(i, self.input_idx), input_node, otherwise)
values = jax.lax.cond(jnp.isin(i, self.input_idx), lambda: values, hit)
return values, idx + 1 return values, idx + 1
@@ -99,55 +114,72 @@ class DefaultGenome(BaseGenome):
return self.output_transform(vals[self.output_idx]) return self.output_transform(vals[self.output_idx])
def update_by_batch(self, state, batch_input, transformed): def update_by_batch(self, state, batch_input, transformed):
cal_seqs, nodes, u_conns = transformed cal_seqs, nodes, conns, u_conns = transformed
batch_size = batch_input.shape[0] batch_size = batch_input.shape[0]
batch_ini_vals = jnp.full((batch_size, self.max_nodes), jnp.nan) batch_ini_vals = jnp.full((batch_size, self.max_nodes), jnp.nan)
batch_ini_vals = batch_ini_vals.at[:, self.input_idx].set(batch_input) batch_ini_vals = batch_ini_vals.at[:, self.input_idx].set(batch_input)
nodes_attrs = nodes[:, 1:] nodes_attrs = jax.vmap(extract_node_attrs)(nodes)
conns_attrs = jax.vmap(extract_conn_attrs)(conns)
def cond_fun(carry): def cond_fun(carry):
batch_values, nodes_attrs_, u_conns_, idx = carry batch_values, nodes_attrs_, conns_attrs_, idx = carry
return (idx < self.max_nodes) & (cal_seqs[idx] != I_INF) return (idx < self.max_nodes) & (cal_seqs[idx] != I_INF)
def body_func(carry): def body_func(carry):
batch_values, nodes_attrs_, u_conns_, idx = carry batch_values, nodes_attrs_, conns_attrs_, idx = carry
i = cal_seqs[idx] i = cal_seqs[idx]
def hit(): def input_node():
batch, new_attrs = self.node_gene.update_input_transform(
state, nodes_attrs_[i], batch_values[:, i]
)
return (
batch_values.at[:, i].set(batch),
nodes_attrs_.at[i].set(new_attrs),
conns_attrs_,
)
def otherwise():
conn_indices = u_conns[:, i]
hit_attrs = attach_with_inf(conns_attrs, conn_indices)
batch_ins, new_conn_attrs = jax.vmap( batch_ins, new_conn_attrs = jax.vmap(
self.conn_gene.update_by_batch, self.conn_gene.update_by_batch,
in_axes=(None, 1, 1), in_axes=(None, 0, 1),
out_axes=(1, 1), out_axes=(1, 0),
)(state, u_conns_[:, :, i], batch_values) )(state, hit_attrs, batch_values)
batch_z, new_node_attrs = self.node_gene.update_by_batch( batch_z, new_node_attrs = self.node_gene.update_by_batch(
state, state,
nodes_attrs[i], nodes_attrs_[i],
batch_ins, batch_ins,
is_output_node=jnp.isin(i, self.output_idx), is_output_node=jnp.isin(i, self.output_idx),
) )
new_batch_values = batch_values.at[:, i].set(batch_z)
return ( return (
new_batch_values, batch_values.at[:, i].set(batch_z),
nodes_attrs_.at[i].set(new_node_attrs), nodes_attrs_.at[i].set(new_node_attrs),
u_conns_.at[:, :, i].set(new_conn_attrs), conns_attrs_.at[conn_indices].set(new_conn_attrs),
) )
# the val of input nodes is obtained by the task, not by calculation # the val of input nodes is obtained by the task, not by calculation
(batch_values, nodes_attrs_, u_conns_) = jax.lax.cond( (batch_values, nodes_attrs_, conns_attrs_) = jax.lax.cond(
jnp.isin(i, self.input_idx), jnp.isin(i, self.input_idx),
lambda: (batch_values, nodes_attrs_, u_conns_), input_node,
hit, otherwise,
) )
return batch_values, nodes_attrs_, u_conns_, idx + 1 return batch_values, nodes_attrs_, conns_attrs_, idx + 1
batch_vals, nodes_attrs, u_conns, _ = jax.lax.while_loop( batch_vals, nodes_attrs, conns_attrs, _ = jax.lax.while_loop(
cond_fun, body_func, (batch_ini_vals, nodes_attrs, u_conns, 0) cond_fun, body_func, (batch_ini_vals, nodes_attrs, conns_attrs, 0)
) )
nodes = nodes.at[:, 1:].set(nodes_attrs) nodes = jax.vmap(set_node_attrs)(nodes, nodes_attrs)
new_transformed = (cal_seqs, nodes, u_conns) conns = jax.vmap(set_conn_attrs)(conns, conns_attrs)
new_transformed = (cal_seqs, nodes, conns, u_conns)
if self.output_transform is None: if self.output_transform is None:
return batch_vals[:, self.output_idx], new_transformed return batch_vals[:, self.output_idx], new_transformed

View File

@@ -1,7 +1,7 @@
from typing import Callable from typing import Callable
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from utils import unflatten_conns, flatten_conns from utils import unflatten_conns
from . import BaseGenome from . import BaseGenome
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
@@ -47,11 +47,10 @@ class RecurrentGenome(BaseGenome):
def transform(self, state, nodes, conns): def transform(self, state, nodes, conns):
u_conns = unflatten_conns(nodes, conns) u_conns = unflatten_conns(nodes, conns)
return nodes, u_conns return nodes, conns, u_conns
def restore(self, state, transformed): def restore(self, state, transformed):
nodes, u_conns = transformed nodes, conns, u_conns = transformed
conns = flatten_conns(nodes, u_conns, C=self.max_conns)
return nodes, conns return nodes, conns
def forward(self, state, inputs, transformed): def forward(self, state, inputs, transformed):

View File

@@ -47,8 +47,8 @@ class NEAT(BaseAlgorithm):
def restore(self, state, transformed): def restore(self, state, transformed):
return self.genome.restore(state, transformed) return self.genome.restore(state, transformed)
def forward(self, state, inputs, transformed): def forward(self, state, transformed, inputs):
return self.genome.forward(state, inputs, transformed) return self.genome.forward(state, transformed, inputs)
def update_by_batch(self, state, batch_input, transformed): def update_by_batch(self, state, batch_input, transformed):
return self.genome.update_by_batch(state, batch_input, transformed) return self.genome.update_by_batch(state, batch_input, transformed)

View File

@@ -2,7 +2,7 @@ from pipeline import Pipeline
from algorithm.neat import * from algorithm.neat import *
from problem.func_fit import XOR3d from problem.func_fit import XOR3d
from utils import Act from utils import ACT_ALL, AGG_ALL, Act, Agg
if __name__ == "__main__": if __name__ == "__main__":
pipeline = Pipeline( pipeline = Pipeline(
@@ -15,17 +15,21 @@ if __name__ == "__main__":
max_conns=100, max_conns=100,
node_gene=DefaultNodeGene( node_gene=DefaultNodeGene(
activation_default=Act.tanh, activation_default=Act.tanh,
activation_options=(Act.tanh,), # activation_options=(Act.tanh,),
activation_options=ACT_ALL,
aggregation_default=Agg.sum,
# aggregation_options=(Agg.sum,),
aggregation_options=AGG_ALL,
), ),
output_transform=Act.sigmoid, # the activation function for output node output_transform=Act.sigmoid, # the activation function for output node
mutation=DefaultMutation( mutation=DefaultMutation(
node_add=0.1, node_add=0.1,
conn_add=0.1, conn_add=0.1,
node_delete=0.05, node_delete=0,
conn_delete=0.05, conn_delete=0,
), ),
), ),
pop_size=1000, pop_size=100000,
species_size=20, species_size=20,
compatibility_threshold=2, compatibility_threshold=2,
survival_threshold=0.01, # magic survival_threshold=0.01, # magic

View File

@@ -16,7 +16,7 @@ if __name__ == "__main__":
max_nodes=50, max_nodes=50,
max_conns=100, max_conns=100,
node_gene=KANNode(), node_gene=KANNode(),
conn_gene=BSplineConn(), conn_gene=BSplineConn(grid_cnt=10),
output_transform=Act.sigmoid, # the activation function for output node output_transform=Act.sigmoid, # the activation function for output node
mutation=DefaultMutation( mutation=DefaultMutation(
node_add=0.1, node_add=0.1,
@@ -25,7 +25,7 @@ if __name__ == "__main__":
conn_delete=0.05, conn_delete=0.05,
), ),
), ),
pop_size=1000, pop_size=10000,
species_size=20, species_size=20,
compatibility_threshold=1.5, compatibility_threshold=1.5,
survival_threshold=0.01, # magic survival_threshold=0.01, # magic
@@ -34,7 +34,7 @@ if __name__ == "__main__":
# problem=XOR3d(return_data=True), # problem=XOR3d(return_data=True),
problem=XOR3d(), problem=XOR3d(),
generation_limit=10000, generation_limit=10000,
fitness_target=-1e-8, fitness_target=-1e-5,
# update_batch_size=8, # update_batch_size=8,
# pre_update=True, # pre_update=True,
) )

View File

@@ -20,8 +20,8 @@ class FuncFit(BaseProblem):
def evaluate(self, state, randkey, act_func, params): def evaluate(self, state, randkey, act_func, params):
predict = jax.vmap(act_func, in_axes=(None, 0, None))( predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, self.inputs, params state, params, self.inputs
) )
if self.error_method == "mse": if self.error_method == "mse":
@@ -45,8 +45,8 @@ class FuncFit(BaseProblem):
return -loss return -loss
def show(self, state, randkey, act_func, params, *args, **kwargs): def show(self, state, randkey, act_func, params, *args, **kwargs):
predict = jax.vmap(act_func, in_axes=(None, 0, None))( predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, self.inputs, params state, params, self.inputs, params
) )
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict]) inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
if self.return_data: if self.return_data:

View File

@@ -51,7 +51,7 @@ class BraxEnv(RLEnv):
def step(key, env_state, obs): def step(key, env_state, obs):
key, _ = jax.random.split(key) key, _ = jax.random.split(key)
action = act_func(obs, params) action = act_func(params, obs)
obs, env_state, r, done, _ = self.step(randkey, env_state, action) obs, env_state, r, done, _ = self.step(randkey, env_state, action)
return key, env_state, obs, r, done return key, env_state, obs, r, done

View File

@@ -36,7 +36,7 @@ class RLEnv(BaseProblem):
def body_func(carry): def body_func(carry):
obs, env_state, rng, done, tr, count, epis = carry # tr -> total reward obs, env_state, rng, done, tr, count, epis = carry # tr -> total reward
action = act_func(state, obs, params) action = act_func(state, params, obs)
next_obs, next_env_state, reward, done, _ = self.step( next_obs, next_env_state, reward, done, _ = self.step(
rng, env_state, action rng, env_state, action
) )

View File

@@ -1,132 +1,27 @@
import jax
from algorithm.neat import * from algorithm.neat import *
from utils import Act, Agg, State
import jax, jax.numpy as jnp genome = DefaultGenome(
from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse num_inputs=3,
num_outputs=1,
max_nodes=5,
max_conns=10,
)
def test_default(): def test_output_work():
randkey = jax.random.PRNGKey(0)
# index, bias, response, activation, aggregation
nodes = jnp.array(
[
[0, 0, 1, 0, 0], # in[0]
[1, 0, 1, 0, 0], # in[1]
[2, 0.5, 1, 0, 0], # out[0],
[3, 1, 1, 0, 0], # hidden[0],
[4, -1, 1, 0, 0], # hidden[1],
]
)
# in_node, out_node, enable, weight
conns = jnp.array(
[
[0, 3, 0.5], # in[0] -> hidden[0]
[1, 4, 0.5], # in[1] -> hidden[1]
[3, 2, 0.5], # hidden[0] -> out[0]
[4, 2, 0.5], # hidden[1] -> out[0]
]
)
genome = DefaultGenome(
num_inputs=2,
num_outputs=1,
max_nodes=5,
max_conns=4,
node_gene=DefaultNodeGene(
activation_default=Act.identity,
activation_options=(Act.identity,),
aggregation_default=Agg.sum,
aggregation_options=(Agg.sum,),
),
)
state = genome.setup(State(randkey=jax.random.key(0)))
transformed = genome.transform(state, nodes, conns)
print(*transformed, sep="\n")
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, None)))(
state, inputs, transformed
)
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
# expected: [[0.5], [0.75], [0.75], [1]]
def test_recurrent():
# index, bias, response, activation, aggregation
nodes = jnp.array(
[
[0, 0, 1, 0, 0], # in[0]
[1, 0, 1, 0, 0], # in[1]
[2, 0.5, 1, 0, 0], # out[0],
[3, 1, 1, 0, 0], # hidden[0],
[4, -1, 1, 0, 0], # hidden[1],
]
)
# in_node, out_node, enable, weight
conns = jnp.array(
[
[0, 3, 0.5], # in[0] -> hidden[0]
[1, 4, 0.5], # in[1] -> hidden[1]
[3, 2, 0.5], # hidden[0] -> out[0]
[4, 2, 0.5], # hidden[1] -> out[0]
]
)
genome = RecurrentGenome(
num_inputs=2,
num_outputs=1,
max_nodes=5,
max_conns=4,
node_gene=DefaultNodeGene(
activation_default=Act.identity,
activation_options=(Act.identity,),
aggregation_default=Agg.sum,
aggregation_options=(Agg.sum,),
),
activate_time=3,
)
state = genome.setup(State(randkey=jax.random.key(0)))
transformed = genome.transform(state, nodes, conns)
print(*transformed, sep="\n")
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, None)))(
state, inputs, transformed
)
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
# expected: [[0.5], [0.75], [0.75], [1]]
def test_random_initialize():
genome = DefaultGenome(
num_inputs=2,
num_outputs=1,
max_nodes=5,
max_conns=4,
node_gene=NodeGeneWithoutResponse(
activation_default=Act.identity,
activation_options=(Act.identity,),
aggregation_default=Agg.sum,
aggregation_options=(Agg.sum,),
),
)
state = genome.setup() state = genome.setup()
key = jax.random.PRNGKey(0) nodes, conns = genome.initialize(state, randkey)
nodes, conns = genome.initialize(state, key)
transformed = genome.transform(state, nodes, conns) transformed = genome.transform(state, nodes, conns)
print(*transformed, sep="\n") inputs = jax.random.normal(randkey, (3,))
output = genome.forward(state, transformed, inputs)
print(output)
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]]) batch_inputs = jax.random.normal(randkey, (10, 3))
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, None)))( batch_output = jax.vmap(genome.forward, in_axes=(None, None, 0))(
state, inputs, transformed state, transformed, batch_inputs
) )
print(outputs) print(batch_output)
assert True

View File

@@ -9,12 +9,12 @@ I_INF = np.iinfo(jnp.int32).max # infinite int
def unflatten_conns(nodes, conns): def unflatten_conns(nodes, conns):
""" """
transform the (C, CL) connections to (CL-2, N, N), 2 is for the input index and output index), which CL means transform the (C, CL) connections to (N, N), which contains the idx of the connection in conns
connection length, N means the number of nodes, C means the number of connections connection length, N means the number of nodes, C means the number of connections
returns the un_flattened connections with shape (CL-2, N, N) returns the unflatten connection indices with shape (N, N)
""" """
N = nodes.shape[0] # max_nodes N = nodes.shape[0] # max_nodes
CL = conns.shape[1] # connection length = (fix_attrs + custom_attrs) C = conns.shape[0] # max_conns
node_keys = nodes[:, 0] node_keys = nodes[:, 0]
i_keys, o_keys = conns[:, 0], conns[:, 1] i_keys, o_keys = conns[:, 0], conns[:, 1]
@@ -23,47 +23,25 @@ def unflatten_conns(nodes, conns):
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys) i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys) o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
unflatten = jnp.full((CL - 2, N, N), jnp.nan)
# Is interesting that jax use clip when attach data in array # Is interesting that jax use clip when attach data in array
# however, it will do nothing set values in an array # however, it will do nothing when setting values in an array
# put all attributes include enable in res # put the index of connections in the unflatten array
unflatten = unflatten.at[:, i_idxs, o_idxs].set(conns[:, 2:].T) unflatten = (
assert unflatten.shape == (CL - 2, N, N) jnp.full((N, N), I_INF, dtype=jnp.int32)
.at[i_idxs, o_idxs]
.set(jnp.arange(C, dtype=jnp.int32))
)
return unflatten return unflatten
def flatten_conns(nodes, unflatten, C): def attach_with_inf(arr, idx):
""" expand_size = arr.ndim - idx.ndim
the inverse function of unflatten_conns expand_idx = jnp.expand_dims(
transform the unflatten conn (CL-2, N, N) to (C, CL) idx, axis=tuple(range(idx.ndim, expand_size + idx.ndim))
""" )
N = nodes.shape[0] return jnp.where(expand_idx == I_INF, jnp.nan, arr[idx])
CL = unflatten.shape[0] + 2
node_keys = nodes[:, 0]
def extract_conn(i, j):
return jnp.where(
jnp.isnan(unflatten[0, i, j]),
jnp.nan,
jnp.concatenate(
[jnp.array([node_keys[i], node_keys[j]]), unflatten[:, i, j]]
),
)
x, y = jnp.meshgrid(jnp.arange(N), jnp.arange(N), indexing="ij")
conns = vmap(extract_conn)(x.flatten(), y.flatten())
assert conns.shape == (N * N, CL)
# put nan to the tail of the conns
sorted_idx = jnp.argsort(conns[:, 0])
sorted_conn = conns[sorted_idx]
# truncate the conns to the number of connections
conns = sorted_conn[:C]
assert conns.shape == (C, CL)
return conns
def extract_node_attrs(node): def extract_node_attrs(node):