hyper neat

This commit is contained in:
wls2002
2023-07-24 19:25:02 +08:00
parent ac295c1921
commit ebad574431
24 changed files with 542 additions and 103 deletions

View File

@@ -0,0 +1,2 @@
from .neat import *
from .hyper_neat import *

View File

@@ -0,0 +1,2 @@
from .hyper_neat import HyperNEAT
from .substrate import NormalSubstrate, NormalSubstrateConfig

View File

@@ -0,0 +1,122 @@
from typing import Type
import jax
from jax import numpy as jnp, Array, vmap
import numpy as np
from config import Config, HyperNeatConfig
from core import Algorithm, Substrate, State, Genome
from utils import Activation, Aggregation
from algorithm.neat import NEAT
from .substrate import analysis_substrate
class HyperNEAT(Algorithm):
def __init__(self, config: Config, neat: NEAT, substrate: Type[Substrate]):
self.config = config
self.neat = neat
self.substrate = substrate
self.forward_func = None
def setup(self, randkey, state=State()):
neat_key, randkey = jax.random.split(randkey)
state = state.update(
below_threshold=self.config.hyper_neat.below_threshold,
max_weight=self.config.hyper_neat.max_weight,
)
state = self.neat.setup(neat_key, state)
state = self.substrate.setup(self.config.substrate, state)
assert self.config.hyper_neat.inputs + 1 == state.input_coors.shape[0] # +1 for bias
assert self.config.hyper_neat.outputs == state.output_coors.shape[0]
h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state)
h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis]
h_conns = np.zeros((correspond_keys.shape[0], 3), dtype=np.float32)
h_conns[:, 0:2] = correspond_keys
state = state.update(
h_input_idx=h_input_idx,
h_output_idx=h_output_idx,
h_hidden_idx=h_hidden_idx,
h_nodes=h_nodes,
h_conns=h_conns,
query_coors=query_coors,
)
self.forward_func = HyperNEATGene.create_forward(self.config.hyper_neat, state)
return state
def ask(self, state: State):
return state.pop_genomes
def tell(self, state: State, fitness):
return self.neat.tell(state, fitness)
def forward(self, inputs: Array, transformed: Array):
return self.forward_func(inputs, transformed)
def forward_transform(self, state: State, genome: Genome):
t = self.neat.forward_transform(state, genome)
query_res = vmap(self.neat.forward, in_axes=(0, None))(state.query_coors, t)
# mute the connection with weight below threshold
query_res = jnp.where((-state.below_threshold < query_res) & (query_res < state.below_threshold), 0., query_res)
# make query res in range [-max_weight, max_weight]
query_res = jnp.where(query_res > 0, query_res - state.below_threshold, query_res)
query_res = jnp.where(query_res < 0, query_res + state.below_threshold, query_res)
query_res = query_res / (1 - state.below_threshold) * state.max_weight
h_conns = state.h_conns.at[:, 2:].set(query_res)
return HyperNEATGene.forward_transform(Genome(state.h_nodes, h_conns))
class HyperNEATGene:
node_attrs = [] # no node attributes
conn_attrs = ['weight']
@staticmethod
def forward_transform(genome: Genome):
N = genome.nodes.shape[0]
u_conns = jnp.zeros((N, N), dtype=jnp.float32)
in_keys = jnp.asarray(genome.conns[:, 0], jnp.int32)
out_keys = jnp.asarray(genome.conns[:, 1], jnp.int32)
weights = genome.conns[:, 2]
u_conns = u_conns.at[in_keys, out_keys].set(weights)
return genome.nodes, u_conns
@staticmethod
def create_forward(config: HyperNeatConfig, state: State):
act = Activation.name2func[config.activation]
agg = Aggregation.name2func[config.aggregation]
batch_act, batch_agg = jax.vmap(act), jax.vmap(agg)
def forward(inputs, transform):
inputs_with_bias = jnp.concatenate((inputs, jnp.ones((1,))), axis=0)
nodes, weights = transform
input_idx = state.h_input_idx
output_idx = state.h_output_idx
N = nodes.shape[0]
vals = jnp.full((N,), 0.)
def body_func(i, values):
values = values.at[input_idx].set(inputs_with_bias)
nodes_ins = values * weights.T
values = batch_agg(nodes_ins) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(values) # z = act(z)
return values
vals = jax.lax.fori_loop(0, config.activate_times, body_func, vals)
return vals[output_idx]
return forward

View File

@@ -0,0 +1,2 @@
from .normal import NormalSubstrate, NormalSubstrateConfig
from .tools import analysis_substrate

View File

@@ -0,0 +1,25 @@
from dataclasses import dataclass
from typing import Tuple
import numpy as np
from core import Substrate, State
from config import SubstrateConfig
@dataclass(frozen=True)
class NormalSubstrateConfig(SubstrateConfig):
input_coors: Tuple[Tuple[float]] = ((-1, -1), (0, -1), (1, -1))
hidden_coors: Tuple[Tuple[float]] = ((-1, 0), (0, 0), (1, 0))
output_coors: Tuple[Tuple[float]] = ((0, 1), )
class NormalSubstrate(Substrate):
@staticmethod
def setup(config: NormalSubstrateConfig, state: State = State()):
return state.update(
input_coors=np.asarray(config.input_coors, dtype=np.float32),
output_coors=np.asarray(config.output_coors, dtype=np.float32),
hidden_coors=np.asarray(config.hidden_coors, dtype=np.float32),
)

View File

@@ -0,0 +1,50 @@
from typing import Type
import numpy as np
def analysis_substrate(state):
cd = state.input_coors.shape[1] # coordinate dimensions
si = state.input_coors.shape[0] # input coordinate size
so = state.output_coors.shape[0] # output coordinate size
sh = state.hidden_coors.shape[0] # hidden coordinate size
input_idx = np.arange(si)
output_idx = np.arange(si, si + so)
hidden_idx = np.arange(si + so, si + so + sh)
total_conns = si * sh + sh * sh + sh * so
query_coors = np.zeros((total_conns, cd * 2))
correspond_keys = np.zeros((total_conns, 2))
# connect input to hidden
aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, state.input_coors, state.hidden_coors)
query_coors[0: si * sh, :] = aux_coors
correspond_keys[0: si * sh, :] = aux_keys
# connect hidden to hidden
aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, state.hidden_coors, state.hidden_coors)
query_coors[si * sh: si * sh + sh * sh, :] = aux_coors
correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys
# connect hidden to output
aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, state.hidden_coors, state.output_coors)
query_coors[si * sh + sh * sh:, :] = aux_coors
correspond_keys[si * sh + sh * sh:, :] = aux_keys
return input_idx, output_idx, hidden_idx, query_coors, correspond_keys
def cartesian_product(keys1, keys2, coors1, coors2):
len1 = keys1.shape[0]
len2 = keys2.shape[0]
repeated_coors1 = np.repeat(coors1, len2, axis=0)
repeated_keys1 = np.repeat(keys1, len2)
tiled_coors2 = np.tile(coors2, (len1, 1))
tiled_keys2 = np.tile(keys2, len1)
new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1)
correspond_keys = np.column_stack((repeated_keys1, tiled_keys2))
return new_coors, correspond_keys

View File

@@ -0,0 +1,2 @@
from .neat import NEAT
from .gene import *

View File

@@ -1 +1,2 @@
from .normal import NormalGene, NormalGeneConfig
from .recurrent import RecurrentGene, RecurrentGeneConfig

View File

@@ -0,0 +1,84 @@
from dataclasses import dataclass
import jax
from jax import Array, numpy as jnp, vmap
from .normal import NormalGene, NormalGeneConfig
from core import State, Genome
from utils import Activation, Aggregation, unflatten_conns
@dataclass(frozen=True)
class RecurrentGeneConfig(NormalGeneConfig):
activate_times: int = 10
def __post_init__(self):
super().__post_init__()
assert self.activate_times > 0
class RecurrentGene(NormalGene):
@staticmethod
def forward_transform(state: State, genome: Genome):
u_conns = unflatten_conns(genome.nodes, genome.conns)
# remove un-enable connections and remove enable attr
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
return genome.nodes, u_conns
@staticmethod
def create_forward(state: State, config: RecurrentGeneConfig):
activation_funcs = [Activation.name2func[name] for name in config.activation_options]
aggregation_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
def act(idx, z):
"""
calculate activation function for each node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
res = jax.lax.switch(idx, activation_funcs, z)
return res
def agg(idx, z):
"""
calculate activation function for inputs of node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
def all_nan():
return 0.
def not_all_nan():
return jax.lax.switch(idx, aggregation_funcs, z)
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
batch_act, batch_agg = vmap(act), vmap(agg)
def forward(inputs, transform) -> Array:
nodes, cons = transform
input_idx = state.input_idx
output_idx = state.output_idx
N = nodes.shape[0]
vals = jnp.full((N,), 0.)
weights = cons[0, :]
def body_func(i, values):
values = values.at[input_idx].set(inputs)
nodes_ins = values * weights.T
values = batch_agg(nodes[:, 4], nodes_ins) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(nodes[:, 3], values) # z = act(z)
return values
vals = jax.lax.fori_loop(0, config.activate_times, body_func, vals)
return vals[output_idx]
return forward

View File

@@ -7,7 +7,7 @@ import numpy as np
from config import Config
from core import Algorithm, State, Gene, Genome
from .ga import crossover, create_mutate
from .species import update_species, create_speciate
from .species import SpeciesInfo, update_species, create_speciate
class NEAT(Algorithm):
@@ -22,9 +22,9 @@ class NEAT(Algorithm):
def setup(self, randkey, state: State = State()):
"""initialize the state of the algorithm"""
input_idx = np.arange(self.config.basic.num_inputs)
output_idx = np.arange(self.config.basic.num_inputs,
self.config.basic.num_inputs + self.config.basic.num_outputs)
input_idx = np.arange(self.config.neat.inputs)
output_idx = np.arange(self.config.neat.inputs,
self.config.neat.inputs + self.config.neat.outputs)
state = state.update(
P=self.config.basic.pop_size,
@@ -49,22 +49,13 @@ class NEAT(Algorithm):
state = self.gene_type.setup(self.config.gene, state)
pop_genomes = self._initialize_genomes(state)
species_keys = np.full((state.S,), np.nan, dtype=np.float32)
best_fitness = np.full((state.S,), np.nan, dtype=np.float32)
last_improved = np.full((state.S,), np.nan, dtype=np.float32)
member_count = np.full((state.S,), np.nan, dtype=np.float32)
species_info = SpeciesInfo.initialize(state)
idx2species = jnp.zeros(state.P, dtype=jnp.float32)
species_keys[0] = 0
best_fitness[0] = -np.inf
last_improved[0] = 0
member_count[0] = state.P
center_nodes = jnp.full((state.S, state.N, state.NL), jnp.nan, dtype=jnp.float32)
center_conns = jnp.full((state.S, state.C, state.CL), jnp.nan, dtype=jnp.float32)
center_nodes = center_nodes.at[0, :, :].set(pop_genomes.nodes[0, :, :])
center_conns = center_conns.at[0, :, :].set(pop_genomes.conns[0, :, :])
center_genomes = vmap(Genome)(center_nodes, center_conns)
center_genomes = Genome(center_nodes, center_conns)
center_genomes = center_genomes.set(0, pop_genomes[0])
generation = 0
next_node_key = max(*state.input_idx, *state.output_idx) + 2
@@ -73,10 +64,7 @@ class NEAT(Algorithm):
state = state.update(
randkey=randkey,
pop_genomes=pop_genomes,
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
member_count=member_count,
species_info=species_info,
idx2species=idx2species,
center_genomes=center_genomes,
@@ -135,7 +123,7 @@ class NEAT(Algorithm):
pop_nodes = np.tile(o_nodes, (state.P, 1, 1))
pop_conns = np.tile(o_conns, (state.P, 1, 1))
return vmap(Genome)(pop_nodes, pop_conns)
return Genome(pop_nodes, pop_conns)
def _create_tell(self):
mutate = create_mutate(self.config.neat, self.gene_type)

View File

@@ -1 +1,2 @@
from .operations import update_species, create_speciate
from .species_info import SpeciesInfo

View File

@@ -6,6 +6,7 @@ from jax import numpy as jnp, vmap
from core import Gene, Genome
from utils import rank_elements, fetch_first
from .distance import create_distance
from .species_info import SpeciesInfo
def update_species(state, randkey, fitness):
@@ -18,15 +19,9 @@ def update_species(state, randkey, fitness):
# sort species_info by their fitness. (push nan to the end)
sort_indices = jnp.argsort(species_fitness)[::-1]
center_nodes = state.center_genomes.nodes[sort_indices]
center_conns = state.center_genomes.conns[sort_indices]
state = state.update(
species_keys=state.species_keys[sort_indices],
best_fitness=state.best_fitness[sort_indices],
last_improved=state.last_improved[sort_indices],
member_count=state.member_count[sort_indices],
center_genomes=Genome(center_nodes, center_conns),
species_info=state.species_info[sort_indices],
center_genomes=state.center_genomes[sort_indices],
)
# decide the number of members of each species by their fitness
@@ -45,11 +40,11 @@ def update_species_fitness(state, fitness):
"""
def aux_func(idx):
s_fitness = jnp.where(state.idx2species == state.species_keys[idx], fitness, -jnp.inf)
s_fitness = jnp.where(state.idx2species == state.species_info.species_keys[idx], fitness, -jnp.inf)
f = jnp.max(s_fitness)
return f
return vmap(aux_func)(jnp.arange(state.species_keys.shape[0]))
return vmap(aux_func)(jnp.arange(state.species_info.size()))
def stagnation(state, species_fitness):
@@ -61,7 +56,7 @@ def stagnation(state, species_fitness):
def aux_func(idx):
s_fitness = species_fitness[idx]
sk, bf, li = state.species_keys[idx], state.best_fitness[idx], state.last_improved[idx]
sk, bf, li, _ = state.species_info.get(idx)
st = (s_fitness <= bf) & (state.generation - li > state.max_stagnation)
li = jnp.where(s_fitness > bf, state.generation, li)
bf = jnp.where(s_fitness > bf, s_fitness, bf)
@@ -78,18 +73,19 @@ def stagnation(state, species_fitness):
species_keys = jnp.where(spe_st, jnp.nan, species_keys)
best_fitness = jnp.where(spe_st, jnp.nan, best_fitness)
last_improved = jnp.where(spe_st, jnp.nan, last_improved)
member_count = jnp.where(spe_st, jnp.nan, state.member_count)
member_count = jnp.where(spe_st, jnp.nan, state.species_info.member_count)
species_fitness = jnp.where(spe_st, -jnp.inf, species_fitness)
species_info = SpeciesInfo(species_keys, best_fitness, last_improved, member_count)
# TODO: Simplify the coded
center_nodes = jnp.where(spe_st[:, None, None], jnp.nan, state.center_genomes.nodes)
center_conns = jnp.where(spe_st[:, None, None], jnp.nan, state.center_genomes.conns)
state = state.update(
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
member_count=member_count,
center_genomes=state.center_genomes.update(center_nodes, center_conns)
species_info=species_info,
center_genomes=Genome(center_nodes, center_conns)
)
return state, species_fitness
@@ -103,18 +99,20 @@ def cal_spawn_numbers(state):
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
"""
is_species_valid = ~jnp.isnan(state.species_keys)
species_keys = state.species_info.species_keys
is_species_valid = ~jnp.isnan(species_keys)
valid_species_num = jnp.sum(is_species_valid)
denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6
rank_score = valid_species_num - jnp.arange(state.species_keys.shape[0]) # obtain [3, 2, 1]
rank_score = valid_species_num - jnp.arange(species_keys.shape[0]) # obtain [3, 2, 1]
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
target_spawn_number = jnp.floor(spawn_number_rate * state.P) # calculate member
# Avoid too much variation of numbers in a species
previous_size = state.member_count
previous_size = state.species_info.member_count
spawn_number = previous_size + (target_spawn_number - previous_size) * state.spawn_number_change_rate
# jax.debug.print("previous_size: {}, spawn_number: {}", previous_size, spawn_number)
spawn_number = spawn_number.astype(jnp.int32)
@@ -127,14 +125,14 @@ def cal_spawn_numbers(state):
def create_crossover_pair(state, randkey, spawn_number, fitness):
species_size = state.species_keys.shape[0]
species_size = state.species_info.size()
pop_size = fitness.shape[0]
s_idx = jnp.arange(species_size)
p_idx = jnp.arange(pop_size)
# def aux_func(key, idx):
def aux_func(key, idx):
members = state.idx2species == state.species_keys[idx]
members = state.idx2species == state.species_info.species_keys[idx]
members_num = jnp.sum(members)
members_fitness = jnp.where(members, fitness, -jnp.inf)
@@ -176,7 +174,7 @@ def create_speciate(gene_type: Type[Gene]):
distance = create_distance(gene_type)
def speciate(state):
pop_size, species_size = state.idx2species.shape[0], state.species_keys.shape[0]
pop_size, species_size = state.idx2species.shape[0], state.species_info.size()
# prepare distance functions
o2p_distance_func = vmap(distance, in_axes=(None, None, 0)) # one to population
@@ -191,25 +189,23 @@ def create_speciate(gene_type: Type[Gene]):
def cond_func(carry):
i, i2s, cgs, o2c = carry
return (i < species_size) & (~jnp.isnan(state.species_keys[i])) # current species is existing
return (i < species_size) & (~jnp.isnan(state.species_info.species_keys[i])) # current species is existing
def body_func(carry):
i, i2s, cgs, o2c = carry
distances = o2p_distance_func(state, Genome(cgs.nodes[i], cgs.conns[i]), state.pop_genomes)
distances = o2p_distance_func(state, cgs[i], state.pop_genomes)
# find the closest one
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
# jax.debug.print("closest_idx: {}", closest_idx)
i2s = i2s.at[closest_idx].set(state.species_keys[i])
cn = cgs.nodes.at[i].set(state.pop_genomes.nodes[closest_idx])
cc = cgs.conns.at[i].set(state.pop_genomes.conns[closest_idx])
i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i])
cgs = cgs.set(i, state.pop_genomes[closest_idx])
# the genome with closest_idx will become the new center, thus its distance to center is 0.
o2c = o2c.at[closest_idx].set(0)
return i + 1, i2s, Genome(cn, cc), o2c
return i + 1, i2s, cgs, o2c
_, idx2species, center_genomes, o2c_distances = \
jax.lax.while_loop(cond_func, body_func, (0, idx2species, state.center_genomes, o2c_distances))
@@ -247,15 +243,13 @@ def create_speciate(gene_type: Type[Gene]):
idx = fetch_first(jnp.isnan(i2s))
# assign it to the new species
# [key, best score, last update generation, members_count]
# [key, best score, last update generation, member_count]
sk = sk.at[i].set(nsk)
i2s = i2s.at[idx].set(nsk)
o2c = o2c.at[idx].set(0)
# update center genomes
cn = cgs.nodes.at[i].set(state.pop_genomes.nodes[idx])
cc = cgs.conns.at[i].set(state.pop_genomes.conns[idx])
cgs = Genome(cn, cc)
cgs = cgs.set(i, state.pop_genomes[idx])
i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c)
@@ -273,8 +267,7 @@ def create_speciate(gene_type: Type[Gene]):
def speciate_by_threshold(i, i2s, cgs, sk, o2c):
# distance between such center genome and ppo genomes
center = Genome(cgs.nodes[i], cgs.conns[i])
o2p_distance = o2p_distance_func(state, center, state.pop_genomes)
o2p_distance = o2p_distance_func(state, cgs[i], state.pop_genomes)
close_enough_mask = o2p_distance < state.compatibility_threshold
# when a genome is not assigned or the distance between its current center is bigger than this center
@@ -294,32 +287,31 @@ def create_speciate(gene_type: Type[Gene]):
_, idx2species, center_genomes, species_keys, _, next_species_key = jax.lax.while_loop(
cond_func,
body_func,
(0, state.idx2species, state.center_genomes, state.species_keys, o2c_distances, state.next_species_key)
(0, state.idx2species, state.center_genomes, state.species_info.species_keys, o2c_distances, state.next_species_key)
)
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition can only happen when the number of species is reached species upper bounds
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
# complete info of species which is created in this generation
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.best_fitness)
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.best_fitness)
last_improved = jnp.where(new_created_mask, state.generation, state.last_improved)
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species_info.best_fitness)
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species_info.best_fitness)
last_improved = jnp.where(new_created_mask, state.generation, state.species_info.last_improved)
# update members count
def count_members(idx):
key = species_keys[idx]
count = jnp.sum(idx2species == key)
count = jnp.sum(idx2species == key, dtype=jnp.float32)
count = jnp.where(jnp.isnan(key), jnp.nan, count)
return count
member_count = vmap(count_members)(jnp.arange(species_size))
return state.update(
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
members_count=member_count,
species_info = SpeciesInfo(species_keys, best_fitness, last_improved, member_count),
idx2species=idx2species,
center_genomes=center_genomes,
next_species_key=next_species_key

View File

@@ -0,0 +1,55 @@
from jax.tree_util import register_pytree_node_class
import numpy as np
import jax.numpy as jnp
@register_pytree_node_class
class SpeciesInfo:
def __init__(self, species_keys, best_fitness, last_improved, member_count):
self.species_keys = species_keys
self.best_fitness = best_fitness
self.last_improved = last_improved
self.member_count = member_count
@classmethod
def initialize(cls, state):
species_keys = np.full((state.S,), np.nan, dtype=np.float32)
best_fitness = np.full((state.S,), np.nan, dtype=np.float32)
last_improved = np.full((state.S,), np.nan, dtype=np.float32)
member_count = np.full((state.S,), np.nan, dtype=np.float32)
species_keys[0] = 0
best_fitness[0] = -np.inf
last_improved[0] = 0
member_count[0] = state.P
return cls(species_keys, best_fitness, last_improved, member_count)
def __getitem__(self, i):
return SpeciesInfo(self.species_keys[i], self.best_fitness[i], self.last_improved[i], self.member_count[i])
def get(self, i):
return self.species_keys[i], self.best_fitness[i], self.last_improved[i], self.member_count[i]
def set(self, idx, value):
species_keys = self.species_keys.at[idx].set(value[0])
best_fitness = self.best_fitness.at[idx].set(value[1])
last_improved = self.last_improved.at[idx].set(value[2])
member_count = self.member_count.at[idx].set(value[3])
return SpeciesInfo(species_keys, best_fitness, last_improved, member_count)
def remove(self, idx):
return self.set(idx, jnp.array([jnp.nan] * 4))
def size(self):
return self.species_keys.shape[0]
def tree_flatten(self):
children = self.species_keys, self.best_fitness, self.last_improved, self.member_count
aux_data = None
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)