From 18c3d44c795c5607c96e69b6b7a7f94877bcc2dd Mon Sep 17 00:00:00 2001 From: wls2002 Date: Sun, 26 May 2024 18:08:43 +0800 Subject: [PATCH] complete fully stateful! use black to format all files! --- .../algorithm/neat/ga/crossover/base.py | 2 +- .../algorithm/neat/ga/crossover/default.py | 13 +- tensorneat/algorithm/neat/ga/mutation/base.py | 2 +- .../algorithm/neat/ga/mutation/default.py | 60 +++-- .../algorithm/neat/gene/conn/default.py | 2 +- .../algorithm/neat/gene/node/default.py | 8 +- .../gene/node/default_without_response.py | 106 ++++++++ tensorneat/algorithm/neat/genome/base.py | 118 ++++++--- tensorneat/algorithm/neat/genome/default.py | 12 +- tensorneat/algorithm/neat/genome/recurrent.py | 12 +- tensorneat/algorithm/neat/neat.py | 44 ++-- tensorneat/algorithm/neat/species/default.py | 246 ++++-------------- tensorneat/examples/brax/ant.py | 10 +- tensorneat/examples/brax/half_cheetah.py | 10 +- tensorneat/examples/brax/reacher.py | 10 +- tensorneat/examples/brax/walker.py | 10 +- tensorneat/examples/func_fit/xor.py | 16 +- .../examples/func_fit/xor3d_hyperneat.py | 23 +- tensorneat/examples/func_fit/xor_recurrent.py | 24 +- tensorneat/examples/gymnax/arcbot.py | 12 +- tensorneat/examples/gymnax/cartpole.py | 12 +- .../examples/gymnax/cartpole_hyperneat.py | 29 +-- tensorneat/examples/gymnax/mountain_car.py | 12 +- .../gymnax/mountain_car_continuous.py | 12 +- tensorneat/examples/gymnax/pendulum.py | 11 +- tensorneat/examples/gymnax/reacher.py | 8 +- tensorneat/examples/with_evox/ray_test.py | 1 + tensorneat/pipeline.py | 63 +++-- tensorneat/problem/base.py | 4 +- tensorneat/problem/func_fit/func_fit.py | 30 ++- tensorneat/problem/func_fit/xor.py | 17 +- tensorneat/problem/func_fit/xor3d.py | 36 +-- tensorneat/problem/rl_env/brax_env.py | 24 +- tensorneat/problem/rl_env/gymnax_env.py | 3 +- tensorneat/problem/rl_env/rl_jit.py | 30 +-- tensorneat/test/crossover_mutation.py | 6 +- tensorneat/test/nan_fitness.py | 4 +- tensorneat/test/test_genome.py | 27 ++ tensorneat/test/test_nan_fitness.py | 6 +- tensorneat/utils/__init__.py | 2 +- tensorneat/utils/tools.py | 38 +++ 41 files changed, 620 insertions(+), 495 deletions(-) create mode 100644 tensorneat/algorithm/neat/gene/node/default_without_response.py diff --git a/tensorneat/algorithm/neat/ga/crossover/base.py b/tensorneat/algorithm/neat/ga/crossover/base.py index 8a2dc65..b59ce6c 100644 --- a/tensorneat/algorithm/neat/ga/crossover/base.py +++ b/tensorneat/algorithm/neat/ga/crossover/base.py @@ -5,5 +5,5 @@ class BaseCrossover: def setup(self, state=State()): return state - def __call__(self, state, genome, nodes1, nodes2, conns1, conns2): + def __call__(self, state, randkey, genome, nodes1, nodes2, conns1, conns2): raise NotImplementedError diff --git a/tensorneat/algorithm/neat/ga/crossover/default.py b/tensorneat/algorithm/neat/ga/crossover/default.py index 71fd7af..5c3014f 100644 --- a/tensorneat/algorithm/neat/ga/crossover/default.py +++ b/tensorneat/algorithm/neat/ga/crossover/default.py @@ -4,12 +4,12 @@ from .base import BaseCrossover class DefaultCrossover(BaseCrossover): - def __call__(self, state, genome, nodes1, conns1, nodes2, conns2): + def __call__(self, state, randkey, genome, nodes1, conns1, nodes2, conns2): """ use genome1 and genome2 to generate a new genome notice that genome1 should have higher fitness than genome2 (genome1 is winner!) """ - randkey1, randkey2, randkey = jax.random.split(state.randkey, 3) + randkey1, randkey2 = jax.random.split(randkey, 2) # crossover nodes keys1, keys2 = nodes1[:, 0], nodes2[:, 0] @@ -34,11 +34,12 @@ class DefaultCrossover(BaseCrossover): self.crossover_gene(randkey2, conns1, conns2, is_conn=True), ) - return state.update(randkey=randkey), new_nodes, new_conns + return new_nodes, new_conns def align_array(self, seq1, seq2, ar2, is_conn: bool): """ - After I review this code, I found that it is the most difficult part of the code. Please never change it! + After I review this code, I found that it is the most difficult part of the code. + Please consider carefully before change it! make ar2 align with ar1. :param seq1: :param seq2: @@ -64,8 +65,8 @@ class DefaultCrossover(BaseCrossover): return refactor_ar2 - def crossover_gene(self, rand_key, g1, g2, is_conn): - r = jax.random.uniform(rand_key, shape=g1.shape) + def crossover_gene(self, randkey, g1, g2, is_conn): + r = jax.random.uniform(randkey, shape=g1.shape) new_gene = jnp.where(r > 0.5, g1, g2) if is_conn: # fix enabled enabled = jnp.where(g1[:, 2] + g2[:, 2] > 0, 1, 0) # any of them is enabled diff --git a/tensorneat/algorithm/neat/ga/mutation/base.py b/tensorneat/algorithm/neat/ga/mutation/base.py index ab7c06b..68bd05a 100644 --- a/tensorneat/algorithm/neat/ga/mutation/base.py +++ b/tensorneat/algorithm/neat/ga/mutation/base.py @@ -5,5 +5,5 @@ class BaseMutation: def setup(self, state=State()): return state - def __call__(self, state, genome, nodes, conns, new_node_key): + def __call__(self, state, randkey, genome, nodes, conns, new_node_key): raise NotImplementedError diff --git a/tensorneat/algorithm/neat/ga/mutation/default.py b/tensorneat/algorithm/neat/ga/mutation/default.py index 0d716c0..7bb32c8 100644 --- a/tensorneat/algorithm/neat/ga/mutation/default.py +++ b/tensorneat/algorithm/neat/ga/mutation/default.py @@ -1,6 +1,16 @@ import jax, jax.numpy as jnp from . import BaseMutation -from utils import fetch_first, fetch_random, I_INF, unflatten_conns, check_cycles +from utils import ( + fetch_first, + fetch_random, + I_INF, + unflatten_conns, + check_cycles, + add_node, + add_conn, + delete_node_by_pos, + delete_conn_by_pos, +) class DefaultMutation(BaseMutation): @@ -16,15 +26,17 @@ class DefaultMutation(BaseMutation): self.node_add = node_add self.node_delete = node_delete - def __call__(self, state, genome, nodes, conns, new_node_key): - k1, k2, randkey = jax.random.split(state.randkey) + def __call__(self, state, randkey, genome, nodes, conns, new_node_key): + k1, k2 = jax.random.split(randkey) - nodes, conns = self.mutate_structure(k1, genome, nodes, conns, new_node_key) - nodes, conns = self.mutate_values(k2, genome, nodes, conns) + nodes, conns = self.mutate_structure( + state, k1, genome, nodes, conns, new_node_key + ) + nodes, conns = self.mutate_values(state, k2, genome, nodes, conns) - return state.update(randkey=randkey), nodes, conns + return nodes, conns - def mutate_structure(self, key, genome, nodes, conns, new_node_key): + def mutate_structure(self, state, randkey, genome, nodes, conns, new_node_key): def mutate_add_node(key_, nodes_, conns_): i_key, o_key, idx = self.choice_connection_key(key_, conns_) @@ -33,24 +45,24 @@ class DefaultMutation(BaseMutation): new_conns = conns_.at[idx, 2].set(False) # add a new node - new_nodes = genome.add_node( - nodes_, new_node_key, genome.node_gene.new_custom_attrs() + new_nodes = add_node( + nodes_, new_node_key, genome.node_gene.new_custom_attrs(state) ) # add two new connections - new_conns = genome.add_conn( + new_conns = add_conn( new_conns, i_key, new_node_key, True, - genome.conn_gene.new_custom_attrs(), + genome.conn_gene.new_custom_attrs(state), ) - new_conns = genome.add_conn( + new_conns = add_conn( new_conns, new_node_key, o_key, True, - genome.conn_gene.new_custom_attrs(), + genome.conn_gene.new_custom_attrs(state), ) return new_nodes, new_conns @@ -75,7 +87,7 @@ class DefaultMutation(BaseMutation): def successful_delete_node(): # delete the node - new_nodes = genome.delete_node_by_pos(nodes_, idx) + new_nodes = delete_node_by_pos(nodes_, idx) # delete all connections new_conns = jnp.where( @@ -123,8 +135,8 @@ class DefaultMutation(BaseMutation): return nodes_, conns_ def successful(): - return nodes_, genome.add_conn( - conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs() + return nodes_, add_conn( + conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs(state) ) def already_exist(): @@ -152,7 +164,7 @@ class DefaultMutation(BaseMutation): i_key, o_key, idx = self.choice_connection_key(key_, conns_) def successfully_delete_connection(): - return nodes_, genome.delete_conn_by_pos(conns_, idx) + return nodes_, delete_conn_by_pos(conns_, idx) return jax.lax.cond( idx == I_INF, @@ -160,7 +172,7 @@ class DefaultMutation(BaseMutation): successfully_delete_connection, ) - k1, k2, k3, k4 = jax.random.split(key, num=4) + k1, k2, k3, k4 = jax.random.split(randkey, num=4) r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,)) def no(key_, nodes_, conns_): @@ -181,13 +193,17 @@ class DefaultMutation(BaseMutation): return nodes, conns - def mutate_values(self, key, genome, nodes, conns): - k1, k2 = jax.random.split(key, num=2) + def mutate_values(self, state, randkey, genome, nodes, conns): + k1, k2 = jax.random.split(randkey, num=2) nodes_keys = jax.random.split(k1, num=nodes.shape[0]) conns_keys = jax.random.split(k2, num=conns.shape[0]) - new_nodes = jax.vmap(genome.node_gene.mutate)(nodes_keys, nodes) - new_conns = jax.vmap(genome.conn_gene.mutate)(conns_keys, conns) + new_nodes = jax.vmap(genome.node_gene.mutate, in_axes=(None, 0, 0))( + state, nodes_keys, nodes + ) + new_conns = jax.vmap(genome.conn_gene.mutate, in_axes=(None, 0, 0))( + state, conns_keys, conns + ) # nan nodes not changed new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes) diff --git a/tensorneat/algorithm/neat/gene/conn/default.py b/tensorneat/algorithm/neat/gene/conn/default.py index 26d1a80..2f2ed04 100644 --- a/tensorneat/algorithm/neat/gene/conn/default.py +++ b/tensorneat/algorithm/neat/gene/conn/default.py @@ -26,7 +26,7 @@ class DefaultConnGene(BaseConnGene): self.weight_replace_rate = weight_replace_rate def new_custom_attrs(self, state): - return state, jnp.array([self.weight_init_mean]) + return jnp.array([self.weight_init_mean]) def new_random_attrs(self, state, randkey): weight = ( diff --git a/tensorneat/algorithm/neat/gene/node/default.py b/tensorneat/algorithm/neat/gene/node/default.py index 6259527..8a884b7 100644 --- a/tensorneat/algorithm/neat/gene/node/default.py +++ b/tensorneat/algorithm/neat/gene/node/default.py @@ -109,10 +109,10 @@ class DefaultNodeGene(BaseNodeGene): def distance(self, state, node1, node2): return ( - jnp.abs(node1[1] - node2[1]) - + jnp.abs(node1[2] - node2[2]) - + (node1[3] != node2[3]) - + (node1[4] != node2[4]) + jnp.abs(node1[1] - node2[1]) # bias + + jnp.abs(node1[2] - node2[2]) # response + + (node1[3] != node2[3]) # activation + + (node1[4] != node2[4]) # aggregation ) def forward(self, state, attrs, inputs, is_output_node=False): diff --git a/tensorneat/algorithm/neat/gene/node/default_without_response.py b/tensorneat/algorithm/neat/gene/node/default_without_response.py new file mode 100644 index 0000000..32c9419 --- /dev/null +++ b/tensorneat/algorithm/neat/gene/node/default_without_response.py @@ -0,0 +1,106 @@ +from typing import Tuple + +import jax, jax.numpy as jnp + +from utils import Act, Agg, act, agg, mutate_int, mutate_float +from . import BaseNodeGene + + +class NodeGeneWithoutResponse(BaseNodeGene): + """ + Default node gene, with the same behavior as in NEAT-python. + The attribute response is removed. + """ + + custom_attrs = ["bias", "aggregation", "activation"] + + def __init__( + self, + bias_init_mean: float = 0.0, + bias_init_std: float = 1.0, + bias_mutate_power: float = 0.5, + bias_mutate_rate: float = 0.7, + bias_replace_rate: float = 0.1, + activation_default: callable = Act.sigmoid, + activation_options: Tuple = (Act.sigmoid,), + activation_replace_rate: float = 0.1, + aggregation_default: callable = Agg.sum, + aggregation_options: Tuple = (Agg.sum,), + aggregation_replace_rate: float = 0.1, + ): + super().__init__() + self.bias_init_mean = bias_init_mean + self.bias_init_std = bias_init_std + self.bias_mutate_power = bias_mutate_power + self.bias_mutate_rate = bias_mutate_rate + self.bias_replace_rate = bias_replace_rate + + self.activation_default = activation_options.index(activation_default) + self.activation_options = activation_options + self.activation_indices = jnp.arange(len(activation_options)) + self.activation_replace_rate = activation_replace_rate + + self.aggregation_default = aggregation_options.index(aggregation_default) + self.aggregation_options = aggregation_options + self.aggregation_indices = jnp.arange(len(aggregation_options)) + self.aggregation_replace_rate = aggregation_replace_rate + + def new_custom_attrs(self, state): + return jnp.array( + [ + self.bias_init_mean, + self.activation_default, + self.aggregation_default, + ] + ) + + def new_random_attrs(self, state, randkey): + k1, k2, k3, k4 = jax.random.split(randkey, num=4) + bias = jax.random.normal(k1, ()) * self.bias_init_std + self.bias_init_mean + act = jax.random.randint(k3, (), 0, len(self.activation_options)) + agg = jax.random.randint(k4, (), 0, len(self.aggregation_options)) + return jnp.array([bias, act, agg]) + + def mutate(self, state, randkey, node): + k1, k2, k3, k4 = jax.random.split(state.randkey, num=4) + index = node[0] + + bias = mutate_float( + k1, + node[1], + self.bias_init_mean, + self.bias_init_std, + self.bias_mutate_power, + self.bias_mutate_rate, + self.bias_replace_rate, + ) + + act = mutate_int( + k3, node[3], self.activation_indices, self.activation_replace_rate + ) + + agg = mutate_int( + k4, node[4], self.aggregation_indices, self.aggregation_replace_rate + ) + + return jnp.array([index, bias, act, agg]) + + def distance(self, state, node1, node2): + return ( + jnp.abs(node1[1] - node2[1]) # bias + + (node1[3] != node2[3]) # activation + + (node1[4] != node2[4]) # aggregation + ) + + def forward(self, state, attrs, inputs, is_output_node=False): + bias, act_idx, agg_idx = attrs + + z = agg(agg_idx, inputs, self.aggregation_options) + z = bias + z + + # the last output node should not be activated + z = jax.lax.cond( + is_output_node, lambda: z, lambda: act(act_idx, z, self.activation_options) + ) + + return z diff --git a/tensorneat/algorithm/neat/genome/base.py b/tensorneat/algorithm/neat/genome/base.py index 995d13f..7f7bf39 100644 --- a/tensorneat/algorithm/neat/genome/base.py +++ b/tensorneat/algorithm/neat/genome/base.py @@ -1,6 +1,7 @@ -import jax.numpy as jnp -from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene -from utils import fetch_first, State +import jax, jax.numpy as jnp +from ..gene import BaseNodeGene, BaseConnGene +from ..ga import BaseMutation, BaseCrossover +from utils import State class BaseGenome: @@ -12,8 +13,10 @@ class BaseGenome: num_outputs: int, max_nodes: int, max_conns: int, - node_gene: BaseNodeGene = DefaultNodeGene(), - conn_gene: BaseConnGene = DefaultConnGene(), + node_gene: BaseNodeGene, + conn_gene: BaseConnGene, + mutation: BaseMutation, + crossover: BaseCrossover, ): self.num_inputs = num_inputs self.num_outputs = num_outputs @@ -23,10 +26,14 @@ class BaseGenome: self.max_conns = max_conns self.node_gene = node_gene self.conn_gene = conn_gene + self.mutation = mutation + self.crossover = crossover def setup(self, state=State()): state = self.node_gene.setup(state) state = self.conn_gene.setup(state) + state = self.mutation.setup(state) + state = self.crossover.setup(state) return state def transform(self, state, nodes, conns): @@ -35,36 +42,81 @@ class BaseGenome: def forward(self, state, inputs, transformed): raise NotImplementedError - def add_node(self, nodes, new_key: int, attrs): - """ - Add a new node to the genome. - The new node will place at the first NaN row. - """ - exist_keys = nodes[:, 0] - pos = fetch_first(jnp.isnan(exist_keys)) - new_nodes = nodes.at[pos, 0].set(new_key) - return new_nodes.at[pos, 1:].set(attrs) + def execute_mutation(self, state, randkey, nodes, conns, new_node_key): + return self.mutation(state, randkey, self, nodes, conns, new_node_key) - def delete_node_by_pos(self, nodes, pos): - """ - Delete a node from the genome. - Delete the node by its pos in nodes. - """ - return nodes.at[pos].set(jnp.nan) + def execute_crossover(self, state, randkey, nodes1, conns1, nodes2, conns2): + return self.crossover(state, randkey, self, nodes1, conns1, nodes2, conns2) - def add_conn(self, conns, i_key, o_key, enable: bool, attrs): + def initialize(self, state, randkey): """ - Add a new connection to the genome. - The new connection will place at the first NaN row. - """ - con_keys = conns[:, 0] - pos = fetch_first(jnp.isnan(con_keys)) - new_conns = conns.at[pos, 0:3].set(jnp.array([i_key, o_key, enable])) - return new_conns.at[pos, 3:].set(attrs) + Default initialization method for the genome. + Add an extra hidden node. + Make all input nodes and output nodes connected to the hidden node. + All attributes will be initialized randomly using gene.new_random_attrs method. - def delete_conn_by_pos(self, conns, pos): + For example, a network with 2 inputs and 1 output, the structure will be: + nodes: + [ + [0, attrs0], # input node 0 + [1, attrs1], # input node 1 + [2, attrs2], # output node 0 + [3, attrs3], # hidden node + [NaN, NaN], # empty node + ] + conns: + [ + [0, 3, attrs0], # input node 0 -> hidden node + [1, 3, attrs1], # input node 1 -> hidden node + [3, 2, attrs2], # hidden node -> output node 0 + [NaN, NaN], + [NaN, NaN], + ] """ - Delete a connection from the genome. - Delete the connection by its idx. - """ - return conns.at[pos].set(jnp.nan) + + k1, k2 = jax.random.split(randkey) # k1 for nodes, k2 for conns + # initialize nodes + new_node_key = ( + max([*self.input_idx, *self.output_idx]) + 1 + ) # the key for the hidden node + node_keys = jnp.concatenate( + [self.input_idx, self.output_idx, jnp.array([new_node_key])] + ) # the list of all node keys + + # initialize nodes and connections with NaN + nodes = jnp.full((self.max_nodes, self.node_gene.length), jnp.nan) + conns = jnp.full((self.max_conns, self.conn_gene.length), jnp.nan) + + # set keys for input nodes, output nodes and hidden node + nodes = nodes.at[node_keys, 0].set(node_keys) + + # generate random attributes for nodes + node_keys = jax.random.split(k1, len(node_keys)) + random_node_attrs = jax.vmap( + self.node_gene.new_random_attrs, in_axes=(None, 0) + )(state, node_keys) + nodes = nodes.at[: len(node_keys), 1:].set(random_node_attrs) + + # initialize conns + # input-hidden connections + input_conns = jnp.c_[ + self.input_idx, jnp.full_like(self.input_idx, new_node_key) + ] + conns = conns.at[self.input_idx, :2].set(input_conns) # in-keys, out-keys + conns = conns.at[self.input_idx, 2].set(True) # enable + + # output-hidden connections + output_conns = jnp.c_[ + jnp.full_like(self.output_idx, new_node_key), self.output_idx + ] + conns = conns.at[self.output_idx, :2].set(output_conns) # in-keys, out-keys + conns = conns.at[self.output_idx, 2].set(True) # enable + + conn_keys = jax.random.split(k2, num=len(self.input_idx) + len(self.output_idx)) + # generate random attributes for conns + random_conn_attrs = jax.vmap( + self.conn_gene.new_random_attrs, in_axes=(None, 0) + )(state, conn_keys) + conns = conns.at[: len(conn_keys), 3:].set(random_conn_attrs) + + return nodes, conns diff --git a/tensorneat/algorithm/neat/genome/default.py b/tensorneat/algorithm/neat/genome/default.py index 37a6453..a2b6a50 100644 --- a/tensorneat/algorithm/neat/genome/default.py +++ b/tensorneat/algorithm/neat/genome/default.py @@ -5,6 +5,7 @@ from utils import unflatten_conns, topological_sort, I_INF from . import BaseGenome from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene +from ..ga import BaseMutation, BaseCrossover, DefaultMutation, DefaultCrossover class DefaultGenome(BaseGenome): @@ -20,10 +21,19 @@ class DefaultGenome(BaseGenome): max_conns=4, node_gene: BaseNodeGene = DefaultNodeGene(), conn_gene: BaseConnGene = DefaultConnGene(), + mutation: BaseMutation = DefaultMutation(), + crossover: BaseCrossover = DefaultCrossover(), output_transform: Callable = None, ): super().__init__( - num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene + num_inputs, + num_outputs, + max_nodes, + max_conns, + node_gene, + conn_gene, + mutation, + crossover, ) if output_transform is not None: diff --git a/tensorneat/algorithm/neat/genome/recurrent.py b/tensorneat/algorithm/neat/genome/recurrent.py index 3e77271..88d88e8 100644 --- a/tensorneat/algorithm/neat/genome/recurrent.py +++ b/tensorneat/algorithm/neat/genome/recurrent.py @@ -5,6 +5,7 @@ from utils import unflatten_conns from . import BaseGenome from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene +from ..ga import BaseMutation, BaseCrossover, DefaultMutation, DefaultCrossover class RecurrentGenome(BaseGenome): @@ -20,11 +21,20 @@ class RecurrentGenome(BaseGenome): max_conns: int, node_gene: BaseNodeGene = DefaultNodeGene(), conn_gene: BaseConnGene = DefaultConnGene(), + mutation: BaseMutation = DefaultMutation(), + crossover: BaseCrossover = DefaultCrossover(), activate_time: int = 10, output_transform: Callable = None, ): super().__init__( - num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene + num_inputs, + num_outputs, + max_nodes, + max_conns, + node_gene, + conn_gene, + mutation, + crossover, ) self.activate_time = activate_time diff --git a/tensorneat/algorithm/neat/neat.py b/tensorneat/algorithm/neat/neat.py index 057f7b6..f0e65e0 100644 --- a/tensorneat/algorithm/neat/neat.py +++ b/tensorneat/algorithm/neat/neat.py @@ -10,18 +10,12 @@ class NEAT(BaseAlgorithm): def __init__( self, species: BaseSpecies, - mutation: BaseMutation = DefaultMutation(), - crossover: BaseCrossover = DefaultCrossover(), ): - self.genome: BaseGenome = species.genome self.species = species - self.mutation = mutation - self.crossover = crossover + self.genome = species.genome def setup(self, state=State()): state = self.species.setup(state) - state = self.mutation.setup(state) - state = self.crossover.setup(state) state = state.register( generation=jnp.array(0.0), next_node_key=jnp.array( @@ -32,18 +26,16 @@ class NEAT(BaseAlgorithm): return state def ask(self, state: State): - return state, self.species.ask(state.species) + return self.species.ask(state) def tell(self, state: State, fitness): k1, k2, randkey = jax.random.split(state.randkey, 3) state = state.update(generation=state.generation + 1, randkey=randkey) - state, winner, loser, elite_mask = self.species.update_species( - state.species, fitness - ) + state, winner, loser, elite_mask = self.species.update_species(state, fitness) state = self.create_next_generation(state, winner, loser, elite_mask) - state = self.species.speciate(state.species) + state = self.species.speciate(state) return state @@ -73,21 +65,25 @@ class NEAT(BaseAlgorithm): new_node_keys = jnp.arange(pop_size) + state.next_node_key k1, k2, randkey = jax.random.split(state.randkey, 3) - crossover_rand_keys = jax.random.split(k1, pop_size) - mutate_rand_keys = jax.random.split(k2, pop_size) + crossover_randkeys = jax.random.split(k1, pop_size) + mutate_randkeys = jax.random.split(k2, pop_size) - wpn, wpc = state.species.pop_nodes[winner], state.species.pop_conns[winner] - lpn, lpc = state.species.pop_nodes[loser], state.species.pop_conns[loser] + wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner] + lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser] # batch crossover - n_nodes, n_conns = jax.vmap(self.crossover, in_axes=(0, None, 0, 0, 0, 0))( - crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc - ) + n_nodes, n_conns = jax.vmap( + self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0) + )( + state, crossover_randkeys, wpn, wpc, lpn, lpc + ) # new_nodes, new_conns # batch mutation - m_n_nodes, m_n_conns = jax.vmap(self.mutation, in_axes=(0, None, 0, 0, 0))( - mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys - ) + m_n_nodes, m_n_conns = jax.vmap( + self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0) + )( + state, mutate_randkeys, n_nodes, n_conns, new_node_keys + ) # mutated_new_nodes, mutated_new_conns # elitism don't mutate pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes) @@ -108,8 +104,8 @@ class NEAT(BaseAlgorithm): ) def member_count(self, state: State): - return state, state.species.member_count + return state.member_count def generation(self, state: State): # to analysis the algorithm - return state, state.generation + return state.generation diff --git a/tensorneat/algorithm/neat/species/default.py b/tensorneat/algorithm/neat/species/default.py index 8bab142..f82141c 100644 --- a/tensorneat/algorithm/neat/species/default.py +++ b/tensorneat/algorithm/neat/species/default.py @@ -1,10 +1,22 @@ -import numpy as np import jax, jax.numpy as jnp from utils import State, rank_elements, argmin_with_mask, fetch_first from ..genome import BaseGenome from .base import BaseSpecies +""" +Core procedures of NEAT algorithm, contains the following steps: +1. Update the fitness of each species; +2. Decide which species will be stagnation; +3. Decide the number of members of each species in the next generation; +4. Choice the crossover pair for each species; +5. Divided the whole new population into different species; + +This class use tensor operation to imitate the behavior of NEAT algorithm which implemented in NEAT-python. +The code may be hard to understand. Fortunately, we don't need to overwrite it in most cases. +""" + + class DefaultSpecies(BaseSpecies): def __init__( self, @@ -20,8 +32,6 @@ class DefaultSpecies(BaseSpecies): survival_threshold: float = 0.2, min_species_size: int = 1, compatibility_threshold: float = 3.0, - initialize_method: str = "one_hidden_node", - # {'one_hidden_node', 'dense_hideen_layer', 'no_hidden_random'} ): self.genome = genome self.pop_size = pop_size @@ -36,15 +46,17 @@ class DefaultSpecies(BaseSpecies): self.survival_threshold = survival_threshold self.min_species_size = min_species_size self.compatibility_threshold = compatibility_threshold - self.initialize_method = initialize_method self.species_arange = jnp.arange(self.species_size) def setup(self, state=State()): state = self.genome.setup(state) k1, randkey = jax.random.split(state.randkey, 2) - pop_nodes, pop_conns = initialize_population( - self.pop_size, self.genome, k1, self.initialize_method + + # initialize the population + initialize_keys = jax.random.split(randkey, self.pop_size) + pop_nodes, pop_conns = jax.vmap(self.genome.initialize, in_axes=(None, 0))( + state, initialize_keys ) species_keys = jnp.full( @@ -82,8 +94,9 @@ class DefaultSpecies(BaseSpecies): pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns)) + state = state.update(randkey=randkey) + return state.register( - randkey=randkey, pop_nodes=pop_nodes, pop_conns=pop_conns, species_keys=species_keys, @@ -97,7 +110,7 @@ class DefaultSpecies(BaseSpecies): ) def ask(self, state): - return state, state.pop_nodes, state.pop_conns + return state.pop_nodes, state.pop_conns def update_species(self, state, fitness): # update the fitness of each species @@ -122,8 +135,8 @@ class DefaultSpecies(BaseSpecies): k1, k2 = jax.random.split(state.randkey) # crossover info - winner, loser, elite_mask = self.create_crossover_pair( - state, k1, spawn_number, fitness + state, winner, loser, elite_mask = self.create_crossover_pair( + state, spawn_number, fitness ) return state.update(randkey=k2), winner, loser, elite_mask @@ -322,12 +335,12 @@ class DefaultSpecies(BaseSpecies): winner = jnp.where(is_part1_win, part1, part2) loser = jnp.where(is_part1_win, part2, part1) - return state(randkey=randkey), winner, loser, elite_mask + return state.update(randkey=randkey), winner, loser, elite_mask def speciate(self, state): # prepare distance functions o2p_distance_func = jax.vmap( - self.distance, in_axes=(None, None, 0, 0) + self.distance, in_axes=(None, None, None, 0, 0) ) # one to population # idx to specie key @@ -351,7 +364,7 @@ class DefaultSpecies(BaseSpecies): i, i2s, cns, ccs, o2c = carry distances = o2p_distance_func( - cns[i], ccs[i], state.pop_nodes, state.pop_conns + state, cns[i], ccs[i], state.pop_nodes, state.pop_conns ) # find the closest one @@ -434,7 +447,7 @@ class DefaultSpecies(BaseSpecies): def speciate_by_threshold(i, i2s, cns, ccs, sk, o2c): # distance between such center genome and ppo genomes o2p_distance = o2p_distance_func( - cns[i], ccs[i], state.pop_nodes, state.pop_conns + state, cns[i], ccs[i], state.pop_nodes, state.pop_conns ) close_enough_mask = o2p_distance < self.compatibility_threshold @@ -508,14 +521,16 @@ class DefaultSpecies(BaseSpecies): next_species_key=next_species_key, ) - def distance(self, nodes1, conns1, nodes2, conns2): + def distance(self, state, nodes1, conns1, nodes2, conns2): """ The distance between two genomes """ - d = self.node_distance(nodes1, nodes2) + self.conn_distance(conns1, conns2) + d = self.node_distance(state, nodes1, nodes2) + self.conn_distance( + state, conns1, conns2 + ) return d - def node_distance(self, nodes1, nodes2): + def node_distance(self, state, nodes1, nodes2): """ The distance of the nodes part for two genomes """ @@ -541,7 +556,9 @@ class DefaultSpecies(BaseSpecies): non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask) # calculate the distance of homologous nodes - hnd = jax.vmap(self.genome.node_gene.distance, in_axes=(0, 0))(fr, sr) + hnd = jax.vmap(self.genome.node_gene.distance, in_axes=(None, 0, 0))( + state, fr, sr + ) # homologous node distance hnd = jnp.where(jnp.isnan(hnd), 0, hnd) homologous_distance = jnp.sum(hnd * intersect_mask) @@ -550,9 +567,11 @@ class DefaultSpecies(BaseSpecies): + homologous_distance * self.compatibility_weight ) - return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division + val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize - def conn_distance(self, conns1, conns2): + return val + + def conn_distance(self, state, conns1, conns2): """ The distance of the conns part for two genomes """ @@ -573,7 +592,9 @@ class DefaultSpecies(BaseSpecies): intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0]) non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask) - hcd = jax.vmap(self.genome.conn_gene.distance, in_axes=(0, 0))(fr, sr) + hcd = jax.vmap(self.genome.conn_gene.distance, in_axes=(None, 0, 0))( + state, fr, sr + ) # homologous connection distance hcd = jnp.where(jnp.isnan(hcd), 0, hcd) homologous_distance = jnp.sum(hcd * intersect_mask) @@ -582,185 +603,6 @@ class DefaultSpecies(BaseSpecies): + homologous_distance * self.compatibility_weight ) - return jnp.where(max_cnt == 0, 0, val / max_cnt) + val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize - -def initialize_population(pop_size, genome, randkey, init_method="default"): - rand_keys = jax.random.split(randkey, pop_size) - - if init_method == "one_hidden_node": - init_func = init_one_hidden_node - elif init_method == "dense_hideen_layer": - init_func = init_dense_hideen_layer - elif init_method == "no_hidden_random": - init_func = init_no_hidden_random - else: - raise ValueError("Unknown initialization method: {}".format(init_method)) - - pop_nodes, pop_conns = jax.vmap(init_func, in_axes=(None, 0))(genome, rand_keys) - - return pop_nodes, pop_conns - - -# one hidden node -def init_one_hidden_node(genome, randkey): - input_idx, output_idx = genome.input_idx, genome.output_idx - new_node_key = max([*input_idx, *output_idx]) + 1 - - nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan) - conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan) - - nodes = nodes.at[input_idx, 0].set(input_idx) - nodes = nodes.at[output_idx, 0].set(output_idx) - nodes = nodes.at[new_node_key, 0].set(new_node_key) - - rand_keys_nodes = jax.random.split( - randkey, num=len(input_idx) + len(output_idx) + 1 - ) - input_keys, output_keys, hidden_key = ( - rand_keys_nodes[: len(input_idx)], - rand_keys_nodes[len(input_idx) : len(input_idx) + len(output_idx)], - rand_keys_nodes[-1], - ) - - node_attr_func = jax.vmap(genome.node_gene.new_attrs, in_axes=(None, 0)) - input_attrs = node_attr_func(input_keys) - output_attrs = node_attr_func(output_keys) - hidden_attrs = genome.node_gene.new_custom_attrs(hidden_key) - - nodes = nodes.at[input_idx, 1:].set(input_attrs) - nodes = nodes.at[output_idx, 1:].set(output_attrs) - nodes = nodes.at[new_node_key, 1:].set(hidden_attrs) - - input_conns = jnp.c_[input_idx, jnp.full_like(input_idx, new_node_key)] - conns = conns.at[input_idx, 0:2].set(input_conns) - conns = conns.at[input_idx, 2].set(True) - - output_conns = jnp.c_[jnp.full_like(output_idx, new_node_key), output_idx] - conns = conns.at[output_idx, 0:2].set(output_conns) - conns = conns.at[output_idx, 2].set(True) - - rand_keys_conns = jax.random.split(randkey, num=len(input_idx) + len(output_idx)) - input_conn_keys, output_conn_keys = ( - rand_keys_conns[: len(input_idx)], - rand_keys_conns[len(input_idx) :], - ) - - conn_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(None, 0)) - input_conn_attrs = conn_attr_func(input_conn_keys) - output_conn_attrs = conn_attr_func(output_conn_keys) - - conns = conns.at[input_idx, 3:].set(input_conn_attrs) - conns = conns.at[output_idx, 3:].set(output_conn_attrs) - - return nodes, conns - - -# random dense connections with 1 hidden layer -def init_dense_hideen_layer(genome, randkey, hiddens=20): - k1, k2, k3 = jax.random.split(randkey, num=3) - input_idx, output_idx = genome.input_idx, genome.output_idx - input_size = len(input_idx) - output_size = len(output_idx) - - hidden_idx = jnp.arange( - input_size + output_size, input_size + output_size + hiddens - ) - nodes = jnp.full( - (genome.max_nodes, genome.node_gene.length), jnp.nan, dtype=jnp.float32 - ) - nodes = nodes.at[input_idx, 0].set(input_idx) - nodes = nodes.at[output_idx, 0].set(output_idx) - nodes = nodes.at[hidden_idx, 0].set(hidden_idx) - - total_idx = input_size + output_size + hiddens - rand_keys_n = jax.random.split(k1, num=total_idx) - input_keys = rand_keys_n[:input_size] - output_keys = rand_keys_n[input_size : input_size + output_size] - hidden_keys = rand_keys_n[input_size + output_size :] - - node_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(0)) - input_attrs = node_attr_func(input_keys) - output_attrs = node_attr_func(output_keys) - hidden_attrs = node_attr_func(hidden_keys) - - nodes = nodes.at[input_idx, 1:].set(input_attrs) - nodes = nodes.at[output_idx, 1:].set(output_attrs) - nodes = nodes.at[hidden_idx, 1:].set(hidden_attrs) - - total_connections = input_size * hiddens + hiddens * output_size - conns = jnp.full( - (genome.max_conns, genome.conn_gene.length), jnp.nan, dtype=jnp.float32 - ) - - rand_keys_c = jax.random.split(k2, num=total_connections) - conns_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(0)) - conns_attrs = conns_attr_func(rand_keys_c) - - input_to_hidden_ids, hidden_ids = jnp.meshgrid(input_idx, hidden_idx, indexing="ij") - hidden_to_output_ids, output_ids = jnp.meshgrid( - hidden_idx, output_idx, indexing="ij" - ) - - conns = conns.at[: input_size * hiddens, 0].set(input_to_hidden_ids.flatten()) - conns = conns.at[: input_size * hiddens, 1].set(hidden_ids.flatten()) - conns = conns.at[input_size * hiddens : total_connections, 0].set( - hidden_to_output_ids.flatten() - ) - conns = conns.at[input_size * hiddens : total_connections, 1].set( - output_ids.flatten() - ) - conns = conns.at[: input_size * hiddens + hiddens * output_size, 2].set(True) - conns = conns.at[: input_size * hiddens + hiddens * output_size, 3:].set( - conns_attrs - ) - - return nodes, conns - - -# random sparse connections with no hidden nodes -def init_no_hidden_random(genome, randkey): - k1, k2, k3 = jax.random.split(randkey, num=3) - input_idx, output_idx = genome.input_idx, genome.output_idx - - nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan) - nodes = nodes.at[input_idx, 0].set(input_idx) - nodes = nodes.at[output_idx, 0].set(output_idx) - - total_idx = len(input_idx) + len(output_idx) - rand_keys_n = jax.random.split(k1, num=total_idx) - input_keys = rand_keys_n[: len(input_idx)] - output_keys = rand_keys_n[len(input_idx) :] - - node_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(0)) - input_attrs = node_attr_func(input_keys) - output_attrs = node_attr_func(output_keys) - nodes = nodes.at[input_idx, 1:].set(input_attrs) - nodes = nodes.at[output_idx, 1:].set(output_attrs) - - conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan) - - num_connections_per_output = 4 - total_connections = len(output_idx) * num_connections_per_output - - def create_connections_for_output(key): - permuted_inputs = jax.random.permutation(key, input_idx) - selected_inputs = permuted_inputs[:num_connections_per_output] - return selected_inputs - - conn_keys = jax.random.split(k2, num=len(output_idx)) - connections = jax.vmap(create_connections_for_output)(conn_keys) - connections = connections.flatten() - - output_repeats = jnp.repeat(output_idx, num_connections_per_output) - - rand_keys_c = jax.random.split(k3, num=total_connections) - conns_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(0)) - conns_attrs = conns_attr_func(rand_keys_c) - - conns = conns.at[:total_connections, 0].set(connections) - conns = conns.at[:total_connections, 1].set(output_repeats) - conns = conns.at[:total_connections, 2].set(True) # enabled - conns = conns.at[:total_connections, 3:].set(conns_attrs) - - return nodes, conns + return val diff --git a/tensorneat/examples/brax/ant.py b/tensorneat/examples/brax/ant.py index 1f8bb46..faff804 100644 --- a/tensorneat/examples/brax/ant.py +++ b/tensorneat/examples/brax/ant.py @@ -4,7 +4,7 @@ from algorithm.neat import * from problem.rl_env import BraxEnv from utils import Act -if __name__ == '__main__': +if __name__ == "__main__": pipeline = Pipeline( algorithm=NEAT( species=DefaultSpecies( @@ -17,21 +17,21 @@ if __name__ == '__main__': activation_options=(Act.tanh,), activation_default=Act.tanh, ), - output_transform=Act.tanh + output_transform=Act.tanh, ), pop_size=1000, species_size=10, ), ), problem=BraxEnv( - env_name='ant', + env_name="ant", ), generation_limit=10000, - fitness_target=5000 + fitness_target=5000, ) # initialize state state = pipeline.setup() # print(state) # run until terminate - state, best = pipeline.auto_run(state) \ No newline at end of file + state, best = pipeline.auto_run(state) diff --git a/tensorneat/examples/brax/half_cheetah.py b/tensorneat/examples/brax/half_cheetah.py index 4d31efe..5fa1ca6 100644 --- a/tensorneat/examples/brax/half_cheetah.py +++ b/tensorneat/examples/brax/half_cheetah.py @@ -4,7 +4,7 @@ from algorithm.neat import * from problem.rl_env import BraxEnv from utils import Act -if __name__ == '__main__': +if __name__ == "__main__": pipeline = Pipeline( algorithm=NEAT( species=DefaultSpecies( @@ -16,21 +16,21 @@ if __name__ == '__main__': node_gene=DefaultNodeGene( activation_options=(Act.tanh,), activation_default=Act.tanh, - ) + ), ), pop_size=1000, species_size=10, ), ), problem=BraxEnv( - env_name='halfcheetah', + env_name="halfcheetah", ), generation_limit=10000, - fitness_target=5000 + fitness_target=5000, ) # initialize state state = pipeline.setup() # print(state) # run until terminate - state, best = pipeline.auto_run(state) \ No newline at end of file + state, best = pipeline.auto_run(state) diff --git a/tensorneat/examples/brax/reacher.py b/tensorneat/examples/brax/reacher.py index 73f4d14..c9a27aa 100644 --- a/tensorneat/examples/brax/reacher.py +++ b/tensorneat/examples/brax/reacher.py @@ -4,7 +4,7 @@ from algorithm.neat import * from problem.rl_env import BraxEnv from utils import Act -if __name__ == '__main__': +if __name__ == "__main__": pipeline = Pipeline( algorithm=NEAT( species=DefaultSpecies( @@ -16,21 +16,21 @@ if __name__ == '__main__': node_gene=DefaultNodeGene( activation_options=(Act.tanh,), activation_default=Act.tanh, - ) + ), ), pop_size=100, species_size=10, ), ), problem=BraxEnv( - env_name='reacher', + env_name="reacher", ), generation_limit=10000, - fitness_target=5000 + fitness_target=5000, ) # initialize state state = pipeline.setup() # print(state) # run until terminate - state, best = pipeline.auto_run(state) \ No newline at end of file + state, best = pipeline.auto_run(state) diff --git a/tensorneat/examples/brax/walker.py b/tensorneat/examples/brax/walker.py index 993567d..d128d21 100644 --- a/tensorneat/examples/brax/walker.py +++ b/tensorneat/examples/brax/walker.py @@ -4,7 +4,7 @@ from algorithm.neat import * from problem.rl_env import BraxEnv from utils import Act -if __name__ == '__main__': +if __name__ == "__main__": pipeline = Pipeline( algorithm=NEAT( species=DefaultSpecies( @@ -16,21 +16,21 @@ if __name__ == '__main__': node_gene=DefaultNodeGene( activation_options=(Act.tanh,), activation_default=Act.tanh, - ) + ), ), pop_size=10000, species_size=10, ), ), problem=BraxEnv( - env_name='walker2d', + env_name="walker2d", ), generation_limit=10000, - fitness_target=5000 + fitness_target=5000, ) # initialize state state = pipeline.setup() # print(state) # run until terminate - state, best = pipeline.auto_run(state) \ No newline at end of file + state, best = pipeline.auto_run(state) diff --git a/tensorneat/examples/func_fit/xor.py b/tensorneat/examples/func_fit/xor.py index 89c8a7d..b628343 100644 --- a/tensorneat/examples/func_fit/xor.py +++ b/tensorneat/examples/func_fit/xor.py @@ -4,7 +4,7 @@ from algorithm.neat import * from problem.func_fit import XOR3d from utils import Act -if __name__ == '__main__': +if __name__ == "__main__": pipeline = Pipeline( algorithm=NEAT( species=DefaultSpecies( @@ -18,22 +18,22 @@ if __name__ == '__main__': activation_options=(Act.tanh,), ), output_transform=Act.sigmoid, # the activation function for output node + mutation=DefaultMutation( + node_add=0.05, + conn_add=0.2, + node_delete=0, + conn_delete=0, + ), ), pop_size=10000, species_size=10, compatibility_threshold=3.5, survival_threshold=0.01, # magic ), - mutation=DefaultMutation( - node_add=0.05, - conn_add=0.2, - node_delete=0, - conn_delete=0, - ) ), problem=XOR3d(), generation_limit=10000, - fitness_target=-1e-8 + fitness_target=-1e-8, ) # initialize state diff --git a/tensorneat/examples/func_fit/xor3d_hyperneat.py b/tensorneat/examples/func_fit/xor3d_hyperneat.py index 412066b..933d4aa 100644 --- a/tensorneat/examples/func_fit/xor3d_hyperneat.py +++ b/tensorneat/examples/func_fit/xor3d_hyperneat.py @@ -5,17 +5,28 @@ from utils import Act from problem.func_fit import XOR3d -if __name__ == '__main__': +if __name__ == "__main__": pipeline = Pipeline( algorithm=HyperNEAT( substrate=FullSubstrate( input_coors=[(-1, -1), (0.333, -1), (-0.333, -1), (1, -1)], hidden_coors=[ - (-1, -0.5), (0.333, -0.5), (-0.333, -0.5), (1, -0.5), - (-1, 0), (0.333, 0), (-0.333, 0), (1, 0), - (-1, 0.5), (0.333, 0.5), (-0.333, 0.5), (1, 0.5), + (-1, -0.5), + (0.333, -0.5), + (-0.333, -0.5), + (1, -0.5), + (-1, 0), + (0.333, 0), + (-0.333, 0), + (1, 0), + (-1, 0.5), + (0.333, 0.5), + (-0.333, 0.5), + (1, 0.5), + ], + output_coors=[ + (0, 1), ], - output_coors=[(0, 1), ], ), neat=NEAT( species=DefaultSpecies( @@ -42,7 +53,7 @@ if __name__ == '__main__': ), problem=XOR3d(), generation_limit=300, - fitness_target=-1e-6 + fitness_target=-1e-6, ) # initialize state diff --git a/tensorneat/examples/func_fit/xor_recurrent.py b/tensorneat/examples/func_fit/xor_recurrent.py index 41aad28..7599262 100644 --- a/tensorneat/examples/func_fit/xor_recurrent.py +++ b/tensorneat/examples/func_fit/xor_recurrent.py @@ -1,10 +1,11 @@ from pipeline import Pipeline from algorithm.neat import * +from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse from problem.func_fit import XOR3d from utils.activation import ACT_ALL, Act -if __name__ == '__main__': +if __name__ == "__main__": pipeline = Pipeline( seed=0, algorithm=NEAT( @@ -15,27 +16,26 @@ if __name__ == '__main__': max_nodes=50, max_conns=100, activate_time=5, - node_gene=DefaultNodeGene( - activation_options=ACT_ALL, - activation_replace_rate=0.2 + node_gene=NodeGeneWithoutResponse( + activation_options=ACT_ALL, activation_replace_rate=0.2 + ), + output_transform=Act.sigmoid, + mutation=DefaultMutation( + node_add=0.05, + conn_add=0.2, + node_delete=0, + conn_delete=0, ), - output_transform=Act.sigmoid ), pop_size=10000, species_size=10, compatibility_threshold=3.5, survival_threshold=0.03, ), - mutation=DefaultMutation( - node_add=0.05, - conn_add=0.2, - node_delete=0, - conn_delete=0, - ) ), problem=XOR3d(), generation_limit=10000, - fitness_target=-1e-8 + fitness_target=-1e-8, ) # initialize state diff --git a/tensorneat/examples/gymnax/arcbot.py b/tensorneat/examples/gymnax/arcbot.py index e56ffa1..dc0bd4d 100644 --- a/tensorneat/examples/gymnax/arcbot.py +++ b/tensorneat/examples/gymnax/arcbot.py @@ -5,7 +5,7 @@ from algorithm.neat import * from problem.rl_env import GymNaxEnv -if __name__ == '__main__': +if __name__ == "__main__": pipeline = Pipeline( algorithm=NEAT( species=DefaultSpecies( @@ -14,21 +14,23 @@ if __name__ == '__main__': num_outputs=3, max_nodes=50, max_conns=100, - output_transform=lambda out: jnp.argmax(out) # the action of acrobot is {0, 1, 2} + output_transform=lambda out: jnp.argmax( + out + ), # the action of acrobot is {0, 1, 2} ), pop_size=10000, species_size=10, ), ), problem=GymNaxEnv( - env_name='Acrobot-v1', + env_name="Acrobot-v1", ), generation_limit=10000, - fitness_target=-62 + fitness_target=-62, ) # initialize state state = pipeline.setup() # print(state) # run until terminate - state, best = pipeline.auto_run(state) \ No newline at end of file + state, best = pipeline.auto_run(state) diff --git a/tensorneat/examples/gymnax/cartpole.py b/tensorneat/examples/gymnax/cartpole.py index 75e9d88..0199e3d 100644 --- a/tensorneat/examples/gymnax/cartpole.py +++ b/tensorneat/examples/gymnax/cartpole.py @@ -5,7 +5,7 @@ from algorithm.neat import * from problem.rl_env import GymNaxEnv -if __name__ == '__main__': +if __name__ == "__main__": pipeline = Pipeline( algorithm=NEAT( species=DefaultSpecies( @@ -14,21 +14,23 @@ if __name__ == '__main__': num_outputs=2, max_nodes=50, max_conns=100, - output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1} + output_transform=lambda out: jnp.argmax( + out + ), # the action of cartpole is {0, 1} ), pop_size=10000, species_size=10, ), ), problem=GymNaxEnv( - env_name='CartPole-v1', + env_name="CartPole-v1", ), generation_limit=10000, - fitness_target=500 + fitness_target=500, ) # initialize state state = pipeline.setup() # print(state) # run until terminate - state, best = pipeline.auto_run(state) \ No newline at end of file + state, best = pipeline.auto_run(state) diff --git a/tensorneat/examples/gymnax/cartpole_hyperneat.py b/tensorneat/examples/gymnax/cartpole_hyperneat.py index 3a689ff..200cec5 100644 --- a/tensorneat/examples/gymnax/cartpole_hyperneat.py +++ b/tensorneat/examples/gymnax/cartpole_hyperneat.py @@ -10,11 +10,7 @@ from problem.rl_env import GymNaxConfig, GymNaxEnv def example_conf(): return Config( - basic=BasicConfig( - seed=42, - fitness_target=500, - pop_size=10000 - ), + basic=BasicConfig(seed=42, fitness_target=500, pop_size=10000), neat=NeatConfig( inputs=4, outputs=1, @@ -23,28 +19,31 @@ def example_conf(): activation_default=Act.tanh, activation_options=(Act.tanh,), ), - hyperneat=HyperNeatConfig( - activation=Act.sigmoid, - inputs=4, - outputs=2 - ), + hyperneat=HyperNeatConfig(activation=Act.sigmoid, inputs=4, outputs=2), substrate=NormalSubstrateConfig( input_coors=((-1, -1), (-0.5, -1), (0, -1), (0.5, -1), (1, -1)), hidden_coors=( # (-1, -0.5), (-0.5, -0.5), (0, -0.5), (0.5, -0.5), - (1, 0), (-1, 0), (-0.5, 0), (0, 0), (0.5, 0), (1, 0), + (1, 0), + (-1, 0), + (-0.5, 0), + (0, 0), + (0.5, 0), + (1, 0), # (1, 0.5), (-1, 0.5), (-0.5, 0.5), (0, 0.5), (0.5, 0.5), (1, 0.5), ), output_coors=((-1, 1), (1, 1)), ), problem=GymNaxConfig( - env_name='CartPole-v1', - output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1} - ) + env_name="CartPole-v1", + output_transform=lambda out: jnp.argmax( + out + ), # the action of cartpole is {0, 1} + ), ) -if __name__ == '__main__': +if __name__ == "__main__": conf = example_conf() algorithm = HyperNEAT(conf, NormalGene, NormalSubstrate) diff --git a/tensorneat/examples/gymnax/mountain_car.py b/tensorneat/examples/gymnax/mountain_car.py index d9082cf..38c4eb0 100644 --- a/tensorneat/examples/gymnax/mountain_car.py +++ b/tensorneat/examples/gymnax/mountain_car.py @@ -5,7 +5,7 @@ from algorithm.neat import * from problem.rl_env import GymNaxEnv -if __name__ == '__main__': +if __name__ == "__main__": pipeline = Pipeline( algorithm=NEAT( species=DefaultSpecies( @@ -14,21 +14,23 @@ if __name__ == '__main__': num_outputs=3, max_nodes=50, max_conns=100, - output_transform=lambda out: jnp.argmax(out) # the action of mountain car is {0, 1, 2} + output_transform=lambda out: jnp.argmax( + out + ), # the action of mountain car is {0, 1, 2} ), pop_size=10000, species_size=10, ), ), problem=GymNaxEnv( - env_name='MountainCar-v0', + env_name="MountainCar-v0", ), generation_limit=10000, - fitness_target=0 + fitness_target=0, ) # initialize state state = pipeline.setup() # print(state) # run until terminate - state, best = pipeline.auto_run(state) \ No newline at end of file + state, best = pipeline.auto_run(state) diff --git a/tensorneat/examples/gymnax/mountain_car_continuous.py b/tensorneat/examples/gymnax/mountain_car_continuous.py index f863e52..1edd4f1 100644 --- a/tensorneat/examples/gymnax/mountain_car_continuous.py +++ b/tensorneat/examples/gymnax/mountain_car_continuous.py @@ -4,7 +4,7 @@ from algorithm.neat import * from problem.rl_env import GymNaxEnv from utils import Act -if __name__ == '__main__': +if __name__ == "__main__": pipeline = Pipeline( algorithm=NEAT( species=DefaultSpecies( @@ -14,23 +14,23 @@ if __name__ == '__main__': max_nodes=50, max_conns=100, node_gene=DefaultNodeGene( - activation_options=(Act.tanh, ), + activation_options=(Act.tanh,), activation_default=Act.tanh, - ) + ), ), pop_size=10000, species_size=10, ), ), problem=GymNaxEnv( - env_name='MountainCarContinuous-v0', + env_name="MountainCarContinuous-v0", ), generation_limit=10000, - fitness_target=500 + fitness_target=500, ) # initialize state state = pipeline.setup() # print(state) # run until terminate - state, best = pipeline.auto_run(state) \ No newline at end of file + state, best = pipeline.auto_run(state) diff --git a/tensorneat/examples/gymnax/pendulum.py b/tensorneat/examples/gymnax/pendulum.py index 7073dbe..9542c37 100644 --- a/tensorneat/examples/gymnax/pendulum.py +++ b/tensorneat/examples/gymnax/pendulum.py @@ -4,7 +4,7 @@ from algorithm.neat import * from problem.rl_env import GymNaxEnv from utils import Act -if __name__ == '__main__': +if __name__ == "__main__": pipeline = Pipeline( algorithm=NEAT( species=DefaultSpecies( @@ -17,21 +17,22 @@ if __name__ == '__main__': activation_options=(Act.tanh,), activation_default=Act.tanh, ), - output_transform=lambda out: out * 2 # the action of pendulum is [-2, 2] + output_transform=lambda out: out + * 2, # the action of pendulum is [-2, 2] ), pop_size=10000, species_size=10, ), ), problem=GymNaxEnv( - env_name='Pendulum-v1', + env_name="Pendulum-v1", ), generation_limit=10000, - fitness_target=0 + fitness_target=0, ) # initialize state state = pipeline.setup() # print(state) # run until terminate - state, best = pipeline.auto_run(state) \ No newline at end of file + state, best = pipeline.auto_run(state) diff --git a/tensorneat/examples/gymnax/reacher.py b/tensorneat/examples/gymnax/reacher.py index acf23aa..d6b5345 100644 --- a/tensorneat/examples/gymnax/reacher.py +++ b/tensorneat/examples/gymnax/reacher.py @@ -5,7 +5,7 @@ from algorithm.neat import * from problem.rl_env import GymNaxEnv -if __name__ == '__main__': +if __name__ == "__main__": pipeline = Pipeline( algorithm=NEAT( species=DefaultSpecies( @@ -20,14 +20,14 @@ if __name__ == '__main__': ), ), problem=GymNaxEnv( - env_name='Reacher-misc', + env_name="Reacher-misc", ), generation_limit=10000, - fitness_target =500 + fitness_target=500, ) # initialize state state = pipeline.setup() # print(state) # run until terminate - state, best = pipeline.auto_run(state) \ No newline at end of file + state, best = pipeline.auto_run(state) diff --git a/tensorneat/examples/with_evox/ray_test.py b/tensorneat/examples/with_evox/ray_test.py index 6b850da..0a6e8c9 100644 --- a/tensorneat/examples/with_evox/ray_test.py +++ b/tensorneat/examples/with_evox/ray_test.py @@ -1,4 +1,5 @@ import ray + ray.init(num_gpus=2) available_resources = ray.available_resources() diff --git a/tensorneat/pipeline.py b/tensorneat/pipeline.py index 5ffa22d..b3d7990 100644 --- a/tensorneat/pipeline.py +++ b/tensorneat/pipeline.py @@ -10,14 +10,13 @@ from utils import State class Pipeline: - def __init__( - self, - algorithm: BaseAlgorithm, - problem: BaseProblem, - seed: int = 42, - fitness_target: float = 1, - generation_limit: int = 1000, + self, + algorithm: BaseAlgorithm, + problem: BaseProblem, + seed: int = 42, + fitness_target: float = 1, + generation_limit: int = 1000, ): assert problem.jitable, "Currently, problem must be jitable" @@ -31,32 +30,35 @@ class Pipeline: # print(self.problem.input_shape, self.problem.output_shape) # TODO: make each algorithm's input_num and output_num - assert algorithm.num_inputs == self.problem.input_shape[-1], \ - f"algorithm input shape is {algorithm.num_inputs} but problem input shape is {self.problem.input_shape}" + assert ( + algorithm.num_inputs == self.problem.input_shape[-1] + ), f"algorithm input shape is {algorithm.num_inputs} but problem input shape is {self.problem.input_shape}" self.best_genome = None - self.best_fitness = float('-inf') + self.best_fitness = float("-inf") self.generation_timestamp = None def setup(self, state=State()): + print("initializing") state = state.register(randkey=jax.random.PRNGKey(self.seed)) state = self.algorithm.setup(state) state = self.problem.setup(state) + print("initializing finished") return state def step(self, state): + randkey_, randkey = jax.random.split(state.randkey) keys = jax.random.split(randkey_, self.pop_size) - state, pop = self.algorithm.ask(state) + pop = self.algorithm.ask(state) - state, pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0), out_axes=(None, 0))(state, pop) + pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0))( + state, pop + ) - state, fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0), out_axes=(None, 0))( - keys, - state, - self.algorithm.forward, - pop_transformed + fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))( + state, keys, self.algorithm.forward, pop_transformed ) state = self.algorithm.tell(state, fitnesses) @@ -67,13 +69,15 @@ class Pipeline: print("start compile") tic = time.time() compiled_step = jax.jit(self.step).lower(state).compile() - print(f"compile finished, cost time: {time.time() - tic:.6f}s", ) + print( + f"compile finished, cost time: {time.time() - tic:.6f}s", + ) for _ in range(self.generation_limit): self.generation_timestamp = time.time() - state, previous_pop = self.algorithm.ask(state) + previous_pop = self.algorithm.ask(state) state, fitnesses = compiled_step(state) @@ -98,7 +102,12 @@ class Pipeline: def analysis(self, state, pop, fitnesses): - max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses) + max_f, min_f, mean_f, std_f = ( + max(fitnesses), + min(fitnesses), + np.mean(fitnesses), + np.std(fitnesses), + ) new_timestamp = time.time() @@ -112,10 +121,14 @@ class Pipeline: member_count = jax.device_get(self.algorithm.member_count(state)) species_sizes = [int(i) for i in member_count if i > 0] - print(f"Generation: {self.algorithm.generation(state)}", - f"species: {len(species_sizes)}, {species_sizes}", - f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms") + print( + f"Generation: {self.algorithm.generation(state)}", + f"species: {len(species_sizes)}, {species_sizes}", + f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms", + ) def show(self, state, best, *args, **kwargs): - state, transformed = self.algorithm.transform(state, best) - self.problem.show(state.randkey, state, self.algorithm.forward, transformed, *args, **kwargs) + transformed = self.algorithm.transform(state, best) + self.problem.show( + state, state.randkey, self.algorithm.forward, transformed, *args, **kwargs + ) diff --git a/tensorneat/problem/base.py b/tensorneat/problem/base.py index 67e73c1..712c989 100644 --- a/tensorneat/problem/base.py +++ b/tensorneat/problem/base.py @@ -10,7 +10,7 @@ class BaseProblem: """initialize the state of the problem""" return state - def evaluate(self, randkey, state: State, act_func: Callable, params): + def evaluate(self, state: State, randkey, act_func: Callable, params): """evaluate one individual""" raise NotImplementedError @@ -32,7 +32,7 @@ class BaseProblem: """ raise NotImplementedError - def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs): + def show(self, state: State, randkey, act_func: Callable, params, *args, **kwargs): """ show how a genome perform in this problem """ diff --git a/tensorneat/problem/func_fit/func_fit.py b/tensorneat/problem/func_fit/func_fit.py index 3c71415..ff55342 100644 --- a/tensorneat/problem/func_fit/func_fit.py +++ b/tensorneat/problem/func_fit/func_fit.py @@ -8,42 +8,44 @@ from .. import BaseProblem class FuncFit(BaseProblem): jitable = True - def __init__(self, - error_method: str = 'mse' - ): + def __init__(self, error_method: str = "mse"): super().__init__() - assert error_method in {'mse', 'rmse', 'mae', 'mape'} + assert error_method in {"mse", "rmse", "mae", "mape"} self.error_method = error_method def setup(self, state: State = State()): return state - def evaluate(self, randkey, state, act_func, params): + def evaluate(self, state, randkey, act_func, params): - state, predict = jax.vmap(act_func, in_axes=(None, 0, None), out_axes=(None, 0))(state, self.inputs, params) + predict = jax.vmap(act_func, in_axes=(None, 0, None))( + state, self.inputs, params + ) - if self.error_method == 'mse': + if self.error_method == "mse": loss = jnp.mean((predict - self.targets) ** 2) - elif self.error_method == 'rmse': + elif self.error_method == "rmse": loss = jnp.sqrt(jnp.mean((predict - self.targets) ** 2)) - elif self.error_method == 'mae': + elif self.error_method == "mae": loss = jnp.mean(jnp.abs(predict - self.targets)) - elif self.error_method == 'mape': + elif self.error_method == "mape": loss = jnp.mean(jnp.abs((predict - self.targets) / self.targets)) else: raise NotImplementedError - return state, -loss + return -loss - def show(self, randkey, state, act_func, params, *args, **kwargs): - state, predict = jax.vmap(act_func, in_axes=(None, 0, None), out_axes=(None, 0))(state, self.inputs, params) + def show(self, state, randkey, act_func, params, *args, **kwargs): + predict = jax.vmap(act_func, in_axes=(None, 0, None))( + state, self.inputs, params + ) inputs, target, predict = jax.device_get([self.inputs, self.targets, predict]) - state, loss = self.evaluate(randkey, state, act_func, params) + loss = self.evaluate(state, randkey, act_func, params) loss = -loss msg = "" diff --git a/tensorneat/problem/func_fit/xor.py b/tensorneat/problem/func_fit/xor.py index 65b1e0c..c9544b9 100644 --- a/tensorneat/problem/func_fit/xor.py +++ b/tensorneat/problem/func_fit/xor.py @@ -4,27 +4,16 @@ from .func_fit import FuncFit class XOR(FuncFit): - - def __init__(self, error_method: str = 'mse'): + def __init__(self, error_method: str = "mse"): super().__init__(error_method) @property def inputs(self): - return np.array([ - [0, 0], - [0, 1], - [1, 0], - [1, 1] - ]) + return np.array([[0, 0], [0, 1], [1, 0], [1, 1]]) @property def targets(self): - return np.array([ - [0], - [1], - [1], - [0] - ]) + return np.array([[0], [1], [1], [0]]) @property def input_shape(self): diff --git a/tensorneat/problem/func_fit/xor3d.py b/tensorneat/problem/func_fit/xor3d.py index 1ae8b1b..7c9877d 100644 --- a/tensorneat/problem/func_fit/xor3d.py +++ b/tensorneat/problem/func_fit/xor3d.py @@ -4,35 +4,27 @@ from .func_fit import FuncFit class XOR3d(FuncFit): - - def __init__(self, error_method: str = 'mse'): + def __init__(self, error_method: str = "mse"): super().__init__(error_method) @property def inputs(self): - return np.array([ - [0, 0, 0], - [0, 0, 1], - [0, 1, 0], - [0, 1, 1], - [1, 0, 0], - [1, 0, 1], - [1, 1, 0], - [1, 1, 1], - ]) + return np.array( + [ + [0, 0, 0], + [0, 0, 1], + [0, 1, 0], + [0, 1, 1], + [1, 0, 0], + [1, 0, 1], + [1, 1, 0], + [1, 1, 1], + ] + ) @property def targets(self): - return np.array([ - [0], - [1], - [1], - [0], - [1], - [0], - [0], - [1] - ]) + return np.array([[0], [1], [1], [0], [1], [0], [0], [1]]) @property def input_shape(self): diff --git a/tensorneat/problem/rl_env/brax_env.py b/tensorneat/problem/rl_env/brax_env.py index 4e0b505..7a65bfe 100644 --- a/tensorneat/problem/rl_env/brax_env.py +++ b/tensorneat/problem/rl_env/brax_env.py @@ -25,7 +25,19 @@ class BraxEnv(RLEnv): def output_shape(self): return (self.env.action_size,) - def show(self, randkey, state, act_func, params, save_path=None, height=512, width=512, duration=0.1, *args, **kwargs): + def show( + self, + state, + randkey, + act_func, + params, + save_path=None, + height=512, + width=512, + duration=0.1, + *args, + **kwargs + ): import jax import imageio @@ -48,11 +60,13 @@ class BraxEnv(RLEnv): key, env_state, obs, r, done = jax.jit(step)(randkey, env_state, obs) reward += r - imgs = [image.render_array(sys=self.env.sys, state=s, width=width, height=height) for s in - tqdm(state_histories, desc="Rendering")] + imgs = [ + image.render_array(sys=self.env.sys, state=s, width=width, height=height) + for s in tqdm(state_histories, desc="Rendering") + ] def create_gif(image_list, gif_name, duration): - with imageio.get_writer(gif_name, mode='I', duration=duration) as writer: + with imageio.get_writer(gif_name, mode="I", duration=duration) as writer: for image in image_list: formatted_image = np.array(image, dtype=np.uint8) writer.append_data(formatted_image) @@ -60,5 +74,3 @@ class BraxEnv(RLEnv): create_gif(imgs, save_path, duration=0.1) print("Gif saved to: ", save_path) print("Total reward: ", reward) - - diff --git a/tensorneat/problem/rl_env/gymnax_env.py b/tensorneat/problem/rl_env/gymnax_env.py index a0b2bb3..95af8fa 100644 --- a/tensorneat/problem/rl_env/gymnax_env.py +++ b/tensorneat/problem/rl_env/gymnax_env.py @@ -4,7 +4,6 @@ from .rl_jit import RLEnv class GymNaxEnv(RLEnv): - def __init__(self, env_name): super().__init__() assert env_name in gymnax.registered_envs, f"Env {env_name} not registered" @@ -24,5 +23,5 @@ class GymNaxEnv(RLEnv): def output_shape(self): return self.env.action_space(self.env_params).shape - def show(self, randkey, state, act_func, params, *args, **kwargs): + def show(self, state, randkey, act_func, params, *args, **kwargs): raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).") diff --git a/tensorneat/problem/rl_env/rl_jit.py b/tensorneat/problem/rl_env/rl_jit.py index 4edf924..73ba04e 100644 --- a/tensorneat/problem/rl_env/rl_jit.py +++ b/tensorneat/problem/rl_env/rl_jit.py @@ -12,29 +12,29 @@ class RLEnv(BaseProblem): super().__init__() self.max_step = max_step - def evaluate(self, randkey, state, act_func, params): + def evaluate(self, state, randkey, act_func, params): rng_reset, rng_episode = jax.random.split(randkey) init_obs, init_env_state = self.reset(rng_reset) def cond_func(carry): - _, _, _, _, done, _, count = carry - return ~done & (count < self.max_step) + _, _, _, done, _, count = carry + return ~done & (count < self.max_step) def body_func(carry): - state_, obs, env_state, rng, done, tr, count = carry # tr -> total reward - state_, action = act_func(state_, obs, params) - next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action) - next_rng, _ = jax.random.split(rng) - return state_, next_obs, next_env_state, next_rng, done, tr + reward, count + 1 + obs, env_state, rng, done, tr, count = carry # tr -> total reward + action = act_func(state, obs, params) + next_obs, next_env_state, reward, done, _ = self.step( + rng, env_state, action + ) + next_rng, _ = jax.random.split(rng) + return next_obs, next_env_state, next_rng, done, tr + reward, count + 1 - state, _, _, _, _, total_reward, _ = jax.lax.while_loop( - cond_func, - body_func, - (state, init_obs, init_env_state, rng_episode, False, 0.0, 0) + _, _, _, _, total_reward, _ = jax.lax.while_loop( + cond_func, body_func, (init_obs, init_env_state, rng_episode, False, 0.0, 0) ) - return state, total_reward - + return total_reward + # @partial(jax.jit, static_argnums=(0,)) def step(self, randkey, env_state, action): return self.env_step(randkey, env_state, action) @@ -57,5 +57,5 @@ class RLEnv(BaseProblem): def output_shape(self): raise NotImplementedError - def show(self, randkey, state, act_func, params, *args, **kwargs): + def show(self, state, randkey, act_func, params, *args, **kwargs): raise NotImplementedError diff --git a/tensorneat/test/crossover_mutation.py b/tensorneat/test/crossover_mutation.py index 8da8761..2e6e0bf 100644 --- a/tensorneat/test/crossover_mutation.py +++ b/tensorneat/test/crossover_mutation.py @@ -36,7 +36,9 @@ def main(): elite_mask = jnp.zeros((1000,), dtype=jnp.bool_) elite_mask = elite_mask.at[:5].set(1) - state = algorithm.create_next_generation(jax.random.key(0), state, winner, losser, elite_mask) + state = algorithm.create_next_generation( + jax.random.key(0), state, winner, losser, elite_mask + ) pop_nodes, pop_conns = algorithm.species.ask(state.species) transforms = batch_transform(pop_nodes, pop_conns) @@ -48,5 +50,5 @@ def main(): print(_) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tensorneat/test/nan_fitness.py b/tensorneat/test/nan_fitness.py index 3097ebc..4b59f31 100644 --- a/tensorneat/test/nan_fitness.py +++ b/tensorneat/test/nan_fitness.py @@ -19,7 +19,7 @@ def main(): node_gene=DefaultNodeGene( activation_options=(Act.tanh,), activation_default=Act.tanh, - ) + ), ) transformed = genome.transform(nodes, conns) @@ -35,7 +35,7 @@ def main(): print(output) -if __name__ == '__main__': +if __name__ == "__main__": a = jnp.array([1, 3, 5, 6, 8]) b = jnp.array([1, 2, 3]) print(jnp.isin(a, b)) diff --git a/tensorneat/test/test_genome.py b/tensorneat/test/test_genome.py index 12ff43e..b7da2c9 100644 --- a/tensorneat/test/test_genome.py +++ b/tensorneat/test/test_genome.py @@ -2,6 +2,7 @@ from algorithm.neat import * from utils import Act, Agg, State import jax, jax.numpy as jnp +from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse def test_default(): @@ -135,3 +136,29 @@ def test_recurrent(): print(outputs) assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]])) # expected: [[0.5], [0.75], [0.5], [0.75]] + + +def test_random_initialize(): + genome = DefaultGenome( + num_inputs=2, + num_outputs=1, + max_nodes=5, + max_conns=4, + node_gene=NodeGeneWithoutResponse( + activation_default=Act.identity, + activation_options=(Act.identity,), + aggregation_default=Agg.sum, + aggregation_options=(Agg.sum,), + ), + ) + state = genome.setup() + key = jax.random.PRNGKey(0) + nodes, conns = genome.initialize(state, key) + transformed = genome.transform(state, nodes, conns) + print(*transformed, sep="\n") + + inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]]) + outputs = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, None)))( + state, inputs, transformed + ) + print(outputs) diff --git a/tensorneat/test/test_nan_fitness.py b/tensorneat/test/test_nan_fitness.py index 79247cd..c2575ef 100644 --- a/tensorneat/test/test_nan_fitness.py +++ b/tensorneat/test/test_nan_fitness.py @@ -19,11 +19,11 @@ def main(): node_gene=DefaultNodeGene( activation_options=(Act.tanh,), activation_default=Act.tanh, - ) + ), ) transformed = genome.transform(nodes, conns) - print(*transformed, sep='\n') + print(*transformed, sep="\n") key = jax.random.key(0) dummy_input = jnp.zeros((8,)) @@ -31,5 +31,5 @@ def main(): print(output) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tensorneat/utils/__init__.py b/tensorneat/utils/__init__.py index 077fcd4..7de0984 100644 --- a/tensorneat/utils/__init__.py +++ b/tensorneat/utils/__init__.py @@ -2,4 +2,4 @@ from .activation import Act, act from .aggregation import Agg, agg from .tools import * from .graph import * -from .state import State \ No newline at end of file +from .state import State diff --git a/tensorneat/utils/tools.py b/tensorneat/utils/tools.py index c592496..5f9bc11 100644 --- a/tensorneat/utils/tools.py +++ b/tensorneat/utils/tools.py @@ -116,3 +116,41 @@ def argmin_with_mask(arr, mask): masked_arr = jnp.where(mask, arr, jnp.inf) min_idx = jnp.argmin(masked_arr) return min_idx + + +def add_node(nodes, new_key: int, attrs): + """ + Add a new node to the genome. + The new node will place at the first NaN row. + """ + exist_keys = nodes[:, 0] + pos = fetch_first(jnp.isnan(exist_keys)) + new_nodes = nodes.at[pos, 0].set(new_key) + return new_nodes.at[pos, 1:].set(attrs) + + +def delete_node_by_pos(nodes, pos): + """ + Delete a node from the genome. + Delete the node by its pos in nodes. + """ + return nodes.at[pos].set(jnp.nan) + + +def add_conn(conns, i_key, o_key, enable: bool, attrs): + """ + Add a new connection to the genome. + The new connection will place at the first NaN row. + """ + con_keys = conns[:, 0] + pos = fetch_first(jnp.isnan(con_keys)) + new_conns = conns.at[pos, 0:3].set(jnp.array([i_key, o_key, enable])) + return new_conns.at[pos, 3:].set(attrs) + + +def delete_conn_by_pos(conns, pos): + """ + Delete a connection from the genome. + Delete the connection by its idx. + """ + return conns.at[pos].set(jnp.nan)