complete fully stateful!

use black to format all files!
This commit is contained in:
wls2002
2024-05-26 18:08:43 +08:00
parent cf69b916af
commit 18c3d44c79
41 changed files with 620 additions and 495 deletions

View File

@@ -5,5 +5,5 @@ class BaseCrossover:
def setup(self, state=State()):
return state
def __call__(self, state, genome, nodes1, nodes2, conns1, conns2):
def __call__(self, state, randkey, genome, nodes1, nodes2, conns1, conns2):
raise NotImplementedError

View File

@@ -4,12 +4,12 @@ from .base import BaseCrossover
class DefaultCrossover(BaseCrossover):
def __call__(self, state, genome, nodes1, conns1, nodes2, conns2):
def __call__(self, state, randkey, genome, nodes1, conns1, nodes2, conns2):
"""
use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
"""
randkey1, randkey2, randkey = jax.random.split(state.randkey, 3)
randkey1, randkey2 = jax.random.split(randkey, 2)
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
@@ -34,11 +34,12 @@ class DefaultCrossover(BaseCrossover):
self.crossover_gene(randkey2, conns1, conns2, is_conn=True),
)
return state.update(randkey=randkey), new_nodes, new_conns
return new_nodes, new_conns
def align_array(self, seq1, seq2, ar2, is_conn: bool):
"""
After I review this code, I found that it is the most difficult part of the code. Please never change it!
After I review this code, I found that it is the most difficult part of the code.
Please consider carefully before change it!
make ar2 align with ar1.
:param seq1:
:param seq2:
@@ -64,8 +65,8 @@ class DefaultCrossover(BaseCrossover):
return refactor_ar2
def crossover_gene(self, rand_key, g1, g2, is_conn):
r = jax.random.uniform(rand_key, shape=g1.shape)
def crossover_gene(self, randkey, g1, g2, is_conn):
r = jax.random.uniform(randkey, shape=g1.shape)
new_gene = jnp.where(r > 0.5, g1, g2)
if is_conn: # fix enabled
enabled = jnp.where(g1[:, 2] + g2[:, 2] > 0, 1, 0) # any of them is enabled

View File

@@ -5,5 +5,5 @@ class BaseMutation:
def setup(self, state=State()):
return state
def __call__(self, state, genome, nodes, conns, new_node_key):
def __call__(self, state, randkey, genome, nodes, conns, new_node_key):
raise NotImplementedError

View File

@@ -1,6 +1,16 @@
import jax, jax.numpy as jnp
from . import BaseMutation
from utils import fetch_first, fetch_random, I_INF, unflatten_conns, check_cycles
from utils import (
fetch_first,
fetch_random,
I_INF,
unflatten_conns,
check_cycles,
add_node,
add_conn,
delete_node_by_pos,
delete_conn_by_pos,
)
class DefaultMutation(BaseMutation):
@@ -16,15 +26,17 @@ class DefaultMutation(BaseMutation):
self.node_add = node_add
self.node_delete = node_delete
def __call__(self, state, genome, nodes, conns, new_node_key):
k1, k2, randkey = jax.random.split(state.randkey)
def __call__(self, state, randkey, genome, nodes, conns, new_node_key):
k1, k2 = jax.random.split(randkey)
nodes, conns = self.mutate_structure(k1, genome, nodes, conns, new_node_key)
nodes, conns = self.mutate_values(k2, genome, nodes, conns)
nodes, conns = self.mutate_structure(
state, k1, genome, nodes, conns, new_node_key
)
nodes, conns = self.mutate_values(state, k2, genome, nodes, conns)
return state.update(randkey=randkey), nodes, conns
return nodes, conns
def mutate_structure(self, key, genome, nodes, conns, new_node_key):
def mutate_structure(self, state, randkey, genome, nodes, conns, new_node_key):
def mutate_add_node(key_, nodes_, conns_):
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
@@ -33,24 +45,24 @@ class DefaultMutation(BaseMutation):
new_conns = conns_.at[idx, 2].set(False)
# add a new node
new_nodes = genome.add_node(
nodes_, new_node_key, genome.node_gene.new_custom_attrs()
new_nodes = add_node(
nodes_, new_node_key, genome.node_gene.new_custom_attrs(state)
)
# add two new connections
new_conns = genome.add_conn(
new_conns = add_conn(
new_conns,
i_key,
new_node_key,
True,
genome.conn_gene.new_custom_attrs(),
genome.conn_gene.new_custom_attrs(state),
)
new_conns = genome.add_conn(
new_conns = add_conn(
new_conns,
new_node_key,
o_key,
True,
genome.conn_gene.new_custom_attrs(),
genome.conn_gene.new_custom_attrs(state),
)
return new_nodes, new_conns
@@ -75,7 +87,7 @@ class DefaultMutation(BaseMutation):
def successful_delete_node():
# delete the node
new_nodes = genome.delete_node_by_pos(nodes_, idx)
new_nodes = delete_node_by_pos(nodes_, idx)
# delete all connections
new_conns = jnp.where(
@@ -123,8 +135,8 @@ class DefaultMutation(BaseMutation):
return nodes_, conns_
def successful():
return nodes_, genome.add_conn(
conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs()
return nodes_, add_conn(
conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs(state)
)
def already_exist():
@@ -152,7 +164,7 @@ class DefaultMutation(BaseMutation):
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
def successfully_delete_connection():
return nodes_, genome.delete_conn_by_pos(conns_, idx)
return nodes_, delete_conn_by_pos(conns_, idx)
return jax.lax.cond(
idx == I_INF,
@@ -160,7 +172,7 @@ class DefaultMutation(BaseMutation):
successfully_delete_connection,
)
k1, k2, k3, k4 = jax.random.split(key, num=4)
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
def no(key_, nodes_, conns_):
@@ -181,13 +193,17 @@ class DefaultMutation(BaseMutation):
return nodes, conns
def mutate_values(self, key, genome, nodes, conns):
k1, k2 = jax.random.split(key, num=2)
def mutate_values(self, state, randkey, genome, nodes, conns):
k1, k2 = jax.random.split(randkey, num=2)
nodes_keys = jax.random.split(k1, num=nodes.shape[0])
conns_keys = jax.random.split(k2, num=conns.shape[0])
new_nodes = jax.vmap(genome.node_gene.mutate)(nodes_keys, nodes)
new_conns = jax.vmap(genome.conn_gene.mutate)(conns_keys, conns)
new_nodes = jax.vmap(genome.node_gene.mutate, in_axes=(None, 0, 0))(
state, nodes_keys, nodes
)
new_conns = jax.vmap(genome.conn_gene.mutate, in_axes=(None, 0, 0))(
state, conns_keys, conns
)
# nan nodes not changed
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)

View File

@@ -26,7 +26,7 @@ class DefaultConnGene(BaseConnGene):
self.weight_replace_rate = weight_replace_rate
def new_custom_attrs(self, state):
return state, jnp.array([self.weight_init_mean])
return jnp.array([self.weight_init_mean])
def new_random_attrs(self, state, randkey):
weight = (

View File

@@ -109,10 +109,10 @@ class DefaultNodeGene(BaseNodeGene):
def distance(self, state, node1, node2):
return (
jnp.abs(node1[1] - node2[1])
+ jnp.abs(node1[2] - node2[2])
+ (node1[3] != node2[3])
+ (node1[4] != node2[4])
jnp.abs(node1[1] - node2[1]) # bias
+ jnp.abs(node1[2] - node2[2]) # response
+ (node1[3] != node2[3]) # activation
+ (node1[4] != node2[4]) # aggregation
)
def forward(self, state, attrs, inputs, is_output_node=False):

View File

@@ -0,0 +1,106 @@
from typing import Tuple
import jax, jax.numpy as jnp
from utils import Act, Agg, act, agg, mutate_int, mutate_float
from . import BaseNodeGene
class NodeGeneWithoutResponse(BaseNodeGene):
"""
Default node gene, with the same behavior as in NEAT-python.
The attribute response is removed.
"""
custom_attrs = ["bias", "aggregation", "activation"]
def __init__(
self,
bias_init_mean: float = 0.0,
bias_init_std: float = 1.0,
bias_mutate_power: float = 0.5,
bias_mutate_rate: float = 0.7,
bias_replace_rate: float = 0.1,
activation_default: callable = Act.sigmoid,
activation_options: Tuple = (Act.sigmoid,),
activation_replace_rate: float = 0.1,
aggregation_default: callable = Agg.sum,
aggregation_options: Tuple = (Agg.sum,),
aggregation_replace_rate: float = 0.1,
):
super().__init__()
self.bias_init_mean = bias_init_mean
self.bias_init_std = bias_init_std
self.bias_mutate_power = bias_mutate_power
self.bias_mutate_rate = bias_mutate_rate
self.bias_replace_rate = bias_replace_rate
self.activation_default = activation_options.index(activation_default)
self.activation_options = activation_options
self.activation_indices = jnp.arange(len(activation_options))
self.activation_replace_rate = activation_replace_rate
self.aggregation_default = aggregation_options.index(aggregation_default)
self.aggregation_options = aggregation_options
self.aggregation_indices = jnp.arange(len(aggregation_options))
self.aggregation_replace_rate = aggregation_replace_rate
def new_custom_attrs(self, state):
return jnp.array(
[
self.bias_init_mean,
self.activation_default,
self.aggregation_default,
]
)
def new_random_attrs(self, state, randkey):
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
bias = jax.random.normal(k1, ()) * self.bias_init_std + self.bias_init_mean
act = jax.random.randint(k3, (), 0, len(self.activation_options))
agg = jax.random.randint(k4, (), 0, len(self.aggregation_options))
return jnp.array([bias, act, agg])
def mutate(self, state, randkey, node):
k1, k2, k3, k4 = jax.random.split(state.randkey, num=4)
index = node[0]
bias = mutate_float(
k1,
node[1],
self.bias_init_mean,
self.bias_init_std,
self.bias_mutate_power,
self.bias_mutate_rate,
self.bias_replace_rate,
)
act = mutate_int(
k3, node[3], self.activation_indices, self.activation_replace_rate
)
agg = mutate_int(
k4, node[4], self.aggregation_indices, self.aggregation_replace_rate
)
return jnp.array([index, bias, act, agg])
def distance(self, state, node1, node2):
return (
jnp.abs(node1[1] - node2[1]) # bias
+ (node1[3] != node2[3]) # activation
+ (node1[4] != node2[4]) # aggregation
)
def forward(self, state, attrs, inputs, is_output_node=False):
bias, act_idx, agg_idx = attrs
z = agg(agg_idx, inputs, self.aggregation_options)
z = bias + z
# the last output node should not be activated
z = jax.lax.cond(
is_output_node, lambda: z, lambda: act(act_idx, z, self.activation_options)
)
return z

View File

@@ -1,6 +1,7 @@
import jax.numpy as jnp
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
from utils import fetch_first, State
import jax, jax.numpy as jnp
from ..gene import BaseNodeGene, BaseConnGene
from ..ga import BaseMutation, BaseCrossover
from utils import State
class BaseGenome:
@@ -12,8 +13,10 @@ class BaseGenome:
num_outputs: int,
max_nodes: int,
max_conns: int,
node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(),
node_gene: BaseNodeGene,
conn_gene: BaseConnGene,
mutation: BaseMutation,
crossover: BaseCrossover,
):
self.num_inputs = num_inputs
self.num_outputs = num_outputs
@@ -23,10 +26,14 @@ class BaseGenome:
self.max_conns = max_conns
self.node_gene = node_gene
self.conn_gene = conn_gene
self.mutation = mutation
self.crossover = crossover
def setup(self, state=State()):
state = self.node_gene.setup(state)
state = self.conn_gene.setup(state)
state = self.mutation.setup(state)
state = self.crossover.setup(state)
return state
def transform(self, state, nodes, conns):
@@ -35,36 +42,81 @@ class BaseGenome:
def forward(self, state, inputs, transformed):
raise NotImplementedError
def add_node(self, nodes, new_key: int, attrs):
"""
Add a new node to the genome.
The new node will place at the first NaN row.
"""
exist_keys = nodes[:, 0]
pos = fetch_first(jnp.isnan(exist_keys))
new_nodes = nodes.at[pos, 0].set(new_key)
return new_nodes.at[pos, 1:].set(attrs)
def execute_mutation(self, state, randkey, nodes, conns, new_node_key):
return self.mutation(state, randkey, self, nodes, conns, new_node_key)
def delete_node_by_pos(self, nodes, pos):
"""
Delete a node from the genome.
Delete the node by its pos in nodes.
"""
return nodes.at[pos].set(jnp.nan)
def execute_crossover(self, state, randkey, nodes1, conns1, nodes2, conns2):
return self.crossover(state, randkey, self, nodes1, conns1, nodes2, conns2)
def add_conn(self, conns, i_key, o_key, enable: bool, attrs):
def initialize(self, state, randkey):
"""
Add a new connection to the genome.
The new connection will place at the first NaN row.
"""
con_keys = conns[:, 0]
pos = fetch_first(jnp.isnan(con_keys))
new_conns = conns.at[pos, 0:3].set(jnp.array([i_key, o_key, enable]))
return new_conns.at[pos, 3:].set(attrs)
Default initialization method for the genome.
Add an extra hidden node.
Make all input nodes and output nodes connected to the hidden node.
All attributes will be initialized randomly using gene.new_random_attrs method.
def delete_conn_by_pos(self, conns, pos):
For example, a network with 2 inputs and 1 output, the structure will be:
nodes:
[
[0, attrs0], # input node 0
[1, attrs1], # input node 1
[2, attrs2], # output node 0
[3, attrs3], # hidden node
[NaN, NaN], # empty node
]
conns:
[
[0, 3, attrs0], # input node 0 -> hidden node
[1, 3, attrs1], # input node 1 -> hidden node
[3, 2, attrs2], # hidden node -> output node 0
[NaN, NaN],
[NaN, NaN],
]
"""
Delete a connection from the genome.
Delete the connection by its idx.
"""
return conns.at[pos].set(jnp.nan)
k1, k2 = jax.random.split(randkey) # k1 for nodes, k2 for conns
# initialize nodes
new_node_key = (
max([*self.input_idx, *self.output_idx]) + 1
) # the key for the hidden node
node_keys = jnp.concatenate(
[self.input_idx, self.output_idx, jnp.array([new_node_key])]
) # the list of all node keys
# initialize nodes and connections with NaN
nodes = jnp.full((self.max_nodes, self.node_gene.length), jnp.nan)
conns = jnp.full((self.max_conns, self.conn_gene.length), jnp.nan)
# set keys for input nodes, output nodes and hidden node
nodes = nodes.at[node_keys, 0].set(node_keys)
# generate random attributes for nodes
node_keys = jax.random.split(k1, len(node_keys))
random_node_attrs = jax.vmap(
self.node_gene.new_random_attrs, in_axes=(None, 0)
)(state, node_keys)
nodes = nodes.at[: len(node_keys), 1:].set(random_node_attrs)
# initialize conns
# input-hidden connections
input_conns = jnp.c_[
self.input_idx, jnp.full_like(self.input_idx, new_node_key)
]
conns = conns.at[self.input_idx, :2].set(input_conns) # in-keys, out-keys
conns = conns.at[self.input_idx, 2].set(True) # enable
# output-hidden connections
output_conns = jnp.c_[
jnp.full_like(self.output_idx, new_node_key), self.output_idx
]
conns = conns.at[self.output_idx, :2].set(output_conns) # in-keys, out-keys
conns = conns.at[self.output_idx, 2].set(True) # enable
conn_keys = jax.random.split(k2, num=len(self.input_idx) + len(self.output_idx))
# generate random attributes for conns
random_conn_attrs = jax.vmap(
self.conn_gene.new_random_attrs, in_axes=(None, 0)
)(state, conn_keys)
conns = conns.at[: len(conn_keys), 3:].set(random_conn_attrs)
return nodes, conns

View File

@@ -5,6 +5,7 @@ from utils import unflatten_conns, topological_sort, I_INF
from . import BaseGenome
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
from ..ga import BaseMutation, BaseCrossover, DefaultMutation, DefaultCrossover
class DefaultGenome(BaseGenome):
@@ -20,10 +21,19 @@ class DefaultGenome(BaseGenome):
max_conns=4,
node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(),
mutation: BaseMutation = DefaultMutation(),
crossover: BaseCrossover = DefaultCrossover(),
output_transform: Callable = None,
):
super().__init__(
num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene
num_inputs,
num_outputs,
max_nodes,
max_conns,
node_gene,
conn_gene,
mutation,
crossover,
)
if output_transform is not None:

View File

@@ -5,6 +5,7 @@ from utils import unflatten_conns
from . import BaseGenome
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
from ..ga import BaseMutation, BaseCrossover, DefaultMutation, DefaultCrossover
class RecurrentGenome(BaseGenome):
@@ -20,11 +21,20 @@ class RecurrentGenome(BaseGenome):
max_conns: int,
node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(),
mutation: BaseMutation = DefaultMutation(),
crossover: BaseCrossover = DefaultCrossover(),
activate_time: int = 10,
output_transform: Callable = None,
):
super().__init__(
num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene
num_inputs,
num_outputs,
max_nodes,
max_conns,
node_gene,
conn_gene,
mutation,
crossover,
)
self.activate_time = activate_time

View File

@@ -10,18 +10,12 @@ class NEAT(BaseAlgorithm):
def __init__(
self,
species: BaseSpecies,
mutation: BaseMutation = DefaultMutation(),
crossover: BaseCrossover = DefaultCrossover(),
):
self.genome: BaseGenome = species.genome
self.species = species
self.mutation = mutation
self.crossover = crossover
self.genome = species.genome
def setup(self, state=State()):
state = self.species.setup(state)
state = self.mutation.setup(state)
state = self.crossover.setup(state)
state = state.register(
generation=jnp.array(0.0),
next_node_key=jnp.array(
@@ -32,18 +26,16 @@ class NEAT(BaseAlgorithm):
return state
def ask(self, state: State):
return state, self.species.ask(state.species)
return self.species.ask(state)
def tell(self, state: State, fitness):
k1, k2, randkey = jax.random.split(state.randkey, 3)
state = state.update(generation=state.generation + 1, randkey=randkey)
state, winner, loser, elite_mask = self.species.update_species(
state.species, fitness
)
state, winner, loser, elite_mask = self.species.update_species(state, fitness)
state = self.create_next_generation(state, winner, loser, elite_mask)
state = self.species.speciate(state.species)
state = self.species.speciate(state)
return state
@@ -73,21 +65,25 @@ class NEAT(BaseAlgorithm):
new_node_keys = jnp.arange(pop_size) + state.next_node_key
k1, k2, randkey = jax.random.split(state.randkey, 3)
crossover_rand_keys = jax.random.split(k1, pop_size)
mutate_rand_keys = jax.random.split(k2, pop_size)
crossover_randkeys = jax.random.split(k1, pop_size)
mutate_randkeys = jax.random.split(k2, pop_size)
wpn, wpc = state.species.pop_nodes[winner], state.species.pop_conns[winner]
lpn, lpc = state.species.pop_nodes[loser], state.species.pop_conns[loser]
wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner]
lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser]
# batch crossover
n_nodes, n_conns = jax.vmap(self.crossover, in_axes=(0, None, 0, 0, 0, 0))(
crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc
)
n_nodes, n_conns = jax.vmap(
self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0)
)(
state, crossover_randkeys, wpn, wpc, lpn, lpc
) # new_nodes, new_conns
# batch mutation
m_n_nodes, m_n_conns = jax.vmap(self.mutation, in_axes=(0, None, 0, 0, 0))(
mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys
)
m_n_nodes, m_n_conns = jax.vmap(
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
)(
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
) # mutated_new_nodes, mutated_new_conns
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
@@ -108,8 +104,8 @@ class NEAT(BaseAlgorithm):
)
def member_count(self, state: State):
return state, state.species.member_count
return state.member_count
def generation(self, state: State):
# to analysis the algorithm
return state, state.generation
return state.generation

View File

@@ -1,10 +1,22 @@
import numpy as np
import jax, jax.numpy as jnp
from utils import State, rank_elements, argmin_with_mask, fetch_first
from ..genome import BaseGenome
from .base import BaseSpecies
"""
Core procedures of NEAT algorithm, contains the following steps:
1. Update the fitness of each species;
2. Decide which species will be stagnation;
3. Decide the number of members of each species in the next generation;
4. Choice the crossover pair for each species;
5. Divided the whole new population into different species;
This class use tensor operation to imitate the behavior of NEAT algorithm which implemented in NEAT-python.
The code may be hard to understand. Fortunately, we don't need to overwrite it in most cases.
"""
class DefaultSpecies(BaseSpecies):
def __init__(
self,
@@ -20,8 +32,6 @@ class DefaultSpecies(BaseSpecies):
survival_threshold: float = 0.2,
min_species_size: int = 1,
compatibility_threshold: float = 3.0,
initialize_method: str = "one_hidden_node",
# {'one_hidden_node', 'dense_hideen_layer', 'no_hidden_random'}
):
self.genome = genome
self.pop_size = pop_size
@@ -36,15 +46,17 @@ class DefaultSpecies(BaseSpecies):
self.survival_threshold = survival_threshold
self.min_species_size = min_species_size
self.compatibility_threshold = compatibility_threshold
self.initialize_method = initialize_method
self.species_arange = jnp.arange(self.species_size)
def setup(self, state=State()):
state = self.genome.setup(state)
k1, randkey = jax.random.split(state.randkey, 2)
pop_nodes, pop_conns = initialize_population(
self.pop_size, self.genome, k1, self.initialize_method
# initialize the population
initialize_keys = jax.random.split(randkey, self.pop_size)
pop_nodes, pop_conns = jax.vmap(self.genome.initialize, in_axes=(None, 0))(
state, initialize_keys
)
species_keys = jnp.full(
@@ -82,8 +94,9 @@ class DefaultSpecies(BaseSpecies):
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns))
state = state.update(randkey=randkey)
return state.register(
randkey=randkey,
pop_nodes=pop_nodes,
pop_conns=pop_conns,
species_keys=species_keys,
@@ -97,7 +110,7 @@ class DefaultSpecies(BaseSpecies):
)
def ask(self, state):
return state, state.pop_nodes, state.pop_conns
return state.pop_nodes, state.pop_conns
def update_species(self, state, fitness):
# update the fitness of each species
@@ -122,8 +135,8 @@ class DefaultSpecies(BaseSpecies):
k1, k2 = jax.random.split(state.randkey)
# crossover info
winner, loser, elite_mask = self.create_crossover_pair(
state, k1, spawn_number, fitness
state, winner, loser, elite_mask = self.create_crossover_pair(
state, spawn_number, fitness
)
return state.update(randkey=k2), winner, loser, elite_mask
@@ -322,12 +335,12 @@ class DefaultSpecies(BaseSpecies):
winner = jnp.where(is_part1_win, part1, part2)
loser = jnp.where(is_part1_win, part2, part1)
return state(randkey=randkey), winner, loser, elite_mask
return state.update(randkey=randkey), winner, loser, elite_mask
def speciate(self, state):
# prepare distance functions
o2p_distance_func = jax.vmap(
self.distance, in_axes=(None, None, 0, 0)
self.distance, in_axes=(None, None, None, 0, 0)
) # one to population
# idx to specie key
@@ -351,7 +364,7 @@ class DefaultSpecies(BaseSpecies):
i, i2s, cns, ccs, o2c = carry
distances = o2p_distance_func(
cns[i], ccs[i], state.pop_nodes, state.pop_conns
state, cns[i], ccs[i], state.pop_nodes, state.pop_conns
)
# find the closest one
@@ -434,7 +447,7 @@ class DefaultSpecies(BaseSpecies):
def speciate_by_threshold(i, i2s, cns, ccs, sk, o2c):
# distance between such center genome and ppo genomes
o2p_distance = o2p_distance_func(
cns[i], ccs[i], state.pop_nodes, state.pop_conns
state, cns[i], ccs[i], state.pop_nodes, state.pop_conns
)
close_enough_mask = o2p_distance < self.compatibility_threshold
@@ -508,14 +521,16 @@ class DefaultSpecies(BaseSpecies):
next_species_key=next_species_key,
)
def distance(self, nodes1, conns1, nodes2, conns2):
def distance(self, state, nodes1, conns1, nodes2, conns2):
"""
The distance between two genomes
"""
d = self.node_distance(nodes1, nodes2) + self.conn_distance(conns1, conns2)
d = self.node_distance(state, nodes1, nodes2) + self.conn_distance(
state, conns1, conns2
)
return d
def node_distance(self, nodes1, nodes2):
def node_distance(self, state, nodes1, nodes2):
"""
The distance of the nodes part for two genomes
"""
@@ -541,7 +556,9 @@ class DefaultSpecies(BaseSpecies):
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
# calculate the distance of homologous nodes
hnd = jax.vmap(self.genome.node_gene.distance, in_axes=(0, 0))(fr, sr)
hnd = jax.vmap(self.genome.node_gene.distance, in_axes=(None, 0, 0))(
state, fr, sr
) # homologous node distance
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
homologous_distance = jnp.sum(hnd * intersect_mask)
@@ -550,9 +567,11 @@ class DefaultSpecies(BaseSpecies):
+ homologous_distance * self.compatibility_weight
)
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
def conn_distance(self, conns1, conns2):
return val
def conn_distance(self, state, conns1, conns2):
"""
The distance of the conns part for two genomes
"""
@@ -573,7 +592,9 @@ class DefaultSpecies(BaseSpecies):
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
hcd = jax.vmap(self.genome.conn_gene.distance, in_axes=(0, 0))(fr, sr)
hcd = jax.vmap(self.genome.conn_gene.distance, in_axes=(None, 0, 0))(
state, fr, sr
) # homologous connection distance
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
homologous_distance = jnp.sum(hcd * intersect_mask)
@@ -582,185 +603,6 @@ class DefaultSpecies(BaseSpecies):
+ homologous_distance * self.compatibility_weight
)
return jnp.where(max_cnt == 0, 0, val / max_cnt)
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
def initialize_population(pop_size, genome, randkey, init_method="default"):
rand_keys = jax.random.split(randkey, pop_size)
if init_method == "one_hidden_node":
init_func = init_one_hidden_node
elif init_method == "dense_hideen_layer":
init_func = init_dense_hideen_layer
elif init_method == "no_hidden_random":
init_func = init_no_hidden_random
else:
raise ValueError("Unknown initialization method: {}".format(init_method))
pop_nodes, pop_conns = jax.vmap(init_func, in_axes=(None, 0))(genome, rand_keys)
return pop_nodes, pop_conns
# one hidden node
def init_one_hidden_node(genome, randkey):
input_idx, output_idx = genome.input_idx, genome.output_idx
new_node_key = max([*input_idx, *output_idx]) + 1
nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan)
conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan)
nodes = nodes.at[input_idx, 0].set(input_idx)
nodes = nodes.at[output_idx, 0].set(output_idx)
nodes = nodes.at[new_node_key, 0].set(new_node_key)
rand_keys_nodes = jax.random.split(
randkey, num=len(input_idx) + len(output_idx) + 1
)
input_keys, output_keys, hidden_key = (
rand_keys_nodes[: len(input_idx)],
rand_keys_nodes[len(input_idx) : len(input_idx) + len(output_idx)],
rand_keys_nodes[-1],
)
node_attr_func = jax.vmap(genome.node_gene.new_attrs, in_axes=(None, 0))
input_attrs = node_attr_func(input_keys)
output_attrs = node_attr_func(output_keys)
hidden_attrs = genome.node_gene.new_custom_attrs(hidden_key)
nodes = nodes.at[input_idx, 1:].set(input_attrs)
nodes = nodes.at[output_idx, 1:].set(output_attrs)
nodes = nodes.at[new_node_key, 1:].set(hidden_attrs)
input_conns = jnp.c_[input_idx, jnp.full_like(input_idx, new_node_key)]
conns = conns.at[input_idx, 0:2].set(input_conns)
conns = conns.at[input_idx, 2].set(True)
output_conns = jnp.c_[jnp.full_like(output_idx, new_node_key), output_idx]
conns = conns.at[output_idx, 0:2].set(output_conns)
conns = conns.at[output_idx, 2].set(True)
rand_keys_conns = jax.random.split(randkey, num=len(input_idx) + len(output_idx))
input_conn_keys, output_conn_keys = (
rand_keys_conns[: len(input_idx)],
rand_keys_conns[len(input_idx) :],
)
conn_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(None, 0))
input_conn_attrs = conn_attr_func(input_conn_keys)
output_conn_attrs = conn_attr_func(output_conn_keys)
conns = conns.at[input_idx, 3:].set(input_conn_attrs)
conns = conns.at[output_idx, 3:].set(output_conn_attrs)
return nodes, conns
# random dense connections with 1 hidden layer
def init_dense_hideen_layer(genome, randkey, hiddens=20):
k1, k2, k3 = jax.random.split(randkey, num=3)
input_idx, output_idx = genome.input_idx, genome.output_idx
input_size = len(input_idx)
output_size = len(output_idx)
hidden_idx = jnp.arange(
input_size + output_size, input_size + output_size + hiddens
)
nodes = jnp.full(
(genome.max_nodes, genome.node_gene.length), jnp.nan, dtype=jnp.float32
)
nodes = nodes.at[input_idx, 0].set(input_idx)
nodes = nodes.at[output_idx, 0].set(output_idx)
nodes = nodes.at[hidden_idx, 0].set(hidden_idx)
total_idx = input_size + output_size + hiddens
rand_keys_n = jax.random.split(k1, num=total_idx)
input_keys = rand_keys_n[:input_size]
output_keys = rand_keys_n[input_size : input_size + output_size]
hidden_keys = rand_keys_n[input_size + output_size :]
node_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(0))
input_attrs = node_attr_func(input_keys)
output_attrs = node_attr_func(output_keys)
hidden_attrs = node_attr_func(hidden_keys)
nodes = nodes.at[input_idx, 1:].set(input_attrs)
nodes = nodes.at[output_idx, 1:].set(output_attrs)
nodes = nodes.at[hidden_idx, 1:].set(hidden_attrs)
total_connections = input_size * hiddens + hiddens * output_size
conns = jnp.full(
(genome.max_conns, genome.conn_gene.length), jnp.nan, dtype=jnp.float32
)
rand_keys_c = jax.random.split(k2, num=total_connections)
conns_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(0))
conns_attrs = conns_attr_func(rand_keys_c)
input_to_hidden_ids, hidden_ids = jnp.meshgrid(input_idx, hidden_idx, indexing="ij")
hidden_to_output_ids, output_ids = jnp.meshgrid(
hidden_idx, output_idx, indexing="ij"
)
conns = conns.at[: input_size * hiddens, 0].set(input_to_hidden_ids.flatten())
conns = conns.at[: input_size * hiddens, 1].set(hidden_ids.flatten())
conns = conns.at[input_size * hiddens : total_connections, 0].set(
hidden_to_output_ids.flatten()
)
conns = conns.at[input_size * hiddens : total_connections, 1].set(
output_ids.flatten()
)
conns = conns.at[: input_size * hiddens + hiddens * output_size, 2].set(True)
conns = conns.at[: input_size * hiddens + hiddens * output_size, 3:].set(
conns_attrs
)
return nodes, conns
# random sparse connections with no hidden nodes
def init_no_hidden_random(genome, randkey):
k1, k2, k3 = jax.random.split(randkey, num=3)
input_idx, output_idx = genome.input_idx, genome.output_idx
nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan)
nodes = nodes.at[input_idx, 0].set(input_idx)
nodes = nodes.at[output_idx, 0].set(output_idx)
total_idx = len(input_idx) + len(output_idx)
rand_keys_n = jax.random.split(k1, num=total_idx)
input_keys = rand_keys_n[: len(input_idx)]
output_keys = rand_keys_n[len(input_idx) :]
node_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(0))
input_attrs = node_attr_func(input_keys)
output_attrs = node_attr_func(output_keys)
nodes = nodes.at[input_idx, 1:].set(input_attrs)
nodes = nodes.at[output_idx, 1:].set(output_attrs)
conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan)
num_connections_per_output = 4
total_connections = len(output_idx) * num_connections_per_output
def create_connections_for_output(key):
permuted_inputs = jax.random.permutation(key, input_idx)
selected_inputs = permuted_inputs[:num_connections_per_output]
return selected_inputs
conn_keys = jax.random.split(k2, num=len(output_idx))
connections = jax.vmap(create_connections_for_output)(conn_keys)
connections = connections.flatten()
output_repeats = jnp.repeat(output_idx, num_connections_per_output)
rand_keys_c = jax.random.split(k3, num=total_connections)
conns_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(0))
conns_attrs = conns_attr_func(rand_keys_c)
conns = conns.at[:total_connections, 0].set(connections)
conns = conns.at[:total_connections, 1].set(output_repeats)
conns = conns.at[:total_connections, 2].set(True) # enabled
conns = conns.at[:total_connections, 3:].set(conns_attrs)
return nodes, conns
return val

View File

@@ -4,7 +4,7 @@ from algorithm.neat import *
from problem.rl_env import BraxEnv
from utils import Act
if __name__ == '__main__':
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
@@ -17,17 +17,17 @@ if __name__ == '__main__':
activation_options=(Act.tanh,),
activation_default=Act.tanh,
),
output_transform=Act.tanh
output_transform=Act.tanh,
),
pop_size=1000,
species_size=10,
),
),
problem=BraxEnv(
env_name='ant',
env_name="ant",
),
generation_limit=10000,
fitness_target=5000
fitness_target=5000,
)
# initialize state

View File

@@ -4,7 +4,7 @@ from algorithm.neat import *
from problem.rl_env import BraxEnv
from utils import Act
if __name__ == '__main__':
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
@@ -16,17 +16,17 @@ if __name__ == '__main__':
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
)
),
),
pop_size=1000,
species_size=10,
),
),
problem=BraxEnv(
env_name='halfcheetah',
env_name="halfcheetah",
),
generation_limit=10000,
fitness_target=5000
fitness_target=5000,
)
# initialize state

View File

@@ -4,7 +4,7 @@ from algorithm.neat import *
from problem.rl_env import BraxEnv
from utils import Act
if __name__ == '__main__':
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
@@ -16,17 +16,17 @@ if __name__ == '__main__':
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
)
),
),
pop_size=100,
species_size=10,
),
),
problem=BraxEnv(
env_name='reacher',
env_name="reacher",
),
generation_limit=10000,
fitness_target=5000
fitness_target=5000,
)
# initialize state

View File

@@ -4,7 +4,7 @@ from algorithm.neat import *
from problem.rl_env import BraxEnv
from utils import Act
if __name__ == '__main__':
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
@@ -16,17 +16,17 @@ if __name__ == '__main__':
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
)
),
),
pop_size=10000,
species_size=10,
),
),
problem=BraxEnv(
env_name='walker2d',
env_name="walker2d",
),
generation_limit=10000,
fitness_target=5000
fitness_target=5000,
)
# initialize state

View File

@@ -4,7 +4,7 @@ from algorithm.neat import *
from problem.func_fit import XOR3d
from utils import Act
if __name__ == '__main__':
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
@@ -18,22 +18,22 @@ if __name__ == '__main__':
activation_options=(Act.tanh,),
),
output_transform=Act.sigmoid, # the activation function for output node
mutation=DefaultMutation(
node_add=0.05,
conn_add=0.2,
node_delete=0,
conn_delete=0,
),
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
survival_threshold=0.01, # magic
),
mutation=DefaultMutation(
node_add=0.05,
conn_add=0.2,
node_delete=0,
conn_delete=0,
)
),
problem=XOR3d(),
generation_limit=10000,
fitness_target=-1e-8
fitness_target=-1e-8,
)
# initialize state

View File

@@ -5,17 +5,28 @@ from utils import Act
from problem.func_fit import XOR3d
if __name__ == '__main__':
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=HyperNEAT(
substrate=FullSubstrate(
input_coors=[(-1, -1), (0.333, -1), (-0.333, -1), (1, -1)],
hidden_coors=[
(-1, -0.5), (0.333, -0.5), (-0.333, -0.5), (1, -0.5),
(-1, 0), (0.333, 0), (-0.333, 0), (1, 0),
(-1, 0.5), (0.333, 0.5), (-0.333, 0.5), (1, 0.5),
(-1, -0.5),
(0.333, -0.5),
(-0.333, -0.5),
(1, -0.5),
(-1, 0),
(0.333, 0),
(-0.333, 0),
(1, 0),
(-1, 0.5),
(0.333, 0.5),
(-0.333, 0.5),
(1, 0.5),
],
output_coors=[
(0, 1),
],
output_coors=[(0, 1), ],
),
neat=NEAT(
species=DefaultSpecies(
@@ -42,7 +53,7 @@ if __name__ == '__main__':
),
problem=XOR3d(),
generation_limit=300,
fitness_target=-1e-6
fitness_target=-1e-6,
)
# initialize state

View File

@@ -1,10 +1,11 @@
from pipeline import Pipeline
from algorithm.neat import *
from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse
from problem.func_fit import XOR3d
from utils.activation import ACT_ALL, Act
if __name__ == '__main__':
if __name__ == "__main__":
pipeline = Pipeline(
seed=0,
algorithm=NEAT(
@@ -15,27 +16,26 @@ if __name__ == '__main__':
max_nodes=50,
max_conns=100,
activate_time=5,
node_gene=DefaultNodeGene(
activation_options=ACT_ALL,
activation_replace_rate=0.2
node_gene=NodeGeneWithoutResponse(
activation_options=ACT_ALL, activation_replace_rate=0.2
),
output_transform=Act.sigmoid,
mutation=DefaultMutation(
node_add=0.05,
conn_add=0.2,
node_delete=0,
conn_delete=0,
),
output_transform=Act.sigmoid
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
survival_threshold=0.03,
),
mutation=DefaultMutation(
node_add=0.05,
conn_add=0.2,
node_delete=0,
conn_delete=0,
)
),
problem=XOR3d(),
generation_limit=10000,
fitness_target=-1e-8
fitness_target=-1e-8,
)
# initialize state

View File

@@ -5,7 +5,7 @@ from algorithm.neat import *
from problem.rl_env import GymNaxEnv
if __name__ == '__main__':
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
@@ -14,17 +14,19 @@ if __name__ == '__main__':
num_outputs=3,
max_nodes=50,
max_conns=100,
output_transform=lambda out: jnp.argmax(out) # the action of acrobot is {0, 1, 2}
output_transform=lambda out: jnp.argmax(
out
), # the action of acrobot is {0, 1, 2}
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name='Acrobot-v1',
env_name="Acrobot-v1",
),
generation_limit=10000,
fitness_target=-62
fitness_target=-62,
)
# initialize state

View File

@@ -5,7 +5,7 @@ from algorithm.neat import *
from problem.rl_env import GymNaxEnv
if __name__ == '__main__':
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
@@ -14,17 +14,19 @@ if __name__ == '__main__':
num_outputs=2,
max_nodes=50,
max_conns=100,
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1}
output_transform=lambda out: jnp.argmax(
out
), # the action of cartpole is {0, 1}
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name='CartPole-v1',
env_name="CartPole-v1",
),
generation_limit=10000,
fitness_target=500
fitness_target=500,
)
# initialize state

View File

@@ -10,11 +10,7 @@ from problem.rl_env import GymNaxConfig, GymNaxEnv
def example_conf():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=500,
pop_size=10000
),
basic=BasicConfig(seed=42, fitness_target=500, pop_size=10000),
neat=NeatConfig(
inputs=4,
outputs=1,
@@ -23,28 +19,31 @@ def example_conf():
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
hyperneat=HyperNeatConfig(
activation=Act.sigmoid,
inputs=4,
outputs=2
),
hyperneat=HyperNeatConfig(activation=Act.sigmoid, inputs=4, outputs=2),
substrate=NormalSubstrateConfig(
input_coors=((-1, -1), (-0.5, -1), (0, -1), (0.5, -1), (1, -1)),
hidden_coors=(
# (-1, -0.5), (-0.5, -0.5), (0, -0.5), (0.5, -0.5),
(1, 0), (-1, 0), (-0.5, 0), (0, 0), (0.5, 0), (1, 0),
(1, 0),
(-1, 0),
(-0.5, 0),
(0, 0),
(0.5, 0),
(1, 0),
# (1, 0.5), (-1, 0.5), (-0.5, 0.5), (0, 0.5), (0.5, 0.5), (1, 0.5),
),
output_coors=((-1, 1), (1, 1)),
),
problem=GymNaxConfig(
env_name='CartPole-v1',
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1}
)
env_name="CartPole-v1",
output_transform=lambda out: jnp.argmax(
out
), # the action of cartpole is {0, 1}
),
)
if __name__ == '__main__':
if __name__ == "__main__":
conf = example_conf()
algorithm = HyperNEAT(conf, NormalGene, NormalSubstrate)

View File

@@ -5,7 +5,7 @@ from algorithm.neat import *
from problem.rl_env import GymNaxEnv
if __name__ == '__main__':
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
@@ -14,17 +14,19 @@ if __name__ == '__main__':
num_outputs=3,
max_nodes=50,
max_conns=100,
output_transform=lambda out: jnp.argmax(out) # the action of mountain car is {0, 1, 2}
output_transform=lambda out: jnp.argmax(
out
), # the action of mountain car is {0, 1, 2}
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name='MountainCar-v0',
env_name="MountainCar-v0",
),
generation_limit=10000,
fitness_target=0
fitness_target=0,
)
# initialize state

View File

@@ -4,7 +4,7 @@ from algorithm.neat import *
from problem.rl_env import GymNaxEnv
from utils import Act
if __name__ == '__main__':
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
@@ -14,19 +14,19 @@ if __name__ == '__main__':
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh, ),
activation_options=(Act.tanh,),
activation_default=Act.tanh,
)
),
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name='MountainCarContinuous-v0',
env_name="MountainCarContinuous-v0",
),
generation_limit=10000,
fitness_target=500
fitness_target=500,
)
# initialize state

View File

@@ -4,7 +4,7 @@ from algorithm.neat import *
from problem.rl_env import GymNaxEnv
from utils import Act
if __name__ == '__main__':
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
@@ -17,17 +17,18 @@ if __name__ == '__main__':
activation_options=(Act.tanh,),
activation_default=Act.tanh,
),
output_transform=lambda out: out * 2 # the action of pendulum is [-2, 2]
output_transform=lambda out: out
* 2, # the action of pendulum is [-2, 2]
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name='Pendulum-v1',
env_name="Pendulum-v1",
),
generation_limit=10000,
fitness_target=0
fitness_target=0,
)
# initialize state

View File

@@ -5,7 +5,7 @@ from algorithm.neat import *
from problem.rl_env import GymNaxEnv
if __name__ == '__main__':
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
@@ -20,10 +20,10 @@ if __name__ == '__main__':
),
),
problem=GymNaxEnv(
env_name='Reacher-misc',
env_name="Reacher-misc",
),
generation_limit=10000,
fitness_target =500
fitness_target=500,
)
# initialize state

View File

@@ -1,4 +1,5 @@
import ray
ray.init(num_gpus=2)
available_resources = ray.available_resources()

View File

@@ -10,7 +10,6 @@ from utils import State
class Pipeline:
def __init__(
self,
algorithm: BaseAlgorithm,
@@ -31,32 +30,35 @@ class Pipeline:
# print(self.problem.input_shape, self.problem.output_shape)
# TODO: make each algorithm's input_num and output_num
assert algorithm.num_inputs == self.problem.input_shape[-1], \
f"algorithm input shape is {algorithm.num_inputs} but problem input shape is {self.problem.input_shape}"
assert (
algorithm.num_inputs == self.problem.input_shape[-1]
), f"algorithm input shape is {algorithm.num_inputs} but problem input shape is {self.problem.input_shape}"
self.best_genome = None
self.best_fitness = float('-inf')
self.best_fitness = float("-inf")
self.generation_timestamp = None
def setup(self, state=State()):
print("initializing")
state = state.register(randkey=jax.random.PRNGKey(self.seed))
state = self.algorithm.setup(state)
state = self.problem.setup(state)
print("initializing finished")
return state
def step(self, state):
randkey_, randkey = jax.random.split(state.randkey)
keys = jax.random.split(randkey_, self.pop_size)
state, pop = self.algorithm.ask(state)
pop = self.algorithm.ask(state)
state, pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0), out_axes=(None, 0))(state, pop)
pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0))(
state, pop
)
state, fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0), out_axes=(None, 0))(
keys,
state,
self.algorithm.forward,
pop_transformed
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
state, keys, self.algorithm.forward, pop_transformed
)
state = self.algorithm.tell(state, fitnesses)
@@ -67,13 +69,15 @@ class Pipeline:
print("start compile")
tic = time.time()
compiled_step = jax.jit(self.step).lower(state).compile()
print(f"compile finished, cost time: {time.time() - tic:.6f}s", )
print(
f"compile finished, cost time: {time.time() - tic:.6f}s",
)
for _ in range(self.generation_limit):
self.generation_timestamp = time.time()
state, previous_pop = self.algorithm.ask(state)
previous_pop = self.algorithm.ask(state)
state, fitnesses = compiled_step(state)
@@ -98,7 +102,12 @@ class Pipeline:
def analysis(self, state, pop, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
max_f, min_f, mean_f, std_f = (
max(fitnesses),
min(fitnesses),
np.mean(fitnesses),
np.std(fitnesses),
)
new_timestamp = time.time()
@@ -112,10 +121,14 @@ class Pipeline:
member_count = jax.device_get(self.algorithm.member_count(state))
species_sizes = [int(i) for i in member_count if i > 0]
print(f"Generation: {self.algorithm.generation(state)}",
print(
f"Generation: {self.algorithm.generation(state)}",
f"species: {len(species_sizes)}, {species_sizes}",
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms",
)
def show(self, state, best, *args, **kwargs):
state, transformed = self.algorithm.transform(state, best)
self.problem.show(state.randkey, state, self.algorithm.forward, transformed, *args, **kwargs)
transformed = self.algorithm.transform(state, best)
self.problem.show(
state, state.randkey, self.algorithm.forward, transformed, *args, **kwargs
)

View File

@@ -10,7 +10,7 @@ class BaseProblem:
"""initialize the state of the problem"""
return state
def evaluate(self, randkey, state: State, act_func: Callable, params):
def evaluate(self, state: State, randkey, act_func: Callable, params):
"""evaluate one individual"""
raise NotImplementedError
@@ -32,7 +32,7 @@ class BaseProblem:
"""
raise NotImplementedError
def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs):
def show(self, state: State, randkey, act_func: Callable, params, *args, **kwargs):
"""
show how a genome perform in this problem
"""

View File

@@ -8,42 +8,44 @@ from .. import BaseProblem
class FuncFit(BaseProblem):
jitable = True
def __init__(self,
error_method: str = 'mse'
):
def __init__(self, error_method: str = "mse"):
super().__init__()
assert error_method in {'mse', 'rmse', 'mae', 'mape'}
assert error_method in {"mse", "rmse", "mae", "mape"}
self.error_method = error_method
def setup(self, state: State = State()):
return state
def evaluate(self, randkey, state, act_func, params):
def evaluate(self, state, randkey, act_func, params):
state, predict = jax.vmap(act_func, in_axes=(None, 0, None), out_axes=(None, 0))(state, self.inputs, params)
predict = jax.vmap(act_func, in_axes=(None, 0, None))(
state, self.inputs, params
)
if self.error_method == 'mse':
if self.error_method == "mse":
loss = jnp.mean((predict - self.targets) ** 2)
elif self.error_method == 'rmse':
elif self.error_method == "rmse":
loss = jnp.sqrt(jnp.mean((predict - self.targets) ** 2))
elif self.error_method == 'mae':
elif self.error_method == "mae":
loss = jnp.mean(jnp.abs(predict - self.targets))
elif self.error_method == 'mape':
elif self.error_method == "mape":
loss = jnp.mean(jnp.abs((predict - self.targets) / self.targets))
else:
raise NotImplementedError
return state, -loss
return -loss
def show(self, randkey, state, act_func, params, *args, **kwargs):
state, predict = jax.vmap(act_func, in_axes=(None, 0, None), out_axes=(None, 0))(state, self.inputs, params)
def show(self, state, randkey, act_func, params, *args, **kwargs):
predict = jax.vmap(act_func, in_axes=(None, 0, None))(
state, self.inputs, params
)
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
state, loss = self.evaluate(randkey, state, act_func, params)
loss = self.evaluate(state, randkey, act_func, params)
loss = -loss
msg = ""

View File

@@ -4,27 +4,16 @@ from .func_fit import FuncFit
class XOR(FuncFit):
def __init__(self, error_method: str = 'mse'):
def __init__(self, error_method: str = "mse"):
super().__init__(error_method)
@property
def inputs(self):
return np.array([
[0, 0],
[0, 1],
[1, 0],
[1, 1]
])
return np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
@property
def targets(self):
return np.array([
[0],
[1],
[1],
[0]
])
return np.array([[0], [1], [1], [0]])
@property
def input_shape(self):

View File

@@ -4,13 +4,13 @@ from .func_fit import FuncFit
class XOR3d(FuncFit):
def __init__(self, error_method: str = 'mse'):
def __init__(self, error_method: str = "mse"):
super().__init__(error_method)
@property
def inputs(self):
return np.array([
return np.array(
[
[0, 0, 0],
[0, 0, 1],
[0, 1, 0],
@@ -19,20 +19,12 @@ class XOR3d(FuncFit):
[1, 0, 1],
[1, 1, 0],
[1, 1, 1],
])
]
)
@property
def targets(self):
return np.array([
[0],
[1],
[1],
[0],
[1],
[0],
[0],
[1]
])
return np.array([[0], [1], [1], [0], [1], [0], [0], [1]])
@property
def input_shape(self):

View File

@@ -25,7 +25,19 @@ class BraxEnv(RLEnv):
def output_shape(self):
return (self.env.action_size,)
def show(self, randkey, state, act_func, params, save_path=None, height=512, width=512, duration=0.1, *args, **kwargs):
def show(
self,
state,
randkey,
act_func,
params,
save_path=None,
height=512,
width=512,
duration=0.1,
*args,
**kwargs
):
import jax
import imageio
@@ -48,11 +60,13 @@ class BraxEnv(RLEnv):
key, env_state, obs, r, done = jax.jit(step)(randkey, env_state, obs)
reward += r
imgs = [image.render_array(sys=self.env.sys, state=s, width=width, height=height) for s in
tqdm(state_histories, desc="Rendering")]
imgs = [
image.render_array(sys=self.env.sys, state=s, width=width, height=height)
for s in tqdm(state_histories, desc="Rendering")
]
def create_gif(image_list, gif_name, duration):
with imageio.get_writer(gif_name, mode='I', duration=duration) as writer:
with imageio.get_writer(gif_name, mode="I", duration=duration) as writer:
for image in image_list:
formatted_image = np.array(image, dtype=np.uint8)
writer.append_data(formatted_image)
@@ -60,5 +74,3 @@ class BraxEnv(RLEnv):
create_gif(imgs, save_path, duration=0.1)
print("Gif saved to: ", save_path)
print("Total reward: ", reward)

View File

@@ -4,7 +4,6 @@ from .rl_jit import RLEnv
class GymNaxEnv(RLEnv):
def __init__(self, env_name):
super().__init__()
assert env_name in gymnax.registered_envs, f"Env {env_name} not registered"
@@ -24,5 +23,5 @@ class GymNaxEnv(RLEnv):
def output_shape(self):
return self.env.action_space(self.env_params).shape
def show(self, randkey, state, act_func, params, *args, **kwargs):
def show(self, state, randkey, act_func, params, *args, **kwargs):
raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).")

View File

@@ -12,28 +12,28 @@ class RLEnv(BaseProblem):
super().__init__()
self.max_step = max_step
def evaluate(self, randkey, state, act_func, params):
def evaluate(self, state, randkey, act_func, params):
rng_reset, rng_episode = jax.random.split(randkey)
init_obs, init_env_state = self.reset(rng_reset)
def cond_func(carry):
_, _, _, _, done, _, count = carry
_, _, _, done, _, count = carry
return ~done & (count < self.max_step)
def body_func(carry):
state_, obs, env_state, rng, done, tr, count = carry # tr -> total reward
state_, action = act_func(state_, obs, params)
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
obs, env_state, rng, done, tr, count = carry # tr -> total reward
action = act_func(state, obs, params)
next_obs, next_env_state, reward, done, _ = self.step(
rng, env_state, action
)
next_rng, _ = jax.random.split(rng)
return state_, next_obs, next_env_state, next_rng, done, tr + reward, count + 1
return next_obs, next_env_state, next_rng, done, tr + reward, count + 1
state, _, _, _, _, total_reward, _ = jax.lax.while_loop(
cond_func,
body_func,
(state, init_obs, init_env_state, rng_episode, False, 0.0, 0)
_, _, _, _, total_reward, _ = jax.lax.while_loop(
cond_func, body_func, (init_obs, init_env_state, rng_episode, False, 0.0, 0)
)
return state, total_reward
return total_reward
# @partial(jax.jit, static_argnums=(0,))
def step(self, randkey, env_state, action):
@@ -57,5 +57,5 @@ class RLEnv(BaseProblem):
def output_shape(self):
raise NotImplementedError
def show(self, randkey, state, act_func, params, *args, **kwargs):
def show(self, state, randkey, act_func, params, *args, **kwargs):
raise NotImplementedError

View File

@@ -36,7 +36,9 @@ def main():
elite_mask = jnp.zeros((1000,), dtype=jnp.bool_)
elite_mask = elite_mask.at[:5].set(1)
state = algorithm.create_next_generation(jax.random.key(0), state, winner, losser, elite_mask)
state = algorithm.create_next_generation(
jax.random.key(0), state, winner, losser, elite_mask
)
pop_nodes, pop_conns = algorithm.species.ask(state.species)
transforms = batch_transform(pop_nodes, pop_conns)
@@ -48,5 +50,5 @@ def main():
print(_)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@@ -19,7 +19,7 @@ def main():
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
)
),
)
transformed = genome.transform(nodes, conns)
@@ -35,7 +35,7 @@ def main():
print(output)
if __name__ == '__main__':
if __name__ == "__main__":
a = jnp.array([1, 3, 5, 6, 8])
b = jnp.array([1, 2, 3])
print(jnp.isin(a, b))

View File

@@ -2,6 +2,7 @@ from algorithm.neat import *
from utils import Act, Agg, State
import jax, jax.numpy as jnp
from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse
def test_default():
@@ -135,3 +136,29 @@ def test_recurrent():
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
# expected: [[0.5], [0.75], [0.5], [0.75]]
def test_random_initialize():
genome = DefaultGenome(
num_inputs=2,
num_outputs=1,
max_nodes=5,
max_conns=4,
node_gene=NodeGeneWithoutResponse(
activation_default=Act.identity,
activation_options=(Act.identity,),
aggregation_default=Agg.sum,
aggregation_options=(Agg.sum,),
),
)
state = genome.setup()
key = jax.random.PRNGKey(0)
nodes, conns = genome.initialize(state, key)
transformed = genome.transform(state, nodes, conns)
print(*transformed, sep="\n")
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, None)))(
state, inputs, transformed
)
print(outputs)

View File

@@ -19,11 +19,11 @@ def main():
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
)
),
)
transformed = genome.transform(nodes, conns)
print(*transformed, sep='\n')
print(*transformed, sep="\n")
key = jax.random.key(0)
dummy_input = jnp.zeros((8,))
@@ -31,5 +31,5 @@ def main():
print(output)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@@ -116,3 +116,41 @@ def argmin_with_mask(arr, mask):
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx
def add_node(nodes, new_key: int, attrs):
"""
Add a new node to the genome.
The new node will place at the first NaN row.
"""
exist_keys = nodes[:, 0]
pos = fetch_first(jnp.isnan(exist_keys))
new_nodes = nodes.at[pos, 0].set(new_key)
return new_nodes.at[pos, 1:].set(attrs)
def delete_node_by_pos(nodes, pos):
"""
Delete a node from the genome.
Delete the node by its pos in nodes.
"""
return nodes.at[pos].set(jnp.nan)
def add_conn(conns, i_key, o_key, enable: bool, attrs):
"""
Add a new connection to the genome.
The new connection will place at the first NaN row.
"""
con_keys = conns[:, 0]
pos = fetch_first(jnp.isnan(con_keys))
new_conns = conns.at[pos, 0:3].set(jnp.array([i_key, o_key, enable]))
return new_conns.at[pos, 3:].set(attrs)
def delete_conn_by_pos(conns, pos):
"""
Delete a connection from the genome.
Delete the connection by its idx.
"""
return conns.at[pos].set(jnp.nan)