remove create_func....

This commit is contained in:
wls2002
2023-08-02 13:26:01 +08:00
parent 85318f98f3
commit 1499e062fe
34 changed files with 558 additions and 1022 deletions

View File

@@ -1,2 +1 @@
from .neat import * from .neat import NEAT
from .hyper_neat import *

View File

@@ -1,2 +0,0 @@
from .hyper_neat import HyperNEAT
from .substrate import NormalSubstrate, NormalSubstrateConfig

View File

@@ -1,122 +0,0 @@
from typing import Type
import jax
from jax import numpy as jnp, Array, vmap
import numpy as np
from config import Config, HyperNeatConfig
from core import Algorithm, Substrate, State, Genome
from utils import Activation, Aggregation
from algorithm.neat import NEAT
from .substrate import analysis_substrate
class HyperNEAT(Algorithm):
def __init__(self, config: Config, neat: NEAT, substrate: Type[Substrate]):
self.config = config
self.neat = neat
self.substrate = substrate
self.forward_func = None
def setup(self, randkey, state=State()):
neat_key, randkey = jax.random.split(randkey)
state = state.update(
below_threshold=self.config.hyper_neat.below_threshold,
max_weight=self.config.hyper_neat.max_weight,
)
state = self.neat.setup(neat_key, state)
state = self.substrate.setup(self.config.substrate, state)
assert self.config.hyper_neat.inputs + 1 == state.input_coors.shape[0] # +1 for bias
assert self.config.hyper_neat.outputs == state.output_coors.shape[0]
h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state)
h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis]
h_conns = np.zeros((correspond_keys.shape[0], 3), dtype=np.float32)
h_conns[:, 0:2] = correspond_keys
state = state.update(
h_input_idx=h_input_idx,
h_output_idx=h_output_idx,
h_hidden_idx=h_hidden_idx,
h_nodes=h_nodes,
h_conns=h_conns,
query_coors=query_coors,
)
self.forward_func = HyperNEATGene.create_forward(self.config.hyper_neat, state)
return state
def ask(self, state: State):
return state.pop_genomes
def tell(self, state: State, fitness):
return self.neat.tell(state, fitness)
def forward(self, inputs: Array, transformed: Array):
return self.forward_func(inputs, transformed)
def forward_transform(self, state: State, genome: Genome):
t = self.neat.forward_transform(state, genome)
query_res = vmap(self.neat.forward, in_axes=(0, None))(state.query_coors, t)
# mute the connection with weight below threshold
query_res = jnp.where((-state.below_threshold < query_res) & (query_res < state.below_threshold), 0., query_res)
# make query res in range [-max_weight, max_weight]
query_res = jnp.where(query_res > 0, query_res - state.below_threshold, query_res)
query_res = jnp.where(query_res < 0, query_res + state.below_threshold, query_res)
query_res = query_res / (1 - state.below_threshold) * state.max_weight
h_conns = state.h_conns.at[:, 2:].set(query_res)
return HyperNEATGene.forward_transform(Genome(state.h_nodes, h_conns))
class HyperNEATGene:
node_attrs = [] # no node attributes
conn_attrs = ['weight']
@staticmethod
def forward_transform(genome: Genome):
N = genome.nodes.shape[0]
u_conns = jnp.zeros((N, N), dtype=jnp.float32)
in_keys = jnp.asarray(genome.conns[:, 0], jnp.int32)
out_keys = jnp.asarray(genome.conns[:, 1], jnp.int32)
weights = genome.conns[:, 2]
u_conns = u_conns.at[in_keys, out_keys].set(weights)
return genome.nodes, u_conns
@staticmethod
def create_forward(config: HyperNeatConfig, state: State):
act = Activation.name2func[config.activation]
agg = Aggregation.name2func[config.aggregation]
batch_act, batch_agg = jax.vmap(act), jax.vmap(agg)
def forward(inputs, transform):
inputs_with_bias = jnp.concatenate((inputs, jnp.ones((1,))), axis=0)
nodes, weights = transform
input_idx = state.h_input_idx
output_idx = state.h_output_idx
N = nodes.shape[0]
vals = jnp.full((N,), 0.)
def body_func(i, values):
values = values.at[input_idx].set(inputs_with_bias)
nodes_ins = values * weights.T
values = batch_agg(nodes_ins) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(values) # z = act(z)
return values
vals = jax.lax.fori_loop(0, config.activate_times, body_func, vals)
return vals[output_idx]
return forward

View File

@@ -1,2 +0,0 @@
from .normal import NormalSubstrate, NormalSubstrateConfig
from .tools import analysis_substrate

View File

@@ -1,25 +0,0 @@
from dataclasses import dataclass
from typing import Tuple
import numpy as np
from core import Substrate, State
from config import SubstrateConfig
@dataclass(frozen=True)
class NormalSubstrateConfig(SubstrateConfig):
input_coors: Tuple[Tuple[float]] = ((-1, -1), (0, -1), (1, -1))
hidden_coors: Tuple[Tuple[float]] = ((-1, 0), (0, 0), (1, 0))
output_coors: Tuple[Tuple[float]] = ((0, 1), )
class NormalSubstrate(Substrate):
@staticmethod
def setup(config: NormalSubstrateConfig, state: State = State()):
return state.update(
input_coors=np.asarray(config.input_coors, dtype=np.float32),
output_coors=np.asarray(config.output_coors, dtype=np.float32),
hidden_coors=np.asarray(config.hidden_coors, dtype=np.float32),
)

View File

@@ -1,50 +0,0 @@
from typing import Type
import numpy as np
def analysis_substrate(state):
cd = state.input_coors.shape[1] # coordinate dimensions
si = state.input_coors.shape[0] # input coordinate size
so = state.output_coors.shape[0] # output coordinate size
sh = state.hidden_coors.shape[0] # hidden coordinate size
input_idx = np.arange(si)
output_idx = np.arange(si, si + so)
hidden_idx = np.arange(si + so, si + so + sh)
total_conns = si * sh + sh * sh + sh * so
query_coors = np.zeros((total_conns, cd * 2))
correspond_keys = np.zeros((total_conns, 2))
# connect input to hidden
aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, state.input_coors, state.hidden_coors)
query_coors[0: si * sh, :] = aux_coors
correspond_keys[0: si * sh, :] = aux_keys
# connect hidden to hidden
aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, state.hidden_coors, state.hidden_coors)
query_coors[si * sh: si * sh + sh * sh, :] = aux_coors
correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys
# connect hidden to output
aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, state.hidden_coors, state.output_coors)
query_coors[si * sh + sh * sh:, :] = aux_coors
correspond_keys[si * sh + sh * sh:, :] = aux_keys
return input_idx, output_idx, hidden_idx, query_coors, correspond_keys
def cartesian_product(keys1, keys2, coors1, coors2):
len1 = keys1.shape[0]
len2 = keys2.shape[0]
repeated_coors1 = np.repeat(coors1, len2, axis=0)
repeated_keys1 = np.repeat(keys1, len2)
tiled_coors2 = np.tile(coors2, (len1, 1))
tiled_keys2 = np.tile(keys2, len1)
new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1)
correspond_keys = np.column_stack((repeated_keys1, tiled_keys2))
return new_coors, correspond_keys

View File

@@ -1,2 +1 @@
from .neat import NEAT from .neat import NEAT
from .gene import *

View File

@@ -1,2 +1,3 @@
from .crossover import crossover from .crossover import crossover
from .mutate import create_mutate from .mutate import mutate
from .operation import create_next_generation

View File

@@ -9,7 +9,7 @@ def crossover(randkey, genome1: Genome, genome2: Genome):
use genome1 and genome2 to generate a new genome use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!) notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
""" """
randkey_1, randkey_2, key= jax.random.split(randkey, 3) randkey_1, randkey_2, key = jax.random.split(randkey, 3)
# crossover nodes # crossover nodes
keys1, keys2 = genome1.nodes[:, 0], genome2.nodes[:, 0] keys1, keys2 = genome1.nodes[:, 0], genome2.nodes[:, 0]

View File

@@ -1,4 +1,4 @@
from typing import Tuple, Type from typing import Tuple
import jax import jax
from jax import Array, numpy as jnp, vmap from jax import Array, numpy as jnp, vmap
@@ -8,144 +8,141 @@ from core import State, Gene, Genome
from utils import check_cycles, fetch_random, fetch_first, I_INT, unflatten_conns from utils import check_cycles, fetch_random, fetch_first, I_INT, unflatten_conns
def create_mutate(config: NeatConfig, gene_type: Type[Gene]): def mutate(config: NeatConfig, gene: Gene, state: State, randkey, genome: Genome, new_node_key):
""" """
Create function to mutate a single genome Mutate a population of genomes
""" """
k1, k2 = jax.random.split(randkey)
def mutate_structure(state: State, randkey, genome: Genome, new_node_key): genome = mutate_structure(config, gene, state, k1, genome, new_node_key)
genome = mutate_values(gene, state, randkey, genome)
def mutate_add_node(key_, genome_: Genome): return genome
i_key, o_key, idx = choice_connection_key(key_, genome_.conns)
def nothing():
return genome_
def successful_add_node():
# disable the connection
new_genome = genome_.update_conns(genome_.conns.at[idx, 2].set(False))
# add a new node
new_genome = new_genome.add_node(new_node_key, gene_type.new_node_attrs(state))
# add two new connections
new_genome = new_genome.add_conn(i_key, new_node_key, True, gene_type.new_conn_attrs(state))
new_genome = new_genome.add_conn(new_node_key, o_key, True, gene_type.new_conn_attrs(state))
return new_genome
# if from_idx == I_INT, that means no connection exist, do nothing
return jax.lax.cond(idx == I_INT, nothing, successful_add_node)
def mutate_delete_node(key_, genome_: Genome):
# TODO: Do we really need to delete a node?
# randomly choose a node
key, idx = choice_node_key(key_, genome_.nodes, state.input_idx, state.output_idx,
allow_input_keys=False, allow_output_keys=False)
def nothing():
return genome_
def successful_delete_node():
# delete the node
new_genome = genome_.delete_node_by_pos(idx)
# delete all connections
new_conns = jnp.where(((new_genome.conns[:, 0] == key) | (new_genome.conns[:, 1] == key))[:, None],
jnp.nan, new_genome.conns)
return new_genome.update_conns(new_conns)
return jax.lax.cond(idx == I_INT, nothing, successful_delete_node)
def mutate_add_conn(key_, genome_: Genome):
# randomly choose two nodes
k1_, k2_ = jax.random.split(key_, num=2)
i_key, from_idx = choice_node_key(k1_, genome_.nodes, state.input_idx, state.output_idx,
allow_input_keys=True, allow_output_keys=True)
o_key, to_idx = choice_node_key(k2_, genome_.nodes, state.input_idx, state.output_idx,
allow_input_keys=False, allow_output_keys=True)
conn_pos = fetch_first((genome_.conns[:, 0] == i_key) & (genome_.conns[:, 1] == o_key))
def nothing():
return genome_
def successful():
return genome_.add_conn(i_key, o_key, True, gene_type.new_conn_attrs(state))
def already_exist():
return genome_.update_conns(genome_.conns.at[conn_pos, 2].set(True))
is_already_exist = conn_pos != I_INT def mutate_structure(config: NeatConfig, gene: Gene, state: State, randkey, genome: Genome, new_node_key):
def mutate_add_node(key_, genome_: Genome):
i_key, o_key, idx = choice_connection_key(key_, genome_.conns)
if config.network_type == 'feedforward': def nothing():
u_cons = unflatten_conns(genome_.nodes, genome_.conns) return genome_
cons_exist = jnp.where(~jnp.isnan(u_cons[0, :, :]), True, False)
is_cycle = check_cycles(genome_.nodes, cons_exist, from_idx, to_idx)
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2)) def successful_add_node():
return jax.lax.switch(choice, [already_exist, nothing, successful]) # disable the connection
new_genome = genome_.update_conns(genome_.conns.at[idx, 2].set(False))
elif config.network_type == 'recurrent': # add a new node
return jax.lax.cond(is_already_exist, already_exist, successful) new_genome = new_genome.add_node(new_node_key, gene.new_node_attrs(state))
else: # add two new connections
raise ValueError(f"Invalid network type: {config.network_type}") new_genome = new_genome.add_conn(i_key, new_node_key, True, gene.new_conn_attrs(state))
new_genome = new_genome.add_conn(new_node_key, o_key, True, gene.new_conn_attrs(state))
def mutate_delete_conn(key_, genome_: Genome): return new_genome
# randomly choose a connection
i_key, o_key, idx = choice_connection_key(key_, genome_.conns)
def nothing(): # if from_idx == I_INT, that means no connection exist, do nothing
return genome_ return jax.lax.cond(idx == I_INT, nothing, successful_add_node)
def successfully_delete_connection(): def mutate_delete_node(key_, genome_: Genome):
return genome_.delete_conn_by_pos(idx) # TODO: Do we really need to delete a node?
# randomly choose a node
key, idx = choice_node_key(key_, genome_.nodes, state.input_idx, state.output_idx,
allow_input_keys=False, allow_output_keys=False)
return jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection) def nothing():
return genome_
k1, k2, k3, k4 = jax.random.split(randkey, num=4) def successful_delete_node():
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,)) # delete the node
new_genome = genome_.delete_node_by_pos(idx)
def no(k, g): # delete all connections
return g new_conns = jnp.where(((new_genome.conns[:, 0] == key) | (new_genome.conns[:, 1] == key))[:, None],
jnp.nan, new_genome.conns)
genome = jax.lax.cond(r1 < config.node_add, mutate_add_node, no, k1, genome) return new_genome.update_conns(new_conns)
genome = jax.lax.cond(r2 < config.node_delete, mutate_delete_node, no, k2, genome)
genome = jax.lax.cond(r3 < config.conn_add, mutate_add_conn, no, k3, genome)
genome = jax.lax.cond(r4 < config.conn_delete, mutate_delete_conn, no, k4, genome)
return genome return jax.lax.cond(idx == I_INT, nothing, successful_delete_node)
def mutate_values(state: State, randkey, genome: Genome): def mutate_add_conn(key_, genome_: Genome):
k1, k2 = jax.random.split(randkey, num=2) # randomly choose two nodes
nodes_keys = jax.random.split(k1, num=genome.nodes.shape[0]) k1_, k2_ = jax.random.split(key_, num=2)
conns_keys = jax.random.split(k2, num=genome.conns.shape[0]) i_key, from_idx = choice_node_key(k1_, genome_.nodes, state.input_idx, state.output_idx,
allow_input_keys=True, allow_output_keys=True)
o_key, to_idx = choice_node_key(k2_, genome_.nodes, state.input_idx, state.output_idx,
allow_input_keys=False, allow_output_keys=True)
nodes_attrs, conns_attrs = genome.nodes[:, 1:], genome.conns[:, 3:] conn_pos = fetch_first((genome_.conns[:, 0] == i_key) & (genome_.conns[:, 1] == o_key))
new_nodes_attrs = vmap(gene_type.mutate_node, in_axes=(None, 0, 0))(state, nodes_attrs, nodes_keys) def nothing():
new_conns_attrs = vmap(gene_type.mutate_conn, in_axes=(None, 0, 0))(state, conns_attrs, conns_keys) return genome_
# nan nodes not changed def successful():
new_nodes_attrs = jnp.where(jnp.isnan(nodes_attrs), jnp.nan, new_nodes_attrs) return genome_.add_conn(i_key, o_key, True, gene.new_conn_attrs(state))
new_conns_attrs = jnp.where(jnp.isnan(conns_attrs), jnp.nan, new_conns_attrs)
new_nodes = genome.nodes.at[:, 1:].set(new_nodes_attrs) def already_exist():
new_conns = genome.conns.at[:, 3:].set(new_conns_attrs) return genome_.update_conns(genome_.conns.at[conn_pos, 2].set(True))
return genome.update(new_nodes, new_conns) is_already_exist = conn_pos != I_INT
def mutate(state, randkey, genome: Genome, new_node_key): if config.network_type == 'feedforward':
k1, k2 = jax.random.split(randkey) u_cons = unflatten_conns(genome_.nodes, genome_.conns)
cons_exist = jnp.where(~jnp.isnan(u_cons[0, :, :]), True, False)
is_cycle = check_cycles(genome_.nodes, cons_exist, from_idx, to_idx)
genome = mutate_structure(state, k1, genome, new_node_key) choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
genome = mutate_values(state, k2, genome) return jax.lax.switch(choice, [already_exist, nothing, successful])
return genome elif config.network_type == 'recurrent':
return jax.lax.cond(is_already_exist, already_exist, successful)
return mutate else:
raise ValueError(f"Invalid network type: {config.network_type}")
def mutate_delete_conn(key_, genome_: Genome):
# randomly choose a connection
i_key, o_key, idx = choice_connection_key(key_, genome_.conns)
def nothing():
return genome_
def successfully_delete_connection():
return genome_.delete_conn_by_pos(idx)
return jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection)
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
def no(k, g):
return g
genome = jax.lax.cond(r1 < config.node_add, mutate_add_node, no, k1, genome)
genome = jax.lax.cond(r2 < config.node_delete, mutate_delete_node, no, k2, genome)
genome = jax.lax.cond(r3 < config.conn_add, mutate_add_conn, no, k3, genome)
genome = jax.lax.cond(r4 < config.conn_delete, mutate_delete_conn, no, k4, genome)
return genome
def mutate_values(gene: Gene, state: State, randkey, genome: Genome):
k1, k2 = jax.random.split(randkey, num=2)
nodes_keys = jax.random.split(k1, num=genome.nodes.shape[0])
conns_keys = jax.random.split(k2, num=genome.conns.shape[0])
nodes_attrs, conns_attrs = genome.nodes[:, 1:], genome.conns[:, 3:]
new_nodes_attrs = vmap(gene.mutate_node, in_axes=(None, 0, 0))(state, nodes_keys, nodes_attrs)
new_conns_attrs = vmap(gene.mutate_conn, in_axes=(None, 0, 0))(state, conns_keys, conns_attrs)
# nan nodes not changed
new_nodes_attrs = jnp.where(jnp.isnan(nodes_attrs), jnp.nan, new_nodes_attrs)
new_conns_attrs = jnp.where(jnp.isnan(conns_attrs), jnp.nan, new_conns_attrs)
new_nodes = genome.nodes.at[:, 1:].set(new_nodes_attrs)
new_conns = genome.conns.at[:, 3:].set(new_conns_attrs)
return genome.update(new_nodes, new_conns)
def choice_node_key(rand_key: Array, nodes: Array, def choice_node_key(rand_key: Array, nodes: Array,

View File

@@ -0,0 +1,40 @@
import jax
from jax import numpy as jnp, vmap
from config import NeatConfig
from core import Genome, State, Gene
from .mutate import mutate
from .crossover import crossover
def create_next_generation(config: NeatConfig, gene: Gene, state: State, randkey, winner, loser, elite_mask):
# prepare random keys
pop_size = state.idx2species.shape[0]
new_node_keys = jnp.arange(pop_size) + state.next_node_key
k1, k2 = jax.random.split(randkey, 2)
crossover_rand_keys = jax.random.split(k1, pop_size)
mutate_rand_keys = jax.random.split(k2, pop_size)
# batch crossover
wpn, wpc = state.pop_genomes.nodes[winner], state.pop_genomes.conns[winner]
lpn, lpc = state.pop_genomes.nodes[loser], state.pop_genomes.conns[loser]
n_genomes = vmap(crossover)(crossover_rand_keys, Genome(wpn, wpc), Genome(lpn, lpc))
# batch mutation
mutate_func = vmap(mutate, in_axes=(None, None, None, 0, 0, 0))
m_n_genomes = mutate_func(config, gene, state, mutate_rand_keys, n_genomes, new_node_keys) # mutate_new_pop_nodes
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], n_genomes.nodes, m_n_genomes.nodes)
pop_conns = jnp.where(elite_mask[:, None, None], n_genomes.conns, m_n_genomes.conns)
# update next node key
all_nodes_keys = pop_nodes[:, :, 0]
max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys))
next_node_key = max_node_key + 1
return state.update(
pop_genomes=Genome(pop_nodes, pop_conns),
next_node_key=next_node_key,
)

View File

@@ -1,2 +1 @@
from .normal import NormalGene, NormalGeneConfig from .normal import NormalGene, NormalGeneConfig
from .recurrent import RecurrentGene, RecurrentGeneConfig

View File

@@ -6,7 +6,7 @@ from jax import Array, numpy as jnp
from config import GeneConfig from config import GeneConfig
from core import Gene, Genome, State from core import Gene, Genome, State
from utils import Activation, Aggregation, unflatten_conns, topological_sort, I_INT from utils import Activation, Aggregation, unflatten_conns, topological_sort, I_INT, act, agg
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -66,48 +66,51 @@ class NormalGene(Gene):
node_attrs = ['bias', 'response', 'aggregation', 'activation'] node_attrs = ['bias', 'response', 'aggregation', 'activation']
conn_attrs = ['weight'] conn_attrs = ['weight']
@staticmethod def __init__(self, config: NormalGeneConfig):
def setup(config: NormalGeneConfig, state: State = State()): self.config = config
self.act_funcs = [Activation.name2func[name] for name in config.activation_options]
self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
def setup(self, state: State = State()):
return state.update( return state.update(
bias_init_mean=config.bias_init_mean, bias_init_mean=self.config.bias_init_mean,
bias_init_std=config.bias_init_std, bias_init_std=self.config.bias_init_std,
bias_mutate_power=config.bias_mutate_power, bias_mutate_power=self.config.bias_mutate_power,
bias_mutate_rate=config.bias_mutate_rate, bias_mutate_rate=self.config.bias_mutate_rate,
bias_replace_rate=config.bias_replace_rate, bias_replace_rate=self.config.bias_replace_rate,
response_init_mean=config.response_init_mean, response_init_mean=self.config.response_init_mean,
response_init_std=config.response_init_std, response_init_std=self.config.response_init_std,
response_mutate_power=config.response_mutate_power, response_mutate_power=self.config.response_mutate_power,
response_mutate_rate=config.response_mutate_rate, response_mutate_rate=self.config.response_mutate_rate,
response_replace_rate=config.response_replace_rate, response_replace_rate=self.config.response_replace_rate,
activation_replace_rate=config.activation_replace_rate, activation_replace_rate=self.config.activation_replace_rate,
activation_default=0, activation_default=0,
activation_options=jnp.arange(len(config.activation_options)), activation_options=jnp.arange(len(self.config.activation_options)),
aggregation_replace_rate=config.aggregation_replace_rate, aggregation_replace_rate=self.config.aggregation_replace_rate,
aggregation_default=0, aggregation_default=0,
aggregation_options=jnp.arange(len(config.aggregation_options)), aggregation_options=jnp.arange(len(self.config.aggregation_options)),
weight_init_mean=config.weight_init_mean, weight_init_mean=self.config.weight_init_mean,
weight_init_std=config.weight_init_std, weight_init_std=self.config.weight_init_std,
weight_mutate_power=config.weight_mutate_power, weight_mutate_power=self.config.weight_mutate_power,
weight_mutate_rate=config.weight_mutate_rate, weight_mutate_rate=self.config.weight_mutate_rate,
weight_replace_rate=config.weight_replace_rate, weight_replace_rate=self.config.weight_replace_rate,
) )
@staticmethod def update(self, state):
def new_node_attrs(state): pass
def new_node_attrs(self, state):
return jnp.array([state.bias_init_mean, state.response_init_mean, return jnp.array([state.bias_init_mean, state.response_init_mean,
state.activation_default, state.aggregation_default]) state.activation_default, state.aggregation_default])
@staticmethod def new_conn_attrs(self, state):
def new_conn_attrs(state):
return jnp.array([state.weight_init_mean]) return jnp.array([state.weight_init_mean])
@staticmethod def mutate_node(self, state, key, attrs: Array):
def mutate_node(state, attrs: Array, key):
k1, k2, k3, k4 = jax.random.split(key, num=4) k1, k2, k3, k4 = jax.random.split(key, num=4)
bias = NormalGene._mutate_float(k1, attrs[0], state.bias_init_mean, state.bias_init_std, bias = NormalGene._mutate_float(k1, attrs[0], state.bias_init_mean, state.bias_init_std,
@@ -120,26 +123,22 @@ class NormalGene(Gene):
return jnp.array([bias, res, act, agg]) return jnp.array([bias, res, act, agg])
@staticmethod def mutate_conn(self, state, key, attrs: Array):
def mutate_conn(state, attrs: Array, key):
weight = NormalGene._mutate_float(key, attrs[0], state.weight_init_mean, state.weight_init_std, weight = NormalGene._mutate_float(key, attrs[0], state.weight_init_mean, state.weight_init_std,
state.weight_mutate_power, state.weight_mutate_rate, state.weight_mutate_power, state.weight_mutate_rate,
state.weight_replace_rate) state.weight_replace_rate)
return jnp.array([weight]) return jnp.array([weight])
@staticmethod def distance_node(self, state, node1: Array, node2: Array):
def distance_node(state, node1: Array, node2: Array):
# bias + response + activation + aggregation # bias + response + activation + aggregation
return jnp.abs(node1[1] - node2[1]) + jnp.abs(node1[2] - node2[2]) + \ return jnp.abs(node1[1] - node2[1]) + jnp.abs(node1[2] - node2[2]) + \
(node1[3] != node2[3]) + (node1[4] != node2[4]) (node1[3] != node2[3]) + (node1[4] != node2[4])
@staticmethod def distance_conn(self, state, con1: Array, con2: Array):
def distance_conn(state, con1: Array, con2: Array):
return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight
@staticmethod def forward_transform(self, state: State, genome: Genome):
def forward_transform(state: State, genome: Genome):
u_conns = unflatten_conns(genome.nodes, genome.conns) u_conns = unflatten_conns(genome.nodes, genome.conns)
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False) conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
@@ -149,87 +148,46 @@ class NormalGene(Gene):
return seqs, genome.nodes, u_conns return seqs, genome.nodes, u_conns
@staticmethod def forward(self, state: State, inputs, transformed):
def create_forward(state: State, config: NormalGeneConfig): cal_seqs, nodes, cons = transformed
activation_funcs = [Activation.name2func[name] for name in config.activation_options]
aggregation_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
def act(idx, z): input_idx = state.input_idx
""" output_idx = state.output_idx
calculate activation function for each node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
res = jax.lax.switch(idx, activation_funcs, z)
return res
def agg(idx, z): N = nodes.shape[0]
""" ini_vals = jnp.full((N,), jnp.nan)
calculate activation function for inputs of node ini_vals = ini_vals.at[input_idx].set(inputs)
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
def all_nan(): weights = cons[0, :]
return 0.
def not_all_nan(): def cond_fun(carry):
return jax.lax.switch(idx, aggregation_funcs, z) values, idx = carry
return (idx < N) & (cal_seqs[idx] != I_INT)
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan) def body_func(carry):
values, idx = carry
i = cal_seqs[idx]
def forward(inputs, transformed) -> Array: def hit():
""" ins = values * weights[:, i]
forward for single input shaped (input_num, ) z = agg(nodes[i, 4], ins, self.agg_funcs) # z = agg(ins)
z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
z = act(nodes[i, 3], z, self.act_funcs) # z = act(z)
:argument inputs: (input_num, ) new_values = values.at[i].set(z)
:argument cal_seqs: (N, ) return new_values
:argument nodes: (N, 5)
:argument connections: (2, N, N)
:return (output_num, ) def miss():
""" return values
cal_seqs, nodes, cons = transformed # the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit)
input_idx = state.input_idx return values, idx + 1
output_idx = state.output_idx
N = nodes.shape[0] vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[input_idx].set(inputs)
weights = cons[0, :] return vals[output_idx]
def cond_fun(carry):
values, idx = carry
return (idx < N) & (cal_seqs[idx] != I_INT)
def body_func(carry):
values, idx = carry
i = cal_seqs[idx]
def hit():
ins = values * weights[:, i]
z = agg(nodes[i, 4], ins) # z = agg(ins)
z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
z = act(nodes[i, 3], z) # z = act(z)
new_values = values.at[i].set(z)
return new_values
def miss():
return values
# the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit)
return values, idx + 1
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
return vals[output_idx]
return forward
@staticmethod @staticmethod
def _mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate): def _mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate):

View File

@@ -1,84 +0,0 @@
from dataclasses import dataclass
import jax
from jax import Array, numpy as jnp, vmap
from .normal import NormalGene, NormalGeneConfig
from core import State, Genome
from utils import Activation, Aggregation, unflatten_conns
@dataclass(frozen=True)
class RecurrentGeneConfig(NormalGeneConfig):
activate_times: int = 10
def __post_init__(self):
super().__post_init__()
assert self.activate_times > 0
class RecurrentGene(NormalGene):
@staticmethod
def forward_transform(state: State, genome: Genome):
u_conns = unflatten_conns(genome.nodes, genome.conns)
# remove un-enable connections and remove enable attr
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
return genome.nodes, u_conns
@staticmethod
def create_forward(state: State, config: RecurrentGeneConfig):
activation_funcs = [Activation.name2func[name] for name in config.activation_options]
aggregation_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
def act(idx, z):
"""
calculate activation function for each node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
res = jax.lax.switch(idx, activation_funcs, z)
return res
def agg(idx, z):
"""
calculate activation function for inputs of node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
def all_nan():
return 0.
def not_all_nan():
return jax.lax.switch(idx, aggregation_funcs, z)
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
batch_act, batch_agg = vmap(act), vmap(agg)
def forward(inputs, transform) -> Array:
nodes, cons = transform
input_idx = state.input_idx
output_idx = state.output_idx
N = nodes.shape[0]
vals = jnp.full((N,), 0.)
weights = cons[0, :]
def body_func(i, values):
values = values.at[input_idx].set(inputs)
nodes_ins = values * weights.T
values = batch_agg(nodes[:, 4], nodes_ins) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(nodes[:, 3], values) # z = act(z)
return values
vals = jax.lax.fori_loop(0, config.activate_times, body_func, vals)
return vals[output_idx]
return forward

View File

@@ -1,20 +1,18 @@
from typing import Type
import jax import jax
from jax import numpy as jnp, Array, vmap from jax import numpy as jnp
import numpy as np import numpy as np
from config import Config from config import Config
from core import Algorithm, State, Gene, Genome from core import Algorithm, State, Gene, Genome
from .ga import crossover, create_mutate from .ga import create_next_generation
from .species import SpeciesInfo, update_species, create_speciate from .species import SpeciesInfo, update_species, speciate
class NEAT(Algorithm): class NEAT(Algorithm):
def __init__(self, config: Config, gene_type: Type[Gene]): def __init__(self, config: Config, gene: Gene):
self.config = config self.config = config
self.gene_type = gene_type self.gene = gene
self.forward_func = None self.forward_func = None
self.tell_func = None self.tell_func = None
@@ -31,8 +29,8 @@ class NEAT(Algorithm):
N=self.config.neat.maximum_nodes, N=self.config.neat.maximum_nodes,
C=self.config.neat.maximum_conns, C=self.config.neat.maximum_conns,
S=self.config.neat.maximum_species, S=self.config.neat.maximum_species,
NL=1 + len(self.gene_type.node_attrs), # node length = (key) + attributes NL=1 + len(self.gene.node_attrs), # node length = (key) + attributes
CL=3 + len(self.gene_type.conn_attrs), # conn length = (in, out, key) + attributes CL=3 + len(self.gene.conn_attrs), # conn length = (in, out, key) + attributes
max_stagnation=self.config.neat.max_stagnation, max_stagnation=self.config.neat.max_stagnation,
species_elitism=self.config.neat.species_elitism, species_elitism=self.config.neat.species_elitism,
spawn_number_change_rate=self.config.neat.spawn_number_change_rate, spawn_number_change_rate=self.config.neat.spawn_number_change_rate,
@@ -46,7 +44,7 @@ class NEAT(Algorithm):
output_idx=output_idx, output_idx=output_idx,
) )
state = self.gene_type.setup(self.config.gene, state) state = self.gene.setup(state)
pop_genomes = self._initialize_genomes(state) pop_genomes = self._initialize_genomes(state)
species_info = SpeciesInfo.initialize(state) species_info = SpeciesInfo.initialize(state)
@@ -74,26 +72,32 @@ class NEAT(Algorithm):
next_species_key=jnp.asarray(next_species_key, dtype=jnp.float32), next_species_key=jnp.asarray(next_species_key, dtype=jnp.float32),
) )
self.forward_func = self.gene_type.create_forward(state, self.config.gene)
self.tell_func = self._create_tell()
return jax.device_put(state) return jax.device_put(state)
def ask(self, state: State): def ask_algorithm(self, state: State):
"""require the population to be evaluated"""
return state.pop_genomes return state.pop_genomes
def tell(self, state: State, fitness): def tell_algorithm(self, state: State, fitness):
"""update the state of the algorithm""" k1, k2, randkey = jax.random.split(state.randkey, 3)
return self.tell_func(state, fitness)
def forward(self, inputs: Array, transformed: Array): state = state.update(
"""the forward function of a single forward transformation""" generation=state.generation + 1,
return self.forward_func(inputs, transformed) randkey=randkey
)
state, winner, loser, elite_mask = update_species(state, k1, fitness)
state = create_next_generation(self.config.neat, self.gene, state, k2, winner, loser, elite_mask)
state = speciate(self.gene, state)
return state
def forward_transform(self, state: State, genome: Genome): def forward_transform(self, state: State, genome: Genome):
"""create the forward transformation of a genome""" return self.gene.forward_transform(state, genome)
return self.gene_type.forward_transform(state, genome)
def forward(self, state: State, inputs, genome: Genome):
return self.gene.forward(state, inputs, genome)
def _initialize_genomes(self, state): def _initialize_genomes(self, state):
o_nodes = np.full((state.N, state.NL), np.nan, dtype=np.float32) # original nodes o_nodes = np.full((state.N, state.NL), np.nan, dtype=np.float32) # original nodes
@@ -106,80 +110,21 @@ class NEAT(Algorithm):
o_nodes[input_idx, 0] = input_idx o_nodes[input_idx, 0] = input_idx
o_nodes[output_idx, 0] = output_idx o_nodes[output_idx, 0] = output_idx
o_nodes[new_node_key, 0] = new_node_key o_nodes[new_node_key, 0] = new_node_key
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = self.gene_type.new_node_attrs(state) o_nodes[np.concatenate([input_idx, output_idx]), 1:] = self.gene.new_node_attrs(state)
o_nodes[new_node_key, 1:] = self.gene_type.new_node_attrs(state) o_nodes[new_node_key, 1:] = self.gene.new_node_attrs(state)
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)] input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)]
o_conns[input_idx, 0:2] = input_conns # in key, out key o_conns[input_idx, 0:2] = input_conns # in key, out key
o_conns[input_idx, 2] = True # enabled o_conns[input_idx, 2] = True # enabled
o_conns[input_idx, 3:] = self.gene_type.new_conn_attrs(state) o_conns[input_idx, 3:] = self.gene.new_conn_attrs(state)
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx] output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx]
o_conns[output_idx, 0:2] = output_conns # in key, out key o_conns[output_idx, 0:2] = output_conns # in key, out key
o_conns[output_idx, 2] = True # enabled o_conns[output_idx, 2] = True # enabled
o_conns[output_idx, 3:] = self.gene_type.new_conn_attrs(state) o_conns[output_idx, 3:] = self.gene.new_conn_attrs(state)
# repeat origin genome for P times to create population # repeat origin genome for P times to create population
pop_nodes = np.tile(o_nodes, (state.P, 1, 1)) pop_nodes = np.tile(o_nodes, (state.P, 1, 1))
pop_conns = np.tile(o_conns, (state.P, 1, 1)) pop_conns = np.tile(o_conns, (state.P, 1, 1))
return Genome(pop_nodes, pop_conns) return Genome(pop_nodes, pop_conns)
def _create_tell(self):
mutate = create_mutate(self.config.neat, self.gene_type)
def create_next_generation(state, randkey, winner, loser, elite_mask):
# prepare random keys
pop_size = state.idx2species.shape[0]
new_node_keys = jnp.arange(pop_size) + state.next_node_key
k1, k2 = jax.random.split(randkey, 2)
crossover_rand_keys = jax.random.split(k1, pop_size)
mutate_rand_keys = jax.random.split(k2, pop_size)
# batch crossover
wpn, wpc = state.pop_genomes.nodes[winner], state.pop_genomes.conns[winner]
lpn, lpc = state.pop_genomes.nodes[loser], state.pop_genomes.conns[loser]
n_genomes = vmap(crossover)(crossover_rand_keys, Genome(wpn, wpc), Genome(lpn, lpc))
# batch mutation
mutate_func = vmap(mutate, in_axes=(None, 0, 0, 0))
m_n_genomes = mutate_func(state, mutate_rand_keys, n_genomes, new_node_keys) # mutate_new_pop_nodes
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], n_genomes.nodes, m_n_genomes.nodes)
pop_conns = jnp.where(elite_mask[:, None, None], n_genomes.conns, m_n_genomes.conns)
# update next node key
all_nodes_keys = pop_nodes[:, :, 0]
max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys))
next_node_key = max_node_key + 1
return state.update(
pop_genomes=Genome(pop_nodes, pop_conns),
next_node_key=next_node_key,
)
speciate = create_speciate(self.gene_type)
def tell(state, fitness):
"""
Main update function in NEAT.
"""
k1, k2, randkey = jax.random.split(state.randkey, 3)
state = state.update(
generation=state.generation + 1,
randkey=randkey
)
state, winner, loser, elite_mask = update_species(state, k1, fitness)
state = create_next_generation(state, k2, winner, loser, elite_mask)
state = speciate(state)
return state
return tell

View File

@@ -1,2 +1,2 @@
from .operations import update_species, create_speciate
from .species_info import SpeciesInfo from .species_info import SpeciesInfo
from .operations import update_species, speciate

View File

@@ -1,73 +1,71 @@
from typing import Type
from jax import Array, numpy as jnp, vmap from jax import Array, numpy as jnp, vmap
from core import Gene from core import Gene
def create_distance(gene_type: Type[Gene]): def distance(gene: Gene, state, genome1, genome2):
def node_distance(state, nodes1: Array, nodes2: Array): return node_distance(gene, state, genome1.nodes, genome2.nodes) + \
""" connection_distance(gene, state, genome1.conns, genome2.conns)
Calculate the distance between nodes of two genomes.
"""
# statistics nodes count of two genomes
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
# align homologous nodes
# this process is similar to np.intersect1d.
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0]
sorted_indices = jnp.argsort(keys, axis=0)
nodes = nodes[sorted_indices]
nodes = jnp.concatenate([nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row
# flag location of homologous nodes def node_distance(gene: Gene, state, nodes1: Array, nodes2: Array):
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0]) """
Calculate the distance between nodes of two genomes.
"""
# statistics nodes count of two genomes
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
# calculate the count of non_homologous of two genomes # align homologous nodes
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask) # this process is similar to np.intersect1d.
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0]
sorted_indices = jnp.argsort(keys, axis=0)
nodes = nodes[sorted_indices]
nodes = jnp.concatenate([nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row
# calculate the distance of homologous nodes # flag location of homologous nodes
hnd = vmap(gene_type.distance_node, in_axes=(None, 0, 0))(state, fr, sr) intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
homologous_distance = jnp.sum(hnd * intersect_mask)
val = non_homologous_cnt * state.compatibility_disjoint + homologous_distance * state.compatibility_weight # calculate the count of non_homologous of two genomes
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division # calculate the distance of homologous nodes
hnd = vmap(gene.distance_node, in_axes=(None, 0, 0))(state, fr, sr)
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
homologous_distance = jnp.sum(hnd * intersect_mask)
def connection_distance(state, cons1: Array, cons2: Array): val = non_homologous_cnt * state.compatibility_disjoint + homologous_distance * state.compatibility_weight
"""
Calculate the distance between connections of two genomes.
Similar process as node_distance.
"""
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0]))
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
cons = jnp.concatenate((cons1, cons2), axis=0) return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division
keys = cons[:, :2]
sorted_indices = jnp.lexsort(keys.T[::-1])
cons = cons[sorted_indices]
cons = jnp.concatenate([cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0) # add a nan row to the end
fr, sr = cons[:-1], cons[1:] # first row, second row
# both genome has such connection
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask) def connection_distance(gene: Gene, state, cons1: Array, cons2: Array):
hcd = vmap(gene_type.distance_conn, in_axes=(None, 0, 0))(state, fr, sr) """
hcd = jnp.where(jnp.isnan(hcd), 0, hcd) Calculate the distance between connections of two genomes.
homologous_distance = jnp.sum(hcd * intersect_mask) Similar process as node_distance.
"""
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0]))
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
val = non_homologous_cnt * state.compatibility_disjoint + homologous_distance * state.compatibility_weight cons = jnp.concatenate((cons1, cons2), axis=0)
keys = cons[:, :2]
sorted_indices = jnp.lexsort(keys.T[::-1])
cons = cons[sorted_indices]
cons = jnp.concatenate([cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0) # add a nan row to the end
fr, sr = cons[:-1], cons[1:] # first row, second row
return jnp.where(max_cnt == 0, 0, val / max_cnt) # both genome has such connection
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
def distance(state, genome1, genome2): non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
return node_distance(state, genome1.nodes, genome2.nodes) + connection_distance(state, genome1.conns, genome2.conns) hcd = vmap(gene.distance_conn, in_axes=(None, 0, 0))(state, fr, sr)
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
homologous_distance = jnp.sum(hcd * intersect_mask)
return distance val = non_homologous_cnt * state.compatibility_disjoint + homologous_distance * state.compatibility_weight
return jnp.where(max_cnt == 0, 0, val / max_cnt)

View File

@@ -1,11 +1,9 @@
from typing import Type
import jax import jax
from jax import numpy as jnp, vmap from jax import numpy as jnp, vmap
from core import Gene, Genome from core import Gene, Genome, State
from utils import rank_elements, fetch_first from utils import rank_elements, fetch_first
from .distance import create_distance from .distance import distance
from .species_info import SpeciesInfo from .species_info import SpeciesInfo
@@ -170,154 +168,149 @@ def create_crossover_pair(state, randkey, spawn_number, fitness):
return winner, loser, elite_mask return winner, loser, elite_mask
def create_speciate(gene_type: Type[Gene]): def speciate(gene: Gene, state: State):
distance = create_distance(gene_type) pop_size, species_size = state.idx2species.shape[0], state.species_info.size()
def speciate(state): # prepare distance functions
pop_size, species_size = state.idx2species.shape[0], state.species_info.size() o2p_distance_func = vmap(distance, in_axes=(None, None, None, 0)) # one to population
# prepare distance functions # idx to specie key
o2p_distance_func = vmap(distance, in_axes=(None, None, 0)) # one to population idx2species = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species
# idx to specie key # the distance between genomes to its center genomes
idx2species = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species o2c_distances = jnp.full((pop_size,), jnp.inf)
# the distance between genomes to its center genomes # step 1: find new centers
o2c_distances = jnp.full((pop_size,), jnp.inf) def cond_func(carry):
i, i2s, cgs, o2c = carry
# step 1: find new centers return (i < species_size) & (~jnp.isnan(state.species_info.species_keys[i])) # current species is existing
def cond_func(carry):
i, i2s, cgs, o2c = carry
return (i < species_size) & (~jnp.isnan(state.species_info.species_keys[i])) # current species is existing def body_func(carry):
i, i2s, cgs, o2c = carry
def body_func(carry): distances = o2p_distance_func(gene, state, cgs[i], state.pop_genomes)
i, i2s, cgs, o2c = carry
distances = o2p_distance_func(state, cgs[i], state.pop_genomes) # find the closest one
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
# find the closest one i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i])
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s)) cgs = cgs.set(i, state.pop_genomes[closest_idx])
i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i]) # the genome with closest_idx will become the new center, thus its distance to center is 0.
cgs = cgs.set(i, state.pop_genomes[closest_idx]) o2c = o2c.at[closest_idx].set(0)
# the genome with closest_idx will become the new center, thus its distance to center is 0. return i + 1, i2s, cgs, o2c
o2c = o2c.at[closest_idx].set(0)
return i + 1, i2s, cgs, o2c _, idx2species, center_genomes, o2c_distances = \
jax.lax.while_loop(cond_func, body_func, (0, idx2species, state.center_genomes, o2c_distances))
_, idx2species, center_genomes, o2c_distances = \ state = state.update(
jax.lax.while_loop(cond_func, body_func, (0, idx2species, state.center_genomes, o2c_distances)) idx2species=idx2species,
center_genomes=center_genomes,
)
state = state.update( # part 2: assign members to each species
idx2species=idx2species, def cond_func(carry):
center_genomes=center_genomes, i, i2s, cgs, sk, o2c, nsk = carry
current_species_existed = ~jnp.isnan(sk[i])
not_all_assigned = jnp.any(jnp.isnan(i2s))
not_reach_species_upper_bounds = i < species_size
return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
def body_func(carry):
i, i2s, cgs, sk, o2c, nsk = carry
_, i2s, cgs, sk, o2c, nsk = jax.lax.cond(
jnp.isnan(sk[i]), # whether the current species is existing or not
create_new_species, # if not existing, create a new specie
update_exist_specie, # if existing, update the specie
(i, i2s, cgs, sk, o2c, nsk)
) )
# part 2: assign members to each species return i + 1, i2s, cgs, sk, o2c, nsk
def cond_func(carry):
i, i2s, cgs, sk, o2c, nsk = carry
current_species_existed = ~jnp.isnan(sk[i]) def create_new_species(carry):
not_all_assigned = jnp.any(jnp.isnan(i2s)) i, i2s, cgs, sk, o2c, nsk = carry
not_reach_species_upper_bounds = i < species_size
return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
def body_func(carry): # pick the first one who has not been assigned to any species
i, i2s, cgs, sk, o2c, nsk = carry idx = fetch_first(jnp.isnan(i2s))
_, i2s, cgs, sk, o2c, nsk = jax.lax.cond( # assign it to the new species
jnp.isnan(sk[i]), # whether the current species is existing or not # [key, best score, last update generation, member_count]
create_new_species, # if not existing, create a new specie sk = sk.at[i].set(nsk)
update_exist_specie, # if existing, update the specie i2s = i2s.at[idx].set(nsk)
(i, i2s, cgs, sk, o2c, nsk) o2c = o2c.at[idx].set(0)
)
return i + 1, i2s, cgs, sk, o2c, nsk # update center genomes
cgs = cgs.set(i, state.pop_genomes[idx])
def create_new_species(carry): i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c)
i, i2s, cgs, sk, o2c, nsk = carry
# pick the first one who has not been assigned to any species # when a new species is created, it needs to be updated, thus do not change i
idx = fetch_first(jnp.isnan(i2s)) return i + 1, i2s, cgs, sk, o2c, nsk + 1 # change to next new speciate key
# assign it to the new species def update_exist_specie(carry):
# [key, best score, last update generation, member_count] i, i2s, cgs, sk, o2c, nsk = carry
sk = sk.at[i].set(nsk)
i2s = i2s.at[idx].set(nsk)
o2c = o2c.at[idx].set(0)
# update center genomes i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c)
cgs = cgs.set(i, state.pop_genomes[idx])
i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c) # turn to next species
return i + 1, i2s, cgs, sk, o2c, nsk
# when a new species is created, it needs to be updated, thus do not change i def speciate_by_threshold(i, i2s, cgs, sk, o2c):
return i + 1, i2s, cgs, sk, o2c, nsk + 1 # change to next new speciate key # distance between such center genome and ppo genomes
def update_exist_specie(carry): o2p_distance = o2p_distance_func(gene, state, cgs[i], state.pop_genomes)
i, i2s, cgs, sk, o2c, nsk = carry close_enough_mask = o2p_distance < state.compatibility_threshold
i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c) # when a genome is not assigned or the distance between its current center is bigger than this center
cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c)
# jax.debug.print("{}", o2p_distance)
mask = close_enough_mask & cacheable_mask
# turn to next species # update species info
return i + 1, i2s, cgs, sk, o2c, nsk i2s = jnp.where(mask, sk[i], i2s)
def speciate_by_threshold(i, i2s, cgs, sk, o2c): # update distance between centers
# distance between such center genome and ppo genomes o2c = jnp.where(mask, o2p_distance, o2c)
o2p_distance = o2p_distance_func(state, cgs[i], state.pop_genomes) return i2s, o2c
close_enough_mask = o2p_distance < state.compatibility_threshold
# when a genome is not assigned or the distance between its current center is bigger than this center # update idx2species
cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c) _, idx2species, center_genomes, species_keys, _, next_species_key = jax.lax.while_loop(
# jax.debug.print("{}", o2p_distance) cond_func,
mask = close_enough_mask & cacheable_mask body_func,
(0, state.idx2species, state.center_genomes, state.species_info.species_keys, o2c_distances,
state.next_species_key)
)
# update species info # if there are still some pop genomes not assigned to any species, add them to the last genome
i2s = jnp.where(mask, sk[i], i2s) # this condition can only happen when the number of species is reached species upper bounds
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
# update distance between centers # complete info of species which is created in this generation
o2c = jnp.where(mask, o2p_distance, o2c) new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species_info.best_fitness)
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species_info.best_fitness)
last_improved = jnp.where(new_created_mask, state.generation, state.species_info.last_improved)
return i2s, o2c # update members count
def count_members(idx):
key = species_keys[idx]
count = jnp.sum(idx2species == key, dtype=jnp.float32)
count = jnp.where(jnp.isnan(key), jnp.nan, count)
# update idx2species return count
_, idx2species, center_genomes, species_keys, _, next_species_key = jax.lax.while_loop(
cond_func,
body_func,
(0, state.idx2species, state.center_genomes, state.species_info.species_keys, o2c_distances, state.next_species_key)
)
member_count = vmap(count_members)(jnp.arange(species_size))
# if there are still some pop genomes not assigned to any species, add them to the last genome return state.update(
# this condition can only happen when the number of species is reached species upper bounds species_info=SpeciesInfo(species_keys, best_fitness, last_improved, member_count),
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species) idx2species=idx2species,
center_genomes=center_genomes,
# complete info of species which is created in this generation next_species_key=next_species_key
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species_info.best_fitness) )
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species_info.best_fitness)
last_improved = jnp.where(new_created_mask, state.generation, state.species_info.last_improved)
# update members count
def count_members(idx):
key = species_keys[idx]
count = jnp.sum(idx2species == key, dtype=jnp.float32)
count = jnp.where(jnp.isnan(key), jnp.nan, count)
return count
member_count = vmap(count_members)(jnp.arange(species_size))
return state.update(
species_info = SpeciesInfo(species_keys, best_fitness, last_improved, member_count),
idx2species=idx2species,
center_genomes=center_genomes,
next_species_key=next_species_key
)
return speciate
def argmin_with_mask(arr, mask): def argmin_with_mask(arr, mask):

View File

@@ -2,6 +2,7 @@ from jax.tree_util import register_pytree_node_class
import numpy as np import numpy as np
import jax.numpy as jnp import jax.numpy as jnp
@register_pytree_node_class @register_pytree_node_class
class SpeciesInfo: class SpeciesInfo:
@@ -44,7 +45,6 @@ class SpeciesInfo:
def size(self): def size(self):
return self.species_keys.shape[0] return self.species_keys.shape[0]
def tree_flatten(self): def tree_flatten(self):
children = self.species_keys, self.best_fitness, self.last_improved, self.member_count children = self.species_keys, self.best_fitness, self.last_improved, self.member_count
aux_data = None aux_data = None

View File

@@ -1,2 +1 @@
from .config import * from .config import *

View File

@@ -86,6 +86,7 @@ class HyperNeatConfig:
class GeneConfig: class GeneConfig:
pass pass
@dataclass(frozen=True) @dataclass(frozen=True)
class SubstrateConfig: class SubstrateConfig:
pass pass

View File

@@ -1,76 +0,0 @@
[basic]
random_seed = 0
generation_limit = 1000
fitness_threshold = 3.9999
num_inputs = 2
num_outputs = 1
[neat]
network_type = "feedforward"
activate_times = 5
maximum_nodes = 50
maximum_conns = 50
maximum_species = 10
compatibility_disjoint = 1.0
compatibility_weight = 0.5
conn_add_prob = 0.4
conn_delete_prob = 0
node_add_prob = 0.2
node_delete_prob = 0
[hyperneat]
below_threshold = 0.2
max_weight = 3
h_activation = "sigmoid"
h_aggregation = "sum"
h_activate_times = 5
[substrate]
input_coors = [[-1, 1], [0, 1], [1, 1]]
hidden_coors = [[-1, 0], [0, 0], [1, 0]]
output_coors = [[0, -1]]
[species]
compatibility_threshold = 3.0
species_elitism = 2
max_stagnation = 15
genome_elitism = 2
survival_threshold = 0.2
min_species_size = 1
spawn_number_change_rate = 0.5
[gene]
# bias
bias_init_mean = 0.0
bias_init_std = 1.0
bias_mutate_power = 0.5
bias_mutate_rate = 0.7
bias_replace_rate = 0.1
# response
response_init_mean = 1.0
response_init_std = 0.0
response_mutate_power = 0.0
response_mutate_rate = 0.0
response_replace_rate = 0.0
# activation
activation_default = "sigmoid"
activation_option_names = ["tanh"]
activation_replace_rate = 0.0
# aggregation
aggregation_default = "sum"
aggregation_option_names = ["sum"]
aggregation_replace_rate = 0.0
# weight
weight_init_mean = 0.0
weight_init_std = 1.0
weight_mutate_power = 0.5
weight_mutate_rate = 0.8
weight_replace_rate = 0.1
[visualize]
renumber_nodes = True

View File

@@ -1,28 +1,50 @@
from jax import Array from functools import partial
import jax
from .state import State from .state import State
from .genome import Genome from .genome import Genome
EMPTY = lambda *args: args
class Algorithm: class Algorithm:
def setup(self, randkey, state: State = State()): def setup(self, randkey, state: State = State()):
"""initialize the state of the algorithm""" """initialize the state of the algorithm"""
pass
raise NotImplementedError
@partial(jax.jit, static_argnums=(0,))
def ask(self, state: State): def ask(self, state: State):
"""require the population to be evaluated""" """require the population to be evaluated"""
pass
return self.ask_algorithm(state)
@partial(jax.jit, static_argnums=(0,))
def tell(self, state: State, fitness): def tell(self, state: State, fitness):
"""update the state of the algorithm""" """update the state of the algorithm"""
pass
def forward(self, inputs: Array, transformed: Array): return self.tell_algorithm(state, fitness)
"""the forward function of a single forward transformation"""
pass @partial(jax.jit, static_argnums=(0,))
def transform(self, state: State, genome: Genome):
"""transform the genome into a neural network"""
return self.forward_transform(state, genome)
@partial(jax.jit, static_argnums=(0,))
def act(self, state: State, inputs, genome: Genome):
return self.forward(state, inputs, genome)
def forward_transform(self, state: State, genome: Genome): def forward_transform(self, state: State, genome: Genome):
"""create the forward transformation of a genome""" raise NotImplementedError
pass
def forward(self, state: State, inputs, genome: Genome):
raise NotImplementedError
def ask_algorithm(self, state: State):
"""ask the specific algorithm for a new population"""
raise NotImplementedError
def tell_algorithm(self, state: State, fitness):
"""tell the specific algorithm the fitness of the population"""
raise NotImplementedError

View File

@@ -1,46 +1,37 @@
from jax import Array, numpy as jnp
from config import GeneConfig from config import GeneConfig
from .state import State from .state import State
from .genome import Genome
class Gene: class Gene:
node_attrs = [] node_attrs = []
conn_attrs = [] conn_attrs = []
@staticmethod def setup(self, state=State()):
def setup(config: GeneConfig, state: State): raise NotImplementedError
return state
@staticmethod def update(self, state):
def new_node_attrs(state: State): raise NotImplementedError
return jnp.zeros(0)
@staticmethod def new_node_attrs(self, state: State):
def new_conn_attrs(state: State): raise NotImplementedError
return jnp.zeros(0)
@staticmethod def new_conn_attrs(self, state: State):
def mutate_node(state: State, attrs: Array, randkey: Array): raise NotImplementedError
return attrs
@staticmethod def mutate_node(self, state: State, randkey, node_attrs):
def mutate_conn(state: State, attrs: Array, randkey: Array): raise NotImplementedError
return attrs
@staticmethod def mutate_conn(self, state: State, randkey, conn_attrs):
def distance_node(state: State, node1: Array, node2: Array): raise NotImplementedError
return node1
@staticmethod def distance_node(self, state: State, node_attrs1, node_attrs2):
def distance_conn(state: State, conn1: Array, conn2: Array): raise NotImplementedError
return conn1
@staticmethod def distance_conn(self, state: State, conn_attrs1, conn_attrs2):
def forward_transform(state: State, genome: Genome): raise NotImplementedError
return jnp.zeros(0) # transformed
@staticmethod def forward_transform(self, state: State, genome):
def create_forward(state: State, config: GeneConfig): raise NotImplementedError
return lambda *args: args # forward function
def forward(self, state: State, inputs, transform):
raise NotImplementedError

View File

@@ -84,4 +84,3 @@ class Genome:
def tree_unflatten(cls, aux_data, children): def tree_unflatten(cls, aux_data, children):
return cls(*children) return cls(*children)

24
examples/test.py Normal file
View File

@@ -0,0 +1,24 @@
from functools import partial
import jax
class A:
def __init__(self):
self.a = 1
self.b = 2
self.isTrue = False
@partial(jax.jit, static_argnums=(0,))
def step(self):
if self.isTrue:
return self.a + 1
else:
return self.b + 1
AA = A()
print(AA.step(), hash(AA))
print(AA.step(), hash(AA))
print(AA.step(), hash(AA))
AA.a = (2, 3, 4)
print(AA.step(), hash(AA))

View File

@@ -3,7 +3,8 @@ import numpy as np
from config import Config, BasicConfig, NeatConfig from config import Config, BasicConfig, NeatConfig
from pipeline import Pipeline from pipeline import Pipeline
from algorithm import NEAT, NormalGene, NormalGeneConfig from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
@@ -23,15 +24,15 @@ def evaluate(forward_func):
if __name__ == '__main__': if __name__ == '__main__':
config = Config( config = Config(
basic=BasicConfig( basic=BasicConfig(
fitness_target=3.99999, fitness_target=3.9999999,
pop_size=10000 pop_size=10000
), ),
neat=NeatConfig( neat=NeatConfig(
maximum_nodes=20, maximum_nodes=20,
maximum_conns=50, maximum_conns=50,
), )
gene=NormalGeneConfig()
) )
algorithm = NEAT(config, NormalGene) normal_gene = NormalGene(NormalGeneConfig())
algorithm = NEAT(config, normal_gene)
pipeline = Pipeline(config, algorithm) pipeline = Pipeline(config, algorithm)
pipeline.auto_run(evaluate) pipeline.auto_run(evaluate)

View File

@@ -1,49 +0,0 @@
import jax
import numpy as np
from config import Config, BasicConfig, NeatConfig
from pipeline import Pipeline
from algorithm import NEAT, RecurrentGene, RecurrentGeneConfig
from algorithm import HyperNEAT, NormalSubstrate, NormalSubstrateConfig
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses
if __name__ == '__main__':
config = Config(
basic=BasicConfig(
fitness_target=3.99999,
pop_size=100
),
neat=NeatConfig(
network_type="recurrent",
maximum_nodes=50,
maximum_conns=100,
inputs=4,
outputs=1
),
gene=RecurrentGeneConfig(
activation_default="tanh",
activation_options=("tanh", ),
),
substrate=NormalSubstrateConfig(),
)
neat = NEAT(config, RecurrentGene)
hyperNEAT = HyperNEAT(config, neat, NormalSubstrate)
pipeline = Pipeline(config, hyperNEAT)
pipeline.auto_run(evaluate)

View File

@@ -1,39 +0,0 @@
import jax
import numpy as np
from config import Config, BasicConfig, NeatConfig
from pipeline import Pipeline
from algorithm import NEAT, RecurrentGene, RecurrentGeneConfig
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses
if __name__ == '__main__':
config = Config(
basic=BasicConfig(
fitness_target=3.99999,
pop_size=10000
),
neat=NeatConfig(
network_type="recurrent",
maximum_nodes=50,
maximum_conns=100
),
gene=RecurrentGeneConfig()
)
algorithm = NEAT(config, RecurrentGene)
pipeline = Pipeline(config, algorithm)
pipeline.auto_run(evaluate)

View File

@@ -27,15 +27,15 @@ class Pipeline:
self.evaluate_time = 0 self.evaluate_time = 0
self.forward_func = jit(self.algorithm.forward) self.act_func = jit(self.algorithm.act)
self.batch_forward_func = jit(vmap(self.forward_func, in_axes=(0, None))) self.batch_act_func = jit(vmap(self.act_func, in_axes=(None, 0, None)))
self.pop_batch_forward_func = jit(vmap(self.batch_forward_func, in_axes=(None, 0))) self.pop_batch_act_func = jit(vmap(self.batch_act_func, in_axes=(None, None, 0)))
self.forward_transform_func = jit(vmap(self.algorithm.forward_transform, in_axes=(None, 0))) self.forward_transform_func = jit(vmap(self.algorithm.forward_transform, in_axes=(None, 0)))
self.tell_func = jit(self.algorithm.tell) self.tell_func = jit(self.algorithm.tell)
def ask(self): def ask(self):
pop_transforms = self.forward_transform_func(self.state, self.state.pop_genomes) pop_transforms = self.forward_transform_func(self.state, self.state.pop_genomes)
return lambda inputs: self.pop_batch_forward_func(inputs, pop_transforms) return lambda inputs: self.pop_batch_act_func(self.state, inputs, pop_transforms)
def tell(self, fitness): def tell(self, fitness):
# self.state = self.tell_func(self.state, fitness) # self.state = self.tell_func(self.state, fitness)
@@ -81,7 +81,3 @@ class Pipeline:
print(f"Generation: {self.state.generation}", print(f"Generation: {self.state.generation}",
f"species: {len(species_sizes)}, {species_sizes}", f"species: {len(species_sizes)}, {species_sizes}",
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms") f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")

View File

@@ -1,4 +1,35 @@
from .activation import Activation from .activation import Activation, act
from .aggregation import Aggregation from .aggregation import Aggregation, agg
from .tools import * from .tools import *
from .graph import * from .graph import *
Activation.name2func = {
'sigmoid': Activation.sigmoid_act,
'tanh': Activation.tanh_act,
'sin': Activation.sin_act,
'gauss': Activation.gauss_act,
'relu': Activation.relu_act,
'elu': Activation.elu_act,
'lelu': Activation.lelu_act,
'selu': Activation.selu_act,
'softplus': Activation.softplus_act,
'identity': Activation.identity_act,
'clamped': Activation.clamped_act,
'inv': Activation.inv_act,
'log': Activation.log_act,
'exp': Activation.exp_act,
'abs': Activation.abs_act,
'hat': Activation.hat_act,
'square': Activation.square_act,
'cube': Activation.cube_act,
}
Aggregation.name2func = {
'sum': Aggregation.sum_agg,
'product': Aggregation.product_agg,
'max': Aggregation.max_agg,
'min': Aggregation.min_agg,
'maxabs': Aggregation.maxabs_agg,
'median': Aggregation.median_agg,
'mean': Aggregation.mean_agg,
}

View File

@@ -1,8 +1,8 @@
import jax
import jax.numpy as jnp import jax.numpy as jnp
class Activation: class Activation:
name2func = {} name2func = {}
@staticmethod @staticmethod
@@ -89,23 +89,11 @@ class Activation:
return z ** 3 return z ** 3
Activation.name2func = { def act(idx, z, act_funcs):
'sigmoid': Activation.sigmoid_act, """
'tanh': Activation.tanh_act, calculate activation function for each node
'sin': Activation.sin_act, """
'gauss': Activation.gauss_act, idx = jnp.asarray(idx, dtype=jnp.int32)
'relu': Activation.relu_act, # change idx from float to int
'elu': Activation.elu_act, res = jax.lax.switch(idx, act_funcs, z)
'lelu': Activation.lelu_act, return res
'selu': Activation.selu_act,
'softplus': Activation.softplus_act,
'identity': Activation.identity_act,
'clamped': Activation.clamped_act,
'inv': Activation.inv_act,
'log': Activation.log_act,
'exp': Activation.exp_act,
'abs': Activation.abs_act,
'hat': Activation.hat_act,
'square': Activation.square_act,
'cube': Activation.cube_act,
}

View File

@@ -1,8 +1,8 @@
import jax
import jax.numpy as jnp import jax.numpy as jnp
class Aggregation: class Aggregation:
name2func = {} name2func = {}
@staticmethod @staticmethod
@@ -52,12 +52,16 @@ class Aggregation:
return mean_without_zeros return mean_without_zeros
Aggregation.name2func = { def agg(idx, z, agg_funcs):
'sum': Aggregation.sum_agg, """
'product': Aggregation.product_agg, calculate activation function for inputs of node
'max': Aggregation.max_agg, """
'min': Aggregation.min_agg, idx = jnp.asarray(idx, dtype=jnp.int32)
'maxabs': Aggregation.maxabs_agg,
'median': Aggregation.median_agg, def all_nan():
'mean': Aggregation.mean_agg, return 0.
}
def not_all_nan():
return jax.lax.switch(idx, agg_funcs, z)
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)