new architecture

This commit is contained in:
wls2002
2024-01-27 00:52:39 +08:00
parent 4efe9a53c1
commit aac41a089d
65 changed files with 1651 additions and 1783 deletions

View File

@@ -1,2 +1,2 @@
from .neat import NEAT
from .hyperneat import HyperNEAT
from .base import BaseAlgorithm
from .neat import NEAT

24
algorithm/base.py Normal file
View File

@@ -0,0 +1,24 @@
from utils import State
class BaseAlgorithm:
def setup(self, randkey):
"""initialize the state of the algorithm"""
raise NotImplementedError
def ask(self, state: State):
"""require the population to be evaluated"""
raise NotImplementedError
def tell(self, state: State, fitness):
"""update the state of the algorithm"""
raise NotImplementedError
def transform(self, state: State):
"""transform the genome into a neural network"""
raise NotImplementedError
def forward(self, inputs, transformed):
raise NotImplementedError

View File

@@ -1,2 +0,0 @@
from .hyperneat import HyperNEAT
from .substrate import NormalSubstrate, NormalSubstrateConfig

View File

@@ -1,113 +0,0 @@
from typing import Type
import jax
from jax import numpy as jnp, Array, vmap
import numpy as np
from config import Config, HyperNeatConfig
from core import Algorithm, Substrate, State, Genome, Gene
from .substrate import analysis_substrate
from algorithm import NEAT
class HyperNEAT(Algorithm):
def __init__(self, config: Config, gene: Type[Gene], substrate: Type[Substrate]):
self.config = config
self.neat = NEAT(config, gene)
self.substrate = substrate
def setup(self, randkey, state=State()):
neat_key, randkey = jax.random.split(randkey)
state = state.update(
below_threshold=self.config.hyperneat.below_threshold,
max_weight=self.config.hyperneat.max_weight,
)
state = self.neat.setup(neat_key, state)
state = self.substrate.setup(self.config.substrate, state)
assert self.config.hyperneat.inputs + 1 == state.input_coors.shape[0] # +1 for bias
assert self.config.hyperneat.outputs == state.output_coors.shape[0]
h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state)
h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis]
h_conns = np.zeros((correspond_keys.shape[0], 3), dtype=np.float32)
h_conns[:, 0:2] = correspond_keys
state = state.update(
h_input_idx=h_input_idx,
h_output_idx=h_output_idx,
h_hidden_idx=h_hidden_idx,
h_nodes=h_nodes,
h_conns=h_conns,
query_coors=query_coors,
)
return state
def ask_algorithm(self, state: State):
return state.pop_genomes
def tell_algorithm(self, state: State, fitness):
return self.neat.tell(state, fitness)
def forward(self, state, inputs: Array, transformed: Array):
return HyperNEATGene.forward(self.config.hyperneat, state, inputs, transformed)
def forward_transform(self, state: State, genome: Genome):
t = self.neat.forward_transform(state, genome)
query_res = vmap(self.neat.forward, in_axes=(None, 0, None))(state, state.query_coors, t)
# mute the connection with weight below threshold
query_res = jnp.where((-state.below_threshold < query_res) & (query_res < state.below_threshold), 0., query_res)
# make query res in range [-max_weight, max_weight]
query_res = jnp.where(query_res > 0, query_res - state.below_threshold, query_res)
query_res = jnp.where(query_res < 0, query_res + state.below_threshold, query_res)
query_res = query_res / (1 - state.below_threshold) * state.max_weight
h_conns = state.h_conns.at[:, 2:].set(query_res)
return HyperNEATGene.forward_transform(Genome(state.h_nodes, h_conns))
class HyperNEATGene:
node_attrs = [] # no node attributes
conn_attrs = ['weight']
@staticmethod
def forward_transform(genome: Genome):
N = genome.nodes.shape[0]
u_conns = jnp.zeros((N, N), dtype=jnp.float32)
in_keys = jnp.asarray(genome.conns[:, 0], jnp.int32)
out_keys = jnp.asarray(genome.conns[:, 1], jnp.int32)
weights = genome.conns[:, 2]
u_conns = u_conns.at[in_keys, out_keys].set(weights)
return genome.nodes, u_conns
@staticmethod
def forward(config: HyperNeatConfig, state: State, inputs, transformed):
batch_act, batch_agg = jax.vmap(config.activation), jax.vmap(config.aggregation)
nodes, weights = transformed
inputs_with_bias = jnp.concatenate((inputs, jnp.ones((1,))), axis=0)
input_idx = state.h_input_idx
output_idx = state.h_output_idx
N = nodes.shape[0]
vals = jnp.full((N,), 0.)
def body_func(i, values):
values = values.at[input_idx].set(inputs_with_bias)
nodes_ins = values * weights.T
values = batch_agg(nodes_ins) # z = agg(ins)
# values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(values) # z = act(z)
return values
vals = jax.lax.fori_loop(0, config.activate_times, body_func, vals)
return vals[output_idx]

View File

@@ -1,2 +0,0 @@
from .normal import NormalSubstrate, NormalSubstrateConfig
from .tools import analysis_substrate

View File

@@ -1,25 +0,0 @@
from dataclasses import dataclass
from typing import Tuple
import numpy as np
from core import Substrate, State
from config import SubstrateConfig
@dataclass(frozen=True)
class NormalSubstrateConfig(SubstrateConfig):
input_coors: Tuple = ((-1, -1), (0, -1), (1, -1))
hidden_coors: Tuple = ((-1, 0), (0, 0), (1, 0))
output_coors: Tuple = ((0, 1),)
class NormalSubstrate(Substrate):
@staticmethod
def setup(config: NormalSubstrateConfig, state: State = State()):
return state.update(
input_coors=np.asarray(config.input_coors, dtype=np.float32),
output_coors=np.asarray(config.output_coors, dtype=np.float32),
hidden_coors=np.asarray(config.hidden_coors, dtype=np.float32),
)

View File

@@ -1,49 +0,0 @@
import numpy as np
def analysis_substrate(state):
cd = state.input_coors.shape[1] # coordinate dimensions
si = state.input_coors.shape[0] # input coordinate size
so = state.output_coors.shape[0] # output coordinate size
sh = state.hidden_coors.shape[0] # hidden coordinate size
input_idx = np.arange(si)
output_idx = np.arange(si, si + so)
hidden_idx = np.arange(si + so, si + so + sh)
total_conns = si * sh + sh * sh + sh * so
query_coors = np.zeros((total_conns, cd * 2))
correspond_keys = np.zeros((total_conns, 2))
# connect input to hidden
aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, state.input_coors, state.hidden_coors)
query_coors[0: si * sh, :] = aux_coors
correspond_keys[0: si * sh, :] = aux_keys
# connect hidden to hidden
aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, state.hidden_coors, state.hidden_coors)
query_coors[si * sh: si * sh + sh * sh, :] = aux_coors
correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys
# connect hidden to output
aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, state.hidden_coors, state.output_coors)
query_coors[si * sh + sh * sh:, :] = aux_coors
correspond_keys[si * sh + sh * sh:, :] = aux_keys
return input_idx, output_idx, hidden_idx, query_coors, correspond_keys
def cartesian_product(keys1, keys2, coors1, coors2):
len1 = keys1.shape[0]
len2 = keys2.shape[0]
repeated_coors1 = np.repeat(coors1, len2, axis=0)
repeated_keys1 = np.repeat(keys1, len2)
tiled_coors2 = np.tile(coors2, (len1, 1))
tiled_keys2 = np.tile(keys2, len1)
new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1)
correspond_keys = np.column_stack((repeated_keys1, tiled_keys2))
return new_coors, correspond_keys

View File

@@ -1,2 +1,3 @@
from .neat import NEAT
from .gene import *
from .genome import *
from .neat import NEAT

View File

@@ -1,3 +1,2 @@
from .crossover import crossover
from .mutate import mutate
from .operation import create_next_generation
from .crossover import BaseCrossover, DefaultCrossover
from .mutation import BaseMutation, DefaultMutation

View File

@@ -1,70 +0,0 @@
import jax
from jax import Array, numpy as jnp
from core import Genome
def crossover(randkey, genome1: Genome, genome2: Genome):
"""
use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
"""
randkey_1, randkey_2, key = jax.random.split(randkey, 3)
# crossover nodes
keys1, keys2 = genome1.nodes[:, 0], genome2.nodes[:, 0]
# make homologous genes align in nodes2 align with nodes1
nodes2 = align_array(keys1, keys2, genome2.nodes, False)
nodes1 = genome1.nodes
# For not homologous genes, use the value of nodes1(winner)
# For homologous genes, use the crossover result between nodes1 and nodes2
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, crossover_gene(randkey_1, nodes1, nodes2))
# crossover connections
con_keys1, con_keys2 = genome1.conns[:, :2], genome2.conns[:, :2]
conns2 = align_array(con_keys1, con_keys2, genome2.conns, True)
conns1 = genome1.conns
new_cons = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, crossover_gene(randkey_2, conns1, conns2))
return genome1.update(new_nodes, new_cons)
def align_array(seq1: Array, seq2: Array, ar2: Array, is_conn: bool) -> Array:
"""
After I review this code, I found that it is the most difficult part of the code. Please never change it!
make ar2 align with ar1.
:param seq1:
:param seq2:
:param ar2:
:param is_conn:
:return:
align means to intersect part of ar2 will be at the same position as ar1,
non-intersect part of ar2 will be set to Nan
"""
seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :]
mask = (seq1 == seq2) & (~jnp.isnan(seq1))
if is_conn:
mask = jnp.all(mask, axis=2)
intersect_mask = mask.any(axis=1)
idx = jnp.arange(0, len(seq1))
idx_fixed = jnp.dot(mask, idx)
refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan)
return refactor_ar2
def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
"""
crossover two genes
:param rand_key:
:param g1:
:param g2:
:return:
only gene with the same key will be crossover, thus don't need to consider change key
"""
r = jax.random.uniform(rand_key, shape=g1.shape)
return jnp.where(r > 0.5, g1, g2)

View File

@@ -0,0 +1,2 @@
from .base import BaseCrossover
from .default import DefaultCrossover

View File

@@ -0,0 +1,3 @@
class BaseCrossover:
def __call__(self, randkey, genome, nodes1, nodes2, conns1, conns2):
raise NotImplementedError

View File

@@ -0,0 +1,66 @@
import jax, jax.numpy as jnp
from .base import BaseCrossover
class DefaultCrossover(BaseCrossover):
def __call__(self, randkey, genome, nodes1, nodes2, conns1, conns2):
"""
use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
"""
randkey_1, randkey_2, key = jax.random.split(randkey, 3)
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
# make homologous genes align in nodes2 align with nodes1
nodes2 = self.align_array(keys1, keys2, nodes2, False)
# For not homologous genes, use the value of nodes1(winner)
# For homologous genes, use the crossover result between nodes1 and nodes2
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, self.crossover_gene(randkey_1, nodes1, nodes2))
# crossover connections
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
conns2 = self.align_array(con_keys1, con_keys2, conns2, True)
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, self.crossover_gene(randkey_2, conns1, conns2))
return new_nodes, new_conns
def align_array(self, seq1, seq2, ar2, is_conn: bool):
"""
After I review this code, I found that it is the most difficult part of the code. Please never change it!
make ar2 align with ar1.
:param seq1:
:param seq2:
:param ar2:
:param is_conn:
:return:
align means to intersect part of ar2 will be at the same position as ar1,
non-intersect part of ar2 will be set to Nan
"""
seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :]
mask = (seq1 == seq2) & (~jnp.isnan(seq1))
if is_conn:
mask = jnp.all(mask, axis=2)
intersect_mask = mask.any(axis=1)
idx = jnp.arange(0, len(seq1))
idx_fixed = jnp.dot(mask, idx)
refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan)
return refactor_ar2
def crossover_gene(self, rand_key, g1, g2):
"""
crossover two genes
:param rand_key:
:param g1:
:param g2:
:return:
only gene with the same key will be crossover, thus don't need to consider change key
"""
r = jax.random.uniform(rand_key, shape=g1.shape)
return jnp.where(r > 0.5, g1, g2)

View File

@@ -1,186 +0,0 @@
from typing import Tuple
import jax
from jax import Array, numpy as jnp, vmap
from config import NeatConfig
from core import State, Gene, Genome
from utils import check_cycles, fetch_random, fetch_first, I_INT, unflatten_conns
def mutate(config: NeatConfig, gene: Gene, state: State, randkey, genome: Genome, new_node_key):
"""
Mutate a population of genomes
"""
k1, k2 = jax.random.split(randkey)
genome = mutate_structure(config, gene, state, k1, genome, new_node_key)
genome = mutate_values(gene, state, randkey, genome)
return genome
def mutate_structure(config: NeatConfig, gene: Gene, state: State, randkey, genome: Genome, new_node_key):
def mutate_add_node(key_, genome_: Genome):
i_key, o_key, idx = choice_connection_key(key_, genome_.conns)
def nothing():
return genome_
def successful_add_node():
# disable the connection
new_genome = genome_.update_conns(genome_.conns.at[idx, 2].set(False))
# add a new node
new_genome = new_genome.add_node(new_node_key, gene.new_node_attrs(state))
# add two new connections
new_genome = new_genome.add_conn(i_key, new_node_key, True, gene.new_conn_attrs(state))
new_genome = new_genome.add_conn(new_node_key, o_key, True, gene.new_conn_attrs(state))
return new_genome
# if from_idx == I_INT, that means no connection exist, do nothing
return jax.lax.cond(idx == I_INT, nothing, successful_add_node)
def mutate_delete_node(key_, genome_: Genome):
# TODO: Do we really need to delete a node?
# randomly choose a node
key, idx = choice_node_key(key_, genome_.nodes, state.input_idx, state.output_idx,
allow_input_keys=False, allow_output_keys=False)
def nothing():
return genome_
def successful_delete_node():
# delete the node
new_genome = genome_.delete_node_by_pos(idx)
# delete all connections
new_conns = jnp.where(((new_genome.conns[:, 0] == key) | (new_genome.conns[:, 1] == key))[:, None],
jnp.nan, new_genome.conns)
return new_genome.update_conns(new_conns)
return jax.lax.cond(idx == I_INT, nothing, successful_delete_node)
def mutate_add_conn(key_, genome_: Genome):
# randomly choose two nodes
k1_, k2_ = jax.random.split(key_, num=2)
i_key, from_idx = choice_node_key(k1_, genome_.nodes, state.input_idx, state.output_idx,
allow_input_keys=True, allow_output_keys=True)
o_key, to_idx = choice_node_key(k2_, genome_.nodes, state.input_idx, state.output_idx,
allow_input_keys=False, allow_output_keys=True)
conn_pos = fetch_first((genome_.conns[:, 0] == i_key) & (genome_.conns[:, 1] == o_key))
def nothing():
return genome_
def successful():
return genome_.add_conn(i_key, o_key, True, gene.new_conn_attrs(state))
def already_exist():
return genome_.update_conns(genome_.conns.at[conn_pos, 2].set(True))
is_already_exist = conn_pos != I_INT
if config.network_type == 'feedforward':
u_cons = unflatten_conns(genome_.nodes, genome_.conns)
cons_exist = jnp.where(~jnp.isnan(u_cons[0, :, :]), True, False)
is_cycle = check_cycles(genome_.nodes, cons_exist, from_idx, to_idx)
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
return jax.lax.switch(choice, [already_exist, nothing, successful])
elif config.network_type == 'recurrent':
return jax.lax.cond(is_already_exist, already_exist, successful)
else:
raise ValueError(f"Invalid network type: {config.network_type}")
def mutate_delete_conn(key_, genome_: Genome):
# randomly choose a connection
i_key, o_key, idx = choice_connection_key(key_, genome_.conns)
def nothing():
return genome_
def successfully_delete_connection():
return genome_.delete_conn_by_pos(idx)
return jax.lax.cond(idx == I_INT, nothing, successfully_delete_connection)
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
def no(k, g):
return g
genome = jax.lax.cond(r1 < config.node_add, mutate_add_node, no, k1, genome)
genome = jax.lax.cond(r2 < config.node_delete, mutate_delete_node, no, k2, genome)
genome = jax.lax.cond(r3 < config.conn_add, mutate_add_conn, no, k3, genome)
genome = jax.lax.cond(r4 < config.conn_delete, mutate_delete_conn, no, k4, genome)
return genome
def mutate_values(gene: Gene, state: State, randkey, genome: Genome):
k1, k2 = jax.random.split(randkey, num=2)
nodes_keys = jax.random.split(k1, num=genome.nodes.shape[0])
conns_keys = jax.random.split(k2, num=genome.conns.shape[0])
nodes_attrs, conns_attrs = genome.nodes[:, 1:], genome.conns[:, 3:]
new_nodes_attrs = vmap(gene.mutate_node, in_axes=(None, 0, 0))(state, nodes_keys, nodes_attrs)
new_conns_attrs = vmap(gene.mutate_conn, in_axes=(None, 0, 0))(state, conns_keys, conns_attrs)
# nan nodes not changed
new_nodes_attrs = jnp.where(jnp.isnan(nodes_attrs), jnp.nan, new_nodes_attrs)
new_conns_attrs = jnp.where(jnp.isnan(conns_attrs), jnp.nan, new_conns_attrs)
new_nodes = genome.nodes.at[:, 1:].set(new_nodes_attrs)
new_conns = genome.conns.at[:, 3:].set(new_conns_attrs)
return genome.update(new_nodes, new_conns)
def choice_node_key(rand_key: Array, nodes: Array,
input_keys: Array, output_keys: Array,
allow_input_keys: bool = False, allow_output_keys: bool = False) -> Tuple[Array, Array]:
"""
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
:param rand_key:
:param nodes:
:param input_keys:
:param output_keys:
:param allow_input_keys:
:param allow_output_keys:
:return: return its key and position(idx)
"""
node_keys = nodes[:, 0]
mask = ~jnp.isnan(node_keys)
if not allow_input_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_keys))
if not allow_output_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_keys))
idx = fetch_random(rand_key, mask)
key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan)
return key, idx
def choice_connection_key(rand_key: Array, conns: Array):
"""
Randomly choose a connection key from the given connections.
:return: i_key, o_key, idx
"""
idx = fetch_random(rand_key, ~jnp.isnan(conns[:, 0]))
i_key = jnp.where(idx != I_INT, conns[idx, 0], jnp.nan)
o_key = jnp.where(idx != I_INT, conns[idx, 1], jnp.nan)
return i_key, o_key, idx

View File

@@ -0,0 +1,2 @@
from .base import BaseMutation
from .default import DefaultMutation

View File

@@ -0,0 +1,3 @@
class BaseMutation:
def __call__(self, key, genome, nodes, conns, new_node_key):
raise NotImplementedError

View File

@@ -0,0 +1,201 @@
import jax, jax.numpy as jnp
from . import BaseMutation
from utils import fetch_first, fetch_random, I_INT, unflatten_conns, check_cycles
class DefaultMutation(BaseMutation):
def __init__(
self,
conn_add: float = 0.4,
conn_delete: float = 0,
node_add: float = 0.2,
node_delete: float = 0,
):
self.conn_add = conn_add
self.conn_delete = conn_delete
self.node_add = node_add
self.node_delete = node_delete
def __call__(self, randkey, genome, nodes, conns, new_node_key):
k1, k2 = jax.random.split(randkey)
nodes, conns = self.mutate_structure(k1, genome, nodes, conns, new_node_key)
nodes, conns = self.mutate_values(k2, genome, nodes, conns)
return nodes, conns
def mutate_structure(self, randkey, genome, nodes, conns, new_node_key):
def mutate_add_node(key_, nodes_, conns_):
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
def successful_add_node():
# disable the connection
new_conns = conns_.at[idx, 2].set(False)
# add a new node
new_nodes = genome.add_node(nodes_, new_node_key, genome.node_gene.new_custom_attrs())
# add two new connections
new_conns = genome.add_conn(new_conns, i_key, new_node_key, True, genome.conn_gene.new_custom_attrs())
new_conns = genome.add_conn(new_conns, new_node_key, o_key, True, genome.conn_gene.new_custom_attrs())
return new_nodes, new_conns
return jax.lax.cond(
idx == I_INT,
lambda: (nodes_, conns_), # do nothing
successful_add_node
)
def mutate_delete_node(key_, nodes_, conns_):
# randomly choose a node
key, idx = self.choice_node_key(key_, nodes_, genome.input_idx, genome.output_idx,
allow_input_keys=False, allow_output_keys=False)
def successful_delete_node():
# delete the node
new_nodes = genome.delete_node_by_pos(nodes_, idx)
# delete all connections
new_conns = jnp.where(
((conns_[:, 0] == key) | (conns_[:, 1] == key))[:, None],
jnp.nan,
conns_
)
return new_nodes, new_conns
return jax.lax.cond(
idx == I_INT,
lambda: (nodes_, conns_), # do nothing
successful_delete_node
)
def mutate_add_conn(key_, nodes_, conns_):
# randomly choose two nodes
k1_, k2_ = jax.random.split(key_, num=2)
# input node of the connection can be any node
i_key, from_idx = self.choice_node_key(k1_, nodes_, genome.input_idx, genome.output_idx,
allow_input_keys=True, allow_output_keys=True)
# output node of the connection can be any node except input node
o_key, to_idx = self.choice_node_key(k2_, nodes_, genome.input_idx, genome.output_idx,
allow_input_keys=False, allow_output_keys=True)
conn_pos = fetch_first((conns_[:, 0] == i_key) & (conns_[:, 1] == o_key))
is_already_exist = conn_pos != I_INT
def nothing():
return nodes_, conns_
def successful():
return nodes_, genome.add_conn(conns_, i_key, o_key, True, genome.conns.new_custom_attrs())
def already_exist():
return nodes_, conns_.at[conn_pos, 2].set(True)
if genome.network_type == 'feedforward':
u_cons = unflatten_conns(nodes_, conns_)
cons_exist = ~jnp.isnan(u_cons[0, :, :])
is_cycle = check_cycles(nodes_, cons_exist, from_idx, to_idx)
return jax.lax.cond(
is_already_exist,
already_exist,
jax.lax.cond(
is_cycle,
nothing,
successful
)
)
elif genome.network_type == 'recurrent':
return jax.lax.cond(
is_already_exist,
already_exist,
successful
)
else:
raise ValueError(f"Invalid network type: {genome.network_type}")
def mutate_delete_conn(key_, nodes_, conns_):
# randomly choose a connection
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
def successfully_delete_connection():
return nodes_, genome.delete_conn_by_pos(conns_, idx)
return jax.lax.cond(
idx == I_INT,
lambda: (nodes_, conns_), # nothing
successfully_delete_connection
)
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
def no(k, g):
return g
genome = jax.lax.cond(r1 < self.node_add, mutate_add_node, no, k1, nodes, conns)
genome = jax.lax.cond(r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns)
genome = jax.lax.cond(r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns)
genome = jax.lax.cond(r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns)
return genome
def mutate_values(self, randkey, genome, nodes, conns):
k1, k2 = jax.random.split(randkey, num=2)
nodes_keys = jax.random.split(k1, num=genome.nodes.shape[0])
conns_keys = jax.random.split(k2, num=genome.conns.shape[0])
new_nodes = jax.vmap(genome.nodes.mutate, in_axes=(0, 0))(nodes_keys, nodes)
new_conns = jax.vmap(genome.conns.mutate, in_axes=(0, 0))(conns_keys, conns)
# nan nodes not changed
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)
new_conns = jnp.where(jnp.isnan(conns), jnp.nan, new_conns)
return new_nodes, new_conns
def choice_node_key(self, rand_key, nodes, input_idx, output_idx,
allow_input_keys: bool = False, allow_output_keys: bool = False):
"""
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
:param rand_key:
:param nodes:
:param input_idx:
:param output_idx:
:param allow_input_keys:
:param allow_output_keys:
:return: return its key and position(idx)
"""
node_keys = nodes[:, 0]
mask = ~jnp.isnan(node_keys)
if not allow_input_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_idx))
if not allow_output_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_idx))
idx = fetch_random(rand_key, mask)
key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan)
return key, idx
def choice_connection_key(self, rand_key, conns):
"""
Randomly choose a connection key from the given connections.
:return: i_key, o_key, idx
"""
idx = fetch_random(rand_key, ~jnp.isnan(conns[:, 0]))
i_key = jnp.where(idx != I_INT, conns[idx, 0], jnp.nan)
o_key = jnp.where(idx != I_INT, conns[idx, 1], jnp.nan)
return i_key, o_key, idx

View File

@@ -1,40 +0,0 @@
import jax
from jax import numpy as jnp, vmap
from config import NeatConfig
from core import Genome, State, Gene
from .mutate import mutate
from .crossover import crossover
def create_next_generation(config: NeatConfig, gene: Gene, state: State, randkey, winner, loser, elite_mask):
# prepare random keys
pop_size = state.idx2species.shape[0]
new_node_keys = jnp.arange(pop_size) + state.next_node_key
k1, k2 = jax.random.split(randkey, 2)
crossover_rand_keys = jax.random.split(k1, pop_size)
mutate_rand_keys = jax.random.split(k2, pop_size)
# batch crossover
wpn, wpc = state.pop_genomes.nodes[winner], state.pop_genomes.conns[winner]
lpn, lpc = state.pop_genomes.nodes[loser], state.pop_genomes.conns[loser]
n_genomes = vmap(crossover)(crossover_rand_keys, Genome(wpn, wpc), Genome(lpn, lpc))
# batch mutation
mutate_func = vmap(mutate, in_axes=(None, None, None, 0, 0, 0))
m_n_genomes = mutate_func(config, gene, state, mutate_rand_keys, n_genomes, new_node_keys) # mutate_new_pop_nodes
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], n_genomes.nodes, m_n_genomes.nodes)
pop_conns = jnp.where(elite_mask[:, None, None], n_genomes.conns, m_n_genomes.conns)
# update next node key
all_nodes_keys = pop_nodes[:, :, 0]
max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys))
next_node_key = max_node_key + 1
return state.update(
pop_genomes=Genome(pop_nodes, pop_conns),
next_node_key=next_node_key,
)

View File

@@ -1,3 +1,3 @@
from .normal import NormalGene, NormalGeneConfig
from .recurrent import RecurrentGene, RecurrentGeneConfig
from .base import BaseGene
from .conn import *
from .node import *

View File

@@ -0,0 +1,23 @@
class BaseGene:
"Base class for node genes or connection genes."
fixed_attrs = []
custom_attrs = []
def __init__(self):
pass
def new_custom_attrs(self):
raise NotImplementedError
def mutate(self, randkey, gene):
raise NotImplementedError
def distance(self, gene1, gene2):
raise NotImplementedError
def forward(self, attrs, inputs):
raise NotImplementedError
@property
def length(self):
return len(self.fixed_attrs) + len(self.custom_attrs)

View File

@@ -0,0 +1,2 @@
from .base import BaseConnGene
from .default import DefaultConnGene

View File

@@ -0,0 +1,12 @@
from .. import BaseGene
class BaseConnGene(BaseGene):
"Base class for connection genes."
fixed_attrs = ['input_index', 'output_index', 'enabled']
def __init__(self):
super().__init__()
def forward(self, attrs, inputs):
raise NotImplementedError

View File

@@ -0,0 +1,51 @@
import jax.numpy as jnp
from utils import mutate_float
from . import BaseConnGene
class DefaultConnGene(BaseConnGene):
"Default connection gene, with the same behavior as in NEAT-python."
fixed_attrs = ['input_index', 'output_index', 'enabled']
attrs = ['weight']
def __init__(
self,
weight_init_mean: float = 0.0,
weight_init_std: float = 1.0,
weight_mutate_power: float = 0.5,
weight_mutate_rate: float = 0.8,
weight_replace_rate: float = 0.1,
):
super().__init__()
self.weight_init_mean = weight_init_mean
self.weight_init_std = weight_init_std
self.weight_mutate_power = weight_mutate_power
self.weight_mutate_rate = weight_mutate_rate
self.weight_replace_rate = weight_replace_rate
def new_custom_attrs(self):
return jnp.array([self.weight_init_mean])
def mutate(self, key, conn):
input_index = conn[0]
output_index = conn[1]
enabled = conn[2]
weight = mutate_float(key,
conn[3],
self.weight_init_mean,
self.weight_init_std,
self.weight_mutate_power,
self.weight_mutate_rate,
self.weight_replace_rate
)
return jnp.array([input_index, output_index, enabled, weight])
def distance(self, attrs1, attrs2):
return (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight
def forward(self, attrs, inputs):
weight = attrs[0]
return inputs * weight

View File

@@ -0,0 +1,2 @@
from .base import BaseNodeGene
from .default import DefaultNodeGene

View File

@@ -0,0 +1,12 @@
from .. import BaseGene
class BaseNodeGene(BaseGene):
"Base class for node genes."
fixed_attrs = ["index"]
def __init__(self):
super().__init__()
def forward(self, attrs, inputs):
raise NotImplementedError

View File

@@ -0,0 +1,96 @@
from typing import Tuple
import jax, jax.numpy as jnp
from utils import Act, Agg, act, agg, mutate_int, mutate_float
from . import BaseNodeGene
class DefaultNodeGene(BaseNodeGene):
"Default node gene, with the same behavior as in NEAT-python."
fixed_attrs = ['index']
custom_attrs = ['bias', 'response', 'aggregation', 'activation']
def __init__(
self,
bias_init_mean: float = 0.0,
bias_init_std: float = 1.0,
bias_mutate_power: float = 0.5,
bias_mutate_rate: float = 0.7,
bias_replace_rate: float = 0.1,
response_init_mean: float = 1.0,
response_init_std: float = 0.0,
response_mutate_power: float = 0.5,
response_mutate_rate: float = 0.7,
response_replace_rate: float = 0.1,
activation_default: callable = Act.sigmoid,
activation_options: Tuple = (Act.sigmoid,),
activation_replace_rate: float = 0.1,
aggregation_default: callable = Agg.sum,
aggregation_options: Tuple = (Agg.sum,),
aggregation_replace_rate: float = 0.1,
):
super().__init__()
self.bias_init_mean = bias_init_mean
self.bias_init_std = bias_init_std
self.bias_mutate_power = bias_mutate_power
self.bias_mutate_rate = bias_mutate_rate
self.bias_replace_rate = bias_replace_rate
self.response_init_mean = response_init_mean
self.response_init_std = response_init_std
self.response_mutate_power = response_mutate_power
self.response_mutate_rate = response_mutate_rate
self.response_replace_rate = response_replace_rate
self.activation_default = activation_options.index(activation_default)
self.activation_options = activation_options
self.activation_indices = jnp.arange(len(activation_options))
self.activation_replace_rate = activation_replace_rate
self.aggregation_default = aggregation_options.index(aggregation_default)
self.aggregation_options = aggregation_options
self.aggregation_indices = jnp.arange(len(aggregation_options))
self.aggregation_replace_rate = aggregation_replace_rate
def new_custom_attrs(self):
return jnp.array(
[self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default]
)
def mutate(self, key, node):
k1, k2, k3, k4 = jax.random.split(key, num=4)
index = node[0]
bias = mutate_float(k1, node[1], self.bias_init_mean, self.bias_init_std,
self.bias_mutate_power, self.bias_mutate_rate, self.bias_replace_rate)
res = mutate_float(k2, node[2], self.response_init_mean, self.response_init_std,
self.response_mutate_power, self.response_mutate_rate, self.response_replace_rate)
act = mutate_int(k3, node[3], self.activation_indices, self.activation_replace_rate)
agg = mutate_int(k4, node[4], self.aggregation_indices, self.aggregation_replace_rate)
return jnp.array([index, bias, res, act, agg])
def distance(self, node1, node2):
return (
jnp.abs(node1[1] - node2[1]) +
jnp.abs(node1[2] - node2[2]) +
node1[3] != node2[3] +
node1[4] != node2[4]
)
def forward(self, attrs, inputs):
bias, res, act_idx, agg_idx = attrs
z = agg(agg_idx, inputs, self.aggregation_options)
z = bias + res * z
z = act(act_idx, z, self.activation_options)
return z

View File

@@ -1,210 +0,0 @@
from dataclasses import dataclass
from typing import Tuple
import jax
from jax import Array, numpy as jnp
from config import GeneConfig
from core import Gene, Genome, State
from utils import Act, Agg, unflatten_conns, topological_sort, I_INT, act, agg
@dataclass(frozen=True)
class NormalGeneConfig(GeneConfig):
bias_init_mean: float = 0.0
bias_init_std: float = 1.0
bias_mutate_power: float = 0.5
bias_mutate_rate: float = 0.7
bias_replace_rate: float = 0.1
response_init_mean: float = 1.0
response_init_std: float = 0.0
response_mutate_power: float = 0.5
response_mutate_rate: float = 0.7
response_replace_rate: float = 0.1
activation_default: callable = Act.sigmoid
activation_options: Tuple = (Act.sigmoid, )
activation_replace_rate: float = 0.1
aggregation_default: callable = Agg.sum
aggregation_options: Tuple = (Agg.sum, )
aggregation_replace_rate: float = 0.1
weight_init_mean: float = 0.0
weight_init_std: float = 1.0
weight_mutate_power: float = 0.5
weight_mutate_rate: float = 0.8
weight_replace_rate: float = 0.1
def __post_init__(self):
assert self.bias_init_std >= 0.0
assert self.bias_mutate_power >= 0.0
assert self.bias_mutate_rate >= 0.0
assert self.bias_replace_rate >= 0.0
assert self.response_init_std >= 0.0
assert self.response_mutate_power >= 0.0
assert self.response_mutate_rate >= 0.0
assert self.response_replace_rate >= 0.0
assert self.activation_default == self.activation_options[0]
assert self.aggregation_default == self.aggregation_options[0]
class NormalGene(Gene):
node_attrs = ['bias', 'response', 'aggregation', 'activation']
conn_attrs = ['weight']
def __init__(self, config: NormalGeneConfig = NormalGeneConfig()):
self.config = config
def setup(self, state: State = State()):
return state.update(
bias_init_mean=self.config.bias_init_mean,
bias_init_std=self.config.bias_init_std,
bias_mutate_power=self.config.bias_mutate_power,
bias_mutate_rate=self.config.bias_mutate_rate,
bias_replace_rate=self.config.bias_replace_rate,
response_init_mean=self.config.response_init_mean,
response_init_std=self.config.response_init_std,
response_mutate_power=self.config.response_mutate_power,
response_mutate_rate=self.config.response_mutate_rate,
response_replace_rate=self.config.response_replace_rate,
activation_replace_rate=self.config.activation_replace_rate,
activation_default=0,
activation_options=jnp.arange(len(self.config.activation_options)),
aggregation_replace_rate=self.config.aggregation_replace_rate,
aggregation_default=0,
aggregation_options=jnp.arange(len(self.config.aggregation_options)),
weight_init_mean=self.config.weight_init_mean,
weight_init_std=self.config.weight_init_std,
weight_mutate_power=self.config.weight_mutate_power,
weight_mutate_rate=self.config.weight_mutate_rate,
weight_replace_rate=self.config.weight_replace_rate,
)
def update(self, state):
return state
def new_node_attrs(self, state):
return jnp.array([state.bias_init_mean, state.response_init_mean,
state.activation_default, state.aggregation_default])
def new_conn_attrs(self, state):
return jnp.array([state.weight_init_mean])
def mutate_node(self, state, key, attrs: Array):
k1, k2, k3, k4 = jax.random.split(key, num=4)
bias = NormalGene._mutate_float(k1, attrs[0], state.bias_init_mean, state.bias_init_std,
state.bias_mutate_power, state.bias_mutate_rate, state.bias_replace_rate)
res = NormalGene._mutate_float(k2, attrs[1], state.response_init_mean, state.response_init_std,
state.response_mutate_power, state.response_mutate_rate,
state.response_replace_rate)
act = NormalGene._mutate_int(k3, attrs[2], state.activation_options, state.activation_replace_rate)
agg = NormalGene._mutate_int(k4, attrs[3], state.aggregation_options, state.aggregation_replace_rate)
return jnp.array([bias, res, act, agg])
def mutate_conn(self, state, key, attrs: Array):
weight = NormalGene._mutate_float(key, attrs[0], state.weight_init_mean, state.weight_init_std,
state.weight_mutate_power, state.weight_mutate_rate,
state.weight_replace_rate)
return jnp.array([weight])
def distance_node(self, state, node1: Array, node2: Array):
# bias + response + activation + aggregation
return jnp.abs(node1[1] - node2[1]) + jnp.abs(node1[2] - node2[2]) + \
(node1[3] != node2[3]) + (node1[4] != node2[4])
def distance_conn(self, state, con1: Array, con2: Array):
return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight
def forward_transform(self, state: State, genome: Genome):
u_conns = unflatten_conns(genome.nodes, genome.conns)
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
# remove enable attr
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
seqs = topological_sort(genome.nodes, conn_enable)
return seqs, genome.nodes, u_conns
def forward(self, state: State, inputs, transformed):
cal_seqs, nodes, cons = transformed
input_idx = state.input_idx
output_idx = state.output_idx
N = nodes.shape[0]
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[input_idx].set(inputs)
weights = cons[0, :]
def cond_fun(carry):
values, idx = carry
return (idx < N) & (cal_seqs[idx] != I_INT)
def body_func(carry):
values, idx = carry
i = cal_seqs[idx]
def hit():
ins = values * weights[:, i]
z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins)
z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z)
new_values = values.at[i].set(z)
return new_values
def miss():
return values
# the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit)
return values, idx + 1
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
return vals[output_idx]
@staticmethod
def _mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate):
k1, k2, k3 = jax.random.split(key, num=3)
noise = jax.random.normal(k1, ()) * mutate_power
replace = jax.random.normal(k2, ()) * init_std + init_mean
r = jax.random.uniform(k3, ())
val = jnp.where(
r < mutate_rate,
val + noise,
jnp.where(
(mutate_rate < r) & (r < mutate_rate + replace_rate),
replace,
val
)
)
return val
@staticmethod
def _mutate_int(key, val, options, replace_rate):
k1, k2 = jax.random.split(key, num=2)
r = jax.random.uniform(k1, ())
val = jnp.where(
r < replace_rate,
jax.random.choice(k2, options),
val
)
return val

View File

@@ -1,57 +0,0 @@
from dataclasses import dataclass
import jax
from jax import numpy as jnp, vmap
from .normal import NormalGene, NormalGeneConfig
from core import State, Genome
from utils import unflatten_conns, act, agg
@dataclass(frozen=True)
class RecurrentGeneConfig(NormalGeneConfig):
activate_times: int = 10
def __post_init__(self):
super().__post_init__()
assert self.activate_times > 0
class RecurrentGene(NormalGene):
def __init__(self, config: RecurrentGeneConfig = RecurrentGeneConfig()):
self.config = config
super().__init__(config)
def forward_transform(self, state: State, genome: Genome):
u_conns = unflatten_conns(genome.nodes, genome.conns)
# remove un-enable connections and remove enable attr
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
return genome.nodes, u_conns
def forward(self, state: State, inputs, transformed):
nodes, conns = transformed
batch_act, batch_agg = vmap(act, in_axes=(0, 0, None)), vmap(agg, in_axes=(0, 0, None))
input_idx = state.input_idx
output_idx = state.output_idx
N = nodes.shape[0]
vals = jnp.full((N,), 0.)
weights = conns[0, :]
def body_func(i, values):
values = values.at[input_idx].set(inputs)
nodes_ins = values * weights.T
values = batch_agg(nodes[:, 4], nodes_ins, self.config.aggregation_options) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(nodes[:, 3], values, self.config.activation_options) # z = act(z)
return values
vals = jax.lax.fori_loop(0, self.config.activate_times, body_func, vals)
return vals[output_idx]

View File

@@ -0,0 +1,3 @@
from .base import BaseGenome
from .default import DefaultGenome
from .recurrent import RecurrentGenome

View File

@@ -0,0 +1,66 @@
import jax.numpy as jnp
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
from utils import fetch_first
class BaseGenome:
network_type = None
def __init__(
self,
num_inputs: int,
num_outputs: int,
max_nodes: int,
max_conns: int,
node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(),
):
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.input_idx = jnp.arange(num_inputs)
self.output_idx = jnp.arange(num_inputs, num_inputs + num_outputs)
self.max_nodes = max_nodes
self.max_conns = max_conns
self.node_gene = node_gene
self.conn_gene = conn_gene
def transform(self, nodes, conns):
raise NotImplementedError
def forward(self, inputs, transformed):
raise NotImplementedError
def add_node(self, nodes, new_key: int, attrs):
"""
Add a new node to the genome.
The new node will place at the first NaN row.
"""
exist_keys = nodes[:, 0]
pos = fetch_first(jnp.isnan(exist_keys))
new_nodes = nodes.at[pos, 0].set(new_key)
return new_nodes.at[pos, 1:].set(attrs)
def delete_node_by_pos(self, nodes, pos):
"""
Delete a node from the genome.
Delete the node by its pos in nodes.
"""
return nodes.at[pos].set(jnp.nan)
def add_conn(self, conns, i_key, o_key, enable: bool, attrs):
"""
Add a new connection to the genome.
The new connection will place at the first NaN row.
"""
con_keys = conns[:, 0]
pos = fetch_first(jnp.isnan(con_keys))
new_conns = conns.at[pos, 0:3].set(jnp.array([i_key, o_key, enable]))
return new_conns.at[pos, 3:].set(attrs)
def delete_conn_by_pos(self, conns, pos):
"""
Delete a connection from the genome.
Delete the connection by its idx.
"""
return conns.at[pos].set(jnp.nan)

View File

@@ -0,0 +1,75 @@
import jax, jax.numpy as jnp
from utils import unflatten_conns, topological_sort, I_INT
from . import BaseGenome
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
class DefaultGenome(BaseGenome):
"""Default genome class, with the same behavior as the NEAT-Python"""
network_type = 'feedforward'
def __init__(self,
num_inputs: int,
num_outputs: int,
node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(),
):
super().__init__(num_inputs, num_outputs, node_gene, conn_gene)
def transform(self, nodes, conns):
u_conns = unflatten_conns(nodes, conns)
# DONE: Seems like there is a bug in this line
# conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
# modified: exist conn and enable is true
# conn_enable = jnp.where( (~jnp.isnan(u_conns[0])) & (u_conns[0] == 1), True, False)
# advanced modified: when and only when enabled is True
conn_enable = u_conns[0] == 1
# remove enable attr
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
seqs = topological_sort(nodes, conn_enable)
return seqs, nodes, u_conns
def forward(self, inputs, transformed):
cal_seqs, nodes, conns = transformed
N = nodes.shape[0]
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[self.input_idx].set(inputs)
nodes_attrs = nodes[:, 1:]
def cond_fun(carry):
values, idx = carry
return (idx < N) & (cal_seqs[idx] != I_INT)
def body_func(carry):
values, idx = carry
i = cal_seqs[idx]
def hit():
ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values)
# ins = values * weights[:, i]
z = self.node_gene.forward(nodes_attrs[i], ins)
# z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins)
# z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
# z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z)
new_values = values.at[i].set(z)
return new_values
def miss():
return values
# the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(jnp.isin(i, self.input_idx), miss, hit)
return values, idx + 1
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
return vals[self.output_idx]

View File

@@ -0,0 +1,58 @@
import jax, jax.numpy as jnp
from utils import unflatten_conns
from . import BaseGenome
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
class RecurrentGenome(BaseGenome):
"""Default genome class, with the same behavior as the NEAT-Python"""
network_type = 'recurrent'
def __init__(self,
num_inputs: int,
num_outputs: int,
node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(),
activate_time: int = 10,
):
super().__init__(num_inputs, num_outputs, node_gene, conn_gene)
self.activate_time = activate_time
def transform(self, nodes, conns):
u_conns = unflatten_conns(nodes, conns)
# remove un-enable connections and remove enable attr
conn_enable = u_conns[0] == 1
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
return nodes, u_conns
def forward(self, inputs, transformed):
nodes, conns = transformed
N = nodes.shape[0]
vals = jnp.full((N,), jnp.nan)
nodes_attrs = nodes[:, 1:]
def body_func(_, values):
# set input values
values = values.at[self.input_idx].set(inputs)
# calculate connections
node_ins = jax.vmap(
jax.vmap(
self.conn_gene.forward,
in_axes=(1, None)
),
in_axes=(1, 0)
)(conns, values)
# calculate nodes
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T)
return values
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)
return vals[self.output_idx]

View File

@@ -1,87 +1,38 @@
from typing import Type
import jax, jax.numpy as jnp
from utils import State
from .. import BaseAlgorithm
from .genome import *
from .species import *
from .ga import *
import jax
from jax import numpy as jnp
import numpy as np
class NEAT(BaseAlgorithm):
from config import Config
from core import Algorithm, State, Gene, Genome
from .ga import create_next_generation
from .species import SpeciesInfo, update_species, speciate
def __init__(
self,
genome: BaseGenome,
species: BaseSpecies,
mutation: BaseMutation = DefaultMutation(),
crossover: BaseCrossover = DefaultCrossover(),
):
self.genome = genome
self.species = species
self.mutation = mutation
self.crossover = crossover
class NEAT(Algorithm):
def __init__(self, config: Config, gene_type: Type[Gene]):
self.config = config
self.gene = gene_type(config.gene)
self.forward_func = None
self.tell_func = None
def setup(self, randkey, state: State = State()):
"""initialize the state of the algorithm"""
input_idx = np.arange(self.config.neat.inputs)
output_idx = np.arange(self.config.neat.inputs,
self.config.neat.inputs + self.config.neat.outputs)
state = state.update(
P=self.config.basic.pop_size,
N=self.config.neat.max_nodes,
C=self.config.neat.max_conns,
S=self.config.neat.max_species,
NL=1 + len(self.gene.node_attrs), # node length = (key) + attributes
CL=3 + len(self.gene.conn_attrs), # conn length = (in, out, key) + attributes
max_stagnation=self.config.neat.max_stagnation,
species_elitism=self.config.neat.species_elitism,
spawn_number_change_rate=self.config.neat.spawn_number_change_rate,
genome_elitism=self.config.neat.genome_elitism,
survival_threshold=self.config.neat.survival_threshold,
compatibility_threshold=self.config.neat.compatibility_threshold,
compatibility_disjoint=self.config.neat.compatibility_disjoint,
compatibility_weight=self.config.neat.compatibility_weight,
input_idx=input_idx,
output_idx=output_idx,
def setup(self, randkey):
k1, k2 = jax.random.split(randkey, 2)
return State(
randkey=k1,
generation=0,
next_node_key=max(*self.genome.input_idx, *self.genome.output_idx) + 2,
# inputs nodes, output nodes, 1 hidden node
species=self.species.setup(k2),
)
state = self.gene.setup(state)
pop_genomes = self._initialize_genomes(state)
species_info = SpeciesInfo.initialize(state)
idx2species = jnp.zeros(state.P, dtype=jnp.float32)
center_nodes = jnp.full((state.S, state.N, state.NL), jnp.nan, dtype=jnp.float32)
center_conns = jnp.full((state.S, state.C, state.CL), jnp.nan, dtype=jnp.float32)
center_genomes = Genome(center_nodes, center_conns)
center_genomes = center_genomes.set(0, pop_genomes[0])
generation = 0
next_node_key = max(*state.input_idx, *state.output_idx) + 2
next_species_key = 1
state = state.update(
randkey=randkey,
pop_genomes=pop_genomes,
species_info=species_info,
idx2species=idx2species,
center_genomes=center_genomes,
# avoid jax auto cast from int to float. that would cause re-compilation.
generation=jnp.asarray(generation, dtype=jnp.int32),
next_node_key=jnp.asarray(next_node_key, dtype=jnp.float32),
next_species_key=jnp.asarray(next_species_key, dtype=jnp.float32),
)
return jax.device_put(state)
def ask_algorithm(self, state: State):
return state.pop_genomes
def tell_algorithm(self, state: State, fitness):
state = self.gene.update(state)
def ask(self, state: State):
return self.species.ask(state)
def tell(self, state: State, fitness):
k1, k2, randkey = jax.random.split(state.randkey, 3)
state = state.update(
@@ -89,46 +40,55 @@ class NEAT(Algorithm):
randkey=randkey
)
state, winner, loser, elite_mask = update_species(state, k1, fitness)
state, winner, loser, elite_mask = self.species.update_species(state, fitness, state.generation)
state = create_next_generation(self.config.neat, self.gene, state, k2, winner, loser, elite_mask)
state = self.create_next_generation(k2, state, winner, loser, elite_mask)
state = speciate(self.gene, state)
state = self.species.speciate(state, state.generation)
return state
def forward_transform(self, state: State, genome: Genome):
return self.gene.forward_transform(state, genome)
def transform(self, state: State):
"""transform the genome into a neural network"""
raise NotImplementedError
def forward(self, state: State, inputs, genome: Genome):
return self.gene.forward(state, inputs, genome)
def forward(self, inputs, transformed):
raise NotImplementedError
def _initialize_genomes(self, state):
o_nodes = np.full((state.N, state.NL), np.nan, dtype=np.float32) # original nodes
o_conns = np.full((state.C, state.CL), np.nan, dtype=np.float32) # original connections
def create_next_generation(self, randkey, state, winner, loser, elite_mask):
# prepare random keys
pop_size = self.species.pop_size
new_node_keys = jnp.arange(pop_size) + state.species.next_node_key
input_idx = state.input_idx
output_idx = state.output_idx
new_node_key = max([*input_idx, *output_idx]) + 1
k1, k2 = jax.random.split(randkey, 2)
crossover_rand_keys = jax.random.split(k1, pop_size)
mutate_rand_keys = jax.random.split(k2, pop_size)
o_nodes[input_idx, 0] = input_idx
o_nodes[output_idx, 0] = output_idx
o_nodes[new_node_key, 0] = new_node_key
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = self.gene.new_node_attrs(state)
o_nodes[new_node_key, 1:] = self.gene.new_node_attrs(state)
wpn, wpc = state.species.pop_nodes[winner], state.species.pop_conns[winner]
lpn, lpc = state.species.pop_nodes[loser], state.species.pop_conns[loser]
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)]
o_conns[input_idx, 0:2] = input_conns # in key, out key
o_conns[input_idx, 2] = True # enabled
o_conns[input_idx, 3:] = self.gene.new_conn_attrs(state)
# batch crossover
n_nodes, n_conns = (jax.vmap(self.crossover, in_axes=(0, None, 0, 0, 0, 0))
(crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc))
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx]
o_conns[output_idx, 0:2] = output_conns # in key, out key
o_conns[output_idx, 2] = True # enabled
o_conns[output_idx, 3:] = self.gene.new_conn_attrs(state)
# batch mutation
m_n_nodes, m_n_conns = (jax.vmap(self.mutation, in_axes=(0, None, 0, 0, 0))
(mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys))
# repeat origin genome for P times to create population
pop_nodes = np.tile(o_nodes, (state.P, 1, 1))
pop_conns = np.tile(o_conns, (state.P, 1, 1))
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns)
# update next node key
all_nodes_keys = pop_nodes[:, :, 0]
max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys))
next_node_key = max_node_key + 1
return state.update(
species=state.species.update(
pop_nodes=pop_nodes,
pop_conns=pop_conns,
),
next_node_key=next_node_key,
)
return Genome(pop_nodes, pop_conns)

View File

@@ -1,2 +1,2 @@
from .species_info import SpeciesInfo
from .operations import update_species, speciate
from .base import BaseSpecies
from .default import DefaultSpecies

View File

@@ -0,0 +1,14 @@
from utils import State
class BaseSpecies:
def setup(self, randkey):
raise NotImplementedError
def ask(self, state: State):
raise NotImplementedError
def update_species(self, state, fitness, generation):
raise NotImplementedError
def speciate(self, state, generation):
raise NotImplementedError

View File

@@ -0,0 +1,514 @@
import numpy as np
import jax, jax.numpy as jnp
from utils import State, rank_elements, argmin_with_mask, fetch_first
from ..genome import BaseGenome
class DefaultSpecies:
def __init__(self,
genome: BaseGenome,
pop_size,
species_size,
compatibility_disjoint: float = 1.0,
compatibility_weight: float = 0.4,
max_stagnation: int = 15,
species_elitism: int = 2,
spawn_number_change_rate: float = 0.5,
genome_elitism: int = 2,
survival_threshold: float = 0.2,
min_species_size: int = 1,
compatibility_threshold: float = 3.5
):
self.genome = genome
self.pop_size = pop_size
self.species_size = species_size
self.compatibility_disjoint = compatibility_disjoint
self.compatibility_weight = compatibility_weight
self.max_stagnation = max_stagnation
self.species_elitism = species_elitism
self.spawn_number_change_rate = spawn_number_change_rate
self.genome_elitism = genome_elitism
self.survival_threshold = survival_threshold
self.min_species_size = min_species_size
self.compatibility_threshold = compatibility_threshold
self.species_arange = jnp.arange(self.species_size)
def setup(self, randkey):
pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome)
species_keys = jnp.full((self.species_size,), jnp.nan) # the unique index (primary key) for each species
best_fitness = jnp.full((self.species_size,), jnp.nan) # the best fitness of each species
last_improved = jnp.full((self.species_size,), jnp.nan) # the last generation that the species improved
member_count = jnp.full((self.species_size,), jnp.nan) # the number of members of each species
idx2species = jnp.zeros(self.pop_size) # the species index of each individual
# nodes for each center genome of each species
center_nodes = jnp.full((self.species_size, self.genome.max_nodes, self.genome.node_gene.length), jnp.nan)
# connections for each center genome of each species
center_conns = jnp.full((self.species_size, self.genome.max_conns, self.genome.conn_gene.length), jnp.nan)
species_keys = species_keys.at[0].set(0)
best_fitness = best_fitness.at[0].set(-jnp.inf)
last_improved = last_improved.at[0].set(0)
member_count = member_count.at[0].set(self.pop_size)
center_nodes = center_nodes.at[0].set(pop_nodes[0])
center_conns = center_conns.at[0].set(pop_conns[0])
return State(
randkey=randkey,
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
member_count=member_count,
idx2species=idx2species,
center_nodes=center_nodes,
center_conns=center_conns,
next_species_key=1, # 0 is reserved for the first species
)
def ask(self, state):
return state.pop_nodes, state.pop_conns
def update_species(self, state, fitness, generation):
# update the fitness of each species
species_fitness = self.update_species_fitness(state, fitness)
# stagnation species
state, species_fitness = self.stagnation(state, generation, species_fitness)
# sort species_info by their fitness. (also push nan to the end)
sort_indices = jnp.argsort(species_fitness)[::-1]
state = state.update(
species_keys=state.species_keys[sort_indices],
best_fitness=state.best_fitness[sort_indices],
last_improved=state.last_improved[sort_indices],
member_count=state.member_count[sort_indices],
center_nodes=state.center_nodes[sort_indices],
center_conns=state.center_conns[sort_indices],
)
# decide the number of members of each species by their fitness
spawn_number = self.cal_spawn_numbers(state)
k1, k2 = jax.random.split(state.randkey)
# crossover info
winner, loser, elite_mask = self.create_crossover_pair(state, k1, spawn_number, fitness)
return state(randkey=k2), winner, loser, elite_mask
def update_species_fitness(self, state, fitness):
"""
obtain the fitness of the species by the fitness of each individual.
use max criterion.
"""
def aux_func(idx):
s_fitness = jnp.where(state.idx2species == state.species_keys[idx], fitness, -jnp.inf)
val = jnp.max(s_fitness)
return val
return jax.vmap(aux_func)(self.species_arange)
def stagnation(self, state, generation, species_fitness):
"""
stagnation species.
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
elitism species never stagnation
generation: the current generation
"""
def check_stagnation(idx):
# determine whether the species stagnation
st = (
(species_fitness[idx] <= state.best_fitness[
idx]) & # not better than the best fitness of the species
(generation - state.last_improved[idx] > self.max_stagnation) # for a long time
)
# update last_improved and best_fitness
li, bf = jax.lax.cond(
species_fitness[idx] > state.best_fitness[idx],
lambda: (generation, species_fitness[idx]), # update
lambda: (state.last_improved[idx], state.best_fitness[idx]) # not update
)
return st, bf, li
spe_st, best_fitness, last_improved = jax.vmap(check_stagnation)(self.species_arange)
# elite species will not be stagnation
species_rank = rank_elements(species_fitness)
spe_st = jnp.where(species_rank < self.species_elitism, False, spe_st) # elitism never stagnation
# set stagnation species to nan
def update_func(idx):
return jax.lax.cond(
spe_st[idx],
lambda: (
jnp.nan, # species_key
jnp.nan, # best_fitness
jnp.nan, # last_improved
jnp.nan, # member_count
-jnp.inf, # species_fitness
jnp.full_like(center_nodes[idx], jnp.nan), # center_nodes
jnp.full_like(center_conns[idx], jnp.nan), # center_conns
), # stagnation species
lambda: (
species_keys[idx],
best_fitness[idx],
last_improved[idx],
state.member_count[idx],
species_fitness[idx],
center_nodes[idx],
center_conns[idx]
) # not stagnation species
)
(
species_keys,
best_fitness,
last_improved,
member_count,
species_fitness,
center_nodes,
center_conns
) = (
jax.vmap(update_func)(self.species_arange))
return state.update(
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
member_count=member_count,
center_nodes=center_nodes,
center_conns=center_conns,
), species_fitness
def cal_spawn_numbers(self, state):
"""
decide the number of members of each species by their fitness rank.
the species with higher fitness will have more members
Linear ranking selection
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
"""
species_keys = state.species_keys
is_species_valid = ~jnp.isnan(species_keys)
valid_species_num = jnp.sum(is_species_valid)
denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6
rank_score = valid_species_num - self.species_arange # obtain [3, 2, 1]
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
target_spawn_number = jnp.floor(spawn_number_rate * self.pop_size) # calculate member
# Avoid too much variation of numbers for a species
previous_size = state.member_count
spawn_number = previous_size + (target_spawn_number - previous_size) * self.spawn_number_change_rate
spawn_number = spawn_number.astype(jnp.int32)
# must control the sum of spawn_number to be equal to pop_size
error = state.P - jnp.sum(spawn_number)
# add error to the first species to control the sum of spawn_number
spawn_number = spawn_number.at[0].add(error)
return spawn_number
def create_crossover_pair(self, state, randkey, spawn_number, fitness):
s_idx = self.species_arange
p_idx = jnp.arange(self.pop_size)
def aux_func(key, idx):
members = state.idx2species == state.species_keys[idx]
members_num = jnp.sum(members)
members_fitness = jnp.where(members, fitness, -jnp.inf)
sorted_member_indices = jnp.argsort(members_fitness)[::-1]
survive_size = jnp.floor(self.survival_threshold * members_num).astype(jnp.int32)
select_pro = (p_idx < survive_size) / survive_size
fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, self.pop_size), replace=True, p=select_pro)
# elite
fa = jnp.where(p_idx < self.genome_elitism, sorted_member_indices, fa)
ma = jnp.where(p_idx < self.genome_elitism, sorted_member_indices, ma)
elite = jnp.where(p_idx < self.genome_elitism, True, False)
return fa, ma, elite
fas, mas, elites = jax.vmap(aux_func)(jax.random.split(randkey, self.species_size), s_idx)
spawn_number_cum = jnp.cumsum(spawn_number)
def aux_func(idx):
loc = jnp.argmax(idx < spawn_number_cum)
# elite genomes are at the beginning of the species
idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx)
return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species]
part1, part2, elite_mask = jax.vmap(aux_func)(p_idx)
is_part1_win = fitness[part1] >= fitness[part2]
winner = jnp.where(is_part1_win, part1, part2)
loser = jnp.where(is_part1_win, part2, part1)
return winner, loser, elite_mask
def speciate(self, state, generation):
# prepare distance functions
o2p_distance_func = jax.vmap(self.distance, in_axes=(None, None, 0, 0)) # one to population
# idx to specie key
idx2species = jnp.full((self.pop_size,), jnp.nan) # NaN means not assigned to any species
# the distance between genomes to its center genomes
o2c_distances = jnp.full((self.pop_size,), jnp.inf)
# step 1: find new centers
def cond_func(carry):
# i, idx2species, center_nodes, center_conns, o2c_distances
i, i2s, cns, ccs, o2c = carry
return (
(i < self.species_size) &
(~jnp.isnan(state.species_keys[i]))
) # current species is existing
def body_func(carry):
i, i2s, cns, ccs, o2c = carry
distances = o2p_distance_func(cns, ccs, state.pop_nodes, state.pop_conns)
# find the closest one
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i])
cns = cns.set(i, state.pop_nodes[closest_idx])
ccs = ccs.set(i, state.pop_conns[closest_idx])
# the genome with closest_idx will become the new center, thus its distance to center is 0.
o2c = o2c.at[closest_idx].set(0)
return i + 1, i2s, cns, ccs, o2c
_, idx2species, center_nodes, center_conns, o2c_distances = \
jax.lax.while_loop(cond_func, body_func,
(0, idx2species, state.center_nodes, state.center_conns, o2c_distances))
state = state.update(
idx2species=idx2species,
center_nodes=center_nodes,
center_conns=center_conns,
)
# part 2: assign members to each species
def cond_func(carry):
# i, idx2species, center_nodes, center_conns, species_keys, o2c_distances, next_species_key
i, i2s, cns, ccs, sk, o2c, nsk = carry
current_species_existed = ~jnp.isnan(sk[i])
not_all_assigned = jnp.any(jnp.isnan(i2s))
not_reach_species_upper_bounds = i < self.species_size
return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
def body_func(carry):
i, i2s, cns, ccs, sk, o2c, nsk = carry
_, i2s, cns, ccs, sk, o2c, nsk = jax.lax.cond(
jnp.isnan(sk[i]), # whether the current species is existing or not
create_new_species, # if not existing, create a new specie
update_exist_specie, # if existing, update the specie
(i, i2s, cns, ccs, sk, o2c, nsk)
)
return i + 1, i2s, cns, ccs, sk, o2c, nsk
def create_new_species(carry):
i, i2s, cns, ccs, sk, o2c, nsk = carry
# pick the first one who has not been assigned to any species
idx = fetch_first(jnp.isnan(i2s))
# assign it to the new species
# [key, best score, last update generation, member_count]
sk = sk.at[i].set(nsk) # nsk -> next species key
i2s = i2s.at[idx].set(nsk)
o2c = o2c.at[idx].set(0)
# update center genomes
cns = cns.set(i, state.pop_nodes[idx])
ccs = ccs.set(i, state.pop_conns[idx])
# find the members for the new species
i2s, o2c = speciate_by_threshold(i, i2s, cns, ccs, sk, o2c)
return i, i2s, cns, ccs, sk, o2c, nsk + 1 # change to next new speciate key
def update_exist_specie(carry):
i, i2s, cns, ccs, sk, o2c, nsk = carry
i2s, o2c = speciate_by_threshold(i, i2s, cns, ccs, sk, o2c)
# turn to next species
return i + 1, i2s, cns, ccs, sk, o2c, nsk
def speciate_by_threshold(i, i2s, cns, ccs, sk, o2c):
# distance between such center genome and ppo genomes
o2p_distance = o2p_distance_func(cns[i], ccs[i], state.pop_nodes, state.pop_conns)
close_enough_mask = o2p_distance < self.compatibility_threshold
# when a genome is not assigned or the distance between its current center is bigger than this center
catchable_mask = jnp.isnan(i2s) | (o2p_distance < o2c)
mask = close_enough_mask & catchable_mask
# update species info
i2s = jnp.where(mask, sk[i], i2s)
# update distance between centers
o2c = jnp.where(mask, o2p_distance, o2c)
return i2s, o2c
# update idx2species
_, idx2species, center_nodes, center_conns, species_keys, _, next_species_key = jax.lax.while_loop(
cond_func,
body_func,
(0, state.idx2species, state.center_nodes, center_conns, state.species_info.species_keys, o2c_distances,
state.next_species_key)
)
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition can only happen when the number of species is reached species upper bounds
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
# complete info of species which is created in this generation
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.best_fitness)
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.best_fitness)
last_improved = jnp.where(new_created_mask, generation, state.last_improved)
# update members count
def count_members(idx):
return jax.lax.cond(
jnp.isnan(species_keys[idx]), # if the species is not existing
lambda _: jnp.nan, # nan
lambda _: jnp.sum(idx2species == species_keys[idx], dtype=jnp.float32) # count members
)
member_count = jax.vmap(count_members)(self.species_arange)
return state.update(
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
member_count=member_count,
idx2species=idx2species,
center_nodes=center_nodes,
center_conns=center_conns,
next_species_key=next_species_key
)
def distance(self, nodes1, conns1, nodes2, conns2):
"""
The distance between two genomes
"""
return self.node_distance(nodes1, nodes2) + self.conn_distance(conns1, conns2)
def node_distance(self, nodes1, nodes2):
"""
The distance of the nodes part for two genomes
"""
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
# align homologous nodes
# this process is similar to np.intersect1d.
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0]
sorted_indices = jnp.argsort(keys, axis=0)
nodes = nodes[sorted_indices]
nodes = jnp.concatenate([nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row
# flag location of homologous nodes
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
# calculate the count of non_homologous of two genomes
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
# calculate the distance of homologous nodes
hnd = jax.vmap(self.genome.node_gene.distance, in_axes=(0, 0))(fr, sr)
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
homologous_distance = jnp.sum(hnd * intersect_mask)
val = non_homologous_cnt * self.compatibility_disjoint + homologous_distance * self.compatibility_weight
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division
def conn_distance(self, conns1, conns2):
"""
The distance of the conns part for two genomes
"""
con_cnt1 = jnp.sum(~jnp.isnan(conns1[:, 0]))
con_cnt2 = jnp.sum(~jnp.isnan(conns2[:, 0]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
cons = jnp.concatenate((conns1, conns2), axis=0)
keys = cons[:, :2]
sorted_indices = jnp.lexsort(keys.T[::-1])
cons = cons[sorted_indices]
cons = jnp.concatenate([cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0) # add a nan row to the end
fr, sr = cons[:-1], cons[1:] # first row, second row
# both genome has such connection
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
hcd = jax.vmap(self.genome.conn_gene.distance, in_axes=(0, 0))(fr, sr)
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
homologous_distance = jnp.sum(hcd * intersect_mask)
val = non_homologous_cnt * self.compatibility_disjoint + homologous_distance * self.compatibility_weight
return jnp.where(max_cnt == 0, 0, val / max_cnt)
def initialize_population(pop_size, genome):
o_nodes = np.full((genome.max_nodes, genome.node_gene.length), np.nan) # original nodes
o_conns = np.full((genome.max_conns, genome.conn_gene.length), np.nan) # original connections
input_idx, output_idx = genome.input_idx, genome.output_idx
new_node_key = max([*input_idx, *output_idx]) + 1
o_nodes[input_idx, 0] = genome.input_idx
o_nodes[output_idx, 0] = genome.output_idx
o_nodes[new_node_key, 0] = new_node_key # one hidden node
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = genome.node_gene.new_attrs()
o_nodes[new_node_key, 1:] = genome.node_gene.new_attrs() # one hidden node
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)] # input nodes to hidden
o_conns[input_idx, 0:2] = input_conns # in key, out key
o_conns[input_idx, 2] = True # enabled
o_conns[input_idx, 3:] = genome.conn_gene.new_conn_attrs()
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx] # hidden to output nodes
o_conns[output_idx, 0:2] = output_conns # in key, out key
o_conns[output_idx, 2] = True # enabled
o_conns[output_idx, 3:] = genome.conn_gene.new_conn_attrs()
# repeat origin genome for P times to create population
pop_nodes = np.tile(o_nodes, (pop_size, 1, 1))
pop_conns = np.tile(o_conns, (pop_size, 1, 1))
return pop_nodes, pop_conns

View File

@@ -1,71 +0,0 @@
from jax import Array, numpy as jnp, vmap
from core import Gene
def distance(gene: Gene, state, genome1, genome2):
return node_distance(gene, state, genome1.nodes, genome2.nodes) + \
connection_distance(gene, state, genome1.conns, genome2.conns)
def node_distance(gene: Gene, state, nodes1: Array, nodes2: Array):
"""
Calculate the distance between nodes of two genomes.
"""
# statistics nodes count of two genomes
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
# align homologous nodes
# this process is similar to np.intersect1d.
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0]
sorted_indices = jnp.argsort(keys, axis=0)
nodes = nodes[sorted_indices]
nodes = jnp.concatenate([nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row
# flag location of homologous nodes
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
# calculate the count of non_homologous of two genomes
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
# calculate the distance of homologous nodes
hnd = vmap(gene.distance_node, in_axes=(None, 0, 0))(state, fr, sr)
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
homologous_distance = jnp.sum(hnd * intersect_mask)
val = non_homologous_cnt * state.compatibility_disjoint + homologous_distance * state.compatibility_weight
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division
def connection_distance(gene: Gene, state, cons1: Array, cons2: Array):
"""
Calculate the distance between connections of two genomes.
Similar process as node_distance.
"""
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 0]))
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 0]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
cons = jnp.concatenate((cons1, cons2), axis=0)
keys = cons[:, :2]
sorted_indices = jnp.lexsort(keys.T[::-1])
cons = cons[sorted_indices]
cons = jnp.concatenate([cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0) # add a nan row to the end
fr, sr = cons[:-1], cons[1:] # first row, second row
# both genome has such connection
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
hcd = vmap(gene.distance_conn, in_axes=(None, 0, 0))(state, fr, sr)
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
homologous_distance = jnp.sum(hcd * intersect_mask)
val = non_homologous_cnt * state.compatibility_disjoint + homologous_distance * state.compatibility_weight
return jnp.where(max_cnt == 0, 0, val / max_cnt)

View File

@@ -1,319 +0,0 @@
import jax
from jax import numpy as jnp, vmap
from core import Gene, Genome, State
from utils import rank_elements, fetch_first
from .distance import distance
from .species_info import SpeciesInfo
def update_species(state, randkey, fitness):
# update the fitness of each species
species_fitness = update_species_fitness(state, fitness)
# stagnation species
state, species_fitness = stagnation(state, species_fitness)
# sort species_info by their fitness. (push nan to the end)
sort_indices = jnp.argsort(species_fitness)[::-1]
state = state.update(
species_info=state.species_info[sort_indices],
center_genomes=state.center_genomes[sort_indices],
)
# decide the number of members of each species by their fitness
spawn_number = cal_spawn_numbers(state)
# crossover info
winner, loser, elite_mask = create_crossover_pair(state, randkey, spawn_number, fitness)
return state, winner, loser, elite_mask
def update_species_fitness(state, fitness):
"""
obtain the fitness of the species by the fitness of each individual.
use max criterion.
"""
def aux_func(idx):
s_fitness = jnp.where(state.idx2species == state.species_info.species_keys[idx], fitness, -jnp.inf)
f = jnp.max(s_fitness)
return f
return vmap(aux_func)(jnp.arange(state.species_info.size()))
def stagnation(state, species_fitness):
"""
stagnation species.
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
elitism species never stagnation
"""
def aux_func(idx):
s_fitness = species_fitness[idx]
sk, bf, li, _ = state.species_info.get(idx)
st = (s_fitness <= bf) & (state.generation - li > state.max_stagnation)
li = jnp.where(s_fitness > bf, state.generation, li)
bf = jnp.where(s_fitness > bf, s_fitness, bf)
return st, sk, bf, li
spe_st, species_keys, best_fitness, last_improved = vmap(aux_func)(jnp.arange(species_fitness.shape[0]))
# elite species will not be stagnation
species_rank = rank_elements(species_fitness)
spe_st = jnp.where(species_rank < state.species_elitism, False, spe_st) # elitism never stagnation
# set stagnation species to nan
species_keys = jnp.where(spe_st, jnp.nan, species_keys)
best_fitness = jnp.where(spe_st, jnp.nan, best_fitness)
last_improved = jnp.where(spe_st, jnp.nan, last_improved)
member_count = jnp.where(spe_st, jnp.nan, state.species_info.member_count)
species_fitness = jnp.where(spe_st, -jnp.inf, species_fitness)
species_info = SpeciesInfo(species_keys, best_fitness, last_improved, member_count)
# TODO: Simplify the coded
center_nodes = jnp.where(spe_st[:, None, None], jnp.nan, state.center_genomes.nodes)
center_conns = jnp.where(spe_st[:, None, None], jnp.nan, state.center_genomes.conns)
state = state.update(
species_info=species_info,
center_genomes=Genome(center_nodes, center_conns)
)
return state, species_fitness
def cal_spawn_numbers(state):
"""
decide the number of members of each species by their fitness rank.
the species with higher fitness will have more members
Linear ranking selection
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
"""
species_keys = state.species_info.species_keys
is_species_valid = ~jnp.isnan(species_keys)
valid_species_num = jnp.sum(is_species_valid)
denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6
rank_score = valid_species_num - jnp.arange(species_keys.shape[0]) # obtain [3, 2, 1]
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
target_spawn_number = jnp.floor(spawn_number_rate * state.P) # calculate member
# Avoid too much variation of numbers in a species
previous_size = state.species_info.member_count
spawn_number = previous_size + (target_spawn_number - previous_size) * state.spawn_number_change_rate
# jax.debug.print("previous_size: {}, spawn_number: {}", previous_size, spawn_number)
spawn_number = spawn_number.astype(jnp.int32)
# must control the sum of spawn_number to be equal to pop_size
error = state.P - jnp.sum(spawn_number)
spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number
return spawn_number
def create_crossover_pair(state, randkey, spawn_number, fitness):
species_size = state.species_info.size()
pop_size = fitness.shape[0]
s_idx = jnp.arange(species_size)
p_idx = jnp.arange(pop_size)
# def aux_func(key, idx):
def aux_func(key, idx):
members = state.idx2species == state.species_info.species_keys[idx]
members_num = jnp.sum(members)
members_fitness = jnp.where(members, fitness, -jnp.inf)
sorted_member_indices = jnp.argsort(members_fitness)[::-1]
elite_size = state.genome_elitism
survive_size = jnp.floor(state.survival_threshold * members_num).astype(jnp.int32)
select_pro = (p_idx < survive_size) / survive_size
fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, pop_size), replace=True, p=select_pro)
# elite
fa = jnp.where(p_idx < elite_size, sorted_member_indices, fa)
ma = jnp.where(p_idx < elite_size, sorted_member_indices, ma)
elite = jnp.where(p_idx < elite_size, True, False)
return fa, ma, elite
fas, mas, elites = vmap(aux_func)(jax.random.split(randkey, species_size), s_idx)
spawn_number_cum = jnp.cumsum(spawn_number)
def aux_func(idx):
loc = jnp.argmax(idx < spawn_number_cum)
# elite genomes are at the beginning of the species
idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx)
return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species]
part1, part2, elite_mask = vmap(aux_func)(p_idx)
is_part1_win = fitness[part1] >= fitness[part2]
winner = jnp.where(is_part1_win, part1, part2)
loser = jnp.where(is_part1_win, part2, part1)
return winner, loser, elite_mask
def speciate(gene: Gene, state: State):
pop_size, species_size = state.idx2species.shape[0], state.species_info.size()
# prepare distance functions
o2p_distance_func = vmap(distance, in_axes=(None, None, None, 0)) # one to population
# idx to specie key
idx2species = jnp.full((pop_size,), jnp.nan) # NaN means not assigned to any species
# the distance between genomes to its center genomes
o2c_distances = jnp.full((pop_size,), jnp.inf)
# step 1: find new centers
def cond_func(carry):
i, i2s, cgs, o2c = carry
return (i < species_size) & (~jnp.isnan(state.species_info.species_keys[i])) # current species is existing
def body_func(carry):
i, i2s, cgs, o2c = carry
distances = o2p_distance_func(gene, state, cgs[i], state.pop_genomes)
# find the closest one
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i])
cgs = cgs.set(i, state.pop_genomes[closest_idx])
# the genome with closest_idx will become the new center, thus its distance to center is 0.
o2c = o2c.at[closest_idx].set(0)
return i + 1, i2s, cgs, o2c
_, idx2species, center_genomes, o2c_distances = \
jax.lax.while_loop(cond_func, body_func, (0, idx2species, state.center_genomes, o2c_distances))
state = state.update(
idx2species=idx2species,
center_genomes=center_genomes,
)
# part 2: assign members to each species
def cond_func(carry):
i, i2s, cgs, sk, o2c, nsk = carry
current_species_existed = ~jnp.isnan(sk[i])
not_all_assigned = jnp.any(jnp.isnan(i2s))
not_reach_species_upper_bounds = i < species_size
return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
def body_func(carry):
i, i2s, cgs, sk, o2c, nsk = carry
_, i2s, cgs, sk, o2c, nsk = jax.lax.cond(
jnp.isnan(sk[i]), # whether the current species is existing or not
create_new_species, # if not existing, create a new specie
update_exist_specie, # if existing, update the specie
(i, i2s, cgs, sk, o2c, nsk)
)
return i + 1, i2s, cgs, sk, o2c, nsk
def create_new_species(carry):
i, i2s, cgs, sk, o2c, nsk = carry
# pick the first one who has not been assigned to any species
idx = fetch_first(jnp.isnan(i2s))
# assign it to the new species
# [key, best score, last update generation, member_count]
sk = sk.at[i].set(nsk)
i2s = i2s.at[idx].set(nsk)
o2c = o2c.at[idx].set(0)
# update center genomes
cgs = cgs.set(i, state.pop_genomes[idx])
i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c)
# when a new species is created, it needs to be updated, thus do not change i
return i + 1, i2s, cgs, sk, o2c, nsk + 1 # change to next new speciate key
def update_exist_specie(carry):
i, i2s, cgs, sk, o2c, nsk = carry
i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c)
# turn to next species
return i + 1, i2s, cgs, sk, o2c, nsk
def speciate_by_threshold(i, i2s, cgs, sk, o2c):
# distance between such center genome and ppo genomes
o2p_distance = o2p_distance_func(gene, state, cgs[i], state.pop_genomes)
close_enough_mask = o2p_distance < state.compatibility_threshold
# when a genome is not assigned or the distance between its current center is bigger than this center
cacheable_mask = jnp.isnan(i2s) | (o2p_distance < o2c)
# jax.debug.print("{}", o2p_distance)
mask = close_enough_mask & cacheable_mask
# update species info
i2s = jnp.where(mask, sk[i], i2s)
# update distance between centers
o2c = jnp.where(mask, o2p_distance, o2c)
return i2s, o2c
# update idx2species
_, idx2species, center_genomes, species_keys, _, next_species_key = jax.lax.while_loop(
cond_func,
body_func,
(0, state.idx2species, state.center_genomes, state.species_info.species_keys, o2c_distances,
state.next_species_key)
)
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition can only happen when the number of species is reached species upper bounds
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
# complete info of species which is created in this generation
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species_info.best_fitness)
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species_info.best_fitness)
last_improved = jnp.where(new_created_mask, state.generation, state.species_info.last_improved)
# update members count
def count_members(idx):
key = species_keys[idx]
count = jnp.sum(idx2species == key, dtype=jnp.float32)
count = jnp.where(jnp.isnan(key), jnp.nan, count)
return count
member_count = vmap(count_members)(jnp.arange(species_size))
return state.update(
species_info=SpeciesInfo(species_keys, best_fitness, last_improved, member_count),
idx2species=idx2species,
center_genomes=center_genomes,
next_species_key=next_species_key
)
def argmin_with_mask(arr, mask):
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx

View File

@@ -1,55 +0,0 @@
from jax.tree_util import register_pytree_node_class
import numpy as np
import jax.numpy as jnp
@register_pytree_node_class
class SpeciesInfo:
def __init__(self, species_keys, best_fitness, last_improved, member_count):
self.species_keys = species_keys
self.best_fitness = best_fitness
self.last_improved = last_improved
self.member_count = member_count
@classmethod
def initialize(cls, state):
species_keys = np.full((state.S,), np.nan, dtype=np.float32)
best_fitness = np.full((state.S,), np.nan, dtype=np.float32)
last_improved = np.full((state.S,), np.nan, dtype=np.float32)
member_count = np.full((state.S,), np.nan, dtype=np.float32)
species_keys[0] = 0
best_fitness[0] = -np.inf
last_improved[0] = 0
member_count[0] = state.P
return cls(species_keys, best_fitness, last_improved, member_count)
def __getitem__(self, i):
return SpeciesInfo(self.species_keys[i], self.best_fitness[i], self.last_improved[i], self.member_count[i])
def get(self, i):
return self.species_keys[i], self.best_fitness[i], self.last_improved[i], self.member_count[i]
def set(self, idx, value):
species_keys = self.species_keys.at[idx].set(value[0])
best_fitness = self.best_fitness.at[idx].set(value[1])
last_improved = self.last_improved.at[idx].set(value[2])
member_count = self.member_count.at[idx].set(value[3])
return SpeciesInfo(species_keys, best_fitness, last_improved, member_count)
def remove(self, idx):
return self.set(idx, jnp.array([jnp.nan] * 4))
def size(self):
return self.species_keys.shape[0]
def tree_flatten(self):
children = self.species_keys, self.best_fitness, self.last_improved, self.member_count
aux_data = None
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)

View File

@@ -1 +0,0 @@
from .config import *

View File

@@ -1,107 +0,0 @@
from dataclasses import dataclass
from utils import Act, Agg
@dataclass(frozen=True)
class BasicConfig:
seed: int = 42
fitness_target: float = 1
generation_limit: int = 1000
pop_size: int = 100
def __post_init__(self):
assert self.pop_size > 0, "the population size must be greater than 0"
@dataclass(frozen=True)
class NeatConfig:
network_type: str = "feedforward"
inputs: int = 2
outputs: int = 1
max_nodes: int = 50
max_conns: int = 100
max_species: int = 10
# genome config
compatibility_disjoint: float = 1
compatibility_weight: float = 0.5
conn_add: float = 0.4
conn_delete: float = 0
node_add: float = 0.2
node_delete: float = 0
# species config
compatibility_threshold: float = 3.5
species_elitism: int = 2
max_stagnation: int = 15
genome_elitism: int = 2
survival_threshold: float = 0.2
min_species_size: int = 1
spawn_number_change_rate: float = 0.5
def __post_init__(self):
assert self.network_type in ["feedforward", "recurrent"], "the network type must be feedforward or recurrent"
assert self.inputs > 0, "the inputs number of neat must be greater than 0"
assert self.outputs > 0, "the outputs number of neat must be greater than 0"
assert self.max_nodes > 0, "the maximum nodes must be greater than 0"
assert self.max_conns > 0, "the maximum connections must be greater than 0"
assert self.max_species > 0, "the maximum species must be greater than 0"
assert self.compatibility_disjoint > 0, "the compatibility disjoint must be greater than 0"
assert self.compatibility_weight > 0, "the compatibility weight must be greater than 0"
assert self.conn_add >= 0, "the connection add probability must be greater than 0"
assert self.conn_delete >= 0, "the connection delete probability must be greater than 0"
assert self.node_add >= 0, "the node add probability must be greater than 0"
assert self.node_delete >= 0, "the node delete probability must be greater than 0"
assert self.compatibility_threshold > 0, "the compatibility threshold must be greater than 0"
assert self.species_elitism > 0, "the species elitism must be greater than 0"
assert self.max_stagnation > 0, "the max stagnation must be greater than 0"
assert self.genome_elitism > 0, "the genome elitism must be greater than 0"
assert self.survival_threshold > 0, "the survival threshold must be greater than 0"
assert self.min_species_size > 0, "the min species size must be greater than 0"
assert self.spawn_number_change_rate > 0, "the spawn number change rate must be greater than 0"
@dataclass(frozen=True)
class HyperNeatConfig:
below_threshold: float = 0.2
max_weight: float = 3
activation: callable = Act.sigmoid
aggregation: callable = Agg.sum
activate_times: int = 5
inputs: int = 2
outputs: int = 1
def __post_init__(self):
assert self.below_threshold > 0, "the below threshold must be greater than 0"
assert self.max_weight > 0, "the max weight must be greater than 0"
assert self.activate_times > 0, "the activate times must be greater than 0"
assert self.inputs > 0, "the inputs number of hyper neat must be greater than 0"
assert self.outputs > 0, "the outputs number of hyper neat must be greater than 0"
@dataclass(frozen=True)
class GeneConfig:
pass
@dataclass(frozen=True)
class SubstrateConfig:
pass
@dataclass(frozen=True)
class ProblemConfig:
pass
@dataclass(frozen=True)
class Config:
basic: BasicConfig = BasicConfig()
neat: NeatConfig = NeatConfig()
hyperneat: HyperNeatConfig = HyperNeatConfig()
gene: GeneConfig = GeneConfig()
substrate: SubstrateConfig = SubstrateConfig()
problem: ProblemConfig = ProblemConfig()

View File

@@ -1,6 +0,0 @@
from .algorithm import Algorithm
from .state import State
from .genome import Genome
from .gene import Gene
from .substrate import Substrate
from .problem import Problem

View File

@@ -1,50 +0,0 @@
from functools import partial
import jax
from .state import State
from .genome import Genome
class Algorithm:
def setup(self, randkey, state: State = State()):
"""initialize the state of the algorithm"""
raise NotImplementedError
@partial(jax.jit, static_argnums=(0,))
def ask(self, state: State):
"""require the population to be evaluated"""
return self.ask_algorithm(state)
@partial(jax.jit, static_argnums=(0,))
def tell(self, state: State, fitness):
"""update the state of the algorithm"""
return self.tell_algorithm(state, fitness)
@partial(jax.jit, static_argnums=(0,))
def transform(self, state: State, genome: Genome):
"""transform the genome into a neural network"""
return self.forward_transform(state, genome)
@partial(jax.jit, static_argnums=(0,))
def act(self, state: State, inputs, genome: Genome):
return self.forward(state, inputs, genome)
def forward_transform(self, state: State, genome: Genome):
raise NotImplementedError
def forward(self, state: State, inputs, genome: Genome):
raise NotImplementedError
def ask_algorithm(self, state: State):
"""ask the specific algorithm for a new population"""
raise NotImplementedError
def tell_algorithm(self, state: State, fitness):
"""tell the specific algorithm the fitness of the population"""
raise NotImplementedError

View File

@@ -1,40 +0,0 @@
from config import GeneConfig
from .state import State
class Gene:
node_attrs = []
conn_attrs = []
def __init__(self, config: GeneConfig = GeneConfig()):
raise NotImplementedError
def setup(self, state=State()):
raise NotImplementedError
def update(self, state):
raise NotImplementedError
def new_node_attrs(self, state: State):
raise NotImplementedError
def new_conn_attrs(self, state: State):
raise NotImplementedError
def mutate_node(self, state: State, randkey, node_attrs):
raise NotImplementedError
def mutate_conn(self, state: State, randkey, conn_attrs):
raise NotImplementedError
def distance_node(self, state: State, node_attrs1, node_attrs2):
raise NotImplementedError
def distance_conn(self, state: State, conn_attrs1, conn_attrs2):
raise NotImplementedError
def forward_transform(self, state: State, genome):
raise NotImplementedError
def forward(self, state: State, inputs, transform):
raise NotImplementedError

View File

@@ -1,90 +0,0 @@
from __future__ import annotations
from jax.tree_util import register_pytree_node_class
from jax import numpy as jnp
from utils.tools import fetch_first
@register_pytree_node_class
class Genome:
def __init__(self, nodes, conns):
self.nodes = nodes
self.conns = conns
def __repr__(self):
return f"Genome(nodes={self.nodes}, conns={self.conns})"
def __getitem__(self, idx):
return self.__class__(self.nodes[idx], self.conns[idx])
def __eq__(self, other):
nodes_eq = jnp.alltrue((self.nodes == other.nodes) | (jnp.isnan(self.nodes) & jnp.isnan(other.nodes)))
conns_eq = jnp.alltrue((self.conns == other.conns) | (jnp.isnan(self.conns) & jnp.isnan(other.conns)))
return nodes_eq & conns_eq
def set(self, idx, value: Genome):
return self.__class__(self.nodes.at[idx].set(value.nodes), self.conns.at[idx].set(value.conns))
def update(self, nodes, conns):
return self.__class__(nodes, conns)
def update_nodes(self, nodes):
return self.update(nodes, self.conns)
def update_conns(self, conns):
return self.update(self.nodes, conns)
def count(self):
"""Count how many nodes and connections are in the genome."""
nodes_cnt = jnp.sum(~jnp.isnan(self.nodes[:, 0]))
conns_cnt = jnp.sum(~jnp.isnan(self.conns[:, 0]))
return nodes_cnt, conns_cnt
def add_node(self, new_key: int, attrs):
"""
Add a new node to the genome.
The new node will place at the first NaN row.
"""
exist_keys = self.nodes[:, 0]
pos = fetch_first(jnp.isnan(exist_keys))
new_nodes = self.nodes.at[pos, 0].set(new_key)
new_nodes = new_nodes.at[pos, 1:].set(attrs)
return self.update_nodes(new_nodes)
def delete_node_by_pos(self, pos):
"""
Delete a node from the genome.
Delete the node by its pos in nodes.
"""
nodes = self.nodes.at[pos].set(jnp.nan)
return self.update_nodes(nodes)
def add_conn(self, i_key, o_key, enable: bool, attrs):
"""
Add a new connection to the genome.
The new connection will place at the first NaN row.
"""
con_keys = self.conns[:, 0]
pos = fetch_first(jnp.isnan(con_keys))
new_conns = self.conns.at[pos, 0:3].set(jnp.array([i_key, o_key, enable]))
new_conns = new_conns.at[pos, 3:].set(attrs)
return self.update_conns(new_conns)
def delete_conn_by_pos(self, pos):
"""
Delete a connection from the genome.
Delete the connection by its idx.
"""
conns = self.conns.at[pos].set(jnp.nan)
return self.update_conns(conns)
def tree_flatten(self):
children = self.nodes, self.conns
aux_data = None
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)

View File

@@ -1,29 +0,0 @@
from typing import Callable
from config import ProblemConfig
from .state import State
class Problem:
jitable = None
def __init__(self, problem_config: ProblemConfig = ProblemConfig()):
self.config = problem_config
def evaluate(self, randkey, state: State, act_func: Callable, params):
raise NotImplementedError
@property
def input_shape(self):
raise NotImplementedError
@property
def output_shape(self):
raise NotImplementedError
def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs):
"""
show how a genome perform in this problem
"""
raise NotImplementedError

View File

@@ -1,8 +0,0 @@
from config import SubstrateConfig
class Substrate:
@staticmethod
def setup(state, config: SubstrateConfig = SubstrateConfig()):
return state

View File

@@ -12,7 +12,7 @@ def example_conf():
basic=BasicConfig(
seed=42,
fitness_target=10000,
pop_size=1000
pop_size=100
),
neat=NeatConfig(
inputs=27,

View File

@@ -1,7 +1,3 @@
"""
pipeline for jitable env like func_fit, gymnax
"""
from functools import partial
from typing import Type
@@ -16,24 +12,28 @@ from core import State, Algorithm, Problem
class Pipeline:
def __init__(self, config: Config, algorithm: Algorithm, problem_type: Type[Problem]):
def __init__(
self,
algorithm: Algorithm,
problem: Problem,
seed: int = 42,
fitness_target: float = 1,
generation_limit: int = 1000,
pop_size: int = 100,
):
assert problem.jitable, "Currently, problem must be jitable"
assert problem_type.jitable, "problem must be jitable"
self.config = config
self.algorithm = algorithm
self.problem = problem_type(config.problem)
self.problem = problem
self.seed = seed
self.fitness_target = fitness_target
self.generation_limit = generation_limit
self.pop_size = pop_size
print(self.problem.input_shape, self.problem.output_shape)
if isinstance(algorithm, NEAT):
assert config.neat.inputs == self.problem.input_shape[-1], f"problem input shape {self.problem.input_shape}"
elif isinstance(algorithm, HyperNEAT):
assert config.hyperneat.inputs == self.problem.input_shape[-1], f"problem input shape {self.problem.input_shape}"
else:
raise NotImplementedError
# TODO: make each algorithm's input_num and output_num
assert algorithm.input_num == self.problem.input_shape[-1], f"problem input shape {self.problem.input_shape}"
self.act_func = self.algorithm.act
@@ -45,19 +45,19 @@ class Pipeline:
self.generation_timestamp = None
def setup(self):
key = jax.random.PRNGKey(self.config.basic.seed)
key = jax.random.PRNGKey(self.seed)
algorithm_key, evaluate_key = jax.random.split(key, 2)
state = State()
state = self.algorithm.setup(algorithm_key, state)
return state.update(
evaluate_key=evaluate_key
# TODO: Problem should has setup function to maintain state
return State(
alg=self.algorithm.setup(algorithm_key),
pro=self.problem.setup(evaluate_key),
)
@partial(jax.jit, static_argnums=(0,))
def step(self, state):
key, sub_key = jax.random.split(state.evaluate_key)
keys = jax.random.split(key, self.config.basic.pop_size)
keys = jax.random.split(key, self.pop_size)
pop = self.algorithm.ask(state)
@@ -72,7 +72,7 @@ class Pipeline:
def auto_run(self, ini_state):
state = ini_state
for _ in range(self.config.basic.generation_limit):
for _ in range(self.generation_limit):
self.generation_timestamp = time.time()
@@ -84,7 +84,7 @@ class Pipeline:
self.analysis(state, previous_pop, fitnesses)
if max(fitnesses) >= self.config.basic.fitness_target:
if max(fitnesses) >= self.fitness_target:
print("Fitness limit reached!")
return state, self.best_genome
@@ -120,3 +120,4 @@ class Pipeline:
print("start compile")
self.step.lower(self, state).compile()
print(f"compile finished, cost time: {time.time() - tic}s")

View File

@@ -0,0 +1 @@
from .base import BaseProblem

44
problem/base.py Normal file
View File

@@ -0,0 +1,44 @@
from typing import Callable
from config import ProblemConfig
from core.state import State
class BaseProblem:
jitable = None
def __init__(self):
pass
def setup(self, randkey, state: State = State()):
"""initialize the state of the problem"""
raise NotImplementedError
def evaluate(self, randkey, state: State, act_func: Callable, params):
"""evaluate one individual"""
raise NotImplementedError
@property
def input_shape(self):
"""
The input shape for the problem to evaluate
In RL problem, it is the observation space
In function fitting problem, it is the input shape of the function
"""
raise NotImplementedError
@property
def output_shape(self):
"""
The output shape for the problem to evaluate
In RL problem, it is the action space
In function fitting problem, it is the output shape of the function
"""
raise NotImplementedError
def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs):
"""
show how a genome perform in this problem
"""
raise NotImplementedError

View File

@@ -1,3 +1,3 @@
from .func_fit import FuncFit, FuncFitConfig
from .func_fit import FuncFit
from .xor import XOR
from .xor3d import XOR3d

View File

@@ -1,42 +1,35 @@
from typing import Callable
from dataclasses import dataclass
import jax
import jax.numpy as jnp
from config import ProblemConfig
from core import Problem, State
from .. import BaseProblem
@dataclass(frozen=True)
class FuncFitConfig(ProblemConfig):
error_method: str = 'mse'
def __post_init__(self):
assert self.error_method in {'mse', 'rmse', 'mae', 'mape'}
class FuncFit(Problem):
class FuncFit(BaseProblem):
jitable = True
def __init__(self, config: FuncFitConfig = FuncFitConfig()):
self.config = config
super().__init__(config)
def __init__(self,
error_method: str = 'mse'
):
super().__init__()
def evaluate(self, randkey, state: State, act_func: Callable, params):
assert error_method in {'mse', 'rmse', 'mae', 'mape'}
self.error_method = error_method
def evaluate(self, randkey, state, act_func, params):
predict = act_func(state, self.inputs, params)
if self.config.error_method == 'mse':
if self.error_method == 'mse':
loss = jnp.mean((predict - self.targets) ** 2)
elif self.config.error_method == 'rmse':
elif self.error_method == 'rmse':
loss = jnp.sqrt(jnp.mean((predict - self.targets) ** 2))
elif self.config.error_method == 'mae':
elif self.error_method == 'mae':
loss = jnp.mean(jnp.abs(predict - self.targets))
elif self.config.error_method == 'mape':
elif self.error_method == 'mape':
loss = jnp.mean(jnp.abs((predict - self.targets) / self.targets))
else:
@@ -44,7 +37,7 @@ class FuncFit(Problem):
return -loss
def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs):
def show(self, randkey, state, act_func, params, *args, **kwargs):
predict = act_func(state, self.inputs, params)
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
loss = -self.evaluate(randkey, state, act_func, params)

View File

@@ -1,13 +1,12 @@
import numpy as np
from .func_fit import FuncFit, FuncFitConfig
from .func_fit import FuncFit
class XOR(FuncFit):
def __init__(self, config: FuncFitConfig = FuncFitConfig()):
self.config = config
super().__init__(config)
def __init__(self, error_method: str = 'mse'):
super().__init__(error_method)
@property
def inputs(self):

View File

@@ -1,13 +1,12 @@
import numpy as np
from .func_fit import FuncFit, FuncFitConfig
from .func_fit import FuncFit
class XOR3d(FuncFit):
def __init__(self, config: FuncFitConfig = FuncFitConfig()):
self.config = config
super().__init__(config)
def __init__(self, error_method: str = 'mse'):
super().__init__(error_method)
@property
def inputs(self):
@@ -37,8 +36,8 @@ class XOR3d(FuncFit):
@property
def input_shape(self):
return (8, 3)
return 8, 3
@property
def output_shape(self):
return (8, 1)
return 8, 1

View File

@@ -1,28 +1,13 @@
from dataclasses import dataclass
from typing import Callable
import jax.numpy as jnp
from brax import envs
from core import State
from .rl_jit import RLEnv, RLEnvConfig
@dataclass(frozen=True)
class BraxConfig(RLEnvConfig):
env_name: str = "ant"
backend: str = "generalized"
def __post_init__(self):
# TODO: Check if env_name is registered
# assert self.env_name in gymnax.registered_envs, f"Env {self.env_name} not registered"
pass
from .rl_jit import RLEnv
class BraxEnv(RLEnv):
def __init__(self, config: BraxConfig = BraxConfig()):
super().__init__(config)
self.config = config
self.env = envs.create(env_name=config.env_name, backend=config.backend)
def __init__(self, env_name: str = "ant", backend: str = "generalized"):
super().__init__()
self.env = envs.create(env_name=env_name, backend=backend)
def env_step(self, randkey, env_state, action):
state = self.env.step(env_state, action)
@@ -40,9 +25,7 @@ class BraxEnv(RLEnv):
def output_shape(self):
return (self.env.action_size,)
def show(self, randkey, state: State, act_func: Callable, params, save_path=None, height=512, width=512,
duration=0.1, *args,
**kwargs):
def show(self, randkey, state, act_func, params, save_path=None, height=512, width=512, duration=0.1, *args, **kwargs):
import jax
import imageio
@@ -56,8 +39,7 @@ class BraxEnv(RLEnv):
def step(key, env_state, obs):
key, _ = jax.random.split(key)
net_out = act_func(state, obs, params)
action = self.config.output_transform(net_out)
action = act_func(state, obs, params)
obs, env_state, r, done, _ = self.step(randkey, env_state, action)
return key, env_state, obs, r, done
@@ -72,7 +54,6 @@ class BraxEnv(RLEnv):
def create_gif(image_list, gif_name, duration):
with imageio.get_writer(gif_name, mode='I', duration=duration) as writer:
for image in image_list:
# 确保图像的数据类型正确
formatted_image = np.array(image, dtype=np.uint8)
writer.append_data(formatted_image)

View File

@@ -1,26 +1,15 @@
from dataclasses import dataclass
from typing import Callable
import gymnax
from core import State
from .rl_jit import RLEnv, RLEnvConfig
from .rl_jit import RLEnv
@dataclass(frozen=True)
class GymNaxConfig(RLEnvConfig):
env_name: str = "CartPole-v1"
def __post_init__(self):
assert self.env_name in gymnax.registered_envs, f"Env {self.env_name} not registered"
class GymNaxEnv(RLEnv):
def __init__(self, config: GymNaxConfig = GymNaxConfig()):
super().__init__(config)
self.config = config
self.env, self.env_params = gymnax.make(config.env_name)
def __init__(self, env_name):
super().__init__()
assert env_name in gymnax.registered_envs, f"Env {env_name} not registered"
self.env, self.env_params = gymnax.make(env_name)
def env_step(self, randkey, env_state, action):
return self.env.step(randkey, env_state, action, self.env_params)
@@ -36,5 +25,5 @@ class GymNaxEnv(RLEnv):
def output_shape(self):
return self.env.action_space(self.env_params).shape
def show(self, randkey, state: State, act_func: Callable, params):
def show(self, randkey, state, act_func, params, *args, **kwargs):
raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).")

View File

@@ -1,28 +1,18 @@
from dataclasses import dataclass
from typing import Callable
from functools import partial
import jax
from config import ProblemConfig
from .. import BaseProblem
from core import Problem, State
@dataclass(frozen=True)
class RLEnvConfig(ProblemConfig):
output_transform: Callable = lambda x: x
class RLEnv(Problem):
class RLEnv(BaseProblem):
jitable = True
def __init__(self, config: RLEnvConfig = RLEnvConfig()):
super().__init__(config)
self.config = config
# TODO: move output transform to algorithm
def __init__(self):
super().__init__()
def evaluate(self, randkey, state: State, act_func: Callable, params):
def evaluate(self, randkey, state, act_func, params):
rng_reset, rng_episode = jax.random.split(randkey)
init_obs, init_env_state = self.reset(rng_reset)
@@ -31,8 +21,7 @@ class RLEnv(Problem):
return ~done
def body_func(carry):
obs, env_state, rng, _, tr = carry # total reward
net_out = act_func(state, obs, params)
action = self.config.output_transform(net_out)
action = act_func(state, obs, params)
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
next_rng, _ = jax.random.split(rng)
return next_obs, next_env_state, next_rng, done, tr + reward
@@ -67,5 +56,5 @@ class RLEnv(Problem):
def output_shape(self):
raise NotImplementedError
def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs):
def show(self, randkey, state, act_func, params, *args, **kwargs):
raise NotImplementedError

64
t.py Normal file
View File

@@ -0,0 +1,64 @@
from algorithm.neat import *
from utils import Act, Agg
import jax, jax.numpy as jnp
def main():
# index, bias, response, activation, aggregation
nodes = jnp.array([
[0, 0, 1, 0, 0], # in[0]
[1, 0, 1, 0, 0], # in[1]
[2, 0.5, 1, 0, 0], # out[0],
[3, 1, 1, 0, 0], # hidden[0],
[4, -1, 1, 0, 0], # hidden[1],
])
# in_node, out_node, enable, weight
conns = jnp.array([
[0, 3, 1, 0.5], # in[0] -> hidden[0]
[1, 4, 1, 0.5], # in[1] -> hidden[1]
[3, 2, 1, 0.5], # hidden[0] -> out[0]
[4, 2, 1, 0.5], # hidden[1] -> out[0]
])
genome = RecurrentGenome(
num_inputs=2,
num_outputs=1,
node_gene=DefaultNodeGene(
activation_default=Act.identity,
activation_options=(Act.identity, ),
aggregation_default=Agg.sum,
aggregation_options=(Agg.sum, ),
),
activate_time=3
)
transformed = genome.transform(nodes, conns)
print(*transformed, sep='\n')
inputs = jnp.array([0, 0])
outputs = genome.forward(inputs, transformed)
print(outputs)
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(0, None)))(inputs, transformed)
print(outputs)
expected: [[0.5], [0.75], [0.75], [1]]
print('\n-------------------------------------------------------\n')
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
print(conns)
transformed = genome.transform(nodes, conns)
print(*transformed, sep='\n')
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
outputs = jax.vmap(genome.forward, in_axes=(0, None))(inputs, transformed)
print(outputs)
expected: [[0.5], [0.75], [0.5], [0.75]]
if __name__ == '__main__':
main()

0
test/__init__.py Normal file
View File

113
test/test_genome.py Normal file
View File

@@ -0,0 +1,113 @@
from algorithm.neat import *
from utils import Act, Agg
import jax, jax.numpy as jnp
def test_default():
# index, bias, response, activation, aggregation
nodes = jnp.array([
[0, 0, 1, 0, 0], # in[0]
[1, 0, 1, 0, 0], # in[1]
[2, 0.5, 1, 0, 0], # out[0],
[3, 1, 1, 0, 0], # hidden[0],
[4, -1, 1, 0, 0], # hidden[1],
])
# in_node, out_node, enable, weight
conns = jnp.array([
[0, 3, 1, 0.5], # in[0] -> hidden[0]
[1, 4, 1, 0.5], # in[1] -> hidden[1]
[3, 2, 1, 0.5], # hidden[0] -> out[0]
[4, 2, 1, 0.5], # hidden[1] -> out[0]
])
genome = DefaultGenome(
num_inputs=2,
num_outputs=1,
node_gene=DefaultNodeGene(
activation_default=Act.identity,
activation_options=(Act.identity, ),
aggregation_default=Agg.sum,
aggregation_options=(Agg.sum, ),
),
)
transformed = genome.transform(nodes, conns)
print(*transformed, sep='\n')
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(0, None)))(inputs, transformed)
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
# expected: [[0.5], [0.75], [0.75], [1]]
print('\n-------------------------------------------------------\n')
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
print(conns)
transformed = genome.transform(nodes, conns)
print(*transformed, sep='\n')
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
outputs = jax.vmap(genome.forward, in_axes=(0, None))(inputs, transformed)
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
# expected: [[0.5], [0.75], [0.5], [0.75]]
def test_recurrent():
# index, bias, response, activation, aggregation
nodes = jnp.array([
[0, 0, 1, 0, 0], # in[0]
[1, 0, 1, 0, 0], # in[1]
[2, 0.5, 1, 0, 0], # out[0],
[3, 1, 1, 0, 0], # hidden[0],
[4, -1, 1, 0, 0], # hidden[1],
])
# in_node, out_node, enable, weight
conns = jnp.array([
[0, 3, 1, 0.5], # in[0] -> hidden[0]
[1, 4, 1, 0.5], # in[1] -> hidden[1]
[3, 2, 1, 0.5], # hidden[0] -> out[0]
[4, 2, 1, 0.5], # hidden[1] -> out[0]
])
genome = RecurrentGenome(
num_inputs=2,
num_outputs=1,
node_gene=DefaultNodeGene(
activation_default=Act.identity,
activation_options=(Act.identity, ),
aggregation_default=Agg.sum,
aggregation_options=(Agg.sum, ),
),
activate_time=3,
)
transformed = genome.transform(nodes, conns)
print(*transformed, sep='\n')
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(0, None)))(inputs, transformed)
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
# expected: [[0.5], [0.75], [0.75], [1]]
print('\n-------------------------------------------------------\n')
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
print(conns)
transformed = genome.transform(nodes, conns)
print(*transformed, sep='\n')
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
outputs = jax.vmap(genome.forward, in_axes=(0, None))(inputs, transformed)
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
# expected: [[0.5], [0.75], [0.5], [0.75]]

View File

@@ -1,4 +1,5 @@
from .activation import Act, act
from .aggregation import Agg, agg
from .tools import *
from .graph import *
from .graph import *
from .state import State

View File

@@ -57,10 +57,8 @@ def agg(idx, z, agg_funcs):
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
def all_nan():
return 0.
def not_all_nan():
return jax.lax.switch(idx, agg_funcs, z)
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
return jax.lax.cond(
jnp.all(jnp.isnan(z)),
lambda: jnp.nan, # all inputs are nan
lambda: jax.lax.switch(idx, agg_funcs, z) # otherwise
)

View File

@@ -5,13 +5,11 @@ import jax
from jax import numpy as jnp, Array, jit, vmap
I_INT = np.iinfo(jnp.int32).max # infinite int
EMPTY_NODE = np.full((1, 5), jnp.nan)
EMPTY_CON = np.full((1, 4), jnp.nan)
def unflatten_conns(nodes, conns):
"""
transform the (C, CL) connections to (CL-2, N, N)
transform the (C, CL) connections to (CL-2, N, N), 2 is for the input index and output index)
:return:
"""
N = nodes.shape[0]
@@ -66,4 +64,43 @@ def rank_elements(array, reverse=False):
"""
if not reverse:
array = -array
return jnp.argsort(jnp.argsort(array))
return jnp.argsort(jnp.argsort(array))
@jit
def mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate):
k1, k2, k3 = jax.random.split(key, num=3)
noise = jax.random.normal(k1, ()) * mutate_power
replace = jax.random.normal(k2, ()) * init_std + init_mean
r = jax.random.uniform(k3, ())
val = jnp.where(
r < mutate_rate,
val + noise,
jnp.where(
(mutate_rate < r) & (r < mutate_rate + replace_rate),
replace,
val
)
)
return val
@jit
def mutate_int(key, val, options, replace_rate):
k1, k2 = jax.random.split(key, num=2)
r = jax.random.uniform(k1, ())
val = jnp.where(
r < replace_rate,
jax.random.choice(k2, options),
val
)
return val
def argmin_with_mask(arr, mask):
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx