hyper neat

This commit is contained in:
wls2002
2023-07-24 19:25:02 +08:00
parent ac295c1921
commit ebad574431
24 changed files with 542 additions and 103 deletions

View File

@@ -0,0 +1,2 @@
from .neat import *
from .hyper_neat import *

View File

@@ -0,0 +1,2 @@
from .hyper_neat import HyperNEAT
from .substrate import NormalSubstrate, NormalSubstrateConfig

View File

@@ -0,0 +1,122 @@
from typing import Type
import jax
from jax import numpy as jnp, Array, vmap
import numpy as np
from config import Config, HyperNeatConfig
from core import Algorithm, Substrate, State, Genome
from utils import Activation, Aggregation
from algorithm.neat import NEAT
from .substrate import analysis_substrate
class HyperNEAT(Algorithm):
def __init__(self, config: Config, neat: NEAT, substrate: Type[Substrate]):
self.config = config
self.neat = neat
self.substrate = substrate
self.forward_func = None
def setup(self, randkey, state=State()):
neat_key, randkey = jax.random.split(randkey)
state = state.update(
below_threshold=self.config.hyper_neat.below_threshold,
max_weight=self.config.hyper_neat.max_weight,
)
state = self.neat.setup(neat_key, state)
state = self.substrate.setup(self.config.substrate, state)
assert self.config.hyper_neat.inputs + 1 == state.input_coors.shape[0] # +1 for bias
assert self.config.hyper_neat.outputs == state.output_coors.shape[0]
h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state)
h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis]
h_conns = np.zeros((correspond_keys.shape[0], 3), dtype=np.float32)
h_conns[:, 0:2] = correspond_keys
state = state.update(
h_input_idx=h_input_idx,
h_output_idx=h_output_idx,
h_hidden_idx=h_hidden_idx,
h_nodes=h_nodes,
h_conns=h_conns,
query_coors=query_coors,
)
self.forward_func = HyperNEATGene.create_forward(self.config.hyper_neat, state)
return state
def ask(self, state: State):
return state.pop_genomes
def tell(self, state: State, fitness):
return self.neat.tell(state, fitness)
def forward(self, inputs: Array, transformed: Array):
return self.forward_func(inputs, transformed)
def forward_transform(self, state: State, genome: Genome):
t = self.neat.forward_transform(state, genome)
query_res = vmap(self.neat.forward, in_axes=(0, None))(state.query_coors, t)
# mute the connection with weight below threshold
query_res = jnp.where((-state.below_threshold < query_res) & (query_res < state.below_threshold), 0., query_res)
# make query res in range [-max_weight, max_weight]
query_res = jnp.where(query_res > 0, query_res - state.below_threshold, query_res)
query_res = jnp.where(query_res < 0, query_res + state.below_threshold, query_res)
query_res = query_res / (1 - state.below_threshold) * state.max_weight
h_conns = state.h_conns.at[:, 2:].set(query_res)
return HyperNEATGene.forward_transform(Genome(state.h_nodes, h_conns))
class HyperNEATGene:
node_attrs = [] # no node attributes
conn_attrs = ['weight']
@staticmethod
def forward_transform(genome: Genome):
N = genome.nodes.shape[0]
u_conns = jnp.zeros((N, N), dtype=jnp.float32)
in_keys = jnp.asarray(genome.conns[:, 0], jnp.int32)
out_keys = jnp.asarray(genome.conns[:, 1], jnp.int32)
weights = genome.conns[:, 2]
u_conns = u_conns.at[in_keys, out_keys].set(weights)
return genome.nodes, u_conns
@staticmethod
def create_forward(config: HyperNeatConfig, state: State):
act = Activation.name2func[config.activation]
agg = Aggregation.name2func[config.aggregation]
batch_act, batch_agg = jax.vmap(act), jax.vmap(agg)
def forward(inputs, transform):
inputs_with_bias = jnp.concatenate((inputs, jnp.ones((1,))), axis=0)
nodes, weights = transform
input_idx = state.h_input_idx
output_idx = state.h_output_idx
N = nodes.shape[0]
vals = jnp.full((N,), 0.)
def body_func(i, values):
values = values.at[input_idx].set(inputs_with_bias)
nodes_ins = values * weights.T
values = batch_agg(nodes_ins) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(values) # z = act(z)
return values
vals = jax.lax.fori_loop(0, config.activate_times, body_func, vals)
return vals[output_idx]
return forward

View File

@@ -0,0 +1,2 @@
from .normal import NormalSubstrate, NormalSubstrateConfig
from .tools import analysis_substrate

View File

@@ -0,0 +1,25 @@
from dataclasses import dataclass
from typing import Tuple
import numpy as np
from core import Substrate, State
from config import SubstrateConfig
@dataclass(frozen=True)
class NormalSubstrateConfig(SubstrateConfig):
input_coors: Tuple[Tuple[float]] = ((-1, -1), (0, -1), (1, -1))
hidden_coors: Tuple[Tuple[float]] = ((-1, 0), (0, 0), (1, 0))
output_coors: Tuple[Tuple[float]] = ((0, 1), )
class NormalSubstrate(Substrate):
@staticmethod
def setup(config: NormalSubstrateConfig, state: State = State()):
return state.update(
input_coors=np.asarray(config.input_coors, dtype=np.float32),
output_coors=np.asarray(config.output_coors, dtype=np.float32),
hidden_coors=np.asarray(config.hidden_coors, dtype=np.float32),
)

View File

@@ -0,0 +1,50 @@
from typing import Type
import numpy as np
def analysis_substrate(state):
cd = state.input_coors.shape[1] # coordinate dimensions
si = state.input_coors.shape[0] # input coordinate size
so = state.output_coors.shape[0] # output coordinate size
sh = state.hidden_coors.shape[0] # hidden coordinate size
input_idx = np.arange(si)
output_idx = np.arange(si, si + so)
hidden_idx = np.arange(si + so, si + so + sh)
total_conns = si * sh + sh * sh + sh * so
query_coors = np.zeros((total_conns, cd * 2))
correspond_keys = np.zeros((total_conns, 2))
# connect input to hidden
aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, state.input_coors, state.hidden_coors)
query_coors[0: si * sh, :] = aux_coors
correspond_keys[0: si * sh, :] = aux_keys
# connect hidden to hidden
aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, state.hidden_coors, state.hidden_coors)
query_coors[si * sh: si * sh + sh * sh, :] = aux_coors
correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys
# connect hidden to output
aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, state.hidden_coors, state.output_coors)
query_coors[si * sh + sh * sh:, :] = aux_coors
correspond_keys[si * sh + sh * sh:, :] = aux_keys
return input_idx, output_idx, hidden_idx, query_coors, correspond_keys
def cartesian_product(keys1, keys2, coors1, coors2):
len1 = keys1.shape[0]
len2 = keys2.shape[0]
repeated_coors1 = np.repeat(coors1, len2, axis=0)
repeated_keys1 = np.repeat(keys1, len2)
tiled_coors2 = np.tile(coors2, (len1, 1))
tiled_keys2 = np.tile(keys2, len1)
new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1)
correspond_keys = np.column_stack((repeated_keys1, tiled_keys2))
return new_coors, correspond_keys

View File

@@ -0,0 +1,2 @@
from .neat import NEAT
from .gene import *

View File

@@ -1 +1,2 @@
from .normal import NormalGene, NormalGeneConfig
from .recurrent import RecurrentGene, RecurrentGeneConfig

View File

@@ -0,0 +1,84 @@
from dataclasses import dataclass
import jax
from jax import Array, numpy as jnp, vmap
from .normal import NormalGene, NormalGeneConfig
from core import State, Genome
from utils import Activation, Aggregation, unflatten_conns
@dataclass(frozen=True)
class RecurrentGeneConfig(NormalGeneConfig):
activate_times: int = 10
def __post_init__(self):
super().__post_init__()
assert self.activate_times > 0
class RecurrentGene(NormalGene):
@staticmethod
def forward_transform(state: State, genome: Genome):
u_conns = unflatten_conns(genome.nodes, genome.conns)
# remove un-enable connections and remove enable attr
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
return genome.nodes, u_conns
@staticmethod
def create_forward(state: State, config: RecurrentGeneConfig):
activation_funcs = [Activation.name2func[name] for name in config.activation_options]
aggregation_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
def act(idx, z):
"""
calculate activation function for each node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
res = jax.lax.switch(idx, activation_funcs, z)
return res
def agg(idx, z):
"""
calculate activation function for inputs of node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
def all_nan():
return 0.
def not_all_nan():
return jax.lax.switch(idx, aggregation_funcs, z)
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
batch_act, batch_agg = vmap(act), vmap(agg)
def forward(inputs, transform) -> Array:
nodes, cons = transform
input_idx = state.input_idx
output_idx = state.output_idx
N = nodes.shape[0]
vals = jnp.full((N,), 0.)
weights = cons[0, :]
def body_func(i, values):
values = values.at[input_idx].set(inputs)
nodes_ins = values * weights.T
values = batch_agg(nodes[:, 4], nodes_ins) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(nodes[:, 3], values) # z = act(z)
return values
vals = jax.lax.fori_loop(0, config.activate_times, body_func, vals)
return vals[output_idx]
return forward

View File

@@ -7,7 +7,7 @@ import numpy as np
from config import Config
from core import Algorithm, State, Gene, Genome
from .ga import crossover, create_mutate
from .species import update_species, create_speciate
from .species import SpeciesInfo, update_species, create_speciate
class NEAT(Algorithm):
@@ -22,9 +22,9 @@ class NEAT(Algorithm):
def setup(self, randkey, state: State = State()):
"""initialize the state of the algorithm"""
input_idx = np.arange(self.config.basic.num_inputs)
output_idx = np.arange(self.config.basic.num_inputs,
self.config.basic.num_inputs + self.config.basic.num_outputs)
input_idx = np.arange(self.config.neat.inputs)
output_idx = np.arange(self.config.neat.inputs,
self.config.neat.inputs + self.config.neat.outputs)
state = state.update(
P=self.config.basic.pop_size,
@@ -49,22 +49,13 @@ class NEAT(Algorithm):
state = self.gene_type.setup(self.config.gene, state)
pop_genomes = self._initialize_genomes(state)
species_keys = np.full((state.S,), np.nan, dtype=np.float32)
best_fitness = np.full((state.S,), np.nan, dtype=np.float32)
last_improved = np.full((state.S,), np.nan, dtype=np.float32)
member_count = np.full((state.S,), np.nan, dtype=np.float32)
species_info = SpeciesInfo.initialize(state)
idx2species = jnp.zeros(state.P, dtype=jnp.float32)
species_keys[0] = 0
best_fitness[0] = -np.inf
last_improved[0] = 0
member_count[0] = state.P
center_nodes = jnp.full((state.S, state.N, state.NL), jnp.nan, dtype=jnp.float32)
center_conns = jnp.full((state.S, state.C, state.CL), jnp.nan, dtype=jnp.float32)
center_nodes = center_nodes.at[0, :, :].set(pop_genomes.nodes[0, :, :])
center_conns = center_conns.at[0, :, :].set(pop_genomes.conns[0, :, :])
center_genomes = vmap(Genome)(center_nodes, center_conns)
center_genomes = Genome(center_nodes, center_conns)
center_genomes = center_genomes.set(0, pop_genomes[0])
generation = 0
next_node_key = max(*state.input_idx, *state.output_idx) + 2
@@ -73,10 +64,7 @@ class NEAT(Algorithm):
state = state.update(
randkey=randkey,
pop_genomes=pop_genomes,
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
member_count=member_count,
species_info=species_info,
idx2species=idx2species,
center_genomes=center_genomes,
@@ -135,7 +123,7 @@ class NEAT(Algorithm):
pop_nodes = np.tile(o_nodes, (state.P, 1, 1))
pop_conns = np.tile(o_conns, (state.P, 1, 1))
return vmap(Genome)(pop_nodes, pop_conns)
return Genome(pop_nodes, pop_conns)
def _create_tell(self):
mutate = create_mutate(self.config.neat, self.gene_type)

View File

@@ -1 +1,2 @@
from .operations import update_species, create_speciate
from .species_info import SpeciesInfo

View File

@@ -6,6 +6,7 @@ from jax import numpy as jnp, vmap
from core import Gene, Genome
from utils import rank_elements, fetch_first
from .distance import create_distance
from .species_info import SpeciesInfo
def update_species(state, randkey, fitness):
@@ -18,15 +19,9 @@ def update_species(state, randkey, fitness):
# sort species_info by their fitness. (push nan to the end)
sort_indices = jnp.argsort(species_fitness)[::-1]
center_nodes = state.center_genomes.nodes[sort_indices]
center_conns = state.center_genomes.conns[sort_indices]
state = state.update(
species_keys=state.species_keys[sort_indices],
best_fitness=state.best_fitness[sort_indices],
last_improved=state.last_improved[sort_indices],
member_count=state.member_count[sort_indices],
center_genomes=Genome(center_nodes, center_conns),
species_info=state.species_info[sort_indices],
center_genomes=state.center_genomes[sort_indices],
)
# decide the number of members of each species by their fitness
@@ -45,11 +40,11 @@ def update_species_fitness(state, fitness):
"""
def aux_func(idx):
s_fitness = jnp.where(state.idx2species == state.species_keys[idx], fitness, -jnp.inf)
s_fitness = jnp.where(state.idx2species == state.species_info.species_keys[idx], fitness, -jnp.inf)
f = jnp.max(s_fitness)
return f
return vmap(aux_func)(jnp.arange(state.species_keys.shape[0]))
return vmap(aux_func)(jnp.arange(state.species_info.size()))
def stagnation(state, species_fitness):
@@ -61,7 +56,7 @@ def stagnation(state, species_fitness):
def aux_func(idx):
s_fitness = species_fitness[idx]
sk, bf, li = state.species_keys[idx], state.best_fitness[idx], state.last_improved[idx]
sk, bf, li, _ = state.species_info.get(idx)
st = (s_fitness <= bf) & (state.generation - li > state.max_stagnation)
li = jnp.where(s_fitness > bf, state.generation, li)
bf = jnp.where(s_fitness > bf, s_fitness, bf)
@@ -78,18 +73,19 @@ def stagnation(state, species_fitness):
species_keys = jnp.where(spe_st, jnp.nan, species_keys)
best_fitness = jnp.where(spe_st, jnp.nan, best_fitness)
last_improved = jnp.where(spe_st, jnp.nan, last_improved)
member_count = jnp.where(spe_st, jnp.nan, state.member_count)
member_count = jnp.where(spe_st, jnp.nan, state.species_info.member_count)
species_fitness = jnp.where(spe_st, -jnp.inf, species_fitness)
species_info = SpeciesInfo(species_keys, best_fitness, last_improved, member_count)
# TODO: Simplify the coded
center_nodes = jnp.where(spe_st[:, None, None], jnp.nan, state.center_genomes.nodes)
center_conns = jnp.where(spe_st[:, None, None], jnp.nan, state.center_genomes.conns)
state = state.update(
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
member_count=member_count,
center_genomes=state.center_genomes.update(center_nodes, center_conns)
species_info=species_info,
center_genomes=Genome(center_nodes, center_conns)
)
return state, species_fitness
@@ -103,18 +99,20 @@ def cal_spawn_numbers(state):
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
"""
is_species_valid = ~jnp.isnan(state.species_keys)
species_keys = state.species_info.species_keys
is_species_valid = ~jnp.isnan(species_keys)
valid_species_num = jnp.sum(is_species_valid)
denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6
rank_score = valid_species_num - jnp.arange(state.species_keys.shape[0]) # obtain [3, 2, 1]
rank_score = valid_species_num - jnp.arange(species_keys.shape[0]) # obtain [3, 2, 1]
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
target_spawn_number = jnp.floor(spawn_number_rate * state.P) # calculate member
# Avoid too much variation of numbers in a species
previous_size = state.member_count
previous_size = state.species_info.member_count
spawn_number = previous_size + (target_spawn_number - previous_size) * state.spawn_number_change_rate
# jax.debug.print("previous_size: {}, spawn_number: {}", previous_size, spawn_number)
spawn_number = spawn_number.astype(jnp.int32)
@@ -127,14 +125,14 @@ def cal_spawn_numbers(state):
def create_crossover_pair(state, randkey, spawn_number, fitness):
species_size = state.species_keys.shape[0]
species_size = state.species_info.size()
pop_size = fitness.shape[0]
s_idx = jnp.arange(species_size)
p_idx = jnp.arange(pop_size)
# def aux_func(key, idx):
def aux_func(key, idx):
members = state.idx2species == state.species_keys[idx]
members = state.idx2species == state.species_info.species_keys[idx]
members_num = jnp.sum(members)
members_fitness = jnp.where(members, fitness, -jnp.inf)
@@ -176,7 +174,7 @@ def create_speciate(gene_type: Type[Gene]):
distance = create_distance(gene_type)
def speciate(state):
pop_size, species_size = state.idx2species.shape[0], state.species_keys.shape[0]
pop_size, species_size = state.idx2species.shape[0], state.species_info.size()
# prepare distance functions
o2p_distance_func = vmap(distance, in_axes=(None, None, 0)) # one to population
@@ -191,25 +189,23 @@ def create_speciate(gene_type: Type[Gene]):
def cond_func(carry):
i, i2s, cgs, o2c = carry
return (i < species_size) & (~jnp.isnan(state.species_keys[i])) # current species is existing
return (i < species_size) & (~jnp.isnan(state.species_info.species_keys[i])) # current species is existing
def body_func(carry):
i, i2s, cgs, o2c = carry
distances = o2p_distance_func(state, Genome(cgs.nodes[i], cgs.conns[i]), state.pop_genomes)
distances = o2p_distance_func(state, cgs[i], state.pop_genomes)
# find the closest one
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
# jax.debug.print("closest_idx: {}", closest_idx)
i2s = i2s.at[closest_idx].set(state.species_keys[i])
cn = cgs.nodes.at[i].set(state.pop_genomes.nodes[closest_idx])
cc = cgs.conns.at[i].set(state.pop_genomes.conns[closest_idx])
i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i])
cgs = cgs.set(i, state.pop_genomes[closest_idx])
# the genome with closest_idx will become the new center, thus its distance to center is 0.
o2c = o2c.at[closest_idx].set(0)
return i + 1, i2s, Genome(cn, cc), o2c
return i + 1, i2s, cgs, o2c
_, idx2species, center_genomes, o2c_distances = \
jax.lax.while_loop(cond_func, body_func, (0, idx2species, state.center_genomes, o2c_distances))
@@ -247,15 +243,13 @@ def create_speciate(gene_type: Type[Gene]):
idx = fetch_first(jnp.isnan(i2s))
# assign it to the new species
# [key, best score, last update generation, members_count]
# [key, best score, last update generation, member_count]
sk = sk.at[i].set(nsk)
i2s = i2s.at[idx].set(nsk)
o2c = o2c.at[idx].set(0)
# update center genomes
cn = cgs.nodes.at[i].set(state.pop_genomes.nodes[idx])
cc = cgs.conns.at[i].set(state.pop_genomes.conns[idx])
cgs = Genome(cn, cc)
cgs = cgs.set(i, state.pop_genomes[idx])
i2s, o2c = speciate_by_threshold(i, i2s, cgs, sk, o2c)
@@ -273,8 +267,7 @@ def create_speciate(gene_type: Type[Gene]):
def speciate_by_threshold(i, i2s, cgs, sk, o2c):
# distance between such center genome and ppo genomes
center = Genome(cgs.nodes[i], cgs.conns[i])
o2p_distance = o2p_distance_func(state, center, state.pop_genomes)
o2p_distance = o2p_distance_func(state, cgs[i], state.pop_genomes)
close_enough_mask = o2p_distance < state.compatibility_threshold
# when a genome is not assigned or the distance between its current center is bigger than this center
@@ -294,32 +287,31 @@ def create_speciate(gene_type: Type[Gene]):
_, idx2species, center_genomes, species_keys, _, next_species_key = jax.lax.while_loop(
cond_func,
body_func,
(0, state.idx2species, state.center_genomes, state.species_keys, o2c_distances, state.next_species_key)
(0, state.idx2species, state.center_genomes, state.species_info.species_keys, o2c_distances, state.next_species_key)
)
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition can only happen when the number of species is reached species upper bounds
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
# complete info of species which is created in this generation
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.best_fitness)
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.best_fitness)
last_improved = jnp.where(new_created_mask, state.generation, state.last_improved)
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species_info.best_fitness)
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species_info.best_fitness)
last_improved = jnp.where(new_created_mask, state.generation, state.species_info.last_improved)
# update members count
def count_members(idx):
key = species_keys[idx]
count = jnp.sum(idx2species == key)
count = jnp.sum(idx2species == key, dtype=jnp.float32)
count = jnp.where(jnp.isnan(key), jnp.nan, count)
return count
member_count = vmap(count_members)(jnp.arange(species_size))
return state.update(
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
members_count=member_count,
species_info = SpeciesInfo(species_keys, best_fitness, last_improved, member_count),
idx2species=idx2species,
center_genomes=center_genomes,
next_species_key=next_species_key

View File

@@ -0,0 +1,55 @@
from jax.tree_util import register_pytree_node_class
import numpy as np
import jax.numpy as jnp
@register_pytree_node_class
class SpeciesInfo:
def __init__(self, species_keys, best_fitness, last_improved, member_count):
self.species_keys = species_keys
self.best_fitness = best_fitness
self.last_improved = last_improved
self.member_count = member_count
@classmethod
def initialize(cls, state):
species_keys = np.full((state.S,), np.nan, dtype=np.float32)
best_fitness = np.full((state.S,), np.nan, dtype=np.float32)
last_improved = np.full((state.S,), np.nan, dtype=np.float32)
member_count = np.full((state.S,), np.nan, dtype=np.float32)
species_keys[0] = 0
best_fitness[0] = -np.inf
last_improved[0] = 0
member_count[0] = state.P
return cls(species_keys, best_fitness, last_improved, member_count)
def __getitem__(self, i):
return SpeciesInfo(self.species_keys[i], self.best_fitness[i], self.last_improved[i], self.member_count[i])
def get(self, i):
return self.species_keys[i], self.best_fitness[i], self.last_improved[i], self.member_count[i]
def set(self, idx, value):
species_keys = self.species_keys.at[idx].set(value[0])
best_fitness = self.best_fitness.at[idx].set(value[1])
last_improved = self.last_improved.at[idx].set(value[2])
member_count = self.member_count.at[idx].set(value[3])
return SpeciesInfo(species_keys, best_fitness, last_improved, member_count)
def remove(self, idx):
return self.set(idx, jnp.array([jnp.nan] * 4))
def size(self):
return self.species_keys.shape[0]
def tree_flatten(self):
children = self.species_keys, self.best_fitness, self.last_improved, self.member_count
aux_data = None
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)

View File

@@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import Union
@dataclass(frozen=True)
@@ -7,34 +6,31 @@ class BasicConfig:
seed: int = 42
fitness_target: float = 1
generation_limit: int = 1000
num_inputs: int = 2
num_outputs: int = 1
pop_size: int = 100
def __post_init__(self):
assert self.num_inputs > 0, "the inputs number of the problem must be greater than 0"
assert self.num_outputs > 0, "the outputs number of the problem must be greater than 0"
assert self.pop_size > 0, "the population size must be greater than 0"
@dataclass(frozen=True)
class NeatConfig:
network_type: str = "feedforward"
activate_times: Union[int, None] = None # None means the network is feedforward
maximum_nodes: int = 100
maximum_conns: int = 50
inputs: int = 2
outputs: int = 1
maximum_nodes: int = 50
maximum_conns: int = 100
maximum_species: int = 10
# genome config
compatibility_disjoint: float = 1
compatibility_weight: float = 0.5
conn_add: float = 0.4
conn_delete: float = 0.4
conn_delete: float = 0
node_add: float = 0.2
node_delete: float = 0.2
node_delete: float = 0
# species config
compatibility_threshold: float = 3.0
compatibility_threshold: float = 3.5
species_elitism: int = 2
max_stagnation: int = 15
genome_elitism: int = 2
@@ -44,11 +40,9 @@ class NeatConfig:
def __post_init__(self):
assert self.network_type in ["feedforward", "recurrent"], "the network type must be feedforward or recurrent"
if self.network_type == "feedforward":
assert self.activate_times is None, "the activate times of feedforward network must be None"
else:
assert isinstance(self.activate_times, int), "the activate times of recurrent network must be int"
assert self.activate_times > 0, "the activate times of recurrent network must be greater than 0"
assert self.inputs > 0, "the inputs number of neat must be greater than 0"
assert self.outputs > 0, "the outputs number of neat must be greater than 0"
assert self.maximum_nodes > 0, "the maximum nodes must be greater than 0"
assert self.maximum_conns > 0, "the maximum connections must be greater than 0"
@@ -56,10 +50,10 @@ class NeatConfig:
assert self.compatibility_disjoint > 0, "the compatibility disjoint must be greater than 0"
assert self.compatibility_weight > 0, "the compatibility weight must be greater than 0"
assert self.conn_add > 0, "the connection add probability must be greater than 0"
assert self.conn_delete > 0, "the connection delete probability must be greater than 0"
assert self.node_add > 0, "the node add probability must be greater than 0"
assert self.node_delete > 0, "the node delete probability must be greater than 0"
assert self.conn_add >= 0, "the connection add probability must be greater than 0"
assert self.conn_delete >= 0, "the connection delete probability must be greater than 0"
assert self.node_add >= 0, "the node add probability must be greater than 0"
assert self.node_delete >= 0, "the node delete probability must be greater than 0"
assert self.compatibility_threshold > 0, "the compatibility threshold must be greater than 0"
assert self.species_elitism > 0, "the species elitism must be greater than 0"
@@ -77,18 +71,21 @@ class HyperNeatConfig:
activation: str = "sigmoid"
aggregation: str = "sum"
activate_times: int = 5
inputs: int = 2
outputs: int = 1
def __post_init__(self):
assert self.below_threshold > 0, "the below threshold must be greater than 0"
assert self.max_weight > 0, "the max weight must be greater than 0"
assert self.activate_times > 0, "the activate times must be greater than 0"
assert self.inputs > 0, "the inputs number of hyper neat must be greater than 0"
assert self.outputs > 0, "the outputs number of hyper neat must be greater than 0"
@dataclass(frozen=True)
class GeneConfig:
pass
@dataclass(frozen=True)
class SubstrateConfig:
pass

View File

@@ -2,4 +2,4 @@ from .algorithm import Algorithm
from .state import State
from .genome import Genome
from .gene import Gene
from .substrate import Substrate

View File

@@ -40,7 +40,7 @@ class Gene:
@staticmethod
def forward_transform(state: State, genome: Genome):
return jnp.zeros(0) # transformed
@staticmethod
def create_forward(state: State, config: GeneConfig):
return lambda *args: args # forward function

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
from jax.tree_util import register_pytree_node_class
from jax import numpy as jnp
@@ -11,6 +13,15 @@ class Genome:
self.nodes = nodes
self.conns = conns
def __repr__(self):
return f"Genome(nodes={self.nodes}, conns={self.conns})"
def __getitem__(self, idx):
return self.__class__(self.nodes[idx], self.conns[idx])
def set(self, idx, value: Genome):
return self.__class__(self.nodes.at[idx].set(value.nodes), self.conns.at[idx].set(value.conns))
def update(self, nodes, conns):
return self.__class__(nodes, conns)
@@ -73,5 +84,4 @@ class Genome:
def tree_unflatten(cls, aux_data, children):
return cls(*children)
def __repr__(self):
return f"Genome(nodes={self.nodes}, conns={self.conns})"

8
core/substrate.py Normal file
View File

@@ -0,0 +1,8 @@
from config import SubstrateConfig
class Substrate:
@staticmethod
def setup(state, config: SubstrateConfig):
return state

View File

@@ -12,11 +12,10 @@ print(asdict(config))
pop_nodes = jnp.ones((Config.basic.pop_size, Config.neat.maximum_nodes, 3))
pop_conns = jnp.ones((Config.basic.pop_size, Config.neat.maximum_conns, 5))
pop_genomes1 = jax.vmap(Genome)(pop_nodes, pop_conns)
pop_genomes2 = Genome(pop_nodes, pop_conns)
pop_genomes = Genome(pop_nodes, pop_conns)
print(pop_genomes)
print(pop_genomes[0])
print(pop_genomes[0: 20])
@jax.vmap
def pop_cnts(genome):

View File

@@ -15,5 +15,9 @@ def func(d):
d = {0: 1, 1: NetworkType.ANN.value}
n = None
print(n or d)
print(d)
print(func(d))

View File

@@ -1,10 +1,9 @@
import jax
import numpy as np
from config import Config, BasicConfig
from config import Config, BasicConfig, NeatConfig
from pipeline import Pipeline
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from algorithm.neat.neat import NEAT
from algorithm import NEAT, NormalGene, NormalGeneConfig
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
@@ -23,8 +22,14 @@ def evaluate(forward_func):
if __name__ == '__main__':
config = Config(
basic=BasicConfig(fitness_target=4),
gene=NormalGeneConfig()
basic=BasicConfig(
fitness_target=3.99999,
pop_size=10000
),
neat=NeatConfig(
maximum_nodes=50,
maximum_conns=100,
)
)
algorithm = NEAT(config, NormalGene)
pipeline = Pipeline(config, algorithm)

49
examples/xor_hyperNEAT.py Normal file
View File

@@ -0,0 +1,49 @@
import jax
import numpy as np
from config import Config, BasicConfig, NeatConfig
from pipeline import Pipeline
from algorithm import NEAT, RecurrentGene, RecurrentGeneConfig
from algorithm import HyperNEAT, NormalSubstrate, NormalSubstrateConfig
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses
if __name__ == '__main__':
config = Config(
basic=BasicConfig(
fitness_target=3.99999,
pop_size=1000
),
neat=NeatConfig(
network_type="recurrent",
maximum_nodes=50,
maximum_conns=100,
inputs=4,
outputs=1
),
gene=RecurrentGeneConfig(
activation_default="tanh",
activation_options=("tanh", ),
),
substrate=NormalSubstrateConfig(),
)
neat = NEAT(config, RecurrentGene)
hyperNEAT = HyperNEAT(config, neat, NormalSubstrate)
pipeline = Pipeline(config, hyperNEAT)
pipeline.auto_run(evaluate)

39
examples/xor_recurrent.py Normal file
View File

@@ -0,0 +1,39 @@
import jax
import numpy as np
from config import Config, BasicConfig, NeatConfig
from pipeline import Pipeline
from algorithm import NEAT, RecurrentGene, RecurrentGeneConfig
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses
if __name__ == '__main__':
config = Config(
basic=BasicConfig(
fitness_target=3.99999,
pop_size=10000
),
neat=NeatConfig(
network_type="recurrent",
maximum_nodes=50,
maximum_conns=100
),
gene=RecurrentGeneConfig()
)
algorithm = NEAT(config, RecurrentGene)
pipeline = Pipeline(config, algorithm)
pipeline.auto_run(evaluate)

View File

@@ -11,7 +11,7 @@ from core import Algorithm, Genome
class Pipeline:
"""
Neat algorithm pipeline.
Simple pipeline.
"""
def __init__(self, config: Config, algorithm: Algorithm):
@@ -38,7 +38,9 @@ class Pipeline:
return lambda inputs: self.pop_batch_forward_func(inputs, pop_transforms)
def tell(self, fitness):
self.state = self.tell_func(self.state, fitness)
# self.state = self.tell_func(self.state, fitness)
new_state = self.tell_func(self.state, fitness)
self.state = new_state
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config.basic.generation_limit):
@@ -73,9 +75,9 @@ class Pipeline:
self.best_fitness = fitnesses[max_idx]
self.best_genome = Genome(self.state.pop_genomes.nodes[max_idx], self.state.pop_genomes.conns[max_idx])
member_count = jax.device_get(self.state.member_count)
member_count = jax.device_get(self.state.species_info.member_count)
species_sizes = [int(i) for i in member_count if i > 0]
print(f"Generation: {self.state.generation}",
f"species: {len(species_sizes)}, {species_sizes}",
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Cost time: {cost_time}")
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")