finish all refactoring

This commit is contained in:
wls2002
2024-02-21 15:41:08 +08:00
parent aac41a089d
commit 6970e6a6d5
44 changed files with 856 additions and 825 deletions

View File

@@ -16,9 +16,30 @@ class BaseAlgorithm:
"""update the state of the algorithm""" """update the state of the algorithm"""
raise NotImplementedError raise NotImplementedError
def transform(self, state: State): def transform(self, individual):
"""transform the genome into a neural network""" """transform the genome into a neural network"""
raise NotImplementedError raise NotImplementedError
def forward(self, inputs, transformed): def forward(self, inputs, transformed):
raise NotImplementedError raise NotImplementedError
@property
def num_inputs(self):
raise NotImplementedError
@property
def num_outputs(self):
raise NotImplementedError
@property
def pop_size(self):
raise NotImplementedError
def member_count(self, state: State):
# to analysis the species
raise NotImplementedError
def generation(self, state: State):
# to analysis the algorithm
raise NotImplementedError

View File

@@ -0,0 +1,2 @@
from .hyperneat import HyperNEAT
from .substrate import BaseSubstrate, DefaultSubstrate, FullSubstrate

View File

@@ -0,0 +1,116 @@
import jax, jax.numpy as jnp
from utils import State, Act, Agg
from .. import BaseAlgorithm, NEAT
from ..neat.gene import BaseNodeGene, BaseConnGene
from ..neat.genome import RecurrentGenome
from .substrate import *
class HyperNEAT(BaseAlgorithm):
def __init__(
self,
substrate: BaseSubstrate,
neat: NEAT,
below_threshold: float = 0.3,
max_weight: float = 5.,
activation=Act.sigmoid,
aggregation=Agg.sum,
activate_time: int = 10,
):
assert substrate.query_coors.shape[1] == neat.num_inputs, \
"Substrate input size should be equal to NEAT input size"
self.substrate = substrate
self.neat = neat
self.below_threshold = below_threshold
self.max_weight = max_weight
self.hyper_genome = RecurrentGenome(
num_inputs=substrate.num_inputs,
num_outputs=substrate.num_outputs,
max_nodes=substrate.nodes_cnt,
max_conns=substrate.conns_cnt,
node_gene=HyperNodeGene(activation, aggregation),
conn_gene=HyperNEATConnGene(),
activate_time=activate_time,
)
def setup(self, randkey):
return State(
neat_state=self.neat.setup(randkey)
)
def ask(self, state: State):
return self.neat.ask(state.neat_state)
def tell(self, state: State, fitness):
return state.update(
neat_state=self.neat.tell(state.neat_state, fitness)
)
def transform(self, individual):
transformed = self.neat.transform(individual)
query_res = jax.vmap(self.neat.forward, in_axes=(0, None))(self.substrate.query_coors, transformed)
# mute the connection with weight below threshold
query_res = jnp.where(
(-self.below_threshold < query_res) & (query_res < self.below_threshold),
0.,
query_res
)
# make query res in range [-max_weight, max_weight]
query_res = jnp.where(query_res > 0, query_res - self.below_threshold, query_res)
query_res = jnp.where(query_res < 0, query_res + self.below_threshold, query_res)
query_res = query_res / (1 - self.below_threshold) * self.max_weight
h_nodes, h_conns = self.substrate.make_nodes(query_res), self.substrate.make_conn(query_res)
return self.hyper_genome.transform(h_nodes, h_conns)
def forward(self, inputs, transformed):
# add bias
inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])])
return self.hyper_genome.forward(inputs_with_bias, transformed)
@property
def num_inputs(self):
return self.substrate.num_inputs - 1 # remove bias
@property
def num_outputs(self):
return self.substrate.num_outputs
@property
def pop_size(self):
return self.neat.pop_size
def member_count(self, state: State):
return self.neat.member_count(state.neat_state)
def generation(self, state: State):
return self.neat.generation(state.neat_state)
class HyperNodeGene(BaseNodeGene):
def __init__(self,
activation=Act.sigmoid,
aggregation=Agg.sum,
):
super().__init__()
self.activation = activation
self.aggregation = aggregation
def forward(self, attrs, inputs):
return self.activation(
self.aggregation(inputs)
)
class HyperNEATConnGene(BaseConnGene):
custom_attrs = ['weight']
def forward(self, attrs, inputs):
weight = attrs[0]
return inputs * weight

View File

@@ -0,0 +1,3 @@
from .base import BaseSubstrate
from .default import DefaultSubstrate
from .full import FullSubstrate

View File

@@ -0,0 +1,27 @@
class BaseSubstrate:
def make_nodes(self, query_res):
raise NotImplementedError
def make_conn(self, query_res):
raise NotImplementedError
@property
def query_coors(self):
raise NotImplementedError
@property
def num_inputs(self):
raise NotImplementedError
@property
def num_outputs(self):
raise NotImplementedError
@property
def nodes_cnt(self):
raise NotImplementedError
@property
def conns_cnt(self):
raise NotImplementedError

View File

@@ -0,0 +1,38 @@
import jax.numpy as jnp
from . import BaseSubstrate
class DefaultSubstrate(BaseSubstrate):
def __init__(self, num_inputs, num_outputs, coors, nodes, conns):
self.inputs = num_inputs
self.outputs = num_outputs
self.coors = jnp.array(coors)
self.nodes = jnp.array(nodes)
self.conns = jnp.array(conns)
def make_nodes(self, query_res):
return self.nodes
def make_conn(self, query_res):
return self.conns.at[:, 3:].set(query_res) # change weight
@property
def query_coors(self):
return self.coors
@property
def num_inputs(self):
return self.inputs
@property
def num_outputs(self):
return self.outputs
@property
def nodes_cnt(self):
return self.nodes.shape[0]
@property
def conns_cnt(self):
return self.conns.shape[0]

View File

@@ -0,0 +1,76 @@
import numpy as np
from .default import DefaultSubstrate
class FullSubstrate(DefaultSubstrate):
def __init__(self,
input_coors=((-1, -1), (0, -1), (1, -1)),
hidden_coors=((-1, 0), (0, 0), (1, 0)),
output_coors=((0, 1),),
):
query_coors, nodes, conns = analysis_substrate(input_coors, output_coors, hidden_coors)
super().__init__(
len(input_coors),
len(output_coors),
query_coors,
nodes,
conns
)
def analysis_substrate(input_coors, output_coors, hidden_coors):
input_coors = np.array(input_coors)
output_coors = np.array(output_coors)
hidden_coors = np.array(hidden_coors)
cd = input_coors.shape[1] # coordinate dimensions
si = input_coors.shape[0] # input coordinate size
so = output_coors.shape[0] # output coordinate size
sh = hidden_coors.shape[0] # hidden coordinate size
input_idx = np.arange(si)
output_idx = np.arange(si, si + so)
hidden_idx = np.arange(si + so, si + so + sh)
total_conns = si * sh + sh * sh + sh * so
query_coors = np.zeros((total_conns, cd * 2))
correspond_keys = np.zeros((total_conns, 2))
# connect input to hidden
aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, input_coors, hidden_coors)
query_coors[0: si * sh, :] = aux_coors
correspond_keys[0: si * sh, :] = aux_keys
# connect hidden to hidden
aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, hidden_coors, hidden_coors)
query_coors[si * sh: si * sh + sh * sh, :] = aux_coors
correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys
# connect hidden to output
aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, hidden_coors, output_coors)
query_coors[si * sh + sh * sh:, :] = aux_coors
correspond_keys[si * sh + sh * sh:, :] = aux_keys
nodes = np.concatenate((input_idx, output_idx, hidden_idx))[..., np.newaxis]
conns = np.zeros((correspond_keys.shape[0], 4), dtype=np.float32) # input_idx, output_idx, enabled, weight
conns[:, 0:2] = correspond_keys
conns[:, 2] = 1 # enabled is True
return query_coors, nodes, conns
def cartesian_product(keys1, keys2, coors1, coors2):
len1 = keys1.shape[0]
len2 = keys2.shape[0]
repeated_coors1 = np.repeat(coors1, len2, axis=0)
repeated_keys1 = np.repeat(keys1, len2)
tiled_coors2 = np.tile(coors2, (len1, 1))
tiled_keys2 = np.tile(keys2, len1)
new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1)
correspond_keys = np.column_stack((repeated_keys1, tiled_keys2))
return new_coors, correspond_keys

View File

@@ -1,3 +1,5 @@
from .gene import * from .gene import *
from .genome import * from .genome import *
from .species import *
from .neat import NEAT from .neat import NEAT

View File

@@ -3,7 +3,8 @@ import jax, jax.numpy as jnp
from .base import BaseCrossover from .base import BaseCrossover
class DefaultCrossover(BaseCrossover): class DefaultCrossover(BaseCrossover):
def __call__(self, randkey, genome, nodes1, nodes2, conns1, conns2):
def __call__(self, randkey, genome, nodes1, conns1, nodes2, conns2):
""" """
use genome1 and genome2 to generate a new genome use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!) notice that genome1 should have higher fitness than genome2 (genome1 is winner!)

View File

@@ -92,7 +92,7 @@ class DefaultMutation(BaseMutation):
return nodes_, conns_ return nodes_, conns_
def successful(): def successful():
return nodes_, genome.add_conn(conns_, i_key, o_key, True, genome.conns.new_custom_attrs()) return nodes_, genome.add_conn(conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs())
def already_exist(): def already_exist():
return nodes_, conns_.at[conn_pos, 2].set(True) return nodes_, conns_.at[conn_pos, 2].set(True)
@@ -105,11 +105,12 @@ class DefaultMutation(BaseMutation):
return jax.lax.cond( return jax.lax.cond(
is_already_exist, is_already_exist,
already_exist, already_exist,
jax.lax.cond( lambda:
is_cycle, jax.lax.cond(
nothing, is_cycle,
successful nothing,
) successful
)
) )
elif genome.network_type == 'recurrent': elif genome.network_type == 'recurrent':
@@ -138,23 +139,23 @@ class DefaultMutation(BaseMutation):
k1, k2, k3, k4 = jax.random.split(randkey, num=4) k1, k2, k3, k4 = jax.random.split(randkey, num=4)
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,)) r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
def no(k, g): def no(key_, nodes_, conns_):
return g return nodes_, conns_
genome = jax.lax.cond(r1 < self.node_add, mutate_add_node, no, k1, nodes, conns) nodes, conns = jax.lax.cond(r1 < self.node_add, mutate_add_node, no, k1, nodes, conns)
genome = jax.lax.cond(r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns) nodes, conns = jax.lax.cond(r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns)
genome = jax.lax.cond(r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns) nodes, conns = jax.lax.cond(r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns)
genome = jax.lax.cond(r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns) nodes, conns = jax.lax.cond(r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns)
return genome return nodes, conns
def mutate_values(self, randkey, genome, nodes, conns): def mutate_values(self, randkey, genome, nodes, conns):
k1, k2 = jax.random.split(randkey, num=2) k1, k2 = jax.random.split(randkey, num=2)
nodes_keys = jax.random.split(k1, num=genome.nodes.shape[0]) nodes_keys = jax.random.split(k1, num=nodes.shape[0])
conns_keys = jax.random.split(k2, num=genome.conns.shape[0]) conns_keys = jax.random.split(k2, num=conns.shape[0])
new_nodes = jax.vmap(genome.nodes.mutate, in_axes=(0, 0))(nodes_keys, nodes) new_nodes = jax.vmap(genome.node_gene.mutate, in_axes=(0, 0))(nodes_keys, nodes)
new_conns = jax.vmap(genome.conns.mutate, in_axes=(0, 0))(conns_keys, conns) new_conns = jax.vmap(genome.conn_gene.mutate, in_axes=(0, 0))(conns_keys, conns)
# nan nodes not changed # nan nodes not changed
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes) new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)

View File

@@ -7,8 +7,7 @@ from . import BaseConnGene
class DefaultConnGene(BaseConnGene): class DefaultConnGene(BaseConnGene):
"Default connection gene, with the same behavior as in NEAT-python." "Default connection gene, with the same behavior as in NEAT-python."
fixed_attrs = ['input_index', 'output_index', 'enabled'] custom_attrs = ['weight']
attrs = ['weight']
def __init__( def __init__(
self, self,

View File

@@ -9,7 +9,6 @@ from . import BaseNodeGene
class DefaultNodeGene(BaseNodeGene): class DefaultNodeGene(BaseNodeGene):
"Default node gene, with the same behavior as in NEAT-python." "Default node gene, with the same behavior as in NEAT-python."
fixed_attrs = ['index']
custom_attrs = ['bias', 'response', 'aggregation', 'activation'] custom_attrs = ['bias', 'response', 'aggregation', 'activation']
def __init__( def __init__(
@@ -82,8 +81,8 @@ class DefaultNodeGene(BaseNodeGene):
return ( return (
jnp.abs(node1[1] - node2[1]) + jnp.abs(node1[1] - node2[1]) +
jnp.abs(node1[2] - node2[2]) + jnp.abs(node1[2] - node2[2]) +
node1[3] != node2[3] + (node1[3] != node2[3]) +
node1[4] != node2[4] (node1[4] != node2[4])
) )
def forward(self, attrs, inputs): def forward(self, attrs, inputs):

View File

@@ -4,7 +4,6 @@ from utils import fetch_first
class BaseGenome: class BaseGenome:
network_type = None network_type = None
def __init__( def __init__(

View File

@@ -1,3 +1,5 @@
from typing import Callable
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from utils import unflatten_conns, topological_sort, I_INT from utils import unflatten_conns, topological_sort, I_INT
@@ -13,10 +15,20 @@ class DefaultGenome(BaseGenome):
def __init__(self, def __init__(self,
num_inputs: int, num_inputs: int,
num_outputs: int, num_outputs: int,
max_nodes=5,
max_conns=4,
node_gene: BaseNodeGene = DefaultNodeGene(), node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(), conn_gene: BaseConnGene = DefaultConnGene(),
output_transform: Callable = None
): ):
super().__init__(num_inputs, num_outputs, node_gene, conn_gene) super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
if output_transform is not None:
try:
aux = output_transform(jnp.zeros(num_outputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
self.output_transform = output_transform
def transform(self, nodes, conns): def transform(self, nodes, conns):
u_conns = unflatten_conns(nodes, conns) u_conns = unflatten_conns(nodes, conns)
@@ -72,4 +84,7 @@ class DefaultGenome(BaseGenome):
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0)) vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
return vals[self.output_idx] if self.output_transform is None:
return vals[self.output_idx]
else:
return self.output_transform(vals[self.output_idx])

View File

@@ -13,11 +13,13 @@ class RecurrentGenome(BaseGenome):
def __init__(self, def __init__(self,
num_inputs: int, num_inputs: int,
num_outputs: int, num_outputs: int,
max_nodes: int,
max_conns: int,
node_gene: BaseNodeGene = DefaultNodeGene(), node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(), conn_gene: BaseConnGene = DefaultConnGene(),
activate_time: int = 10, activate_time: int = 10,
): ):
super().__init__(num_inputs, num_outputs, node_gene, conn_gene) super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
self.activate_time = activate_time self.activate_time = activate_time
def transform(self, nodes, conns): def transform(self, nodes, conns):

View File

@@ -1,20 +1,19 @@
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from utils import State from utils import State
from .. import BaseAlgorithm from .. import BaseAlgorithm
from .genome import *
from .species import * from .species import *
from .ga import * from .ga import *
class NEAT(BaseAlgorithm): class NEAT(BaseAlgorithm):
def __init__( def __init__(
self, self,
genome: BaseGenome,
species: BaseSpecies, species: BaseSpecies,
mutation: BaseMutation = DefaultMutation(), mutation: BaseMutation = DefaultMutation(),
crossover: BaseCrossover = DefaultCrossover(), crossover: BaseCrossover = DefaultCrossover(),
): ):
self.genome = genome self.genome = species.genome
self.species = species self.species = species
self.mutation = mutation self.mutation = mutation
self.crossover = crossover self.crossover = crossover
@@ -23,14 +22,14 @@ class NEAT(BaseAlgorithm):
k1, k2 = jax.random.split(randkey, 2) k1, k2 = jax.random.split(randkey, 2)
return State( return State(
randkey=k1, randkey=k1,
generation=0, generation=jnp.array(0.),
next_node_key=max(*self.genome.input_idx, *self.genome.output_idx) + 2, next_node_key=jnp.array(max(*self.genome.input_idx, *self.genome.output_idx) + 2, dtype=jnp.float32),
# inputs nodes, output nodes, 1 hidden node # inputs nodes, output nodes, 1 hidden node
species=self.species.setup(k2), species=self.species.setup(k2),
) )
def ask(self, state: State): def ask(self, state: State):
return self.species.ask(state) return self.species.ask(state.species)
def tell(self, state: State, fitness): def tell(self, state: State, fitness):
k1, k2, randkey = jax.random.split(state.randkey, 3) k1, k2, randkey = jax.random.split(state.randkey, 3)
@@ -40,25 +39,39 @@ class NEAT(BaseAlgorithm):
randkey=randkey randkey=randkey
) )
state, winner, loser, elite_mask = self.species.update_species(state, fitness, state.generation) species_state, winner, loser, elite_mask = self.species.update_species(state.species, fitness, state.generation)
state = state.update(species=species_state)
state = self.create_next_generation(k2, state, winner, loser, elite_mask) state = self.create_next_generation(k2, state, winner, loser, elite_mask)
state = self.species.speciate(state, state.generation) species_state = self.species.speciate(state.species, state.generation)
state = state.update(species=species_state)
return state return state
def transform(self, state: State): def transform(self, individual):
"""transform the genome into a neural network""" """transform the genome into a neural network"""
raise NotImplementedError nodes, conns = individual
return self.genome.transform(nodes, conns)
def forward(self, inputs, transformed): def forward(self, inputs, transformed):
raise NotImplementedError return self.genome.forward(inputs, transformed)
@property
def num_inputs(self):
return self.genome.num_inputs
@property
def num_outputs(self):
return self.genome.num_outputs
@property
def pop_size(self):
return self.species.pop_size
def create_next_generation(self, randkey, state, winner, loser, elite_mask): def create_next_generation(self, randkey, state, winner, loser, elite_mask):
# prepare random keys # prepare random keys
pop_size = self.species.pop_size pop_size = self.species.pop_size
new_node_keys = jnp.arange(pop_size) + state.species.next_node_key new_node_keys = jnp.arange(pop_size) + state.next_node_key
k1, k2 = jax.random.split(randkey, 2) k1, k2 = jax.random.split(randkey, 2)
crossover_rand_keys = jax.random.split(k1, pop_size) crossover_rand_keys = jax.random.split(k1, pop_size)
@@ -69,11 +82,11 @@ class NEAT(BaseAlgorithm):
# batch crossover # batch crossover
n_nodes, n_conns = (jax.vmap(self.crossover, in_axes=(0, None, 0, 0, 0, 0)) n_nodes, n_conns = (jax.vmap(self.crossover, in_axes=(0, None, 0, 0, 0, 0))
(crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc)) (crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc))
# batch mutation # batch mutation
m_n_nodes, m_n_conns = (jax.vmap(self.mutation, in_axes=(0, None, 0, 0, 0)) m_n_nodes, m_n_conns = (jax.vmap(self.mutation, in_axes=(0, None, 0, 0, 0))
(mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys)) (mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys))
# elitism don't mutate # elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes) pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
@@ -92,3 +105,9 @@ class NEAT(BaseAlgorithm):
next_node_key=next_node_key, next_node_key=next_node_key,
) )
def member_count(self, state: State):
return state.species.member_count
def generation(self, state: State):
# to analysis the algorithm
return state.generation

View File

@@ -2,9 +2,10 @@ import numpy as np
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from utils import State, rank_elements, argmin_with_mask, fetch_first from utils import State, rank_elements, argmin_with_mask, fetch_first
from ..genome import BaseGenome from ..genome import BaseGenome
from .base import BaseSpecies
class DefaultSpecies: class DefaultSpecies(BaseSpecies):
def __init__(self, def __init__(self,
genome: BaseGenome, genome: BaseGenome,
@@ -18,9 +19,8 @@ class DefaultSpecies:
genome_elitism: int = 2, genome_elitism: int = 2,
survival_threshold: float = 0.2, survival_threshold: float = 0.2,
min_species_size: int = 1, min_species_size: int = 1,
compatibility_threshold: float = 3.5 compatibility_threshold: float = 3.
): ):
self.genome = genome self.genome = genome
self.pop_size = pop_size self.pop_size = pop_size
self.species_size = species_size self.species_size = species_size
@@ -59,8 +59,12 @@ class DefaultSpecies:
center_nodes = center_nodes.at[0].set(pop_nodes[0]) center_nodes = center_nodes.at[0].set(pop_nodes[0])
center_conns = center_conns.at[0].set(pop_conns[0]) center_conns = center_conns.at[0].set(pop_conns[0])
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns))
return State( return State(
randkey=randkey, randkey=randkey,
pop_nodes=pop_nodes,
pop_conns=pop_conns,
species_keys=species_keys, species_keys=species_keys,
best_fitness=best_fitness, best_fitness=best_fitness,
last_improved=last_improved, last_improved=last_improved,
@@ -68,7 +72,7 @@ class DefaultSpecies:
idx2species=idx2species, idx2species=idx2species,
center_nodes=center_nodes, center_nodes=center_nodes,
center_conns=center_conns, center_conns=center_conns,
next_species_key=1, # 0 is reserved for the first species next_species_key=jnp.array(1), # 0 is reserved for the first species
) )
def ask(self, state): def ask(self, state):
@@ -99,7 +103,7 @@ class DefaultSpecies:
# crossover info # crossover info
winner, loser, elite_mask = self.create_crossover_pair(state, k1, spawn_number, fitness) winner, loser, elite_mask = self.create_crossover_pair(state, k1, spawn_number, fitness)
return state(randkey=k2), winner, loser, elite_mask return state.update(randkey=k2), winner, loser, elite_mask
def update_species_fitness(self, state, fitness): def update_species_fitness(self, state, fitness):
""" """
@@ -156,17 +160,17 @@ class DefaultSpecies:
jnp.nan, # last_improved jnp.nan, # last_improved
jnp.nan, # member_count jnp.nan, # member_count
-jnp.inf, # species_fitness -jnp.inf, # species_fitness
jnp.full_like(center_nodes[idx], jnp.nan), # center_nodes jnp.full_like(state.center_nodes[idx], jnp.nan), # center_nodes
jnp.full_like(center_conns[idx], jnp.nan), # center_conns jnp.full_like(state.center_conns[idx], jnp.nan), # center_conns
), # stagnation species ), # stagnation species
lambda: ( lambda: (
species_keys[idx], state.species_keys[idx],
best_fitness[idx], best_fitness[idx],
last_improved[idx], last_improved[idx],
state.member_count[idx], state.member_count[idx],
species_fitness[idx], species_fitness[idx],
center_nodes[idx], state.center_nodes[idx],
center_conns[idx] state.center_conns[idx]
) # not stagnation species ) # not stagnation species
) )
@@ -216,7 +220,7 @@ class DefaultSpecies:
spawn_number = spawn_number.astype(jnp.int32) spawn_number = spawn_number.astype(jnp.int32)
# must control the sum of spawn_number to be equal to pop_size # must control the sum of spawn_number to be equal to pop_size
error = state.P - jnp.sum(spawn_number) error = self.pop_size - jnp.sum(spawn_number)
# add error to the first species to control the sum of spawn_number # add error to the first species to control the sum of spawn_number
spawn_number = spawn_number.at[0].add(error) spawn_number = spawn_number.at[0].add(error)
@@ -287,14 +291,14 @@ class DefaultSpecies:
def body_func(carry): def body_func(carry):
i, i2s, cns, ccs, o2c = carry i, i2s, cns, ccs, o2c = carry
distances = o2p_distance_func(cns, ccs, state.pop_nodes, state.pop_conns) distances = o2p_distance_func(cns[i], ccs[i], state.pop_nodes, state.pop_conns)
# find the closest one # find the closest one
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s)) closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i]) i2s = i2s.at[closest_idx].set(state.species_keys[i])
cns = cns.set(i, state.pop_nodes[closest_idx]) cns = cns.at[i].set(state.pop_nodes[closest_idx])
ccs = ccs.set(i, state.pop_conns[closest_idx]) ccs = ccs.at[i].set(state.pop_conns[closest_idx])
# the genome with closest_idx will become the new center, thus its distance to center is 0. # the genome with closest_idx will become the new center, thus its distance to center is 0.
o2c = o2c.at[closest_idx].set(0) o2c = o2c.at[closest_idx].set(0)
@@ -346,8 +350,8 @@ class DefaultSpecies:
o2c = o2c.at[idx].set(0) o2c = o2c.at[idx].set(0)
# update center genomes # update center genomes
cns = cns.set(i, state.pop_nodes[idx]) cns = cns.at[i].set(state.pop_nodes[idx])
ccs = ccs.set(i, state.pop_conns[idx]) ccs = ccs.at[i].set(state.pop_conns[idx])
# find the members for the new species # find the members for the new species
i2s, o2c = speciate_by_threshold(i, i2s, cns, ccs, sk, o2c) i2s, o2c = speciate_by_threshold(i, i2s, cns, ccs, sk, o2c)
@@ -384,7 +388,7 @@ class DefaultSpecies:
_, idx2species, center_nodes, center_conns, species_keys, _, next_species_key = jax.lax.while_loop( _, idx2species, center_nodes, center_conns, species_keys, _, next_species_key = jax.lax.while_loop(
cond_func, cond_func,
body_func, body_func,
(0, state.idx2species, state.center_nodes, center_conns, state.species_info.species_keys, o2c_distances, (0, state.idx2species, center_nodes, center_conns, state.species_keys, o2c_distances,
state.next_species_key) state.next_species_key)
) )
@@ -401,8 +405,8 @@ class DefaultSpecies:
def count_members(idx): def count_members(idx):
return jax.lax.cond( return jax.lax.cond(
jnp.isnan(species_keys[idx]), # if the species is not existing jnp.isnan(species_keys[idx]), # if the species is not existing
lambda _: jnp.nan, # nan lambda: jnp.nan, # nan
lambda _: jnp.sum(idx2species == species_keys[idx], dtype=jnp.float32) # count members lambda: jnp.sum(idx2species == species_keys[idx], dtype=jnp.float32) # count members
) )
member_count = jax.vmap(count_members)(self.species_arange) member_count = jax.vmap(count_members)(self.species_arange)
@@ -422,7 +426,8 @@ class DefaultSpecies:
""" """
The distance between two genomes The distance between two genomes
""" """
return self.node_distance(nodes1, nodes2) + self.conn_distance(conns1, conns2) d = self.node_distance(nodes1, nodes2) + self.conn_distance(conns1, conns2)
return d
def node_distance(self, nodes1, nodes2): def node_distance(self, nodes1, nodes2):
""" """
@@ -494,18 +499,18 @@ def initialize_population(pop_size, genome):
o_nodes[input_idx, 0] = genome.input_idx o_nodes[input_idx, 0] = genome.input_idx
o_nodes[output_idx, 0] = genome.output_idx o_nodes[output_idx, 0] = genome.output_idx
o_nodes[new_node_key, 0] = new_node_key # one hidden node o_nodes[new_node_key, 0] = new_node_key # one hidden node
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = genome.node_gene.new_attrs() o_nodes[np.concatenate([input_idx, output_idx]), 1:] = genome.node_gene.new_custom_attrs()
o_nodes[new_node_key, 1:] = genome.node_gene.new_attrs() # one hidden node o_nodes[new_node_key, 1:] = genome.node_gene.new_custom_attrs() # one hidden node
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)] # input nodes to hidden input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)] # input nodes to hidden
o_conns[input_idx, 0:2] = input_conns # in key, out key o_conns[input_idx, 0:2] = input_conns # in key, out key
o_conns[input_idx, 2] = True # enabled o_conns[input_idx, 2] = True # enabled
o_conns[input_idx, 3:] = genome.conn_gene.new_conn_attrs() o_conns[input_idx, 3:] = genome.conn_gene.new_custom_attrs()
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx] # hidden to output nodes output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx] # hidden to output nodes
o_conns[output_idx, 0:2] = output_conns # in key, out key o_conns[output_idx, 0:2] = output_conns # in key, out key
o_conns[output_idx, 2] = True # enabled o_conns[output_idx, 2] = True # enabled
o_conns[output_idx, 3:] = genome.conn_gene.new_conn_attrs() o_conns[output_idx, 3:] = genome.conn_gene.new_custom_attrs()
# repeat origin genome for P times to create population # repeat origin genome for P times to create population
pop_nodes = np.tile(o_nodes, (pop_size, 1, 1)) pop_nodes = np.tile(o_nodes, (pop_size, 1, 1))

View File

@@ -1,38 +1,36 @@
import jax.numpy as jnp
from config import *
from pipeline import Pipeline from pipeline import Pipeline
from algorithm import NEAT from algorithm.neat import *
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import BraxEnv, BraxConfig
def example_conf():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=10000,
pop_size=100
),
neat=NeatConfig(
inputs=27,
outputs=8,
),
gene=NormalGeneConfig(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
problem=BraxConfig(
env_name="ant"
)
)
from problem.rl_env import BraxEnv
from utils import Act
if __name__ == '__main__': if __name__ == '__main__':
conf = example_conf() pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=27,
num_outputs=8,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
)
),
pop_size=1000,
species_size=10,
),
),
problem=BraxEnv(
env_name='ant',
),
generation_limit=10000,
fitness_target=5000
)
algorithm = NEAT(conf, NormalGene) # initialize state
pipeline = Pipeline(conf, algorithm, BraxEnv)
state = pipeline.setup() state = pipeline.setup()
pipeline.pre_compile(state) # print(state)
# run until terminate
state, best = pipeline.auto_run(state) state, best = pipeline.auto_run(state)

View File

@@ -1,42 +1,36 @@
import jax.numpy as jnp
from config import *
from pipeline import Pipeline from pipeline import Pipeline
from algorithm import NEAT from algorithm.neat import *
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import BraxEnv, BraxConfig
# ['ant', 'halfcheetah', 'hopper', 'humanoid', 'humanoidstandup', 'inverted_pendulum', 'inverted_double_pendulum', 'pusher', 'reacher', 'walker2d']
def example_conf():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=10000,
generation_limit=10,
pop_size=100
),
neat=NeatConfig(
inputs=17,
outputs=6,
),
gene=NormalGeneConfig(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
problem=BraxConfig(
env_name="halfcheetah"
)
)
from problem.rl_env import BraxEnv
from utils import Act
if __name__ == '__main__': if __name__ == '__main__':
conf = example_conf() pipeline = Pipeline(
algorithm = NEAT(conf, NormalGene) algorithm=NEAT(
pipeline = Pipeline(conf, algorithm, BraxEnv) species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=17,
num_outputs=6,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
)
),
pop_size=1000,
species_size=10,
),
),
problem=BraxEnv(
env_name='halhcheetah',
),
generation_limit=10000,
fitness_target=5000
)
# initialize state
state = pipeline.setup() state = pipeline.setup()
pipeline.pre_compile(state) # print(state)
# run until terminate
state, best = pipeline.auto_run(state) state, best = pipeline.auto_run(state)
pipeline.show(state, best, save_path="half_cheetah.gif", )

View File

@@ -1,38 +1,36 @@
import jax.numpy as jnp
from config import *
from pipeline import Pipeline from pipeline import Pipeline
from algorithm import NEAT from algorithm.neat import *
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import BraxEnv, BraxConfig
def example_conf():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=10000,
pop_size=1000
),
neat=NeatConfig(
inputs=11,
outputs=2,
),
gene=NormalGeneConfig(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
problem=BraxConfig(
env_name="reacher"
)
)
from problem.rl_env import BraxEnv
from utils import Act
if __name__ == '__main__': if __name__ == '__main__':
conf = example_conf() pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=11,
num_outputs=2,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
)
),
pop_size=100,
species_size=10,
),
),
problem=BraxEnv(
env_name='reacher',
),
generation_limit=10000,
fitness_target=5000
)
algorithm = NEAT(conf, NormalGene) # initialize state
pipeline = Pipeline(conf, algorithm, BraxEnv)
state = pipeline.setup() state = pipeline.setup()
pipeline.pre_compile(state) # print(state)
# run until terminate
state, best = pipeline.auto_run(state) state, best = pipeline.auto_run(state)

View File

@@ -1,73 +0,0 @@
import imageio
import jax
import brax
from brax import envs
from brax.io import image
import matplotlib.pyplot as plt
import time
from tqdm import tqdm
import numpy as np
def inference_func(key, *args):
return jax.random.normal(key, shape=(env.action_size,))
env_name = "ant"
backend = "generalized"
env = envs.create(env_name=env_name, backend=backend)
jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.jit(inference_func)
rng = jax.random.PRNGKey(seed=1)
ori_state = jit_env_reset(rng=rng)
state = ori_state
render_history = []
for i in range(100):
act_rng, rng = jax.random.split(rng)
tic = time.time()
act = jit_inference_fn(act_rng, state.obs)
state = jit_env_step(state, act)
print("step time: ", time.time() - tic)
render_history.append(state.pipeline_state)
# img = image.render_array(sys=env.sys, state=pipeline_state, width=512, height=512)
# print("render time: ", time.time() - tic)
# plt.imsave("../images/ant_{}.png".format(i), img)
reward = state.reward
done = state.done
print(i, reward)
render_history = jax.device_get(render_history)
# print(render_history)
imgs = [image.render_array(sys=env.sys, state=s, width=512, height=512) for s in tqdm(render_history)]
# for i, s in enumerate(tqdm(render_history)):
# img = image.render_array(sys=env.sys, state=s, width=512, height=512)
# print(img.shape)
# # print(type(img))
# plt.imsave("../images/ant_{}.png".format(i), img)
def create_gif(image_list, gif_name, duration):
with imageio.get_writer(gif_name, mode='I', duration=duration) as writer:
for image in image_list:
# 确保图像的数据类型正确
formatted_image = np.array(image, dtype=np.uint8)
writer.append_data(formatted_image)
create_gif(imgs, "../images/ant.gif", 0.1)

View File

@@ -1,54 +0,0 @@
import brax
from brax import envs
from brax.envs.wrappers import gym as gym_wrapper
from brax.io import image
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import traceback
# print(f"Using Brax {brax.__version__}, Jax {jax.__version__}")
# print("From GymWrapper, env.reset()")
# try:
# env = envs.create("inverted_pendulum",
# batch_size=1,
# episode_length=150,
# backend='generalized')
# env = gym_wrapper.GymWrapper(env)
# env.reset()
# img = env.render(mode='rgb_array')
# plt.imshow(img)
# except Exception:
# traceback.print_exc()
#
# print("From GymWrapper, env.reset() and action")
# try:
# env = envs.create("inverted_pendulum",
# batch_size=1,
# episode_length=150,
# backend='generalized')
# env = gym_wrapper.GymWrapper(env)
# env.reset()
# action = jnp.zeros(env.action_space.shape)
# env.step(action)
# img = env.render(mode='rgb_array')
# plt.imshow(img)
# except Exception:
# traceback.print_exc()
print("From brax env")
try:
env = envs.create("inverted_pendulum",
batch_size=1,
episode_length=150,
backend='generalized')
key = jax.random.PRNGKey(0)
initial_env_state = env.reset(key)
base_state = initial_env_state.pipeline_state
pipeline_state = env.pipeline_init(base_state.q.ravel(), base_state.qd.ravel())
img = image.render_array(sys=env.sys, state=pipeline_state, width=256, height=256)
print(f"pixel values: [{img.min()}, {img.max()}]")
plt.imshow(img)
plt.show()
except Exception:
traceback.print_exc()

View File

@@ -1,32 +1,31 @@
from config import *
from pipeline import Pipeline from pipeline import Pipeline
from algorithm import NEAT from algorithm.neat import *
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.func_fit import XOR, FuncFitConfig from problem.func_fit import XOR3d
if __name__ == '__main__': if __name__ == '__main__':
# running config pipeline = Pipeline(
config = Config( algorithm=NEAT(
basic=BasicConfig( species=DefaultSpecies(
seed=42, genome=DefaultGenome(
fitness_target=-1e-2, num_inputs=3,
pop_size=10000 num_outputs=1,
max_nodes=50,
max_conns=100,
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
),
), ),
neat=NeatConfig( problem=XOR3d(),
inputs=2, generation_limit=10000,
outputs=1 fitness_target=-1e-8
),
gene=NormalGeneConfig(),
problem=FuncFitConfig(
error_method='rmse'
)
) )
# define algorithm: NEAT with NormalGene
algorithm = NEAT(config, NormalGene)
# full pipeline
pipeline = Pipeline(config, algorithm, XOR)
# initialize state # initialize state
state = pipeline.setup() state = pipeline.setup()
# print(state)
# run until terminate # run until terminate
state, best = pipeline.auto_run(state) state, best = pipeline.auto_run(state)
# show result # show result

View File

@@ -0,0 +1,51 @@
from pipeline import Pipeline
from algorithm.neat import *
from algorithm.hyperneat import *
from utils import Act
from problem.func_fit import XOR3d
if __name__ == '__main__':
pipeline = Pipeline(
algorithm=HyperNEAT(
substrate=FullSubstrate(
input_coors=[(-1, -1), (0.333, -1), (-0.333, -1), (1, -1)],
hidden_coors=[
(-1, -0.5), (0.333, -0.5), (-0.333, -0.5), (1, -0.5),
(-1, 0), (0.333, 0), (-0.333, 0), (1, 0),
(-1, 0.5), (0.333, 0.5), (-0.333, 0.5), (1, 0.5),
],
output_coors=[(0, 1), ],
),
neat=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=4, # [-1, -1, -1, 0]
num_outputs=1,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
),
),
activation=Act.sigmoid,
activate_time=10,
),
problem=XOR3d(),
generation_limit=300,
fitness_target=-1e-6
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)
# show result
pipeline.show(state, best)

View File

@@ -1,41 +0,0 @@
from config import *
from pipeline import Pipeline
from algorithm.neat import NormalGene, NormalGeneConfig
from algorithm.hyperneat import HyperNEAT, NormalSubstrate, NormalSubstrateConfig
from problem.func_fit import XOR3d, FuncFitConfig
from utils import Act
if __name__ == '__main__':
config = Config(
basic=BasicConfig(
seed=42,
fitness_target=0,
pop_size=1000
),
neat=NeatConfig(
max_nodes=50,
max_conns=100,
max_species=30,
inputs=4,
outputs=1
),
hyperneat=HyperNeatConfig(
inputs=3,
outputs=1
),
substrate=NormalSubstrateConfig(
input_coors=((-1, -1), (-0.5, -1), (0.5, -1), (1, -1)),
),
gene=NormalGeneConfig(
activation_default=Act.tanh,
activation_options=(Act.tanh, ),
),
problem=FuncFitConfig()
)
algorithm = HyperNEAT(config, NormalGene, NormalSubstrate)
pipeline = Pipeline(config, algorithm, XOR3d)
state = pipeline.setup()
state, best = pipeline.auto_run(state)
pipeline.show(state, best)

View File

@@ -1,41 +1,41 @@
from config import *
from pipeline import Pipeline from pipeline import Pipeline
from algorithm import NEAT from algorithm.neat import *
from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig
from problem.func_fit import XOR3d, FuncFitConfig
from problem.func_fit import XOR3d
from utils.activation import ACT_ALL
from utils.aggregation import AGG_ALL
if __name__ == '__main__': if __name__ == '__main__':
config = Config( pipeline = Pipeline(
basic=BasicConfig( seed=0,
seed=42, algorithm=NEAT(
fitness_target=-1e-2, species=DefaultSpecies(
generation_limit=300, genome=RecurrentGenome(
pop_size=1000 num_inputs=3,
num_outputs=1,
max_nodes=50,
max_conns=100,
activate_time=5,
node_gene=DefaultNodeGene(
activation_options=ACT_ALL,
# aggregation_options=AGG_ALL,
activation_replace_rate=0.2
),
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
),
), ),
neat=NeatConfig( problem=XOR3d(),
network_type="recurrent", generation_limit=10000,
max_nodes=50, fitness_target=-1e-8
max_conns=100,
max_species=30,
conn_add=0.5,
conn_delete=0.5,
node_add=0.4,
node_delete=0.4,
inputs=3,
outputs=1
),
gene=RecurrentGeneConfig(
activate_times=10
),
problem=FuncFitConfig(
error_method='rmse'
)
) )
algorithm = NEAT(config, RecurrentGene) # initialize state
pipeline = Pipeline(config, algorithm, XOR3d)
state = pipeline.setup() state = pipeline.setup()
pipeline.pre_compile(state) # print(state)
# run until terminate
state, best = pipeline.auto_run(state) state, best = pipeline.auto_run(state)
# show result
pipeline.show(state, best) pipeline.show(state, best)

View File

@@ -1,36 +0,0 @@
from config import *
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.func_fit import XOR, FuncFitConfig
if __name__ == '__main__':
config = Config(
basic=BasicConfig(
seed=42,
fitness_target=-1e-2,
pop_size=10000
),
neat=NeatConfig(
max_nodes=50,
max_conns=100,
max_species=30,
conn_add=0.8,
conn_delete=0,
node_add=0.4,
node_delete=0,
inputs=2,
outputs=1
),
gene=NormalGeneConfig(),
problem=FuncFitConfig(
error_method='rmse'
)
)
algorithm = NEAT(config, NormalGene)
pipeline = Pipeline(config, algorithm, XOR)
state = pipeline.setup()
pipeline.pre_compile(state)
state, best = pipeline.auto_run(state)
pipeline.show(state, best)

View File

@@ -1,39 +0,0 @@
import jax.numpy as jnp
from config import *
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
def example_conf():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=0,
pop_size=10000
),
neat=NeatConfig(
inputs=6,
outputs=3,
),
gene=NormalGeneConfig(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
problem=GymNaxConfig(
env_name='Acrobot-v1',
output_transform=lambda out: jnp.argmax(out) # the action of acrobot is {0, 1, 2}
)
)
if __name__ == '__main__':
conf = example_conf()
algorithm = NEAT(conf, NormalGene)
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
state = pipeline.setup()
pipeline.pre_compile(state)
state, best = pipeline.auto_run(state)

34
examples/gymnax/arcbot.py Normal file
View File

@@ -0,0 +1,34 @@
import jax.numpy as jnp
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import GymNaxEnv
if __name__ == '__main__':
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=6,
num_outputs=3,
max_nodes=50,
max_conns=100,
output_transform=lambda out: jnp.argmax(out) # the action of acrobot is {0, 1, 2}
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name='Acrobot-v1',
),
generation_limit=10000,
fitness_target=-62
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -1,84 +1,34 @@
import jax.numpy as jnp import jax.numpy as jnp
from config import *
from pipeline import Pipeline from pipeline import Pipeline
from algorithm import NEAT from algorithm.neat import *
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
def example_conf1():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=500,
pop_size=10000
),
neat=NeatConfig(
inputs=4,
outputs=1,
),
gene=NormalGeneConfig(
activation_default=Act.sigmoid,
activation_options=(Act.sigmoid,),
),
problem=GymNaxConfig(
env_name='CartPole-v1',
output_transform=lambda out: jnp.where(out[0] > 0.5, 1, 0) # the action of cartpole is {0, 1}
)
)
def example_conf2():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=500,
pop_size=10000
),
neat=NeatConfig(
inputs=4,
outputs=1,
),
gene=NormalGeneConfig(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
problem=GymNaxConfig(
env_name='CartPole-v1',
output_transform=lambda out: jnp.where(out[0] > 0, 1, 0) # the action of cartpole is {0, 1}
)
)
def example_conf3():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=501,
pop_size=10000
),
neat=NeatConfig(
inputs=4,
outputs=2,
),
gene=NormalGeneConfig(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
problem=GymNaxConfig(
env_name='CartPole-v1',
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1}
)
)
from problem.rl_env import GymNaxEnv
if __name__ == '__main__': if __name__ == '__main__':
# all config files above can solve cartpole pipeline = Pipeline(
conf = example_conf3() algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=4,
num_outputs=2,
max_nodes=50,
max_conns=100,
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1}
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name='CartPole-v1',
),
generation_limit=10000,
fitness_target=500
)
algorithm = NEAT(conf, NormalGene) # initialize state
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
state = pipeline.setup() state = pipeline.setup()
pipeline.pre_compile(state) # print(state)
# run until terminate
state, best = pipeline.auto_run(state) state, best = pipeline.auto_run(state)

View File

@@ -1,39 +1,34 @@
import jax.numpy as jnp import jax.numpy as jnp
from config import *
from pipeline import Pipeline from pipeline import Pipeline
from algorithm import NEAT from algorithm.neat import *
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
def example_conf():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=0,
pop_size=10000
),
neat=NeatConfig(
inputs=2,
outputs=3,
),
gene=NormalGeneConfig(
activation_default=Act.sigmoid,
activation_options=(Act.sigmoid,),
),
problem=GymNaxConfig(
env_name='MountainCar-v0',
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1, 2}
)
)
from problem.rl_env import GymNaxEnv
if __name__ == '__main__': if __name__ == '__main__':
conf = example_conf() pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=2,
num_outputs=3,
max_nodes=50,
max_conns=100,
output_transform=lambda out: jnp.argmax(out) # the action of mountain car is {0, 1, 2}
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name='MountainCar-v0',
),
generation_limit=10000,
fitness_target=0
)
algorithm = NEAT(conf, NormalGene) # initialize state
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
state = pipeline.setup() state = pipeline.setup()
pipeline.pre_compile(state) # print(state)
# run until terminate
state, best = pipeline.auto_run(state) state, best = pipeline.auto_run(state)

View File

@@ -1,38 +1,36 @@
import jax.numpy as jnp
from config import *
from pipeline import Pipeline from pipeline import Pipeline
from algorithm import NEAT from algorithm.neat import *
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
def example_conf():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=100,
pop_size=10000
),
neat=NeatConfig(
inputs=2,
outputs=1,
),
gene=NormalGeneConfig(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
problem=GymNaxConfig(
env_name='MountainCarContinuous-v0'
)
)
from problem.rl_env import GymNaxEnv
from utils import Act
if __name__ == '__main__': if __name__ == '__main__':
conf = example_conf() pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=2,
num_outputs=1,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh, ),
activation_default=Act.tanh,
)
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name='MountainCarContinuous-v0',
),
generation_limit=10000,
fitness_target=500
)
algorithm = NEAT(conf, NormalGene) # initialize state
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
state = pipeline.setup() state = pipeline.setup()
pipeline.pre_compile(state) # print(state)
# run until terminate
state, best = pipeline.auto_run(state) state, best = pipeline.auto_run(state)

View File

@@ -1,40 +1,37 @@
import jax.numpy as jnp
from config import *
from pipeline import Pipeline from pipeline import Pipeline
from algorithm import NEAT from algorithm.neat import *
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
def example_conf():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=0,
pop_size=10000
),
neat=NeatConfig(
inputs=3,
outputs=1,
),
gene=NormalGeneConfig(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
problem=GymNaxConfig(
env_name='Pendulum-v1',
output_transform=lambda out: out * 2 # the action of pendulum is [-2, 2]
)
)
from problem.rl_env import GymNaxEnv
from utils import Act
if __name__ == '__main__': if __name__ == '__main__':
conf = example_conf() pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=3,
num_outputs=1,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
),
output_transform=lambda out: out * 2 # the action of pendulum is [-2, 2]
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name='Pendulum-v1',
),
generation_limit=10000,
fitness_target=0
)
algorithm = NEAT(conf, NormalGene) # initialize state
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
state = pipeline.setup() state = pipeline.setup()
pipeline.pre_compile(state) # print(state)
# run until terminate
state, best = pipeline.auto_run(state) state, best = pipeline.auto_run(state)

View File

@@ -1,36 +1,33 @@
from config import * import jax.numpy as jnp
from pipeline import Pipeline from pipeline import Pipeline
from algorithm import NEAT from algorithm.neat import *
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
def example_conf():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=500,
pop_size=10000
),
neat=NeatConfig(
inputs=8,
outputs=2,
),
gene=NormalGeneConfig(
activation_default=Act.sigmoid,
activation_options=(Act.sigmoid,),
),
problem=GymNaxConfig(
env_name='Reacher-misc',
)
)
from problem.rl_env import GymNaxEnv
if __name__ == '__main__': if __name__ == '__main__':
conf = example_conf() pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=8,
num_outputs=2,
max_nodes=50,
max_conns=100,
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name='Reacher-misc',
),
generation_limit=10000,
fitness_target =500
)
algorithm = NEAT(conf, NormalGene) # initialize state
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
state = pipeline.setup() state = pipeline.setup()
pipeline.pre_compile(state) # print(state)
# run until terminate
state, best = pipeline.auto_run(state) state, best = pipeline.auto_run(state)

View File

@@ -1,25 +1,23 @@
from functools import partial from functools import partial
from typing import Type
import jax import jax, jax.numpy as jnp
import time import time
import numpy as np import numpy as np
from algorithm import NEAT, HyperNEAT from algorithm import BaseAlgorithm
from config import Config from problem import BaseProblem
from core import State, Algorithm, Problem from utils import State
class Pipeline: class Pipeline:
def __init__( def __init__(
self, self,
algorithm: Algorithm, algorithm: BaseAlgorithm,
problem: Problem, problem: BaseProblem,
seed: int = 42, seed: int = 42,
fitness_target: float = 1, fitness_target: float = 1,
generation_limit: int = 1000, generation_limit: int = 1000,
pop_size: int = 100,
): ):
assert problem.jitable, "Currently, problem must be jitable" assert problem.jitable, "Currently, problem must be jitable"
@@ -28,17 +26,18 @@ class Pipeline:
self.seed = seed self.seed = seed
self.fitness_target = fitness_target self.fitness_target = fitness_target
self.generation_limit = generation_limit self.generation_limit = generation_limit
self.pop_size = pop_size self.pop_size = self.algorithm.pop_size
print(self.problem.input_shape, self.problem.output_shape) print(self.problem.input_shape, self.problem.output_shape)
# TODO: make each algorithm's input_num and output_num # TODO: make each algorithm's input_num and output_num
assert algorithm.input_num == self.problem.input_shape[-1], f"problem input shape {self.problem.input_shape}" assert algorithm.num_inputs == self.problem.input_shape[-1], \
f"algorithm input shape is {algorithm.num_inputs} but problem input shape is {self.problem.input_shape}"
self.act_func = self.algorithm.act # self.act_func = self.algorithm.act
for _ in range(len(self.problem.input_shape) - 1): # for _ in range(len(self.problem.input_shape) - 1):
self.act_func = jax.vmap(self.act_func, in_axes=(None, 0, None)) # self.act_func = jax.vmap(self.act_func, in_axes=(None, 0, None))
self.best_genome = None self.best_genome = None
self.best_fitness = float('-inf') self.best_fitness = float('-inf')
@@ -46,41 +45,57 @@ class Pipeline:
def setup(self): def setup(self):
key = jax.random.PRNGKey(self.seed) key = jax.random.PRNGKey(self.seed)
algorithm_key, evaluate_key = jax.random.split(key, 2) key, algorithm_key, evaluate_key = jax.random.split(key, 3)
# TODO: Problem should has setup function to maintain state # TODO: Problem should has setup function to maintain state
return State( return State(
randkey=key,
alg=self.algorithm.setup(algorithm_key), alg=self.algorithm.setup(algorithm_key),
pro=self.problem.setup(evaluate_key), pro=self.problem.setup(evaluate_key),
) )
@partial(jax.jit, static_argnums=(0,))
def step(self, state): def step(self, state):
key, sub_key = jax.random.split(state.evaluate_key) key, sub_key = jax.random.split(state.randkey)
keys = jax.random.split(key, self.pop_size) keys = jax.random.split(key, self.pop_size)
pop = self.algorithm.ask(state) pop = self.algorithm.ask(state.alg)
pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0))(state, pop) pop_transformed = jax.vmap(self.algorithm.transform)(pop)
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0))(keys, state, self.act_func, fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0))(
pop_transformed) keys,
state.pro,
self.algorithm.forward,
pop_transformed
)
state = self.algorithm.tell(state, fitnesses) fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses)
return state.update(evaluate_key=sub_key), fitnesses alg_state = self.algorithm.tell(state.alg, fitnesses)
return state.update(
randkey=sub_key,
alg=alg_state,
), fitnesses
def auto_run(self, ini_state): def auto_run(self, ini_state):
state = ini_state state = ini_state
compiled_step = jax.jit(self.step).lower(ini_state).compile()
for _ in range(self.generation_limit): for _ in range(self.generation_limit):
self.generation_timestamp = time.time() self.generation_timestamp = time.time()
previous_pop = self.algorithm.ask(state) previous_pop = self.algorithm.ask(state.alg)
state, fitnesses = self.step(state) state, fitnesses = compiled_step(state)
fitnesses = jax.device_get(fitnesses) fitnesses = jax.device_get(fitnesses)
for idx, fitnesses_i in enumerate(fitnesses):
if np.isnan(fitnesses_i):
print("Fitness is nan")
print(previous_pop[0][idx], previous_pop[1][idx])
assert False
self.analysis(state, previous_pop, fitnesses) self.analysis(state, previous_pop, fitnesses)
@@ -102,22 +117,15 @@ class Pipeline:
max_idx = np.argmax(fitnesses) max_idx = np.argmax(fitnesses)
if fitnesses[max_idx] > self.best_fitness: if fitnesses[max_idx] > self.best_fitness:
self.best_fitness = fitnesses[max_idx] self.best_fitness = fitnesses[max_idx]
self.best_genome = pop[max_idx] self.best_genome = pop[0][max_idx], pop[1][max_idx]
member_count = jax.device_get(state.species_info.member_count) member_count = jax.device_get(self.algorithm.member_count(state.alg))
species_sizes = [int(i) for i in member_count if i > 0] species_sizes = [int(i) for i in member_count if i > 0]
print(f"Generation: {state.generation}", print(f"Generation: {self.algorithm.generation(state.alg)}",
f"species: {len(species_sizes)}, {species_sizes}", f"species: {len(species_sizes)}, {species_sizes}",
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms") f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")
def show(self, state, genome, *args, **kwargs): def show(self, state, best, *args, **kwargs):
transformed = self.algorithm.transform(state, genome) transformed = self.algorithm.transform(best)
self.problem.show(state.evaluate_key, state, self.act_func, transformed, *args, **kwargs) self.problem.show(state.randkey, state.pro, self.algorithm.forward, transformed, *args, **kwargs)
def pre_compile(self, state):
tic = time.time()
print("start compile")
self.step.lower(self, state).compile()
print(f"compile finished, cost time: {time.time() - tic}s")

View File

@@ -1,19 +1,14 @@
from typing import Callable from typing import Callable
from config import ProblemConfig from utils import State
from core.state import State
class BaseProblem: class BaseProblem:
jitable = None jitable = None
def __init__(self):
pass
def setup(self, randkey, state: State = State()): def setup(self, randkey, state: State = State()):
"""initialize the state of the problem""" """initialize the state of the problem"""
raise NotImplementedError pass
def evaluate(self, randkey, state: State, act_func: Callable, params): def evaluate(self, randkey, state: State, act_func: Callable, params):
"""evaluate one individual""" """evaluate one individual"""

View File

@@ -1,24 +1,27 @@
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from utils import State
from .. import BaseProblem from .. import BaseProblem
class FuncFit(BaseProblem):
class FuncFit(BaseProblem):
jitable = True jitable = True
def __init__(self, def __init__(self,
error_method: str = 'mse' error_method: str = 'mse'
): ):
super().__init__() super().__init__()
assert error_method in {'mse', 'rmse', 'mae', 'mape'} assert error_method in {'mse', 'rmse', 'mae', 'mape'}
self.error_method = error_method self.error_method = error_method
def setup(self, randkey, state: State = State()):
return state
def evaluate(self, randkey, state, act_func, params): def evaluate(self, randkey, state, act_func, params):
predict = act_func(state, self.inputs, params) predict = jax.vmap(act_func, in_axes=(0, None))(self.inputs, params)
if self.error_method == 'mse': if self.error_method == 'mse':
loss = jnp.mean((predict - self.targets) ** 2) loss = jnp.mean((predict - self.targets) ** 2)
@@ -38,7 +41,7 @@ class FuncFit(BaseProblem):
return -loss return -loss
def show(self, randkey, state, act_func, params, *args, **kwargs): def show(self, randkey, state, act_func, params, *args, **kwargs):
predict = act_func(state, self.inputs, params) predict = jax.vmap(act_func, in_axes=(0, None))(self.inputs, params)
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict]) inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
loss = -self.evaluate(randkey, state, act_func, params) loss = -self.evaluate(randkey, state, act_func, params)
msg = "" msg = ""

View File

@@ -1,2 +1,2 @@
from .gymnax_env import GymNaxEnv, GymNaxConfig from .gymnax_env import GymNaxEnv
from .brax_env import BraxEnv, BraxConfig from .brax_env import BraxEnv

View File

@@ -3,7 +3,6 @@ import gymnax
from .rl_jit import RLEnv from .rl_jit import RLEnv
class GymNaxEnv(RLEnv): class GymNaxEnv(RLEnv):
def __init__(self, env_name): def __init__(self, env_name):

View File

@@ -4,8 +4,8 @@ import jax
from .. import BaseProblem from .. import BaseProblem
class RLEnv(BaseProblem):
class RLEnv(BaseProblem):
jitable = True jitable = True
# TODO: move output transform to algorithm # TODO: move output transform to algorithm
@@ -19,9 +19,10 @@ class RLEnv(BaseProblem):
def cond_func(carry): def cond_func(carry):
_, _, _, done, _ = carry _, _, _, done, _ = carry
return ~done return ~done
def body_func(carry): def body_func(carry):
obs, env_state, rng, _, tr = carry # total reward obs, env_state, rng, _, tr = carry # total reward
action = act_func(state, obs, params) action = act_func(obs, params)
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action) next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
next_rng, _ = jax.random.split(rng) next_rng, _ = jax.random.split(rng)
return next_obs, next_env_state, next_rng, done, tr + reward return next_obs, next_env_state, next_rng, done, tr + reward

66
t.py
View File

@@ -1,64 +1,4 @@
from algorithm.neat import * import jax.numpy as jnp
from utils import Act, Agg
import jax, jax.numpy as jnp a = jnp.zeros((0, 9, 9))
print(a)
def main():
# index, bias, response, activation, aggregation
nodes = jnp.array([
[0, 0, 1, 0, 0], # in[0]
[1, 0, 1, 0, 0], # in[1]
[2, 0.5, 1, 0, 0], # out[0],
[3, 1, 1, 0, 0], # hidden[0],
[4, -1, 1, 0, 0], # hidden[1],
])
# in_node, out_node, enable, weight
conns = jnp.array([
[0, 3, 1, 0.5], # in[0] -> hidden[0]
[1, 4, 1, 0.5], # in[1] -> hidden[1]
[3, 2, 1, 0.5], # hidden[0] -> out[0]
[4, 2, 1, 0.5], # hidden[1] -> out[0]
])
genome = RecurrentGenome(
num_inputs=2,
num_outputs=1,
node_gene=DefaultNodeGene(
activation_default=Act.identity,
activation_options=(Act.identity, ),
aggregation_default=Agg.sum,
aggregation_options=(Agg.sum, ),
),
activate_time=3
)
transformed = genome.transform(nodes, conns)
print(*transformed, sep='\n')
inputs = jnp.array([0, 0])
outputs = genome.forward(inputs, transformed)
print(outputs)
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(0, None)))(inputs, transformed)
print(outputs)
expected: [[0.5], [0.75], [0.75], [1]]
print('\n-------------------------------------------------------\n')
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
print(conns)
transformed = genome.transform(nodes, conns)
print(*transformed, sep='\n')
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
outputs = jax.vmap(genome.forward, in_axes=(0, None))(inputs, transformed)
print(outputs)
expected: [[0.5], [0.75], [0.5], [0.75]]
if __name__ == '__main__':
main()

View File

@@ -26,6 +26,8 @@ def test_default():
genome = DefaultGenome( genome = DefaultGenome(
num_inputs=2, num_inputs=2,
num_outputs=1, num_outputs=1,
max_nodes=5,
max_conns=4,
node_gene=DefaultNodeGene( node_gene=DefaultNodeGene(
activation_default=Act.identity, activation_default=Act.identity,
activation_options=(Act.identity, ), activation_options=(Act.identity, ),
@@ -80,6 +82,8 @@ def test_recurrent():
genome = RecurrentGenome( genome = RecurrentGenome(
num_inputs=2, num_inputs=2,
num_outputs=1, num_outputs=1,
max_nodes=5,
max_conns=4,
node_gene=DefaultNodeGene( node_gene=DefaultNodeGene(
activation_default=Act.identity, activation_default=Act.identity,
activation_options=(Act.identity, ), activation_options=(Act.identity, ),

View File

@@ -6,48 +6,26 @@ class Act:
@staticmethod @staticmethod
def sigmoid(z): def sigmoid(z):
z = jnp.clip(z * 5, -60, 60) z = jnp.clip(5 * z, -10, 10)
return 1 / (1 + jnp.exp(-z)) return 1 / (1 + jnp.exp(-z))
@staticmethod @staticmethod
def tanh(z): def tanh(z):
z = jnp.clip(z * 2.5, -60, 60)
return jnp.tanh(z) return jnp.tanh(z)
@staticmethod @staticmethod
def sin(z): def sin(z):
z = jnp.clip(z * 5, -60, 60)
return jnp.sin(z) return jnp.sin(z)
@staticmethod
def gauss(z):
z = jnp.clip(z * 5, -3.4, 3.4)
return jnp.exp(-z ** 2)
@staticmethod @staticmethod
def relu(z): def relu(z):
return jnp.maximum(z, 0) return jnp.maximum(z, 0)
@staticmethod
def elu(z):
return jnp.where(z > 0, z, jnp.exp(z) - 1)
@staticmethod @staticmethod
def lelu(z): def lelu(z):
leaky = 0.005 leaky = 0.005
return jnp.where(z > 0, z, leaky * z) return jnp.where(z > 0, z, leaky * z)
@staticmethod
def selu(z):
lam = 1.0507009873554804934193349852946
alpha = 1.6732632423543772848170429916717
return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1))
@staticmethod
def softplus(z):
z = jnp.clip(z * 5, -60, 60)
return 0.2 * jnp.log(1 + jnp.exp(z))
@staticmethod @staticmethod
def identity(z): def identity(z):
return z return z
@@ -58,7 +36,11 @@ class Act:
@staticmethod @staticmethod
def inv(z): def inv(z):
z = jnp.maximum(z, 1e-7) z = jnp.where(
z > 0,
jnp.maximum(z, 1e-7),
jnp.minimum(z, -1e-7)
)
return 1 / z return 1 / z
@staticmethod @staticmethod
@@ -68,24 +50,27 @@ class Act:
@staticmethod @staticmethod
def exp(z): def exp(z):
z = jnp.clip(z, -60, 60) z = jnp.clip(z, -10, 10)
return jnp.exp(z) return jnp.exp(z)
@staticmethod @staticmethod
def abs(z): def abs(z):
return jnp.abs(z) return jnp.abs(z)
@staticmethod
def hat(z):
return jnp.maximum(0, 1 - jnp.abs(z))
@staticmethod ACT_ALL = (
def square(z): Act.sigmoid,
return z ** 2 Act.tanh,
Act.sin,
@staticmethod Act.relu,
def cube(z): Act.lelu,
return z ** 3 Act.identity,
Act.clamped,
Act.inv,
Act.log,
Act.exp,
Act.abs,
)
def act(idx, z, act_funcs): def act(idx, z, act_funcs):

View File

@@ -51,6 +51,9 @@ class Agg:
return mean_without_zeros return mean_without_zeros
AGG_ALL = (Agg.sum, Agg.product, Agg.max, Agg.min, Agg.maxabs, Agg.median, Agg.mean)
def agg(idx, z, agg_funcs): def agg(idx, z, agg_funcs):
""" """
calculate activation function for inputs of node calculate activation function for inputs of node