complete fully stateful!

use black to format all files!
This commit is contained in:
wls2002
2024-05-26 18:08:43 +08:00
parent cf69b916af
commit 18c3d44c79
41 changed files with 620 additions and 495 deletions

View File

@@ -5,5 +5,5 @@ class BaseCrossover:
def setup(self, state=State()): def setup(self, state=State()):
return state return state
def __call__(self, state, genome, nodes1, nodes2, conns1, conns2): def __call__(self, state, randkey, genome, nodes1, nodes2, conns1, conns2):
raise NotImplementedError raise NotImplementedError

View File

@@ -4,12 +4,12 @@ from .base import BaseCrossover
class DefaultCrossover(BaseCrossover): class DefaultCrossover(BaseCrossover):
def __call__(self, state, genome, nodes1, conns1, nodes2, conns2): def __call__(self, state, randkey, genome, nodes1, conns1, nodes2, conns2):
""" """
use genome1 and genome2 to generate a new genome use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!) notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
""" """
randkey1, randkey2, randkey = jax.random.split(state.randkey, 3) randkey1, randkey2 = jax.random.split(randkey, 2)
# crossover nodes # crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0] keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
@@ -34,11 +34,12 @@ class DefaultCrossover(BaseCrossover):
self.crossover_gene(randkey2, conns1, conns2, is_conn=True), self.crossover_gene(randkey2, conns1, conns2, is_conn=True),
) )
return state.update(randkey=randkey), new_nodes, new_conns return new_nodes, new_conns
def align_array(self, seq1, seq2, ar2, is_conn: bool): def align_array(self, seq1, seq2, ar2, is_conn: bool):
""" """
After I review this code, I found that it is the most difficult part of the code. Please never change it! After I review this code, I found that it is the most difficult part of the code.
Please consider carefully before change it!
make ar2 align with ar1. make ar2 align with ar1.
:param seq1: :param seq1:
:param seq2: :param seq2:
@@ -64,8 +65,8 @@ class DefaultCrossover(BaseCrossover):
return refactor_ar2 return refactor_ar2
def crossover_gene(self, rand_key, g1, g2, is_conn): def crossover_gene(self, randkey, g1, g2, is_conn):
r = jax.random.uniform(rand_key, shape=g1.shape) r = jax.random.uniform(randkey, shape=g1.shape)
new_gene = jnp.where(r > 0.5, g1, g2) new_gene = jnp.where(r > 0.5, g1, g2)
if is_conn: # fix enabled if is_conn: # fix enabled
enabled = jnp.where(g1[:, 2] + g2[:, 2] > 0, 1, 0) # any of them is enabled enabled = jnp.where(g1[:, 2] + g2[:, 2] > 0, 1, 0) # any of them is enabled

View File

@@ -5,5 +5,5 @@ class BaseMutation:
def setup(self, state=State()): def setup(self, state=State()):
return state return state
def __call__(self, state, genome, nodes, conns, new_node_key): def __call__(self, state, randkey, genome, nodes, conns, new_node_key):
raise NotImplementedError raise NotImplementedError

View File

@@ -1,6 +1,16 @@
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from . import BaseMutation from . import BaseMutation
from utils import fetch_first, fetch_random, I_INF, unflatten_conns, check_cycles from utils import (
fetch_first,
fetch_random,
I_INF,
unflatten_conns,
check_cycles,
add_node,
add_conn,
delete_node_by_pos,
delete_conn_by_pos,
)
class DefaultMutation(BaseMutation): class DefaultMutation(BaseMutation):
@@ -16,15 +26,17 @@ class DefaultMutation(BaseMutation):
self.node_add = node_add self.node_add = node_add
self.node_delete = node_delete self.node_delete = node_delete
def __call__(self, state, genome, nodes, conns, new_node_key): def __call__(self, state, randkey, genome, nodes, conns, new_node_key):
k1, k2, randkey = jax.random.split(state.randkey) k1, k2 = jax.random.split(randkey)
nodes, conns = self.mutate_structure(k1, genome, nodes, conns, new_node_key) nodes, conns = self.mutate_structure(
nodes, conns = self.mutate_values(k2, genome, nodes, conns) state, k1, genome, nodes, conns, new_node_key
)
nodes, conns = self.mutate_values(state, k2, genome, nodes, conns)
return state.update(randkey=randkey), nodes, conns return nodes, conns
def mutate_structure(self, key, genome, nodes, conns, new_node_key): def mutate_structure(self, state, randkey, genome, nodes, conns, new_node_key):
def mutate_add_node(key_, nodes_, conns_): def mutate_add_node(key_, nodes_, conns_):
i_key, o_key, idx = self.choice_connection_key(key_, conns_) i_key, o_key, idx = self.choice_connection_key(key_, conns_)
@@ -33,24 +45,24 @@ class DefaultMutation(BaseMutation):
new_conns = conns_.at[idx, 2].set(False) new_conns = conns_.at[idx, 2].set(False)
# add a new node # add a new node
new_nodes = genome.add_node( new_nodes = add_node(
nodes_, new_node_key, genome.node_gene.new_custom_attrs() nodes_, new_node_key, genome.node_gene.new_custom_attrs(state)
) )
# add two new connections # add two new connections
new_conns = genome.add_conn( new_conns = add_conn(
new_conns, new_conns,
i_key, i_key,
new_node_key, new_node_key,
True, True,
genome.conn_gene.new_custom_attrs(), genome.conn_gene.new_custom_attrs(state),
) )
new_conns = genome.add_conn( new_conns = add_conn(
new_conns, new_conns,
new_node_key, new_node_key,
o_key, o_key,
True, True,
genome.conn_gene.new_custom_attrs(), genome.conn_gene.new_custom_attrs(state),
) )
return new_nodes, new_conns return new_nodes, new_conns
@@ -75,7 +87,7 @@ class DefaultMutation(BaseMutation):
def successful_delete_node(): def successful_delete_node():
# delete the node # delete the node
new_nodes = genome.delete_node_by_pos(nodes_, idx) new_nodes = delete_node_by_pos(nodes_, idx)
# delete all connections # delete all connections
new_conns = jnp.where( new_conns = jnp.where(
@@ -123,8 +135,8 @@ class DefaultMutation(BaseMutation):
return nodes_, conns_ return nodes_, conns_
def successful(): def successful():
return nodes_, genome.add_conn( return nodes_, add_conn(
conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs() conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs(state)
) )
def already_exist(): def already_exist():
@@ -152,7 +164,7 @@ class DefaultMutation(BaseMutation):
i_key, o_key, idx = self.choice_connection_key(key_, conns_) i_key, o_key, idx = self.choice_connection_key(key_, conns_)
def successfully_delete_connection(): def successfully_delete_connection():
return nodes_, genome.delete_conn_by_pos(conns_, idx) return nodes_, delete_conn_by_pos(conns_, idx)
return jax.lax.cond( return jax.lax.cond(
idx == I_INF, idx == I_INF,
@@ -160,7 +172,7 @@ class DefaultMutation(BaseMutation):
successfully_delete_connection, successfully_delete_connection,
) )
k1, k2, k3, k4 = jax.random.split(key, num=4) k1, k2, k3, k4 = jax.random.split(randkey, num=4)
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,)) r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
def no(key_, nodes_, conns_): def no(key_, nodes_, conns_):
@@ -181,13 +193,17 @@ class DefaultMutation(BaseMutation):
return nodes, conns return nodes, conns
def mutate_values(self, key, genome, nodes, conns): def mutate_values(self, state, randkey, genome, nodes, conns):
k1, k2 = jax.random.split(key, num=2) k1, k2 = jax.random.split(randkey, num=2)
nodes_keys = jax.random.split(k1, num=nodes.shape[0]) nodes_keys = jax.random.split(k1, num=nodes.shape[0])
conns_keys = jax.random.split(k2, num=conns.shape[0]) conns_keys = jax.random.split(k2, num=conns.shape[0])
new_nodes = jax.vmap(genome.node_gene.mutate)(nodes_keys, nodes) new_nodes = jax.vmap(genome.node_gene.mutate, in_axes=(None, 0, 0))(
new_conns = jax.vmap(genome.conn_gene.mutate)(conns_keys, conns) state, nodes_keys, nodes
)
new_conns = jax.vmap(genome.conn_gene.mutate, in_axes=(None, 0, 0))(
state, conns_keys, conns
)
# nan nodes not changed # nan nodes not changed
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes) new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)

View File

@@ -26,7 +26,7 @@ class DefaultConnGene(BaseConnGene):
self.weight_replace_rate = weight_replace_rate self.weight_replace_rate = weight_replace_rate
def new_custom_attrs(self, state): def new_custom_attrs(self, state):
return state, jnp.array([self.weight_init_mean]) return jnp.array([self.weight_init_mean])
def new_random_attrs(self, state, randkey): def new_random_attrs(self, state, randkey):
weight = ( weight = (

View File

@@ -109,10 +109,10 @@ class DefaultNodeGene(BaseNodeGene):
def distance(self, state, node1, node2): def distance(self, state, node1, node2):
return ( return (
jnp.abs(node1[1] - node2[1]) jnp.abs(node1[1] - node2[1]) # bias
+ jnp.abs(node1[2] - node2[2]) + jnp.abs(node1[2] - node2[2]) # response
+ (node1[3] != node2[3]) + (node1[3] != node2[3]) # activation
+ (node1[4] != node2[4]) + (node1[4] != node2[4]) # aggregation
) )
def forward(self, state, attrs, inputs, is_output_node=False): def forward(self, state, attrs, inputs, is_output_node=False):

View File

@@ -0,0 +1,106 @@
from typing import Tuple
import jax, jax.numpy as jnp
from utils import Act, Agg, act, agg, mutate_int, mutate_float
from . import BaseNodeGene
class NodeGeneWithoutResponse(BaseNodeGene):
"""
Default node gene, with the same behavior as in NEAT-python.
The attribute response is removed.
"""
custom_attrs = ["bias", "aggregation", "activation"]
def __init__(
self,
bias_init_mean: float = 0.0,
bias_init_std: float = 1.0,
bias_mutate_power: float = 0.5,
bias_mutate_rate: float = 0.7,
bias_replace_rate: float = 0.1,
activation_default: callable = Act.sigmoid,
activation_options: Tuple = (Act.sigmoid,),
activation_replace_rate: float = 0.1,
aggregation_default: callable = Agg.sum,
aggregation_options: Tuple = (Agg.sum,),
aggregation_replace_rate: float = 0.1,
):
super().__init__()
self.bias_init_mean = bias_init_mean
self.bias_init_std = bias_init_std
self.bias_mutate_power = bias_mutate_power
self.bias_mutate_rate = bias_mutate_rate
self.bias_replace_rate = bias_replace_rate
self.activation_default = activation_options.index(activation_default)
self.activation_options = activation_options
self.activation_indices = jnp.arange(len(activation_options))
self.activation_replace_rate = activation_replace_rate
self.aggregation_default = aggregation_options.index(aggregation_default)
self.aggregation_options = aggregation_options
self.aggregation_indices = jnp.arange(len(aggregation_options))
self.aggregation_replace_rate = aggregation_replace_rate
def new_custom_attrs(self, state):
return jnp.array(
[
self.bias_init_mean,
self.activation_default,
self.aggregation_default,
]
)
def new_random_attrs(self, state, randkey):
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
bias = jax.random.normal(k1, ()) * self.bias_init_std + self.bias_init_mean
act = jax.random.randint(k3, (), 0, len(self.activation_options))
agg = jax.random.randint(k4, (), 0, len(self.aggregation_options))
return jnp.array([bias, act, agg])
def mutate(self, state, randkey, node):
k1, k2, k3, k4 = jax.random.split(state.randkey, num=4)
index = node[0]
bias = mutate_float(
k1,
node[1],
self.bias_init_mean,
self.bias_init_std,
self.bias_mutate_power,
self.bias_mutate_rate,
self.bias_replace_rate,
)
act = mutate_int(
k3, node[3], self.activation_indices, self.activation_replace_rate
)
agg = mutate_int(
k4, node[4], self.aggregation_indices, self.aggregation_replace_rate
)
return jnp.array([index, bias, act, agg])
def distance(self, state, node1, node2):
return (
jnp.abs(node1[1] - node2[1]) # bias
+ (node1[3] != node2[3]) # activation
+ (node1[4] != node2[4]) # aggregation
)
def forward(self, state, attrs, inputs, is_output_node=False):
bias, act_idx, agg_idx = attrs
z = agg(agg_idx, inputs, self.aggregation_options)
z = bias + z
# the last output node should not be activated
z = jax.lax.cond(
is_output_node, lambda: z, lambda: act(act_idx, z, self.activation_options)
)
return z

View File

@@ -1,6 +1,7 @@
import jax.numpy as jnp import jax, jax.numpy as jnp
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene from ..gene import BaseNodeGene, BaseConnGene
from utils import fetch_first, State from ..ga import BaseMutation, BaseCrossover
from utils import State
class BaseGenome: class BaseGenome:
@@ -12,8 +13,10 @@ class BaseGenome:
num_outputs: int, num_outputs: int,
max_nodes: int, max_nodes: int,
max_conns: int, max_conns: int,
node_gene: BaseNodeGene = DefaultNodeGene(), node_gene: BaseNodeGene,
conn_gene: BaseConnGene = DefaultConnGene(), conn_gene: BaseConnGene,
mutation: BaseMutation,
crossover: BaseCrossover,
): ):
self.num_inputs = num_inputs self.num_inputs = num_inputs
self.num_outputs = num_outputs self.num_outputs = num_outputs
@@ -23,10 +26,14 @@ class BaseGenome:
self.max_conns = max_conns self.max_conns = max_conns
self.node_gene = node_gene self.node_gene = node_gene
self.conn_gene = conn_gene self.conn_gene = conn_gene
self.mutation = mutation
self.crossover = crossover
def setup(self, state=State()): def setup(self, state=State()):
state = self.node_gene.setup(state) state = self.node_gene.setup(state)
state = self.conn_gene.setup(state) state = self.conn_gene.setup(state)
state = self.mutation.setup(state)
state = self.crossover.setup(state)
return state return state
def transform(self, state, nodes, conns): def transform(self, state, nodes, conns):
@@ -35,36 +42,81 @@ class BaseGenome:
def forward(self, state, inputs, transformed): def forward(self, state, inputs, transformed):
raise NotImplementedError raise NotImplementedError
def add_node(self, nodes, new_key: int, attrs): def execute_mutation(self, state, randkey, nodes, conns, new_node_key):
""" return self.mutation(state, randkey, self, nodes, conns, new_node_key)
Add a new node to the genome.
The new node will place at the first NaN row.
"""
exist_keys = nodes[:, 0]
pos = fetch_first(jnp.isnan(exist_keys))
new_nodes = nodes.at[pos, 0].set(new_key)
return new_nodes.at[pos, 1:].set(attrs)
def delete_node_by_pos(self, nodes, pos): def execute_crossover(self, state, randkey, nodes1, conns1, nodes2, conns2):
""" return self.crossover(state, randkey, self, nodes1, conns1, nodes2, conns2)
Delete a node from the genome.
Delete the node by its pos in nodes.
"""
return nodes.at[pos].set(jnp.nan)
def add_conn(self, conns, i_key, o_key, enable: bool, attrs): def initialize(self, state, randkey):
""" """
Add a new connection to the genome. Default initialization method for the genome.
The new connection will place at the first NaN row. Add an extra hidden node.
""" Make all input nodes and output nodes connected to the hidden node.
con_keys = conns[:, 0] All attributes will be initialized randomly using gene.new_random_attrs method.
pos = fetch_first(jnp.isnan(con_keys))
new_conns = conns.at[pos, 0:3].set(jnp.array([i_key, o_key, enable]))
return new_conns.at[pos, 3:].set(attrs)
def delete_conn_by_pos(self, conns, pos): For example, a network with 2 inputs and 1 output, the structure will be:
nodes:
[
[0, attrs0], # input node 0
[1, attrs1], # input node 1
[2, attrs2], # output node 0
[3, attrs3], # hidden node
[NaN, NaN], # empty node
]
conns:
[
[0, 3, attrs0], # input node 0 -> hidden node
[1, 3, attrs1], # input node 1 -> hidden node
[3, 2, attrs2], # hidden node -> output node 0
[NaN, NaN],
[NaN, NaN],
]
""" """
Delete a connection from the genome.
Delete the connection by its idx. k1, k2 = jax.random.split(randkey) # k1 for nodes, k2 for conns
""" # initialize nodes
return conns.at[pos].set(jnp.nan) new_node_key = (
max([*self.input_idx, *self.output_idx]) + 1
) # the key for the hidden node
node_keys = jnp.concatenate(
[self.input_idx, self.output_idx, jnp.array([new_node_key])]
) # the list of all node keys
# initialize nodes and connections with NaN
nodes = jnp.full((self.max_nodes, self.node_gene.length), jnp.nan)
conns = jnp.full((self.max_conns, self.conn_gene.length), jnp.nan)
# set keys for input nodes, output nodes and hidden node
nodes = nodes.at[node_keys, 0].set(node_keys)
# generate random attributes for nodes
node_keys = jax.random.split(k1, len(node_keys))
random_node_attrs = jax.vmap(
self.node_gene.new_random_attrs, in_axes=(None, 0)
)(state, node_keys)
nodes = nodes.at[: len(node_keys), 1:].set(random_node_attrs)
# initialize conns
# input-hidden connections
input_conns = jnp.c_[
self.input_idx, jnp.full_like(self.input_idx, new_node_key)
]
conns = conns.at[self.input_idx, :2].set(input_conns) # in-keys, out-keys
conns = conns.at[self.input_idx, 2].set(True) # enable
# output-hidden connections
output_conns = jnp.c_[
jnp.full_like(self.output_idx, new_node_key), self.output_idx
]
conns = conns.at[self.output_idx, :2].set(output_conns) # in-keys, out-keys
conns = conns.at[self.output_idx, 2].set(True) # enable
conn_keys = jax.random.split(k2, num=len(self.input_idx) + len(self.output_idx))
# generate random attributes for conns
random_conn_attrs = jax.vmap(
self.conn_gene.new_random_attrs, in_axes=(None, 0)
)(state, conn_keys)
conns = conns.at[: len(conn_keys), 3:].set(random_conn_attrs)
return nodes, conns

View File

@@ -5,6 +5,7 @@ from utils import unflatten_conns, topological_sort, I_INF
from . import BaseGenome from . import BaseGenome
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
from ..ga import BaseMutation, BaseCrossover, DefaultMutation, DefaultCrossover
class DefaultGenome(BaseGenome): class DefaultGenome(BaseGenome):
@@ -20,10 +21,19 @@ class DefaultGenome(BaseGenome):
max_conns=4, max_conns=4,
node_gene: BaseNodeGene = DefaultNodeGene(), node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(), conn_gene: BaseConnGene = DefaultConnGene(),
mutation: BaseMutation = DefaultMutation(),
crossover: BaseCrossover = DefaultCrossover(),
output_transform: Callable = None, output_transform: Callable = None,
): ):
super().__init__( super().__init__(
num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene num_inputs,
num_outputs,
max_nodes,
max_conns,
node_gene,
conn_gene,
mutation,
crossover,
) )
if output_transform is not None: if output_transform is not None:

View File

@@ -5,6 +5,7 @@ from utils import unflatten_conns
from . import BaseGenome from . import BaseGenome
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
from ..ga import BaseMutation, BaseCrossover, DefaultMutation, DefaultCrossover
class RecurrentGenome(BaseGenome): class RecurrentGenome(BaseGenome):
@@ -20,11 +21,20 @@ class RecurrentGenome(BaseGenome):
max_conns: int, max_conns: int,
node_gene: BaseNodeGene = DefaultNodeGene(), node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(), conn_gene: BaseConnGene = DefaultConnGene(),
mutation: BaseMutation = DefaultMutation(),
crossover: BaseCrossover = DefaultCrossover(),
activate_time: int = 10, activate_time: int = 10,
output_transform: Callable = None, output_transform: Callable = None,
): ):
super().__init__( super().__init__(
num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene num_inputs,
num_outputs,
max_nodes,
max_conns,
node_gene,
conn_gene,
mutation,
crossover,
) )
self.activate_time = activate_time self.activate_time = activate_time

View File

@@ -10,18 +10,12 @@ class NEAT(BaseAlgorithm):
def __init__( def __init__(
self, self,
species: BaseSpecies, species: BaseSpecies,
mutation: BaseMutation = DefaultMutation(),
crossover: BaseCrossover = DefaultCrossover(),
): ):
self.genome: BaseGenome = species.genome
self.species = species self.species = species
self.mutation = mutation self.genome = species.genome
self.crossover = crossover
def setup(self, state=State()): def setup(self, state=State()):
state = self.species.setup(state) state = self.species.setup(state)
state = self.mutation.setup(state)
state = self.crossover.setup(state)
state = state.register( state = state.register(
generation=jnp.array(0.0), generation=jnp.array(0.0),
next_node_key=jnp.array( next_node_key=jnp.array(
@@ -32,18 +26,16 @@ class NEAT(BaseAlgorithm):
return state return state
def ask(self, state: State): def ask(self, state: State):
return state, self.species.ask(state.species) return self.species.ask(state)
def tell(self, state: State, fitness): def tell(self, state: State, fitness):
k1, k2, randkey = jax.random.split(state.randkey, 3) k1, k2, randkey = jax.random.split(state.randkey, 3)
state = state.update(generation=state.generation + 1, randkey=randkey) state = state.update(generation=state.generation + 1, randkey=randkey)
state, winner, loser, elite_mask = self.species.update_species( state, winner, loser, elite_mask = self.species.update_species(state, fitness)
state.species, fitness
)
state = self.create_next_generation(state, winner, loser, elite_mask) state = self.create_next_generation(state, winner, loser, elite_mask)
state = self.species.speciate(state.species) state = self.species.speciate(state)
return state return state
@@ -73,21 +65,25 @@ class NEAT(BaseAlgorithm):
new_node_keys = jnp.arange(pop_size) + state.next_node_key new_node_keys = jnp.arange(pop_size) + state.next_node_key
k1, k2, randkey = jax.random.split(state.randkey, 3) k1, k2, randkey = jax.random.split(state.randkey, 3)
crossover_rand_keys = jax.random.split(k1, pop_size) crossover_randkeys = jax.random.split(k1, pop_size)
mutate_rand_keys = jax.random.split(k2, pop_size) mutate_randkeys = jax.random.split(k2, pop_size)
wpn, wpc = state.species.pop_nodes[winner], state.species.pop_conns[winner] wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner]
lpn, lpc = state.species.pop_nodes[loser], state.species.pop_conns[loser] lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser]
# batch crossover # batch crossover
n_nodes, n_conns = jax.vmap(self.crossover, in_axes=(0, None, 0, 0, 0, 0))( n_nodes, n_conns = jax.vmap(
crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0)
) )(
state, crossover_randkeys, wpn, wpc, lpn, lpc
) # new_nodes, new_conns
# batch mutation # batch mutation
m_n_nodes, m_n_conns = jax.vmap(self.mutation, in_axes=(0, None, 0, 0, 0))( m_n_nodes, m_n_conns = jax.vmap(
mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
) )(
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
) # mutated_new_nodes, mutated_new_conns
# elitism don't mutate # elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes) pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
@@ -108,8 +104,8 @@ class NEAT(BaseAlgorithm):
) )
def member_count(self, state: State): def member_count(self, state: State):
return state, state.species.member_count return state.member_count
def generation(self, state: State): def generation(self, state: State):
# to analysis the algorithm # to analysis the algorithm
return state, state.generation return state.generation

View File

@@ -1,10 +1,22 @@
import numpy as np
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from utils import State, rank_elements, argmin_with_mask, fetch_first from utils import State, rank_elements, argmin_with_mask, fetch_first
from ..genome import BaseGenome from ..genome import BaseGenome
from .base import BaseSpecies from .base import BaseSpecies
"""
Core procedures of NEAT algorithm, contains the following steps:
1. Update the fitness of each species;
2. Decide which species will be stagnation;
3. Decide the number of members of each species in the next generation;
4. Choice the crossover pair for each species;
5. Divided the whole new population into different species;
This class use tensor operation to imitate the behavior of NEAT algorithm which implemented in NEAT-python.
The code may be hard to understand. Fortunately, we don't need to overwrite it in most cases.
"""
class DefaultSpecies(BaseSpecies): class DefaultSpecies(BaseSpecies):
def __init__( def __init__(
self, self,
@@ -20,8 +32,6 @@ class DefaultSpecies(BaseSpecies):
survival_threshold: float = 0.2, survival_threshold: float = 0.2,
min_species_size: int = 1, min_species_size: int = 1,
compatibility_threshold: float = 3.0, compatibility_threshold: float = 3.0,
initialize_method: str = "one_hidden_node",
# {'one_hidden_node', 'dense_hideen_layer', 'no_hidden_random'}
): ):
self.genome = genome self.genome = genome
self.pop_size = pop_size self.pop_size = pop_size
@@ -36,15 +46,17 @@ class DefaultSpecies(BaseSpecies):
self.survival_threshold = survival_threshold self.survival_threshold = survival_threshold
self.min_species_size = min_species_size self.min_species_size = min_species_size
self.compatibility_threshold = compatibility_threshold self.compatibility_threshold = compatibility_threshold
self.initialize_method = initialize_method
self.species_arange = jnp.arange(self.species_size) self.species_arange = jnp.arange(self.species_size)
def setup(self, state=State()): def setup(self, state=State()):
state = self.genome.setup(state) state = self.genome.setup(state)
k1, randkey = jax.random.split(state.randkey, 2) k1, randkey = jax.random.split(state.randkey, 2)
pop_nodes, pop_conns = initialize_population(
self.pop_size, self.genome, k1, self.initialize_method # initialize the population
initialize_keys = jax.random.split(randkey, self.pop_size)
pop_nodes, pop_conns = jax.vmap(self.genome.initialize, in_axes=(None, 0))(
state, initialize_keys
) )
species_keys = jnp.full( species_keys = jnp.full(
@@ -82,8 +94,9 @@ class DefaultSpecies(BaseSpecies):
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns)) pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns))
state = state.update(randkey=randkey)
return state.register( return state.register(
randkey=randkey,
pop_nodes=pop_nodes, pop_nodes=pop_nodes,
pop_conns=pop_conns, pop_conns=pop_conns,
species_keys=species_keys, species_keys=species_keys,
@@ -97,7 +110,7 @@ class DefaultSpecies(BaseSpecies):
) )
def ask(self, state): def ask(self, state):
return state, state.pop_nodes, state.pop_conns return state.pop_nodes, state.pop_conns
def update_species(self, state, fitness): def update_species(self, state, fitness):
# update the fitness of each species # update the fitness of each species
@@ -122,8 +135,8 @@ class DefaultSpecies(BaseSpecies):
k1, k2 = jax.random.split(state.randkey) k1, k2 = jax.random.split(state.randkey)
# crossover info # crossover info
winner, loser, elite_mask = self.create_crossover_pair( state, winner, loser, elite_mask = self.create_crossover_pair(
state, k1, spawn_number, fitness state, spawn_number, fitness
) )
return state.update(randkey=k2), winner, loser, elite_mask return state.update(randkey=k2), winner, loser, elite_mask
@@ -322,12 +335,12 @@ class DefaultSpecies(BaseSpecies):
winner = jnp.where(is_part1_win, part1, part2) winner = jnp.where(is_part1_win, part1, part2)
loser = jnp.where(is_part1_win, part2, part1) loser = jnp.where(is_part1_win, part2, part1)
return state(randkey=randkey), winner, loser, elite_mask return state.update(randkey=randkey), winner, loser, elite_mask
def speciate(self, state): def speciate(self, state):
# prepare distance functions # prepare distance functions
o2p_distance_func = jax.vmap( o2p_distance_func = jax.vmap(
self.distance, in_axes=(None, None, 0, 0) self.distance, in_axes=(None, None, None, 0, 0)
) # one to population ) # one to population
# idx to specie key # idx to specie key
@@ -351,7 +364,7 @@ class DefaultSpecies(BaseSpecies):
i, i2s, cns, ccs, o2c = carry i, i2s, cns, ccs, o2c = carry
distances = o2p_distance_func( distances = o2p_distance_func(
cns[i], ccs[i], state.pop_nodes, state.pop_conns state, cns[i], ccs[i], state.pop_nodes, state.pop_conns
) )
# find the closest one # find the closest one
@@ -434,7 +447,7 @@ class DefaultSpecies(BaseSpecies):
def speciate_by_threshold(i, i2s, cns, ccs, sk, o2c): def speciate_by_threshold(i, i2s, cns, ccs, sk, o2c):
# distance between such center genome and ppo genomes # distance between such center genome and ppo genomes
o2p_distance = o2p_distance_func( o2p_distance = o2p_distance_func(
cns[i], ccs[i], state.pop_nodes, state.pop_conns state, cns[i], ccs[i], state.pop_nodes, state.pop_conns
) )
close_enough_mask = o2p_distance < self.compatibility_threshold close_enough_mask = o2p_distance < self.compatibility_threshold
@@ -508,14 +521,16 @@ class DefaultSpecies(BaseSpecies):
next_species_key=next_species_key, next_species_key=next_species_key,
) )
def distance(self, nodes1, conns1, nodes2, conns2): def distance(self, state, nodes1, conns1, nodes2, conns2):
""" """
The distance between two genomes The distance between two genomes
""" """
d = self.node_distance(nodes1, nodes2) + self.conn_distance(conns1, conns2) d = self.node_distance(state, nodes1, nodes2) + self.conn_distance(
state, conns1, conns2
)
return d return d
def node_distance(self, nodes1, nodes2): def node_distance(self, state, nodes1, nodes2):
""" """
The distance of the nodes part for two genomes The distance of the nodes part for two genomes
""" """
@@ -541,7 +556,9 @@ class DefaultSpecies(BaseSpecies):
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask) non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
# calculate the distance of homologous nodes # calculate the distance of homologous nodes
hnd = jax.vmap(self.genome.node_gene.distance, in_axes=(0, 0))(fr, sr) hnd = jax.vmap(self.genome.node_gene.distance, in_axes=(None, 0, 0))(
state, fr, sr
) # homologous node distance
hnd = jnp.where(jnp.isnan(hnd), 0, hnd) hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
homologous_distance = jnp.sum(hnd * intersect_mask) homologous_distance = jnp.sum(hnd * intersect_mask)
@@ -550,9 +567,11 @@ class DefaultSpecies(BaseSpecies):
+ homologous_distance * self.compatibility_weight + homologous_distance * self.compatibility_weight
) )
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
def conn_distance(self, conns1, conns2): return val
def conn_distance(self, state, conns1, conns2):
""" """
The distance of the conns part for two genomes The distance of the conns part for two genomes
""" """
@@ -573,7 +592,9 @@ class DefaultSpecies(BaseSpecies):
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0]) intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask) non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
hcd = jax.vmap(self.genome.conn_gene.distance, in_axes=(0, 0))(fr, sr) hcd = jax.vmap(self.genome.conn_gene.distance, in_axes=(None, 0, 0))(
state, fr, sr
) # homologous connection distance
hcd = jnp.where(jnp.isnan(hcd), 0, hcd) hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
homologous_distance = jnp.sum(hcd * intersect_mask) homologous_distance = jnp.sum(hcd * intersect_mask)
@@ -582,185 +603,6 @@ class DefaultSpecies(BaseSpecies):
+ homologous_distance * self.compatibility_weight + homologous_distance * self.compatibility_weight
) )
return jnp.where(max_cnt == 0, 0, val / max_cnt) val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
return val
def initialize_population(pop_size, genome, randkey, init_method="default"):
rand_keys = jax.random.split(randkey, pop_size)
if init_method == "one_hidden_node":
init_func = init_one_hidden_node
elif init_method == "dense_hideen_layer":
init_func = init_dense_hideen_layer
elif init_method == "no_hidden_random":
init_func = init_no_hidden_random
else:
raise ValueError("Unknown initialization method: {}".format(init_method))
pop_nodes, pop_conns = jax.vmap(init_func, in_axes=(None, 0))(genome, rand_keys)
return pop_nodes, pop_conns
# one hidden node
def init_one_hidden_node(genome, randkey):
input_idx, output_idx = genome.input_idx, genome.output_idx
new_node_key = max([*input_idx, *output_idx]) + 1
nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan)
conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan)
nodes = nodes.at[input_idx, 0].set(input_idx)
nodes = nodes.at[output_idx, 0].set(output_idx)
nodes = nodes.at[new_node_key, 0].set(new_node_key)
rand_keys_nodes = jax.random.split(
randkey, num=len(input_idx) + len(output_idx) + 1
)
input_keys, output_keys, hidden_key = (
rand_keys_nodes[: len(input_idx)],
rand_keys_nodes[len(input_idx) : len(input_idx) + len(output_idx)],
rand_keys_nodes[-1],
)
node_attr_func = jax.vmap(genome.node_gene.new_attrs, in_axes=(None, 0))
input_attrs = node_attr_func(input_keys)
output_attrs = node_attr_func(output_keys)
hidden_attrs = genome.node_gene.new_custom_attrs(hidden_key)
nodes = nodes.at[input_idx, 1:].set(input_attrs)
nodes = nodes.at[output_idx, 1:].set(output_attrs)
nodes = nodes.at[new_node_key, 1:].set(hidden_attrs)
input_conns = jnp.c_[input_idx, jnp.full_like(input_idx, new_node_key)]
conns = conns.at[input_idx, 0:2].set(input_conns)
conns = conns.at[input_idx, 2].set(True)
output_conns = jnp.c_[jnp.full_like(output_idx, new_node_key), output_idx]
conns = conns.at[output_idx, 0:2].set(output_conns)
conns = conns.at[output_idx, 2].set(True)
rand_keys_conns = jax.random.split(randkey, num=len(input_idx) + len(output_idx))
input_conn_keys, output_conn_keys = (
rand_keys_conns[: len(input_idx)],
rand_keys_conns[len(input_idx) :],
)
conn_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(None, 0))
input_conn_attrs = conn_attr_func(input_conn_keys)
output_conn_attrs = conn_attr_func(output_conn_keys)
conns = conns.at[input_idx, 3:].set(input_conn_attrs)
conns = conns.at[output_idx, 3:].set(output_conn_attrs)
return nodes, conns
# random dense connections with 1 hidden layer
def init_dense_hideen_layer(genome, randkey, hiddens=20):
k1, k2, k3 = jax.random.split(randkey, num=3)
input_idx, output_idx = genome.input_idx, genome.output_idx
input_size = len(input_idx)
output_size = len(output_idx)
hidden_idx = jnp.arange(
input_size + output_size, input_size + output_size + hiddens
)
nodes = jnp.full(
(genome.max_nodes, genome.node_gene.length), jnp.nan, dtype=jnp.float32
)
nodes = nodes.at[input_idx, 0].set(input_idx)
nodes = nodes.at[output_idx, 0].set(output_idx)
nodes = nodes.at[hidden_idx, 0].set(hidden_idx)
total_idx = input_size + output_size + hiddens
rand_keys_n = jax.random.split(k1, num=total_idx)
input_keys = rand_keys_n[:input_size]
output_keys = rand_keys_n[input_size : input_size + output_size]
hidden_keys = rand_keys_n[input_size + output_size :]
node_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(0))
input_attrs = node_attr_func(input_keys)
output_attrs = node_attr_func(output_keys)
hidden_attrs = node_attr_func(hidden_keys)
nodes = nodes.at[input_idx, 1:].set(input_attrs)
nodes = nodes.at[output_idx, 1:].set(output_attrs)
nodes = nodes.at[hidden_idx, 1:].set(hidden_attrs)
total_connections = input_size * hiddens + hiddens * output_size
conns = jnp.full(
(genome.max_conns, genome.conn_gene.length), jnp.nan, dtype=jnp.float32
)
rand_keys_c = jax.random.split(k2, num=total_connections)
conns_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(0))
conns_attrs = conns_attr_func(rand_keys_c)
input_to_hidden_ids, hidden_ids = jnp.meshgrid(input_idx, hidden_idx, indexing="ij")
hidden_to_output_ids, output_ids = jnp.meshgrid(
hidden_idx, output_idx, indexing="ij"
)
conns = conns.at[: input_size * hiddens, 0].set(input_to_hidden_ids.flatten())
conns = conns.at[: input_size * hiddens, 1].set(hidden_ids.flatten())
conns = conns.at[input_size * hiddens : total_connections, 0].set(
hidden_to_output_ids.flatten()
)
conns = conns.at[input_size * hiddens : total_connections, 1].set(
output_ids.flatten()
)
conns = conns.at[: input_size * hiddens + hiddens * output_size, 2].set(True)
conns = conns.at[: input_size * hiddens + hiddens * output_size, 3:].set(
conns_attrs
)
return nodes, conns
# random sparse connections with no hidden nodes
def init_no_hidden_random(genome, randkey):
k1, k2, k3 = jax.random.split(randkey, num=3)
input_idx, output_idx = genome.input_idx, genome.output_idx
nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan)
nodes = nodes.at[input_idx, 0].set(input_idx)
nodes = nodes.at[output_idx, 0].set(output_idx)
total_idx = len(input_idx) + len(output_idx)
rand_keys_n = jax.random.split(k1, num=total_idx)
input_keys = rand_keys_n[: len(input_idx)]
output_keys = rand_keys_n[len(input_idx) :]
node_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(0))
input_attrs = node_attr_func(input_keys)
output_attrs = node_attr_func(output_keys)
nodes = nodes.at[input_idx, 1:].set(input_attrs)
nodes = nodes.at[output_idx, 1:].set(output_attrs)
conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan)
num_connections_per_output = 4
total_connections = len(output_idx) * num_connections_per_output
def create_connections_for_output(key):
permuted_inputs = jax.random.permutation(key, input_idx)
selected_inputs = permuted_inputs[:num_connections_per_output]
return selected_inputs
conn_keys = jax.random.split(k2, num=len(output_idx))
connections = jax.vmap(create_connections_for_output)(conn_keys)
connections = connections.flatten()
output_repeats = jnp.repeat(output_idx, num_connections_per_output)
rand_keys_c = jax.random.split(k3, num=total_connections)
conns_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(0))
conns_attrs = conns_attr_func(rand_keys_c)
conns = conns.at[:total_connections, 0].set(connections)
conns = conns.at[:total_connections, 1].set(output_repeats)
conns = conns.at[:total_connections, 2].set(True) # enabled
conns = conns.at[:total_connections, 3:].set(conns_attrs)
return nodes, conns

View File

@@ -4,7 +4,7 @@ from algorithm.neat import *
from problem.rl_env import BraxEnv from problem.rl_env import BraxEnv
from utils import Act from utils import Act
if __name__ == '__main__': if __name__ == "__main__":
pipeline = Pipeline( pipeline = Pipeline(
algorithm=NEAT( algorithm=NEAT(
species=DefaultSpecies( species=DefaultSpecies(
@@ -17,17 +17,17 @@ if __name__ == '__main__':
activation_options=(Act.tanh,), activation_options=(Act.tanh,),
activation_default=Act.tanh, activation_default=Act.tanh,
), ),
output_transform=Act.tanh output_transform=Act.tanh,
), ),
pop_size=1000, pop_size=1000,
species_size=10, species_size=10,
), ),
), ),
problem=BraxEnv( problem=BraxEnv(
env_name='ant', env_name="ant",
), ),
generation_limit=10000, generation_limit=10000,
fitness_target=5000 fitness_target=5000,
) )
# initialize state # initialize state

View File

@@ -4,7 +4,7 @@ from algorithm.neat import *
from problem.rl_env import BraxEnv from problem.rl_env import BraxEnv
from utils import Act from utils import Act
if __name__ == '__main__': if __name__ == "__main__":
pipeline = Pipeline( pipeline = Pipeline(
algorithm=NEAT( algorithm=NEAT(
species=DefaultSpecies( species=DefaultSpecies(
@@ -16,17 +16,17 @@ if __name__ == '__main__':
node_gene=DefaultNodeGene( node_gene=DefaultNodeGene(
activation_options=(Act.tanh,), activation_options=(Act.tanh,),
activation_default=Act.tanh, activation_default=Act.tanh,
) ),
), ),
pop_size=1000, pop_size=1000,
species_size=10, species_size=10,
), ),
), ),
problem=BraxEnv( problem=BraxEnv(
env_name='halfcheetah', env_name="halfcheetah",
), ),
generation_limit=10000, generation_limit=10000,
fitness_target=5000 fitness_target=5000,
) )
# initialize state # initialize state

View File

@@ -4,7 +4,7 @@ from algorithm.neat import *
from problem.rl_env import BraxEnv from problem.rl_env import BraxEnv
from utils import Act from utils import Act
if __name__ == '__main__': if __name__ == "__main__":
pipeline = Pipeline( pipeline = Pipeline(
algorithm=NEAT( algorithm=NEAT(
species=DefaultSpecies( species=DefaultSpecies(
@@ -16,17 +16,17 @@ if __name__ == '__main__':
node_gene=DefaultNodeGene( node_gene=DefaultNodeGene(
activation_options=(Act.tanh,), activation_options=(Act.tanh,),
activation_default=Act.tanh, activation_default=Act.tanh,
) ),
), ),
pop_size=100, pop_size=100,
species_size=10, species_size=10,
), ),
), ),
problem=BraxEnv( problem=BraxEnv(
env_name='reacher', env_name="reacher",
), ),
generation_limit=10000, generation_limit=10000,
fitness_target=5000 fitness_target=5000,
) )
# initialize state # initialize state

View File

@@ -4,7 +4,7 @@ from algorithm.neat import *
from problem.rl_env import BraxEnv from problem.rl_env import BraxEnv
from utils import Act from utils import Act
if __name__ == '__main__': if __name__ == "__main__":
pipeline = Pipeline( pipeline = Pipeline(
algorithm=NEAT( algorithm=NEAT(
species=DefaultSpecies( species=DefaultSpecies(
@@ -16,17 +16,17 @@ if __name__ == '__main__':
node_gene=DefaultNodeGene( node_gene=DefaultNodeGene(
activation_options=(Act.tanh,), activation_options=(Act.tanh,),
activation_default=Act.tanh, activation_default=Act.tanh,
) ),
), ),
pop_size=10000, pop_size=10000,
species_size=10, species_size=10,
), ),
), ),
problem=BraxEnv( problem=BraxEnv(
env_name='walker2d', env_name="walker2d",
), ),
generation_limit=10000, generation_limit=10000,
fitness_target=5000 fitness_target=5000,
) )
# initialize state # initialize state

View File

@@ -4,7 +4,7 @@ from algorithm.neat import *
from problem.func_fit import XOR3d from problem.func_fit import XOR3d
from utils import Act from utils import Act
if __name__ == '__main__': if __name__ == "__main__":
pipeline = Pipeline( pipeline = Pipeline(
algorithm=NEAT( algorithm=NEAT(
species=DefaultSpecies( species=DefaultSpecies(
@@ -18,22 +18,22 @@ if __name__ == '__main__':
activation_options=(Act.tanh,), activation_options=(Act.tanh,),
), ),
output_transform=Act.sigmoid, # the activation function for output node output_transform=Act.sigmoid, # the activation function for output node
mutation=DefaultMutation(
node_add=0.05,
conn_add=0.2,
node_delete=0,
conn_delete=0,
),
), ),
pop_size=10000, pop_size=10000,
species_size=10, species_size=10,
compatibility_threshold=3.5, compatibility_threshold=3.5,
survival_threshold=0.01, # magic survival_threshold=0.01, # magic
), ),
mutation=DefaultMutation(
node_add=0.05,
conn_add=0.2,
node_delete=0,
conn_delete=0,
)
), ),
problem=XOR3d(), problem=XOR3d(),
generation_limit=10000, generation_limit=10000,
fitness_target=-1e-8 fitness_target=-1e-8,
) )
# initialize state # initialize state

View File

@@ -5,17 +5,28 @@ from utils import Act
from problem.func_fit import XOR3d from problem.func_fit import XOR3d
if __name__ == '__main__': if __name__ == "__main__":
pipeline = Pipeline( pipeline = Pipeline(
algorithm=HyperNEAT( algorithm=HyperNEAT(
substrate=FullSubstrate( substrate=FullSubstrate(
input_coors=[(-1, -1), (0.333, -1), (-0.333, -1), (1, -1)], input_coors=[(-1, -1), (0.333, -1), (-0.333, -1), (1, -1)],
hidden_coors=[ hidden_coors=[
(-1, -0.5), (0.333, -0.5), (-0.333, -0.5), (1, -0.5), (-1, -0.5),
(-1, 0), (0.333, 0), (-0.333, 0), (1, 0), (0.333, -0.5),
(-1, 0.5), (0.333, 0.5), (-0.333, 0.5), (1, 0.5), (-0.333, -0.5),
(1, -0.5),
(-1, 0),
(0.333, 0),
(-0.333, 0),
(1, 0),
(-1, 0.5),
(0.333, 0.5),
(-0.333, 0.5),
(1, 0.5),
],
output_coors=[
(0, 1),
], ],
output_coors=[(0, 1), ],
), ),
neat=NEAT( neat=NEAT(
species=DefaultSpecies( species=DefaultSpecies(
@@ -42,7 +53,7 @@ if __name__ == '__main__':
), ),
problem=XOR3d(), problem=XOR3d(),
generation_limit=300, generation_limit=300,
fitness_target=-1e-6 fitness_target=-1e-6,
) )
# initialize state # initialize state

View File

@@ -1,10 +1,11 @@
from pipeline import Pipeline from pipeline import Pipeline
from algorithm.neat import * from algorithm.neat import *
from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse
from problem.func_fit import XOR3d from problem.func_fit import XOR3d
from utils.activation import ACT_ALL, Act from utils.activation import ACT_ALL, Act
if __name__ == '__main__': if __name__ == "__main__":
pipeline = Pipeline( pipeline = Pipeline(
seed=0, seed=0,
algorithm=NEAT( algorithm=NEAT(
@@ -15,27 +16,26 @@ if __name__ == '__main__':
max_nodes=50, max_nodes=50,
max_conns=100, max_conns=100,
activate_time=5, activate_time=5,
node_gene=DefaultNodeGene( node_gene=NodeGeneWithoutResponse(
activation_options=ACT_ALL, activation_options=ACT_ALL, activation_replace_rate=0.2
activation_replace_rate=0.2 ),
output_transform=Act.sigmoid,
mutation=DefaultMutation(
node_add=0.05,
conn_add=0.2,
node_delete=0,
conn_delete=0,
), ),
output_transform=Act.sigmoid
), ),
pop_size=10000, pop_size=10000,
species_size=10, species_size=10,
compatibility_threshold=3.5, compatibility_threshold=3.5,
survival_threshold=0.03, survival_threshold=0.03,
), ),
mutation=DefaultMutation(
node_add=0.05,
conn_add=0.2,
node_delete=0,
conn_delete=0,
)
), ),
problem=XOR3d(), problem=XOR3d(),
generation_limit=10000, generation_limit=10000,
fitness_target=-1e-8 fitness_target=-1e-8,
) )
# initialize state # initialize state

View File

@@ -5,7 +5,7 @@ from algorithm.neat import *
from problem.rl_env import GymNaxEnv from problem.rl_env import GymNaxEnv
if __name__ == '__main__': if __name__ == "__main__":
pipeline = Pipeline( pipeline = Pipeline(
algorithm=NEAT( algorithm=NEAT(
species=DefaultSpecies( species=DefaultSpecies(
@@ -14,17 +14,19 @@ if __name__ == '__main__':
num_outputs=3, num_outputs=3,
max_nodes=50, max_nodes=50,
max_conns=100, max_conns=100,
output_transform=lambda out: jnp.argmax(out) # the action of acrobot is {0, 1, 2} output_transform=lambda out: jnp.argmax(
out
), # the action of acrobot is {0, 1, 2}
), ),
pop_size=10000, pop_size=10000,
species_size=10, species_size=10,
), ),
), ),
problem=GymNaxEnv( problem=GymNaxEnv(
env_name='Acrobot-v1', env_name="Acrobot-v1",
), ),
generation_limit=10000, generation_limit=10000,
fitness_target=-62 fitness_target=-62,
) )
# initialize state # initialize state

View File

@@ -5,7 +5,7 @@ from algorithm.neat import *
from problem.rl_env import GymNaxEnv from problem.rl_env import GymNaxEnv
if __name__ == '__main__': if __name__ == "__main__":
pipeline = Pipeline( pipeline = Pipeline(
algorithm=NEAT( algorithm=NEAT(
species=DefaultSpecies( species=DefaultSpecies(
@@ -14,17 +14,19 @@ if __name__ == '__main__':
num_outputs=2, num_outputs=2,
max_nodes=50, max_nodes=50,
max_conns=100, max_conns=100,
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1} output_transform=lambda out: jnp.argmax(
out
), # the action of cartpole is {0, 1}
), ),
pop_size=10000, pop_size=10000,
species_size=10, species_size=10,
), ),
), ),
problem=GymNaxEnv( problem=GymNaxEnv(
env_name='CartPole-v1', env_name="CartPole-v1",
), ),
generation_limit=10000, generation_limit=10000,
fitness_target=500 fitness_target=500,
) )
# initialize state # initialize state

View File

@@ -10,11 +10,7 @@ from problem.rl_env import GymNaxConfig, GymNaxEnv
def example_conf(): def example_conf():
return Config( return Config(
basic=BasicConfig( basic=BasicConfig(seed=42, fitness_target=500, pop_size=10000),
seed=42,
fitness_target=500,
pop_size=10000
),
neat=NeatConfig( neat=NeatConfig(
inputs=4, inputs=4,
outputs=1, outputs=1,
@@ -23,28 +19,31 @@ def example_conf():
activation_default=Act.tanh, activation_default=Act.tanh,
activation_options=(Act.tanh,), activation_options=(Act.tanh,),
), ),
hyperneat=HyperNeatConfig( hyperneat=HyperNeatConfig(activation=Act.sigmoid, inputs=4, outputs=2),
activation=Act.sigmoid,
inputs=4,
outputs=2
),
substrate=NormalSubstrateConfig( substrate=NormalSubstrateConfig(
input_coors=((-1, -1), (-0.5, -1), (0, -1), (0.5, -1), (1, -1)), input_coors=((-1, -1), (-0.5, -1), (0, -1), (0.5, -1), (1, -1)),
hidden_coors=( hidden_coors=(
# (-1, -0.5), (-0.5, -0.5), (0, -0.5), (0.5, -0.5), # (-1, -0.5), (-0.5, -0.5), (0, -0.5), (0.5, -0.5),
(1, 0), (-1, 0), (-0.5, 0), (0, 0), (0.5, 0), (1, 0), (1, 0),
(-1, 0),
(-0.5, 0),
(0, 0),
(0.5, 0),
(1, 0),
# (1, 0.5), (-1, 0.5), (-0.5, 0.5), (0, 0.5), (0.5, 0.5), (1, 0.5), # (1, 0.5), (-1, 0.5), (-0.5, 0.5), (0, 0.5), (0.5, 0.5), (1, 0.5),
), ),
output_coors=((-1, 1), (1, 1)), output_coors=((-1, 1), (1, 1)),
), ),
problem=GymNaxConfig( problem=GymNaxConfig(
env_name='CartPole-v1', env_name="CartPole-v1",
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1} output_transform=lambda out: jnp.argmax(
) out
), # the action of cartpole is {0, 1}
),
) )
if __name__ == '__main__': if __name__ == "__main__":
conf = example_conf() conf = example_conf()
algorithm = HyperNEAT(conf, NormalGene, NormalSubstrate) algorithm = HyperNEAT(conf, NormalGene, NormalSubstrate)

View File

@@ -5,7 +5,7 @@ from algorithm.neat import *
from problem.rl_env import GymNaxEnv from problem.rl_env import GymNaxEnv
if __name__ == '__main__': if __name__ == "__main__":
pipeline = Pipeline( pipeline = Pipeline(
algorithm=NEAT( algorithm=NEAT(
species=DefaultSpecies( species=DefaultSpecies(
@@ -14,17 +14,19 @@ if __name__ == '__main__':
num_outputs=3, num_outputs=3,
max_nodes=50, max_nodes=50,
max_conns=100, max_conns=100,
output_transform=lambda out: jnp.argmax(out) # the action of mountain car is {0, 1, 2} output_transform=lambda out: jnp.argmax(
out
), # the action of mountain car is {0, 1, 2}
), ),
pop_size=10000, pop_size=10000,
species_size=10, species_size=10,
), ),
), ),
problem=GymNaxEnv( problem=GymNaxEnv(
env_name='MountainCar-v0', env_name="MountainCar-v0",
), ),
generation_limit=10000, generation_limit=10000,
fitness_target=0 fitness_target=0,
) )
# initialize state # initialize state

View File

@@ -4,7 +4,7 @@ from algorithm.neat import *
from problem.rl_env import GymNaxEnv from problem.rl_env import GymNaxEnv
from utils import Act from utils import Act
if __name__ == '__main__': if __name__ == "__main__":
pipeline = Pipeline( pipeline = Pipeline(
algorithm=NEAT( algorithm=NEAT(
species=DefaultSpecies( species=DefaultSpecies(
@@ -14,19 +14,19 @@ if __name__ == '__main__':
max_nodes=50, max_nodes=50,
max_conns=100, max_conns=100,
node_gene=DefaultNodeGene( node_gene=DefaultNodeGene(
activation_options=(Act.tanh, ), activation_options=(Act.tanh,),
activation_default=Act.tanh, activation_default=Act.tanh,
) ),
), ),
pop_size=10000, pop_size=10000,
species_size=10, species_size=10,
), ),
), ),
problem=GymNaxEnv( problem=GymNaxEnv(
env_name='MountainCarContinuous-v0', env_name="MountainCarContinuous-v0",
), ),
generation_limit=10000, generation_limit=10000,
fitness_target=500 fitness_target=500,
) )
# initialize state # initialize state

View File

@@ -4,7 +4,7 @@ from algorithm.neat import *
from problem.rl_env import GymNaxEnv from problem.rl_env import GymNaxEnv
from utils import Act from utils import Act
if __name__ == '__main__': if __name__ == "__main__":
pipeline = Pipeline( pipeline = Pipeline(
algorithm=NEAT( algorithm=NEAT(
species=DefaultSpecies( species=DefaultSpecies(
@@ -17,17 +17,18 @@ if __name__ == '__main__':
activation_options=(Act.tanh,), activation_options=(Act.tanh,),
activation_default=Act.tanh, activation_default=Act.tanh,
), ),
output_transform=lambda out: out * 2 # the action of pendulum is [-2, 2] output_transform=lambda out: out
* 2, # the action of pendulum is [-2, 2]
), ),
pop_size=10000, pop_size=10000,
species_size=10, species_size=10,
), ),
), ),
problem=GymNaxEnv( problem=GymNaxEnv(
env_name='Pendulum-v1', env_name="Pendulum-v1",
), ),
generation_limit=10000, generation_limit=10000,
fitness_target=0 fitness_target=0,
) )
# initialize state # initialize state

View File

@@ -5,7 +5,7 @@ from algorithm.neat import *
from problem.rl_env import GymNaxEnv from problem.rl_env import GymNaxEnv
if __name__ == '__main__': if __name__ == "__main__":
pipeline = Pipeline( pipeline = Pipeline(
algorithm=NEAT( algorithm=NEAT(
species=DefaultSpecies( species=DefaultSpecies(
@@ -20,10 +20,10 @@ if __name__ == '__main__':
), ),
), ),
problem=GymNaxEnv( problem=GymNaxEnv(
env_name='Reacher-misc', env_name="Reacher-misc",
), ),
generation_limit=10000, generation_limit=10000,
fitness_target =500 fitness_target=500,
) )
# initialize state # initialize state

View File

@@ -1,4 +1,5 @@
import ray import ray
ray.init(num_gpus=2) ray.init(num_gpus=2)
available_resources = ray.available_resources() available_resources = ray.available_resources()

View File

@@ -10,7 +10,6 @@ from utils import State
class Pipeline: class Pipeline:
def __init__( def __init__(
self, self,
algorithm: BaseAlgorithm, algorithm: BaseAlgorithm,
@@ -31,32 +30,35 @@ class Pipeline:
# print(self.problem.input_shape, self.problem.output_shape) # print(self.problem.input_shape, self.problem.output_shape)
# TODO: make each algorithm's input_num and output_num # TODO: make each algorithm's input_num and output_num
assert algorithm.num_inputs == self.problem.input_shape[-1], \ assert (
f"algorithm input shape is {algorithm.num_inputs} but problem input shape is {self.problem.input_shape}" algorithm.num_inputs == self.problem.input_shape[-1]
), f"algorithm input shape is {algorithm.num_inputs} but problem input shape is {self.problem.input_shape}"
self.best_genome = None self.best_genome = None
self.best_fitness = float('-inf') self.best_fitness = float("-inf")
self.generation_timestamp = None self.generation_timestamp = None
def setup(self, state=State()): def setup(self, state=State()):
print("initializing")
state = state.register(randkey=jax.random.PRNGKey(self.seed)) state = state.register(randkey=jax.random.PRNGKey(self.seed))
state = self.algorithm.setup(state) state = self.algorithm.setup(state)
state = self.problem.setup(state) state = self.problem.setup(state)
print("initializing finished")
return state return state
def step(self, state): def step(self, state):
randkey_, randkey = jax.random.split(state.randkey) randkey_, randkey = jax.random.split(state.randkey)
keys = jax.random.split(randkey_, self.pop_size) keys = jax.random.split(randkey_, self.pop_size)
state, pop = self.algorithm.ask(state) pop = self.algorithm.ask(state)
state, pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0), out_axes=(None, 0))(state, pop) pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0))(
state, pop
)
state, fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0), out_axes=(None, 0))( fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
keys, state, keys, self.algorithm.forward, pop_transformed
state,
self.algorithm.forward,
pop_transformed
) )
state = self.algorithm.tell(state, fitnesses) state = self.algorithm.tell(state, fitnesses)
@@ -67,13 +69,15 @@ class Pipeline:
print("start compile") print("start compile")
tic = time.time() tic = time.time()
compiled_step = jax.jit(self.step).lower(state).compile() compiled_step = jax.jit(self.step).lower(state).compile()
print(f"compile finished, cost time: {time.time() - tic:.6f}s", ) print(
f"compile finished, cost time: {time.time() - tic:.6f}s",
)
for _ in range(self.generation_limit): for _ in range(self.generation_limit):
self.generation_timestamp = time.time() self.generation_timestamp = time.time()
state, previous_pop = self.algorithm.ask(state) previous_pop = self.algorithm.ask(state)
state, fitnesses = compiled_step(state) state, fitnesses = compiled_step(state)
@@ -98,7 +102,12 @@ class Pipeline:
def analysis(self, state, pop, fitnesses): def analysis(self, state, pop, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses) max_f, min_f, mean_f, std_f = (
max(fitnesses),
min(fitnesses),
np.mean(fitnesses),
np.std(fitnesses),
)
new_timestamp = time.time() new_timestamp = time.time()
@@ -112,10 +121,14 @@ class Pipeline:
member_count = jax.device_get(self.algorithm.member_count(state)) member_count = jax.device_get(self.algorithm.member_count(state))
species_sizes = [int(i) for i in member_count if i > 0] species_sizes = [int(i) for i in member_count if i > 0]
print(f"Generation: {self.algorithm.generation(state)}", print(
f"Generation: {self.algorithm.generation(state)}",
f"species: {len(species_sizes)}, {species_sizes}", f"species: {len(species_sizes)}, {species_sizes}",
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms") f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms",
)
def show(self, state, best, *args, **kwargs): def show(self, state, best, *args, **kwargs):
state, transformed = self.algorithm.transform(state, best) transformed = self.algorithm.transform(state, best)
self.problem.show(state.randkey, state, self.algorithm.forward, transformed, *args, **kwargs) self.problem.show(
state, state.randkey, self.algorithm.forward, transformed, *args, **kwargs
)

View File

@@ -10,7 +10,7 @@ class BaseProblem:
"""initialize the state of the problem""" """initialize the state of the problem"""
return state return state
def evaluate(self, randkey, state: State, act_func: Callable, params): def evaluate(self, state: State, randkey, act_func: Callable, params):
"""evaluate one individual""" """evaluate one individual"""
raise NotImplementedError raise NotImplementedError
@@ -32,7 +32,7 @@ class BaseProblem:
""" """
raise NotImplementedError raise NotImplementedError
def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs): def show(self, state: State, randkey, act_func: Callable, params, *args, **kwargs):
""" """
show how a genome perform in this problem show how a genome perform in this problem
""" """

View File

@@ -8,42 +8,44 @@ from .. import BaseProblem
class FuncFit(BaseProblem): class FuncFit(BaseProblem):
jitable = True jitable = True
def __init__(self, def __init__(self, error_method: str = "mse"):
error_method: str = 'mse'
):
super().__init__() super().__init__()
assert error_method in {'mse', 'rmse', 'mae', 'mape'} assert error_method in {"mse", "rmse", "mae", "mape"}
self.error_method = error_method self.error_method = error_method
def setup(self, state: State = State()): def setup(self, state: State = State()):
return state return state
def evaluate(self, randkey, state, act_func, params): def evaluate(self, state, randkey, act_func, params):
state, predict = jax.vmap(act_func, in_axes=(None, 0, None), out_axes=(None, 0))(state, self.inputs, params) predict = jax.vmap(act_func, in_axes=(None, 0, None))(
state, self.inputs, params
)
if self.error_method == 'mse': if self.error_method == "mse":
loss = jnp.mean((predict - self.targets) ** 2) loss = jnp.mean((predict - self.targets) ** 2)
elif self.error_method == 'rmse': elif self.error_method == "rmse":
loss = jnp.sqrt(jnp.mean((predict - self.targets) ** 2)) loss = jnp.sqrt(jnp.mean((predict - self.targets) ** 2))
elif self.error_method == 'mae': elif self.error_method == "mae":
loss = jnp.mean(jnp.abs(predict - self.targets)) loss = jnp.mean(jnp.abs(predict - self.targets))
elif self.error_method == 'mape': elif self.error_method == "mape":
loss = jnp.mean(jnp.abs((predict - self.targets) / self.targets)) loss = jnp.mean(jnp.abs((predict - self.targets) / self.targets))
else: else:
raise NotImplementedError raise NotImplementedError
return state, -loss return -loss
def show(self, randkey, state, act_func, params, *args, **kwargs): def show(self, state, randkey, act_func, params, *args, **kwargs):
state, predict = jax.vmap(act_func, in_axes=(None, 0, None), out_axes=(None, 0))(state, self.inputs, params) predict = jax.vmap(act_func, in_axes=(None, 0, None))(
state, self.inputs, params
)
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict]) inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
state, loss = self.evaluate(randkey, state, act_func, params) loss = self.evaluate(state, randkey, act_func, params)
loss = -loss loss = -loss
msg = "" msg = ""

View File

@@ -4,27 +4,16 @@ from .func_fit import FuncFit
class XOR(FuncFit): class XOR(FuncFit):
def __init__(self, error_method: str = "mse"):
def __init__(self, error_method: str = 'mse'):
super().__init__(error_method) super().__init__(error_method)
@property @property
def inputs(self): def inputs(self):
return np.array([ return np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
[0, 0],
[0, 1],
[1, 0],
[1, 1]
])
@property @property
def targets(self): def targets(self):
return np.array([ return np.array([[0], [1], [1], [0]])
[0],
[1],
[1],
[0]
])
@property @property
def input_shape(self): def input_shape(self):

View File

@@ -4,13 +4,13 @@ from .func_fit import FuncFit
class XOR3d(FuncFit): class XOR3d(FuncFit):
def __init__(self, error_method: str = "mse"):
def __init__(self, error_method: str = 'mse'):
super().__init__(error_method) super().__init__(error_method)
@property @property
def inputs(self): def inputs(self):
return np.array([ return np.array(
[
[0, 0, 0], [0, 0, 0],
[0, 0, 1], [0, 0, 1],
[0, 1, 0], [0, 1, 0],
@@ -19,20 +19,12 @@ class XOR3d(FuncFit):
[1, 0, 1], [1, 0, 1],
[1, 1, 0], [1, 1, 0],
[1, 1, 1], [1, 1, 1],
]) ]
)
@property @property
def targets(self): def targets(self):
return np.array([ return np.array([[0], [1], [1], [0], [1], [0], [0], [1]])
[0],
[1],
[1],
[0],
[1],
[0],
[0],
[1]
])
@property @property
def input_shape(self): def input_shape(self):

View File

@@ -25,7 +25,19 @@ class BraxEnv(RLEnv):
def output_shape(self): def output_shape(self):
return (self.env.action_size,) return (self.env.action_size,)
def show(self, randkey, state, act_func, params, save_path=None, height=512, width=512, duration=0.1, *args, **kwargs): def show(
self,
state,
randkey,
act_func,
params,
save_path=None,
height=512,
width=512,
duration=0.1,
*args,
**kwargs
):
import jax import jax
import imageio import imageio
@@ -48,11 +60,13 @@ class BraxEnv(RLEnv):
key, env_state, obs, r, done = jax.jit(step)(randkey, env_state, obs) key, env_state, obs, r, done = jax.jit(step)(randkey, env_state, obs)
reward += r reward += r
imgs = [image.render_array(sys=self.env.sys, state=s, width=width, height=height) for s in imgs = [
tqdm(state_histories, desc="Rendering")] image.render_array(sys=self.env.sys, state=s, width=width, height=height)
for s in tqdm(state_histories, desc="Rendering")
]
def create_gif(image_list, gif_name, duration): def create_gif(image_list, gif_name, duration):
with imageio.get_writer(gif_name, mode='I', duration=duration) as writer: with imageio.get_writer(gif_name, mode="I", duration=duration) as writer:
for image in image_list: for image in image_list:
formatted_image = np.array(image, dtype=np.uint8) formatted_image = np.array(image, dtype=np.uint8)
writer.append_data(formatted_image) writer.append_data(formatted_image)
@@ -60,5 +74,3 @@ class BraxEnv(RLEnv):
create_gif(imgs, save_path, duration=0.1) create_gif(imgs, save_path, duration=0.1)
print("Gif saved to: ", save_path) print("Gif saved to: ", save_path)
print("Total reward: ", reward) print("Total reward: ", reward)

View File

@@ -4,7 +4,6 @@ from .rl_jit import RLEnv
class GymNaxEnv(RLEnv): class GymNaxEnv(RLEnv):
def __init__(self, env_name): def __init__(self, env_name):
super().__init__() super().__init__()
assert env_name in gymnax.registered_envs, f"Env {env_name} not registered" assert env_name in gymnax.registered_envs, f"Env {env_name} not registered"
@@ -24,5 +23,5 @@ class GymNaxEnv(RLEnv):
def output_shape(self): def output_shape(self):
return self.env.action_space(self.env_params).shape return self.env.action_space(self.env_params).shape
def show(self, randkey, state, act_func, params, *args, **kwargs): def show(self, state, randkey, act_func, params, *args, **kwargs):
raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).") raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).")

View File

@@ -12,28 +12,28 @@ class RLEnv(BaseProblem):
super().__init__() super().__init__()
self.max_step = max_step self.max_step = max_step
def evaluate(self, randkey, state, act_func, params): def evaluate(self, state, randkey, act_func, params):
rng_reset, rng_episode = jax.random.split(randkey) rng_reset, rng_episode = jax.random.split(randkey)
init_obs, init_env_state = self.reset(rng_reset) init_obs, init_env_state = self.reset(rng_reset)
def cond_func(carry): def cond_func(carry):
_, _, _, _, done, _, count = carry _, _, _, done, _, count = carry
return ~done & (count < self.max_step) return ~done & (count < self.max_step)
def body_func(carry): def body_func(carry):
state_, obs, env_state, rng, done, tr, count = carry # tr -> total reward obs, env_state, rng, done, tr, count = carry # tr -> total reward
state_, action = act_func(state_, obs, params) action = act_func(state, obs, params)
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action) next_obs, next_env_state, reward, done, _ = self.step(
rng, env_state, action
)
next_rng, _ = jax.random.split(rng) next_rng, _ = jax.random.split(rng)
return state_, next_obs, next_env_state, next_rng, done, tr + reward, count + 1 return next_obs, next_env_state, next_rng, done, tr + reward, count + 1
state, _, _, _, _, total_reward, _ = jax.lax.while_loop( _, _, _, _, total_reward, _ = jax.lax.while_loop(
cond_func, cond_func, body_func, (init_obs, init_env_state, rng_episode, False, 0.0, 0)
body_func,
(state, init_obs, init_env_state, rng_episode, False, 0.0, 0)
) )
return state, total_reward return total_reward
# @partial(jax.jit, static_argnums=(0,)) # @partial(jax.jit, static_argnums=(0,))
def step(self, randkey, env_state, action): def step(self, randkey, env_state, action):
@@ -57,5 +57,5 @@ class RLEnv(BaseProblem):
def output_shape(self): def output_shape(self):
raise NotImplementedError raise NotImplementedError
def show(self, randkey, state, act_func, params, *args, **kwargs): def show(self, state, randkey, act_func, params, *args, **kwargs):
raise NotImplementedError raise NotImplementedError

View File

@@ -36,7 +36,9 @@ def main():
elite_mask = jnp.zeros((1000,), dtype=jnp.bool_) elite_mask = jnp.zeros((1000,), dtype=jnp.bool_)
elite_mask = elite_mask.at[:5].set(1) elite_mask = elite_mask.at[:5].set(1)
state = algorithm.create_next_generation(jax.random.key(0), state, winner, losser, elite_mask) state = algorithm.create_next_generation(
jax.random.key(0), state, winner, losser, elite_mask
)
pop_nodes, pop_conns = algorithm.species.ask(state.species) pop_nodes, pop_conns = algorithm.species.ask(state.species)
transforms = batch_transform(pop_nodes, pop_conns) transforms = batch_transform(pop_nodes, pop_conns)
@@ -48,5 +50,5 @@ def main():
print(_) print(_)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@@ -19,7 +19,7 @@ def main():
node_gene=DefaultNodeGene( node_gene=DefaultNodeGene(
activation_options=(Act.tanh,), activation_options=(Act.tanh,),
activation_default=Act.tanh, activation_default=Act.tanh,
) ),
) )
transformed = genome.transform(nodes, conns) transformed = genome.transform(nodes, conns)
@@ -35,7 +35,7 @@ def main():
print(output) print(output)
if __name__ == '__main__': if __name__ == "__main__":
a = jnp.array([1, 3, 5, 6, 8]) a = jnp.array([1, 3, 5, 6, 8])
b = jnp.array([1, 2, 3]) b = jnp.array([1, 2, 3])
print(jnp.isin(a, b)) print(jnp.isin(a, b))

View File

@@ -2,6 +2,7 @@ from algorithm.neat import *
from utils import Act, Agg, State from utils import Act, Agg, State
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse
def test_default(): def test_default():
@@ -135,3 +136,29 @@ def test_recurrent():
print(outputs) print(outputs)
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]])) assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
# expected: [[0.5], [0.75], [0.5], [0.75]] # expected: [[0.5], [0.75], [0.5], [0.75]]
def test_random_initialize():
genome = DefaultGenome(
num_inputs=2,
num_outputs=1,
max_nodes=5,
max_conns=4,
node_gene=NodeGeneWithoutResponse(
activation_default=Act.identity,
activation_options=(Act.identity,),
aggregation_default=Agg.sum,
aggregation_options=(Agg.sum,),
),
)
state = genome.setup()
key = jax.random.PRNGKey(0)
nodes, conns = genome.initialize(state, key)
transformed = genome.transform(state, nodes, conns)
print(*transformed, sep="\n")
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, None)))(
state, inputs, transformed
)
print(outputs)

View File

@@ -19,11 +19,11 @@ def main():
node_gene=DefaultNodeGene( node_gene=DefaultNodeGene(
activation_options=(Act.tanh,), activation_options=(Act.tanh,),
activation_default=Act.tanh, activation_default=Act.tanh,
) ),
) )
transformed = genome.transform(nodes, conns) transformed = genome.transform(nodes, conns)
print(*transformed, sep='\n') print(*transformed, sep="\n")
key = jax.random.key(0) key = jax.random.key(0)
dummy_input = jnp.zeros((8,)) dummy_input = jnp.zeros((8,))
@@ -31,5 +31,5 @@ def main():
print(output) print(output)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@@ -116,3 +116,41 @@ def argmin_with_mask(arr, mask):
masked_arr = jnp.where(mask, arr, jnp.inf) masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr) min_idx = jnp.argmin(masked_arr)
return min_idx return min_idx
def add_node(nodes, new_key: int, attrs):
"""
Add a new node to the genome.
The new node will place at the first NaN row.
"""
exist_keys = nodes[:, 0]
pos = fetch_first(jnp.isnan(exist_keys))
new_nodes = nodes.at[pos, 0].set(new_key)
return new_nodes.at[pos, 1:].set(attrs)
def delete_node_by_pos(nodes, pos):
"""
Delete a node from the genome.
Delete the node by its pos in nodes.
"""
return nodes.at[pos].set(jnp.nan)
def add_conn(conns, i_key, o_key, enable: bool, attrs):
"""
Add a new connection to the genome.
The new connection will place at the first NaN row.
"""
con_keys = conns[:, 0]
pos = fetch_first(jnp.isnan(con_keys))
new_conns = conns.at[pos, 0:3].set(jnp.array([i_key, o_key, enable]))
return new_conns.at[pos, 3:].set(attrs)
def delete_conn_by_pos(conns, pos):
"""
Delete a connection from the genome.
Delete the connection by its idx.
"""
return conns.at[pos].set(jnp.nan)