diff --git a/algorithm/default_config.ini b/algorithm/default_config.ini index c62905c..2131d21 100644 --- a/algorithm/default_config.ini +++ b/algorithm/default_config.ini @@ -1,13 +1,14 @@ [basic] num_inputs = 2 num_outputs = 1 -maximum_nodes = 100 -maximum_connections = 100 -maximum_species = 100 +maximum_nodes = 50 +maximum_conns = 100 +maximum_species = 10 forward_way = "pop" batch_size = 4 random_seed = 0 -network_type = 'feedforward' +network_type = "feedforward" +activate_times = 10 [population] fitness_threshold = 3.9999 @@ -18,11 +19,11 @@ pop_size = 1000 [genome] compatibility_disjoint = 1.0 compatibility_weight = 0.5 -conn_add_prob = 0.5 +conn_add_prob = 0.4 conn_add_trials = 1 -conn_delete_prob = 0.5 +conn_delete_prob = 0 node_add_prob = 0.2 -node_delete_prob = 0.2 +node_delete_prob = 0 [species] compatibility_threshold = 3.0 diff --git a/algorithm/neat/__init__.py b/algorithm/neat/__init__.py index 87eba79..bff8b1c 100644 --- a/algorithm/neat/__init__.py +++ b/algorithm/neat/__init__.py @@ -1,3 +1,3 @@ from .neat import NEAT -from .gene import NormalGene +from .gene import NormalGene, RecurrentGene from .pipeline import Pipeline diff --git a/algorithm/neat/gene/__init__.py b/algorithm/neat/gene/__init__.py index dc4c9db..e1188c1 100644 --- a/algorithm/neat/gene/__init__.py +++ b/algorithm/neat/gene/__init__.py @@ -2,4 +2,5 @@ from .base import BaseGene from .normal import NormalGene from .activation import Activation from .aggregation import Aggregation +from .recurrent import RecurrentGene diff --git a/algorithm/neat/gene/normal.py b/algorithm/neat/gene/normal.py index c5bcd97..d7fda5f 100644 --- a/algorithm/neat/gene/normal.py +++ b/algorithm/neat/gene/normal.py @@ -86,10 +86,12 @@ class NormalGene(BaseGene): @staticmethod def forward_transform(nodes, conns): u_conns = unflatten_connections(nodes, conns) - u_conns = jnp.where(jnp.isnan(u_conns[0, :]), jnp.nan, u_conns) # enable is false, then the connections is nan - u_conns = u_conns[1:, :] # remove enable attr - conn_exist = jnp.any(~jnp.isnan(u_conns), axis=0) - seqs = topological_sort(nodes, conn_exist) + conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False) + + # remove enable attr + u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan) + seqs = topological_sort(nodes, conn_enable) + return seqs, nodes, u_conns @staticmethod @@ -167,18 +169,8 @@ class NormalGene(BaseGene): # the val of input nodes is obtained by the task, not by calculation values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit) - # if jnp.isin(i, input_idx): - # values = miss() - # else: - # values = hit() - return values, idx + 1 - # carry = (ini_vals, 0) - # while cond_fun(carry): - # carry = body_func(carry) - # vals, _ = carry - vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0)) return vals[output_idx] @@ -216,7 +208,3 @@ class NormalGene(BaseGene): ) return val - - - - diff --git a/algorithm/neat/gene/recurrent.py b/algorithm/neat/gene/recurrent.py new file mode 100644 index 0000000..9d73b96 --- /dev/null +++ b/algorithm/neat/gene/recurrent.py @@ -0,0 +1,90 @@ +import jax +from jax import Array, numpy as jnp, vmap + +from .normal import NormalGene +from .activation import Activation +from .aggregation import Aggregation +from ..utils import unflatten_connections, I_INT + + +class RecurrentGene(NormalGene): + + @staticmethod + def forward_transform(nodes, conns): + u_conns = unflatten_connections(nodes, conns) + + # remove un-enable connections and remove enable attr + conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False) + u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan) + + return nodes, u_conns + + @staticmethod + def create_forward(config): + config['activation_funcs'] = [Activation.name2func[name] for name in config['activation_option_names']] + config['aggregation_funcs'] = [Aggregation.name2func[name] for name in config['aggregation_option_names']] + + def act(idx, z): + """ + calculate activation function for each node + """ + idx = jnp.asarray(idx, dtype=jnp.int32) + # change idx from float to int + res = jax.lax.switch(idx, config['activation_funcs'], z) + return res + + def agg(idx, z): + """ + calculate activation function for inputs of node + """ + idx = jnp.asarray(idx, dtype=jnp.int32) + + def all_nan(): + return 0. + + def not_all_nan(): + return jax.lax.switch(idx, config['aggregation_funcs'], z) + + return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan) + + batch_act, batch_agg = vmap(act), vmap(agg) + + def forward(inputs, transform) -> Array: + """ + jax forward for single input shaped (input_num, ) + nodes, connections are a single genome + + :argument inputs: (input_num, ) + :argument cal_seqs: (N, ) + :argument nodes: (N, 5) + :argument connections: (2, N, N) + + :return (output_num, ) + """ + + nodes, cons = transform + + input_idx = config['input_idx'] + output_idx = config['output_idx'] + + N = nodes.shape[0] + vals = jnp.full((N,), 0.) + + weights = cons[0, :] + + def body_func(i, values): + values = values.at[input_idx].set(inputs) + nodes_ins = values * weights.T + values = batch_agg(nodes[:, 4], nodes_ins) # z = agg(ins) + values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias + values = batch_act(nodes[:, 3], values) # z = act(z) + return values + + # for i in range(config['activate_times']): + # vals = body_func(i, vals) + # + # return vals[output_idx] + vals = jax.lax.fori_loop(0, config['activate_times'], body_func, vals) + return vals[output_idx] + + return forward diff --git a/algorithm/neat/genome/basic.py b/algorithm/neat/genome/basic.py index a71cca8..76b7022 100644 --- a/algorithm/neat/genome/basic.py +++ b/algorithm/neat/genome/basic.py @@ -11,7 +11,7 @@ from ..utils import fetch_first def initialize_genomes(state: State, gene_type: Type[BaseGene]): o_nodes = np.full((state.N, state.NL), np.nan, dtype=np.float32) # original nodes - o_conns = np.full((state.N, state.CL), np.nan, dtype=np.float32) # original connections + o_conns = np.full((state.C, state.CL), np.nan, dtype=np.float32) # original connections input_idx = state.input_idx output_idx = state.output_idx diff --git a/algorithm/neat/genome/graph.py b/algorithm/neat/genome/graph.py index 8fc9842..1f65feb 100644 --- a/algorithm/neat/genome/graph.py +++ b/algorithm/neat/genome/graph.py @@ -9,6 +9,7 @@ from jax import jit, Array, numpy as jnp from ..utils import fetch_first, I_INT +@jit def topological_sort(nodes: Array, conns: Array) -> Array: """ a jit-able version of topological_sort! that's crazy! @@ -60,21 +61,11 @@ def topological_sort(nodes: Array, conns: Array) -> Array: return res -def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Array) -> Array: +@jit +def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array: """ Check whether a new connection (from_idx -> to_idx) will cause a cycle. - :param nodes: JAX array - The array of nodes. - :param connections: JAX array - The array of connections. - :param from_idx: int - The index of the starting node. - :param to_idx: int - The index of the ending node. - :return: JAX array - An array indicating if there is a cycle caused by the new connection. - Example: nodes = jnp.array([ [0], @@ -83,28 +74,21 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra [3] ]) connections = jnp.array([ - [ [0, 0, 1, 0], [0, 0, 1, 1], [0, 0, 0, 1], [0, 0, 0, 0] - ], - [ - [0, 0, 1, 0], - [0, 0, 1, 1], - [0, 0, 0, 1], - [0, 0, 0, 0] - ] ]) - check_cycles(nodes, connections, 3, 2) -> True - check_cycles(nodes, connections, 2, 3) -> False - check_cycles(nodes, connections, 0, 3) -> False - check_cycles(nodes, connections, 1, 0) -> False + check_cycles(nodes, conns, 3, 2) -> True + check_cycles(nodes, conns, 2, 3) -> False + check_cycles(nodes, conns, 0, 3) -> False + check_cycles(nodes, conns, 1, 0) -> False """ - connections_enable = ~jnp.isnan(connections[0, :, :]) - connections_enable = connections_enable.at[from_idx, to_idx].set(True) + conns = conns.at[from_idx, to_idx].set(True) + # conns_enable = ~jnp.isnan(conns[0, :, :]) + # conns_enable = conns_enable.at[from_idx, to_idx].set(True) visited = jnp.full(nodes.shape[0], False) new_visited = visited.at[to_idx].set(True) @@ -117,43 +101,42 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra def body_func(carry): _, visited_ = carry - new_visited_ = jnp.dot(visited_, connections_enable) + new_visited_ = jnp.dot(visited_, conns) new_visited_ = jnp.logical_or(visited_, new_visited_) return visited_, new_visited_ _, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited)) return visited[from_idx] - -if __name__ == '__main__': - nodes = jnp.array([ - [0], - [1], - [2], - [3], - [jnp.nan] - ]) - connections = jnp.array([ - [ - [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan], - [jnp.nan, jnp.nan, 1, 1, jnp.nan], - [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan], - [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan], - [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan] - ], - [ - [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan], - [jnp.nan, jnp.nan, 1, 1, jnp.nan], - [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan], - [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan], - [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan] - ] - ] - ) - - print(topological_sort(nodes, connections)) - - print(check_cycles(nodes, connections, 3, 2)) - print(check_cycles(nodes, connections, 2, 3)) - print(check_cycles(nodes, connections, 0, 3)) - print(check_cycles(nodes, connections, 1, 0)) \ No newline at end of file +# if __name__ == '__main__': +# nodes = jnp.array([ +# [0], +# [1], +# [2], +# [3], +# [jnp.nan] +# ]) +# connections = jnp.array([ +# [ +# [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan], +# [jnp.nan, jnp.nan, 1, 1, jnp.nan], +# [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan], +# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan], +# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan] +# ], +# [ +# [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan], +# [jnp.nan, jnp.nan, 1, 1, jnp.nan], +# [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan], +# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan], +# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan] +# ] +# ] +# ) +# +# print(topological_sort(nodes, connections)) +# +# print(check_cycles(nodes, connections, 3, 2)) +# print(check_cycles(nodes, connections, 2, 3)) +# print(check_cycles(nodes, connections, 0, 3)) +# print(check_cycles(nodes, connections, 1, 0)) diff --git a/algorithm/neat/genome/mutate.py b/algorithm/neat/genome/mutate.py index c555c39..50db98a 100644 --- a/algorithm/neat/genome/mutate.py +++ b/algorithm/neat/genome/mutate.py @@ -91,7 +91,8 @@ def create_mutate(config: Dict, gene_type: Type[BaseGene]): if config['network_type'] == 'feedforward': u_cons = unflatten_connections(nodes_, conns_) - is_cycle = check_cycles(nodes_, u_cons, from_idx, to_idx) + cons_exist = jnp.where(~jnp.isnan(u_cons[0, :, :]), True, False) + is_cycle = check_cycles(nodes_, cons_exist, from_idx, to_idx) choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2)) return jax.lax.switch(choice, [already_exist, nothing, successful]) diff --git a/algorithm/neat/neat.py b/algorithm/neat/neat.py index 044377a..03dc7d1 100644 --- a/algorithm/neat/neat.py +++ b/algorithm/neat/neat.py @@ -26,7 +26,7 @@ class NEAT: state = State( P=self.config['pop_size'], N=self.config['maximum_nodes'], - C=self.config['maximum_connections'], + C=self.config['maximum_conns'], S=self.config['maximum_species'], NL=1 + len(self.gene_type.node_attrs), # node length = (key) + attributes CL=3 + len(self.gene_type.conn_attrs), # conn length = (in, out, key) + attributes @@ -64,11 +64,15 @@ class NEAT: idx2species=idx2species, center_nodes=center_nodes, center_conns=center_conns, - generation=generation, - next_node_key=next_node_key, - next_species_key=next_species_key + # avoid jax auto cast from int to float. that would cause re-compilation. + generation=jnp.asarray(generation, dtype=jnp.int32), + next_node_key=jnp.asarray(next_node_key, dtype=jnp.float32), + next_species_key=jnp.asarray(next_species_key) ) + # move to device + state = jax.device_put(state) + return state def step(self, state, fitness): diff --git a/algorithm/neat/pipeline.py b/algorithm/neat/pipeline.py index fba0391..a03d0fd 100644 --- a/algorithm/neat/pipeline.py +++ b/algorithm/neat/pipeline.py @@ -34,9 +34,6 @@ class Pipeline: def tell(self, fitness): self.state = self.algorithm.step(self.state, fitness) - from algorithm.neat.genome.basic import count - # print([count(self.state.pop_nodes[i], self.state.pop_conns[i]) for i in range(self.state.P)]) - def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"): for _ in range(self.config['generation_limit']): diff --git a/algorithm/neat/population.py b/algorithm/neat/population.py index 4705834..462e44f 100644 --- a/algorithm/neat/population.py +++ b/algorithm/neat/population.py @@ -7,8 +7,8 @@ from .utils import rank_elements, fetch_first from .genome import create_mutate, create_distance, crossover from .gene import BaseGene -def create_tell(config, gene_type: Type[BaseGene]): +def create_tell(config, gene_type: Type[BaseGene]): mutate = create_mutate(config, gene_type) distance = create_distance(config, gene_type) @@ -36,7 +36,6 @@ def create_tell(config, gene_type: Type[BaseGene]): return state, winner, loser, elite_mask - def update_species_fitness(state, fitness): """ obtain the fitness of the species by the fitness of each individual. @@ -51,7 +50,6 @@ def create_tell(config, gene_type: Type[BaseGene]): return vmap(aux_func)(jnp.arange(state.species_info.shape[0])) - def stagnation(state, species_fitness): """ stagnation species. @@ -88,7 +86,6 @@ def create_tell(config, gene_type: Type[BaseGene]): return state, species_fitness - def cal_spawn_numbers(state): """ decide the number of members of each species by their fitness rank. @@ -106,7 +103,6 @@ def create_tell(config, gene_type: Type[BaseGene]): spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0 target_spawn_number = jnp.floor(spawn_number_rate * state.P) # calculate member - # jax.debug.print("denominator: {}, spawn_number_rate: {}, target_spawn_number: {}", denominator, spawn_number_rate, target_spawn_number) # Avoid too much variation of numbers in a species previous_size = state.species_info[:, 3].astype(jnp.int32) @@ -118,11 +114,11 @@ def create_tell(config, gene_type: Type[BaseGene]): # must control the sum of spawn_number to be equal to pop_size error = state.P - jnp.sum(spawn_number) - spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number + spawn_number = spawn_number.at[0].add( + error) # add error to the first species to control the sum of spawn_number return spawn_number - def create_crossover_pair(state, randkey, spawn_number, fitness): species_size = state.species_info.shape[0] pop_size = fitness.shape[0] @@ -238,8 +234,8 @@ def create_tell(config, gene_type: Type[BaseGene]): return i + 1, i2s, cn, cc, o2c _, idx2specie, center_nodes, center_conns, o2c_distances = \ - jax.lax.while_loop(cond_func, body_func, (0, idx2specie, state.center_nodes, state.center_conns, o2c_distances)) - + jax.lax.while_loop(cond_func, body_func, + (0, idx2specie, state.center_nodes, state.center_conns, o2c_distances)) # part 2: assign members to each species def cond_func(carry): @@ -331,7 +327,7 @@ def create_tell(config, gene_type: Type[BaseGene]): species_info = species_info.at[:, 3].set(species_member_counts) return state.update( - idx2specie=idx2specie, + idx2species=idx2specie, center_nodes=center_nodes, center_conns=center_conns, species_info=species_info, @@ -358,11 +354,10 @@ def create_tell(config, gene_type: Type[BaseGene]): return state - return tell def argmin_with_mask(arr, mask): masked_arr = jnp.where(mask, arr, jnp.inf) min_idx = jnp.argmin(masked_arr) - return min_idx \ No newline at end of file + return min_idx diff --git a/examples/config_test.py b/examples/config_test.py deleted file mode 100644 index aeb50b1..0000000 --- a/examples/config_test.py +++ /dev/null @@ -1,4 +0,0 @@ -from algorithm.config import Configer - -config = Configer.load_config() -print(config) \ No newline at end of file diff --git a/examples/rnn_forward_test.py b/examples/rnn_forward_test.py new file mode 100644 index 0000000..0351e1b --- /dev/null +++ b/examples/rnn_forward_test.py @@ -0,0 +1,13 @@ +import numpy as np + + +vals = np.array([1, 2]) +weights = np.array([[0, 4], [5, 0]]) + +ins1 = vals * weights[:, 0] +ins2 = vals * weights[:, 1] +ins_all = vals * weights.T + +print(ins1) +print(ins2) +print(ins_all) \ No newline at end of file diff --git a/examples/xor.ini b/examples/xor.ini index 893fff7..af2d8b9 100644 --- a/examples/xor.ini +++ b/examples/xor.ini @@ -1,5 +1,7 @@ [basic] forward_way = "common" +network_type = "recurrent" +activate_times = 5 [population] fitness_threshold = 4 \ No newline at end of file diff --git a/examples/xor.py b/examples/xor.py index 94a3b8b..f3a3f67 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -2,7 +2,7 @@ import jax import numpy as np from algorithm import Configer, NEAT -from algorithm.neat import NormalGene, Pipeline +from algorithm.neat import NormalGene, RecurrentGene, Pipeline xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) @@ -15,16 +15,17 @@ def evaluate(forward_func): """ outs = forward_func(xor_inputs) outs = jax.device_get(outs) - # print(outs) fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) return fitnesses def main(): config = Configer.load_config("xor.ini") - algorithm = NEAT(config, NormalGene) + # algorithm = NEAT(config, NormalGene) + algorithm = NEAT(config, RecurrentGene) pipeline = Pipeline(config, algorithm) - pipeline.auto_run(evaluate) + best = pipeline.auto_run(evaluate) + print(best) if __name__ == '__main__': diff --git a/examples/xor_test.py b/examples/xor_test.py index 9ca0c40..33da7bf 100644 --- a/examples/xor_test.py +++ b/examples/xor_test.py @@ -2,31 +2,49 @@ import jax import numpy as np from algorithm.config import Configer -from algorithm.neat import NEAT, NormalGene, Pipeline +from algorithm.neat import NEAT, NormalGene, RecurrentGene, Pipeline from algorithm.neat.genome import create_mutate xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) + def single_genome(func, nodes, conns): - t = NormalGene.forward_transform(nodes, conns) + t = RecurrentGene.forward_transform(nodes, conns) out1 = func(xor_inputs[0], t) out2 = func(xor_inputs[1], t) out3 = func(xor_inputs[2], t) out4 = func(xor_inputs[3], t) print(out1, out2, out3, out4) + +def batch_genome(func, nodes, conns): + t = NormalGene.forward_transform(nodes, conns) + out = jax.vmap(func, in_axes=(0, None))(xor_inputs, t) + print(out) + + +def pop_batch_genome(func, pop_nodes, pop_conns): + t = jax.vmap(NormalGene.forward_transform)(pop_nodes, pop_conns) + func = jax.vmap(jax.vmap(func, in_axes=(0, None)), in_axes=(None, 0)) + out = func(xor_inputs, t) + print(out) + + if __name__ == '__main__': - config = Configer.load_config() - neat = NEAT(config, NormalGene) + config = Configer.load_config("xor.ini") + # neat = NEAT(config, NormalGene) + neat = NEAT(config, RecurrentGene) randkey = jax.random.PRNGKey(42) state = neat.setup(randkey) - forward_func = NormalGene.create_forward(config) - mutate_func = create_mutate(config, NormalGene) - + forward_func = RecurrentGene.create_forward(config) + mutate_func = create_mutate(config, RecurrentGene) nodes, conns = state.pop_nodes[0], state.pop_conns[0] single_genome(forward_func, nodes, conns) + # batch_genome(forward_func, nodes, conns) + nodes, conns = mutate_func(state, randkey, nodes, conns, 10000) single_genome(forward_func, nodes, conns) - + # batch_genome(forward_func, nodes, conns) + # diff --git a/test/unit/test_graphs.py b/test/unit/test_graphs.py new file mode 100644 index 0000000..0b2ff17 --- /dev/null +++ b/test/unit/test_graphs.py @@ -0,0 +1,32 @@ +import jax.numpy as jnp + +from algorithm.neat.genome.graph import topological_sort, check_cycles +from algorithm.neat.utils import I_INT + +nodes = jnp.array([ + [0], + [1], + [2], + [3], + [jnp.nan] +]) + +# {(0, 2), (1, 2), (1, 3), (2, 3)} +conns = jnp.array([ + [0, 0, 1, 0, 0], + [0, 0, 1, 1, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0] +]) + + +def test_topological_sort(): + assert jnp.all(topological_sort(nodes, conns) == jnp.array([0, 1, 2, 3, I_INT])) + + +def test_check_cycles(): + assert check_cycles(nodes, conns, 3, 2) + assert ~check_cycles(nodes, conns, 2, 3) + assert ~check_cycles(nodes, conns, 0, 3) + assert ~check_cycles(nodes, conns, 1, 0) diff --git a/test/unit/test_utils.py b/test/unit/test_utils.py index b66c1aa..d313402 100644 --- a/test/unit/test_utils.py +++ b/test/unit/test_utils.py @@ -1,7 +1,5 @@ -import pytest -import jax - -from algorithm.neat.utils import * +import jax.numpy as jnp +from algorithm.neat.utils import unflatten_connections def test_unflatten(): @@ -13,7 +11,6 @@ def test_unflatten(): [jnp.nan, jnp.nan, jnp.nan, jnp.nan] ]) - conns = jnp.array([ [0, 1, True, 0.1, 0.11], [0, 2, False, 0.2, 0.22], @@ -33,4 +30,4 @@ def test_unflatten(): mask = mask.at[:, [0, 0, 1, 1], [1, 2, 2, 3]].set(False) # Ensure all other places are jnp.nan - assert jnp.all(jnp.isnan(res[mask])) \ No newline at end of file + assert jnp.all(jnp.isnan(res[mask]))