finish all refactoring

This commit is contained in:
wls2002
2024-02-21 15:41:08 +08:00
parent aac41a089d
commit 6970e6a6d5
44 changed files with 856 additions and 825 deletions

View File

@@ -16,9 +16,30 @@ class BaseAlgorithm:
"""update the state of the algorithm"""
raise NotImplementedError
def transform(self, state: State):
def transform(self, individual):
"""transform the genome into a neural network"""
raise NotImplementedError
def forward(self, inputs, transformed):
raise NotImplementedError
raise NotImplementedError
@property
def num_inputs(self):
raise NotImplementedError
@property
def num_outputs(self):
raise NotImplementedError
@property
def pop_size(self):
raise NotImplementedError
def member_count(self, state: State):
# to analysis the species
raise NotImplementedError
def generation(self, state: State):
# to analysis the algorithm
raise NotImplementedError

View File

@@ -0,0 +1,2 @@
from .hyperneat import HyperNEAT
from .substrate import BaseSubstrate, DefaultSubstrate, FullSubstrate

View File

@@ -0,0 +1,116 @@
import jax, jax.numpy as jnp
from utils import State, Act, Agg
from .. import BaseAlgorithm, NEAT
from ..neat.gene import BaseNodeGene, BaseConnGene
from ..neat.genome import RecurrentGenome
from .substrate import *
class HyperNEAT(BaseAlgorithm):
def __init__(
self,
substrate: BaseSubstrate,
neat: NEAT,
below_threshold: float = 0.3,
max_weight: float = 5.,
activation=Act.sigmoid,
aggregation=Agg.sum,
activate_time: int = 10,
):
assert substrate.query_coors.shape[1] == neat.num_inputs, \
"Substrate input size should be equal to NEAT input size"
self.substrate = substrate
self.neat = neat
self.below_threshold = below_threshold
self.max_weight = max_weight
self.hyper_genome = RecurrentGenome(
num_inputs=substrate.num_inputs,
num_outputs=substrate.num_outputs,
max_nodes=substrate.nodes_cnt,
max_conns=substrate.conns_cnt,
node_gene=HyperNodeGene(activation, aggregation),
conn_gene=HyperNEATConnGene(),
activate_time=activate_time,
)
def setup(self, randkey):
return State(
neat_state=self.neat.setup(randkey)
)
def ask(self, state: State):
return self.neat.ask(state.neat_state)
def tell(self, state: State, fitness):
return state.update(
neat_state=self.neat.tell(state.neat_state, fitness)
)
def transform(self, individual):
transformed = self.neat.transform(individual)
query_res = jax.vmap(self.neat.forward, in_axes=(0, None))(self.substrate.query_coors, transformed)
# mute the connection with weight below threshold
query_res = jnp.where(
(-self.below_threshold < query_res) & (query_res < self.below_threshold),
0.,
query_res
)
# make query res in range [-max_weight, max_weight]
query_res = jnp.where(query_res > 0, query_res - self.below_threshold, query_res)
query_res = jnp.where(query_res < 0, query_res + self.below_threshold, query_res)
query_res = query_res / (1 - self.below_threshold) * self.max_weight
h_nodes, h_conns = self.substrate.make_nodes(query_res), self.substrate.make_conn(query_res)
return self.hyper_genome.transform(h_nodes, h_conns)
def forward(self, inputs, transformed):
# add bias
inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])])
return self.hyper_genome.forward(inputs_with_bias, transformed)
@property
def num_inputs(self):
return self.substrate.num_inputs - 1 # remove bias
@property
def num_outputs(self):
return self.substrate.num_outputs
@property
def pop_size(self):
return self.neat.pop_size
def member_count(self, state: State):
return self.neat.member_count(state.neat_state)
def generation(self, state: State):
return self.neat.generation(state.neat_state)
class HyperNodeGene(BaseNodeGene):
def __init__(self,
activation=Act.sigmoid,
aggregation=Agg.sum,
):
super().__init__()
self.activation = activation
self.aggregation = aggregation
def forward(self, attrs, inputs):
return self.activation(
self.aggregation(inputs)
)
class HyperNEATConnGene(BaseConnGene):
custom_attrs = ['weight']
def forward(self, attrs, inputs):
weight = attrs[0]
return inputs * weight

View File

@@ -0,0 +1,3 @@
from .base import BaseSubstrate
from .default import DefaultSubstrate
from .full import FullSubstrate

View File

@@ -0,0 +1,27 @@
class BaseSubstrate:
def make_nodes(self, query_res):
raise NotImplementedError
def make_conn(self, query_res):
raise NotImplementedError
@property
def query_coors(self):
raise NotImplementedError
@property
def num_inputs(self):
raise NotImplementedError
@property
def num_outputs(self):
raise NotImplementedError
@property
def nodes_cnt(self):
raise NotImplementedError
@property
def conns_cnt(self):
raise NotImplementedError

View File

@@ -0,0 +1,38 @@
import jax.numpy as jnp
from . import BaseSubstrate
class DefaultSubstrate(BaseSubstrate):
def __init__(self, num_inputs, num_outputs, coors, nodes, conns):
self.inputs = num_inputs
self.outputs = num_outputs
self.coors = jnp.array(coors)
self.nodes = jnp.array(nodes)
self.conns = jnp.array(conns)
def make_nodes(self, query_res):
return self.nodes
def make_conn(self, query_res):
return self.conns.at[:, 3:].set(query_res) # change weight
@property
def query_coors(self):
return self.coors
@property
def num_inputs(self):
return self.inputs
@property
def num_outputs(self):
return self.outputs
@property
def nodes_cnt(self):
return self.nodes.shape[0]
@property
def conns_cnt(self):
return self.conns.shape[0]

View File

@@ -0,0 +1,76 @@
import numpy as np
from .default import DefaultSubstrate
class FullSubstrate(DefaultSubstrate):
def __init__(self,
input_coors=((-1, -1), (0, -1), (1, -1)),
hidden_coors=((-1, 0), (0, 0), (1, 0)),
output_coors=((0, 1),),
):
query_coors, nodes, conns = analysis_substrate(input_coors, output_coors, hidden_coors)
super().__init__(
len(input_coors),
len(output_coors),
query_coors,
nodes,
conns
)
def analysis_substrate(input_coors, output_coors, hidden_coors):
input_coors = np.array(input_coors)
output_coors = np.array(output_coors)
hidden_coors = np.array(hidden_coors)
cd = input_coors.shape[1] # coordinate dimensions
si = input_coors.shape[0] # input coordinate size
so = output_coors.shape[0] # output coordinate size
sh = hidden_coors.shape[0] # hidden coordinate size
input_idx = np.arange(si)
output_idx = np.arange(si, si + so)
hidden_idx = np.arange(si + so, si + so + sh)
total_conns = si * sh + sh * sh + sh * so
query_coors = np.zeros((total_conns, cd * 2))
correspond_keys = np.zeros((total_conns, 2))
# connect input to hidden
aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, input_coors, hidden_coors)
query_coors[0: si * sh, :] = aux_coors
correspond_keys[0: si * sh, :] = aux_keys
# connect hidden to hidden
aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, hidden_coors, hidden_coors)
query_coors[si * sh: si * sh + sh * sh, :] = aux_coors
correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys
# connect hidden to output
aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, hidden_coors, output_coors)
query_coors[si * sh + sh * sh:, :] = aux_coors
correspond_keys[si * sh + sh * sh:, :] = aux_keys
nodes = np.concatenate((input_idx, output_idx, hidden_idx))[..., np.newaxis]
conns = np.zeros((correspond_keys.shape[0], 4), dtype=np.float32) # input_idx, output_idx, enabled, weight
conns[:, 0:2] = correspond_keys
conns[:, 2] = 1 # enabled is True
return query_coors, nodes, conns
def cartesian_product(keys1, keys2, coors1, coors2):
len1 = keys1.shape[0]
len2 = keys2.shape[0]
repeated_coors1 = np.repeat(coors1, len2, axis=0)
repeated_keys1 = np.repeat(keys1, len2)
tiled_coors2 = np.tile(coors2, (len1, 1))
tiled_keys2 = np.tile(keys2, len1)
new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1)
correspond_keys = np.column_stack((repeated_keys1, tiled_keys2))
return new_coors, correspond_keys

View File

@@ -1,3 +1,5 @@
from .gene import *
from .genome import *
from .species import *
from .neat import NEAT

View File

@@ -3,7 +3,8 @@ import jax, jax.numpy as jnp
from .base import BaseCrossover
class DefaultCrossover(BaseCrossover):
def __call__(self, randkey, genome, nodes1, nodes2, conns1, conns2):
def __call__(self, randkey, genome, nodes1, conns1, nodes2, conns2):
"""
use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)

View File

@@ -92,7 +92,7 @@ class DefaultMutation(BaseMutation):
return nodes_, conns_
def successful():
return nodes_, genome.add_conn(conns_, i_key, o_key, True, genome.conns.new_custom_attrs())
return nodes_, genome.add_conn(conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs())
def already_exist():
return nodes_, conns_.at[conn_pos, 2].set(True)
@@ -105,11 +105,12 @@ class DefaultMutation(BaseMutation):
return jax.lax.cond(
is_already_exist,
already_exist,
jax.lax.cond(
is_cycle,
nothing,
successful
)
lambda:
jax.lax.cond(
is_cycle,
nothing,
successful
)
)
elif genome.network_type == 'recurrent':
@@ -138,23 +139,23 @@ class DefaultMutation(BaseMutation):
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
def no(k, g):
return g
def no(key_, nodes_, conns_):
return nodes_, conns_
genome = jax.lax.cond(r1 < self.node_add, mutate_add_node, no, k1, nodes, conns)
genome = jax.lax.cond(r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns)
genome = jax.lax.cond(r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns)
genome = jax.lax.cond(r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns)
nodes, conns = jax.lax.cond(r1 < self.node_add, mutate_add_node, no, k1, nodes, conns)
nodes, conns = jax.lax.cond(r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns)
nodes, conns = jax.lax.cond(r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns)
nodes, conns = jax.lax.cond(r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns)
return genome
return nodes, conns
def mutate_values(self, randkey, genome, nodes, conns):
k1, k2 = jax.random.split(randkey, num=2)
nodes_keys = jax.random.split(k1, num=genome.nodes.shape[0])
conns_keys = jax.random.split(k2, num=genome.conns.shape[0])
nodes_keys = jax.random.split(k1, num=nodes.shape[0])
conns_keys = jax.random.split(k2, num=conns.shape[0])
new_nodes = jax.vmap(genome.nodes.mutate, in_axes=(0, 0))(nodes_keys, nodes)
new_conns = jax.vmap(genome.conns.mutate, in_axes=(0, 0))(conns_keys, conns)
new_nodes = jax.vmap(genome.node_gene.mutate, in_axes=(0, 0))(nodes_keys, nodes)
new_conns = jax.vmap(genome.conn_gene.mutate, in_axes=(0, 0))(conns_keys, conns)
# nan nodes not changed
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)

View File

@@ -7,8 +7,7 @@ from . import BaseConnGene
class DefaultConnGene(BaseConnGene):
"Default connection gene, with the same behavior as in NEAT-python."
fixed_attrs = ['input_index', 'output_index', 'enabled']
attrs = ['weight']
custom_attrs = ['weight']
def __init__(
self,

View File

@@ -9,7 +9,6 @@ from . import BaseNodeGene
class DefaultNodeGene(BaseNodeGene):
"Default node gene, with the same behavior as in NEAT-python."
fixed_attrs = ['index']
custom_attrs = ['bias', 'response', 'aggregation', 'activation']
def __init__(
@@ -82,8 +81,8 @@ class DefaultNodeGene(BaseNodeGene):
return (
jnp.abs(node1[1] - node2[1]) +
jnp.abs(node1[2] - node2[2]) +
node1[3] != node2[3] +
node1[4] != node2[4]
(node1[3] != node2[3]) +
(node1[4] != node2[4])
)
def forward(self, attrs, inputs):

View File

@@ -4,7 +4,6 @@ from utils import fetch_first
class BaseGenome:
network_type = None
def __init__(

View File

@@ -1,3 +1,5 @@
from typing import Callable
import jax, jax.numpy as jnp
from utils import unflatten_conns, topological_sort, I_INT
@@ -13,10 +15,20 @@ class DefaultGenome(BaseGenome):
def __init__(self,
num_inputs: int,
num_outputs: int,
max_nodes=5,
max_conns=4,
node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(),
output_transform: Callable = None
):
super().__init__(num_inputs, num_outputs, node_gene, conn_gene)
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
if output_transform is not None:
try:
aux = output_transform(jnp.zeros(num_outputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
self.output_transform = output_transform
def transform(self, nodes, conns):
u_conns = unflatten_conns(nodes, conns)
@@ -72,4 +84,7 @@ class DefaultGenome(BaseGenome):
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
return vals[self.output_idx]
if self.output_transform is None:
return vals[self.output_idx]
else:
return self.output_transform(vals[self.output_idx])

View File

@@ -13,11 +13,13 @@ class RecurrentGenome(BaseGenome):
def __init__(self,
num_inputs: int,
num_outputs: int,
max_nodes: int,
max_conns: int,
node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(),
activate_time: int = 10,
):
super().__init__(num_inputs, num_outputs, node_gene, conn_gene)
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
self.activate_time = activate_time
def transform(self, nodes, conns):

View File

@@ -1,20 +1,19 @@
import jax, jax.numpy as jnp
from utils import State
from .. import BaseAlgorithm
from .genome import *
from .species import *
from .ga import *
class NEAT(BaseAlgorithm):
def __init__(
self,
genome: BaseGenome,
species: BaseSpecies,
mutation: BaseMutation = DefaultMutation(),
crossover: BaseCrossover = DefaultCrossover(),
):
self.genome = genome
self.genome = species.genome
self.species = species
self.mutation = mutation
self.crossover = crossover
@@ -23,14 +22,14 @@ class NEAT(BaseAlgorithm):
k1, k2 = jax.random.split(randkey, 2)
return State(
randkey=k1,
generation=0,
next_node_key=max(*self.genome.input_idx, *self.genome.output_idx) + 2,
generation=jnp.array(0.),
next_node_key=jnp.array(max(*self.genome.input_idx, *self.genome.output_idx) + 2, dtype=jnp.float32),
# inputs nodes, output nodes, 1 hidden node
species=self.species.setup(k2),
)
def ask(self, state: State):
return self.species.ask(state)
return self.species.ask(state.species)
def tell(self, state: State, fitness):
k1, k2, randkey = jax.random.split(state.randkey, 3)
@@ -40,25 +39,39 @@ class NEAT(BaseAlgorithm):
randkey=randkey
)
state, winner, loser, elite_mask = self.species.update_species(state, fitness, state.generation)
species_state, winner, loser, elite_mask = self.species.update_species(state.species, fitness, state.generation)
state = state.update(species=species_state)
state = self.create_next_generation(k2, state, winner, loser, elite_mask)
state = self.species.speciate(state, state.generation)
species_state = self.species.speciate(state.species, state.generation)
state = state.update(species=species_state)
return state
def transform(self, state: State):
def transform(self, individual):
"""transform the genome into a neural network"""
raise NotImplementedError
nodes, conns = individual
return self.genome.transform(nodes, conns)
def forward(self, inputs, transformed):
raise NotImplementedError
return self.genome.forward(inputs, transformed)
@property
def num_inputs(self):
return self.genome.num_inputs
@property
def num_outputs(self):
return self.genome.num_outputs
@property
def pop_size(self):
return self.species.pop_size
def create_next_generation(self, randkey, state, winner, loser, elite_mask):
# prepare random keys
pop_size = self.species.pop_size
new_node_keys = jnp.arange(pop_size) + state.species.next_node_key
new_node_keys = jnp.arange(pop_size) + state.next_node_key
k1, k2 = jax.random.split(randkey, 2)
crossover_rand_keys = jax.random.split(k1, pop_size)
@@ -69,11 +82,11 @@ class NEAT(BaseAlgorithm):
# batch crossover
n_nodes, n_conns = (jax.vmap(self.crossover, in_axes=(0, None, 0, 0, 0, 0))
(crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc))
(crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc))
# batch mutation
m_n_nodes, m_n_conns = (jax.vmap(self.mutation, in_axes=(0, None, 0, 0, 0))
(mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys))
(mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys))
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
@@ -92,3 +105,9 @@ class NEAT(BaseAlgorithm):
next_node_key=next_node_key,
)
def member_count(self, state: State):
return state.species.member_count
def generation(self, state: State):
# to analysis the algorithm
return state.generation

View File

@@ -2,9 +2,10 @@ import numpy as np
import jax, jax.numpy as jnp
from utils import State, rank_elements, argmin_with_mask, fetch_first
from ..genome import BaseGenome
from .base import BaseSpecies
class DefaultSpecies:
class DefaultSpecies(BaseSpecies):
def __init__(self,
genome: BaseGenome,
@@ -18,9 +19,8 @@ class DefaultSpecies:
genome_elitism: int = 2,
survival_threshold: float = 0.2,
min_species_size: int = 1,
compatibility_threshold: float = 3.5
compatibility_threshold: float = 3.
):
self.genome = genome
self.pop_size = pop_size
self.species_size = species_size
@@ -59,8 +59,12 @@ class DefaultSpecies:
center_nodes = center_nodes.at[0].set(pop_nodes[0])
center_conns = center_conns.at[0].set(pop_conns[0])
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns))
return State(
randkey=randkey,
pop_nodes=pop_nodes,
pop_conns=pop_conns,
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
@@ -68,7 +72,7 @@ class DefaultSpecies:
idx2species=idx2species,
center_nodes=center_nodes,
center_conns=center_conns,
next_species_key=1, # 0 is reserved for the first species
next_species_key=jnp.array(1), # 0 is reserved for the first species
)
def ask(self, state):
@@ -99,7 +103,7 @@ class DefaultSpecies:
# crossover info
winner, loser, elite_mask = self.create_crossover_pair(state, k1, spawn_number, fitness)
return state(randkey=k2), winner, loser, elite_mask
return state.update(randkey=k2), winner, loser, elite_mask
def update_species_fitness(self, state, fitness):
"""
@@ -156,17 +160,17 @@ class DefaultSpecies:
jnp.nan, # last_improved
jnp.nan, # member_count
-jnp.inf, # species_fitness
jnp.full_like(center_nodes[idx], jnp.nan), # center_nodes
jnp.full_like(center_conns[idx], jnp.nan), # center_conns
jnp.full_like(state.center_nodes[idx], jnp.nan), # center_nodes
jnp.full_like(state.center_conns[idx], jnp.nan), # center_conns
), # stagnation species
lambda: (
species_keys[idx],
state.species_keys[idx],
best_fitness[idx],
last_improved[idx],
state.member_count[idx],
species_fitness[idx],
center_nodes[idx],
center_conns[idx]
state.center_nodes[idx],
state.center_conns[idx]
) # not stagnation species
)
@@ -216,7 +220,7 @@ class DefaultSpecies:
spawn_number = spawn_number.astype(jnp.int32)
# must control the sum of spawn_number to be equal to pop_size
error = state.P - jnp.sum(spawn_number)
error = self.pop_size - jnp.sum(spawn_number)
# add error to the first species to control the sum of spawn_number
spawn_number = spawn_number.at[0].add(error)
@@ -287,14 +291,14 @@ class DefaultSpecies:
def body_func(carry):
i, i2s, cns, ccs, o2c = carry
distances = o2p_distance_func(cns, ccs, state.pop_nodes, state.pop_conns)
distances = o2p_distance_func(cns[i], ccs[i], state.pop_nodes, state.pop_conns)
# find the closest one
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i])
cns = cns.set(i, state.pop_nodes[closest_idx])
ccs = ccs.set(i, state.pop_conns[closest_idx])
i2s = i2s.at[closest_idx].set(state.species_keys[i])
cns = cns.at[i].set(state.pop_nodes[closest_idx])
ccs = ccs.at[i].set(state.pop_conns[closest_idx])
# the genome with closest_idx will become the new center, thus its distance to center is 0.
o2c = o2c.at[closest_idx].set(0)
@@ -346,8 +350,8 @@ class DefaultSpecies:
o2c = o2c.at[idx].set(0)
# update center genomes
cns = cns.set(i, state.pop_nodes[idx])
ccs = ccs.set(i, state.pop_conns[idx])
cns = cns.at[i].set(state.pop_nodes[idx])
ccs = ccs.at[i].set(state.pop_conns[idx])
# find the members for the new species
i2s, o2c = speciate_by_threshold(i, i2s, cns, ccs, sk, o2c)
@@ -384,7 +388,7 @@ class DefaultSpecies:
_, idx2species, center_nodes, center_conns, species_keys, _, next_species_key = jax.lax.while_loop(
cond_func,
body_func,
(0, state.idx2species, state.center_nodes, center_conns, state.species_info.species_keys, o2c_distances,
(0, state.idx2species, center_nodes, center_conns, state.species_keys, o2c_distances,
state.next_species_key)
)
@@ -401,8 +405,8 @@ class DefaultSpecies:
def count_members(idx):
return jax.lax.cond(
jnp.isnan(species_keys[idx]), # if the species is not existing
lambda _: jnp.nan, # nan
lambda _: jnp.sum(idx2species == species_keys[idx], dtype=jnp.float32) # count members
lambda: jnp.nan, # nan
lambda: jnp.sum(idx2species == species_keys[idx], dtype=jnp.float32) # count members
)
member_count = jax.vmap(count_members)(self.species_arange)
@@ -422,7 +426,8 @@ class DefaultSpecies:
"""
The distance between two genomes
"""
return self.node_distance(nodes1, nodes2) + self.conn_distance(conns1, conns2)
d = self.node_distance(nodes1, nodes2) + self.conn_distance(conns1, conns2)
return d
def node_distance(self, nodes1, nodes2):
"""
@@ -494,18 +499,18 @@ def initialize_population(pop_size, genome):
o_nodes[input_idx, 0] = genome.input_idx
o_nodes[output_idx, 0] = genome.output_idx
o_nodes[new_node_key, 0] = new_node_key # one hidden node
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = genome.node_gene.new_attrs()
o_nodes[new_node_key, 1:] = genome.node_gene.new_attrs() # one hidden node
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = genome.node_gene.new_custom_attrs()
o_nodes[new_node_key, 1:] = genome.node_gene.new_custom_attrs() # one hidden node
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)] # input nodes to hidden
o_conns[input_idx, 0:2] = input_conns # in key, out key
o_conns[input_idx, 2] = True # enabled
o_conns[input_idx, 3:] = genome.conn_gene.new_conn_attrs()
o_conns[input_idx, 3:] = genome.conn_gene.new_custom_attrs()
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx] # hidden to output nodes
o_conns[output_idx, 0:2] = output_conns # in key, out key
o_conns[output_idx, 2] = True # enabled
o_conns[output_idx, 3:] = genome.conn_gene.new_conn_attrs()
o_conns[output_idx, 3:] = genome.conn_gene.new_custom_attrs()
# repeat origin genome for P times to create population
pop_nodes = np.tile(o_nodes, (pop_size, 1, 1))

View File

@@ -1,38 +1,36 @@
import jax.numpy as jnp
from config import *
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import BraxEnv, BraxConfig
def example_conf():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=10000,
pop_size=100
),
neat=NeatConfig(
inputs=27,
outputs=8,
),
gene=NormalGeneConfig(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
problem=BraxConfig(
env_name="ant"
)
)
from algorithm.neat import *
from problem.rl_env import BraxEnv
from utils import Act
if __name__ == '__main__':
conf = example_conf()
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=27,
num_outputs=8,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
)
),
pop_size=1000,
species_size=10,
),
),
problem=BraxEnv(
env_name='ant',
),
generation_limit=10000,
fitness_target=5000
)
algorithm = NEAT(conf, NormalGene)
pipeline = Pipeline(conf, algorithm, BraxEnv)
# initialize state
state = pipeline.setup()
pipeline.pre_compile(state)
state, best = pipeline.auto_run(state)
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -1,42 +1,36 @@
import jax.numpy as jnp
from config import *
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import BraxEnv, BraxConfig
# ['ant', 'halfcheetah', 'hopper', 'humanoid', 'humanoidstandup', 'inverted_pendulum', 'inverted_double_pendulum', 'pusher', 'reacher', 'walker2d']
def example_conf():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=10000,
generation_limit=10,
pop_size=100
),
neat=NeatConfig(
inputs=17,
outputs=6,
),
gene=NormalGeneConfig(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
problem=BraxConfig(
env_name="halfcheetah"
)
)
from algorithm.neat import *
from problem.rl_env import BraxEnv
from utils import Act
if __name__ == '__main__':
conf = example_conf()
algorithm = NEAT(conf, NormalGene)
pipeline = Pipeline(conf, algorithm, BraxEnv)
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=17,
num_outputs=6,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
)
),
pop_size=1000,
species_size=10,
),
),
problem=BraxEnv(
env_name='halhcheetah',
),
generation_limit=10000,
fitness_target=5000
)
# initialize state
state = pipeline.setup()
pipeline.pre_compile(state)
state, best = pipeline.auto_run(state)
pipeline.show(state, best, save_path="half_cheetah.gif", )
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -1,38 +1,36 @@
import jax.numpy as jnp
from config import *
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import BraxEnv, BraxConfig
def example_conf():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=10000,
pop_size=1000
),
neat=NeatConfig(
inputs=11,
outputs=2,
),
gene=NormalGeneConfig(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
problem=BraxConfig(
env_name="reacher"
)
)
from algorithm.neat import *
from problem.rl_env import BraxEnv
from utils import Act
if __name__ == '__main__':
conf = example_conf()
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=11,
num_outputs=2,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
)
),
pop_size=100,
species_size=10,
),
),
problem=BraxEnv(
env_name='reacher',
),
generation_limit=10000,
fitness_target=5000
)
algorithm = NEAT(conf, NormalGene)
pipeline = Pipeline(conf, algorithm, BraxEnv)
# initialize state
state = pipeline.setup()
pipeline.pre_compile(state)
state, best = pipeline.auto_run(state)
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -1,73 +0,0 @@
import imageio
import jax
import brax
from brax import envs
from brax.io import image
import matplotlib.pyplot as plt
import time
from tqdm import tqdm
import numpy as np
def inference_func(key, *args):
return jax.random.normal(key, shape=(env.action_size,))
env_name = "ant"
backend = "generalized"
env = envs.create(env_name=env_name, backend=backend)
jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.jit(inference_func)
rng = jax.random.PRNGKey(seed=1)
ori_state = jit_env_reset(rng=rng)
state = ori_state
render_history = []
for i in range(100):
act_rng, rng = jax.random.split(rng)
tic = time.time()
act = jit_inference_fn(act_rng, state.obs)
state = jit_env_step(state, act)
print("step time: ", time.time() - tic)
render_history.append(state.pipeline_state)
# img = image.render_array(sys=env.sys, state=pipeline_state, width=512, height=512)
# print("render time: ", time.time() - tic)
# plt.imsave("../images/ant_{}.png".format(i), img)
reward = state.reward
done = state.done
print(i, reward)
render_history = jax.device_get(render_history)
# print(render_history)
imgs = [image.render_array(sys=env.sys, state=s, width=512, height=512) for s in tqdm(render_history)]
# for i, s in enumerate(tqdm(render_history)):
# img = image.render_array(sys=env.sys, state=s, width=512, height=512)
# print(img.shape)
# # print(type(img))
# plt.imsave("../images/ant_{}.png".format(i), img)
def create_gif(image_list, gif_name, duration):
with imageio.get_writer(gif_name, mode='I', duration=duration) as writer:
for image in image_list:
# 确保图像的数据类型正确
formatted_image = np.array(image, dtype=np.uint8)
writer.append_data(formatted_image)
create_gif(imgs, "../images/ant.gif", 0.1)

View File

@@ -1,54 +0,0 @@
import brax
from brax import envs
from brax.envs.wrappers import gym as gym_wrapper
from brax.io import image
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import traceback
# print(f"Using Brax {brax.__version__}, Jax {jax.__version__}")
# print("From GymWrapper, env.reset()")
# try:
# env = envs.create("inverted_pendulum",
# batch_size=1,
# episode_length=150,
# backend='generalized')
# env = gym_wrapper.GymWrapper(env)
# env.reset()
# img = env.render(mode='rgb_array')
# plt.imshow(img)
# except Exception:
# traceback.print_exc()
#
# print("From GymWrapper, env.reset() and action")
# try:
# env = envs.create("inverted_pendulum",
# batch_size=1,
# episode_length=150,
# backend='generalized')
# env = gym_wrapper.GymWrapper(env)
# env.reset()
# action = jnp.zeros(env.action_space.shape)
# env.step(action)
# img = env.render(mode='rgb_array')
# plt.imshow(img)
# except Exception:
# traceback.print_exc()
print("From brax env")
try:
env = envs.create("inverted_pendulum",
batch_size=1,
episode_length=150,
backend='generalized')
key = jax.random.PRNGKey(0)
initial_env_state = env.reset(key)
base_state = initial_env_state.pipeline_state
pipeline_state = env.pipeline_init(base_state.q.ravel(), base_state.qd.ravel())
img = image.render_array(sys=env.sys, state=pipeline_state, width=256, height=256)
print(f"pixel values: [{img.min()}, {img.max()}]")
plt.imshow(img)
plt.show()
except Exception:
traceback.print_exc()

View File

@@ -1,32 +1,31 @@
from config import *
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.func_fit import XOR, FuncFitConfig
from algorithm.neat import *
from problem.func_fit import XOR3d
if __name__ == '__main__':
# running config
config = Config(
basic=BasicConfig(
seed=42,
fitness_target=-1e-2,
pop_size=10000
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=3,
num_outputs=1,
max_nodes=50,
max_conns=100,
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
),
),
neat=NeatConfig(
inputs=2,
outputs=1
),
gene=NormalGeneConfig(),
problem=FuncFitConfig(
error_method='rmse'
)
problem=XOR3d(),
generation_limit=10000,
fitness_target=-1e-8
)
# define algorithm: NEAT with NormalGene
algorithm = NEAT(config, NormalGene)
# full pipeline
pipeline = Pipeline(config, algorithm, XOR)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)
# show result

View File

@@ -0,0 +1,51 @@
from pipeline import Pipeline
from algorithm.neat import *
from algorithm.hyperneat import *
from utils import Act
from problem.func_fit import XOR3d
if __name__ == '__main__':
pipeline = Pipeline(
algorithm=HyperNEAT(
substrate=FullSubstrate(
input_coors=[(-1, -1), (0.333, -1), (-0.333, -1), (1, -1)],
hidden_coors=[
(-1, -0.5), (0.333, -0.5), (-0.333, -0.5), (1, -0.5),
(-1, 0), (0.333, 0), (-0.333, 0), (1, 0),
(-1, 0.5), (0.333, 0.5), (-0.333, 0.5), (1, 0.5),
],
output_coors=[(0, 1), ],
),
neat=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=4, # [-1, -1, -1, 0]
num_outputs=1,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
),
),
activation=Act.sigmoid,
activate_time=10,
),
problem=XOR3d(),
generation_limit=300,
fitness_target=-1e-6
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)
# show result
pipeline.show(state, best)

View File

@@ -1,41 +0,0 @@
from config import *
from pipeline import Pipeline
from algorithm.neat import NormalGene, NormalGeneConfig
from algorithm.hyperneat import HyperNEAT, NormalSubstrate, NormalSubstrateConfig
from problem.func_fit import XOR3d, FuncFitConfig
from utils import Act
if __name__ == '__main__':
config = Config(
basic=BasicConfig(
seed=42,
fitness_target=0,
pop_size=1000
),
neat=NeatConfig(
max_nodes=50,
max_conns=100,
max_species=30,
inputs=4,
outputs=1
),
hyperneat=HyperNeatConfig(
inputs=3,
outputs=1
),
substrate=NormalSubstrateConfig(
input_coors=((-1, -1), (-0.5, -1), (0.5, -1), (1, -1)),
),
gene=NormalGeneConfig(
activation_default=Act.tanh,
activation_options=(Act.tanh, ),
),
problem=FuncFitConfig()
)
algorithm = HyperNEAT(config, NormalGene, NormalSubstrate)
pipeline = Pipeline(config, algorithm, XOR3d)
state = pipeline.setup()
state, best = pipeline.auto_run(state)
pipeline.show(state, best)

View File

@@ -1,41 +1,41 @@
from config import *
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig
from problem.func_fit import XOR3d, FuncFitConfig
from algorithm.neat import *
from problem.func_fit import XOR3d
from utils.activation import ACT_ALL
from utils.aggregation import AGG_ALL
if __name__ == '__main__':
config = Config(
basic=BasicConfig(
seed=42,
fitness_target=-1e-2,
generation_limit=300,
pop_size=1000
pipeline = Pipeline(
seed=0,
algorithm=NEAT(
species=DefaultSpecies(
genome=RecurrentGenome(
num_inputs=3,
num_outputs=1,
max_nodes=50,
max_conns=100,
activate_time=5,
node_gene=DefaultNodeGene(
activation_options=ACT_ALL,
# aggregation_options=AGG_ALL,
activation_replace_rate=0.2
),
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
),
),
neat=NeatConfig(
network_type="recurrent",
max_nodes=50,
max_conns=100,
max_species=30,
conn_add=0.5,
conn_delete=0.5,
node_add=0.4,
node_delete=0.4,
inputs=3,
outputs=1
),
gene=RecurrentGeneConfig(
activate_times=10
),
problem=FuncFitConfig(
error_method='rmse'
)
problem=XOR3d(),
generation_limit=10000,
fitness_target=-1e-8
)
algorithm = NEAT(config, RecurrentGene)
pipeline = Pipeline(config, algorithm, XOR3d)
# initialize state
state = pipeline.setup()
pipeline.pre_compile(state)
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)
# show result
pipeline.show(state, best)

View File

@@ -1,36 +0,0 @@
from config import *
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.func_fit import XOR, FuncFitConfig
if __name__ == '__main__':
config = Config(
basic=BasicConfig(
seed=42,
fitness_target=-1e-2,
pop_size=10000
),
neat=NeatConfig(
max_nodes=50,
max_conns=100,
max_species=30,
conn_add=0.8,
conn_delete=0,
node_add=0.4,
node_delete=0,
inputs=2,
outputs=1
),
gene=NormalGeneConfig(),
problem=FuncFitConfig(
error_method='rmse'
)
)
algorithm = NEAT(config, NormalGene)
pipeline = Pipeline(config, algorithm, XOR)
state = pipeline.setup()
pipeline.pre_compile(state)
state, best = pipeline.auto_run(state)
pipeline.show(state, best)

View File

@@ -1,39 +0,0 @@
import jax.numpy as jnp
from config import *
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
def example_conf():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=0,
pop_size=10000
),
neat=NeatConfig(
inputs=6,
outputs=3,
),
gene=NormalGeneConfig(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
problem=GymNaxConfig(
env_name='Acrobot-v1',
output_transform=lambda out: jnp.argmax(out) # the action of acrobot is {0, 1, 2}
)
)
if __name__ == '__main__':
conf = example_conf()
algorithm = NEAT(conf, NormalGene)
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
state = pipeline.setup()
pipeline.pre_compile(state)
state, best = pipeline.auto_run(state)

34
examples/gymnax/arcbot.py Normal file
View File

@@ -0,0 +1,34 @@
import jax.numpy as jnp
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import GymNaxEnv
if __name__ == '__main__':
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=6,
num_outputs=3,
max_nodes=50,
max_conns=100,
output_transform=lambda out: jnp.argmax(out) # the action of acrobot is {0, 1, 2}
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name='Acrobot-v1',
),
generation_limit=10000,
fitness_target=-62
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -1,84 +1,34 @@
import jax.numpy as jnp
from config import *
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
def example_conf1():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=500,
pop_size=10000
),
neat=NeatConfig(
inputs=4,
outputs=1,
),
gene=NormalGeneConfig(
activation_default=Act.sigmoid,
activation_options=(Act.sigmoid,),
),
problem=GymNaxConfig(
env_name='CartPole-v1',
output_transform=lambda out: jnp.where(out[0] > 0.5, 1, 0) # the action of cartpole is {0, 1}
)
)
def example_conf2():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=500,
pop_size=10000
),
neat=NeatConfig(
inputs=4,
outputs=1,
),
gene=NormalGeneConfig(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
problem=GymNaxConfig(
env_name='CartPole-v1',
output_transform=lambda out: jnp.where(out[0] > 0, 1, 0) # the action of cartpole is {0, 1}
)
)
def example_conf3():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=501,
pop_size=10000
),
neat=NeatConfig(
inputs=4,
outputs=2,
),
gene=NormalGeneConfig(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
problem=GymNaxConfig(
env_name='CartPole-v1',
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1}
)
)
from algorithm.neat import *
from problem.rl_env import GymNaxEnv
if __name__ == '__main__':
# all config files above can solve cartpole
conf = example_conf3()
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=4,
num_outputs=2,
max_nodes=50,
max_conns=100,
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1}
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name='CartPole-v1',
),
generation_limit=10000,
fitness_target=500
)
algorithm = NEAT(conf, NormalGene)
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
# initialize state
state = pipeline.setup()
pipeline.pre_compile(state)
state, best = pipeline.auto_run(state)
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -1,39 +1,34 @@
import jax.numpy as jnp
from config import *
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
def example_conf():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=0,
pop_size=10000
),
neat=NeatConfig(
inputs=2,
outputs=3,
),
gene=NormalGeneConfig(
activation_default=Act.sigmoid,
activation_options=(Act.sigmoid,),
),
problem=GymNaxConfig(
env_name='MountainCar-v0',
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1, 2}
)
)
from algorithm.neat import *
from problem.rl_env import GymNaxEnv
if __name__ == '__main__':
conf = example_conf()
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=2,
num_outputs=3,
max_nodes=50,
max_conns=100,
output_transform=lambda out: jnp.argmax(out) # the action of mountain car is {0, 1, 2}
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name='MountainCar-v0',
),
generation_limit=10000,
fitness_target=0
)
algorithm = NEAT(conf, NormalGene)
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
# initialize state
state = pipeline.setup()
pipeline.pre_compile(state)
state, best = pipeline.auto_run(state)
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -1,38 +1,36 @@
import jax.numpy as jnp
from config import *
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
def example_conf():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=100,
pop_size=10000
),
neat=NeatConfig(
inputs=2,
outputs=1,
),
gene=NormalGeneConfig(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
problem=GymNaxConfig(
env_name='MountainCarContinuous-v0'
)
)
from algorithm.neat import *
from problem.rl_env import GymNaxEnv
from utils import Act
if __name__ == '__main__':
conf = example_conf()
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=2,
num_outputs=1,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh, ),
activation_default=Act.tanh,
)
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name='MountainCarContinuous-v0',
),
generation_limit=10000,
fitness_target=500
)
algorithm = NEAT(conf, NormalGene)
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
# initialize state
state = pipeline.setup()
pipeline.pre_compile(state)
state, best = pipeline.auto_run(state)
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -1,40 +1,37 @@
import jax.numpy as jnp
from config import *
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
def example_conf():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=0,
pop_size=10000
),
neat=NeatConfig(
inputs=3,
outputs=1,
),
gene=NormalGeneConfig(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
problem=GymNaxConfig(
env_name='Pendulum-v1',
output_transform=lambda out: out * 2 # the action of pendulum is [-2, 2]
)
)
from algorithm.neat import *
from problem.rl_env import GymNaxEnv
from utils import Act
if __name__ == '__main__':
conf = example_conf()
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=3,
num_outputs=1,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
),
output_transform=lambda out: out * 2 # the action of pendulum is [-2, 2]
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name='Pendulum-v1',
),
generation_limit=10000,
fitness_target=0
)
algorithm = NEAT(conf, NormalGene)
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
# initialize state
state = pipeline.setup()
pipeline.pre_compile(state)
state, best = pipeline.auto_run(state)
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -1,36 +1,33 @@
from config import *
import jax.numpy as jnp
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
def example_conf():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=500,
pop_size=10000
),
neat=NeatConfig(
inputs=8,
outputs=2,
),
gene=NormalGeneConfig(
activation_default=Act.sigmoid,
activation_options=(Act.sigmoid,),
),
problem=GymNaxConfig(
env_name='Reacher-misc',
)
)
from algorithm.neat import *
from problem.rl_env import GymNaxEnv
if __name__ == '__main__':
conf = example_conf()
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=8,
num_outputs=2,
max_nodes=50,
max_conns=100,
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name='Reacher-misc',
),
generation_limit=10000,
fitness_target =500
)
algorithm = NEAT(conf, NormalGene)
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
# initialize state
state = pipeline.setup()
pipeline.pre_compile(state)
state, best = pipeline.auto_run(state)
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -1,25 +1,23 @@
from functools import partial
from typing import Type
import jax
import jax, jax.numpy as jnp
import time
import numpy as np
from algorithm import NEAT, HyperNEAT
from config import Config
from core import State, Algorithm, Problem
from algorithm import BaseAlgorithm
from problem import BaseProblem
from utils import State
class Pipeline:
def __init__(
self,
algorithm: Algorithm,
problem: Problem,
seed: int = 42,
fitness_target: float = 1,
generation_limit: int = 1000,
pop_size: int = 100,
self,
algorithm: BaseAlgorithm,
problem: BaseProblem,
seed: int = 42,
fitness_target: float = 1,
generation_limit: int = 1000,
):
assert problem.jitable, "Currently, problem must be jitable"
@@ -28,17 +26,18 @@ class Pipeline:
self.seed = seed
self.fitness_target = fitness_target
self.generation_limit = generation_limit
self.pop_size = pop_size
self.pop_size = self.algorithm.pop_size
print(self.problem.input_shape, self.problem.output_shape)
# TODO: make each algorithm's input_num and output_num
assert algorithm.input_num == self.problem.input_shape[-1], f"problem input shape {self.problem.input_shape}"
assert algorithm.num_inputs == self.problem.input_shape[-1], \
f"algorithm input shape is {algorithm.num_inputs} but problem input shape is {self.problem.input_shape}"
self.act_func = self.algorithm.act
# self.act_func = self.algorithm.act
for _ in range(len(self.problem.input_shape) - 1):
self.act_func = jax.vmap(self.act_func, in_axes=(None, 0, None))
# for _ in range(len(self.problem.input_shape) - 1):
# self.act_func = jax.vmap(self.act_func, in_axes=(None, 0, None))
self.best_genome = None
self.best_fitness = float('-inf')
@@ -46,41 +45,57 @@ class Pipeline:
def setup(self):
key = jax.random.PRNGKey(self.seed)
algorithm_key, evaluate_key = jax.random.split(key, 2)
key, algorithm_key, evaluate_key = jax.random.split(key, 3)
# TODO: Problem should has setup function to maintain state
return State(
randkey=key,
alg=self.algorithm.setup(algorithm_key),
pro=self.problem.setup(evaluate_key),
)
@partial(jax.jit, static_argnums=(0,))
def step(self, state):
key, sub_key = jax.random.split(state.evaluate_key)
key, sub_key = jax.random.split(state.randkey)
keys = jax.random.split(key, self.pop_size)
pop = self.algorithm.ask(state)
pop = self.algorithm.ask(state.alg)
pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0))(state, pop)
pop_transformed = jax.vmap(self.algorithm.transform)(pop)
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0))(keys, state, self.act_func,
pop_transformed)
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0))(
keys,
state.pro,
self.algorithm.forward,
pop_transformed
)
state = self.algorithm.tell(state, fitnesses)
fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses)
return state.update(evaluate_key=sub_key), fitnesses
alg_state = self.algorithm.tell(state.alg, fitnesses)
return state.update(
randkey=sub_key,
alg=alg_state,
), fitnesses
def auto_run(self, ini_state):
state = ini_state
compiled_step = jax.jit(self.step).lower(ini_state).compile()
for _ in range(self.generation_limit):
self.generation_timestamp = time.time()
previous_pop = self.algorithm.ask(state)
previous_pop = self.algorithm.ask(state.alg)
state, fitnesses = self.step(state)
state, fitnesses = compiled_step(state)
fitnesses = jax.device_get(fitnesses)
for idx, fitnesses_i in enumerate(fitnesses):
if np.isnan(fitnesses_i):
print("Fitness is nan")
print(previous_pop[0][idx], previous_pop[1][idx])
assert False
self.analysis(state, previous_pop, fitnesses)
@@ -102,22 +117,15 @@ class Pipeline:
max_idx = np.argmax(fitnesses)
if fitnesses[max_idx] > self.best_fitness:
self.best_fitness = fitnesses[max_idx]
self.best_genome = pop[max_idx]
self.best_genome = pop[0][max_idx], pop[1][max_idx]
member_count = jax.device_get(state.species_info.member_count)
member_count = jax.device_get(self.algorithm.member_count(state.alg))
species_sizes = [int(i) for i in member_count if i > 0]
print(f"Generation: {state.generation}",
print(f"Generation: {self.algorithm.generation(state.alg)}",
f"species: {len(species_sizes)}, {species_sizes}",
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")
def show(self, state, genome, *args, **kwargs):
transformed = self.algorithm.transform(state, genome)
self.problem.show(state.evaluate_key, state, self.act_func, transformed, *args, **kwargs)
def pre_compile(self, state):
tic = time.time()
print("start compile")
self.step.lower(self, state).compile()
print(f"compile finished, cost time: {time.time() - tic}s")
def show(self, state, best, *args, **kwargs):
transformed = self.algorithm.transform(best)
self.problem.show(state.randkey, state.pro, self.algorithm.forward, transformed, *args, **kwargs)

View File

@@ -1,19 +1,14 @@
from typing import Callable
from config import ProblemConfig
from core.state import State
from utils import State
class BaseProblem:
jitable = None
def __init__(self):
pass
def setup(self, randkey, state: State = State()):
"""initialize the state of the problem"""
raise NotImplementedError
pass
def evaluate(self, randkey, state: State, act_func: Callable, params):
"""evaluate one individual"""

View File

@@ -1,24 +1,27 @@
import jax
import jax.numpy as jnp
from utils import State
from .. import BaseProblem
class FuncFit(BaseProblem):
class FuncFit(BaseProblem):
jitable = True
def __init__(self,
error_method: str = 'mse'
):
error_method: str = 'mse'
):
super().__init__()
assert error_method in {'mse', 'rmse', 'mae', 'mape'}
self.error_method = error_method
def setup(self, randkey, state: State = State()):
return state
def evaluate(self, randkey, state, act_func, params):
predict = act_func(state, self.inputs, params)
predict = jax.vmap(act_func, in_axes=(0, None))(self.inputs, params)
if self.error_method == 'mse':
loss = jnp.mean((predict - self.targets) ** 2)
@@ -38,7 +41,7 @@ class FuncFit(BaseProblem):
return -loss
def show(self, randkey, state, act_func, params, *args, **kwargs):
predict = act_func(state, self.inputs, params)
predict = jax.vmap(act_func, in_axes=(0, None))(self.inputs, params)
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
loss = -self.evaluate(randkey, state, act_func, params)
msg = ""

View File

@@ -1,2 +1,2 @@
from .gymnax_env import GymNaxEnv, GymNaxConfig
from .brax_env import BraxEnv, BraxConfig
from .gymnax_env import GymNaxEnv
from .brax_env import BraxEnv

View File

@@ -3,7 +3,6 @@ import gymnax
from .rl_jit import RLEnv
class GymNaxEnv(RLEnv):
def __init__(self, env_name):

View File

@@ -4,8 +4,8 @@ import jax
from .. import BaseProblem
class RLEnv(BaseProblem):
class RLEnv(BaseProblem):
jitable = True
# TODO: move output transform to algorithm
@@ -19,9 +19,10 @@ class RLEnv(BaseProblem):
def cond_func(carry):
_, _, _, done, _ = carry
return ~done
def body_func(carry):
obs, env_state, rng, _, tr = carry # total reward
action = act_func(state, obs, params)
action = act_func(obs, params)
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
next_rng, _ = jax.random.split(rng)
return next_obs, next_env_state, next_rng, done, tr + reward

66
t.py
View File

@@ -1,64 +1,4 @@
from algorithm.neat import *
from utils import Act, Agg
import jax.numpy as jnp
import jax, jax.numpy as jnp
def main():
# index, bias, response, activation, aggregation
nodes = jnp.array([
[0, 0, 1, 0, 0], # in[0]
[1, 0, 1, 0, 0], # in[1]
[2, 0.5, 1, 0, 0], # out[0],
[3, 1, 1, 0, 0], # hidden[0],
[4, -1, 1, 0, 0], # hidden[1],
])
# in_node, out_node, enable, weight
conns = jnp.array([
[0, 3, 1, 0.5], # in[0] -> hidden[0]
[1, 4, 1, 0.5], # in[1] -> hidden[1]
[3, 2, 1, 0.5], # hidden[0] -> out[0]
[4, 2, 1, 0.5], # hidden[1] -> out[0]
])
genome = RecurrentGenome(
num_inputs=2,
num_outputs=1,
node_gene=DefaultNodeGene(
activation_default=Act.identity,
activation_options=(Act.identity, ),
aggregation_default=Agg.sum,
aggregation_options=(Agg.sum, ),
),
activate_time=3
)
transformed = genome.transform(nodes, conns)
print(*transformed, sep='\n')
inputs = jnp.array([0, 0])
outputs = genome.forward(inputs, transformed)
print(outputs)
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(0, None)))(inputs, transformed)
print(outputs)
expected: [[0.5], [0.75], [0.75], [1]]
print('\n-------------------------------------------------------\n')
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
print(conns)
transformed = genome.transform(nodes, conns)
print(*transformed, sep='\n')
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
outputs = jax.vmap(genome.forward, in_axes=(0, None))(inputs, transformed)
print(outputs)
expected: [[0.5], [0.75], [0.5], [0.75]]
if __name__ == '__main__':
main()
a = jnp.zeros((0, 9, 9))
print(a)

View File

@@ -26,6 +26,8 @@ def test_default():
genome = DefaultGenome(
num_inputs=2,
num_outputs=1,
max_nodes=5,
max_conns=4,
node_gene=DefaultNodeGene(
activation_default=Act.identity,
activation_options=(Act.identity, ),
@@ -80,6 +82,8 @@ def test_recurrent():
genome = RecurrentGenome(
num_inputs=2,
num_outputs=1,
max_nodes=5,
max_conns=4,
node_gene=DefaultNodeGene(
activation_default=Act.identity,
activation_options=(Act.identity, ),

View File

@@ -6,48 +6,26 @@ class Act:
@staticmethod
def sigmoid(z):
z = jnp.clip(z * 5, -60, 60)
z = jnp.clip(5 * z, -10, 10)
return 1 / (1 + jnp.exp(-z))
@staticmethod
def tanh(z):
z = jnp.clip(z * 2.5, -60, 60)
return jnp.tanh(z)
@staticmethod
def sin(z):
z = jnp.clip(z * 5, -60, 60)
return jnp.sin(z)
@staticmethod
def gauss(z):
z = jnp.clip(z * 5, -3.4, 3.4)
return jnp.exp(-z ** 2)
@staticmethod
def relu(z):
return jnp.maximum(z, 0)
@staticmethod
def elu(z):
return jnp.where(z > 0, z, jnp.exp(z) - 1)
@staticmethod
def lelu(z):
leaky = 0.005
return jnp.where(z > 0, z, leaky * z)
@staticmethod
def selu(z):
lam = 1.0507009873554804934193349852946
alpha = 1.6732632423543772848170429916717
return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1))
@staticmethod
def softplus(z):
z = jnp.clip(z * 5, -60, 60)
return 0.2 * jnp.log(1 + jnp.exp(z))
@staticmethod
def identity(z):
return z
@@ -58,7 +36,11 @@ class Act:
@staticmethod
def inv(z):
z = jnp.maximum(z, 1e-7)
z = jnp.where(
z > 0,
jnp.maximum(z, 1e-7),
jnp.minimum(z, -1e-7)
)
return 1 / z
@staticmethod
@@ -68,24 +50,27 @@ class Act:
@staticmethod
def exp(z):
z = jnp.clip(z, -60, 60)
z = jnp.clip(z, -10, 10)
return jnp.exp(z)
@staticmethod
def abs(z):
return jnp.abs(z)
@staticmethod
def hat(z):
return jnp.maximum(0, 1 - jnp.abs(z))
@staticmethod
def square(z):
return z ** 2
@staticmethod
def cube(z):
return z ** 3
ACT_ALL = (
Act.sigmoid,
Act.tanh,
Act.sin,
Act.relu,
Act.lelu,
Act.identity,
Act.clamped,
Act.inv,
Act.log,
Act.exp,
Act.abs,
)
def act(idx, z, act_funcs):

View File

@@ -51,6 +51,9 @@ class Agg:
return mean_without_zeros
AGG_ALL = (Agg.sum, Agg.product, Agg.max, Agg.min, Agg.maxabs, Agg.median, Agg.mean)
def agg(idx, z, agg_funcs):
"""
calculate activation function for inputs of node