change repo structure; modify readme

This commit is contained in:
wls2002
2024-03-26 21:58:27 +08:00
parent 6970e6a6d5
commit 47dbcbea80
69 changed files with 74 additions and 60 deletions

View File

@@ -0,0 +1,2 @@
from .base import BaseAlgorithm
from .neat import NEAT

View File

@@ -0,0 +1,45 @@
from utils import State
class BaseAlgorithm:
def setup(self, randkey):
"""initialize the state of the algorithm"""
raise NotImplementedError
def ask(self, state: State):
"""require the population to be evaluated"""
raise NotImplementedError
def tell(self, state: State, fitness):
"""update the state of the algorithm"""
raise NotImplementedError
def transform(self, individual):
"""transform the genome into a neural network"""
raise NotImplementedError
def forward(self, inputs, transformed):
raise NotImplementedError
@property
def num_inputs(self):
raise NotImplementedError
@property
def num_outputs(self):
raise NotImplementedError
@property
def pop_size(self):
raise NotImplementedError
def member_count(self, state: State):
# to analysis the species
raise NotImplementedError
def generation(self, state: State):
# to analysis the algorithm
raise NotImplementedError

View File

@@ -0,0 +1,2 @@
from .hyperneat import HyperNEAT
from .substrate import BaseSubstrate, DefaultSubstrate, FullSubstrate

View File

@@ -0,0 +1,116 @@
import jax, jax.numpy as jnp
from utils import State, Act, Agg
from .. import BaseAlgorithm, NEAT
from ..neat.gene import BaseNodeGene, BaseConnGene
from ..neat.genome import RecurrentGenome
from .substrate import *
class HyperNEAT(BaseAlgorithm):
def __init__(
self,
substrate: BaseSubstrate,
neat: NEAT,
below_threshold: float = 0.3,
max_weight: float = 5.,
activation=Act.sigmoid,
aggregation=Agg.sum,
activate_time: int = 10,
):
assert substrate.query_coors.shape[1] == neat.num_inputs, \
"Substrate input size should be equal to NEAT input size"
self.substrate = substrate
self.neat = neat
self.below_threshold = below_threshold
self.max_weight = max_weight
self.hyper_genome = RecurrentGenome(
num_inputs=substrate.num_inputs,
num_outputs=substrate.num_outputs,
max_nodes=substrate.nodes_cnt,
max_conns=substrate.conns_cnt,
node_gene=HyperNodeGene(activation, aggregation),
conn_gene=HyperNEATConnGene(),
activate_time=activate_time,
)
def setup(self, randkey):
return State(
neat_state=self.neat.setup(randkey)
)
def ask(self, state: State):
return self.neat.ask(state.neat_state)
def tell(self, state: State, fitness):
return state.update(
neat_state=self.neat.tell(state.neat_state, fitness)
)
def transform(self, individual):
transformed = self.neat.transform(individual)
query_res = jax.vmap(self.neat.forward, in_axes=(0, None))(self.substrate.query_coors, transformed)
# mute the connection with weight below threshold
query_res = jnp.where(
(-self.below_threshold < query_res) & (query_res < self.below_threshold),
0.,
query_res
)
# make query res in range [-max_weight, max_weight]
query_res = jnp.where(query_res > 0, query_res - self.below_threshold, query_res)
query_res = jnp.where(query_res < 0, query_res + self.below_threshold, query_res)
query_res = query_res / (1 - self.below_threshold) * self.max_weight
h_nodes, h_conns = self.substrate.make_nodes(query_res), self.substrate.make_conn(query_res)
return self.hyper_genome.transform(h_nodes, h_conns)
def forward(self, inputs, transformed):
# add bias
inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])])
return self.hyper_genome.forward(inputs_with_bias, transformed)
@property
def num_inputs(self):
return self.substrate.num_inputs - 1 # remove bias
@property
def num_outputs(self):
return self.substrate.num_outputs
@property
def pop_size(self):
return self.neat.pop_size
def member_count(self, state: State):
return self.neat.member_count(state.neat_state)
def generation(self, state: State):
return self.neat.generation(state.neat_state)
class HyperNodeGene(BaseNodeGene):
def __init__(self,
activation=Act.sigmoid,
aggregation=Agg.sum,
):
super().__init__()
self.activation = activation
self.aggregation = aggregation
def forward(self, attrs, inputs):
return self.activation(
self.aggregation(inputs)
)
class HyperNEATConnGene(BaseConnGene):
custom_attrs = ['weight']
def forward(self, attrs, inputs):
weight = attrs[0]
return inputs * weight

View File

@@ -0,0 +1,3 @@
from .base import BaseSubstrate
from .default import DefaultSubstrate
from .full import FullSubstrate

View File

@@ -0,0 +1,27 @@
class BaseSubstrate:
def make_nodes(self, query_res):
raise NotImplementedError
def make_conn(self, query_res):
raise NotImplementedError
@property
def query_coors(self):
raise NotImplementedError
@property
def num_inputs(self):
raise NotImplementedError
@property
def num_outputs(self):
raise NotImplementedError
@property
def nodes_cnt(self):
raise NotImplementedError
@property
def conns_cnt(self):
raise NotImplementedError

View File

@@ -0,0 +1,38 @@
import jax.numpy as jnp
from . import BaseSubstrate
class DefaultSubstrate(BaseSubstrate):
def __init__(self, num_inputs, num_outputs, coors, nodes, conns):
self.inputs = num_inputs
self.outputs = num_outputs
self.coors = jnp.array(coors)
self.nodes = jnp.array(nodes)
self.conns = jnp.array(conns)
def make_nodes(self, query_res):
return self.nodes
def make_conn(self, query_res):
return self.conns.at[:, 3:].set(query_res) # change weight
@property
def query_coors(self):
return self.coors
@property
def num_inputs(self):
return self.inputs
@property
def num_outputs(self):
return self.outputs
@property
def nodes_cnt(self):
return self.nodes.shape[0]
@property
def conns_cnt(self):
return self.conns.shape[0]

View File

@@ -0,0 +1,76 @@
import numpy as np
from .default import DefaultSubstrate
class FullSubstrate(DefaultSubstrate):
def __init__(self,
input_coors=((-1, -1), (0, -1), (1, -1)),
hidden_coors=((-1, 0), (0, 0), (1, 0)),
output_coors=((0, 1),),
):
query_coors, nodes, conns = analysis_substrate(input_coors, output_coors, hidden_coors)
super().__init__(
len(input_coors),
len(output_coors),
query_coors,
nodes,
conns
)
def analysis_substrate(input_coors, output_coors, hidden_coors):
input_coors = np.array(input_coors)
output_coors = np.array(output_coors)
hidden_coors = np.array(hidden_coors)
cd = input_coors.shape[1] # coordinate dimensions
si = input_coors.shape[0] # input coordinate size
so = output_coors.shape[0] # output coordinate size
sh = hidden_coors.shape[0] # hidden coordinate size
input_idx = np.arange(si)
output_idx = np.arange(si, si + so)
hidden_idx = np.arange(si + so, si + so + sh)
total_conns = si * sh + sh * sh + sh * so
query_coors = np.zeros((total_conns, cd * 2))
correspond_keys = np.zeros((total_conns, 2))
# connect input to hidden
aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, input_coors, hidden_coors)
query_coors[0: si * sh, :] = aux_coors
correspond_keys[0: si * sh, :] = aux_keys
# connect hidden to hidden
aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, hidden_coors, hidden_coors)
query_coors[si * sh: si * sh + sh * sh, :] = aux_coors
correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys
# connect hidden to output
aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, hidden_coors, output_coors)
query_coors[si * sh + sh * sh:, :] = aux_coors
correspond_keys[si * sh + sh * sh:, :] = aux_keys
nodes = np.concatenate((input_idx, output_idx, hidden_idx))[..., np.newaxis]
conns = np.zeros((correspond_keys.shape[0], 4), dtype=np.float32) # input_idx, output_idx, enabled, weight
conns[:, 0:2] = correspond_keys
conns[:, 2] = 1 # enabled is True
return query_coors, nodes, conns
def cartesian_product(keys1, keys2, coors1, coors2):
len1 = keys1.shape[0]
len2 = keys2.shape[0]
repeated_coors1 = np.repeat(coors1, len2, axis=0)
repeated_keys1 = np.repeat(keys1, len2)
tiled_coors2 = np.tile(coors2, (len1, 1))
tiled_keys2 = np.tile(keys2, len1)
new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1)
correspond_keys = np.column_stack((repeated_keys1, tiled_keys2))
return new_coors, correspond_keys

View File

@@ -0,0 +1,5 @@
from .gene import *
from .genome import *
from .species import *
from .neat import NEAT

View File

@@ -0,0 +1,2 @@
from .crossover import BaseCrossover, DefaultCrossover
from .mutation import BaseMutation, DefaultMutation

View File

@@ -0,0 +1,2 @@
from .base import BaseCrossover
from .default import DefaultCrossover

View File

@@ -0,0 +1,3 @@
class BaseCrossover:
def __call__(self, randkey, genome, nodes1, nodes2, conns1, conns2):
raise NotImplementedError

View File

@@ -0,0 +1,67 @@
import jax, jax.numpy as jnp
from .base import BaseCrossover
class DefaultCrossover(BaseCrossover):
def __call__(self, randkey, genome, nodes1, conns1, nodes2, conns2):
"""
use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
"""
randkey_1, randkey_2, key = jax.random.split(randkey, 3)
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
# make homologous genes align in nodes2 align with nodes1
nodes2 = self.align_array(keys1, keys2, nodes2, False)
# For not homologous genes, use the value of nodes1(winner)
# For homologous genes, use the crossover result between nodes1 and nodes2
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, self.crossover_gene(randkey_1, nodes1, nodes2))
# crossover connections
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
conns2 = self.align_array(con_keys1, con_keys2, conns2, True)
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, self.crossover_gene(randkey_2, conns1, conns2))
return new_nodes, new_conns
def align_array(self, seq1, seq2, ar2, is_conn: bool):
"""
After I review this code, I found that it is the most difficult part of the code. Please never change it!
make ar2 align with ar1.
:param seq1:
:param seq2:
:param ar2:
:param is_conn:
:return:
align means to intersect part of ar2 will be at the same position as ar1,
non-intersect part of ar2 will be set to Nan
"""
seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :]
mask = (seq1 == seq2) & (~jnp.isnan(seq1))
if is_conn:
mask = jnp.all(mask, axis=2)
intersect_mask = mask.any(axis=1)
idx = jnp.arange(0, len(seq1))
idx_fixed = jnp.dot(mask, idx)
refactor_ar2 = jnp.where(intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan)
return refactor_ar2
def crossover_gene(self, rand_key, g1, g2):
"""
crossover two genes
:param rand_key:
:param g1:
:param g2:
:return:
only gene with the same key will be crossover, thus don't need to consider change key
"""
r = jax.random.uniform(rand_key, shape=g1.shape)
return jnp.where(r > 0.5, g1, g2)

View File

@@ -0,0 +1,2 @@
from .base import BaseMutation
from .default import DefaultMutation

View File

@@ -0,0 +1,3 @@
class BaseMutation:
def __call__(self, key, genome, nodes, conns, new_node_key):
raise NotImplementedError

View File

@@ -0,0 +1,202 @@
import jax, jax.numpy as jnp
from . import BaseMutation
from utils import fetch_first, fetch_random, I_INT, unflatten_conns, check_cycles
class DefaultMutation(BaseMutation):
def __init__(
self,
conn_add: float = 0.4,
conn_delete: float = 0,
node_add: float = 0.2,
node_delete: float = 0,
):
self.conn_add = conn_add
self.conn_delete = conn_delete
self.node_add = node_add
self.node_delete = node_delete
def __call__(self, randkey, genome, nodes, conns, new_node_key):
k1, k2 = jax.random.split(randkey)
nodes, conns = self.mutate_structure(k1, genome, nodes, conns, new_node_key)
nodes, conns = self.mutate_values(k2, genome, nodes, conns)
return nodes, conns
def mutate_structure(self, randkey, genome, nodes, conns, new_node_key):
def mutate_add_node(key_, nodes_, conns_):
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
def successful_add_node():
# disable the connection
new_conns = conns_.at[idx, 2].set(False)
# add a new node
new_nodes = genome.add_node(nodes_, new_node_key, genome.node_gene.new_custom_attrs())
# add two new connections
new_conns = genome.add_conn(new_conns, i_key, new_node_key, True, genome.conn_gene.new_custom_attrs())
new_conns = genome.add_conn(new_conns, new_node_key, o_key, True, genome.conn_gene.new_custom_attrs())
return new_nodes, new_conns
return jax.lax.cond(
idx == I_INT,
lambda: (nodes_, conns_), # do nothing
successful_add_node
)
def mutate_delete_node(key_, nodes_, conns_):
# randomly choose a node
key, idx = self.choice_node_key(key_, nodes_, genome.input_idx, genome.output_idx,
allow_input_keys=False, allow_output_keys=False)
def successful_delete_node():
# delete the node
new_nodes = genome.delete_node_by_pos(nodes_, idx)
# delete all connections
new_conns = jnp.where(
((conns_[:, 0] == key) | (conns_[:, 1] == key))[:, None],
jnp.nan,
conns_
)
return new_nodes, new_conns
return jax.lax.cond(
idx == I_INT,
lambda: (nodes_, conns_), # do nothing
successful_delete_node
)
def mutate_add_conn(key_, nodes_, conns_):
# randomly choose two nodes
k1_, k2_ = jax.random.split(key_, num=2)
# input node of the connection can be any node
i_key, from_idx = self.choice_node_key(k1_, nodes_, genome.input_idx, genome.output_idx,
allow_input_keys=True, allow_output_keys=True)
# output node of the connection can be any node except input node
o_key, to_idx = self.choice_node_key(k2_, nodes_, genome.input_idx, genome.output_idx,
allow_input_keys=False, allow_output_keys=True)
conn_pos = fetch_first((conns_[:, 0] == i_key) & (conns_[:, 1] == o_key))
is_already_exist = conn_pos != I_INT
def nothing():
return nodes_, conns_
def successful():
return nodes_, genome.add_conn(conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs())
def already_exist():
return nodes_, conns_.at[conn_pos, 2].set(True)
if genome.network_type == 'feedforward':
u_cons = unflatten_conns(nodes_, conns_)
cons_exist = ~jnp.isnan(u_cons[0, :, :])
is_cycle = check_cycles(nodes_, cons_exist, from_idx, to_idx)
return jax.lax.cond(
is_already_exist,
already_exist,
lambda:
jax.lax.cond(
is_cycle,
nothing,
successful
)
)
elif genome.network_type == 'recurrent':
return jax.lax.cond(
is_already_exist,
already_exist,
successful
)
else:
raise ValueError(f"Invalid network type: {genome.network_type}")
def mutate_delete_conn(key_, nodes_, conns_):
# randomly choose a connection
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
def successfully_delete_connection():
return nodes_, genome.delete_conn_by_pos(conns_, idx)
return jax.lax.cond(
idx == I_INT,
lambda: (nodes_, conns_), # nothing
successfully_delete_connection
)
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
def no(key_, nodes_, conns_):
return nodes_, conns_
nodes, conns = jax.lax.cond(r1 < self.node_add, mutate_add_node, no, k1, nodes, conns)
nodes, conns = jax.lax.cond(r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns)
nodes, conns = jax.lax.cond(r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns)
nodes, conns = jax.lax.cond(r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns)
return nodes, conns
def mutate_values(self, randkey, genome, nodes, conns):
k1, k2 = jax.random.split(randkey, num=2)
nodes_keys = jax.random.split(k1, num=nodes.shape[0])
conns_keys = jax.random.split(k2, num=conns.shape[0])
new_nodes = jax.vmap(genome.node_gene.mutate, in_axes=(0, 0))(nodes_keys, nodes)
new_conns = jax.vmap(genome.conn_gene.mutate, in_axes=(0, 0))(conns_keys, conns)
# nan nodes not changed
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)
new_conns = jnp.where(jnp.isnan(conns), jnp.nan, new_conns)
return new_nodes, new_conns
def choice_node_key(self, rand_key, nodes, input_idx, output_idx,
allow_input_keys: bool = False, allow_output_keys: bool = False):
"""
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
:param rand_key:
:param nodes:
:param input_idx:
:param output_idx:
:param allow_input_keys:
:param allow_output_keys:
:return: return its key and position(idx)
"""
node_keys = nodes[:, 0]
mask = ~jnp.isnan(node_keys)
if not allow_input_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_idx))
if not allow_output_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_idx))
idx = fetch_random(rand_key, mask)
key = jnp.where(idx != I_INT, nodes[idx, 0], jnp.nan)
return key, idx
def choice_connection_key(self, rand_key, conns):
"""
Randomly choose a connection key from the given connections.
:return: i_key, o_key, idx
"""
idx = fetch_random(rand_key, ~jnp.isnan(conns[:, 0]))
i_key = jnp.where(idx != I_INT, conns[idx, 0], jnp.nan)
o_key = jnp.where(idx != I_INT, conns[idx, 1], jnp.nan)
return i_key, o_key, idx

View File

@@ -0,0 +1,3 @@
from .base import BaseGene
from .conn import *
from .node import *

View File

@@ -0,0 +1,23 @@
class BaseGene:
"Base class for node genes or connection genes."
fixed_attrs = []
custom_attrs = []
def __init__(self):
pass
def new_custom_attrs(self):
raise NotImplementedError
def mutate(self, randkey, gene):
raise NotImplementedError
def distance(self, gene1, gene2):
raise NotImplementedError
def forward(self, attrs, inputs):
raise NotImplementedError
@property
def length(self):
return len(self.fixed_attrs) + len(self.custom_attrs)

View File

@@ -0,0 +1,2 @@
from .base import BaseConnGene
from .default import DefaultConnGene

View File

@@ -0,0 +1,12 @@
from .. import BaseGene
class BaseConnGene(BaseGene):
"Base class for connection genes."
fixed_attrs = ['input_index', 'output_index', 'enabled']
def __init__(self):
super().__init__()
def forward(self, attrs, inputs):
raise NotImplementedError

View File

@@ -0,0 +1,50 @@
import jax.numpy as jnp
from utils import mutate_float
from . import BaseConnGene
class DefaultConnGene(BaseConnGene):
"Default connection gene, with the same behavior as in NEAT-python."
custom_attrs = ['weight']
def __init__(
self,
weight_init_mean: float = 0.0,
weight_init_std: float = 1.0,
weight_mutate_power: float = 0.5,
weight_mutate_rate: float = 0.8,
weight_replace_rate: float = 0.1,
):
super().__init__()
self.weight_init_mean = weight_init_mean
self.weight_init_std = weight_init_std
self.weight_mutate_power = weight_mutate_power
self.weight_mutate_rate = weight_mutate_rate
self.weight_replace_rate = weight_replace_rate
def new_custom_attrs(self):
return jnp.array([self.weight_init_mean])
def mutate(self, key, conn):
input_index = conn[0]
output_index = conn[1]
enabled = conn[2]
weight = mutate_float(key,
conn[3],
self.weight_init_mean,
self.weight_init_std,
self.weight_mutate_power,
self.weight_mutate_rate,
self.weight_replace_rate
)
return jnp.array([input_index, output_index, enabled, weight])
def distance(self, attrs1, attrs2):
return (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight
def forward(self, attrs, inputs):
weight = attrs[0]
return inputs * weight

View File

@@ -0,0 +1,2 @@
from .base import BaseNodeGene
from .default import DefaultNodeGene

View File

@@ -0,0 +1,12 @@
from .. import BaseGene
class BaseNodeGene(BaseGene):
"Base class for node genes."
fixed_attrs = ["index"]
def __init__(self):
super().__init__()
def forward(self, attrs, inputs):
raise NotImplementedError

View File

@@ -0,0 +1,95 @@
from typing import Tuple
import jax, jax.numpy as jnp
from utils import Act, Agg, act, agg, mutate_int, mutate_float
from . import BaseNodeGene
class DefaultNodeGene(BaseNodeGene):
"Default node gene, with the same behavior as in NEAT-python."
custom_attrs = ['bias', 'response', 'aggregation', 'activation']
def __init__(
self,
bias_init_mean: float = 0.0,
bias_init_std: float = 1.0,
bias_mutate_power: float = 0.5,
bias_mutate_rate: float = 0.7,
bias_replace_rate: float = 0.1,
response_init_mean: float = 1.0,
response_init_std: float = 0.0,
response_mutate_power: float = 0.5,
response_mutate_rate: float = 0.7,
response_replace_rate: float = 0.1,
activation_default: callable = Act.sigmoid,
activation_options: Tuple = (Act.sigmoid,),
activation_replace_rate: float = 0.1,
aggregation_default: callable = Agg.sum,
aggregation_options: Tuple = (Agg.sum,),
aggregation_replace_rate: float = 0.1,
):
super().__init__()
self.bias_init_mean = bias_init_mean
self.bias_init_std = bias_init_std
self.bias_mutate_power = bias_mutate_power
self.bias_mutate_rate = bias_mutate_rate
self.bias_replace_rate = bias_replace_rate
self.response_init_mean = response_init_mean
self.response_init_std = response_init_std
self.response_mutate_power = response_mutate_power
self.response_mutate_rate = response_mutate_rate
self.response_replace_rate = response_replace_rate
self.activation_default = activation_options.index(activation_default)
self.activation_options = activation_options
self.activation_indices = jnp.arange(len(activation_options))
self.activation_replace_rate = activation_replace_rate
self.aggregation_default = aggregation_options.index(aggregation_default)
self.aggregation_options = aggregation_options
self.aggregation_indices = jnp.arange(len(aggregation_options))
self.aggregation_replace_rate = aggregation_replace_rate
def new_custom_attrs(self):
return jnp.array(
[self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default]
)
def mutate(self, key, node):
k1, k2, k3, k4 = jax.random.split(key, num=4)
index = node[0]
bias = mutate_float(k1, node[1], self.bias_init_mean, self.bias_init_std,
self.bias_mutate_power, self.bias_mutate_rate, self.bias_replace_rate)
res = mutate_float(k2, node[2], self.response_init_mean, self.response_init_std,
self.response_mutate_power, self.response_mutate_rate, self.response_replace_rate)
act = mutate_int(k3, node[3], self.activation_indices, self.activation_replace_rate)
agg = mutate_int(k4, node[4], self.aggregation_indices, self.aggregation_replace_rate)
return jnp.array([index, bias, res, act, agg])
def distance(self, node1, node2):
return (
jnp.abs(node1[1] - node2[1]) +
jnp.abs(node1[2] - node2[2]) +
(node1[3] != node2[3]) +
(node1[4] != node2[4])
)
def forward(self, attrs, inputs):
bias, res, act_idx, agg_idx = attrs
z = agg(agg_idx, inputs, self.aggregation_options)
z = bias + res * z
z = act(act_idx, z, self.activation_options)
return z

View File

@@ -0,0 +1,3 @@
from .base import BaseGenome
from .default import DefaultGenome
from .recurrent import RecurrentGenome

View File

@@ -0,0 +1,65 @@
import jax.numpy as jnp
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
from utils import fetch_first
class BaseGenome:
network_type = None
def __init__(
self,
num_inputs: int,
num_outputs: int,
max_nodes: int,
max_conns: int,
node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(),
):
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.input_idx = jnp.arange(num_inputs)
self.output_idx = jnp.arange(num_inputs, num_inputs + num_outputs)
self.max_nodes = max_nodes
self.max_conns = max_conns
self.node_gene = node_gene
self.conn_gene = conn_gene
def transform(self, nodes, conns):
raise NotImplementedError
def forward(self, inputs, transformed):
raise NotImplementedError
def add_node(self, nodes, new_key: int, attrs):
"""
Add a new node to the genome.
The new node will place at the first NaN row.
"""
exist_keys = nodes[:, 0]
pos = fetch_first(jnp.isnan(exist_keys))
new_nodes = nodes.at[pos, 0].set(new_key)
return new_nodes.at[pos, 1:].set(attrs)
def delete_node_by_pos(self, nodes, pos):
"""
Delete a node from the genome.
Delete the node by its pos in nodes.
"""
return nodes.at[pos].set(jnp.nan)
def add_conn(self, conns, i_key, o_key, enable: bool, attrs):
"""
Add a new connection to the genome.
The new connection will place at the first NaN row.
"""
con_keys = conns[:, 0]
pos = fetch_first(jnp.isnan(con_keys))
new_conns = conns.at[pos, 0:3].set(jnp.array([i_key, o_key, enable]))
return new_conns.at[pos, 3:].set(attrs)
def delete_conn_by_pos(self, conns, pos):
"""
Delete a connection from the genome.
Delete the connection by its idx.
"""
return conns.at[pos].set(jnp.nan)

View File

@@ -0,0 +1,90 @@
from typing import Callable
import jax, jax.numpy as jnp
from utils import unflatten_conns, topological_sort, I_INT
from . import BaseGenome
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
class DefaultGenome(BaseGenome):
"""Default genome class, with the same behavior as the NEAT-Python"""
network_type = 'feedforward'
def __init__(self,
num_inputs: int,
num_outputs: int,
max_nodes=5,
max_conns=4,
node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(),
output_transform: Callable = None
):
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
if output_transform is not None:
try:
aux = output_transform(jnp.zeros(num_outputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
self.output_transform = output_transform
def transform(self, nodes, conns):
u_conns = unflatten_conns(nodes, conns)
# DONE: Seems like there is a bug in this line
# conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
# modified: exist conn and enable is true
# conn_enable = jnp.where( (~jnp.isnan(u_conns[0])) & (u_conns[0] == 1), True, False)
# advanced modified: when and only when enabled is True
conn_enable = u_conns[0] == 1
# remove enable attr
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
seqs = topological_sort(nodes, conn_enable)
return seqs, nodes, u_conns
def forward(self, inputs, transformed):
cal_seqs, nodes, conns = transformed
N = nodes.shape[0]
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[self.input_idx].set(inputs)
nodes_attrs = nodes[:, 1:]
def cond_fun(carry):
values, idx = carry
return (idx < N) & (cal_seqs[idx] != I_INT)
def body_func(carry):
values, idx = carry
i = cal_seqs[idx]
def hit():
ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values)
# ins = values * weights[:, i]
z = self.node_gene.forward(nodes_attrs[i], ins)
# z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins)
# z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
# z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z)
new_values = values.at[i].set(z)
return new_values
def miss():
return values
# the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(jnp.isin(i, self.input_idx), miss, hit)
return values, idx + 1
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
if self.output_transform is None:
return vals[self.output_idx]
else:
return self.output_transform(vals[self.output_idx])

View File

@@ -0,0 +1,60 @@
import jax, jax.numpy as jnp
from utils import unflatten_conns
from . import BaseGenome
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
class RecurrentGenome(BaseGenome):
"""Default genome class, with the same behavior as the NEAT-Python"""
network_type = 'recurrent'
def __init__(self,
num_inputs: int,
num_outputs: int,
max_nodes: int,
max_conns: int,
node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(),
activate_time: int = 10,
):
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
self.activate_time = activate_time
def transform(self, nodes, conns):
u_conns = unflatten_conns(nodes, conns)
# remove un-enable connections and remove enable attr
conn_enable = u_conns[0] == 1
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
return nodes, u_conns
def forward(self, inputs, transformed):
nodes, conns = transformed
N = nodes.shape[0]
vals = jnp.full((N,), jnp.nan)
nodes_attrs = nodes[:, 1:]
def body_func(_, values):
# set input values
values = values.at[self.input_idx].set(inputs)
# calculate connections
node_ins = jax.vmap(
jax.vmap(
self.conn_gene.forward,
in_axes=(1, None)
),
in_axes=(1, 0)
)(conns, values)
# calculate nodes
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T)
return values
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)
return vals[self.output_idx]

View File

@@ -0,0 +1,113 @@
import jax, jax.numpy as jnp
from utils import State
from .. import BaseAlgorithm
from .species import *
from .ga import *
class NEAT(BaseAlgorithm):
def __init__(
self,
species: BaseSpecies,
mutation: BaseMutation = DefaultMutation(),
crossover: BaseCrossover = DefaultCrossover(),
):
self.genome = species.genome
self.species = species
self.mutation = mutation
self.crossover = crossover
def setup(self, randkey):
k1, k2 = jax.random.split(randkey, 2)
return State(
randkey=k1,
generation=jnp.array(0.),
next_node_key=jnp.array(max(*self.genome.input_idx, *self.genome.output_idx) + 2, dtype=jnp.float32),
# inputs nodes, output nodes, 1 hidden node
species=self.species.setup(k2),
)
def ask(self, state: State):
return self.species.ask(state.species)
def tell(self, state: State, fitness):
k1, k2, randkey = jax.random.split(state.randkey, 3)
state = state.update(
generation=state.generation + 1,
randkey=randkey
)
species_state, winner, loser, elite_mask = self.species.update_species(state.species, fitness, state.generation)
state = state.update(species=species_state)
state = self.create_next_generation(k2, state, winner, loser, elite_mask)
species_state = self.species.speciate(state.species, state.generation)
state = state.update(species=species_state)
return state
def transform(self, individual):
"""transform the genome into a neural network"""
nodes, conns = individual
return self.genome.transform(nodes, conns)
def forward(self, inputs, transformed):
return self.genome.forward(inputs, transformed)
@property
def num_inputs(self):
return self.genome.num_inputs
@property
def num_outputs(self):
return self.genome.num_outputs
@property
def pop_size(self):
return self.species.pop_size
def create_next_generation(self, randkey, state, winner, loser, elite_mask):
# prepare random keys
pop_size = self.species.pop_size
new_node_keys = jnp.arange(pop_size) + state.next_node_key
k1, k2 = jax.random.split(randkey, 2)
crossover_rand_keys = jax.random.split(k1, pop_size)
mutate_rand_keys = jax.random.split(k2, pop_size)
wpn, wpc = state.species.pop_nodes[winner], state.species.pop_conns[winner]
lpn, lpc = state.species.pop_nodes[loser], state.species.pop_conns[loser]
# batch crossover
n_nodes, n_conns = (jax.vmap(self.crossover, in_axes=(0, None, 0, 0, 0, 0))
(crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc))
# batch mutation
m_n_nodes, m_n_conns = (jax.vmap(self.mutation, in_axes=(0, None, 0, 0, 0))
(mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys))
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns)
# update next node key
all_nodes_keys = pop_nodes[:, :, 0]
max_node_key = jnp.max(jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys))
next_node_key = max_node_key + 1
return state.update(
species=state.species.update(
pop_nodes=pop_nodes,
pop_conns=pop_conns,
),
next_node_key=next_node_key,
)
def member_count(self, state: State):
return state.species.member_count
def generation(self, state: State):
# to analysis the algorithm
return state.generation

View File

@@ -0,0 +1,2 @@
from .base import BaseSpecies
from .default import DefaultSpecies

View File

@@ -0,0 +1,14 @@
from utils import State
class BaseSpecies:
def setup(self, randkey):
raise NotImplementedError
def ask(self, state: State):
raise NotImplementedError
def update_species(self, state, fitness, generation):
raise NotImplementedError
def speciate(self, state, generation):
raise NotImplementedError

View File

@@ -0,0 +1,519 @@
import numpy as np
import jax, jax.numpy as jnp
from utils import State, rank_elements, argmin_with_mask, fetch_first
from ..genome import BaseGenome
from .base import BaseSpecies
class DefaultSpecies(BaseSpecies):
def __init__(self,
genome: BaseGenome,
pop_size,
species_size,
compatibility_disjoint: float = 1.0,
compatibility_weight: float = 0.4,
max_stagnation: int = 15,
species_elitism: int = 2,
spawn_number_change_rate: float = 0.5,
genome_elitism: int = 2,
survival_threshold: float = 0.2,
min_species_size: int = 1,
compatibility_threshold: float = 3.
):
self.genome = genome
self.pop_size = pop_size
self.species_size = species_size
self.compatibility_disjoint = compatibility_disjoint
self.compatibility_weight = compatibility_weight
self.max_stagnation = max_stagnation
self.species_elitism = species_elitism
self.spawn_number_change_rate = spawn_number_change_rate
self.genome_elitism = genome_elitism
self.survival_threshold = survival_threshold
self.min_species_size = min_species_size
self.compatibility_threshold = compatibility_threshold
self.species_arange = jnp.arange(self.species_size)
def setup(self, randkey):
pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome)
species_keys = jnp.full((self.species_size,), jnp.nan) # the unique index (primary key) for each species
best_fitness = jnp.full((self.species_size,), jnp.nan) # the best fitness of each species
last_improved = jnp.full((self.species_size,), jnp.nan) # the last generation that the species improved
member_count = jnp.full((self.species_size,), jnp.nan) # the number of members of each species
idx2species = jnp.zeros(self.pop_size) # the species index of each individual
# nodes for each center genome of each species
center_nodes = jnp.full((self.species_size, self.genome.max_nodes, self.genome.node_gene.length), jnp.nan)
# connections for each center genome of each species
center_conns = jnp.full((self.species_size, self.genome.max_conns, self.genome.conn_gene.length), jnp.nan)
species_keys = species_keys.at[0].set(0)
best_fitness = best_fitness.at[0].set(-jnp.inf)
last_improved = last_improved.at[0].set(0)
member_count = member_count.at[0].set(self.pop_size)
center_nodes = center_nodes.at[0].set(pop_nodes[0])
center_conns = center_conns.at[0].set(pop_conns[0])
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns))
return State(
randkey=randkey,
pop_nodes=pop_nodes,
pop_conns=pop_conns,
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
member_count=member_count,
idx2species=idx2species,
center_nodes=center_nodes,
center_conns=center_conns,
next_species_key=jnp.array(1), # 0 is reserved for the first species
)
def ask(self, state):
return state.pop_nodes, state.pop_conns
def update_species(self, state, fitness, generation):
# update the fitness of each species
species_fitness = self.update_species_fitness(state, fitness)
# stagnation species
state, species_fitness = self.stagnation(state, generation, species_fitness)
# sort species_info by their fitness. (also push nan to the end)
sort_indices = jnp.argsort(species_fitness)[::-1]
state = state.update(
species_keys=state.species_keys[sort_indices],
best_fitness=state.best_fitness[sort_indices],
last_improved=state.last_improved[sort_indices],
member_count=state.member_count[sort_indices],
center_nodes=state.center_nodes[sort_indices],
center_conns=state.center_conns[sort_indices],
)
# decide the number of members of each species by their fitness
spawn_number = self.cal_spawn_numbers(state)
k1, k2 = jax.random.split(state.randkey)
# crossover info
winner, loser, elite_mask = self.create_crossover_pair(state, k1, spawn_number, fitness)
return state.update(randkey=k2), winner, loser, elite_mask
def update_species_fitness(self, state, fitness):
"""
obtain the fitness of the species by the fitness of each individual.
use max criterion.
"""
def aux_func(idx):
s_fitness = jnp.where(state.idx2species == state.species_keys[idx], fitness, -jnp.inf)
val = jnp.max(s_fitness)
return val
return jax.vmap(aux_func)(self.species_arange)
def stagnation(self, state, generation, species_fitness):
"""
stagnation species.
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
elitism species never stagnation
generation: the current generation
"""
def check_stagnation(idx):
# determine whether the species stagnation
st = (
(species_fitness[idx] <= state.best_fitness[
idx]) & # not better than the best fitness of the species
(generation - state.last_improved[idx] > self.max_stagnation) # for a long time
)
# update last_improved and best_fitness
li, bf = jax.lax.cond(
species_fitness[idx] > state.best_fitness[idx],
lambda: (generation, species_fitness[idx]), # update
lambda: (state.last_improved[idx], state.best_fitness[idx]) # not update
)
return st, bf, li
spe_st, best_fitness, last_improved = jax.vmap(check_stagnation)(self.species_arange)
# elite species will not be stagnation
species_rank = rank_elements(species_fitness)
spe_st = jnp.where(species_rank < self.species_elitism, False, spe_st) # elitism never stagnation
# set stagnation species to nan
def update_func(idx):
return jax.lax.cond(
spe_st[idx],
lambda: (
jnp.nan, # species_key
jnp.nan, # best_fitness
jnp.nan, # last_improved
jnp.nan, # member_count
-jnp.inf, # species_fitness
jnp.full_like(state.center_nodes[idx], jnp.nan), # center_nodes
jnp.full_like(state.center_conns[idx], jnp.nan), # center_conns
), # stagnation species
lambda: (
state.species_keys[idx],
best_fitness[idx],
last_improved[idx],
state.member_count[idx],
species_fitness[idx],
state.center_nodes[idx],
state.center_conns[idx]
) # not stagnation species
)
(
species_keys,
best_fitness,
last_improved,
member_count,
species_fitness,
center_nodes,
center_conns
) = (
jax.vmap(update_func)(self.species_arange))
return state.update(
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
member_count=member_count,
center_nodes=center_nodes,
center_conns=center_conns,
), species_fitness
def cal_spawn_numbers(self, state):
"""
decide the number of members of each species by their fitness rank.
the species with higher fitness will have more members
Linear ranking selection
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
"""
species_keys = state.species_keys
is_species_valid = ~jnp.isnan(species_keys)
valid_species_num = jnp.sum(is_species_valid)
denominator = (valid_species_num + 1) * valid_species_num / 2 # obtain 3 + 2 + 1 = 6
rank_score = valid_species_num - self.species_arange # obtain [3, 2, 1]
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
target_spawn_number = jnp.floor(spawn_number_rate * self.pop_size) # calculate member
# Avoid too much variation of numbers for a species
previous_size = state.member_count
spawn_number = previous_size + (target_spawn_number - previous_size) * self.spawn_number_change_rate
spawn_number = spawn_number.astype(jnp.int32)
# must control the sum of spawn_number to be equal to pop_size
error = self.pop_size - jnp.sum(spawn_number)
# add error to the first species to control the sum of spawn_number
spawn_number = spawn_number.at[0].add(error)
return spawn_number
def create_crossover_pair(self, state, randkey, spawn_number, fitness):
s_idx = self.species_arange
p_idx = jnp.arange(self.pop_size)
def aux_func(key, idx):
members = state.idx2species == state.species_keys[idx]
members_num = jnp.sum(members)
members_fitness = jnp.where(members, fitness, -jnp.inf)
sorted_member_indices = jnp.argsort(members_fitness)[::-1]
survive_size = jnp.floor(self.survival_threshold * members_num).astype(jnp.int32)
select_pro = (p_idx < survive_size) / survive_size
fa, ma = jax.random.choice(key, sorted_member_indices, shape=(2, self.pop_size), replace=True, p=select_pro)
# elite
fa = jnp.where(p_idx < self.genome_elitism, sorted_member_indices, fa)
ma = jnp.where(p_idx < self.genome_elitism, sorted_member_indices, ma)
elite = jnp.where(p_idx < self.genome_elitism, True, False)
return fa, ma, elite
fas, mas, elites = jax.vmap(aux_func)(jax.random.split(randkey, self.species_size), s_idx)
spawn_number_cum = jnp.cumsum(spawn_number)
def aux_func(idx):
loc = jnp.argmax(idx < spawn_number_cum)
# elite genomes are at the beginning of the species
idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx)
return fas[loc, idx_in_species], mas[loc, idx_in_species], elites[loc, idx_in_species]
part1, part2, elite_mask = jax.vmap(aux_func)(p_idx)
is_part1_win = fitness[part1] >= fitness[part2]
winner = jnp.where(is_part1_win, part1, part2)
loser = jnp.where(is_part1_win, part2, part1)
return winner, loser, elite_mask
def speciate(self, state, generation):
# prepare distance functions
o2p_distance_func = jax.vmap(self.distance, in_axes=(None, None, 0, 0)) # one to population
# idx to specie key
idx2species = jnp.full((self.pop_size,), jnp.nan) # NaN means not assigned to any species
# the distance between genomes to its center genomes
o2c_distances = jnp.full((self.pop_size,), jnp.inf)
# step 1: find new centers
def cond_func(carry):
# i, idx2species, center_nodes, center_conns, o2c_distances
i, i2s, cns, ccs, o2c = carry
return (
(i < self.species_size) &
(~jnp.isnan(state.species_keys[i]))
) # current species is existing
def body_func(carry):
i, i2s, cns, ccs, o2c = carry
distances = o2p_distance_func(cns[i], ccs[i], state.pop_nodes, state.pop_conns)
# find the closest one
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
i2s = i2s.at[closest_idx].set(state.species_keys[i])
cns = cns.at[i].set(state.pop_nodes[closest_idx])
ccs = ccs.at[i].set(state.pop_conns[closest_idx])
# the genome with closest_idx will become the new center, thus its distance to center is 0.
o2c = o2c.at[closest_idx].set(0)
return i + 1, i2s, cns, ccs, o2c
_, idx2species, center_nodes, center_conns, o2c_distances = \
jax.lax.while_loop(cond_func, body_func,
(0, idx2species, state.center_nodes, state.center_conns, o2c_distances))
state = state.update(
idx2species=idx2species,
center_nodes=center_nodes,
center_conns=center_conns,
)
# part 2: assign members to each species
def cond_func(carry):
# i, idx2species, center_nodes, center_conns, species_keys, o2c_distances, next_species_key
i, i2s, cns, ccs, sk, o2c, nsk = carry
current_species_existed = ~jnp.isnan(sk[i])
not_all_assigned = jnp.any(jnp.isnan(i2s))
not_reach_species_upper_bounds = i < self.species_size
return not_reach_species_upper_bounds & (current_species_existed | not_all_assigned)
def body_func(carry):
i, i2s, cns, ccs, sk, o2c, nsk = carry
_, i2s, cns, ccs, sk, o2c, nsk = jax.lax.cond(
jnp.isnan(sk[i]), # whether the current species is existing or not
create_new_species, # if not existing, create a new specie
update_exist_specie, # if existing, update the specie
(i, i2s, cns, ccs, sk, o2c, nsk)
)
return i + 1, i2s, cns, ccs, sk, o2c, nsk
def create_new_species(carry):
i, i2s, cns, ccs, sk, o2c, nsk = carry
# pick the first one who has not been assigned to any species
idx = fetch_first(jnp.isnan(i2s))
# assign it to the new species
# [key, best score, last update generation, member_count]
sk = sk.at[i].set(nsk) # nsk -> next species key
i2s = i2s.at[idx].set(nsk)
o2c = o2c.at[idx].set(0)
# update center genomes
cns = cns.at[i].set(state.pop_nodes[idx])
ccs = ccs.at[i].set(state.pop_conns[idx])
# find the members for the new species
i2s, o2c = speciate_by_threshold(i, i2s, cns, ccs, sk, o2c)
return i, i2s, cns, ccs, sk, o2c, nsk + 1 # change to next new speciate key
def update_exist_specie(carry):
i, i2s, cns, ccs, sk, o2c, nsk = carry
i2s, o2c = speciate_by_threshold(i, i2s, cns, ccs, sk, o2c)
# turn to next species
return i + 1, i2s, cns, ccs, sk, o2c, nsk
def speciate_by_threshold(i, i2s, cns, ccs, sk, o2c):
# distance between such center genome and ppo genomes
o2p_distance = o2p_distance_func(cns[i], ccs[i], state.pop_nodes, state.pop_conns)
close_enough_mask = o2p_distance < self.compatibility_threshold
# when a genome is not assigned or the distance between its current center is bigger than this center
catchable_mask = jnp.isnan(i2s) | (o2p_distance < o2c)
mask = close_enough_mask & catchable_mask
# update species info
i2s = jnp.where(mask, sk[i], i2s)
# update distance between centers
o2c = jnp.where(mask, o2p_distance, o2c)
return i2s, o2c
# update idx2species
_, idx2species, center_nodes, center_conns, species_keys, _, next_species_key = jax.lax.while_loop(
cond_func,
body_func,
(0, state.idx2species, center_nodes, center_conns, state.species_keys, o2c_distances,
state.next_species_key)
)
# if there are still some pop genomes not assigned to any species, add them to the last genome
# this condition can only happen when the number of species is reached species upper bounds
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
# complete info of species which is created in this generation
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.best_fitness)
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.best_fitness)
last_improved = jnp.where(new_created_mask, generation, state.last_improved)
# update members count
def count_members(idx):
return jax.lax.cond(
jnp.isnan(species_keys[idx]), # if the species is not existing
lambda: jnp.nan, # nan
lambda: jnp.sum(idx2species == species_keys[idx], dtype=jnp.float32) # count members
)
member_count = jax.vmap(count_members)(self.species_arange)
return state.update(
species_keys=species_keys,
best_fitness=best_fitness,
last_improved=last_improved,
member_count=member_count,
idx2species=idx2species,
center_nodes=center_nodes,
center_conns=center_conns,
next_species_key=next_species_key
)
def distance(self, nodes1, conns1, nodes2, conns2):
"""
The distance between two genomes
"""
d = self.node_distance(nodes1, nodes2) + self.conn_distance(conns1, conns2)
return d
def node_distance(self, nodes1, nodes2):
"""
The distance of the nodes part for two genomes
"""
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
# align homologous nodes
# this process is similar to np.intersect1d.
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0]
sorted_indices = jnp.argsort(keys, axis=0)
nodes = nodes[sorted_indices]
nodes = jnp.concatenate([nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row
# flag location of homologous nodes
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
# calculate the count of non_homologous of two genomes
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
# calculate the distance of homologous nodes
hnd = jax.vmap(self.genome.node_gene.distance, in_axes=(0, 0))(fr, sr)
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
homologous_distance = jnp.sum(hnd * intersect_mask)
val = non_homologous_cnt * self.compatibility_disjoint + homologous_distance * self.compatibility_weight
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division
def conn_distance(self, conns1, conns2):
"""
The distance of the conns part for two genomes
"""
con_cnt1 = jnp.sum(~jnp.isnan(conns1[:, 0]))
con_cnt2 = jnp.sum(~jnp.isnan(conns2[:, 0]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
cons = jnp.concatenate((conns1, conns2), axis=0)
keys = cons[:, :2]
sorted_indices = jnp.lexsort(keys.T[::-1])
cons = cons[sorted_indices]
cons = jnp.concatenate([cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0) # add a nan row to the end
fr, sr = cons[:-1], cons[1:] # first row, second row
# both genome has such connection
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
hcd = jax.vmap(self.genome.conn_gene.distance, in_axes=(0, 0))(fr, sr)
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
homologous_distance = jnp.sum(hcd * intersect_mask)
val = non_homologous_cnt * self.compatibility_disjoint + homologous_distance * self.compatibility_weight
return jnp.where(max_cnt == 0, 0, val / max_cnt)
def initialize_population(pop_size, genome):
o_nodes = np.full((genome.max_nodes, genome.node_gene.length), np.nan) # original nodes
o_conns = np.full((genome.max_conns, genome.conn_gene.length), np.nan) # original connections
input_idx, output_idx = genome.input_idx, genome.output_idx
new_node_key = max([*input_idx, *output_idx]) + 1
o_nodes[input_idx, 0] = genome.input_idx
o_nodes[output_idx, 0] = genome.output_idx
o_nodes[new_node_key, 0] = new_node_key # one hidden node
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = genome.node_gene.new_custom_attrs()
o_nodes[new_node_key, 1:] = genome.node_gene.new_custom_attrs() # one hidden node
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)] # input nodes to hidden
o_conns[input_idx, 0:2] = input_conns # in key, out key
o_conns[input_idx, 2] = True # enabled
o_conns[input_idx, 3:] = genome.conn_gene.new_custom_attrs()
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx] # hidden to output nodes
o_conns[output_idx, 0:2] = output_conns # in key, out key
o_conns[output_idx, 2] = True # enabled
o_conns[output_idx, 3:] = genome.conn_gene.new_custom_attrs()
# repeat origin genome for P times to create population
pop_nodes = np.tile(o_nodes, (pop_size, 1, 1))
pop_conns = np.tile(o_conns, (pop_size, 1, 1))
return pop_nodes, pop_conns

View File

@@ -0,0 +1,36 @@
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import BraxEnv
from utils import Act
if __name__ == '__main__':
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=27,
num_outputs=8,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
)
),
pop_size=1000,
species_size=10,
),
),
problem=BraxEnv(
env_name='ant',
),
generation_limit=10000,
fitness_target=5000
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 MiB

View File

@@ -0,0 +1,36 @@
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import BraxEnv
from utils import Act
if __name__ == '__main__':
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=17,
num_outputs=6,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
)
),
pop_size=1000,
species_size=10,
),
),
problem=BraxEnv(
env_name='halhcheetah',
),
generation_limit=10000,
fitness_target=5000
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -0,0 +1,36 @@
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import BraxEnv
from utils import Act
if __name__ == '__main__':
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=11,
num_outputs=2,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
)
),
pop_size=100,
species_size=10,
),
),
problem=BraxEnv(
env_name='reacher',
),
generation_limit=10000,
fitness_target=5000
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -0,0 +1,32 @@
from pipeline import Pipeline
from algorithm.neat import *
from problem.func_fit import XOR3d
if __name__ == '__main__':
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=3,
num_outputs=1,
max_nodes=50,
max_conns=100,
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
),
),
problem=XOR3d(),
generation_limit=10000,
fitness_target=-1e-8
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)
# show result
pipeline.show(state, best)

View File

@@ -0,0 +1,51 @@
from pipeline import Pipeline
from algorithm.neat import *
from algorithm.hyperneat import *
from utils import Act
from problem.func_fit import XOR3d
if __name__ == '__main__':
pipeline = Pipeline(
algorithm=HyperNEAT(
substrate=FullSubstrate(
input_coors=[(-1, -1), (0.333, -1), (-0.333, -1), (1, -1)],
hidden_coors=[
(-1, -0.5), (0.333, -0.5), (-0.333, -0.5), (1, -0.5),
(-1, 0), (0.333, 0), (-0.333, 0), (1, 0),
(-1, 0.5), (0.333, 0.5), (-0.333, 0.5), (1, 0.5),
],
output_coors=[(0, 1), ],
),
neat=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=4, # [-1, -1, -1, 0]
num_outputs=1,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
),
),
activation=Act.sigmoid,
activate_time=10,
),
problem=XOR3d(),
generation_limit=300,
fitness_target=-1e-6
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)
# show result
pipeline.show(state, best)

View File

@@ -0,0 +1,41 @@
from pipeline import Pipeline
from algorithm.neat import *
from problem.func_fit import XOR3d
from utils.activation import ACT_ALL
from utils.aggregation import AGG_ALL
if __name__ == '__main__':
pipeline = Pipeline(
seed=0,
algorithm=NEAT(
species=DefaultSpecies(
genome=RecurrentGenome(
num_inputs=3,
num_outputs=1,
max_nodes=50,
max_conns=100,
activate_time=5,
node_gene=DefaultNodeGene(
activation_options=ACT_ALL,
# aggregation_options=AGG_ALL,
activation_replace_rate=0.2
),
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
),
),
problem=XOR3d(),
generation_limit=10000,
fitness_target=-1e-8
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)
# show result
pipeline.show(state, best)

View File

@@ -0,0 +1,34 @@
import jax.numpy as jnp
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import GymNaxEnv
if __name__ == '__main__':
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=6,
num_outputs=3,
max_nodes=50,
max_conns=100,
output_transform=lambda out: jnp.argmax(out) # the action of acrobot is {0, 1, 2}
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name='Acrobot-v1',
),
generation_limit=10000,
fitness_target=-62
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -0,0 +1,34 @@
import jax.numpy as jnp
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import GymNaxEnv
if __name__ == '__main__':
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=4,
num_outputs=2,
max_nodes=50,
max_conns=100,
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1}
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name='CartPole-v1',
),
generation_limit=10000,
fitness_target=500
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -0,0 +1,54 @@
import jax.numpy as jnp
from config import *
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from algorithm.hyperneat import HyperNEAT, NormalSubstrateConfig, NormalSubstrate
from problem.rl_env import GymNaxConfig, GymNaxEnv
def example_conf():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=500,
pop_size=10000
),
neat=NeatConfig(
inputs=4,
outputs=1,
),
gene=NormalGeneConfig(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
hyperneat=HyperNeatConfig(
activation=Act.sigmoid,
inputs=4,
outputs=2
),
substrate=NormalSubstrateConfig(
input_coors=((-1, -1), (-0.5, -1), (0, -1), (0.5, -1), (1, -1)),
hidden_coors=(
# (-1, -0.5), (-0.5, -0.5), (0, -0.5), (0.5, -0.5),
(1, 0), (-1, 0), (-0.5, 0), (0, 0), (0.5, 0), (1, 0),
# (1, 0.5), (-1, 0.5), (-0.5, 0.5), (0, 0.5), (0.5, 0.5), (1, 0.5),
),
output_coors=((-1, 1), (1, 1)),
),
problem=GymNaxConfig(
env_name='CartPole-v1',
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1}
)
)
if __name__ == '__main__':
conf = example_conf()
algorithm = HyperNEAT(conf, NormalGene, NormalSubstrate)
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
state = pipeline.setup()
pipeline.pre_compile(state)
state, best = pipeline.auto_run(state)

View File

@@ -0,0 +1,34 @@
import jax.numpy as jnp
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import GymNaxEnv
if __name__ == '__main__':
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=2,
num_outputs=3,
max_nodes=50,
max_conns=100,
output_transform=lambda out: jnp.argmax(out) # the action of mountain car is {0, 1, 2}
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name='MountainCar-v0',
),
generation_limit=10000,
fitness_target=0
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -0,0 +1,36 @@
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import GymNaxEnv
from utils import Act
if __name__ == '__main__':
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=2,
num_outputs=1,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh, ),
activation_default=Act.tanh,
)
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name='MountainCarContinuous-v0',
),
generation_limit=10000,
fitness_target=500
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -0,0 +1,37 @@
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import GymNaxEnv
from utils import Act
if __name__ == '__main__':
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=3,
num_outputs=1,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
),
output_transform=lambda out: out * 2 # the action of pendulum is [-2, 2]
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name='Pendulum-v1',
),
generation_limit=10000,
fitness_target=0
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -0,0 +1,33 @@
import jax.numpy as jnp
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import GymNaxEnv
if __name__ == '__main__':
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=8,
num_outputs=2,
max_nodes=50,
max_conns=100,
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name='Reacher-misc',
),
generation_limit=10000,
fitness_target =500
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

131
tensorneat/pipeline.py Normal file
View File

@@ -0,0 +1,131 @@
from functools import partial
import jax, jax.numpy as jnp
import time
import numpy as np
from algorithm import BaseAlgorithm
from problem import BaseProblem
from utils import State
class Pipeline:
def __init__(
self,
algorithm: BaseAlgorithm,
problem: BaseProblem,
seed: int = 42,
fitness_target: float = 1,
generation_limit: int = 1000,
):
assert problem.jitable, "Currently, problem must be jitable"
self.algorithm = algorithm
self.problem = problem
self.seed = seed
self.fitness_target = fitness_target
self.generation_limit = generation_limit
self.pop_size = self.algorithm.pop_size
print(self.problem.input_shape, self.problem.output_shape)
# TODO: make each algorithm's input_num and output_num
assert algorithm.num_inputs == self.problem.input_shape[-1], \
f"algorithm input shape is {algorithm.num_inputs} but problem input shape is {self.problem.input_shape}"
# self.act_func = self.algorithm.act
# for _ in range(len(self.problem.input_shape) - 1):
# self.act_func = jax.vmap(self.act_func, in_axes=(None, 0, None))
self.best_genome = None
self.best_fitness = float('-inf')
self.generation_timestamp = None
def setup(self):
key = jax.random.PRNGKey(self.seed)
key, algorithm_key, evaluate_key = jax.random.split(key, 3)
# TODO: Problem should has setup function to maintain state
return State(
randkey=key,
alg=self.algorithm.setup(algorithm_key),
pro=self.problem.setup(evaluate_key),
)
def step(self, state):
key, sub_key = jax.random.split(state.randkey)
keys = jax.random.split(key, self.pop_size)
pop = self.algorithm.ask(state.alg)
pop_transformed = jax.vmap(self.algorithm.transform)(pop)
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0))(
keys,
state.pro,
self.algorithm.forward,
pop_transformed
)
fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses)
alg_state = self.algorithm.tell(state.alg, fitnesses)
return state.update(
randkey=sub_key,
alg=alg_state,
), fitnesses
def auto_run(self, ini_state):
state = ini_state
compiled_step = jax.jit(self.step).lower(ini_state).compile()
for _ in range(self.generation_limit):
self.generation_timestamp = time.time()
previous_pop = self.algorithm.ask(state.alg)
state, fitnesses = compiled_step(state)
fitnesses = jax.device_get(fitnesses)
for idx, fitnesses_i in enumerate(fitnesses):
if np.isnan(fitnesses_i):
print("Fitness is nan")
print(previous_pop[0][idx], previous_pop[1][idx])
assert False
self.analysis(state, previous_pop, fitnesses)
if max(fitnesses) >= self.fitness_target:
print("Fitness limit reached!")
return state, self.best_genome
print("Generation limit reached!")
return state, self.best_genome
def analysis(self, state, pop, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
new_timestamp = time.time()
cost_time = new_timestamp - self.generation_timestamp
max_idx = np.argmax(fitnesses)
if fitnesses[max_idx] > self.best_fitness:
self.best_fitness = fitnesses[max_idx]
self.best_genome = pop[0][max_idx], pop[1][max_idx]
member_count = jax.device_get(self.algorithm.member_count(state.alg))
species_sizes = [int(i) for i in member_count if i > 0]
print(f"Generation: {self.algorithm.generation(state.alg)}",
f"species: {len(species_sizes)}, {species_sizes}",
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")
def show(self, state, best, *args, **kwargs):
transformed = self.algorithm.transform(best)
self.problem.show(state.randkey, state.pro, self.algorithm.forward, transformed, *args, **kwargs)

View File

@@ -0,0 +1 @@
from .base import BaseProblem

View File

@@ -0,0 +1,39 @@
from typing import Callable
from utils import State
class BaseProblem:
jitable = None
def setup(self, randkey, state: State = State()):
"""initialize the state of the problem"""
pass
def evaluate(self, randkey, state: State, act_func: Callable, params):
"""evaluate one individual"""
raise NotImplementedError
@property
def input_shape(self):
"""
The input shape for the problem to evaluate
In RL problem, it is the observation space
In function fitting problem, it is the input shape of the function
"""
raise NotImplementedError
@property
def output_shape(self):
"""
The output shape for the problem to evaluate
In RL problem, it is the action space
In function fitting problem, it is the output shape of the function
"""
raise NotImplementedError
def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs):
"""
show how a genome perform in this problem
"""
raise NotImplementedError

View File

@@ -0,0 +1,3 @@
from .func_fit import FuncFit
from .xor import XOR
from .xor3d import XOR3d

View File

@@ -0,0 +1,67 @@
import jax
import jax.numpy as jnp
from utils import State
from .. import BaseProblem
class FuncFit(BaseProblem):
jitable = True
def __init__(self,
error_method: str = 'mse'
):
super().__init__()
assert error_method in {'mse', 'rmse', 'mae', 'mape'}
self.error_method = error_method
def setup(self, randkey, state: State = State()):
return state
def evaluate(self, randkey, state, act_func, params):
predict = jax.vmap(act_func, in_axes=(0, None))(self.inputs, params)
if self.error_method == 'mse':
loss = jnp.mean((predict - self.targets) ** 2)
elif self.error_method == 'rmse':
loss = jnp.sqrt(jnp.mean((predict - self.targets) ** 2))
elif self.error_method == 'mae':
loss = jnp.mean(jnp.abs(predict - self.targets))
elif self.error_method == 'mape':
loss = jnp.mean(jnp.abs((predict - self.targets) / self.targets))
else:
raise NotImplementedError
return -loss
def show(self, randkey, state, act_func, params, *args, **kwargs):
predict = jax.vmap(act_func, in_axes=(0, None))(self.inputs, params)
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
loss = -self.evaluate(randkey, state, act_func, params)
msg = ""
for i in range(inputs.shape[0]):
msg += f"input: {inputs[i]}, target: {target[i]}, predict: {predict[i]}\n"
msg += f"loss: {loss}\n"
print(msg)
@property
def inputs(self):
raise NotImplementedError
@property
def targets(self):
raise NotImplementedError
@property
def input_shape(self):
raise NotImplementedError
@property
def output_shape(self):
raise NotImplementedError

View File

@@ -0,0 +1,35 @@
import numpy as np
from .func_fit import FuncFit
class XOR(FuncFit):
def __init__(self, error_method: str = 'mse'):
super().__init__(error_method)
@property
def inputs(self):
return np.array([
[0, 0],
[0, 1],
[1, 0],
[1, 1]
])
@property
def targets(self):
return np.array([
[0],
[1],
[1],
[0]
])
@property
def input_shape(self):
return 4, 2
@property
def output_shape(self):
return 4, 1

View File

@@ -0,0 +1,43 @@
import numpy as np
from .func_fit import FuncFit
class XOR3d(FuncFit):
def __init__(self, error_method: str = 'mse'):
super().__init__(error_method)
@property
def inputs(self):
return np.array([
[0, 0, 0],
[0, 0, 1],
[0, 1, 0],
[0, 1, 1],
[1, 0, 0],
[1, 0, 1],
[1, 1, 0],
[1, 1, 1],
])
@property
def targets(self):
return np.array([
[0],
[1],
[1],
[0],
[1],
[0],
[0],
[1]
])
@property
def input_shape(self):
return 8, 3
@property
def output_shape(self):
return 8, 1

View File

@@ -0,0 +1,2 @@
from .gymnax_env import GymNaxEnv
from .brax_env import BraxEnv

View File

@@ -0,0 +1,64 @@
import jax.numpy as jnp
from brax import envs
from .rl_jit import RLEnv
class BraxEnv(RLEnv):
def __init__(self, env_name: str = "ant", backend: str = "generalized"):
super().__init__()
self.env = envs.create(env_name=env_name, backend=backend)
def env_step(self, randkey, env_state, action):
state = self.env.step(env_state, action)
return state.obs, state, state.reward, state.done.astype(jnp.bool_), state.info
def env_reset(self, randkey):
init_state = self.env.reset(randkey)
return init_state.obs, init_state
@property
def input_shape(self):
return (self.env.observation_size,)
@property
def output_shape(self):
return (self.env.action_size,)
def show(self, randkey, state, act_func, params, save_path=None, height=512, width=512, duration=0.1, *args, **kwargs):
import jax
import imageio
import numpy as np
from brax.io import image
from tqdm import tqdm
obs, env_state = self.reset(randkey)
reward, done = 0.0, False
state_histories = []
def step(key, env_state, obs):
key, _ = jax.random.split(key)
action = act_func(state, obs, params)
obs, env_state, r, done, _ = self.step(randkey, env_state, action)
return key, env_state, obs, r, done
while not done:
state_histories.append(env_state.pipeline_state)
key, env_state, obs, r, done = jax.jit(step)(randkey, env_state, obs)
reward += r
imgs = [image.render_array(sys=self.env.sys, state=s, width=width, height=height) for s in
tqdm(state_histories, desc="Rendering")]
def create_gif(image_list, gif_name, duration):
with imageio.get_writer(gif_name, mode='I', duration=duration) as writer:
for image in image_list:
formatted_image = np.array(image, dtype=np.uint8)
writer.append_data(formatted_image)
create_gif(imgs, save_path, duration=0.1)
print("Gif saved to: ", save_path)
print("Total reward: ", reward)

View File

@@ -0,0 +1,28 @@
import gymnax
from .rl_jit import RLEnv
class GymNaxEnv(RLEnv):
def __init__(self, env_name):
super().__init__()
assert env_name in gymnax.registered_envs, f"Env {env_name} not registered"
self.env, self.env_params = gymnax.make(env_name)
def env_step(self, randkey, env_state, action):
return self.env.step(randkey, env_state, action, self.env_params)
def env_reset(self, randkey):
return self.env.reset(randkey, self.env_params)
@property
def input_shape(self):
return self.env.observation_space(self.env_params).shape
@property
def output_shape(self):
return self.env.action_space(self.env_params).shape
def show(self, randkey, state, act_func, params, *args, **kwargs):
raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).")

View File

@@ -0,0 +1,61 @@
from functools import partial
import jax
from .. import BaseProblem
class RLEnv(BaseProblem):
jitable = True
# TODO: move output transform to algorithm
def __init__(self):
super().__init__()
def evaluate(self, randkey, state, act_func, params):
rng_reset, rng_episode = jax.random.split(randkey)
init_obs, init_env_state = self.reset(rng_reset)
def cond_func(carry):
_, _, _, done, _ = carry
return ~done
def body_func(carry):
obs, env_state, rng, _, tr = carry # total reward
action = act_func(obs, params)
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
next_rng, _ = jax.random.split(rng)
return next_obs, next_env_state, next_rng, done, tr + reward
_, _, _, _, total_reward = jax.lax.while_loop(
cond_func,
body_func,
(init_obs, init_env_state, rng_episode, False, 0.0)
)
return total_reward
@partial(jax.jit, static_argnums=(0,))
def step(self, randkey, env_state, action):
return self.env_step(randkey, env_state, action)
@partial(jax.jit, static_argnums=(0,))
def reset(self, randkey):
return self.env_reset(randkey)
def env_step(self, randkey, env_state, action):
raise NotImplementedError
def env_reset(self, randkey):
raise NotImplementedError
@property
def input_shape(self):
raise NotImplementedError
@property
def output_shape(self):
raise NotImplementedError
def show(self, randkey, state, act_func, params, *args, **kwargs):
raise NotImplementedError

View File

View File

@@ -0,0 +1,117 @@
from algorithm.neat import *
from utils import Act, Agg
import jax, jax.numpy as jnp
def test_default():
# index, bias, response, activation, aggregation
nodes = jnp.array([
[0, 0, 1, 0, 0], # in[0]
[1, 0, 1, 0, 0], # in[1]
[2, 0.5, 1, 0, 0], # out[0],
[3, 1, 1, 0, 0], # hidden[0],
[4, -1, 1, 0, 0], # hidden[1],
])
# in_node, out_node, enable, weight
conns = jnp.array([
[0, 3, 1, 0.5], # in[0] -> hidden[0]
[1, 4, 1, 0.5], # in[1] -> hidden[1]
[3, 2, 1, 0.5], # hidden[0] -> out[0]
[4, 2, 1, 0.5], # hidden[1] -> out[0]
])
genome = DefaultGenome(
num_inputs=2,
num_outputs=1,
max_nodes=5,
max_conns=4,
node_gene=DefaultNodeGene(
activation_default=Act.identity,
activation_options=(Act.identity, ),
aggregation_default=Agg.sum,
aggregation_options=(Agg.sum, ),
),
)
transformed = genome.transform(nodes, conns)
print(*transformed, sep='\n')
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(0, None)))(inputs, transformed)
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
# expected: [[0.5], [0.75], [0.75], [1]]
print('\n-------------------------------------------------------\n')
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
print(conns)
transformed = genome.transform(nodes, conns)
print(*transformed, sep='\n')
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
outputs = jax.vmap(genome.forward, in_axes=(0, None))(inputs, transformed)
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
# expected: [[0.5], [0.75], [0.5], [0.75]]
def test_recurrent():
# index, bias, response, activation, aggregation
nodes = jnp.array([
[0, 0, 1, 0, 0], # in[0]
[1, 0, 1, 0, 0], # in[1]
[2, 0.5, 1, 0, 0], # out[0],
[3, 1, 1, 0, 0], # hidden[0],
[4, -1, 1, 0, 0], # hidden[1],
])
# in_node, out_node, enable, weight
conns = jnp.array([
[0, 3, 1, 0.5], # in[0] -> hidden[0]
[1, 4, 1, 0.5], # in[1] -> hidden[1]
[3, 2, 1, 0.5], # hidden[0] -> out[0]
[4, 2, 1, 0.5], # hidden[1] -> out[0]
])
genome = RecurrentGenome(
num_inputs=2,
num_outputs=1,
max_nodes=5,
max_conns=4,
node_gene=DefaultNodeGene(
activation_default=Act.identity,
activation_options=(Act.identity, ),
aggregation_default=Agg.sum,
aggregation_options=(Agg.sum, ),
),
activate_time=3,
)
transformed = genome.transform(nodes, conns)
print(*transformed, sep='\n')
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(0, None)))(inputs, transformed)
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
# expected: [[0.5], [0.75], [0.75], [1]]
print('\n-------------------------------------------------------\n')
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
print(conns)
transformed = genome.transform(nodes, conns)
print(*transformed, sep='\n')
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
outputs = jax.vmap(genome.forward, in_axes=(0, None))(inputs, transformed)
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
# expected: [[0.5], [0.75], [0.5], [0.75]]

View File

@@ -0,0 +1,5 @@
from .activation import Act, act
from .aggregation import Agg, agg
from .tools import *
from .graph import *
from .state import State

View File

@@ -0,0 +1,83 @@
import jax
import jax.numpy as jnp
class Act:
@staticmethod
def sigmoid(z):
z = jnp.clip(5 * z, -10, 10)
return 1 / (1 + jnp.exp(-z))
@staticmethod
def tanh(z):
return jnp.tanh(z)
@staticmethod
def sin(z):
return jnp.sin(z)
@staticmethod
def relu(z):
return jnp.maximum(z, 0)
@staticmethod
def lelu(z):
leaky = 0.005
return jnp.where(z > 0, z, leaky * z)
@staticmethod
def identity(z):
return z
@staticmethod
def clamped(z):
return jnp.clip(z, -1, 1)
@staticmethod
def inv(z):
z = jnp.where(
z > 0,
jnp.maximum(z, 1e-7),
jnp.minimum(z, -1e-7)
)
return 1 / z
@staticmethod
def log(z):
z = jnp.maximum(z, 1e-7)
return jnp.log(z)
@staticmethod
def exp(z):
z = jnp.clip(z, -10, 10)
return jnp.exp(z)
@staticmethod
def abs(z):
return jnp.abs(z)
ACT_ALL = (
Act.sigmoid,
Act.tanh,
Act.sin,
Act.relu,
Act.lelu,
Act.identity,
Act.clamped,
Act.inv,
Act.log,
Act.exp,
Act.abs,
)
def act(idx, z, act_funcs):
"""
calculate activation function for each node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
res = jax.lax.switch(idx, act_funcs, z)
return res

View File

@@ -0,0 +1,67 @@
import jax
import jax.numpy as jnp
class Agg:
@staticmethod
def sum(z):
z = jnp.where(jnp.isnan(z), 0, z)
return jnp.sum(z, axis=0)
@staticmethod
def product(z):
z = jnp.where(jnp.isnan(z), 1, z)
return jnp.prod(z, axis=0)
@staticmethod
def max(z):
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
return jnp.max(z, axis=0)
@staticmethod
def min(z):
z = jnp.where(jnp.isnan(z), jnp.inf, z)
return jnp.min(z, axis=0)
@staticmethod
def maxabs(z):
z = jnp.where(jnp.isnan(z), 0, z)
abs_z = jnp.abs(z)
max_abs_index = jnp.argmax(abs_z)
return z[max_abs_index]
@staticmethod
def median(z):
n = jnp.sum(~jnp.isnan(z), axis=0)
z = jnp.sort(z) # sort
idx1, idx2 = (n - 1) // 2, n // 2
median = (z[idx1] + z[idx2]) / 2
return median
@staticmethod
def mean(z):
aux = jnp.where(jnp.isnan(z), 0, z)
valid_values_sum = jnp.sum(aux, axis=0)
valid_values_count = jnp.sum(~jnp.isnan(z), axis=0)
mean_without_zeros = valid_values_sum / valid_values_count
return mean_without_zeros
AGG_ALL = (Agg.sum, Agg.product, Agg.max, Agg.min, Agg.maxabs, Agg.median, Agg.mean)
def agg(idx, z, agg_funcs):
"""
calculate activation function for inputs of node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
return jax.lax.cond(
jnp.all(jnp.isnan(z)),
lambda: jnp.nan, # all inputs are nan
lambda: jax.lax.switch(idx, agg_funcs, z) # otherwise
)

68
tensorneat/utils/graph.py Normal file
View File

@@ -0,0 +1,68 @@
"""
Some graph algorithm implemented in jax.
Only used in feed-forward networks.
"""
import jax
from jax import jit, Array, numpy as jnp
from .tools import fetch_first, I_INT
@jit
def topological_sort(nodes: Array, conns: Array) -> Array:
"""
a jit-able version of topological_sort!
conns: Array[N, N]
"""
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(conns, axis=0))
res = jnp.full(in_degree.shape, I_INT)
def cond_fun(carry):
res_, idx_, in_degree_ = carry
i = fetch_first(in_degree_ == 0.)
return i != I_INT
def body_func(carry):
res_, idx_, in_degree_ = carry
i = fetch_first(in_degree_ == 0.)
# add to res and flag it is already in it
res_ = res_.at[idx_].set(i)
in_degree_ = in_degree_.at[i].set(-1)
# decrease in_degree of all its children
children = conns[i, :]
in_degree_ = jnp.where(children, in_degree_ - 1, in_degree_)
return res_, idx_ + 1, in_degree_
res, _, _ = jax.lax.while_loop(cond_fun, body_func, (res, 0, in_degree))
return res
@jit
def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
"""
Check whether a new connection (from_idx -> to_idx) will cause a cycle.
"""
conns = conns.at[from_idx, to_idx].set(True)
visited = jnp.full(nodes.shape[0], False)
new_visited = visited.at[to_idx].set(True)
def cond_func(carry):
visited_, new_visited_ = carry
end_cond1 = jnp.all(visited_ == new_visited_) # no new nodes been visited
end_cond2 = new_visited_[from_idx] # the starting node has been visited
return jnp.logical_not(end_cond1 | end_cond2)
def body_func(carry):
_, visited_ = carry
new_visited_ = jnp.dot(visited_, conns)
new_visited_ = jnp.logical_or(visited_, new_visited_)
return visited_, new_visited_
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
return visited[from_idx]

29
tensorneat/utils/state.py Normal file
View File

@@ -0,0 +1,29 @@
from jax.tree_util import register_pytree_node_class
@register_pytree_node_class
class State:
def __init__(self, **kwargs):
self.__dict__['state_dict'] = kwargs
def update(self, **kwargs):
return State(**{**self.state_dict, **kwargs})
def __getattr__(self, name):
return self.state_dict[name]
def __setattr__(self, name, value):
raise AttributeError("State is immutable")
def __repr__(self):
return f"State ({self.state_dict})"
def tree_flatten(self):
children = list(self.state_dict.values())
aux_data = list(self.state_dict.keys())
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(**dict(zip(aux_data, children)))

106
tensorneat/utils/tools.py Normal file
View File

@@ -0,0 +1,106 @@
from functools import partial
import numpy as np
import jax
from jax import numpy as jnp, Array, jit, vmap
I_INT = np.iinfo(jnp.int32).max # infinite int
def unflatten_conns(nodes, conns):
"""
transform the (C, CL) connections to (CL-2, N, N), 2 is for the input index and output index)
:return:
"""
N = nodes.shape[0]
CL = conns.shape[1]
node_keys = nodes[:, 0]
i_keys, o_keys = conns[:, 0], conns[:, 1]
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
res = jnp.full((CL - 2, N, N), jnp.nan)
# Is interesting that jax use clip when attach data in array
# however, it will do nothing set values in an array
# put all attributes include enable in res
res = res.at[:, i_idxs, o_idxs].set(conns[:, 2:].T)
return res
def key_to_indices(key, keys):
return fetch_first(key == keys)
@jit
def fetch_first(mask, default=I_INT) -> Array:
"""
fetch the first True index
:param mask: array of bool
:param default: the default value if no element satisfying the condition
:return: the index of the first element satisfying the condition. if no element satisfying the condition, return default value
"""
idx = jnp.argmax(mask)
return jnp.where(mask[idx], idx, default)
@jit
def fetch_random(rand_key, mask, default=I_INT) -> Array:
"""
similar to fetch_first, but fetch a random True index
"""
true_cnt = jnp.sum(mask)
cumsum = jnp.cumsum(mask)
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
return fetch_first(mask, default)
@partial(jit, static_argnames=['reverse'])
def rank_elements(array, reverse=False):
"""
rank the element in the array.
if reverse is True, the rank is from small to large. default large to small
"""
if not reverse:
array = -array
return jnp.argsort(jnp.argsort(array))
@jit
def mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate):
k1, k2, k3 = jax.random.split(key, num=3)
noise = jax.random.normal(k1, ()) * mutate_power
replace = jax.random.normal(k2, ()) * init_std + init_mean
r = jax.random.uniform(k3, ())
val = jnp.where(
r < mutate_rate,
val + noise,
jnp.where(
(mutate_rate < r) & (r < mutate_rate + replace_rate),
replace,
val
)
)
return val
@jit
def mutate_int(key, val, options, replace_rate):
k1, k2 = jax.random.split(key, num=2)
r = jax.random.uniform(k1, ())
val = jnp.where(
r < replace_rate,
jax.random.choice(k2, options),
val
)
return val
def argmin_with_mask(arr, mask):
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx