diff --git a/tensorneat/algorithm/__init__.py b/tensorneat/algorithm/__init__.py index b2f3695..deaf74c 100644 --- a/tensorneat/algorithm/__init__.py +++ b/tensorneat/algorithm/__init__.py @@ -1,2 +1,2 @@ from .base import BaseAlgorithm -from .neat import NEAT \ No newline at end of file +from .neat import NEAT diff --git a/tensorneat/algorithm/base.py b/tensorneat/algorithm/base.py index 93aeafe..d9677f2 100644 --- a/tensorneat/algorithm/base.py +++ b/tensorneat/algorithm/base.py @@ -2,8 +2,7 @@ from utils import State class BaseAlgorithm: - - def setup(self, randkey): + def setup(self, state=State()): """initialize the state of the algorithm""" raise NotImplementedError @@ -16,11 +15,11 @@ class BaseAlgorithm: """update the state of the algorithm""" raise NotImplementedError - def transform(self, individual): + def transform(self, state, individual): """transform the genome into a neural network""" raise NotImplementedError - def forward(self, inputs, transformed): + def forward(self, state, inputs, transformed): raise NotImplementedError @property @@ -42,4 +41,3 @@ class BaseAlgorithm: def generation(self, state: State): # to analysis the algorithm raise NotImplementedError - diff --git a/tensorneat/algorithm/hyperneat/__init__.py b/tensorneat/algorithm/hyperneat/__init__.py index 374e2aa..eef06a2 100644 --- a/tensorneat/algorithm/hyperneat/__init__.py +++ b/tensorneat/algorithm/hyperneat/__init__.py @@ -1,2 +1,2 @@ from .hyperneat import HyperNEAT -from .substrate import BaseSubstrate, DefaultSubstrate, FullSubstrate \ No newline at end of file +from .substrate import BaseSubstrate, DefaultSubstrate, FullSubstrate diff --git a/tensorneat/algorithm/hyperneat/hyperneat.py b/tensorneat/algorithm/hyperneat/hyperneat.py index 302a3c6..aa574ef 100644 --- a/tensorneat/algorithm/hyperneat/hyperneat.py +++ b/tensorneat/algorithm/hyperneat/hyperneat.py @@ -10,20 +10,20 @@ from .substrate import * class HyperNEAT(BaseAlgorithm): - def __init__( - self, - substrate: BaseSubstrate, - neat: NEAT, - below_threshold: float = 0.3, - max_weight: float = 5., - activation=Act.sigmoid, - aggregation=Agg.sum, - activate_time: int = 10, - output_transform: Callable = Act.sigmoid, + self, + substrate: BaseSubstrate, + neat: NEAT, + below_threshold: float = 0.3, + max_weight: float = 5.0, + activation=Act.sigmoid, + aggregation=Agg.sum, + activate_time: int = 10, + output_transform: Callable = Act.sigmoid, ): - assert substrate.query_coors.shape[1] == neat.num_inputs, \ - "Substrate input size should be equal to NEAT input size" + assert ( + substrate.query_coors.shape[1] == neat.num_inputs + ), "Substrate input size should be equal to NEAT input size" self.substrate = substrate self.neat = neat @@ -37,39 +37,43 @@ class HyperNEAT(BaseAlgorithm): node_gene=HyperNodeGene(activation, aggregation), conn_gene=HyperNEATConnGene(), activate_time=activate_time, - output_transform=output_transform + output_transform=output_transform, ) def setup(self, randkey): - return State( - neat_state=self.neat.setup(randkey) - ) + return State(neat_state=self.neat.setup(randkey)) def ask(self, state: State): return self.neat.ask(state.neat_state) def tell(self, state: State, fitness): - return state.update( - neat_state=self.neat.tell(state.neat_state, fitness) - ) + return state.update(neat_state=self.neat.tell(state.neat_state, fitness)) def transform(self, individual): transformed = self.neat.transform(individual) - query_res = jax.vmap(self.neat.forward, in_axes=(0, None))(self.substrate.query_coors, transformed) + query_res = jax.vmap(self.neat.forward, in_axes=(0, None))( + self.substrate.query_coors, transformed + ) # mute the connection with weight below threshold query_res = jnp.where( (-self.below_threshold < query_res) & (query_res < self.below_threshold), - 0., - query_res + 0.0, + query_res, ) # make query res in range [-max_weight, max_weight] - query_res = jnp.where(query_res > 0, query_res - self.below_threshold, query_res) - query_res = jnp.where(query_res < 0, query_res + self.below_threshold, query_res) + query_res = jnp.where( + query_res > 0, query_res - self.below_threshold, query_res + ) + query_res = jnp.where( + query_res < 0, query_res + self.below_threshold, query_res + ) query_res = query_res / (1 - self.below_threshold) * self.max_weight - h_nodes, h_conns = self.substrate.make_nodes(query_res), self.substrate.make_conn(query_res) + h_nodes, h_conns = self.substrate.make_nodes( + query_res + ), self.substrate.make_conn(query_res) return self.hyper_genome.transform(h_nodes, h_conns) def forward(self, inputs, transformed): @@ -97,11 +101,11 @@ class HyperNEAT(BaseAlgorithm): class HyperNodeGene(BaseNodeGene): - - def __init__(self, - activation=Act.sigmoid, - aggregation=Agg.sum, - ): + def __init__( + self, + activation=Act.sigmoid, + aggregation=Agg.sum, + ): super().__init__() self.activation = activation self.aggregation = aggregation @@ -110,12 +114,12 @@ class HyperNodeGene(BaseNodeGene): return jax.lax.cond( is_output_node, lambda: self.aggregation(inputs), # output node does not need activation - lambda: self.activation(self.aggregation(inputs)) - + lambda: self.activation(self.aggregation(inputs)), ) + class HyperNEATConnGene(BaseConnGene): - custom_attrs = ['weight'] + custom_attrs = ["weight"] def forward(self, attrs, inputs): weight = attrs[0] diff --git a/tensorneat/algorithm/hyperneat/substrate/base.py b/tensorneat/algorithm/hyperneat/substrate/base.py index 3a15832..8a60756 100644 --- a/tensorneat/algorithm/hyperneat/substrate/base.py +++ b/tensorneat/algorithm/hyperneat/substrate/base.py @@ -1,5 +1,4 @@ class BaseSubstrate: - def make_nodes(self, query_res): raise NotImplementedError diff --git a/tensorneat/algorithm/hyperneat/substrate/default.py b/tensorneat/algorithm/hyperneat/substrate/default.py index a7273dc..bb8b118 100644 --- a/tensorneat/algorithm/hyperneat/substrate/default.py +++ b/tensorneat/algorithm/hyperneat/substrate/default.py @@ -3,7 +3,6 @@ from . import BaseSubstrate class DefaultSubstrate(BaseSubstrate): - def __init__(self, num_inputs, num_outputs, coors, nodes, conns): self.inputs = num_inputs self.outputs = num_outputs diff --git a/tensorneat/algorithm/hyperneat/substrate/full.py b/tensorneat/algorithm/hyperneat/substrate/full.py index 98ec869..31c4ed9 100644 --- a/tensorneat/algorithm/hyperneat/substrate/full.py +++ b/tensorneat/algorithm/hyperneat/substrate/full.py @@ -3,20 +3,16 @@ from .default import DefaultSubstrate class FullSubstrate(DefaultSubstrate): - - def __init__(self, - input_coors=((-1, -1), (0, -1), (1, -1)), - hidden_coors=((-1, 0), (0, 0), (1, 0)), - output_coors=((0, 1),), - ): - query_coors, nodes, conns = analysis_substrate(input_coors, output_coors, hidden_coors) - super().__init__( - len(input_coors), - len(output_coors), - query_coors, - nodes, - conns + def __init__( + self, + input_coors=((-1, -1), (0, -1), (1, -1)), + hidden_coors=((-1, 0), (0, 0), (1, 0)), + output_coors=((0, 1),), + ): + query_coors, nodes, conns = analysis_substrate( + input_coors, output_coors, hidden_coors ) + super().__init__(len(input_coors), len(output_coors), query_coors, nodes, conns) def analysis_substrate(input_coors, output_coors, hidden_coors): @@ -38,22 +34,30 @@ def analysis_substrate(input_coors, output_coors, hidden_coors): correspond_keys = np.zeros((total_conns, 2)) # connect input to hidden - aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, input_coors, hidden_coors) - query_coors[0: si * sh, :] = aux_coors - correspond_keys[0: si * sh, :] = aux_keys + aux_coors, aux_keys = cartesian_product( + input_idx, hidden_idx, input_coors, hidden_coors + ) + query_coors[0 : si * sh, :] = aux_coors + correspond_keys[0 : si * sh, :] = aux_keys # connect hidden to hidden - aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, hidden_coors, hidden_coors) - query_coors[si * sh: si * sh + sh * sh, :] = aux_coors - correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys + aux_coors, aux_keys = cartesian_product( + hidden_idx, hidden_idx, hidden_coors, hidden_coors + ) + query_coors[si * sh : si * sh + sh * sh, :] = aux_coors + correspond_keys[si * sh : si * sh + sh * sh, :] = aux_keys # connect hidden to output - aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, hidden_coors, output_coors) - query_coors[si * sh + sh * sh:, :] = aux_coors - correspond_keys[si * sh + sh * sh:, :] = aux_keys + aux_coors, aux_keys = cartesian_product( + hidden_idx, output_idx, hidden_coors, output_coors + ) + query_coors[si * sh + sh * sh :, :] = aux_coors + correspond_keys[si * sh + sh * sh :, :] = aux_keys nodes = np.concatenate((input_idx, output_idx, hidden_idx))[..., np.newaxis] - conns = np.zeros((correspond_keys.shape[0], 4), dtype=np.float32) # input_idx, output_idx, enabled, weight + conns = np.zeros( + (correspond_keys.shape[0], 4), dtype=np.float32 + ) # input_idx, output_idx, enabled, weight conns[:, 0:2] = correspond_keys conns[:, 2] = 1 # enabled is True diff --git a/tensorneat/algorithm/neat/__init__.py b/tensorneat/algorithm/neat/__init__.py index 1af3e2b..1338b5b 100644 --- a/tensorneat/algorithm/neat/__init__.py +++ b/tensorneat/algorithm/neat/__init__.py @@ -2,4 +2,4 @@ from .ga import * from .gene import * from .genome import * from .species import * -from .neat import NEAT \ No newline at end of file +from .neat import NEAT diff --git a/tensorneat/algorithm/neat/ga/crossover/base.py b/tensorneat/algorithm/neat/ga/crossover/base.py index 206a6f8..8a2dc65 100644 --- a/tensorneat/algorithm/neat/ga/crossover/base.py +++ b/tensorneat/algorithm/neat/ga/crossover/base.py @@ -2,7 +2,6 @@ from utils import State class BaseCrossover: - def setup(self, state=State()): return state diff --git a/tensorneat/algorithm/neat/ga/crossover/default.py b/tensorneat/algorithm/neat/ga/crossover/default.py index c6e3e37..71fd7af 100644 --- a/tensorneat/algorithm/neat/ga/crossover/default.py +++ b/tensorneat/algorithm/neat/ga/crossover/default.py @@ -4,7 +4,6 @@ from .base import BaseCrossover class DefaultCrossover(BaseCrossover): - def __call__(self, state, genome, nodes1, conns1, nodes2, conns2): """ use genome1 and genome2 to generate a new genome @@ -19,15 +18,21 @@ class DefaultCrossover(BaseCrossover): # For not homologous genes, use the value of nodes1(winner) # For homologous genes, use the crossover result between nodes1 and nodes2 - new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, - self.crossover_gene(randkey1, nodes1, nodes2, is_conn=False)) + new_nodes = jnp.where( + jnp.isnan(nodes1) | jnp.isnan(nodes2), + nodes1, + self.crossover_gene(randkey1, nodes1, nodes2, is_conn=False), + ) # crossover connections con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2] conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True) - new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, - self.crossover_gene(randkey2, conns1, conns2, is_conn=True)) + new_conns = jnp.where( + jnp.isnan(conns1) | jnp.isnan(conns2), + conns1, + self.crossover_gene(randkey2, conns1, conns2, is_conn=True), + ) return state.update(randkey=randkey), new_nodes, new_conns @@ -53,7 +58,9 @@ class DefaultCrossover(BaseCrossover): idx = jnp.arange(0, len(seq1)) idx_fixed = jnp.dot(mask, idx) - refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan) + refactor_ar2 = jnp.where( + intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan + ) return refactor_ar2 @@ -61,10 +68,6 @@ class DefaultCrossover(BaseCrossover): r = jax.random.uniform(rand_key, shape=g1.shape) new_gene = jnp.where(r > 0.5, g1, g2) if is_conn: # fix enabled - enabled = jnp.where( - g1[:, 2] + g2[:, 2] > 0, # any of them is enabled - 1, - 0 - ) + enabled = jnp.where(g1[:, 2] + g2[:, 2] > 0, 1, 0) # any of them is enabled new_gene = new_gene.at[:, 2].set(enabled) return new_gene diff --git a/tensorneat/algorithm/neat/ga/mutation/__init__.py b/tensorneat/algorithm/neat/ga/mutation/__init__.py index 599f35c..2c12bea 100644 --- a/tensorneat/algorithm/neat/ga/mutation/__init__.py +++ b/tensorneat/algorithm/neat/ga/mutation/__init__.py @@ -1,2 +1,2 @@ from .base import BaseMutation -from .default import DefaultMutation \ No newline at end of file +from .default import DefaultMutation diff --git a/tensorneat/algorithm/neat/ga/mutation/base.py b/tensorneat/algorithm/neat/ga/mutation/base.py index 4e4a0b3..ab7c06b 100644 --- a/tensorneat/algorithm/neat/ga/mutation/base.py +++ b/tensorneat/algorithm/neat/ga/mutation/base.py @@ -2,7 +2,6 @@ from utils import State class BaseMutation: - def setup(self, state=State()): return state diff --git a/tensorneat/algorithm/neat/ga/mutation/default.py b/tensorneat/algorithm/neat/ga/mutation/default.py index b0fc047..0d716c0 100644 --- a/tensorneat/algorithm/neat/ga/mutation/default.py +++ b/tensorneat/algorithm/neat/ga/mutation/default.py @@ -1,16 +1,15 @@ import jax, jax.numpy as jnp from . import BaseMutation -from utils import fetch_first, fetch_random, I_INT, unflatten_conns, check_cycles +from utils import fetch_first, fetch_random, I_INF, unflatten_conns, check_cycles class DefaultMutation(BaseMutation): - def __init__( - self, - conn_add: float = 0.4, - conn_delete: float = 0, - node_add: float = 0.2, - node_delete: float = 0, + self, + conn_add: float = 0.4, + conn_delete: float = 0, + node_add: float = 0.2, + node_delete: float = 0, ): self.conn_add = conn_add self.conn_delete = conn_delete @@ -34,25 +33,45 @@ class DefaultMutation(BaseMutation): new_conns = conns_.at[idx, 2].set(False) # add a new node - new_nodes = genome.add_node(nodes_, new_node_key, genome.node_gene.new_custom_attrs()) + new_nodes = genome.add_node( + nodes_, new_node_key, genome.node_gene.new_custom_attrs() + ) # add two new connections - new_conns = genome.add_conn(new_conns, i_key, new_node_key, True, genome.conn_gene.new_custom_attrs()) - new_conns = genome.add_conn(new_conns, new_node_key, o_key, True, genome.conn_gene.new_custom_attrs()) + new_conns = genome.add_conn( + new_conns, + i_key, + new_node_key, + True, + genome.conn_gene.new_custom_attrs(), + ) + new_conns = genome.add_conn( + new_conns, + new_node_key, + o_key, + True, + genome.conn_gene.new_custom_attrs(), + ) return new_nodes, new_conns return jax.lax.cond( - idx == I_INT, + idx == I_INF, lambda: (nodes_, conns_), # do nothing - successful_add_node + successful_add_node, ) def mutate_delete_node(key_, nodes_, conns_): # randomly choose a node - key, idx = self.choice_node_key(key_, nodes_, genome.input_idx, genome.output_idx, - allow_input_keys=False, allow_output_keys=False) + key, idx = self.choice_node_key( + key_, + nodes_, + genome.input_idx, + genome.output_idx, + allow_input_keys=False, + allow_output_keys=False, + ) def successful_delete_node(): # delete the node @@ -62,15 +81,15 @@ class DefaultMutation(BaseMutation): new_conns = jnp.where( ((conns_[:, 0] == key) | (conns_[:, 1] == key))[:, None], jnp.nan, - conns_ + conns_, ) return new_nodes, new_conns return jax.lax.cond( - idx == I_INT, + idx == I_INF, lambda: (nodes_, conns_), # do nothing - successful_delete_node + successful_delete_node, ) def mutate_add_conn(key_, nodes_, conns_): @@ -78,26 +97,40 @@ class DefaultMutation(BaseMutation): k1_, k2_ = jax.random.split(key_, num=2) # input node of the connection can be any node - i_key, from_idx = self.choice_node_key(k1_, nodes_, genome.input_idx, genome.output_idx, - allow_input_keys=True, allow_output_keys=True) + i_key, from_idx = self.choice_node_key( + k1_, + nodes_, + genome.input_idx, + genome.output_idx, + allow_input_keys=True, + allow_output_keys=True, + ) # output node of the connection can be any node except input node - o_key, to_idx = self.choice_node_key(k2_, nodes_, genome.input_idx, genome.output_idx, - allow_input_keys=False, allow_output_keys=True) + o_key, to_idx = self.choice_node_key( + k2_, + nodes_, + genome.input_idx, + genome.output_idx, + allow_input_keys=False, + allow_output_keys=True, + ) conn_pos = fetch_first((conns_[:, 0] == i_key) & (conns_[:, 1] == o_key)) - is_already_exist = conn_pos != I_INT + is_already_exist = conn_pos != I_INF def nothing(): return nodes_, conns_ def successful(): - return nodes_, genome.add_conn(conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs()) + return nodes_, genome.add_conn( + conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs() + ) def already_exist(): return nodes_, conns_.at[conn_pos, 2].set(True) - if genome.network_type == 'feedforward': + if genome.network_type == "feedforward": u_cons = unflatten_conns(nodes_, conns_) cons_exist = ~jnp.isnan(u_cons[0, :, :]) is_cycle = check_cycles(nodes_, cons_exist, from_idx, to_idx) @@ -105,20 +138,11 @@ class DefaultMutation(BaseMutation): return jax.lax.cond( is_already_exist, already_exist, - lambda: - jax.lax.cond( - is_cycle, - nothing, - successful - ) + lambda: jax.lax.cond(is_cycle, nothing, successful), ) - elif genome.network_type == 'recurrent': - return jax.lax.cond( - is_already_exist, - already_exist, - successful - ) + elif genome.network_type == "recurrent": + return jax.lax.cond(is_already_exist, already_exist, successful) else: raise ValueError(f"Invalid network type: {genome.network_type}") @@ -131,9 +155,9 @@ class DefaultMutation(BaseMutation): return nodes_, genome.delete_conn_by_pos(conns_, idx) return jax.lax.cond( - idx == I_INT, + idx == I_INF, lambda: (nodes_, conns_), # nothing - successfully_delete_connection + successfully_delete_connection, ) k1, k2, k3, k4 = jax.random.split(key, num=4) @@ -142,10 +166,18 @@ class DefaultMutation(BaseMutation): def no(key_, nodes_, conns_): return nodes_, conns_ - nodes, conns = jax.lax.cond(r1 < self.node_add, mutate_add_node, no, k1, nodes, conns) - nodes, conns = jax.lax.cond(r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns) - nodes, conns = jax.lax.cond(r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns) - nodes, conns = jax.lax.cond(r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns) + nodes, conns = jax.lax.cond( + r1 < self.node_add, mutate_add_node, no, k1, nodes, conns + ) + nodes, conns = jax.lax.cond( + r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns + ) + nodes, conns = jax.lax.cond( + r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns + ) + nodes, conns = jax.lax.cond( + r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns + ) return nodes, conns @@ -163,8 +195,15 @@ class DefaultMutation(BaseMutation): return new_nodes, new_conns - def choice_node_key(self, key, nodes, input_idx, output_idx, - allow_input_keys: bool = False, allow_output_keys: bool = False): + def choice_node_key( + self, + key, + nodes, + input_idx, + output_idx, + allow_input_keys: bool = False, + allow_output_keys: bool = False, + ): """ Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node. :param key: @@ -186,7 +225,7 @@ class DefaultMutation(BaseMutation): mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_idx)) idx = fetch_random(key, mask) - key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan) + key = jnp.where(idx != I_INF, nodes[idx, 0], jnp.nan) return key, idx def choice_connection_key(self, key, conns): @@ -196,7 +235,7 @@ class DefaultMutation(BaseMutation): """ idx = fetch_random(key, ~jnp.isnan(conns[:, 0])) - i_key = jnp.where(idx != I_INT, conns[idx, 0], jnp.nan) - o_key = jnp.where(idx != I_INT, conns[idx, 1], jnp.nan) + i_key = jnp.where(idx != I_INF, conns[idx, 0], jnp.nan) + o_key = jnp.where(idx != I_INF, conns[idx, 1], jnp.nan) return i_key, o_key, idx diff --git a/tensorneat/algorithm/neat/gene/base.py b/tensorneat/algorithm/neat/gene/base.py index 4e3a49a..afbd5f6 100644 --- a/tensorneat/algorithm/neat/gene/base.py +++ b/tensorneat/algorithm/neat/gene/base.py @@ -12,10 +12,15 @@ class BaseGene: def setup(self, state=State()): return state - def new_attrs(self, state): + def new_custom_attrs(self, state): + # the attrs which make the least influence on the network, used in add node or add conn in mutation raise NotImplementedError - def mutate(self, state, gene): + def new_random_attrs(self, state, randkey): + # random attributes of the gene. used in initialization. + raise NotImplementedError + + def mutate(self, state, randkey, gene): raise NotImplementedError def distance(self, state, gene1, gene2): diff --git a/tensorneat/algorithm/neat/gene/conn/base.py b/tensorneat/algorithm/neat/gene/conn/base.py index 17a67fc..a4ab3dc 100644 --- a/tensorneat/algorithm/neat/gene/conn/base.py +++ b/tensorneat/algorithm/neat/gene/conn/base.py @@ -3,7 +3,7 @@ from .. import BaseGene class BaseConnGene(BaseGene): "Base class for connection genes." - fixed_attrs = ['input_index', 'output_index', 'enabled'] + fixed_attrs = ["input_index", "output_index", "enabled"] def __init__(self): super().__init__() diff --git a/tensorneat/algorithm/neat/gene/conn/default.py b/tensorneat/algorithm/neat/gene/conn/default.py index 8f4d0e8..26d1a80 100644 --- a/tensorneat/algorithm/neat/gene/conn/default.py +++ b/tensorneat/algorithm/neat/gene/conn/default.py @@ -8,15 +8,15 @@ from . import BaseConnGene class DefaultConnGene(BaseConnGene): "Default connection gene, with the same behavior as in NEAT-python." - custom_attrs = ['weight'] + custom_attrs = ["weight"] def __init__( - self, - weight_init_mean: float = 0.0, - weight_init_std: float = 1.0, - weight_mutate_power: float = 0.5, - weight_mutate_rate: float = 0.8, - weight_replace_rate: float = 0.1, + self, + weight_init_mean: float = 0.0, + weight_init_std: float = 1.0, + weight_mutate_power: float = 0.5, + weight_mutate_rate: float = 0.8, + weight_replace_rate: float = 0.1, ): super().__init__() self.weight_init_mean = weight_init_mean @@ -25,28 +25,37 @@ class DefaultConnGene(BaseConnGene): self.weight_mutate_rate = weight_mutate_rate self.weight_replace_rate = weight_replace_rate - def new_attrs(self, state): + def new_custom_attrs(self, state): return state, jnp.array([self.weight_init_mean]) - def mutate(self, state, conn): - randkey_, randkey = jax.random.split(state.randkey, 2) + def new_random_attrs(self, state, randkey): + weight = ( + jax.random.normal(randkey, ()) * self.weight_init_std + + self.weight_init_mean + ) + return jnp.array([weight]) + + def mutate(self, state, randkey, conn): input_index = conn[0] output_index = conn[1] enabled = conn[2] - weight = mutate_float(randkey_, - conn[3], - self.weight_init_mean, - self.weight_init_std, - self.weight_mutate_power, - self.weight_mutate_rate, - self.weight_replace_rate - ) + weight = mutate_float( + randkey, + conn[3], + self.weight_init_mean, + self.weight_init_std, + self.weight_mutate_power, + self.weight_mutate_rate, + self.weight_replace_rate, + ) - return state.update(randkey=randkey), jnp.array([input_index, output_index, enabled, weight]) + return jnp.array([input_index, output_index, enabled, weight]) def distance(self, state, attrs1, attrs2): - return state, (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight + return (attrs1[2] != attrs2[2]) + jnp.abs( + attrs1[3] - attrs2[3] + ) # enable + weight def forward(self, state, attrs, inputs): weight = attrs[0] - return state, inputs * weight + return inputs * weight diff --git a/tensorneat/algorithm/neat/gene/node/default.py b/tensorneat/algorithm/neat/gene/node/default.py index 6e118ce..6259527 100644 --- a/tensorneat/algorithm/neat/gene/node/default.py +++ b/tensorneat/algorithm/neat/gene/node/default.py @@ -9,29 +9,26 @@ from . import BaseNodeGene class DefaultNodeGene(BaseNodeGene): "Default node gene, with the same behavior as in NEAT-python." - custom_attrs = ['bias', 'response', 'aggregation', 'activation'] + custom_attrs = ["bias", "response", "aggregation", "activation"] def __init__( - self, - bias_init_mean: float = 0.0, - bias_init_std: float = 1.0, - bias_mutate_power: float = 0.5, - bias_mutate_rate: float = 0.7, - bias_replace_rate: float = 0.1, - - response_init_mean: float = 1.0, - response_init_std: float = 0.0, - response_mutate_power: float = 0.5, - response_mutate_rate: float = 0.7, - response_replace_rate: float = 0.1, - - activation_default: callable = Act.sigmoid, - activation_options: Tuple = (Act.sigmoid,), - activation_replace_rate: float = 0.1, - - aggregation_default: callable = Agg.sum, - aggregation_options: Tuple = (Agg.sum,), - aggregation_replace_rate: float = 0.1, + self, + bias_init_mean: float = 0.0, + bias_init_std: float = 1.0, + bias_mutate_power: float = 0.5, + bias_mutate_rate: float = 0.7, + bias_replace_rate: float = 0.1, + response_init_mean: float = 1.0, + response_init_std: float = 0.0, + response_mutate_power: float = 0.5, + response_mutate_rate: float = 0.7, + response_replace_rate: float = 0.1, + activation_default: callable = Act.sigmoid, + activation_options: Tuple = (Act.sigmoid,), + activation_replace_rate: float = 0.1, + aggregation_default: callable = Agg.sum, + aggregation_options: Tuple = (Agg.sum,), + aggregation_replace_rate: float = 0.1, ): super().__init__() self.bias_init_mean = bias_init_mean @@ -56,33 +53,66 @@ class DefaultNodeGene(BaseNodeGene): self.aggregation_indices = jnp.arange(len(aggregation_options)) self.aggregation_replace_rate = aggregation_replace_rate - def new_attrs(self, state): - return state, jnp.array( - [self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default] + def new_custom_attrs(self, state): + return jnp.array( + [ + self.bias_init_mean, + self.response_init_mean, + self.activation_default, + self.aggregation_default, + ] ) - def mutate(self, state, node): - k1, k2, k3, k4, randkey = jax.random.split(state.randkey, num=5) + def new_random_attrs(self, state, randkey): + k1, k2, k3, k4 = jax.random.split(randkey, num=4) + bias = jax.random.normal(k1, ()) * self.bias_init_std + self.bias_init_mean + res = ( + jax.random.normal(k2, ()) * self.response_init_std + self.response_init_mean + ) + act = jax.random.randint(k3, (), 0, len(self.activation_options)) + agg = jax.random.randint(k4, (), 0, len(self.aggregation_options)) + return jnp.array([bias, res, act, agg]) + + def mutate(self, state, randkey, node): + k1, k2, k3, k4 = jax.random.split(state.randkey, num=4) index = node[0] - bias = mutate_float(k1, node[1], self.bias_init_mean, self.bias_init_std, - self.bias_mutate_power, self.bias_mutate_rate, self.bias_replace_rate) + bias = mutate_float( + k1, + node[1], + self.bias_init_mean, + self.bias_init_std, + self.bias_mutate_power, + self.bias_mutate_rate, + self.bias_replace_rate, + ) - res = mutate_float(k2, node[2], self.response_init_mean, self.response_init_std, - self.response_mutate_power, self.response_mutate_rate, self.response_replace_rate) + res = mutate_float( + k2, + node[2], + self.response_init_mean, + self.response_init_std, + self.response_mutate_power, + self.response_mutate_rate, + self.response_replace_rate, + ) - act = mutate_int(k3, node[3], self.activation_indices, self.activation_replace_rate) + act = mutate_int( + k3, node[3], self.activation_indices, self.activation_replace_rate + ) - agg = mutate_int(k4, node[4], self.aggregation_indices, self.aggregation_replace_rate) + agg = mutate_int( + k4, node[4], self.aggregation_indices, self.aggregation_replace_rate + ) - return state.update(randkey=randkey), jnp.array([index, bias, res, act, agg]) + return jnp.array([index, bias, res, act, agg]) def distance(self, state, node1, node2): - return state, ( - jnp.abs(node1[1] - node2[1]) + - jnp.abs(node1[2] - node2[2]) + - (node1[3] != node2[3]) + - (node1[4] != node2[4]) + return ( + jnp.abs(node1[1] - node2[1]) + + jnp.abs(node1[2] - node2[2]) + + (node1[3] != node2[3]) + + (node1[4] != node2[4]) ) def forward(self, state, attrs, inputs, is_output_node=False): @@ -93,9 +123,7 @@ class DefaultNodeGene(BaseNodeGene): # the last output node should not be activated z = jax.lax.cond( - is_output_node, - lambda: z, - lambda: act(act_idx, z, self.activation_options) + is_output_node, lambda: z, lambda: act(act_idx, z, self.activation_options) ) - return state, z + return z diff --git a/tensorneat/algorithm/neat/genome/base.py b/tensorneat/algorithm/neat/genome/base.py index fd711ab..995d13f 100644 --- a/tensorneat/algorithm/neat/genome/base.py +++ b/tensorneat/algorithm/neat/genome/base.py @@ -7,13 +7,13 @@ class BaseGenome: network_type = None def __init__( - self, - num_inputs: int, - num_outputs: int, - max_nodes: int, - max_conns: int, - node_gene: BaseNodeGene = DefaultNodeGene(), - conn_gene: BaseConnGene = DefaultConnGene(), + self, + num_inputs: int, + num_outputs: int, + max_nodes: int, + max_conns: int, + node_gene: BaseNodeGene = DefaultNodeGene(), + conn_gene: BaseConnGene = DefaultConnGene(), ): self.num_inputs = num_inputs self.num_outputs = num_outputs @@ -25,6 +25,8 @@ class BaseGenome: self.conn_gene = conn_gene def setup(self, state=State()): + state = self.node_gene.setup(state) + state = self.conn_gene.setup(state) return state def transform(self, state, nodes, conns): diff --git a/tensorneat/algorithm/neat/genome/default.py b/tensorneat/algorithm/neat/genome/default.py index 98445bf..37a6453 100644 --- a/tensorneat/algorithm/neat/genome/default.py +++ b/tensorneat/algorithm/neat/genome/default.py @@ -1,7 +1,7 @@ from typing import Callable import jax, jax.numpy as jnp -from utils import unflatten_conns, topological_sort, I_INT +from utils import unflatten_conns, topological_sort, I_INF from . import BaseGenome from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene @@ -10,18 +10,21 @@ from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene class DefaultGenome(BaseGenome): """Default genome class, with the same behavior as the NEAT-Python""" - network_type = 'feedforward' + network_type = "feedforward" - def __init__(self, - num_inputs: int, - num_outputs: int, - max_nodes=5, - max_conns=4, - node_gene: BaseNodeGene = DefaultNodeGene(), - conn_gene: BaseConnGene = DefaultConnGene(), - output_transform: Callable = None - ): - super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene) + def __init__( + self, + num_inputs: int, + num_outputs: int, + max_nodes=5, + max_conns=4, + node_gene: BaseNodeGene = DefaultNodeGene(), + conn_gene: BaseConnGene = DefaultConnGene(), + output_transform: Callable = None, + ): + super().__init__( + num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene + ) if output_transform is not None: try: @@ -38,7 +41,7 @@ class DefaultGenome(BaseGenome): u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan) seqs = topological_sort(nodes, conn_enable) - return state, seqs, nodes, u_conns + return seqs, nodes, u_conns def forward(self, state, inputs, transformed): cal_seqs, nodes, conns = transformed @@ -49,32 +52,34 @@ class DefaultGenome(BaseGenome): nodes_attrs = nodes[:, 1:] def cond_fun(carry): - state_, values, idx = carry - return (idx < N) & (cal_seqs[idx] != I_INT) + values, idx = carry + return (idx < N) & (cal_seqs[idx] != I_INF) def body_func(carry): - state_, values, idx = carry + values, idx = carry i = cal_seqs[idx] def hit(): - s, ins = jax.vmap(self.conn_gene.forward, - in_axes=(None, 1, 0), out_axes=(None, 0))(state_, conns[:, :, i], values) - s, z = self.node_gene.forward(s, nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx)) + ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 1, 0))( + state, conns[:, :, i], values + ) + z = self.node_gene.forward( + state, + nodes_attrs[i], + ins, + is_output_node=jnp.isin(i, self.output_idx), + ) new_values = values.at[i].set(z) - return s, new_values + return new_values # the val of input nodes is obtained by the task, not by calculation - state_, values = jax.lax.cond( - jnp.isin(i, self.input_idx), - lambda: (state_, values), - hit - ) + values = jax.lax.cond(jnp.isin(i, self.input_idx), lambda: values, hit) - return state_, values, idx + 1 + return values, idx + 1 - state, vals, _ = jax.lax.while_loop(cond_fun, body_func, (state, ini_vals, 0)) + vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0)) if self.output_transform is None: - return state, vals[self.output_idx] + return vals[self.output_idx] else: - return state, self.output_transform(vals[self.output_idx]) + return self.output_transform(vals[self.output_idx]) diff --git a/tensorneat/algorithm/neat/genome/recurrent.py b/tensorneat/algorithm/neat/genome/recurrent.py index 93b3614..3e77271 100644 --- a/tensorneat/algorithm/neat/genome/recurrent.py +++ b/tensorneat/algorithm/neat/genome/recurrent.py @@ -10,19 +10,22 @@ from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene class RecurrentGenome(BaseGenome): """Default genome class, with the same behavior as the NEAT-Python""" - network_type = 'recurrent' + network_type = "recurrent" - def __init__(self, - num_inputs: int, - num_outputs: int, - max_nodes: int, - max_conns: int, - node_gene: BaseNodeGene = DefaultNodeGene(), - conn_gene: BaseConnGene = DefaultConnGene(), - activate_time: int = 10, - output_transform: Callable = None - ): - super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene) + def __init__( + self, + num_inputs: int, + num_outputs: int, + max_nodes: int, + max_conns: int, + node_gene: BaseNodeGene = DefaultNodeGene(), + conn_gene: BaseConnGene = DefaultConnGene(), + activate_time: int = 10, + output_transform: Callable = None, + ): + super().__init__( + num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene + ) self.activate_time = activate_time if output_transform is not None: @@ -39,45 +42,37 @@ class RecurrentGenome(BaseGenome): conn_enable = u_conns[0] == 1 u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan) - return state, nodes, u_conns + return nodes, u_conns def forward(self, state, inputs, transformed): nodes, conns = transformed N = nodes.shape[0] vals = jnp.full((N,), jnp.nan) - nodes_attrs = nodes[:, 1:] + nodes_attrs = nodes[:, 1:] # remove index - def body_func(_, carry): - state_, values = carry + def body_func(_, values): # set input values values = values.at[self.input_idx].set(inputs) # calculate connections - state_, node_ins = jax.vmap( - jax.vmap( - self.conn_gene.forward, - in_axes=(None, 1, None), - out_axes=(None, 0) - ), + node_ins = jax.vmap( + jax.vmap(self.conn_gene.forward, in_axes=(None, 1, None)), in_axes=(None, 1, 0), - out_axes=(None, 0) - )(state_, conns, values) + )(state, conns, values) # calculate nodes - is_output_nodes = jnp.isin( - jnp.arange(N), - self.output_idx + is_output_nodes = jnp.isin(jnp.arange(N), self.output_idx) + values = jax.vmap(self.node_gene.forward, in_axes=(None, 0, 0, 0))( + state, nodes_attrs, node_ins.T, is_output_nodes ) - state_, values = jax.vmap( - self.node_gene.forward, - in_axes=(None, 0, 0, 0), - out_axes=(None, 0) - )(state_, nodes_attrs, node_ins.T, is_output_nodes) - return state_, values + return values - state, vals = jax.lax.fori_loop(0, self.activate_time, body_func, (state, vals)) + vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals) - return state, vals[self.output_idx] + if self.output_transform is None: + return vals[self.output_idx] + else: + return self.output_transform(vals[self.output_idx]) diff --git a/tensorneat/algorithm/neat/neat.py b/tensorneat/algorithm/neat/neat.py index 1de420d..057f7b6 100644 --- a/tensorneat/algorithm/neat/neat.py +++ b/tensorneat/algorithm/neat/neat.py @@ -3,58 +3,57 @@ from utils import State from .. import BaseAlgorithm from .species import * from .ga import * +from .genome import * class NEAT(BaseAlgorithm): - def __init__( - self, - species: BaseSpecies, - mutation: BaseMutation = DefaultMutation(), - crossover: BaseCrossover = DefaultCrossover(), + self, + species: BaseSpecies, + mutation: BaseMutation = DefaultMutation(), + crossover: BaseCrossover = DefaultCrossover(), ): - self.genome = species.genome + self.genome: BaseGenome = species.genome self.species = species self.mutation = mutation self.crossover = crossover - def setup(self, randkey): - k1, k2 = jax.random.split(randkey, 2) - return State( - randkey=k1, - generation=jnp.array(0.), - next_node_key=jnp.array(max(*self.genome.input_idx, *self.genome.output_idx) + 2, dtype=jnp.float32), - # inputs nodes, output nodes, 1 hidden node - species=self.species.setup(k2), + def setup(self, state=State()): + state = self.species.setup(state) + state = self.mutation.setup(state) + state = self.crossover.setup(state) + state = state.register( + generation=jnp.array(0.0), + next_node_key=jnp.array( + max(*self.genome.input_idx, *self.genome.output_idx) + 2, + dtype=jnp.float32, + ), ) + return state def ask(self, state: State): - return self.species.ask(state.species) + return state, self.species.ask(state.species) def tell(self, state: State, fitness): k1, k2, randkey = jax.random.split(state.randkey, 3) - state = state.update( - generation=state.generation + 1, - randkey=randkey + state = state.update(generation=state.generation + 1, randkey=randkey) + + state, winner, loser, elite_mask = self.species.update_species( + state.species, fitness ) + state = self.create_next_generation(state, winner, loser, elite_mask) + state = self.species.speciate(state.species) - species_state, winner, loser, elite_mask = self.species.update_species(state.species, fitness, state.generation) - state = state.update(species=species_state) - - state = self.create_next_generation(k2, state, winner, loser, elite_mask) - - species_state = self.species.speciate(state.species, state.generation) - state = state.update(species=species_state) return state - def transform(self, individual): + def transform(self, state, individual): """transform the genome into a neural network""" nodes, conns = individual - return self.genome.transform(nodes, conns) + return self.genome.transform(state, nodes, conns) - def forward(self, inputs, transformed): - return self.genome.forward(inputs, transformed) + def forward(self, state, inputs, transformed): + return self.genome.forward(state, inputs, transformed) @property def num_inputs(self): @@ -68,12 +67,12 @@ class NEAT(BaseAlgorithm): def pop_size(self): return self.species.pop_size - def create_next_generation(self, randkey, state, winner, loser, elite_mask): + def create_next_generation(self, state, winner, loser, elite_mask): # prepare random keys pop_size = self.species.pop_size new_node_keys = jnp.arange(pop_size) + state.next_node_key - k1, k2 = jax.random.split(randkey, 2) + k1, k2, randkey = jax.random.split(state.randkey, 3) crossover_rand_keys = jax.random.split(k1, pop_size) mutate_rand_keys = jax.random.split(k2, pop_size) @@ -81,12 +80,14 @@ class NEAT(BaseAlgorithm): lpn, lpc = state.species.pop_nodes[loser], state.species.pop_conns[loser] # batch crossover - n_nodes, n_conns = (jax.vmap(self.crossover, in_axes=(0, None, 0, 0, 0, 0)) - (crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc)) + n_nodes, n_conns = jax.vmap(self.crossover, in_axes=(0, None, 0, 0, 0, 0))( + crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc + ) # batch mutation - m_n_nodes, m_n_conns = (jax.vmap(self.mutation, in_axes=(0, None, 0, 0, 0)) - (mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys)) + m_n_nodes, m_n_conns = jax.vmap(self.mutation, in_axes=(0, None, 0, 0, 0))( + mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys + ) # elitism don't mutate pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes) @@ -94,20 +95,21 @@ class NEAT(BaseAlgorithm): # update next node key all_nodes_keys = pop_nodes[:, :, 0] - max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys)) + max_node_key = jnp.max( + jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys) + ) next_node_key = max_node_key + 1 return state.update( - species=state.species.update( - pop_nodes=pop_nodes, - pop_conns=pop_conns, - ), + randkey=randkey, + pop_nodes=pop_nodes, + pop_conns=pop_conns, next_node_key=next_node_key, ) def member_count(self, state: State): - return state.species.member_count + return state, state.species.member_count def generation(self, state: State): # to analysis the algorithm - return state.generation + return state, state.generation diff --git a/tensorneat/algorithm/neat/species/base.py b/tensorneat/algorithm/neat/species/base.py index 682fc8f..f2b9cc1 100644 --- a/tensorneat/algorithm/neat/species/base.py +++ b/tensorneat/algorithm/neat/species/base.py @@ -1,15 +1,20 @@ from utils import State +from ..genome import BaseGenome class BaseSpecies: - def setup(self, key, state=State()): + genome: BaseGenome + pop_size: int + species_size: int + + def setup(self, state=State()): return state def ask(self, state: State): raise NotImplementedError - def update_species(self, state, fitness, generation): + def update_species(self, state, fitness): raise NotImplementedError - def speciate(self, state, generation): + def speciate(self, state): raise NotImplementedError diff --git a/tensorneat/algorithm/neat/species/default.py b/tensorneat/algorithm/neat/species/default.py index 0bf2eec..8bab142 100644 --- a/tensorneat/algorithm/neat/species/default.py +++ b/tensorneat/algorithm/neat/species/default.py @@ -6,23 +6,23 @@ from .base import BaseSpecies class DefaultSpecies(BaseSpecies): - - def __init__(self, - genome: BaseGenome, - pop_size, - species_size, - compatibility_disjoint: float = 1.0, - compatibility_weight: float = 0.4, - max_stagnation: int = 15, - species_elitism: int = 2, - spawn_number_change_rate: float = 0.5, - genome_elitism: int = 2, - survival_threshold: float = 0.2, - min_species_size: int = 1, - compatibility_threshold: float = 3., - initialize_method: str = 'one_hidden_node', - # {'one_hidden_node', 'dense_hideen_layer', 'no_hidden_random'} - ): + def __init__( + self, + genome: BaseGenome, + pop_size, + species_size, + compatibility_disjoint: float = 1.0, + compatibility_weight: float = 0.4, + max_stagnation: int = 15, + species_elitism: int = 2, + spawn_number_change_rate: float = 0.5, + genome_elitism: int = 2, + survival_threshold: float = 0.2, + min_species_size: int = 1, + compatibility_threshold: float = 3.0, + initialize_method: str = "one_hidden_node", + # {'one_hidden_node', 'dense_hideen_layer', 'no_hidden_random'} + ): self.genome = genome self.pop_size = pop_size self.species_size = species_size @@ -40,21 +40,38 @@ class DefaultSpecies(BaseSpecies): self.species_arange = jnp.arange(self.species_size) - def setup(self, key, state=State()): - k1, k2 = jax.random.split(key, 2) - pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome, k1, self.initialize_method) + def setup(self, state=State()): + state = self.genome.setup(state) + k1, randkey = jax.random.split(state.randkey, 2) + pop_nodes, pop_conns = initialize_population( + self.pop_size, self.genome, k1, self.initialize_method + ) - species_keys = jnp.full((self.species_size,), jnp.nan) # the unique index (primary key) for each species - best_fitness = jnp.full((self.species_size,), jnp.nan) # the best fitness of each species - last_improved = jnp.full((self.species_size,), jnp.nan) # the last generation that the species improved - member_count = jnp.full((self.species_size,), jnp.nan) # the number of members of each species + species_keys = jnp.full( + (self.species_size,), jnp.nan + ) # the unique index (primary key) for each species + best_fitness = jnp.full( + (self.species_size,), jnp.nan + ) # the best fitness of each species + last_improved = jnp.full( + (self.species_size,), jnp.nan + ) # the last 1 that the species improved + member_count = jnp.full( + (self.species_size,), jnp.nan + ) # the number of members of each species idx2species = jnp.zeros(self.pop_size) # the species index of each individual # nodes for each center genome of each species - center_nodes = jnp.full((self.species_size, self.genome.max_nodes, self.genome.node_gene.length), jnp.nan) + center_nodes = jnp.full( + (self.species_size, self.genome.max_nodes, self.genome.node_gene.length), + jnp.nan, + ) # connections for each center genome of each species - center_conns = jnp.full((self.species_size, self.genome.max_conns, self.genome.conn_gene.length), jnp.nan) + center_conns = jnp.full( + (self.species_size, self.genome.max_conns, self.genome.conn_gene.length), + jnp.nan, + ) species_keys = species_keys.at[0].set(0) best_fitness = best_fitness.at[0].set(-jnp.inf) @@ -66,7 +83,7 @@ class DefaultSpecies(BaseSpecies): pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns)) return state.register( - species_randkey=k2, + randkey=randkey, pop_nodes=pop_nodes, pop_conns=pop_conns, species_keys=species_keys, @@ -80,14 +97,14 @@ class DefaultSpecies(BaseSpecies): ) def ask(self, state): - return state.pop_nodes, state.pop_conns + return state, state.pop_nodes, state.pop_conns - def update_species(self, state, fitness, generation): + def update_species(self, state, fitness): # update the fitness of each species - species_fitness = self.update_species_fitness(state, fitness) + state, species_fitness = self.update_species_fitness(state, fitness) # stagnation species - state, species_fitness = self.stagnation(state, generation, species_fitness) + state, species_fitness = self.stagnation(state, species_fitness) # sort species_info by their fitness. (also push nan to the end) sort_indices = jnp.argsort(species_fitness)[::-1] @@ -101,11 +118,13 @@ class DefaultSpecies(BaseSpecies): ) # decide the number of members of each species by their fitness - spawn_number = self.cal_spawn_numbers(state) + state, spawn_number = self.cal_spawn_numbers(state) k1, k2 = jax.random.split(state.randkey) # crossover info - winner, loser, elite_mask = self.create_crossover_pair(state, k1, spawn_number, fitness) + winner, loser, elite_mask = self.create_crossover_pair( + state, k1, spawn_number, fitness + ) return state.update(randkey=k2), winner, loser, elite_mask @@ -116,42 +135,50 @@ class DefaultSpecies(BaseSpecies): """ def aux_func(idx): - s_fitness = jnp.where(state.idx2species == state.species_keys[idx], fitness, -jnp.inf) + s_fitness = jnp.where( + state.idx2species == state.species_keys[idx], fitness, -jnp.inf + ) val = jnp.max(s_fitness) return val - return jax.vmap(aux_func)(self.species_arange) + return state, jax.vmap(aux_func)(self.species_arange) - def stagnation(self, state, generation, species_fitness): + def stagnation(self, state, species_fitness): """ stagnation species. those species whose fitness is not better than the best fitness of the species for a long time will be stagnation. elitism species never stagnation - - generation: the current generation """ def check_stagnation(idx): # determine whether the species stagnation st = ( - (species_fitness[idx] <= state.best_fitness[idx]) & # not better than the best fitness of the species - (generation - state.last_improved[idx] > self.max_stagnation) # for a long time - ) + species_fitness[idx] <= state.best_fitness[idx] + ) & ( # not better than the best fitness of the species + state.generation - state.last_improved[idx] > self.max_stagnation + ) # for a long time # update last_improved and best_fitness li, bf = jax.lax.cond( species_fitness[idx] > state.best_fitness[idx], - lambda: (generation, species_fitness[idx]), # update - lambda: (state.last_improved[idx], state.best_fitness[idx]) # not update + lambda: (state.generation, species_fitness[idx]), # update + lambda: ( + state.last_improved[idx], + state.best_fitness[idx], + ), # not update ) return st, bf, li - spe_st, best_fitness, last_improved = jax.vmap(check_stagnation)(self.species_arange) + spe_st, best_fitness, last_improved = jax.vmap(check_stagnation)( + self.species_arange + ) # elite species will not be stagnation species_rank = rank_elements(species_fitness) - spe_st = jnp.where(species_rank < self.species_elitism, False, spe_st) # elitism never stagnation + spe_st = jnp.where( + species_rank < self.species_elitism, False, spe_st + ) # elitism never stagnation # set stagnation species to nan def update_func(idx): @@ -173,8 +200,8 @@ class DefaultSpecies(BaseSpecies): state.member_count[idx], species_fitness[idx], state.center_nodes[idx], - state.center_conns[idx] - ) # not stagnation species + state.center_conns[idx], + ), # not stagnation species ) ( @@ -184,18 +211,20 @@ class DefaultSpecies(BaseSpecies): member_count, species_fitness, center_nodes, - center_conns - ) = ( - jax.vmap(update_func)(self.species_arange)) + center_conns, + ) = jax.vmap(update_func)(self.species_arange) - return state.update( - species_keys=species_keys, - best_fitness=best_fitness, - last_improved=last_improved, - member_count=member_count, - center_nodes=center_nodes, - center_conns=center_conns, - ), species_fitness + return ( + state.update( + species_keys=species_keys, + best_fitness=best_fitness, + last_improved=last_improved, + member_count=member_count, + center_nodes=center_nodes, + center_conns=center_conns, + ), + species_fitness, + ) def cal_spawn_numbers(self, state): """ @@ -209,17 +238,26 @@ class DefaultSpecies(BaseSpecies): is_species_valid = ~jnp.isnan(species_keys) valid_species_num = jnp.sum(is_species_valid) - denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6 + denominator = ( + (valid_species_num + 1) * valid_species_num / 2 + ) # obtain 3 + 2 + 1 = 6 rank_score = valid_species_num - self.species_arange # obtain [3, 2, 1] spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17] - spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0 + spawn_number_rate = jnp.where( + is_species_valid, spawn_number_rate, 0 + ) # set invalid species to 0 - target_spawn_number = jnp.floor(spawn_number_rate * self.pop_size) # calculate member + target_spawn_number = jnp.floor( + spawn_number_rate * self.pop_size + ) # calculate member # Avoid too much variation of numbers for a species previous_size = state.member_count - spawn_number = previous_size + (target_spawn_number - previous_size) * self.spawn_number_change_rate + spawn_number = ( + previous_size + + (target_spawn_number - previous_size) * self.spawn_number_change_rate + ) spawn_number = spawn_number.astype(jnp.int32) # must control the sum of spawn_number to be equal to pop_size @@ -228,9 +266,9 @@ class DefaultSpecies(BaseSpecies): # add error to the first species to control the sum of spawn_number spawn_number = spawn_number.at[0].add(error) - return spawn_number + return state, spawn_number - def create_crossover_pair(self, state, randkey, spawn_number, fitness): + def create_crossover_pair(self, state, spawn_number, fitness): s_idx = self.species_arange p_idx = jnp.arange(self.pop_size) @@ -241,10 +279,18 @@ class DefaultSpecies(BaseSpecies): members_fitness = jnp.where(members, fitness, -jnp.inf) sorted_member_indices = jnp.argsort(members_fitness)[::-1] - survive_size = jnp.floor(self.survival_threshold * members_num).astype(jnp.int32) + survive_size = jnp.floor(self.survival_threshold * members_num).astype( + jnp.int32 + ) select_pro = (p_idx < survive_size) / survive_size - fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, self.pop_size), replace=True, p=select_pro) + fa, ma = jax.random.choice( + key, + sorted_member_indices, + shape=(2, self.pop_size), + replace=True, + p=select_pro, + ) # elite fa = jnp.where(p_idx < self.genome_elitism, sorted_member_indices, fa) @@ -252,7 +298,10 @@ class DefaultSpecies(BaseSpecies): elite = jnp.where(p_idx < self.genome_elitism, True, False) return fa, ma, elite - fas, mas, elites = jax.vmap(aux_func)(jax.random.split(randkey, self.species_size), s_idx) + randkey_, randkey = jax.random.split(state.randkey) + fas, mas, elites = jax.vmap(aux_func)( + jax.random.split(randkey_, self.species_size), s_idx + ) spawn_number_cum = jnp.cumsum(spawn_number) @@ -261,7 +310,11 @@ class DefaultSpecies(BaseSpecies): # elite genomes are at the beginning of the species idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx) - return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species] + return ( + fas[loc, idx_in_species], + mas[loc, idx_in_species], + elites[loc, idx_in_species], + ) part1, part2, elite_mask = jax.vmap(aux_func)(p_idx) @@ -269,14 +322,18 @@ class DefaultSpecies(BaseSpecies): winner = jnp.where(is_part1_win, part1, part2) loser = jnp.where(is_part1_win, part2, part1) - return winner, loser, elite_mask + return state(randkey=randkey), winner, loser, elite_mask - def speciate(self, state, generation): + def speciate(self, state): # prepare distance functions - o2p_distance_func = jax.vmap(self.distance, in_axes=(None, None, 0, 0)) # one to population + o2p_distance_func = jax.vmap( + self.distance, in_axes=(None, None, 0, 0) + ) # one to population # idx to specie key - idx2species = jnp.full((self.pop_size,), jnp.nan) # NaN means not assigned to any species + idx2species = jnp.full( + (self.pop_size,), jnp.nan + ) # NaN means not assigned to any species # the distance between genomes to its center genomes o2c_distances = jnp.full((self.pop_size,), jnp.inf) @@ -286,15 +343,16 @@ class DefaultSpecies(BaseSpecies): # i, idx2species, center_nodes, center_conns, o2c_distances i, i2s, cns, ccs, o2c = carry - return ( - (i < self.species_size) & - (~jnp.isnan(state.species_keys[i])) + return (i < self.species_size) & ( + ~jnp.isnan(state.species_keys[i]) ) # current species is existing def body_func(carry): i, i2s, cns, ccs, o2c = carry - distances = o2p_distance_func(cns[i], ccs[i], state.pop_nodes, state.pop_conns) + distances = o2p_distance_func( + cns[i], ccs[i], state.pop_nodes, state.pop_conns + ) # find the closest one closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s)) @@ -308,9 +366,11 @@ class DefaultSpecies(BaseSpecies): return i + 1, i2s, cns, ccs, o2c - _, idx2species, center_nodes, center_conns, o2c_distances = \ - jax.lax.while_loop(cond_func, body_func, - (0, idx2species, state.center_nodes, state.center_conns, o2c_distances)) + _, idx2species, center_nodes, center_conns, o2c_distances = jax.lax.while_loop( + cond_func, + body_func, + (0, idx2species, state.center_nodes, state.center_conns, o2c_distances), + ) state = state.update( idx2species=idx2species, @@ -326,7 +386,9 @@ class DefaultSpecies(BaseSpecies): current_species_existed = ~jnp.isnan(sk[i]) not_all_assigned = jnp.any(jnp.isnan(i2s)) not_reach_species_upper_bounds = i < self.species_size - return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned) + return not_reach_species_upper_bounds & ( + current_species_existed | not_all_assigned + ) def body_func(carry): i, i2s, cns, ccs, sk, o2c, nsk = carry @@ -335,7 +397,7 @@ class DefaultSpecies(BaseSpecies): jnp.isnan(sk[i]), # whether the current species is existing or not create_new_species, # if not existing, create a new specie update_exist_specie, # if existing, update the specie - (i, i2s, cns, ccs, sk, o2c, nsk) + (i, i2s, cns, ccs, sk, o2c, nsk), ) return i + 1, i2s, cns, ccs, sk, o2c, nsk @@ -371,7 +433,9 @@ class DefaultSpecies(BaseSpecies): def speciate_by_threshold(i, i2s, cns, ccs, sk, o2c): # distance between such center genome and ppo genomes - o2p_distance = o2p_distance_func(cns[i], ccs[i], state.pop_nodes, state.pop_conns) + o2p_distance = o2p_distance_func( + cns[i], ccs[i], state.pop_nodes, state.pop_conns + ) close_enough_mask = o2p_distance < self.compatibility_threshold # when a genome is not assigned or the distance between its current center is bigger than this center @@ -388,11 +452,26 @@ class DefaultSpecies(BaseSpecies): return i2s, o2c # update idx2species - _, idx2species, center_nodes, center_conns, species_keys, _, next_species_key = jax.lax.while_loop( + ( + _, + idx2species, + center_nodes, + center_conns, + species_keys, + _, + next_species_key, + ) = jax.lax.while_loop( cond_func, body_func, - (0, state.idx2species, center_nodes, center_conns, state.species_keys, o2c_distances, - state.next_species_key) + ( + 0, + state.idx2species, + center_nodes, + center_conns, + state.species_keys, + o2c_distances, + state.next_species_key, + ), ) # if there are still some pop genomes not assigned to any species, add them to the last genome @@ -402,14 +481,18 @@ class DefaultSpecies(BaseSpecies): # complete info of species which is created in this generation new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.best_fitness) best_fitness = jnp.where(new_created_mask, -jnp.inf, state.best_fitness) - last_improved = jnp.where(new_created_mask, generation, state.last_improved) + last_improved = jnp.where( + new_created_mask, state.generation, state.last_improved + ) # update members count def count_members(idx): return jax.lax.cond( jnp.isnan(species_keys[idx]), # if the species is not existing lambda: jnp.nan, # nan - lambda: jnp.sum(idx2species == species_keys[idx], dtype=jnp.float32) # count members + lambda: jnp.sum( + idx2species == species_keys[idx], dtype=jnp.float32 + ), # count members ) member_count = jax.vmap(count_members)(self.species_arange) @@ -422,7 +505,7 @@ class DefaultSpecies(BaseSpecies): idx2species=idx2species, center_nodes=center_nodes, center_conns=center_conns, - next_species_key=next_species_key + next_species_key=next_species_key, ) def distance(self, nodes1, conns1, nodes2, conns2): @@ -446,7 +529,9 @@ class DefaultSpecies(BaseSpecies): keys = nodes[:, 0] sorted_indices = jnp.argsort(keys, axis=0) nodes = nodes[sorted_indices] - nodes = jnp.concatenate([nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0) # add a nan row to the end + nodes = jnp.concatenate( + [nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0 + ) # add a nan row to the end fr, sr = nodes[:-1], nodes[1:] # first row, second row # flag location of homologous nodes @@ -460,7 +545,10 @@ class DefaultSpecies(BaseSpecies): hnd = jnp.where(jnp.isnan(hnd), 0, hnd) homologous_distance = jnp.sum(hnd * intersect_mask) - val = non_homologous_cnt * self.compatibility_disjoint + homologous_distance * self.compatibility_weight + val = ( + non_homologous_cnt * self.compatibility_disjoint + + homologous_distance * self.compatibility_weight + ) return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division @@ -476,7 +564,9 @@ class DefaultSpecies(BaseSpecies): keys = cons[:, :2] sorted_indices = jnp.lexsort(keys.T[::-1]) cons = cons[sorted_indices] - cons = jnp.concatenate([cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0) # add a nan row to the end + cons = jnp.concatenate( + [cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0 + ) # add a nan row to the end fr, sr = cons[:-1], cons[1:] # first row, second row # both genome has such connection @@ -487,19 +577,22 @@ class DefaultSpecies(BaseSpecies): hcd = jnp.where(jnp.isnan(hcd), 0, hcd) homologous_distance = jnp.sum(hcd * intersect_mask) - val = non_homologous_cnt * self.compatibility_disjoint + homologous_distance * self.compatibility_weight + val = ( + non_homologous_cnt * self.compatibility_disjoint + + homologous_distance * self.compatibility_weight + ) return jnp.where(max_cnt == 0, 0, val / max_cnt) -def initialize_population(pop_size, genome, randkey, init_method='default'): +def initialize_population(pop_size, genome, randkey, init_method="default"): rand_keys = jax.random.split(randkey, pop_size) - if init_method == 'one_hidden_node': + if init_method == "one_hidden_node": init_func = init_one_hidden_node - elif init_method == 'dense_hideen_layer': + elif init_method == "dense_hideen_layer": init_func = init_dense_hideen_layer - elif init_method == 'no_hidden_random': + elif init_method == "no_hidden_random": init_func = init_no_hidden_random else: raise ValueError("Unknown initialization method: {}".format(init_method)) @@ -521,12 +614,16 @@ def init_one_hidden_node(genome, randkey): nodes = nodes.at[output_idx, 0].set(output_idx) nodes = nodes.at[new_node_key, 0].set(new_node_key) - rand_keys_nodes = jax.random.split(randkey, num=len(input_idx) + len(output_idx) + 1) - input_keys, output_keys, hidden_key = rand_keys_nodes[:len(input_idx)], rand_keys_nodes[ - len(input_idx):len(input_idx) + len( - output_idx)], rand_keys_nodes[-1] + rand_keys_nodes = jax.random.split( + randkey, num=len(input_idx) + len(output_idx) + 1 + ) + input_keys, output_keys, hidden_key = ( + rand_keys_nodes[: len(input_idx)], + rand_keys_nodes[len(input_idx) : len(input_idx) + len(output_idx)], + rand_keys_nodes[-1], + ) - node_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(None, 0)) + node_attr_func = jax.vmap(genome.node_gene.new_attrs, in_axes=(None, 0)) input_attrs = node_attr_func(input_keys) output_attrs = node_attr_func(output_keys) hidden_attrs = genome.node_gene.new_custom_attrs(hidden_key) @@ -544,7 +641,10 @@ def init_one_hidden_node(genome, randkey): conns = conns.at[output_idx, 2].set(True) rand_keys_conns = jax.random.split(randkey, num=len(input_idx) + len(output_idx)) - input_conn_keys, output_conn_keys = rand_keys_conns[:len(input_idx)], rand_keys_conns[len(input_idx):] + input_conn_keys, output_conn_keys = ( + rand_keys_conns[: len(input_idx)], + rand_keys_conns[len(input_idx) :], + ) conn_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(None, 0)) input_conn_attrs = conn_attr_func(input_conn_keys) @@ -563,8 +663,12 @@ def init_dense_hideen_layer(genome, randkey, hiddens=20): input_size = len(input_idx) output_size = len(output_idx) - hidden_idx = jnp.arange(input_size + output_size, input_size + output_size + hiddens) - nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan, dtype=jnp.float32) + hidden_idx = jnp.arange( + input_size + output_size, input_size + output_size + hiddens + ) + nodes = jnp.full( + (genome.max_nodes, genome.node_gene.length), jnp.nan, dtype=jnp.float32 + ) nodes = nodes.at[input_idx, 0].set(input_idx) nodes = nodes.at[output_idx, 0].set(output_idx) nodes = nodes.at[hidden_idx, 0].set(hidden_idx) @@ -572,8 +676,8 @@ def init_dense_hideen_layer(genome, randkey, hiddens=20): total_idx = input_size + output_size + hiddens rand_keys_n = jax.random.split(k1, num=total_idx) input_keys = rand_keys_n[:input_size] - output_keys = rand_keys_n[input_size:input_size + output_size] - hidden_keys = rand_keys_n[input_size + output_size:] + output_keys = rand_keys_n[input_size : input_size + output_size] + hidden_keys = rand_keys_n[input_size + output_size :] node_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(0)) input_attrs = node_attr_func(input_keys) @@ -585,21 +689,31 @@ def init_dense_hideen_layer(genome, randkey, hiddens=20): nodes = nodes.at[hidden_idx, 1:].set(hidden_attrs) total_connections = input_size * hiddens + hiddens * output_size - conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan, dtype=jnp.float32) + conns = jnp.full( + (genome.max_conns, genome.conn_gene.length), jnp.nan, dtype=jnp.float32 + ) rand_keys_c = jax.random.split(k2, num=total_connections) conns_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(0)) conns_attrs = conns_attr_func(rand_keys_c) - input_to_hidden_ids, hidden_ids = jnp.meshgrid(input_idx, hidden_idx, indexing='ij') - hidden_to_output_ids, output_ids = jnp.meshgrid(hidden_idx, output_idx, indexing='ij') + input_to_hidden_ids, hidden_ids = jnp.meshgrid(input_idx, hidden_idx, indexing="ij") + hidden_to_output_ids, output_ids = jnp.meshgrid( + hidden_idx, output_idx, indexing="ij" + ) - conns = conns.at[:input_size * hiddens, 0].set(input_to_hidden_ids.flatten()) - conns = conns.at[:input_size * hiddens, 1].set(hidden_ids.flatten()) - conns = conns.at[input_size * hiddens: total_connections, 0].set(hidden_to_output_ids.flatten()) - conns = conns.at[input_size * hiddens: total_connections, 1].set(output_ids.flatten()) - conns = conns.at[:input_size * hiddens + hiddens * output_size, 2].set(True) - conns = conns.at[:input_size * hiddens + hiddens * output_size, 3:].set(conns_attrs) + conns = conns.at[: input_size * hiddens, 0].set(input_to_hidden_ids.flatten()) + conns = conns.at[: input_size * hiddens, 1].set(hidden_ids.flatten()) + conns = conns.at[input_size * hiddens : total_connections, 0].set( + hidden_to_output_ids.flatten() + ) + conns = conns.at[input_size * hiddens : total_connections, 1].set( + output_ids.flatten() + ) + conns = conns.at[: input_size * hiddens + hiddens * output_size, 2].set(True) + conns = conns.at[: input_size * hiddens + hiddens * output_size, 3:].set( + conns_attrs + ) return nodes, conns @@ -615,8 +729,8 @@ def init_no_hidden_random(genome, randkey): total_idx = len(input_idx) + len(output_idx) rand_keys_n = jax.random.split(k1, num=total_idx) - input_keys = rand_keys_n[:len(input_idx)] - output_keys = rand_keys_n[len(input_idx):] + input_keys = rand_keys_n[: len(input_idx)] + output_keys = rand_keys_n[len(input_idx) :] node_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(0)) input_attrs = node_attr_func(input_keys) diff --git a/tensorneat/examples/brax/ant.py b/tensorneat/examples/brax/ant.py index 082d202..1f8bb46 100644 --- a/tensorneat/examples/brax/ant.py +++ b/tensorneat/examples/brax/ant.py @@ -16,7 +16,8 @@ if __name__ == '__main__': node_gene=DefaultNodeGene( activation_options=(Act.tanh,), activation_default=Act.tanh, - ) + ), + output_transform=Act.tanh ), pop_size=1000, species_size=10, diff --git a/tensorneat/examples/with_evox/ray_test.py b/tensorneat/examples/with_evox/ray_test.py new file mode 100644 index 0000000..6b850da --- /dev/null +++ b/tensorneat/examples/with_evox/ray_test.py @@ -0,0 +1,5 @@ +import ray +ray.init(num_gpus=2) + +available_resources = ray.available_resources() +print("Available resources:", available_resources) diff --git a/tensorneat/pipeline.py b/tensorneat/pipeline.py index b45424a..5ffa22d 100644 --- a/tensorneat/pipeline.py +++ b/tensorneat/pipeline.py @@ -28,70 +28,53 @@ class Pipeline: self.generation_limit = generation_limit self.pop_size = self.algorithm.pop_size - print(self.problem.input_shape, self.problem.output_shape) + # print(self.problem.input_shape, self.problem.output_shape) # TODO: make each algorithm's input_num and output_num assert algorithm.num_inputs == self.problem.input_shape[-1], \ f"algorithm input shape is {algorithm.num_inputs} but problem input shape is {self.problem.input_shape}" - # self.act_func = self.algorithm.act - - # for _ in range(len(self.problem.input_shape) - 1): - # self.act_func = jax.vmap(self.act_func, in_axes=(None, 0, None)) - self.best_genome = None self.best_fitness = float('-inf') self.generation_timestamp = None - def setup(self): - key = jax.random.PRNGKey(self.seed) - key, algorithm_key, evaluate_key = jax.random.split(key, 3) - - # TODO: Problem should has setup function to maintain state - return State( - randkey=key, - alg=self.algorithm.setup(algorithm_key), - pro=self.problem.setup(evaluate_key), - ) + def setup(self, state=State()): + state = state.register(randkey=jax.random.PRNGKey(self.seed)) + state = self.algorithm.setup(state) + state = self.problem.setup(state) + return state def step(self, state): - key, sub_key = jax.random.split(state.randkey) - keys = jax.random.split(key, self.pop_size) + randkey_, randkey = jax.random.split(state.randkey) + keys = jax.random.split(randkey_, self.pop_size) - pop = self.algorithm.ask(state.alg) + state, pop = self.algorithm.ask(state) - pop_transformed = jax.vmap(self.algorithm.transform)(pop) + state, pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0), out_axes=(None, 0))(state, pop) - fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0))( - keys, - state.pro, - self.algorithm.forward, - pop_transformed - ) + state, fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0), out_axes=(None, 0))( + keys, + state, + self.algorithm.forward, + pop_transformed + ) - # fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses) + state = self.algorithm.tell(state, fitnesses) - alg_state = self.algorithm.tell(state.alg, fitnesses) + return state.update(randkey=randkey), fitnesses - return state.update( - randkey=sub_key, - alg=alg_state, - ), fitnesses - - def auto_run(self, ini_state): - state = ini_state + def auto_run(self, state): print("start compile") tic = time.time() - compiled_step = jax.jit(self.step).lower(ini_state).compile() - + compiled_step = jax.jit(self.step).lower(state).compile() print(f"compile finished, cost time: {time.time() - tic:.6f}s", ) + for _ in range(self.generation_limit): self.generation_timestamp = time.time() - previous_pop = self.algorithm.ask(state.alg) + state, previous_pop = self.algorithm.ask(state) - state, fitnesses = compiled_step(state) fitnesses = jax.device_get(fitnesses) @@ -101,13 +84,15 @@ class Pipeline: if max(fitnesses) >= self.fitness_target: print("Fitness limit reached!") return state, self.best_genome - node= previous_pop[0][0][:,0] - node_count = jnp.sum(~jnp.isnan(node)) - conn= previous_pop[1][0][:,0] - conn_count = jnp.sum(~jnp.isnan(conn)) - if(w%5==0): - print("node_count",node_count) - print("conn_count",conn_count) + + # node = previous_pop[0][0][:, 0] + # node_count = jnp.sum(~jnp.isnan(node)) + # conn = previous_pop[1][0][:, 0] + # conn_count = jnp.sum(~jnp.isnan(conn)) + # if (w % 5 == 0): + # print("node_count", node_count) + # print("conn_count", conn_count) + print("Generation limit reached!") return state, self.best_genome @@ -124,13 +109,13 @@ class Pipeline: self.best_fitness = fitnesses[max_idx] self.best_genome = pop[0][max_idx], pop[1][max_idx] - member_count = jax.device_get(self.algorithm.member_count(state.alg)) + member_count = jax.device_get(self.algorithm.member_count(state)) species_sizes = [int(i) for i in member_count if i > 0] - print(f"Generation: {self.algorithm.generation(state.alg)}", + print(f"Generation: {self.algorithm.generation(state)}", f"species: {len(species_sizes)}, {species_sizes}", f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms") def show(self, state, best, *args, **kwargs): - transformed = self.algorithm.transform(best) - self.problem.show(state.randkey, state.pro, self.algorithm.forward, transformed, *args, **kwargs) + state, transformed = self.algorithm.transform(state, best) + self.problem.show(state.randkey, state, self.algorithm.forward, transformed, *args, **kwargs) diff --git a/tensorneat/problem/base.py b/tensorneat/problem/base.py index 1e740c2..67e73c1 100644 --- a/tensorneat/problem/base.py +++ b/tensorneat/problem/base.py @@ -6,9 +6,9 @@ from utils import State class BaseProblem: jitable = None - def setup(self, randkey, state: State = State()): + def setup(self, state: State = State()): """initialize the state of the problem""" - pass + return state def evaluate(self, randkey, state: State, act_func: Callable, params): """evaluate one individual""" diff --git a/tensorneat/problem/func_fit/func_fit.py b/tensorneat/problem/func_fit/func_fit.py index 5d09e26..3c71415 100644 --- a/tensorneat/problem/func_fit/func_fit.py +++ b/tensorneat/problem/func_fit/func_fit.py @@ -16,12 +16,12 @@ class FuncFit(BaseProblem): assert error_method in {'mse', 'rmse', 'mae', 'mape'} self.error_method = error_method - def setup(self, randkey, state: State = State()): + def setup(self, state: State = State()): return state def evaluate(self, randkey, state, act_func, params): - predict = jax.vmap(act_func, in_axes=(0, None))(self.inputs, params) + state, predict = jax.vmap(act_func, in_axes=(None, 0, None), out_axes=(None, 0))(state, self.inputs, params) if self.error_method == 'mse': loss = jnp.mean((predict - self.targets) ** 2) @@ -38,12 +38,14 @@ class FuncFit(BaseProblem): else: raise NotImplementedError - return -loss + return state, -loss def show(self, randkey, state, act_func, params, *args, **kwargs): - predict = jax.vmap(act_func, in_axes=(0, None))(self.inputs, params) + state, predict = jax.vmap(act_func, in_axes=(None, 0, None), out_axes=(None, 0))(state, self.inputs, params) inputs, target, predict = jax.device_get([self.inputs, self.targets, predict]) - loss = -self.evaluate(randkey, state, act_func, params) + state, loss = self.evaluate(randkey, state, act_func, params) + loss = -loss + msg = "" for i in range(inputs.shape[0]): msg += f"input: {inputs[i]}, target: {target[i]}, predict: {predict[i]}\n" diff --git a/tensorneat/problem/rl_env/rl_jit.py b/tensorneat/problem/rl_env/rl_jit.py index b9f2329..4edf924 100644 --- a/tensorneat/problem/rl_env/rl_jit.py +++ b/tensorneat/problem/rl_env/rl_jit.py @@ -17,29 +17,29 @@ class RLEnv(BaseProblem): init_obs, init_env_state = self.reset(rng_reset) def cond_func(carry): - _, _, _, done, _, count = carry + _, _, _, _, done, _, count = carry return ~done & (count < self.max_step) def body_func(carry): - obs, env_state, rng, done, tr, count = carry # tr -> total reward - action = act_func(obs, params) + state_, obs, env_state, rng, done, tr, count = carry # tr -> total reward + state_, action = act_func(state_, obs, params) next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action) next_rng, _ = jax.random.split(rng) - return next_obs, next_env_state, next_rng, done, tr + reward, count + 1 + return state_, next_obs, next_env_state, next_rng, done, tr + reward, count + 1 - _, _, _, _, total_reward, _ = jax.lax.while_loop( + state, _, _, _, _, total_reward, _ = jax.lax.while_loop( cond_func, body_func, - (init_obs, init_env_state, rng_episode, False, 0.0, 0) + (state, init_obs, init_env_state, rng_episode, False, 0.0, 0) ) - return total_reward + return state, total_reward - @partial(jax.jit, static_argnums=(0,)) + # @partial(jax.jit, static_argnums=(0,)) def step(self, randkey, env_state, action): return self.env_step(randkey, env_state, action) - @partial(jax.jit, static_argnums=(0,)) + # @partial(jax.jit, static_argnums=(0,)) def reset(self, randkey): return self.env_reset(randkey) diff --git a/tensorneat/test/crossover_mutation.py b/tensorneat/test/crossover_mutation.py new file mode 100644 index 0000000..8da8761 --- /dev/null +++ b/tensorneat/test/crossover_mutation.py @@ -0,0 +1,52 @@ +import jax, jax.numpy as jnp +from utils import Act +from algorithm.neat import * +import numpy as np + + +def main(): + algorithm = NEAT( + species=DefaultSpecies( + genome=DefaultGenome( + num_inputs=3, + num_outputs=1, + max_nodes=100, + max_conns=100, + ), + pop_size=1000, + species_size=10, + compatibility_threshold=3.5, + ), + mutation=DefaultMutation( + conn_add=0.4, + conn_delete=0, + node_add=0.9, + node_delete=0, + ), + ) + + state = algorithm.setup(jax.random.key(0)) + pop_nodes, pop_conns = algorithm.species.ask(state.species) + + batch_transform = jax.vmap(algorithm.genome.transform) + batch_forward = jax.vmap(algorithm.forward, in_axes=(None, 0)) + + for _ in range(50): + winner, losser = jax.random.randint(state.randkey, (2, 1000), 0, 1000) + elite_mask = jnp.zeros((1000,), dtype=jnp.bool_) + elite_mask = elite_mask.at[:5].set(1) + + state = algorithm.create_next_generation(jax.random.key(0), state, winner, losser, elite_mask) + pop_nodes, pop_conns = algorithm.species.ask(state.species) + + transforms = batch_transform(pop_nodes, pop_conns) + outputs = batch_forward(jnp.array([1, 0, 1]), transforms) + + try: + assert not jnp.any(jnp.isnan(outputs)) + except: + print(_) + + +if __name__ == '__main__': + main() diff --git a/tensorneat/test/nan_fitness.py b/tensorneat/test/nan_fitness.py new file mode 100644 index 0000000..3097ebc --- /dev/null +++ b/tensorneat/test/nan_fitness.py @@ -0,0 +1,42 @@ +import jax, jax.numpy as jnp +from utils import Act +from algorithm.neat import * +import numpy as np + + +def main(): + node_path = "../examples/brax/nan_node.npy" + conn_path = "../examples/brax/nan_conn.npy" + nodes = np.load(node_path) + conns = np.load(conn_path) + nodes, conns = jax.device_put([nodes, conns]) + + genome = DefaultGenome( + num_inputs=8, + num_outputs=2, + max_nodes=20, + max_conns=20, + node_gene=DefaultNodeGene( + activation_options=(Act.tanh,), + activation_default=Act.tanh, + ) + ) + + transformed = genome.transform(nodes, conns) + seq, nodes, conns = transformed + print(seq) + + exit(0) + # print(*transformed, sep='\n') + + key = jax.random.key(0) + dummy_input = jnp.zeros((8,)) + output = genome.forward(dummy_input, transformed) + print(output) + + +if __name__ == '__main__': + a = jnp.array([1, 3, 5, 6, 8]) + b = jnp.array([1, 2, 3]) + print(jnp.isin(a, b)) + # main() diff --git a/tensorneat/test/test_genome.py b/tensorneat/test/test_genome.py index d889097..12ff43e 100644 --- a/tensorneat/test/test_genome.py +++ b/tensorneat/test/test_genome.py @@ -7,21 +7,25 @@ import jax, jax.numpy as jnp def test_default(): # index, bias, response, activation, aggregation - nodes = jnp.array([ - [0, 0, 1, 0, 0], # in[0] - [1, 0, 1, 0, 0], # in[1] - [2, 0.5, 1, 0, 0], # out[0], - [3, 1, 1, 0, 0], # hidden[0], - [4, -1, 1, 0, 0], # hidden[1], - ]) + nodes = jnp.array( + [ + [0, 0, 1, 0, 0], # in[0] + [1, 0, 1, 0, 0], # in[1] + [2, 0.5, 1, 0, 0], # out[0], + [3, 1, 1, 0, 0], # hidden[0], + [4, -1, 1, 0, 0], # hidden[1], + ] + ) # in_node, out_node, enable, weight - conns = jnp.array([ - [0, 3, 1, 0.5], # in[0] -> hidden[0] - [1, 4, 1, 0.5], # in[1] -> hidden[1] - [3, 2, 1, 0.5], # hidden[0] -> out[0] - [4, 2, 1, 0.5], # hidden[1] -> out[0] - ]) + conns = jnp.array( + [ + [0, 3, 1, 0.5], # in[0] -> hidden[0] + [1, 4, 1, 0.5], # in[1] -> hidden[1] + [3, 2, 1, 0.5], # hidden[0] -> out[0] + [4, 2, 1, 0.5], # hidden[1] -> out[0] + ] + ) genome = DefaultGenome( num_inputs=2, @@ -30,34 +34,37 @@ def test_default(): max_conns=4, node_gene=DefaultNodeGene( activation_default=Act.identity, - activation_options=(Act.identity, ), + activation_options=(Act.identity,), aggregation_default=Agg.sum, - aggregation_options=(Agg.sum, ), + aggregation_options=(Agg.sum,), ), ) state = genome.setup(State(randkey=jax.random.key(0))) - state, *transformed = genome.transform(state, nodes, conns) - print(*transformed, sep='\n') + transformed = genome.transform(state, nodes, conns) + print(*transformed, sep="\n") - inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]]) - state, outputs = jax.jit(jax.vmap(genome.forward, - in_axes=(None, 0, None), out_axes=(None, 0)))(state, inputs, transformed) + inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]]) + outputs = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, None)))( + state, inputs, transformed + ) print(outputs) assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]])) # expected: [[0.5], [0.75], [0.75], [1]] - print('\n-------------------------------------------------------\n') + print("\n-------------------------------------------------------\n") conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0] print(conns) - state, *transformed = genome.transform(state, nodes, conns) - print(*transformed, sep='\n') + transformed = genome.transform(state, nodes, conns) + print(*transformed, sep="\n") - inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]]) - state, outputs = jax.vmap(genome.forward, in_axes=(None, 0, None), out_axes=(None, 0))(state, inputs, transformed) + inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]]) + outputs = jax.vmap(genome.forward, in_axes=(None, 0, None))( + state, inputs, transformed + ) print(outputs) assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]])) # expected: [[0.5], [0.75], [0.5], [0.75]] @@ -66,21 +73,25 @@ def test_default(): def test_recurrent(): # index, bias, response, activation, aggregation - nodes = jnp.array([ - [0, 0, 1, 0, 0], # in[0] - [1, 0, 1, 0, 0], # in[1] - [2, 0.5, 1, 0, 0], # out[0], - [3, 1, 1, 0, 0], # hidden[0], - [4, -1, 1, 0, 0], # hidden[1], - ]) + nodes = jnp.array( + [ + [0, 0, 1, 0, 0], # in[0] + [1, 0, 1, 0, 0], # in[1] + [2, 0.5, 1, 0, 0], # out[0], + [3, 1, 1, 0, 0], # hidden[0], + [4, -1, 1, 0, 0], # hidden[1], + ] + ) # in_node, out_node, enable, weight - conns = jnp.array([ - [0, 3, 1, 0.5], # in[0] -> hidden[0] - [1, 4, 1, 0.5], # in[1] -> hidden[1] - [3, 2, 1, 0.5], # hidden[0] -> out[0] - [4, 2, 1, 0.5], # hidden[1] -> out[0] - ]) + conns = jnp.array( + [ + [0, 3, 1, 0.5], # in[0] -> hidden[0] + [1, 4, 1, 0.5], # in[1] -> hidden[1] + [3, 2, 1, 0.5], # hidden[0] -> out[0] + [4, 2, 1, 0.5], # hidden[1] -> out[0] + ] + ) genome = RecurrentGenome( num_inputs=2, @@ -89,35 +100,38 @@ def test_recurrent(): max_conns=4, node_gene=DefaultNodeGene( activation_default=Act.identity, - activation_options=(Act.identity, ), + activation_options=(Act.identity,), aggregation_default=Agg.sum, - aggregation_options=(Agg.sum, ), + aggregation_options=(Agg.sum,), ), activate_time=3, ) state = genome.setup(State(randkey=jax.random.key(0))) - state, *transformed = genome.transform(state, nodes, conns) - print(*transformed, sep='\n') + transformed = genome.transform(state, nodes, conns) + print(*transformed, sep="\n") - inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]]) - state, outputs = jax.jit(jax.vmap(genome.forward, - in_axes=(None, 0, None), out_axes=(None, 0)))(state, inputs, transformed) + inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]]) + outputs = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, None)))( + state, inputs, transformed + ) print(outputs) assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]])) # expected: [[0.5], [0.75], [0.75], [1]] - print('\n-------------------------------------------------------\n') + print("\n-------------------------------------------------------\n") conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0] print(conns) - state, *transformed = genome.transform(state, nodes, conns) - print(*transformed, sep='\n') + transformed = genome.transform(state, nodes, conns) + print(*transformed, sep="\n") - inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]]) - state, outputs = jax.vmap(genome.forward, in_axes=(None, 0, None), out_axes=(None, 0))(state, inputs, transformed) + inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]]) + outputs = jax.vmap(genome.forward, in_axes=(None, 0, None))( + state, inputs, transformed + ) print(outputs) assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]])) - # expected: [[0.5], [0.75], [0.5], [0.75]] \ No newline at end of file + # expected: [[0.5], [0.75], [0.5], [0.75]] diff --git a/tensorneat/test/test_nan_fitness.py b/tensorneat/test/test_nan_fitness.py new file mode 100644 index 0000000..79247cd --- /dev/null +++ b/tensorneat/test/test_nan_fitness.py @@ -0,0 +1,35 @@ +import jax, jax.numpy as jnp +from utils import Act +from algorithm.neat import * +import numpy as np + + +def main(): + node_path = "../examples/brax/nan_node.npy" + conn_path = "../examples/brax/nan_conn.npy" + nodes = np.load(node_path) + conns = np.load(conn_path) + nodes, conns = jax.device_put([nodes, conns]) + + genome = DefaultGenome( + num_inputs=8, + num_outputs=2, + max_nodes=20, + max_conns=20, + node_gene=DefaultNodeGene( + activation_options=(Act.tanh,), + activation_default=Act.tanh, + ) + ) + + transformed = genome.transform(nodes, conns) + print(*transformed, sep='\n') + + key = jax.random.key(0) + dummy_input = jnp.zeros((8,)) + output = genome.forward(dummy_input, transformed) + print(output) + + +if __name__ == '__main__': + main() diff --git a/tensorneat/utils/activation.py b/tensorneat/utils/activation.py index 11e34f3..c4bb494 100644 --- a/tensorneat/utils/activation.py +++ b/tensorneat/utils/activation.py @@ -3,7 +3,6 @@ import jax.numpy as jnp class Act: - @staticmethod def sigmoid(z): z = jnp.clip(5 * z, -10, 10) @@ -36,11 +35,7 @@ class Act: @staticmethod def inv(z): - z = jnp.where( - z > 0, - jnp.maximum(z, 1e-7), - jnp.minimum(z, -1e-7) - ) + z = jnp.where(z > 0, jnp.maximum(z, 1e-7), jnp.minimum(z, -1e-7)) return 1 / z @staticmethod diff --git a/tensorneat/utils/aggregation.py b/tensorneat/utils/aggregation.py index 63df1e4..114abc2 100644 --- a/tensorneat/utils/aggregation.py +++ b/tensorneat/utils/aggregation.py @@ -3,7 +3,6 @@ import jax.numpy as jnp class Agg: - @staticmethod def sum(z): z = jnp.where(jnp.isnan(z), 0, z) @@ -63,5 +62,5 @@ def agg(idx, z, agg_funcs): return jax.lax.cond( jnp.all(jnp.isnan(z)), lambda: jnp.nan, # all inputs are nan - lambda: jax.lax.switch(idx, agg_funcs, z) # otherwise + lambda: jax.lax.switch(idx, agg_funcs, z), # otherwise ) diff --git a/tensorneat/utils/graph.py b/tensorneat/utils/graph.py index ef4eb19..b2c2a4f 100644 --- a/tensorneat/utils/graph.py +++ b/tensorneat/utils/graph.py @@ -6,7 +6,7 @@ Only used in feed-forward networks. import jax from jax import jit, Array, numpy as jnp -from .tools import fetch_first, I_INT +from .tools import fetch_first, I_INF @jit @@ -17,16 +17,16 @@ def topological_sort(nodes: Array, conns: Array) -> Array: """ in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(conns, axis=0)) - res = jnp.full(in_degree.shape, I_INT) + res = jnp.full(in_degree.shape, I_INF) def cond_fun(carry): res_, idx_, in_degree_ = carry - i = fetch_first(in_degree_ == 0.) - return i != I_INT + i = fetch_first(in_degree_ == 0.0) + return i != I_INF def body_func(carry): res_, idx_, in_degree_ = carry - i = fetch_first(in_degree_ == 0.) + i = fetch_first(in_degree_ == 0.0) # add to res and flag it is already in it res_ = res_.at[idx_].set(i) @@ -65,4 +65,4 @@ def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array: return visited_, new_visited_ _, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited)) - return visited[from_idx] \ No newline at end of file + return visited[from_idx] diff --git a/tensorneat/utils/state.py b/tensorneat/utils/state.py index 8212129..84a4fe6 100644 --- a/tensorneat/utils/state.py +++ b/tensorneat/utils/state.py @@ -3,9 +3,8 @@ from jax.tree_util import register_pytree_node_class @register_pytree_node_class class State: - def __init__(self, **kwargs): - self.__dict__['state_dict'] = kwargs + self.__dict__["state_dict"] = kwargs def registered_keys(self): return self.state_dict.keys() diff --git a/tensorneat/utils/tools.py b/tensorneat/utils/tools.py index 0103296..c592496 100644 --- a/tensorneat/utils/tools.py +++ b/tensorneat/utils/tools.py @@ -4,13 +4,14 @@ import numpy as np import jax from jax import numpy as jnp, Array, jit, vmap -I_INT = np.iinfo(jnp.int32).max # infinite int +I_INF = np.iinfo(jnp.int32).max # infinite int def unflatten_conns(nodes, conns): """ - transform the (C, CL) connections to (CL-2, N, N), 2 is for the input index and output index) - :return: + transform the (C, CL) connections to (CL-2, N, N), 2 is for the input index and output index), which CL means + connection length, N means the number of nodes, C means the number of connections + returns the un_flattened connections with shape (CL-2, N, N) """ N = nodes.shape[0] CL = conns.shape[1] @@ -33,7 +34,7 @@ def key_to_indices(key, keys): @jit -def fetch_first(mask, default=I_INT) -> Array: +def fetch_first(mask, default=I_INF) -> Array: """ fetch the first True index :param mask: array of bool @@ -45,18 +46,18 @@ def fetch_first(mask, default=I_INT) -> Array: @jit -def fetch_random(rand_key, mask, default=I_INT) -> Array: +def fetch_random(randkey, mask, default=I_INF) -> Array: """ similar to fetch_first, but fetch a random True index """ true_cnt = jnp.sum(mask) cumsum = jnp.cumsum(mask) - target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1) + target = jax.random.randint(randkey, shape=(), minval=1, maxval=true_cnt + 1) mask = jnp.where(true_cnt == 0, False, cumsum >= target) return fetch_first(mask, default) -@partial(jit, static_argnames=['reverse']) +@partial(jit, static_argnames=["reverse"]) def rank_elements(array, reverse=False): """ rank the element in the array. @@ -68,8 +69,17 @@ def rank_elements(array, reverse=False): @jit -def mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate): - k1, k2, k3 = jax.random.split(key, num=3) +def mutate_float( + randkey, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate +): + """ + mutate a float value + uniformly pick r from [0, 1] + r in [0, mutate_rate) -> add noise + r in [mutate_rate, mutate_rate + replace_rate) -> create a new value to replace the original value + otherwise -> keep the original value + """ + k1, k2, k3 = jax.random.split(randkey, num=3) noise = jax.random.normal(k1, ()) * mutate_power replace = jax.random.normal(k2, ()) * init_std + init_mean r = jax.random.uniform(k3, ()) @@ -77,30 +87,32 @@ def mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, repla val = jnp.where( r < mutate_rate, val + noise, - jnp.where( - (mutate_rate < r) & (r < mutate_rate + replace_rate), - replace, - val - ) + jnp.where((mutate_rate < r) & (r < mutate_rate + replace_rate), replace, val), ) return val @jit -def mutate_int(key, val, options, replace_rate): - k1, k2 = jax.random.split(key, num=2) +def mutate_int(randkey, val, options, replace_rate): + """ + mutate an int value + uniformly pick r from [0, 1] + r in [0, replace_rate) -> create a new value to replace the original value + otherwise -> keep the original value + """ + k1, k2 = jax.random.split(randkey, num=2) r = jax.random.uniform(k1, ()) - val = jnp.where( - r < replace_rate, - jax.random.choice(k2, options), - val - ) + val = jnp.where(r < replace_rate, jax.random.choice(k2, options), val) return val + def argmin_with_mask(arr, mask): + """ + find the index of the minimum element in the array, but only consider the element with True mask + """ masked_arr = jnp.where(mask, arr, jnp.inf) min_idx = jnp.argmin(masked_arr) - return min_idx \ No newline at end of file + return min_idx