From 6970e6a6d5eda0711df834aa720bfa213c2b1892 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Wed, 21 Feb 2024 15:41:08 +0800 Subject: [PATCH] finish all refactoring --- algorithm/base.py | 25 ++++- algorithm/hyperneat/__init__.py | 2 + algorithm/hyperneat/hyperneat.py | 116 +++++++++++++++++++++ algorithm/hyperneat/substrate/__init__.py | 3 + algorithm/hyperneat/substrate/base.py | 27 +++++ algorithm/hyperneat/substrate/default.py | 38 +++++++ algorithm/hyperneat/substrate/full.py | 76 ++++++++++++++ algorithm/neat/__init__.py | 2 + algorithm/neat/ga/crossover/default.py | 3 +- algorithm/neat/ga/mutation/default.py | 35 ++++--- algorithm/neat/gene/conn/default.py | 3 +- algorithm/neat/gene/node/default.py | 5 +- algorithm/neat/genome/base.py | 1 - algorithm/neat/genome/default.py | 19 +++- algorithm/neat/genome/recurrent.py | 4 +- algorithm/neat/neat.py | 49 ++++++--- algorithm/neat/species/default.py | 55 +++++----- examples/brax/ant.py | 62 ++++++----- examples/brax/half_cheetah.py | 68 ++++++------ examples/brax/reacher.py | 62 ++++++----- examples/brax_env.py | 73 ------------- examples/brax_render.py | 54 ---------- examples/func_fit/xor.py | 43 ++++---- examples/func_fit/xor3d_hyperneat.py | 51 +++++++++ examples/func_fit/xor_hyperneat.py | 41 -------- examples/func_fit/xor_recurrent.py | 62 +++++------ examples/general_xor.py | 36 ------- examples/gymnax/acrobot.py | 39 ------- examples/gymnax/arcbot.py | 34 ++++++ examples/gymnax/cartpole.py | 102 +++++------------- examples/gymnax/mountain_car.py | 57 +++++----- examples/gymnax/mountain_car_continuous.py | 62 ++++++----- examples/gymnax/pendulum.py | 65 ++++++------ examples/gymnax/reacher.py | 57 +++++----- pipeline.py | 90 ++++++++-------- problem/base.py | 9 +- problem/func_fit/func_fit.py | 13 ++- problem/rl_env/__init__.py | 4 +- problem/rl_env/gymnax_env.py | 1 - problem/rl_env/rl_jit.py | 5 +- t.py | 66 +----------- test/test_genome.py | 4 + utils/activation.py | 55 ++++------ utils/aggregation.py | 3 + 44 files changed, 856 insertions(+), 825 deletions(-) create mode 100644 algorithm/hyperneat/__init__.py create mode 100644 algorithm/hyperneat/hyperneat.py create mode 100644 algorithm/hyperneat/substrate/__init__.py create mode 100644 algorithm/hyperneat/substrate/base.py create mode 100644 algorithm/hyperneat/substrate/default.py create mode 100644 algorithm/hyperneat/substrate/full.py delete mode 100644 examples/brax_env.py delete mode 100644 examples/brax_render.py create mode 100644 examples/func_fit/xor3d_hyperneat.py delete mode 100644 examples/func_fit/xor_hyperneat.py delete mode 100644 examples/general_xor.py delete mode 100644 examples/gymnax/acrobot.py create mode 100644 examples/gymnax/arcbot.py diff --git a/algorithm/base.py b/algorithm/base.py index 36789c8..93aeafe 100644 --- a/algorithm/base.py +++ b/algorithm/base.py @@ -16,9 +16,30 @@ class BaseAlgorithm: """update the state of the algorithm""" raise NotImplementedError - def transform(self, state: State): + def transform(self, individual): """transform the genome into a neural network""" raise NotImplementedError def forward(self, inputs, transformed): - raise NotImplementedError \ No newline at end of file + raise NotImplementedError + + @property + def num_inputs(self): + raise NotImplementedError + + @property + def num_outputs(self): + raise NotImplementedError + + @property + def pop_size(self): + raise NotImplementedError + + def member_count(self, state: State): + # to analysis the species + raise NotImplementedError + + def generation(self, state: State): + # to analysis the algorithm + raise NotImplementedError + diff --git a/algorithm/hyperneat/__init__.py b/algorithm/hyperneat/__init__.py new file mode 100644 index 0000000..374e2aa --- /dev/null +++ b/algorithm/hyperneat/__init__.py @@ -0,0 +1,2 @@ +from .hyperneat import HyperNEAT +from .substrate import BaseSubstrate, DefaultSubstrate, FullSubstrate \ No newline at end of file diff --git a/algorithm/hyperneat/hyperneat.py b/algorithm/hyperneat/hyperneat.py new file mode 100644 index 0000000..eab7693 --- /dev/null +++ b/algorithm/hyperneat/hyperneat.py @@ -0,0 +1,116 @@ +import jax, jax.numpy as jnp + +from utils import State, Act, Agg +from .. import BaseAlgorithm, NEAT +from ..neat.gene import BaseNodeGene, BaseConnGene +from ..neat.genome import RecurrentGenome +from .substrate import * + + +class HyperNEAT(BaseAlgorithm): + + def __init__( + self, + substrate: BaseSubstrate, + neat: NEAT, + below_threshold: float = 0.3, + max_weight: float = 5., + activation=Act.sigmoid, + aggregation=Agg.sum, + activate_time: int = 10, + ): + assert substrate.query_coors.shape[1] == neat.num_inputs, \ + "Substrate input size should be equal to NEAT input size" + + self.substrate = substrate + self.neat = neat + self.below_threshold = below_threshold + self.max_weight = max_weight + self.hyper_genome = RecurrentGenome( + num_inputs=substrate.num_inputs, + num_outputs=substrate.num_outputs, + max_nodes=substrate.nodes_cnt, + max_conns=substrate.conns_cnt, + node_gene=HyperNodeGene(activation, aggregation), + conn_gene=HyperNEATConnGene(), + activate_time=activate_time, + ) + + def setup(self, randkey): + return State( + neat_state=self.neat.setup(randkey) + ) + + def ask(self, state: State): + return self.neat.ask(state.neat_state) + + def tell(self, state: State, fitness): + return state.update( + neat_state=self.neat.tell(state.neat_state, fitness) + ) + + def transform(self, individual): + transformed = self.neat.transform(individual) + query_res = jax.vmap(self.neat.forward, in_axes=(0, None))(self.substrate.query_coors, transformed) + + # mute the connection with weight below threshold + query_res = jnp.where( + (-self.below_threshold < query_res) & (query_res < self.below_threshold), + 0., + query_res + ) + + # make query res in range [-max_weight, max_weight] + query_res = jnp.where(query_res > 0, query_res - self.below_threshold, query_res) + query_res = jnp.where(query_res < 0, query_res + self.below_threshold, query_res) + query_res = query_res / (1 - self.below_threshold) * self.max_weight + + h_nodes, h_conns = self.substrate.make_nodes(query_res), self.substrate.make_conn(query_res) + return self.hyper_genome.transform(h_nodes, h_conns) + + def forward(self, inputs, transformed): + # add bias + inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])]) + return self.hyper_genome.forward(inputs_with_bias, transformed) + + @property + def num_inputs(self): + return self.substrate.num_inputs - 1 # remove bias + + @property + def num_outputs(self): + return self.substrate.num_outputs + + @property + def pop_size(self): + return self.neat.pop_size + + def member_count(self, state: State): + return self.neat.member_count(state.neat_state) + + def generation(self, state: State): + return self.neat.generation(state.neat_state) + + +class HyperNodeGene(BaseNodeGene): + + def __init__(self, + activation=Act.sigmoid, + aggregation=Agg.sum, + ): + super().__init__() + self.activation = activation + self.aggregation = aggregation + + def forward(self, attrs, inputs): + return self.activation( + self.aggregation(inputs) + ) + + +class HyperNEATConnGene(BaseConnGene): + custom_attrs = ['weight'] + + def forward(self, attrs, inputs): + weight = attrs[0] + return inputs * weight diff --git a/algorithm/hyperneat/substrate/__init__.py b/algorithm/hyperneat/substrate/__init__.py new file mode 100644 index 0000000..f532c43 --- /dev/null +++ b/algorithm/hyperneat/substrate/__init__.py @@ -0,0 +1,3 @@ +from .base import BaseSubstrate +from .default import DefaultSubstrate +from .full import FullSubstrate diff --git a/algorithm/hyperneat/substrate/base.py b/algorithm/hyperneat/substrate/base.py new file mode 100644 index 0000000..3a15832 --- /dev/null +++ b/algorithm/hyperneat/substrate/base.py @@ -0,0 +1,27 @@ +class BaseSubstrate: + + def make_nodes(self, query_res): + raise NotImplementedError + + def make_conn(self, query_res): + raise NotImplementedError + + @property + def query_coors(self): + raise NotImplementedError + + @property + def num_inputs(self): + raise NotImplementedError + + @property + def num_outputs(self): + raise NotImplementedError + + @property + def nodes_cnt(self): + raise NotImplementedError + + @property + def conns_cnt(self): + raise NotImplementedError diff --git a/algorithm/hyperneat/substrate/default.py b/algorithm/hyperneat/substrate/default.py new file mode 100644 index 0000000..a7273dc --- /dev/null +++ b/algorithm/hyperneat/substrate/default.py @@ -0,0 +1,38 @@ +import jax.numpy as jnp +from . import BaseSubstrate + + +class DefaultSubstrate(BaseSubstrate): + + def __init__(self, num_inputs, num_outputs, coors, nodes, conns): + self.inputs = num_inputs + self.outputs = num_outputs + self.coors = jnp.array(coors) + self.nodes = jnp.array(nodes) + self.conns = jnp.array(conns) + + def make_nodes(self, query_res): + return self.nodes + + def make_conn(self, query_res): + return self.conns.at[:, 3:].set(query_res) # change weight + + @property + def query_coors(self): + return self.coors + + @property + def num_inputs(self): + return self.inputs + + @property + def num_outputs(self): + return self.outputs + + @property + def nodes_cnt(self): + return self.nodes.shape[0] + + @property + def conns_cnt(self): + return self.conns.shape[0] diff --git a/algorithm/hyperneat/substrate/full.py b/algorithm/hyperneat/substrate/full.py new file mode 100644 index 0000000..98ec869 --- /dev/null +++ b/algorithm/hyperneat/substrate/full.py @@ -0,0 +1,76 @@ +import numpy as np +from .default import DefaultSubstrate + + +class FullSubstrate(DefaultSubstrate): + + def __init__(self, + input_coors=((-1, -1), (0, -1), (1, -1)), + hidden_coors=((-1, 0), (0, 0), (1, 0)), + output_coors=((0, 1),), + ): + query_coors, nodes, conns = analysis_substrate(input_coors, output_coors, hidden_coors) + super().__init__( + len(input_coors), + len(output_coors), + query_coors, + nodes, + conns + ) + + +def analysis_substrate(input_coors, output_coors, hidden_coors): + input_coors = np.array(input_coors) + output_coors = np.array(output_coors) + hidden_coors = np.array(hidden_coors) + + cd = input_coors.shape[1] # coordinate dimensions + si = input_coors.shape[0] # input coordinate size + so = output_coors.shape[0] # output coordinate size + sh = hidden_coors.shape[0] # hidden coordinate size + + input_idx = np.arange(si) + output_idx = np.arange(si, si + so) + hidden_idx = np.arange(si + so, si + so + sh) + + total_conns = si * sh + sh * sh + sh * so + query_coors = np.zeros((total_conns, cd * 2)) + correspond_keys = np.zeros((total_conns, 2)) + + # connect input to hidden + aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, input_coors, hidden_coors) + query_coors[0: si * sh, :] = aux_coors + correspond_keys[0: si * sh, :] = aux_keys + + # connect hidden to hidden + aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, hidden_coors, hidden_coors) + query_coors[si * sh: si * sh + sh * sh, :] = aux_coors + correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys + + # connect hidden to output + aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, hidden_coors, output_coors) + query_coors[si * sh + sh * sh:, :] = aux_coors + correspond_keys[si * sh + sh * sh:, :] = aux_keys + + nodes = np.concatenate((input_idx, output_idx, hidden_idx))[..., np.newaxis] + conns = np.zeros((correspond_keys.shape[0], 4), dtype=np.float32) # input_idx, output_idx, enabled, weight + conns[:, 0:2] = correspond_keys + conns[:, 2] = 1 # enabled is True + + return query_coors, nodes, conns + + +def cartesian_product(keys1, keys2, coors1, coors2): + len1 = keys1.shape[0] + len2 = keys2.shape[0] + + repeated_coors1 = np.repeat(coors1, len2, axis=0) + repeated_keys1 = np.repeat(keys1, len2) + + tiled_coors2 = np.tile(coors2, (len1, 1)) + tiled_keys2 = np.tile(keys2, len1) + + new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1) + correspond_keys = np.column_stack((repeated_keys1, tiled_keys2)) + + return new_coors, correspond_keys diff --git a/algorithm/neat/__init__.py b/algorithm/neat/__init__.py index 44bd257..97185ca 100644 --- a/algorithm/neat/__init__.py +++ b/algorithm/neat/__init__.py @@ -1,3 +1,5 @@ from .gene import * from .genome import * +from .species import * from .neat import NEAT + diff --git a/algorithm/neat/ga/crossover/default.py b/algorithm/neat/ga/crossover/default.py index adabd2d..4d01b41 100644 --- a/algorithm/neat/ga/crossover/default.py +++ b/algorithm/neat/ga/crossover/default.py @@ -3,7 +3,8 @@ import jax, jax.numpy as jnp from .base import BaseCrossover class DefaultCrossover(BaseCrossover): - def __call__(self, randkey, genome, nodes1, nodes2, conns1, conns2): + + def __call__(self, randkey, genome, nodes1, conns1, nodes2, conns2): """ use genome1 and genome2 to generate a new genome notice that genome1 should have higher fitness than genome2 (genome1 is winner!) diff --git a/algorithm/neat/ga/mutation/default.py b/algorithm/neat/ga/mutation/default.py index 7cc1446..33ba6fe 100644 --- a/algorithm/neat/ga/mutation/default.py +++ b/algorithm/neat/ga/mutation/default.py @@ -92,7 +92,7 @@ class DefaultMutation(BaseMutation): return nodes_, conns_ def successful(): - return nodes_, genome.add_conn(conns_, i_key, o_key, True, genome.conns.new_custom_attrs()) + return nodes_, genome.add_conn(conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs()) def already_exist(): return nodes_, conns_.at[conn_pos, 2].set(True) @@ -105,11 +105,12 @@ class DefaultMutation(BaseMutation): return jax.lax.cond( is_already_exist, already_exist, - jax.lax.cond( - is_cycle, - nothing, - successful - ) + lambda: + jax.lax.cond( + is_cycle, + nothing, + successful + ) ) elif genome.network_type == 'recurrent': @@ -138,23 +139,23 @@ class DefaultMutation(BaseMutation): k1, k2, k3, k4 = jax.random.split(randkey, num=4) r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,)) - def no(k, g): - return g + def no(key_, nodes_, conns_): + return nodes_, conns_ - genome = jax.lax.cond(r1 < self.node_add, mutate_add_node, no, k1, nodes, conns) - genome = jax.lax.cond(r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns) - genome = jax.lax.cond(r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns) - genome = jax.lax.cond(r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns) + nodes, conns = jax.lax.cond(r1 < self.node_add, mutate_add_node, no, k1, nodes, conns) + nodes, conns = jax.lax.cond(r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns) + nodes, conns = jax.lax.cond(r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns) + nodes, conns = jax.lax.cond(r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns) - return genome + return nodes, conns def mutate_values(self, randkey, genome, nodes, conns): k1, k2 = jax.random.split(randkey, num=2) - nodes_keys = jax.random.split(k1, num=genome.nodes.shape[0]) - conns_keys = jax.random.split(k2, num=genome.conns.shape[0]) + nodes_keys = jax.random.split(k1, num=nodes.shape[0]) + conns_keys = jax.random.split(k2, num=conns.shape[0]) - new_nodes = jax.vmap(genome.nodes.mutate, in_axes=(0, 0))(nodes_keys, nodes) - new_conns = jax.vmap(genome.conns.mutate, in_axes=(0, 0))(conns_keys, conns) + new_nodes = jax.vmap(genome.node_gene.mutate, in_axes=(0, 0))(nodes_keys, nodes) + new_conns = jax.vmap(genome.conn_gene.mutate, in_axes=(0, 0))(conns_keys, conns) # nan nodes not changed new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes) diff --git a/algorithm/neat/gene/conn/default.py b/algorithm/neat/gene/conn/default.py index f5915c6..39db6e2 100644 --- a/algorithm/neat/gene/conn/default.py +++ b/algorithm/neat/gene/conn/default.py @@ -7,8 +7,7 @@ from . import BaseConnGene class DefaultConnGene(BaseConnGene): "Default connection gene, with the same behavior as in NEAT-python." - fixed_attrs = ['input_index', 'output_index', 'enabled'] - attrs = ['weight'] + custom_attrs = ['weight'] def __init__( self, diff --git a/algorithm/neat/gene/node/default.py b/algorithm/neat/gene/node/default.py index 6580072..3f43d4f 100644 --- a/algorithm/neat/gene/node/default.py +++ b/algorithm/neat/gene/node/default.py @@ -9,7 +9,6 @@ from . import BaseNodeGene class DefaultNodeGene(BaseNodeGene): "Default node gene, with the same behavior as in NEAT-python." - fixed_attrs = ['index'] custom_attrs = ['bias', 'response', 'aggregation', 'activation'] def __init__( @@ -82,8 +81,8 @@ class DefaultNodeGene(BaseNodeGene): return ( jnp.abs(node1[1] - node2[1]) + jnp.abs(node1[2] - node2[2]) + - node1[3] != node2[3] + - node1[4] != node2[4] + (node1[3] != node2[3]) + + (node1[4] != node2[4]) ) def forward(self, attrs, inputs): diff --git a/algorithm/neat/genome/base.py b/algorithm/neat/genome/base.py index 4567193..928c0f5 100644 --- a/algorithm/neat/genome/base.py +++ b/algorithm/neat/genome/base.py @@ -4,7 +4,6 @@ from utils import fetch_first class BaseGenome: - network_type = None def __init__( diff --git a/algorithm/neat/genome/default.py b/algorithm/neat/genome/default.py index 786d8df..768145d 100644 --- a/algorithm/neat/genome/default.py +++ b/algorithm/neat/genome/default.py @@ -1,3 +1,5 @@ +from typing import Callable + import jax, jax.numpy as jnp from utils import unflatten_conns, topological_sort, I_INT @@ -13,10 +15,20 @@ class DefaultGenome(BaseGenome): def __init__(self, num_inputs: int, num_outputs: int, + max_nodes=5, + max_conns=4, node_gene: BaseNodeGene = DefaultNodeGene(), conn_gene: BaseConnGene = DefaultConnGene(), + output_transform: Callable = None ): - super().__init__(num_inputs, num_outputs, node_gene, conn_gene) + super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene) + + if output_transform is not None: + try: + aux = output_transform(jnp.zeros(num_outputs)) + except Exception as e: + raise ValueError(f"Output transform function failed: {e}") + self.output_transform = output_transform def transform(self, nodes, conns): u_conns = unflatten_conns(nodes, conns) @@ -72,4 +84,7 @@ class DefaultGenome(BaseGenome): vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0)) - return vals[self.output_idx] + if self.output_transform is None: + return vals[self.output_idx] + else: + return self.output_transform(vals[self.output_idx]) diff --git a/algorithm/neat/genome/recurrent.py b/algorithm/neat/genome/recurrent.py index b1cef54..2f4e630 100644 --- a/algorithm/neat/genome/recurrent.py +++ b/algorithm/neat/genome/recurrent.py @@ -13,11 +13,13 @@ class RecurrentGenome(BaseGenome): def __init__(self, num_inputs: int, num_outputs: int, + max_nodes: int, + max_conns: int, node_gene: BaseNodeGene = DefaultNodeGene(), conn_gene: BaseConnGene = DefaultConnGene(), activate_time: int = 10, ): - super().__init__(num_inputs, num_outputs, node_gene, conn_gene) + super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene) self.activate_time = activate_time def transform(self, nodes, conns): diff --git a/algorithm/neat/neat.py b/algorithm/neat/neat.py index d7c97f8..1de420d 100644 --- a/algorithm/neat/neat.py +++ b/algorithm/neat/neat.py @@ -1,20 +1,19 @@ import jax, jax.numpy as jnp from utils import State from .. import BaseAlgorithm -from .genome import * from .species import * from .ga import * + class NEAT(BaseAlgorithm): def __init__( self, - genome: BaseGenome, species: BaseSpecies, mutation: BaseMutation = DefaultMutation(), crossover: BaseCrossover = DefaultCrossover(), ): - self.genome = genome + self.genome = species.genome self.species = species self.mutation = mutation self.crossover = crossover @@ -23,14 +22,14 @@ class NEAT(BaseAlgorithm): k1, k2 = jax.random.split(randkey, 2) return State( randkey=k1, - generation=0, - next_node_key=max(*self.genome.input_idx, *self.genome.output_idx) + 2, + generation=jnp.array(0.), + next_node_key=jnp.array(max(*self.genome.input_idx, *self.genome.output_idx) + 2, dtype=jnp.float32), # inputs nodes, output nodes, 1 hidden node species=self.species.setup(k2), ) def ask(self, state: State): - return self.species.ask(state) + return self.species.ask(state.species) def tell(self, state: State, fitness): k1, k2, randkey = jax.random.split(state.randkey, 3) @@ -40,25 +39,39 @@ class NEAT(BaseAlgorithm): randkey=randkey ) - state, winner, loser, elite_mask = self.species.update_species(state, fitness, state.generation) + species_state, winner, loser, elite_mask = self.species.update_species(state.species, fitness, state.generation) + state = state.update(species=species_state) state = self.create_next_generation(k2, state, winner, loser, elite_mask) - state = self.species.speciate(state, state.generation) - + species_state = self.species.speciate(state.species, state.generation) + state = state.update(species=species_state) return state - def transform(self, state: State): + def transform(self, individual): """transform the genome into a neural network""" - raise NotImplementedError + nodes, conns = individual + return self.genome.transform(nodes, conns) def forward(self, inputs, transformed): - raise NotImplementedError + return self.genome.forward(inputs, transformed) + + @property + def num_inputs(self): + return self.genome.num_inputs + + @property + def num_outputs(self): + return self.genome.num_outputs + + @property + def pop_size(self): + return self.species.pop_size def create_next_generation(self, randkey, state, winner, loser, elite_mask): # prepare random keys pop_size = self.species.pop_size - new_node_keys = jnp.arange(pop_size) + state.species.next_node_key + new_node_keys = jnp.arange(pop_size) + state.next_node_key k1, k2 = jax.random.split(randkey, 2) crossover_rand_keys = jax.random.split(k1, pop_size) @@ -69,11 +82,11 @@ class NEAT(BaseAlgorithm): # batch crossover n_nodes, n_conns = (jax.vmap(self.crossover, in_axes=(0, None, 0, 0, 0, 0)) - (crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc)) + (crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc)) # batch mutation m_n_nodes, m_n_conns = (jax.vmap(self.mutation, in_axes=(0, None, 0, 0, 0)) - (mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys)) + (mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys)) # elitism don't mutate pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes) @@ -92,3 +105,9 @@ class NEAT(BaseAlgorithm): next_node_key=next_node_key, ) + def member_count(self, state: State): + return state.species.member_count + + def generation(self, state: State): + # to analysis the algorithm + return state.generation diff --git a/algorithm/neat/species/default.py b/algorithm/neat/species/default.py index c04e653..7cf3e93 100644 --- a/algorithm/neat/species/default.py +++ b/algorithm/neat/species/default.py @@ -2,9 +2,10 @@ import numpy as np import jax, jax.numpy as jnp from utils import State, rank_elements, argmin_with_mask, fetch_first from ..genome import BaseGenome +from .base import BaseSpecies -class DefaultSpecies: +class DefaultSpecies(BaseSpecies): def __init__(self, genome: BaseGenome, @@ -18,9 +19,8 @@ class DefaultSpecies: genome_elitism: int = 2, survival_threshold: float = 0.2, min_species_size: int = 1, - compatibility_threshold: float = 3.5 + compatibility_threshold: float = 3. ): - self.genome = genome self.pop_size = pop_size self.species_size = species_size @@ -59,8 +59,12 @@ class DefaultSpecies: center_nodes = center_nodes.at[0].set(pop_nodes[0]) center_conns = center_conns.at[0].set(pop_conns[0]) + pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns)) + return State( randkey=randkey, + pop_nodes=pop_nodes, + pop_conns=pop_conns, species_keys=species_keys, best_fitness=best_fitness, last_improved=last_improved, @@ -68,7 +72,7 @@ class DefaultSpecies: idx2species=idx2species, center_nodes=center_nodes, center_conns=center_conns, - next_species_key=1, # 0 is reserved for the first species + next_species_key=jnp.array(1), # 0 is reserved for the first species ) def ask(self, state): @@ -99,7 +103,7 @@ class DefaultSpecies: # crossover info winner, loser, elite_mask = self.create_crossover_pair(state, k1, spawn_number, fitness) - return state(randkey=k2), winner, loser, elite_mask + return state.update(randkey=k2), winner, loser, elite_mask def update_species_fitness(self, state, fitness): """ @@ -156,17 +160,17 @@ class DefaultSpecies: jnp.nan, # last_improved jnp.nan, # member_count -jnp.inf, # species_fitness - jnp.full_like(center_nodes[idx], jnp.nan), # center_nodes - jnp.full_like(center_conns[idx], jnp.nan), # center_conns + jnp.full_like(state.center_nodes[idx], jnp.nan), # center_nodes + jnp.full_like(state.center_conns[idx], jnp.nan), # center_conns ), # stagnation species lambda: ( - species_keys[idx], + state.species_keys[idx], best_fitness[idx], last_improved[idx], state.member_count[idx], species_fitness[idx], - center_nodes[idx], - center_conns[idx] + state.center_nodes[idx], + state.center_conns[idx] ) # not stagnation species ) @@ -216,7 +220,7 @@ class DefaultSpecies: spawn_number = spawn_number.astype(jnp.int32) # must control the sum of spawn_number to be equal to pop_size - error = state.P - jnp.sum(spawn_number) + error = self.pop_size - jnp.sum(spawn_number) # add error to the first species to control the sum of spawn_number spawn_number = spawn_number.at[0].add(error) @@ -287,14 +291,14 @@ class DefaultSpecies: def body_func(carry): i, i2s, cns, ccs, o2c = carry - distances = o2p_distance_func(cns, ccs, state.pop_nodes, state.pop_conns) + distances = o2p_distance_func(cns[i], ccs[i], state.pop_nodes, state.pop_conns) # find the closest one closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s)) - i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i]) - cns = cns.set(i, state.pop_nodes[closest_idx]) - ccs = ccs.set(i, state.pop_conns[closest_idx]) + i2s = i2s.at[closest_idx].set(state.species_keys[i]) + cns = cns.at[i].set(state.pop_nodes[closest_idx]) + ccs = ccs.at[i].set(state.pop_conns[closest_idx]) # the genome with closest_idx will become the new center, thus its distance to center is 0. o2c = o2c.at[closest_idx].set(0) @@ -346,8 +350,8 @@ class DefaultSpecies: o2c = o2c.at[idx].set(0) # update center genomes - cns = cns.set(i, state.pop_nodes[idx]) - ccs = ccs.set(i, state.pop_conns[idx]) + cns = cns.at[i].set(state.pop_nodes[idx]) + ccs = ccs.at[i].set(state.pop_conns[idx]) # find the members for the new species i2s, o2c = speciate_by_threshold(i, i2s, cns, ccs, sk, o2c) @@ -384,7 +388,7 @@ class DefaultSpecies: _, idx2species, center_nodes, center_conns, species_keys, _, next_species_key = jax.lax.while_loop( cond_func, body_func, - (0, state.idx2species, state.center_nodes, center_conns, state.species_info.species_keys, o2c_distances, + (0, state.idx2species, center_nodes, center_conns, state.species_keys, o2c_distances, state.next_species_key) ) @@ -401,8 +405,8 @@ class DefaultSpecies: def count_members(idx): return jax.lax.cond( jnp.isnan(species_keys[idx]), # if the species is not existing - lambda _: jnp.nan, # nan - lambda _: jnp.sum(idx2species == species_keys[idx], dtype=jnp.float32) # count members + lambda: jnp.nan, # nan + lambda: jnp.sum(idx2species == species_keys[idx], dtype=jnp.float32) # count members ) member_count = jax.vmap(count_members)(self.species_arange) @@ -422,7 +426,8 @@ class DefaultSpecies: """ The distance between two genomes """ - return self.node_distance(nodes1, nodes2) + self.conn_distance(conns1, conns2) + d = self.node_distance(nodes1, nodes2) + self.conn_distance(conns1, conns2) + return d def node_distance(self, nodes1, nodes2): """ @@ -494,18 +499,18 @@ def initialize_population(pop_size, genome): o_nodes[input_idx, 0] = genome.input_idx o_nodes[output_idx, 0] = genome.output_idx o_nodes[new_node_key, 0] = new_node_key # one hidden node - o_nodes[np.concatenate([input_idx, output_idx]), 1:] = genome.node_gene.new_attrs() - o_nodes[new_node_key, 1:] = genome.node_gene.new_attrs() # one hidden node + o_nodes[np.concatenate([input_idx, output_idx]), 1:] = genome.node_gene.new_custom_attrs() + o_nodes[new_node_key, 1:] = genome.node_gene.new_custom_attrs() # one hidden node input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)] # input nodes to hidden o_conns[input_idx, 0:2] = input_conns # in key, out key o_conns[input_idx, 2] = True # enabled - o_conns[input_idx, 3:] = genome.conn_gene.new_conn_attrs() + o_conns[input_idx, 3:] = genome.conn_gene.new_custom_attrs() output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx] # hidden to output nodes o_conns[output_idx, 0:2] = output_conns # in key, out key o_conns[output_idx, 2] = True # enabled - o_conns[output_idx, 3:] = genome.conn_gene.new_conn_attrs() + o_conns[output_idx, 3:] = genome.conn_gene.new_custom_attrs() # repeat origin genome for P times to create population pop_nodes = np.tile(o_nodes, (pop_size, 1, 1)) diff --git a/examples/brax/ant.py b/examples/brax/ant.py index 60f3d8b..082d202 100644 --- a/examples/brax/ant.py +++ b/examples/brax/ant.py @@ -1,38 +1,36 @@ -import jax.numpy as jnp - -from config import * from pipeline import Pipeline -from algorithm import NEAT -from algorithm.neat.gene import NormalGene, NormalGeneConfig -from problem.rl_env import BraxEnv, BraxConfig - - -def example_conf(): - return Config( - basic=BasicConfig( - seed=42, - fitness_target=10000, - pop_size=100 - ), - neat=NeatConfig( - inputs=27, - outputs=8, - ), - gene=NormalGeneConfig( - activation_default=Act.tanh, - activation_options=(Act.tanh,), - ), - problem=BraxConfig( - env_name="ant" - ) - ) +from algorithm.neat import * +from problem.rl_env import BraxEnv +from utils import Act if __name__ == '__main__': - conf = example_conf() + pipeline = Pipeline( + algorithm=NEAT( + species=DefaultSpecies( + genome=DefaultGenome( + num_inputs=27, + num_outputs=8, + max_nodes=50, + max_conns=100, + node_gene=DefaultNodeGene( + activation_options=(Act.tanh,), + activation_default=Act.tanh, + ) + ), + pop_size=1000, + species_size=10, + ), + ), + problem=BraxEnv( + env_name='ant', + ), + generation_limit=10000, + fitness_target=5000 + ) - algorithm = NEAT(conf, NormalGene) - pipeline = Pipeline(conf, algorithm, BraxEnv) + # initialize state state = pipeline.setup() - pipeline.pre_compile(state) - state, best = pipeline.auto_run(state) + # print(state) + # run until terminate + state, best = pipeline.auto_run(state) \ No newline at end of file diff --git a/examples/brax/half_cheetah.py b/examples/brax/half_cheetah.py index dbc6207..da4823c 100644 --- a/examples/brax/half_cheetah.py +++ b/examples/brax/half_cheetah.py @@ -1,42 +1,36 @@ -import jax.numpy as jnp - -from config import * from pipeline import Pipeline -from algorithm import NEAT -from algorithm.neat.gene import NormalGene, NormalGeneConfig -from problem.rl_env import BraxEnv, BraxConfig - - -# ['ant', 'halfcheetah', 'hopper', 'humanoid', 'humanoidstandup', 'inverted_pendulum', 'inverted_double_pendulum', 'pusher', 'reacher', 'walker2d'] - - -def example_conf(): - return Config( - basic=BasicConfig( - seed=42, - fitness_target=10000, - generation_limit=10, - pop_size=100 - ), - neat=NeatConfig( - inputs=17, - outputs=6, - ), - gene=NormalGeneConfig( - activation_default=Act.tanh, - activation_options=(Act.tanh,), - ), - problem=BraxConfig( - env_name="halfcheetah" - ) - ) +from algorithm.neat import * +from problem.rl_env import BraxEnv +from utils import Act if __name__ == '__main__': - conf = example_conf() - algorithm = NEAT(conf, NormalGene) - pipeline = Pipeline(conf, algorithm, BraxEnv) + pipeline = Pipeline( + algorithm=NEAT( + species=DefaultSpecies( + genome=DefaultGenome( + num_inputs=17, + num_outputs=6, + max_nodes=50, + max_conns=100, + node_gene=DefaultNodeGene( + activation_options=(Act.tanh,), + activation_default=Act.tanh, + ) + ), + pop_size=1000, + species_size=10, + ), + ), + problem=BraxEnv( + env_name='halhcheetah', + ), + generation_limit=10000, + fitness_target=5000 + ) + + # initialize state state = pipeline.setup() - pipeline.pre_compile(state) - state, best = pipeline.auto_run(state) - pipeline.show(state, best, save_path="half_cheetah.gif", ) + # print(state) + # run until terminate + state, best = pipeline.auto_run(state) \ No newline at end of file diff --git a/examples/brax/reacher.py b/examples/brax/reacher.py index a6ed280..73f4d14 100644 --- a/examples/brax/reacher.py +++ b/examples/brax/reacher.py @@ -1,38 +1,36 @@ -import jax.numpy as jnp - -from config import * from pipeline import Pipeline -from algorithm import NEAT -from algorithm.neat.gene import NormalGene, NormalGeneConfig -from problem.rl_env import BraxEnv, BraxConfig - - -def example_conf(): - return Config( - basic=BasicConfig( - seed=42, - fitness_target=10000, - pop_size=1000 - ), - neat=NeatConfig( - inputs=11, - outputs=2, - ), - gene=NormalGeneConfig( - activation_default=Act.tanh, - activation_options=(Act.tanh,), - ), - problem=BraxConfig( - env_name="reacher" - ) - ) +from algorithm.neat import * +from problem.rl_env import BraxEnv +from utils import Act if __name__ == '__main__': - conf = example_conf() + pipeline = Pipeline( + algorithm=NEAT( + species=DefaultSpecies( + genome=DefaultGenome( + num_inputs=11, + num_outputs=2, + max_nodes=50, + max_conns=100, + node_gene=DefaultNodeGene( + activation_options=(Act.tanh,), + activation_default=Act.tanh, + ) + ), + pop_size=100, + species_size=10, + ), + ), + problem=BraxEnv( + env_name='reacher', + ), + generation_limit=10000, + fitness_target=5000 + ) - algorithm = NEAT(conf, NormalGene) - pipeline = Pipeline(conf, algorithm, BraxEnv) + # initialize state state = pipeline.setup() - pipeline.pre_compile(state) - state, best = pipeline.auto_run(state) + # print(state) + # run until terminate + state, best = pipeline.auto_run(state) \ No newline at end of file diff --git a/examples/brax_env.py b/examples/brax_env.py deleted file mode 100644 index 2124f24..0000000 --- a/examples/brax_env.py +++ /dev/null @@ -1,73 +0,0 @@ -import imageio -import jax - -import brax -from brax import envs -from brax.io import image -import matplotlib.pyplot as plt - -import time -from tqdm import tqdm -import numpy as np - - -def inference_func(key, *args): - return jax.random.normal(key, shape=(env.action_size,)) - - -env_name = "ant" -backend = "generalized" - -env = envs.create(env_name=env_name, backend=backend) - -jit_env_reset = jax.jit(env.reset) -jit_env_step = jax.jit(env.step) -jit_inference_fn = jax.jit(inference_func) - -rng = jax.random.PRNGKey(seed=1) -ori_state = jit_env_reset(rng=rng) -state = ori_state - -render_history = [] - -for i in range(100): - act_rng, rng = jax.random.split(rng) - - tic = time.time() - act = jit_inference_fn(act_rng, state.obs) - state = jit_env_step(state, act) - print("step time: ", time.time() - tic) - - render_history.append(state.pipeline_state) - - # img = image.render_array(sys=env.sys, state=pipeline_state, width=512, height=512) - # print("render time: ", time.time() - tic) - - # plt.imsave("../images/ant_{}.png".format(i), img) - - reward = state.reward - done = state.done - print(i, reward) - -render_history = jax.device_get(render_history) -# print(render_history) - -imgs = [image.render_array(sys=env.sys, state=s, width=512, height=512) for s in tqdm(render_history)] - - -# for i, s in enumerate(tqdm(render_history)): -# img = image.render_array(sys=env.sys, state=s, width=512, height=512) -# print(img.shape) -# # print(type(img)) -# plt.imsave("../images/ant_{}.png".format(i), img) - - -def create_gif(image_list, gif_name, duration): - with imageio.get_writer(gif_name, mode='I', duration=duration) as writer: - for image in image_list: - # 确保图像的数据类型正确 - formatted_image = np.array(image, dtype=np.uint8) - writer.append_data(formatted_image) - - -create_gif(imgs, "../images/ant.gif", 0.1) diff --git a/examples/brax_render.py b/examples/brax_render.py deleted file mode 100644 index 4347697..0000000 --- a/examples/brax_render.py +++ /dev/null @@ -1,54 +0,0 @@ -import brax -from brax import envs -from brax.envs.wrappers import gym as gym_wrapper -from brax.io import image -import jax -import jax.numpy as jnp -import matplotlib.pyplot as plt -import traceback - -# print(f"Using Brax {brax.__version__}, Jax {jax.__version__}") -# print("From GymWrapper, env.reset()") -# try: -# env = envs.create("inverted_pendulum", -# batch_size=1, -# episode_length=150, -# backend='generalized') -# env = gym_wrapper.GymWrapper(env) -# env.reset() -# img = env.render(mode='rgb_array') -# plt.imshow(img) -# except Exception: -# traceback.print_exc() -# -# print("From GymWrapper, env.reset() and action") -# try: -# env = envs.create("inverted_pendulum", -# batch_size=1, -# episode_length=150, -# backend='generalized') -# env = gym_wrapper.GymWrapper(env) -# env.reset() -# action = jnp.zeros(env.action_space.shape) -# env.step(action) -# img = env.render(mode='rgb_array') -# plt.imshow(img) -# except Exception: -# traceback.print_exc() - -print("From brax env") -try: - env = envs.create("inverted_pendulum", - batch_size=1, - episode_length=150, - backend='generalized') - key = jax.random.PRNGKey(0) - initial_env_state = env.reset(key) - base_state = initial_env_state.pipeline_state - pipeline_state = env.pipeline_init(base_state.q.ravel(), base_state.qd.ravel()) - img = image.render_array(sys=env.sys, state=pipeline_state, width=256, height=256) - print(f"pixel values: [{img.min()}, {img.max()}]") - plt.imshow(img) - plt.show() -except Exception: - traceback.print_exc() \ No newline at end of file diff --git a/examples/func_fit/xor.py b/examples/func_fit/xor.py index 11429ae..94be087 100644 --- a/examples/func_fit/xor.py +++ b/examples/func_fit/xor.py @@ -1,32 +1,31 @@ -from config import * from pipeline import Pipeline -from algorithm import NEAT -from algorithm.neat.gene import NormalGene, NormalGeneConfig -from problem.func_fit import XOR, FuncFitConfig +from algorithm.neat import * + +from problem.func_fit import XOR3d if __name__ == '__main__': - # running config - config = Config( - basic=BasicConfig( - seed=42, - fitness_target=-1e-2, - pop_size=10000 + pipeline = Pipeline( + algorithm=NEAT( + species=DefaultSpecies( + genome=DefaultGenome( + num_inputs=3, + num_outputs=1, + max_nodes=50, + max_conns=100, + ), + pop_size=10000, + species_size=10, + compatibility_threshold=3.5, + ), ), - neat=NeatConfig( - inputs=2, - outputs=1 - ), - gene=NormalGeneConfig(), - problem=FuncFitConfig( - error_method='rmse' - ) + problem=XOR3d(), + generation_limit=10000, + fitness_target=-1e-8 ) - # define algorithm: NEAT with NormalGene - algorithm = NEAT(config, NormalGene) - # full pipeline - pipeline = Pipeline(config, algorithm, XOR) + # initialize state state = pipeline.setup() + # print(state) # run until terminate state, best = pipeline.auto_run(state) # show result diff --git a/examples/func_fit/xor3d_hyperneat.py b/examples/func_fit/xor3d_hyperneat.py new file mode 100644 index 0000000..6f0e7e2 --- /dev/null +++ b/examples/func_fit/xor3d_hyperneat.py @@ -0,0 +1,51 @@ +from pipeline import Pipeline +from algorithm.neat import * +from algorithm.hyperneat import * +from utils import Act + +from problem.func_fit import XOR3d + +if __name__ == '__main__': + pipeline = Pipeline( + algorithm=HyperNEAT( + substrate=FullSubstrate( + input_coors=[(-1, -1), (0.333, -1), (-0.333, -1), (1, -1)], + hidden_coors=[ + (-1, -0.5), (0.333, -0.5), (-0.333, -0.5), (1, -0.5), + (-1, 0), (0.333, 0), (-0.333, 0), (1, 0), + (-1, 0.5), (0.333, 0.5), (-0.333, 0.5), (1, 0.5), + ], + output_coors=[(0, 1), ], + ), + neat=NEAT( + species=DefaultSpecies( + genome=DefaultGenome( + num_inputs=4, # [-1, -1, -1, 0] + num_outputs=1, + max_nodes=50, + max_conns=100, + node_gene=DefaultNodeGene( + activation_default=Act.tanh, + activation_options=(Act.tanh,), + ), + ), + pop_size=10000, + species_size=10, + compatibility_threshold=3.5, + ), + ), + activation=Act.sigmoid, + activate_time=10, + ), + problem=XOR3d(), + generation_limit=300, + fitness_target=-1e-6 + ) + + # initialize state + state = pipeline.setup() + # print(state) + # run until terminate + state, best = pipeline.auto_run(state) + # show result + pipeline.show(state, best) diff --git a/examples/func_fit/xor_hyperneat.py b/examples/func_fit/xor_hyperneat.py deleted file mode 100644 index cfd23f1..0000000 --- a/examples/func_fit/xor_hyperneat.py +++ /dev/null @@ -1,41 +0,0 @@ -from config import * -from pipeline import Pipeline -from algorithm.neat import NormalGene, NormalGeneConfig -from algorithm.hyperneat import HyperNEAT, NormalSubstrate, NormalSubstrateConfig -from problem.func_fit import XOR3d, FuncFitConfig -from utils import Act - - -if __name__ == '__main__': - config = Config( - basic=BasicConfig( - seed=42, - fitness_target=0, - pop_size=1000 - ), - neat=NeatConfig( - max_nodes=50, - max_conns=100, - max_species=30, - inputs=4, - outputs=1 - ), - hyperneat=HyperNeatConfig( - inputs=3, - outputs=1 - ), - substrate=NormalSubstrateConfig( - input_coors=((-1, -1), (-0.5, -1), (0.5, -1), (1, -1)), - ), - gene=NormalGeneConfig( - activation_default=Act.tanh, - activation_options=(Act.tanh, ), - ), - problem=FuncFitConfig() - ) - - algorithm = HyperNEAT(config, NormalGene, NormalSubstrate) - pipeline = Pipeline(config, algorithm, XOR3d) - state = pipeline.setup() - state, best = pipeline.auto_run(state) - pipeline.show(state, best) diff --git a/examples/func_fit/xor_recurrent.py b/examples/func_fit/xor_recurrent.py index d100fd8..1e1f3bd 100644 --- a/examples/func_fit/xor_recurrent.py +++ b/examples/func_fit/xor_recurrent.py @@ -1,41 +1,41 @@ -from config import * from pipeline import Pipeline -from algorithm import NEAT -from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig -from problem.func_fit import XOR3d, FuncFitConfig +from algorithm.neat import * +from problem.func_fit import XOR3d +from utils.activation import ACT_ALL +from utils.aggregation import AGG_ALL if __name__ == '__main__': - config = Config( - basic=BasicConfig( - seed=42, - fitness_target=-1e-2, - generation_limit=300, - pop_size=1000 + pipeline = Pipeline( + seed=0, + algorithm=NEAT( + species=DefaultSpecies( + genome=RecurrentGenome( + num_inputs=3, + num_outputs=1, + max_nodes=50, + max_conns=100, + activate_time=5, + node_gene=DefaultNodeGene( + activation_options=ACT_ALL, + # aggregation_options=AGG_ALL, + activation_replace_rate=0.2 + ), + ), + pop_size=10000, + species_size=10, + compatibility_threshold=3.5, + ), ), - neat=NeatConfig( - network_type="recurrent", - max_nodes=50, - max_conns=100, - max_species=30, - conn_add=0.5, - conn_delete=0.5, - node_add=0.4, - node_delete=0.4, - inputs=3, - outputs=1 - ), - gene=RecurrentGeneConfig( - activate_times=10 - ), - problem=FuncFitConfig( - error_method='rmse' - ) + problem=XOR3d(), + generation_limit=10000, + fitness_target=-1e-8 ) - algorithm = NEAT(config, RecurrentGene) - pipeline = Pipeline(config, algorithm, XOR3d) + # initialize state state = pipeline.setup() - pipeline.pre_compile(state) + # print(state) + # run until terminate state, best = pipeline.auto_run(state) + # show result pipeline.show(state, best) diff --git a/examples/general_xor.py b/examples/general_xor.py deleted file mode 100644 index a2d45ee..0000000 --- a/examples/general_xor.py +++ /dev/null @@ -1,36 +0,0 @@ -from config import * -from pipeline import Pipeline -from algorithm import NEAT -from algorithm.neat.gene import NormalGene, NormalGeneConfig -from problem.func_fit import XOR, FuncFitConfig - -if __name__ == '__main__': - config = Config( - basic=BasicConfig( - seed=42, - fitness_target=-1e-2, - pop_size=10000 - ), - neat=NeatConfig( - max_nodes=50, - max_conns=100, - max_species=30, - conn_add=0.8, - conn_delete=0, - node_add=0.4, - node_delete=0, - inputs=2, - outputs=1 - ), - gene=NormalGeneConfig(), - problem=FuncFitConfig( - error_method='rmse' - ) - ) - - algorithm = NEAT(config, NormalGene) - pipeline = Pipeline(config, algorithm, XOR) - state = pipeline.setup() - pipeline.pre_compile(state) - state, best = pipeline.auto_run(state) - pipeline.show(state, best) diff --git a/examples/gymnax/acrobot.py b/examples/gymnax/acrobot.py deleted file mode 100644 index 3867ab3..0000000 --- a/examples/gymnax/acrobot.py +++ /dev/null @@ -1,39 +0,0 @@ -import jax.numpy as jnp - -from config import * -from pipeline import Pipeline -from algorithm import NEAT -from algorithm.neat.gene import NormalGene, NormalGeneConfig -from problem.rl_env import GymNaxConfig, GymNaxEnv - - -def example_conf(): - return Config( - basic=BasicConfig( - seed=42, - fitness_target=0, - pop_size=10000 - ), - neat=NeatConfig( - inputs=6, - outputs=3, - ), - gene=NormalGeneConfig( - activation_default=Act.tanh, - activation_options=(Act.tanh,), - ), - problem=GymNaxConfig( - env_name='Acrobot-v1', - output_transform=lambda out: jnp.argmax(out) # the action of acrobot is {0, 1, 2} - ) - ) - - -if __name__ == '__main__': - conf = example_conf() - - algorithm = NEAT(conf, NormalGene) - pipeline = Pipeline(conf, algorithm, GymNaxEnv) - state = pipeline.setup() - pipeline.pre_compile(state) - state, best = pipeline.auto_run(state) diff --git a/examples/gymnax/arcbot.py b/examples/gymnax/arcbot.py new file mode 100644 index 0000000..e56ffa1 --- /dev/null +++ b/examples/gymnax/arcbot.py @@ -0,0 +1,34 @@ +import jax.numpy as jnp + +from pipeline import Pipeline +from algorithm.neat import * + +from problem.rl_env import GymNaxEnv + +if __name__ == '__main__': + pipeline = Pipeline( + algorithm=NEAT( + species=DefaultSpecies( + genome=DefaultGenome( + num_inputs=6, + num_outputs=3, + max_nodes=50, + max_conns=100, + output_transform=lambda out: jnp.argmax(out) # the action of acrobot is {0, 1, 2} + ), + pop_size=10000, + species_size=10, + ), + ), + problem=GymNaxEnv( + env_name='Acrobot-v1', + ), + generation_limit=10000, + fitness_target=-62 + ) + + # initialize state + state = pipeline.setup() + # print(state) + # run until terminate + state, best = pipeline.auto_run(state) \ No newline at end of file diff --git a/examples/gymnax/cartpole.py b/examples/gymnax/cartpole.py index d5564c5..75e9d88 100644 --- a/examples/gymnax/cartpole.py +++ b/examples/gymnax/cartpole.py @@ -1,84 +1,34 @@ import jax.numpy as jnp -from config import * from pipeline import Pipeline -from algorithm import NEAT -from algorithm.neat.gene import NormalGene, NormalGeneConfig -from problem.rl_env import GymNaxConfig, GymNaxEnv - - -def example_conf1(): - return Config( - basic=BasicConfig( - seed=42, - fitness_target=500, - pop_size=10000 - ), - neat=NeatConfig( - inputs=4, - outputs=1, - ), - gene=NormalGeneConfig( - activation_default=Act.sigmoid, - activation_options=(Act.sigmoid,), - ), - problem=GymNaxConfig( - env_name='CartPole-v1', - output_transform=lambda out: jnp.where(out[0] > 0.5, 1, 0) # the action of cartpole is {0, 1} - ) - ) - - -def example_conf2(): - return Config( - basic=BasicConfig( - seed=42, - fitness_target=500, - pop_size=10000 - ), - neat=NeatConfig( - inputs=4, - outputs=1, - ), - gene=NormalGeneConfig( - activation_default=Act.tanh, - activation_options=(Act.tanh,), - ), - problem=GymNaxConfig( - env_name='CartPole-v1', - output_transform=lambda out: jnp.where(out[0] > 0, 1, 0) # the action of cartpole is {0, 1} - ) - ) - - -def example_conf3(): - return Config( - basic=BasicConfig( - seed=42, - fitness_target=501, - pop_size=10000 - ), - neat=NeatConfig( - inputs=4, - outputs=2, - ), - gene=NormalGeneConfig( - activation_default=Act.tanh, - activation_options=(Act.tanh,), - ), - problem=GymNaxConfig( - env_name='CartPole-v1', - output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1} - ) - ) +from algorithm.neat import * +from problem.rl_env import GymNaxEnv if __name__ == '__main__': - # all config files above can solve cartpole - conf = example_conf3() + pipeline = Pipeline( + algorithm=NEAT( + species=DefaultSpecies( + genome=DefaultGenome( + num_inputs=4, + num_outputs=2, + max_nodes=50, + max_conns=100, + output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1} + ), + pop_size=10000, + species_size=10, + ), + ), + problem=GymNaxEnv( + env_name='CartPole-v1', + ), + generation_limit=10000, + fitness_target=500 + ) - algorithm = NEAT(conf, NormalGene) - pipeline = Pipeline(conf, algorithm, GymNaxEnv) + # initialize state state = pipeline.setup() - pipeline.pre_compile(state) - state, best = pipeline.auto_run(state) + # print(state) + # run until terminate + state, best = pipeline.auto_run(state) \ No newline at end of file diff --git a/examples/gymnax/mountain_car.py b/examples/gymnax/mountain_car.py index 6c3a43d..d9082cf 100644 --- a/examples/gymnax/mountain_car.py +++ b/examples/gymnax/mountain_car.py @@ -1,39 +1,34 @@ import jax.numpy as jnp -from config import * from pipeline import Pipeline -from algorithm import NEAT -from algorithm.neat.gene import NormalGene, NormalGeneConfig -from problem.rl_env import GymNaxConfig, GymNaxEnv - - -def example_conf(): - return Config( - basic=BasicConfig( - seed=42, - fitness_target=0, - pop_size=10000 - ), - neat=NeatConfig( - inputs=2, - outputs=3, - ), - gene=NormalGeneConfig( - activation_default=Act.sigmoid, - activation_options=(Act.sigmoid,), - ), - problem=GymNaxConfig( - env_name='MountainCar-v0', - output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1, 2} - ) - ) +from algorithm.neat import * +from problem.rl_env import GymNaxEnv if __name__ == '__main__': - conf = example_conf() + pipeline = Pipeline( + algorithm=NEAT( + species=DefaultSpecies( + genome=DefaultGenome( + num_inputs=2, + num_outputs=3, + max_nodes=50, + max_conns=100, + output_transform=lambda out: jnp.argmax(out) # the action of mountain car is {0, 1, 2} + ), + pop_size=10000, + species_size=10, + ), + ), + problem=GymNaxEnv( + env_name='MountainCar-v0', + ), + generation_limit=10000, + fitness_target=0 + ) - algorithm = NEAT(conf, NormalGene) - pipeline = Pipeline(conf, algorithm, GymNaxEnv) + # initialize state state = pipeline.setup() - pipeline.pre_compile(state) - state, best = pipeline.auto_run(state) + # print(state) + # run until terminate + state, best = pipeline.auto_run(state) \ No newline at end of file diff --git a/examples/gymnax/mountain_car_continuous.py b/examples/gymnax/mountain_car_continuous.py index 41169a4..f863e52 100644 --- a/examples/gymnax/mountain_car_continuous.py +++ b/examples/gymnax/mountain_car_continuous.py @@ -1,38 +1,36 @@ -import jax.numpy as jnp - -from config import * from pipeline import Pipeline -from algorithm import NEAT -from algorithm.neat.gene import NormalGene, NormalGeneConfig -from problem.rl_env import GymNaxConfig, GymNaxEnv - - -def example_conf(): - return Config( - basic=BasicConfig( - seed=42, - fitness_target=100, - pop_size=10000 - ), - neat=NeatConfig( - inputs=2, - outputs=1, - ), - gene=NormalGeneConfig( - activation_default=Act.tanh, - activation_options=(Act.tanh,), - ), - problem=GymNaxConfig( - env_name='MountainCarContinuous-v0' - ) - ) +from algorithm.neat import * +from problem.rl_env import GymNaxEnv +from utils import Act if __name__ == '__main__': - conf = example_conf() + pipeline = Pipeline( + algorithm=NEAT( + species=DefaultSpecies( + genome=DefaultGenome( + num_inputs=2, + num_outputs=1, + max_nodes=50, + max_conns=100, + node_gene=DefaultNodeGene( + activation_options=(Act.tanh, ), + activation_default=Act.tanh, + ) + ), + pop_size=10000, + species_size=10, + ), + ), + problem=GymNaxEnv( + env_name='MountainCarContinuous-v0', + ), + generation_limit=10000, + fitness_target=500 + ) - algorithm = NEAT(conf, NormalGene) - pipeline = Pipeline(conf, algorithm, GymNaxEnv) + # initialize state state = pipeline.setup() - pipeline.pre_compile(state) - state, best = pipeline.auto_run(state) + # print(state) + # run until terminate + state, best = pipeline.auto_run(state) \ No newline at end of file diff --git a/examples/gymnax/pendulum.py b/examples/gymnax/pendulum.py index 5a75832..7073dbe 100644 --- a/examples/gymnax/pendulum.py +++ b/examples/gymnax/pendulum.py @@ -1,40 +1,37 @@ -import jax.numpy as jnp - -from config import * from pipeline import Pipeline -from algorithm import NEAT -from algorithm.neat.gene import NormalGene, NormalGeneConfig -from problem.rl_env import GymNaxConfig, GymNaxEnv - - -def example_conf(): - return Config( - basic=BasicConfig( - seed=42, - fitness_target=0, - pop_size=10000 - ), - neat=NeatConfig( - inputs=3, - outputs=1, - ), - gene=NormalGeneConfig( - activation_default=Act.tanh, - activation_options=(Act.tanh,), - ), - problem=GymNaxConfig( - env_name='Pendulum-v1', - output_transform=lambda out: out * 2 # the action of pendulum is [-2, 2] - ) - ) +from algorithm.neat import * +from problem.rl_env import GymNaxEnv +from utils import Act if __name__ == '__main__': - conf = example_conf() + pipeline = Pipeline( + algorithm=NEAT( + species=DefaultSpecies( + genome=DefaultGenome( + num_inputs=3, + num_outputs=1, + max_nodes=50, + max_conns=100, + node_gene=DefaultNodeGene( + activation_options=(Act.tanh,), + activation_default=Act.tanh, + ), + output_transform=lambda out: out * 2 # the action of pendulum is [-2, 2] + ), + pop_size=10000, + species_size=10, + ), + ), + problem=GymNaxEnv( + env_name='Pendulum-v1', + ), + generation_limit=10000, + fitness_target=0 + ) - algorithm = NEAT(conf, NormalGene) - pipeline = Pipeline(conf, algorithm, GymNaxEnv) + # initialize state state = pipeline.setup() - pipeline.pre_compile(state) - state, best = pipeline.auto_run(state) - + # print(state) + # run until terminate + state, best = pipeline.auto_run(state) \ No newline at end of file diff --git a/examples/gymnax/reacher.py b/examples/gymnax/reacher.py index 39afdbd..acf23aa 100644 --- a/examples/gymnax/reacher.py +++ b/examples/gymnax/reacher.py @@ -1,36 +1,33 @@ -from config import * +import jax.numpy as jnp + from pipeline import Pipeline -from algorithm import NEAT -from algorithm.neat.gene import NormalGene, NormalGeneConfig -from problem.rl_env import GymNaxConfig, GymNaxEnv - - -def example_conf(): - return Config( - basic=BasicConfig( - seed=42, - fitness_target=500, - pop_size=10000 - ), - neat=NeatConfig( - inputs=8, - outputs=2, - ), - gene=NormalGeneConfig( - activation_default=Act.sigmoid, - activation_options=(Act.sigmoid,), - ), - problem=GymNaxConfig( - env_name='Reacher-misc', - ) - ) +from algorithm.neat import * +from problem.rl_env import GymNaxEnv if __name__ == '__main__': - conf = example_conf() + pipeline = Pipeline( + algorithm=NEAT( + species=DefaultSpecies( + genome=DefaultGenome( + num_inputs=8, + num_outputs=2, + max_nodes=50, + max_conns=100, + ), + pop_size=10000, + species_size=10, + ), + ), + problem=GymNaxEnv( + env_name='Reacher-misc', + ), + generation_limit=10000, + fitness_target =500 + ) - algorithm = NEAT(conf, NormalGene) - pipeline = Pipeline(conf, algorithm, GymNaxEnv) + # initialize state state = pipeline.setup() - pipeline.pre_compile(state) - state, best = pipeline.auto_run(state) + # print(state) + # run until terminate + state, best = pipeline.auto_run(state) \ No newline at end of file diff --git a/pipeline.py b/pipeline.py index 8b6bdcb..ad33945 100644 --- a/pipeline.py +++ b/pipeline.py @@ -1,25 +1,23 @@ from functools import partial -from typing import Type -import jax +import jax, jax.numpy as jnp import time import numpy as np -from algorithm import NEAT, HyperNEAT -from config import Config -from core import State, Algorithm, Problem +from algorithm import BaseAlgorithm +from problem import BaseProblem +from utils import State class Pipeline: def __init__( - self, - algorithm: Algorithm, - problem: Problem, - seed: int = 42, - fitness_target: float = 1, - generation_limit: int = 1000, - pop_size: int = 100, + self, + algorithm: BaseAlgorithm, + problem: BaseProblem, + seed: int = 42, + fitness_target: float = 1, + generation_limit: int = 1000, ): assert problem.jitable, "Currently, problem must be jitable" @@ -28,17 +26,18 @@ class Pipeline: self.seed = seed self.fitness_target = fitness_target self.generation_limit = generation_limit - self.pop_size = pop_size + self.pop_size = self.algorithm.pop_size print(self.problem.input_shape, self.problem.output_shape) # TODO: make each algorithm's input_num and output_num - assert algorithm.input_num == self.problem.input_shape[-1], f"problem input shape {self.problem.input_shape}" + assert algorithm.num_inputs == self.problem.input_shape[-1], \ + f"algorithm input shape is {algorithm.num_inputs} but problem input shape is {self.problem.input_shape}" - self.act_func = self.algorithm.act + # self.act_func = self.algorithm.act - for _ in range(len(self.problem.input_shape) - 1): - self.act_func = jax.vmap(self.act_func, in_axes=(None, 0, None)) + # for _ in range(len(self.problem.input_shape) - 1): + # self.act_func = jax.vmap(self.act_func, in_axes=(None, 0, None)) self.best_genome = None self.best_fitness = float('-inf') @@ -46,41 +45,57 @@ class Pipeline: def setup(self): key = jax.random.PRNGKey(self.seed) - algorithm_key, evaluate_key = jax.random.split(key, 2) + key, algorithm_key, evaluate_key = jax.random.split(key, 3) # TODO: Problem should has setup function to maintain state return State( + randkey=key, alg=self.algorithm.setup(algorithm_key), pro=self.problem.setup(evaluate_key), ) - @partial(jax.jit, static_argnums=(0,)) def step(self, state): - key, sub_key = jax.random.split(state.evaluate_key) + key, sub_key = jax.random.split(state.randkey) keys = jax.random.split(key, self.pop_size) - pop = self.algorithm.ask(state) + pop = self.algorithm.ask(state.alg) - pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0))(state, pop) + pop_transformed = jax.vmap(self.algorithm.transform)(pop) - fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0))(keys, state, self.act_func, - pop_transformed) + fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0))( + keys, + state.pro, + self.algorithm.forward, + pop_transformed + ) - state = self.algorithm.tell(state, fitnesses) + fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses) - return state.update(evaluate_key=sub_key), fitnesses + alg_state = self.algorithm.tell(state.alg, fitnesses) + + return state.update( + randkey=sub_key, + alg=alg_state, + ), fitnesses def auto_run(self, ini_state): state = ini_state + compiled_step = jax.jit(self.step).lower(ini_state).compile() + for _ in range(self.generation_limit): self.generation_timestamp = time.time() - previous_pop = self.algorithm.ask(state) + previous_pop = self.algorithm.ask(state.alg) - state, fitnesses = self.step(state) + state, fitnesses = compiled_step(state) fitnesses = jax.device_get(fitnesses) + for idx, fitnesses_i in enumerate(fitnesses): + if np.isnan(fitnesses_i): + print("Fitness is nan") + print(previous_pop[0][idx], previous_pop[1][idx]) + assert False self.analysis(state, previous_pop, fitnesses) @@ -102,22 +117,15 @@ class Pipeline: max_idx = np.argmax(fitnesses) if fitnesses[max_idx] > self.best_fitness: self.best_fitness = fitnesses[max_idx] - self.best_genome = pop[max_idx] + self.best_genome = pop[0][max_idx], pop[1][max_idx] - member_count = jax.device_get(state.species_info.member_count) + member_count = jax.device_get(self.algorithm.member_count(state.alg)) species_sizes = [int(i) for i in member_count if i > 0] - print(f"Generation: {state.generation}", + print(f"Generation: {self.algorithm.generation(state.alg)}", f"species: {len(species_sizes)}, {species_sizes}", f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms") - def show(self, state, genome, *args, **kwargs): - transformed = self.algorithm.transform(state, genome) - self.problem.show(state.evaluate_key, state, self.act_func, transformed, *args, **kwargs) - - def pre_compile(self, state): - tic = time.time() - print("start compile") - self.step.lower(self, state).compile() - print(f"compile finished, cost time: {time.time() - tic}s") - + def show(self, state, best, *args, **kwargs): + transformed = self.algorithm.transform(best) + self.problem.show(state.randkey, state.pro, self.algorithm.forward, transformed, *args, **kwargs) diff --git a/problem/base.py b/problem/base.py index 28118f1..1e740c2 100644 --- a/problem/base.py +++ b/problem/base.py @@ -1,19 +1,14 @@ from typing import Callable -from config import ProblemConfig -from core.state import State +from utils import State class BaseProblem: - jitable = None - def __init__(self): - pass - def setup(self, randkey, state: State = State()): """initialize the state of the problem""" - raise NotImplementedError + pass def evaluate(self, randkey, state: State, act_func: Callable, params): """evaluate one individual""" diff --git a/problem/func_fit/func_fit.py b/problem/func_fit/func_fit.py index 360796e..5d09e26 100644 --- a/problem/func_fit/func_fit.py +++ b/problem/func_fit/func_fit.py @@ -1,24 +1,27 @@ import jax import jax.numpy as jnp +from utils import State from .. import BaseProblem -class FuncFit(BaseProblem): +class FuncFit(BaseProblem): jitable = True def __init__(self, - error_method: str = 'mse' - ): + error_method: str = 'mse' + ): super().__init__() assert error_method in {'mse', 'rmse', 'mae', 'mape'} self.error_method = error_method + def setup(self, randkey, state: State = State()): + return state def evaluate(self, randkey, state, act_func, params): - predict = act_func(state, self.inputs, params) + predict = jax.vmap(act_func, in_axes=(0, None))(self.inputs, params) if self.error_method == 'mse': loss = jnp.mean((predict - self.targets) ** 2) @@ -38,7 +41,7 @@ class FuncFit(BaseProblem): return -loss def show(self, randkey, state, act_func, params, *args, **kwargs): - predict = act_func(state, self.inputs, params) + predict = jax.vmap(act_func, in_axes=(0, None))(self.inputs, params) inputs, target, predict = jax.device_get([self.inputs, self.targets, predict]) loss = -self.evaluate(randkey, state, act_func, params) msg = "" diff --git a/problem/rl_env/__init__.py b/problem/rl_env/__init__.py index acc4653..ac447c1 100644 --- a/problem/rl_env/__init__.py +++ b/problem/rl_env/__init__.py @@ -1,2 +1,2 @@ -from .gymnax_env import GymNaxEnv, GymNaxConfig -from .brax_env import BraxEnv, BraxConfig +from .gymnax_env import GymNaxEnv +from .brax_env import BraxEnv diff --git a/problem/rl_env/gymnax_env.py b/problem/rl_env/gymnax_env.py index 5912df5..a0b2bb3 100644 --- a/problem/rl_env/gymnax_env.py +++ b/problem/rl_env/gymnax_env.py @@ -3,7 +3,6 @@ import gymnax from .rl_jit import RLEnv - class GymNaxEnv(RLEnv): def __init__(self, env_name): diff --git a/problem/rl_env/rl_jit.py b/problem/rl_env/rl_jit.py index 07e68f6..128ebfb 100644 --- a/problem/rl_env/rl_jit.py +++ b/problem/rl_env/rl_jit.py @@ -4,8 +4,8 @@ import jax from .. import BaseProblem -class RLEnv(BaseProblem): +class RLEnv(BaseProblem): jitable = True # TODO: move output transform to algorithm @@ -19,9 +19,10 @@ class RLEnv(BaseProblem): def cond_func(carry): _, _, _, done, _ = carry return ~done + def body_func(carry): obs, env_state, rng, _, tr = carry # total reward - action = act_func(state, obs, params) + action = act_func(obs, params) next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action) next_rng, _ = jax.random.split(rng) return next_obs, next_env_state, next_rng, done, tr + reward diff --git a/t.py b/t.py index 2318d8b..5335a7a 100644 --- a/t.py +++ b/t.py @@ -1,64 +1,4 @@ -from algorithm.neat import * -from utils import Act, Agg +import jax.numpy as jnp -import jax, jax.numpy as jnp - - -def main(): - - # index, bias, response, activation, aggregation - nodes = jnp.array([ - [0, 0, 1, 0, 0], # in[0] - [1, 0, 1, 0, 0], # in[1] - [2, 0.5, 1, 0, 0], # out[0], - [3, 1, 1, 0, 0], # hidden[0], - [4, -1, 1, 0, 0], # hidden[1], - ]) - - # in_node, out_node, enable, weight - conns = jnp.array([ - [0, 3, 1, 0.5], # in[0] -> hidden[0] - [1, 4, 1, 0.5], # in[1] -> hidden[1] - [3, 2, 1, 0.5], # hidden[0] -> out[0] - [4, 2, 1, 0.5], # hidden[1] -> out[0] - ]) - - genome = RecurrentGenome( - num_inputs=2, - num_outputs=1, - node_gene=DefaultNodeGene( - activation_default=Act.identity, - activation_options=(Act.identity, ), - aggregation_default=Agg.sum, - aggregation_options=(Agg.sum, ), - ), - activate_time=3 - ) - - transformed = genome.transform(nodes, conns) - print(*transformed, sep='\n') - - inputs = jnp.array([0, 0]) - outputs = genome.forward(inputs, transformed) - print(outputs) - - inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]]) - outputs = jax.jit(jax.vmap(genome.forward, in_axes=(0, None)))(inputs, transformed) - print(outputs) - expected: [[0.5], [0.75], [0.75], [1]] - - print('\n-------------------------------------------------------\n') - - conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0] - print(conns) - - transformed = genome.transform(nodes, conns) - print(*transformed, sep='\n') - - inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]]) - outputs = jax.vmap(genome.forward, in_axes=(0, None))(inputs, transformed) - print(outputs) - expected: [[0.5], [0.75], [0.5], [0.75]] - -if __name__ == '__main__': - main() \ No newline at end of file +a = jnp.zeros((0, 9, 9)) +print(a) \ No newline at end of file diff --git a/test/test_genome.py b/test/test_genome.py index d77a89a..712d272 100644 --- a/test/test_genome.py +++ b/test/test_genome.py @@ -26,6 +26,8 @@ def test_default(): genome = DefaultGenome( num_inputs=2, num_outputs=1, + max_nodes=5, + max_conns=4, node_gene=DefaultNodeGene( activation_default=Act.identity, activation_options=(Act.identity, ), @@ -80,6 +82,8 @@ def test_recurrent(): genome = RecurrentGenome( num_inputs=2, num_outputs=1, + max_nodes=5, + max_conns=4, node_gene=DefaultNodeGene( activation_default=Act.identity, activation_options=(Act.identity, ), diff --git a/utils/activation.py b/utils/activation.py index 9795c31..03a2b4d 100644 --- a/utils/activation.py +++ b/utils/activation.py @@ -6,48 +6,26 @@ class Act: @staticmethod def sigmoid(z): - z = jnp.clip(z * 5, -60, 60) + z = jnp.clip(5 * z, -10, 10) return 1 / (1 + jnp.exp(-z)) @staticmethod def tanh(z): - z = jnp.clip(z * 2.5, -60, 60) return jnp.tanh(z) @staticmethod def sin(z): - z = jnp.clip(z * 5, -60, 60) return jnp.sin(z) - @staticmethod - def gauss(z): - z = jnp.clip(z * 5, -3.4, 3.4) - return jnp.exp(-z ** 2) - @staticmethod def relu(z): return jnp.maximum(z, 0) - @staticmethod - def elu(z): - return jnp.where(z > 0, z, jnp.exp(z) - 1) - @staticmethod def lelu(z): leaky = 0.005 return jnp.where(z > 0, z, leaky * z) - @staticmethod - def selu(z): - lam = 1.0507009873554804934193349852946 - alpha = 1.6732632423543772848170429916717 - return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1)) - - @staticmethod - def softplus(z): - z = jnp.clip(z * 5, -60, 60) - return 0.2 * jnp.log(1 + jnp.exp(z)) - @staticmethod def identity(z): return z @@ -58,7 +36,11 @@ class Act: @staticmethod def inv(z): - z = jnp.maximum(z, 1e-7) + z = jnp.where( + z > 0, + jnp.maximum(z, 1e-7), + jnp.minimum(z, -1e-7) + ) return 1 / z @staticmethod @@ -68,24 +50,27 @@ class Act: @staticmethod def exp(z): - z = jnp.clip(z, -60, 60) + z = jnp.clip(z, -10, 10) return jnp.exp(z) @staticmethod def abs(z): return jnp.abs(z) - @staticmethod - def hat(z): - return jnp.maximum(0, 1 - jnp.abs(z)) - @staticmethod - def square(z): - return z ** 2 - - @staticmethod - def cube(z): - return z ** 3 +ACT_ALL = ( + Act.sigmoid, + Act.tanh, + Act.sin, + Act.relu, + Act.lelu, + Act.identity, + Act.clamped, + Act.inv, + Act.log, + Act.exp, + Act.abs, +) def act(idx, z, act_funcs): diff --git a/utils/aggregation.py b/utils/aggregation.py index 2e5d94a..63df1e4 100644 --- a/utils/aggregation.py +++ b/utils/aggregation.py @@ -51,6 +51,9 @@ class Agg: return mean_without_zeros +AGG_ALL = (Agg.sum, Agg.product, Agg.max, Agg.min, Agg.maxabs, Agg.median, Agg.mean) + + def agg(idx, z, agg_funcs): """ calculate activation function for inputs of node