complete fully stateful!
use black to format all files!
This commit is contained in:
@@ -5,5 +5,5 @@ class BaseCrossover:
|
||||
def setup(self, state=State()):
|
||||
return state
|
||||
|
||||
def __call__(self, state, genome, nodes1, nodes2, conns1, conns2):
|
||||
def __call__(self, state, randkey, genome, nodes1, nodes2, conns1, conns2):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -4,12 +4,12 @@ from .base import BaseCrossover
|
||||
|
||||
|
||||
class DefaultCrossover(BaseCrossover):
|
||||
def __call__(self, state, genome, nodes1, conns1, nodes2, conns2):
|
||||
def __call__(self, state, randkey, genome, nodes1, conns1, nodes2, conns2):
|
||||
"""
|
||||
use genome1 and genome2 to generate a new genome
|
||||
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
||||
"""
|
||||
randkey1, randkey2, randkey = jax.random.split(state.randkey, 3)
|
||||
randkey1, randkey2 = jax.random.split(randkey, 2)
|
||||
|
||||
# crossover nodes
|
||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||
@@ -34,11 +34,12 @@ class DefaultCrossover(BaseCrossover):
|
||||
self.crossover_gene(randkey2, conns1, conns2, is_conn=True),
|
||||
)
|
||||
|
||||
return state.update(randkey=randkey), new_nodes, new_conns
|
||||
return new_nodes, new_conns
|
||||
|
||||
def align_array(self, seq1, seq2, ar2, is_conn: bool):
|
||||
"""
|
||||
After I review this code, I found that it is the most difficult part of the code. Please never change it!
|
||||
After I review this code, I found that it is the most difficult part of the code.
|
||||
Please consider carefully before change it!
|
||||
make ar2 align with ar1.
|
||||
:param seq1:
|
||||
:param seq2:
|
||||
@@ -64,8 +65,8 @@ class DefaultCrossover(BaseCrossover):
|
||||
|
||||
return refactor_ar2
|
||||
|
||||
def crossover_gene(self, rand_key, g1, g2, is_conn):
|
||||
r = jax.random.uniform(rand_key, shape=g1.shape)
|
||||
def crossover_gene(self, randkey, g1, g2, is_conn):
|
||||
r = jax.random.uniform(randkey, shape=g1.shape)
|
||||
new_gene = jnp.where(r > 0.5, g1, g2)
|
||||
if is_conn: # fix enabled
|
||||
enabled = jnp.where(g1[:, 2] + g2[:, 2] > 0, 1, 0) # any of them is enabled
|
||||
|
||||
@@ -5,5 +5,5 @@ class BaseMutation:
|
||||
def setup(self, state=State()):
|
||||
return state
|
||||
|
||||
def __call__(self, state, genome, nodes, conns, new_node_key):
|
||||
def __call__(self, state, randkey, genome, nodes, conns, new_node_key):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,6 +1,16 @@
|
||||
import jax, jax.numpy as jnp
|
||||
from . import BaseMutation
|
||||
from utils import fetch_first, fetch_random, I_INF, unflatten_conns, check_cycles
|
||||
from utils import (
|
||||
fetch_first,
|
||||
fetch_random,
|
||||
I_INF,
|
||||
unflatten_conns,
|
||||
check_cycles,
|
||||
add_node,
|
||||
add_conn,
|
||||
delete_node_by_pos,
|
||||
delete_conn_by_pos,
|
||||
)
|
||||
|
||||
|
||||
class DefaultMutation(BaseMutation):
|
||||
@@ -16,15 +26,17 @@ class DefaultMutation(BaseMutation):
|
||||
self.node_add = node_add
|
||||
self.node_delete = node_delete
|
||||
|
||||
def __call__(self, state, genome, nodes, conns, new_node_key):
|
||||
k1, k2, randkey = jax.random.split(state.randkey)
|
||||
def __call__(self, state, randkey, genome, nodes, conns, new_node_key):
|
||||
k1, k2 = jax.random.split(randkey)
|
||||
|
||||
nodes, conns = self.mutate_structure(k1, genome, nodes, conns, new_node_key)
|
||||
nodes, conns = self.mutate_values(k2, genome, nodes, conns)
|
||||
nodes, conns = self.mutate_structure(
|
||||
state, k1, genome, nodes, conns, new_node_key
|
||||
)
|
||||
nodes, conns = self.mutate_values(state, k2, genome, nodes, conns)
|
||||
|
||||
return state.update(randkey=randkey), nodes, conns
|
||||
return nodes, conns
|
||||
|
||||
def mutate_structure(self, key, genome, nodes, conns, new_node_key):
|
||||
def mutate_structure(self, state, randkey, genome, nodes, conns, new_node_key):
|
||||
def mutate_add_node(key_, nodes_, conns_):
|
||||
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
|
||||
|
||||
@@ -33,24 +45,24 @@ class DefaultMutation(BaseMutation):
|
||||
new_conns = conns_.at[idx, 2].set(False)
|
||||
|
||||
# add a new node
|
||||
new_nodes = genome.add_node(
|
||||
nodes_, new_node_key, genome.node_gene.new_custom_attrs()
|
||||
new_nodes = add_node(
|
||||
nodes_, new_node_key, genome.node_gene.new_custom_attrs(state)
|
||||
)
|
||||
|
||||
# add two new connections
|
||||
new_conns = genome.add_conn(
|
||||
new_conns = add_conn(
|
||||
new_conns,
|
||||
i_key,
|
||||
new_node_key,
|
||||
True,
|
||||
genome.conn_gene.new_custom_attrs(),
|
||||
genome.conn_gene.new_custom_attrs(state),
|
||||
)
|
||||
new_conns = genome.add_conn(
|
||||
new_conns = add_conn(
|
||||
new_conns,
|
||||
new_node_key,
|
||||
o_key,
|
||||
True,
|
||||
genome.conn_gene.new_custom_attrs(),
|
||||
genome.conn_gene.new_custom_attrs(state),
|
||||
)
|
||||
|
||||
return new_nodes, new_conns
|
||||
@@ -75,7 +87,7 @@ class DefaultMutation(BaseMutation):
|
||||
|
||||
def successful_delete_node():
|
||||
# delete the node
|
||||
new_nodes = genome.delete_node_by_pos(nodes_, idx)
|
||||
new_nodes = delete_node_by_pos(nodes_, idx)
|
||||
|
||||
# delete all connections
|
||||
new_conns = jnp.where(
|
||||
@@ -123,8 +135,8 @@ class DefaultMutation(BaseMutation):
|
||||
return nodes_, conns_
|
||||
|
||||
def successful():
|
||||
return nodes_, genome.add_conn(
|
||||
conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs()
|
||||
return nodes_, add_conn(
|
||||
conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs(state)
|
||||
)
|
||||
|
||||
def already_exist():
|
||||
@@ -152,7 +164,7 @@ class DefaultMutation(BaseMutation):
|
||||
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
|
||||
|
||||
def successfully_delete_connection():
|
||||
return nodes_, genome.delete_conn_by_pos(conns_, idx)
|
||||
return nodes_, delete_conn_by_pos(conns_, idx)
|
||||
|
||||
return jax.lax.cond(
|
||||
idx == I_INF,
|
||||
@@ -160,7 +172,7 @@ class DefaultMutation(BaseMutation):
|
||||
successfully_delete_connection,
|
||||
)
|
||||
|
||||
k1, k2, k3, k4 = jax.random.split(key, num=4)
|
||||
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
|
||||
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
|
||||
|
||||
def no(key_, nodes_, conns_):
|
||||
@@ -181,13 +193,17 @@ class DefaultMutation(BaseMutation):
|
||||
|
||||
return nodes, conns
|
||||
|
||||
def mutate_values(self, key, genome, nodes, conns):
|
||||
k1, k2 = jax.random.split(key, num=2)
|
||||
def mutate_values(self, state, randkey, genome, nodes, conns):
|
||||
k1, k2 = jax.random.split(randkey, num=2)
|
||||
nodes_keys = jax.random.split(k1, num=nodes.shape[0])
|
||||
conns_keys = jax.random.split(k2, num=conns.shape[0])
|
||||
|
||||
new_nodes = jax.vmap(genome.node_gene.mutate)(nodes_keys, nodes)
|
||||
new_conns = jax.vmap(genome.conn_gene.mutate)(conns_keys, conns)
|
||||
new_nodes = jax.vmap(genome.node_gene.mutate, in_axes=(None, 0, 0))(
|
||||
state, nodes_keys, nodes
|
||||
)
|
||||
new_conns = jax.vmap(genome.conn_gene.mutate, in_axes=(None, 0, 0))(
|
||||
state, conns_keys, conns
|
||||
)
|
||||
|
||||
# nan nodes not changed
|
||||
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)
|
||||
|
||||
@@ -26,7 +26,7 @@ class DefaultConnGene(BaseConnGene):
|
||||
self.weight_replace_rate = weight_replace_rate
|
||||
|
||||
def new_custom_attrs(self, state):
|
||||
return state, jnp.array([self.weight_init_mean])
|
||||
return jnp.array([self.weight_init_mean])
|
||||
|
||||
def new_random_attrs(self, state, randkey):
|
||||
weight = (
|
||||
|
||||
@@ -109,10 +109,10 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
|
||||
def distance(self, state, node1, node2):
|
||||
return (
|
||||
jnp.abs(node1[1] - node2[1])
|
||||
+ jnp.abs(node1[2] - node2[2])
|
||||
+ (node1[3] != node2[3])
|
||||
+ (node1[4] != node2[4])
|
||||
jnp.abs(node1[1] - node2[1]) # bias
|
||||
+ jnp.abs(node1[2] - node2[2]) # response
|
||||
+ (node1[3] != node2[3]) # activation
|
||||
+ (node1[4] != node2[4]) # aggregation
|
||||
)
|
||||
|
||||
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||
|
||||
106
tensorneat/algorithm/neat/gene/node/default_without_response.py
Normal file
106
tensorneat/algorithm/neat/gene/node/default_without_response.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from typing import Tuple
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from utils import Act, Agg, act, agg, mutate_int, mutate_float
|
||||
from . import BaseNodeGene
|
||||
|
||||
|
||||
class NodeGeneWithoutResponse(BaseNodeGene):
|
||||
"""
|
||||
Default node gene, with the same behavior as in NEAT-python.
|
||||
The attribute response is removed.
|
||||
"""
|
||||
|
||||
custom_attrs = ["bias", "aggregation", "activation"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bias_init_mean: float = 0.0,
|
||||
bias_init_std: float = 1.0,
|
||||
bias_mutate_power: float = 0.5,
|
||||
bias_mutate_rate: float = 0.7,
|
||||
bias_replace_rate: float = 0.1,
|
||||
activation_default: callable = Act.sigmoid,
|
||||
activation_options: Tuple = (Act.sigmoid,),
|
||||
activation_replace_rate: float = 0.1,
|
||||
aggregation_default: callable = Agg.sum,
|
||||
aggregation_options: Tuple = (Agg.sum,),
|
||||
aggregation_replace_rate: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
self.bias_init_mean = bias_init_mean
|
||||
self.bias_init_std = bias_init_std
|
||||
self.bias_mutate_power = bias_mutate_power
|
||||
self.bias_mutate_rate = bias_mutate_rate
|
||||
self.bias_replace_rate = bias_replace_rate
|
||||
|
||||
self.activation_default = activation_options.index(activation_default)
|
||||
self.activation_options = activation_options
|
||||
self.activation_indices = jnp.arange(len(activation_options))
|
||||
self.activation_replace_rate = activation_replace_rate
|
||||
|
||||
self.aggregation_default = aggregation_options.index(aggregation_default)
|
||||
self.aggregation_options = aggregation_options
|
||||
self.aggregation_indices = jnp.arange(len(aggregation_options))
|
||||
self.aggregation_replace_rate = aggregation_replace_rate
|
||||
|
||||
def new_custom_attrs(self, state):
|
||||
return jnp.array(
|
||||
[
|
||||
self.bias_init_mean,
|
||||
self.activation_default,
|
||||
self.aggregation_default,
|
||||
]
|
||||
)
|
||||
|
||||
def new_random_attrs(self, state, randkey):
|
||||
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
|
||||
bias = jax.random.normal(k1, ()) * self.bias_init_std + self.bias_init_mean
|
||||
act = jax.random.randint(k3, (), 0, len(self.activation_options))
|
||||
agg = jax.random.randint(k4, (), 0, len(self.aggregation_options))
|
||||
return jnp.array([bias, act, agg])
|
||||
|
||||
def mutate(self, state, randkey, node):
|
||||
k1, k2, k3, k4 = jax.random.split(state.randkey, num=4)
|
||||
index = node[0]
|
||||
|
||||
bias = mutate_float(
|
||||
k1,
|
||||
node[1],
|
||||
self.bias_init_mean,
|
||||
self.bias_init_std,
|
||||
self.bias_mutate_power,
|
||||
self.bias_mutate_rate,
|
||||
self.bias_replace_rate,
|
||||
)
|
||||
|
||||
act = mutate_int(
|
||||
k3, node[3], self.activation_indices, self.activation_replace_rate
|
||||
)
|
||||
|
||||
agg = mutate_int(
|
||||
k4, node[4], self.aggregation_indices, self.aggregation_replace_rate
|
||||
)
|
||||
|
||||
return jnp.array([index, bias, act, agg])
|
||||
|
||||
def distance(self, state, node1, node2):
|
||||
return (
|
||||
jnp.abs(node1[1] - node2[1]) # bias
|
||||
+ (node1[3] != node2[3]) # activation
|
||||
+ (node1[4] != node2[4]) # aggregation
|
||||
)
|
||||
|
||||
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||
bias, act_idx, agg_idx = attrs
|
||||
|
||||
z = agg(agg_idx, inputs, self.aggregation_options)
|
||||
z = bias + z
|
||||
|
||||
# the last output node should not be activated
|
||||
z = jax.lax.cond(
|
||||
is_output_node, lambda: z, lambda: act(act_idx, z, self.activation_options)
|
||||
)
|
||||
|
||||
return z
|
||||
@@ -1,6 +1,7 @@
|
||||
import jax.numpy as jnp
|
||||
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
|
||||
from utils import fetch_first, State
|
||||
import jax, jax.numpy as jnp
|
||||
from ..gene import BaseNodeGene, BaseConnGene
|
||||
from ..ga import BaseMutation, BaseCrossover
|
||||
from utils import State
|
||||
|
||||
|
||||
class BaseGenome:
|
||||
@@ -12,8 +13,10 @@ class BaseGenome:
|
||||
num_outputs: int,
|
||||
max_nodes: int,
|
||||
max_conns: int,
|
||||
node_gene: BaseNodeGene = DefaultNodeGene(),
|
||||
conn_gene: BaseConnGene = DefaultConnGene(),
|
||||
node_gene: BaseNodeGene,
|
||||
conn_gene: BaseConnGene,
|
||||
mutation: BaseMutation,
|
||||
crossover: BaseCrossover,
|
||||
):
|
||||
self.num_inputs = num_inputs
|
||||
self.num_outputs = num_outputs
|
||||
@@ -23,10 +26,14 @@ class BaseGenome:
|
||||
self.max_conns = max_conns
|
||||
self.node_gene = node_gene
|
||||
self.conn_gene = conn_gene
|
||||
self.mutation = mutation
|
||||
self.crossover = crossover
|
||||
|
||||
def setup(self, state=State()):
|
||||
state = self.node_gene.setup(state)
|
||||
state = self.conn_gene.setup(state)
|
||||
state = self.mutation.setup(state)
|
||||
state = self.crossover.setup(state)
|
||||
return state
|
||||
|
||||
def transform(self, state, nodes, conns):
|
||||
@@ -35,36 +42,81 @@ class BaseGenome:
|
||||
def forward(self, state, inputs, transformed):
|
||||
raise NotImplementedError
|
||||
|
||||
def add_node(self, nodes, new_key: int, attrs):
|
||||
"""
|
||||
Add a new node to the genome.
|
||||
The new node will place at the first NaN row.
|
||||
"""
|
||||
exist_keys = nodes[:, 0]
|
||||
pos = fetch_first(jnp.isnan(exist_keys))
|
||||
new_nodes = nodes.at[pos, 0].set(new_key)
|
||||
return new_nodes.at[pos, 1:].set(attrs)
|
||||
def execute_mutation(self, state, randkey, nodes, conns, new_node_key):
|
||||
return self.mutation(state, randkey, self, nodes, conns, new_node_key)
|
||||
|
||||
def delete_node_by_pos(self, nodes, pos):
|
||||
"""
|
||||
Delete a node from the genome.
|
||||
Delete the node by its pos in nodes.
|
||||
"""
|
||||
return nodes.at[pos].set(jnp.nan)
|
||||
def execute_crossover(self, state, randkey, nodes1, conns1, nodes2, conns2):
|
||||
return self.crossover(state, randkey, self, nodes1, conns1, nodes2, conns2)
|
||||
|
||||
def add_conn(self, conns, i_key, o_key, enable: bool, attrs):
|
||||
def initialize(self, state, randkey):
|
||||
"""
|
||||
Add a new connection to the genome.
|
||||
The new connection will place at the first NaN row.
|
||||
"""
|
||||
con_keys = conns[:, 0]
|
||||
pos = fetch_first(jnp.isnan(con_keys))
|
||||
new_conns = conns.at[pos, 0:3].set(jnp.array([i_key, o_key, enable]))
|
||||
return new_conns.at[pos, 3:].set(attrs)
|
||||
Default initialization method for the genome.
|
||||
Add an extra hidden node.
|
||||
Make all input nodes and output nodes connected to the hidden node.
|
||||
All attributes will be initialized randomly using gene.new_random_attrs method.
|
||||
|
||||
def delete_conn_by_pos(self, conns, pos):
|
||||
For example, a network with 2 inputs and 1 output, the structure will be:
|
||||
nodes:
|
||||
[
|
||||
[0, attrs0], # input node 0
|
||||
[1, attrs1], # input node 1
|
||||
[2, attrs2], # output node 0
|
||||
[3, attrs3], # hidden node
|
||||
[NaN, NaN], # empty node
|
||||
]
|
||||
conns:
|
||||
[
|
||||
[0, 3, attrs0], # input node 0 -> hidden node
|
||||
[1, 3, attrs1], # input node 1 -> hidden node
|
||||
[3, 2, attrs2], # hidden node -> output node 0
|
||||
[NaN, NaN],
|
||||
[NaN, NaN],
|
||||
]
|
||||
"""
|
||||
Delete a connection from the genome.
|
||||
Delete the connection by its idx.
|
||||
"""
|
||||
return conns.at[pos].set(jnp.nan)
|
||||
|
||||
k1, k2 = jax.random.split(randkey) # k1 for nodes, k2 for conns
|
||||
# initialize nodes
|
||||
new_node_key = (
|
||||
max([*self.input_idx, *self.output_idx]) + 1
|
||||
) # the key for the hidden node
|
||||
node_keys = jnp.concatenate(
|
||||
[self.input_idx, self.output_idx, jnp.array([new_node_key])]
|
||||
) # the list of all node keys
|
||||
|
||||
# initialize nodes and connections with NaN
|
||||
nodes = jnp.full((self.max_nodes, self.node_gene.length), jnp.nan)
|
||||
conns = jnp.full((self.max_conns, self.conn_gene.length), jnp.nan)
|
||||
|
||||
# set keys for input nodes, output nodes and hidden node
|
||||
nodes = nodes.at[node_keys, 0].set(node_keys)
|
||||
|
||||
# generate random attributes for nodes
|
||||
node_keys = jax.random.split(k1, len(node_keys))
|
||||
random_node_attrs = jax.vmap(
|
||||
self.node_gene.new_random_attrs, in_axes=(None, 0)
|
||||
)(state, node_keys)
|
||||
nodes = nodes.at[: len(node_keys), 1:].set(random_node_attrs)
|
||||
|
||||
# initialize conns
|
||||
# input-hidden connections
|
||||
input_conns = jnp.c_[
|
||||
self.input_idx, jnp.full_like(self.input_idx, new_node_key)
|
||||
]
|
||||
conns = conns.at[self.input_idx, :2].set(input_conns) # in-keys, out-keys
|
||||
conns = conns.at[self.input_idx, 2].set(True) # enable
|
||||
|
||||
# output-hidden connections
|
||||
output_conns = jnp.c_[
|
||||
jnp.full_like(self.output_idx, new_node_key), self.output_idx
|
||||
]
|
||||
conns = conns.at[self.output_idx, :2].set(output_conns) # in-keys, out-keys
|
||||
conns = conns.at[self.output_idx, 2].set(True) # enable
|
||||
|
||||
conn_keys = jax.random.split(k2, num=len(self.input_idx) + len(self.output_idx))
|
||||
# generate random attributes for conns
|
||||
random_conn_attrs = jax.vmap(
|
||||
self.conn_gene.new_random_attrs, in_axes=(None, 0)
|
||||
)(state, conn_keys)
|
||||
conns = conns.at[: len(conn_keys), 3:].set(random_conn_attrs)
|
||||
|
||||
return nodes, conns
|
||||
|
||||
@@ -5,6 +5,7 @@ from utils import unflatten_conns, topological_sort, I_INF
|
||||
|
||||
from . import BaseGenome
|
||||
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
|
||||
from ..ga import BaseMutation, BaseCrossover, DefaultMutation, DefaultCrossover
|
||||
|
||||
|
||||
class DefaultGenome(BaseGenome):
|
||||
@@ -20,10 +21,19 @@ class DefaultGenome(BaseGenome):
|
||||
max_conns=4,
|
||||
node_gene: BaseNodeGene = DefaultNodeGene(),
|
||||
conn_gene: BaseConnGene = DefaultConnGene(),
|
||||
mutation: BaseMutation = DefaultMutation(),
|
||||
crossover: BaseCrossover = DefaultCrossover(),
|
||||
output_transform: Callable = None,
|
||||
):
|
||||
super().__init__(
|
||||
num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene
|
||||
num_inputs,
|
||||
num_outputs,
|
||||
max_nodes,
|
||||
max_conns,
|
||||
node_gene,
|
||||
conn_gene,
|
||||
mutation,
|
||||
crossover,
|
||||
)
|
||||
|
||||
if output_transform is not None:
|
||||
|
||||
@@ -5,6 +5,7 @@ from utils import unflatten_conns
|
||||
|
||||
from . import BaseGenome
|
||||
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
|
||||
from ..ga import BaseMutation, BaseCrossover, DefaultMutation, DefaultCrossover
|
||||
|
||||
|
||||
class RecurrentGenome(BaseGenome):
|
||||
@@ -20,11 +21,20 @@ class RecurrentGenome(BaseGenome):
|
||||
max_conns: int,
|
||||
node_gene: BaseNodeGene = DefaultNodeGene(),
|
||||
conn_gene: BaseConnGene = DefaultConnGene(),
|
||||
mutation: BaseMutation = DefaultMutation(),
|
||||
crossover: BaseCrossover = DefaultCrossover(),
|
||||
activate_time: int = 10,
|
||||
output_transform: Callable = None,
|
||||
):
|
||||
super().__init__(
|
||||
num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene
|
||||
num_inputs,
|
||||
num_outputs,
|
||||
max_nodes,
|
||||
max_conns,
|
||||
node_gene,
|
||||
conn_gene,
|
||||
mutation,
|
||||
crossover,
|
||||
)
|
||||
self.activate_time = activate_time
|
||||
|
||||
|
||||
@@ -10,18 +10,12 @@ class NEAT(BaseAlgorithm):
|
||||
def __init__(
|
||||
self,
|
||||
species: BaseSpecies,
|
||||
mutation: BaseMutation = DefaultMutation(),
|
||||
crossover: BaseCrossover = DefaultCrossover(),
|
||||
):
|
||||
self.genome: BaseGenome = species.genome
|
||||
self.species = species
|
||||
self.mutation = mutation
|
||||
self.crossover = crossover
|
||||
self.genome = species.genome
|
||||
|
||||
def setup(self, state=State()):
|
||||
state = self.species.setup(state)
|
||||
state = self.mutation.setup(state)
|
||||
state = self.crossover.setup(state)
|
||||
state = state.register(
|
||||
generation=jnp.array(0.0),
|
||||
next_node_key=jnp.array(
|
||||
@@ -32,18 +26,16 @@ class NEAT(BaseAlgorithm):
|
||||
return state
|
||||
|
||||
def ask(self, state: State):
|
||||
return state, self.species.ask(state.species)
|
||||
return self.species.ask(state)
|
||||
|
||||
def tell(self, state: State, fitness):
|
||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||
|
||||
state = state.update(generation=state.generation + 1, randkey=randkey)
|
||||
|
||||
state, winner, loser, elite_mask = self.species.update_species(
|
||||
state.species, fitness
|
||||
)
|
||||
state, winner, loser, elite_mask = self.species.update_species(state, fitness)
|
||||
state = self.create_next_generation(state, winner, loser, elite_mask)
|
||||
state = self.species.speciate(state.species)
|
||||
state = self.species.speciate(state)
|
||||
|
||||
return state
|
||||
|
||||
@@ -73,21 +65,25 @@ class NEAT(BaseAlgorithm):
|
||||
new_node_keys = jnp.arange(pop_size) + state.next_node_key
|
||||
|
||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||
crossover_rand_keys = jax.random.split(k1, pop_size)
|
||||
mutate_rand_keys = jax.random.split(k2, pop_size)
|
||||
crossover_randkeys = jax.random.split(k1, pop_size)
|
||||
mutate_randkeys = jax.random.split(k2, pop_size)
|
||||
|
||||
wpn, wpc = state.species.pop_nodes[winner], state.species.pop_conns[winner]
|
||||
lpn, lpc = state.species.pop_nodes[loser], state.species.pop_conns[loser]
|
||||
wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner]
|
||||
lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser]
|
||||
|
||||
# batch crossover
|
||||
n_nodes, n_conns = jax.vmap(self.crossover, in_axes=(0, None, 0, 0, 0, 0))(
|
||||
crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc
|
||||
)
|
||||
n_nodes, n_conns = jax.vmap(
|
||||
self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0)
|
||||
)(
|
||||
state, crossover_randkeys, wpn, wpc, lpn, lpc
|
||||
) # new_nodes, new_conns
|
||||
|
||||
# batch mutation
|
||||
m_n_nodes, m_n_conns = jax.vmap(self.mutation, in_axes=(0, None, 0, 0, 0))(
|
||||
mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys
|
||||
)
|
||||
m_n_nodes, m_n_conns = jax.vmap(
|
||||
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
|
||||
)(
|
||||
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
|
||||
) # mutated_new_nodes, mutated_new_conns
|
||||
|
||||
# elitism don't mutate
|
||||
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
|
||||
@@ -108,8 +104,8 @@ class NEAT(BaseAlgorithm):
|
||||
)
|
||||
|
||||
def member_count(self, state: State):
|
||||
return state, state.species.member_count
|
||||
return state.member_count
|
||||
|
||||
def generation(self, state: State):
|
||||
# to analysis the algorithm
|
||||
return state, state.generation
|
||||
return state.generation
|
||||
|
||||
@@ -1,10 +1,22 @@
|
||||
import numpy as np
|
||||
import jax, jax.numpy as jnp
|
||||
from utils import State, rank_elements, argmin_with_mask, fetch_first
|
||||
from ..genome import BaseGenome
|
||||
from .base import BaseSpecies
|
||||
|
||||
|
||||
"""
|
||||
Core procedures of NEAT algorithm, contains the following steps:
|
||||
1. Update the fitness of each species;
|
||||
2. Decide which species will be stagnation;
|
||||
3. Decide the number of members of each species in the next generation;
|
||||
4. Choice the crossover pair for each species;
|
||||
5. Divided the whole new population into different species;
|
||||
|
||||
This class use tensor operation to imitate the behavior of NEAT algorithm which implemented in NEAT-python.
|
||||
The code may be hard to understand. Fortunately, we don't need to overwrite it in most cases.
|
||||
"""
|
||||
|
||||
|
||||
class DefaultSpecies(BaseSpecies):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -20,8 +32,6 @@ class DefaultSpecies(BaseSpecies):
|
||||
survival_threshold: float = 0.2,
|
||||
min_species_size: int = 1,
|
||||
compatibility_threshold: float = 3.0,
|
||||
initialize_method: str = "one_hidden_node",
|
||||
# {'one_hidden_node', 'dense_hideen_layer', 'no_hidden_random'}
|
||||
):
|
||||
self.genome = genome
|
||||
self.pop_size = pop_size
|
||||
@@ -36,15 +46,17 @@ class DefaultSpecies(BaseSpecies):
|
||||
self.survival_threshold = survival_threshold
|
||||
self.min_species_size = min_species_size
|
||||
self.compatibility_threshold = compatibility_threshold
|
||||
self.initialize_method = initialize_method
|
||||
|
||||
self.species_arange = jnp.arange(self.species_size)
|
||||
|
||||
def setup(self, state=State()):
|
||||
state = self.genome.setup(state)
|
||||
k1, randkey = jax.random.split(state.randkey, 2)
|
||||
pop_nodes, pop_conns = initialize_population(
|
||||
self.pop_size, self.genome, k1, self.initialize_method
|
||||
|
||||
# initialize the population
|
||||
initialize_keys = jax.random.split(randkey, self.pop_size)
|
||||
pop_nodes, pop_conns = jax.vmap(self.genome.initialize, in_axes=(None, 0))(
|
||||
state, initialize_keys
|
||||
)
|
||||
|
||||
species_keys = jnp.full(
|
||||
@@ -82,8 +94,9 @@ class DefaultSpecies(BaseSpecies):
|
||||
|
||||
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns))
|
||||
|
||||
state = state.update(randkey=randkey)
|
||||
|
||||
return state.register(
|
||||
randkey=randkey,
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
species_keys=species_keys,
|
||||
@@ -97,7 +110,7 @@ class DefaultSpecies(BaseSpecies):
|
||||
)
|
||||
|
||||
def ask(self, state):
|
||||
return state, state.pop_nodes, state.pop_conns
|
||||
return state.pop_nodes, state.pop_conns
|
||||
|
||||
def update_species(self, state, fitness):
|
||||
# update the fitness of each species
|
||||
@@ -122,8 +135,8 @@ class DefaultSpecies(BaseSpecies):
|
||||
|
||||
k1, k2 = jax.random.split(state.randkey)
|
||||
# crossover info
|
||||
winner, loser, elite_mask = self.create_crossover_pair(
|
||||
state, k1, spawn_number, fitness
|
||||
state, winner, loser, elite_mask = self.create_crossover_pair(
|
||||
state, spawn_number, fitness
|
||||
)
|
||||
|
||||
return state.update(randkey=k2), winner, loser, elite_mask
|
||||
@@ -322,12 +335,12 @@ class DefaultSpecies(BaseSpecies):
|
||||
winner = jnp.where(is_part1_win, part1, part2)
|
||||
loser = jnp.where(is_part1_win, part2, part1)
|
||||
|
||||
return state(randkey=randkey), winner, loser, elite_mask
|
||||
return state.update(randkey=randkey), winner, loser, elite_mask
|
||||
|
||||
def speciate(self, state):
|
||||
# prepare distance functions
|
||||
o2p_distance_func = jax.vmap(
|
||||
self.distance, in_axes=(None, None, 0, 0)
|
||||
self.distance, in_axes=(None, None, None, 0, 0)
|
||||
) # one to population
|
||||
|
||||
# idx to specie key
|
||||
@@ -351,7 +364,7 @@ class DefaultSpecies(BaseSpecies):
|
||||
i, i2s, cns, ccs, o2c = carry
|
||||
|
||||
distances = o2p_distance_func(
|
||||
cns[i], ccs[i], state.pop_nodes, state.pop_conns
|
||||
state, cns[i], ccs[i], state.pop_nodes, state.pop_conns
|
||||
)
|
||||
|
||||
# find the closest one
|
||||
@@ -434,7 +447,7 @@ class DefaultSpecies(BaseSpecies):
|
||||
def speciate_by_threshold(i, i2s, cns, ccs, sk, o2c):
|
||||
# distance between such center genome and ppo genomes
|
||||
o2p_distance = o2p_distance_func(
|
||||
cns[i], ccs[i], state.pop_nodes, state.pop_conns
|
||||
state, cns[i], ccs[i], state.pop_nodes, state.pop_conns
|
||||
)
|
||||
|
||||
close_enough_mask = o2p_distance < self.compatibility_threshold
|
||||
@@ -508,14 +521,16 @@ class DefaultSpecies(BaseSpecies):
|
||||
next_species_key=next_species_key,
|
||||
)
|
||||
|
||||
def distance(self, nodes1, conns1, nodes2, conns2):
|
||||
def distance(self, state, nodes1, conns1, nodes2, conns2):
|
||||
"""
|
||||
The distance between two genomes
|
||||
"""
|
||||
d = self.node_distance(nodes1, nodes2) + self.conn_distance(conns1, conns2)
|
||||
d = self.node_distance(state, nodes1, nodes2) + self.conn_distance(
|
||||
state, conns1, conns2
|
||||
)
|
||||
return d
|
||||
|
||||
def node_distance(self, nodes1, nodes2):
|
||||
def node_distance(self, state, nodes1, nodes2):
|
||||
"""
|
||||
The distance of the nodes part for two genomes
|
||||
"""
|
||||
@@ -541,7 +556,9 @@ class DefaultSpecies(BaseSpecies):
|
||||
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
|
||||
# calculate the distance of homologous nodes
|
||||
hnd = jax.vmap(self.genome.node_gene.distance, in_axes=(0, 0))(fr, sr)
|
||||
hnd = jax.vmap(self.genome.node_gene.distance, in_axes=(None, 0, 0))(
|
||||
state, fr, sr
|
||||
) # homologous node distance
|
||||
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
|
||||
homologous_distance = jnp.sum(hnd * intersect_mask)
|
||||
|
||||
@@ -550,9 +567,11 @@ class DefaultSpecies(BaseSpecies):
|
||||
+ homologous_distance * self.compatibility_weight
|
||||
)
|
||||
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt) # avoid zero division
|
||||
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
|
||||
|
||||
def conn_distance(self, conns1, conns2):
|
||||
return val
|
||||
|
||||
def conn_distance(self, state, conns1, conns2):
|
||||
"""
|
||||
The distance of the conns part for two genomes
|
||||
"""
|
||||
@@ -573,7 +592,9 @@ class DefaultSpecies(BaseSpecies):
|
||||
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
|
||||
|
||||
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
hcd = jax.vmap(self.genome.conn_gene.distance, in_axes=(0, 0))(fr, sr)
|
||||
hcd = jax.vmap(self.genome.conn_gene.distance, in_axes=(None, 0, 0))(
|
||||
state, fr, sr
|
||||
) # homologous connection distance
|
||||
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
|
||||
homologous_distance = jnp.sum(hcd * intersect_mask)
|
||||
|
||||
@@ -582,185 +603,6 @@ class DefaultSpecies(BaseSpecies):
|
||||
+ homologous_distance * self.compatibility_weight
|
||||
)
|
||||
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt)
|
||||
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
|
||||
|
||||
|
||||
def initialize_population(pop_size, genome, randkey, init_method="default"):
|
||||
rand_keys = jax.random.split(randkey, pop_size)
|
||||
|
||||
if init_method == "one_hidden_node":
|
||||
init_func = init_one_hidden_node
|
||||
elif init_method == "dense_hideen_layer":
|
||||
init_func = init_dense_hideen_layer
|
||||
elif init_method == "no_hidden_random":
|
||||
init_func = init_no_hidden_random
|
||||
else:
|
||||
raise ValueError("Unknown initialization method: {}".format(init_method))
|
||||
|
||||
pop_nodes, pop_conns = jax.vmap(init_func, in_axes=(None, 0))(genome, rand_keys)
|
||||
|
||||
return pop_nodes, pop_conns
|
||||
|
||||
|
||||
# one hidden node
|
||||
def init_one_hidden_node(genome, randkey):
|
||||
input_idx, output_idx = genome.input_idx, genome.output_idx
|
||||
new_node_key = max([*input_idx, *output_idx]) + 1
|
||||
|
||||
nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan)
|
||||
conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan)
|
||||
|
||||
nodes = nodes.at[input_idx, 0].set(input_idx)
|
||||
nodes = nodes.at[output_idx, 0].set(output_idx)
|
||||
nodes = nodes.at[new_node_key, 0].set(new_node_key)
|
||||
|
||||
rand_keys_nodes = jax.random.split(
|
||||
randkey, num=len(input_idx) + len(output_idx) + 1
|
||||
)
|
||||
input_keys, output_keys, hidden_key = (
|
||||
rand_keys_nodes[: len(input_idx)],
|
||||
rand_keys_nodes[len(input_idx) : len(input_idx) + len(output_idx)],
|
||||
rand_keys_nodes[-1],
|
||||
)
|
||||
|
||||
node_attr_func = jax.vmap(genome.node_gene.new_attrs, in_axes=(None, 0))
|
||||
input_attrs = node_attr_func(input_keys)
|
||||
output_attrs = node_attr_func(output_keys)
|
||||
hidden_attrs = genome.node_gene.new_custom_attrs(hidden_key)
|
||||
|
||||
nodes = nodes.at[input_idx, 1:].set(input_attrs)
|
||||
nodes = nodes.at[output_idx, 1:].set(output_attrs)
|
||||
nodes = nodes.at[new_node_key, 1:].set(hidden_attrs)
|
||||
|
||||
input_conns = jnp.c_[input_idx, jnp.full_like(input_idx, new_node_key)]
|
||||
conns = conns.at[input_idx, 0:2].set(input_conns)
|
||||
conns = conns.at[input_idx, 2].set(True)
|
||||
|
||||
output_conns = jnp.c_[jnp.full_like(output_idx, new_node_key), output_idx]
|
||||
conns = conns.at[output_idx, 0:2].set(output_conns)
|
||||
conns = conns.at[output_idx, 2].set(True)
|
||||
|
||||
rand_keys_conns = jax.random.split(randkey, num=len(input_idx) + len(output_idx))
|
||||
input_conn_keys, output_conn_keys = (
|
||||
rand_keys_conns[: len(input_idx)],
|
||||
rand_keys_conns[len(input_idx) :],
|
||||
)
|
||||
|
||||
conn_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(None, 0))
|
||||
input_conn_attrs = conn_attr_func(input_conn_keys)
|
||||
output_conn_attrs = conn_attr_func(output_conn_keys)
|
||||
|
||||
conns = conns.at[input_idx, 3:].set(input_conn_attrs)
|
||||
conns = conns.at[output_idx, 3:].set(output_conn_attrs)
|
||||
|
||||
return nodes, conns
|
||||
|
||||
|
||||
# random dense connections with 1 hidden layer
|
||||
def init_dense_hideen_layer(genome, randkey, hiddens=20):
|
||||
k1, k2, k3 = jax.random.split(randkey, num=3)
|
||||
input_idx, output_idx = genome.input_idx, genome.output_idx
|
||||
input_size = len(input_idx)
|
||||
output_size = len(output_idx)
|
||||
|
||||
hidden_idx = jnp.arange(
|
||||
input_size + output_size, input_size + output_size + hiddens
|
||||
)
|
||||
nodes = jnp.full(
|
||||
(genome.max_nodes, genome.node_gene.length), jnp.nan, dtype=jnp.float32
|
||||
)
|
||||
nodes = nodes.at[input_idx, 0].set(input_idx)
|
||||
nodes = nodes.at[output_idx, 0].set(output_idx)
|
||||
nodes = nodes.at[hidden_idx, 0].set(hidden_idx)
|
||||
|
||||
total_idx = input_size + output_size + hiddens
|
||||
rand_keys_n = jax.random.split(k1, num=total_idx)
|
||||
input_keys = rand_keys_n[:input_size]
|
||||
output_keys = rand_keys_n[input_size : input_size + output_size]
|
||||
hidden_keys = rand_keys_n[input_size + output_size :]
|
||||
|
||||
node_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(0))
|
||||
input_attrs = node_attr_func(input_keys)
|
||||
output_attrs = node_attr_func(output_keys)
|
||||
hidden_attrs = node_attr_func(hidden_keys)
|
||||
|
||||
nodes = nodes.at[input_idx, 1:].set(input_attrs)
|
||||
nodes = nodes.at[output_idx, 1:].set(output_attrs)
|
||||
nodes = nodes.at[hidden_idx, 1:].set(hidden_attrs)
|
||||
|
||||
total_connections = input_size * hiddens + hiddens * output_size
|
||||
conns = jnp.full(
|
||||
(genome.max_conns, genome.conn_gene.length), jnp.nan, dtype=jnp.float32
|
||||
)
|
||||
|
||||
rand_keys_c = jax.random.split(k2, num=total_connections)
|
||||
conns_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(0))
|
||||
conns_attrs = conns_attr_func(rand_keys_c)
|
||||
|
||||
input_to_hidden_ids, hidden_ids = jnp.meshgrid(input_idx, hidden_idx, indexing="ij")
|
||||
hidden_to_output_ids, output_ids = jnp.meshgrid(
|
||||
hidden_idx, output_idx, indexing="ij"
|
||||
)
|
||||
|
||||
conns = conns.at[: input_size * hiddens, 0].set(input_to_hidden_ids.flatten())
|
||||
conns = conns.at[: input_size * hiddens, 1].set(hidden_ids.flatten())
|
||||
conns = conns.at[input_size * hiddens : total_connections, 0].set(
|
||||
hidden_to_output_ids.flatten()
|
||||
)
|
||||
conns = conns.at[input_size * hiddens : total_connections, 1].set(
|
||||
output_ids.flatten()
|
||||
)
|
||||
conns = conns.at[: input_size * hiddens + hiddens * output_size, 2].set(True)
|
||||
conns = conns.at[: input_size * hiddens + hiddens * output_size, 3:].set(
|
||||
conns_attrs
|
||||
)
|
||||
|
||||
return nodes, conns
|
||||
|
||||
|
||||
# random sparse connections with no hidden nodes
|
||||
def init_no_hidden_random(genome, randkey):
|
||||
k1, k2, k3 = jax.random.split(randkey, num=3)
|
||||
input_idx, output_idx = genome.input_idx, genome.output_idx
|
||||
|
||||
nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan)
|
||||
nodes = nodes.at[input_idx, 0].set(input_idx)
|
||||
nodes = nodes.at[output_idx, 0].set(output_idx)
|
||||
|
||||
total_idx = len(input_idx) + len(output_idx)
|
||||
rand_keys_n = jax.random.split(k1, num=total_idx)
|
||||
input_keys = rand_keys_n[: len(input_idx)]
|
||||
output_keys = rand_keys_n[len(input_idx) :]
|
||||
|
||||
node_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(0))
|
||||
input_attrs = node_attr_func(input_keys)
|
||||
output_attrs = node_attr_func(output_keys)
|
||||
nodes = nodes.at[input_idx, 1:].set(input_attrs)
|
||||
nodes = nodes.at[output_idx, 1:].set(output_attrs)
|
||||
|
||||
conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan)
|
||||
|
||||
num_connections_per_output = 4
|
||||
total_connections = len(output_idx) * num_connections_per_output
|
||||
|
||||
def create_connections_for_output(key):
|
||||
permuted_inputs = jax.random.permutation(key, input_idx)
|
||||
selected_inputs = permuted_inputs[:num_connections_per_output]
|
||||
return selected_inputs
|
||||
|
||||
conn_keys = jax.random.split(k2, num=len(output_idx))
|
||||
connections = jax.vmap(create_connections_for_output)(conn_keys)
|
||||
connections = connections.flatten()
|
||||
|
||||
output_repeats = jnp.repeat(output_idx, num_connections_per_output)
|
||||
|
||||
rand_keys_c = jax.random.split(k3, num=total_connections)
|
||||
conns_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(0))
|
||||
conns_attrs = conns_attr_func(rand_keys_c)
|
||||
|
||||
conns = conns.at[:total_connections, 0].set(connections)
|
||||
conns = conns.at[:total_connections, 1].set(output_repeats)
|
||||
conns = conns.at[:total_connections, 2].set(True) # enabled
|
||||
conns = conns.at[:total_connections, 3:].set(conns_attrs)
|
||||
|
||||
return nodes, conns
|
||||
return val
|
||||
|
||||
@@ -4,7 +4,7 @@ from algorithm.neat import *
|
||||
from problem.rl_env import BraxEnv
|
||||
from utils import Act
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
@@ -17,21 +17,21 @@ if __name__ == '__main__':
|
||||
activation_options=(Act.tanh,),
|
||||
activation_default=Act.tanh,
|
||||
),
|
||||
output_transform=Act.tanh
|
||||
output_transform=Act.tanh,
|
||||
),
|
||||
pop_size=1000,
|
||||
species_size=10,
|
||||
),
|
||||
),
|
||||
problem=BraxEnv(
|
||||
env_name='ant',
|
||||
env_name="ant",
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=5000
|
||||
fitness_target=5000,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
state, best = pipeline.auto_run(state)
|
||||
|
||||
@@ -4,7 +4,7 @@ from algorithm.neat import *
|
||||
from problem.rl_env import BraxEnv
|
||||
from utils import Act
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
@@ -16,21 +16,21 @@ if __name__ == '__main__':
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_options=(Act.tanh,),
|
||||
activation_default=Act.tanh,
|
||||
)
|
||||
),
|
||||
),
|
||||
pop_size=1000,
|
||||
species_size=10,
|
||||
),
|
||||
),
|
||||
problem=BraxEnv(
|
||||
env_name='halfcheetah',
|
||||
env_name="halfcheetah",
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=5000
|
||||
fitness_target=5000,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
state, best = pipeline.auto_run(state)
|
||||
|
||||
@@ -4,7 +4,7 @@ from algorithm.neat import *
|
||||
from problem.rl_env import BraxEnv
|
||||
from utils import Act
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
@@ -16,21 +16,21 @@ if __name__ == '__main__':
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_options=(Act.tanh,),
|
||||
activation_default=Act.tanh,
|
||||
)
|
||||
),
|
||||
),
|
||||
pop_size=100,
|
||||
species_size=10,
|
||||
),
|
||||
),
|
||||
problem=BraxEnv(
|
||||
env_name='reacher',
|
||||
env_name="reacher",
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=5000
|
||||
fitness_target=5000,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
state, best = pipeline.auto_run(state)
|
||||
|
||||
@@ -4,7 +4,7 @@ from algorithm.neat import *
|
||||
from problem.rl_env import BraxEnv
|
||||
from utils import Act
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
@@ -16,21 +16,21 @@ if __name__ == '__main__':
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_options=(Act.tanh,),
|
||||
activation_default=Act.tanh,
|
||||
)
|
||||
),
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
),
|
||||
),
|
||||
problem=BraxEnv(
|
||||
env_name='walker2d',
|
||||
env_name="walker2d",
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=5000
|
||||
fitness_target=5000,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
state, best = pipeline.auto_run(state)
|
||||
|
||||
@@ -4,7 +4,7 @@ from algorithm.neat import *
|
||||
from problem.func_fit import XOR3d
|
||||
from utils import Act
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
@@ -18,22 +18,22 @@ if __name__ == '__main__':
|
||||
activation_options=(Act.tanh,),
|
||||
),
|
||||
output_transform=Act.sigmoid, # the activation function for output node
|
||||
mutation=DefaultMutation(
|
||||
node_add=0.05,
|
||||
conn_add=0.2,
|
||||
node_delete=0,
|
||||
conn_delete=0,
|
||||
),
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
compatibility_threshold=3.5,
|
||||
survival_threshold=0.01, # magic
|
||||
),
|
||||
mutation=DefaultMutation(
|
||||
node_add=0.05,
|
||||
conn_add=0.2,
|
||||
node_delete=0,
|
||||
conn_delete=0,
|
||||
)
|
||||
),
|
||||
problem=XOR3d(),
|
||||
generation_limit=10000,
|
||||
fitness_target=-1e-8
|
||||
fitness_target=-1e-8,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
|
||||
@@ -5,17 +5,28 @@ from utils import Act
|
||||
|
||||
from problem.func_fit import XOR3d
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=HyperNEAT(
|
||||
substrate=FullSubstrate(
|
||||
input_coors=[(-1, -1), (0.333, -1), (-0.333, -1), (1, -1)],
|
||||
hidden_coors=[
|
||||
(-1, -0.5), (0.333, -0.5), (-0.333, -0.5), (1, -0.5),
|
||||
(-1, 0), (0.333, 0), (-0.333, 0), (1, 0),
|
||||
(-1, 0.5), (0.333, 0.5), (-0.333, 0.5), (1, 0.5),
|
||||
(-1, -0.5),
|
||||
(0.333, -0.5),
|
||||
(-0.333, -0.5),
|
||||
(1, -0.5),
|
||||
(-1, 0),
|
||||
(0.333, 0),
|
||||
(-0.333, 0),
|
||||
(1, 0),
|
||||
(-1, 0.5),
|
||||
(0.333, 0.5),
|
||||
(-0.333, 0.5),
|
||||
(1, 0.5),
|
||||
],
|
||||
output_coors=[
|
||||
(0, 1),
|
||||
],
|
||||
output_coors=[(0, 1), ],
|
||||
),
|
||||
neat=NEAT(
|
||||
species=DefaultSpecies(
|
||||
@@ -42,7 +53,7 @@ if __name__ == '__main__':
|
||||
),
|
||||
problem=XOR3d(),
|
||||
generation_limit=300,
|
||||
fitness_target=-1e-6
|
||||
fitness_target=-1e-6,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse
|
||||
|
||||
from problem.func_fit import XOR3d
|
||||
from utils.activation import ACT_ALL, Act
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
seed=0,
|
||||
algorithm=NEAT(
|
||||
@@ -15,27 +16,26 @@ if __name__ == '__main__':
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
activate_time=5,
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_options=ACT_ALL,
|
||||
activation_replace_rate=0.2
|
||||
node_gene=NodeGeneWithoutResponse(
|
||||
activation_options=ACT_ALL, activation_replace_rate=0.2
|
||||
),
|
||||
output_transform=Act.sigmoid,
|
||||
mutation=DefaultMutation(
|
||||
node_add=0.05,
|
||||
conn_add=0.2,
|
||||
node_delete=0,
|
||||
conn_delete=0,
|
||||
),
|
||||
output_transform=Act.sigmoid
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
compatibility_threshold=3.5,
|
||||
survival_threshold=0.03,
|
||||
),
|
||||
mutation=DefaultMutation(
|
||||
node_add=0.05,
|
||||
conn_add=0.2,
|
||||
node_delete=0,
|
||||
conn_delete=0,
|
||||
)
|
||||
),
|
||||
problem=XOR3d(),
|
||||
generation_limit=10000,
|
||||
fitness_target=-1e-8
|
||||
fitness_target=-1e-8,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
|
||||
@@ -5,7 +5,7 @@ from algorithm.neat import *
|
||||
|
||||
from problem.rl_env import GymNaxEnv
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
@@ -14,21 +14,23 @@ if __name__ == '__main__':
|
||||
num_outputs=3,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
output_transform=lambda out: jnp.argmax(out) # the action of acrobot is {0, 1, 2}
|
||||
output_transform=lambda out: jnp.argmax(
|
||||
out
|
||||
), # the action of acrobot is {0, 1, 2}
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
),
|
||||
),
|
||||
problem=GymNaxEnv(
|
||||
env_name='Acrobot-v1',
|
||||
env_name="Acrobot-v1",
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=-62
|
||||
fitness_target=-62,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
state, best = pipeline.auto_run(state)
|
||||
|
||||
@@ -5,7 +5,7 @@ from algorithm.neat import *
|
||||
|
||||
from problem.rl_env import GymNaxEnv
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
@@ -14,21 +14,23 @@ if __name__ == '__main__':
|
||||
num_outputs=2,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1}
|
||||
output_transform=lambda out: jnp.argmax(
|
||||
out
|
||||
), # the action of cartpole is {0, 1}
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
),
|
||||
),
|
||||
problem=GymNaxEnv(
|
||||
env_name='CartPole-v1',
|
||||
env_name="CartPole-v1",
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=500
|
||||
fitness_target=500,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
state, best = pipeline.auto_run(state)
|
||||
|
||||
@@ -10,11 +10,7 @@ from problem.rl_env import GymNaxConfig, GymNaxEnv
|
||||
|
||||
def example_conf():
|
||||
return Config(
|
||||
basic=BasicConfig(
|
||||
seed=42,
|
||||
fitness_target=500,
|
||||
pop_size=10000
|
||||
),
|
||||
basic=BasicConfig(seed=42, fitness_target=500, pop_size=10000),
|
||||
neat=NeatConfig(
|
||||
inputs=4,
|
||||
outputs=1,
|
||||
@@ -23,28 +19,31 @@ def example_conf():
|
||||
activation_default=Act.tanh,
|
||||
activation_options=(Act.tanh,),
|
||||
),
|
||||
hyperneat=HyperNeatConfig(
|
||||
activation=Act.sigmoid,
|
||||
inputs=4,
|
||||
outputs=2
|
||||
),
|
||||
hyperneat=HyperNeatConfig(activation=Act.sigmoid, inputs=4, outputs=2),
|
||||
substrate=NormalSubstrateConfig(
|
||||
input_coors=((-1, -1), (-0.5, -1), (0, -1), (0.5, -1), (1, -1)),
|
||||
hidden_coors=(
|
||||
# (-1, -0.5), (-0.5, -0.5), (0, -0.5), (0.5, -0.5),
|
||||
(1, 0), (-1, 0), (-0.5, 0), (0, 0), (0.5, 0), (1, 0),
|
||||
(1, 0),
|
||||
(-1, 0),
|
||||
(-0.5, 0),
|
||||
(0, 0),
|
||||
(0.5, 0),
|
||||
(1, 0),
|
||||
# (1, 0.5), (-1, 0.5), (-0.5, 0.5), (0, 0.5), (0.5, 0.5), (1, 0.5),
|
||||
),
|
||||
output_coors=((-1, 1), (1, 1)),
|
||||
),
|
||||
problem=GymNaxConfig(
|
||||
env_name='CartPole-v1',
|
||||
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1}
|
||||
)
|
||||
env_name="CartPole-v1",
|
||||
output_transform=lambda out: jnp.argmax(
|
||||
out
|
||||
), # the action of cartpole is {0, 1}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
conf = example_conf()
|
||||
|
||||
algorithm = HyperNEAT(conf, NormalGene, NormalSubstrate)
|
||||
|
||||
@@ -5,7 +5,7 @@ from algorithm.neat import *
|
||||
|
||||
from problem.rl_env import GymNaxEnv
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
@@ -14,21 +14,23 @@ if __name__ == '__main__':
|
||||
num_outputs=3,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
output_transform=lambda out: jnp.argmax(out) # the action of mountain car is {0, 1, 2}
|
||||
output_transform=lambda out: jnp.argmax(
|
||||
out
|
||||
), # the action of mountain car is {0, 1, 2}
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
),
|
||||
),
|
||||
problem=GymNaxEnv(
|
||||
env_name='MountainCar-v0',
|
||||
env_name="MountainCar-v0",
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=0
|
||||
fitness_target=0,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
state, best = pipeline.auto_run(state)
|
||||
|
||||
@@ -4,7 +4,7 @@ from algorithm.neat import *
|
||||
from problem.rl_env import GymNaxEnv
|
||||
from utils import Act
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
@@ -14,23 +14,23 @@ if __name__ == '__main__':
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_options=(Act.tanh, ),
|
||||
activation_options=(Act.tanh,),
|
||||
activation_default=Act.tanh,
|
||||
)
|
||||
),
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
),
|
||||
),
|
||||
problem=GymNaxEnv(
|
||||
env_name='MountainCarContinuous-v0',
|
||||
env_name="MountainCarContinuous-v0",
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=500
|
||||
fitness_target=500,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
state, best = pipeline.auto_run(state)
|
||||
|
||||
@@ -4,7 +4,7 @@ from algorithm.neat import *
|
||||
from problem.rl_env import GymNaxEnv
|
||||
from utils import Act
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
@@ -17,21 +17,22 @@ if __name__ == '__main__':
|
||||
activation_options=(Act.tanh,),
|
||||
activation_default=Act.tanh,
|
||||
),
|
||||
output_transform=lambda out: out * 2 # the action of pendulum is [-2, 2]
|
||||
output_transform=lambda out: out
|
||||
* 2, # the action of pendulum is [-2, 2]
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
),
|
||||
),
|
||||
problem=GymNaxEnv(
|
||||
env_name='Pendulum-v1',
|
||||
env_name="Pendulum-v1",
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=0
|
||||
fitness_target=0,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
state, best = pipeline.auto_run(state)
|
||||
|
||||
@@ -5,7 +5,7 @@ from algorithm.neat import *
|
||||
|
||||
from problem.rl_env import GymNaxEnv
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
@@ -20,14 +20,14 @@ if __name__ == '__main__':
|
||||
),
|
||||
),
|
||||
problem=GymNaxEnv(
|
||||
env_name='Reacher-misc',
|
||||
env_name="Reacher-misc",
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target =500
|
||||
fitness_target=500,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
state, best = pipeline.auto_run(state)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import ray
|
||||
|
||||
ray.init(num_gpus=2)
|
||||
|
||||
available_resources = ray.available_resources()
|
||||
|
||||
@@ -10,14 +10,13 @@ from utils import State
|
||||
|
||||
|
||||
class Pipeline:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
algorithm: BaseAlgorithm,
|
||||
problem: BaseProblem,
|
||||
seed: int = 42,
|
||||
fitness_target: float = 1,
|
||||
generation_limit: int = 1000,
|
||||
self,
|
||||
algorithm: BaseAlgorithm,
|
||||
problem: BaseProblem,
|
||||
seed: int = 42,
|
||||
fitness_target: float = 1,
|
||||
generation_limit: int = 1000,
|
||||
):
|
||||
assert problem.jitable, "Currently, problem must be jitable"
|
||||
|
||||
@@ -31,32 +30,35 @@ class Pipeline:
|
||||
# print(self.problem.input_shape, self.problem.output_shape)
|
||||
|
||||
# TODO: make each algorithm's input_num and output_num
|
||||
assert algorithm.num_inputs == self.problem.input_shape[-1], \
|
||||
f"algorithm input shape is {algorithm.num_inputs} but problem input shape is {self.problem.input_shape}"
|
||||
assert (
|
||||
algorithm.num_inputs == self.problem.input_shape[-1]
|
||||
), f"algorithm input shape is {algorithm.num_inputs} but problem input shape is {self.problem.input_shape}"
|
||||
|
||||
self.best_genome = None
|
||||
self.best_fitness = float('-inf')
|
||||
self.best_fitness = float("-inf")
|
||||
self.generation_timestamp = None
|
||||
|
||||
def setup(self, state=State()):
|
||||
print("initializing")
|
||||
state = state.register(randkey=jax.random.PRNGKey(self.seed))
|
||||
state = self.algorithm.setup(state)
|
||||
state = self.problem.setup(state)
|
||||
print("initializing finished")
|
||||
return state
|
||||
|
||||
def step(self, state):
|
||||
|
||||
randkey_, randkey = jax.random.split(state.randkey)
|
||||
keys = jax.random.split(randkey_, self.pop_size)
|
||||
|
||||
state, pop = self.algorithm.ask(state)
|
||||
pop = self.algorithm.ask(state)
|
||||
|
||||
state, pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0), out_axes=(None, 0))(state, pop)
|
||||
pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0))(
|
||||
state, pop
|
||||
)
|
||||
|
||||
state, fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0), out_axes=(None, 0))(
|
||||
keys,
|
||||
state,
|
||||
self.algorithm.forward,
|
||||
pop_transformed
|
||||
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
|
||||
state, keys, self.algorithm.forward, pop_transformed
|
||||
)
|
||||
|
||||
state = self.algorithm.tell(state, fitnesses)
|
||||
@@ -67,13 +69,15 @@ class Pipeline:
|
||||
print("start compile")
|
||||
tic = time.time()
|
||||
compiled_step = jax.jit(self.step).lower(state).compile()
|
||||
print(f"compile finished, cost time: {time.time() - tic:.6f}s", )
|
||||
print(
|
||||
f"compile finished, cost time: {time.time() - tic:.6f}s",
|
||||
)
|
||||
|
||||
for _ in range(self.generation_limit):
|
||||
|
||||
self.generation_timestamp = time.time()
|
||||
|
||||
state, previous_pop = self.algorithm.ask(state)
|
||||
previous_pop = self.algorithm.ask(state)
|
||||
|
||||
state, fitnesses = compiled_step(state)
|
||||
|
||||
@@ -98,7 +102,12 @@ class Pipeline:
|
||||
|
||||
def analysis(self, state, pop, fitnesses):
|
||||
|
||||
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
|
||||
max_f, min_f, mean_f, std_f = (
|
||||
max(fitnesses),
|
||||
min(fitnesses),
|
||||
np.mean(fitnesses),
|
||||
np.std(fitnesses),
|
||||
)
|
||||
|
||||
new_timestamp = time.time()
|
||||
|
||||
@@ -112,10 +121,14 @@ class Pipeline:
|
||||
member_count = jax.device_get(self.algorithm.member_count(state))
|
||||
species_sizes = [int(i) for i in member_count if i > 0]
|
||||
|
||||
print(f"Generation: {self.algorithm.generation(state)}",
|
||||
f"species: {len(species_sizes)}, {species_sizes}",
|
||||
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")
|
||||
print(
|
||||
f"Generation: {self.algorithm.generation(state)}",
|
||||
f"species: {len(species_sizes)}, {species_sizes}",
|
||||
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms",
|
||||
)
|
||||
|
||||
def show(self, state, best, *args, **kwargs):
|
||||
state, transformed = self.algorithm.transform(state, best)
|
||||
self.problem.show(state.randkey, state, self.algorithm.forward, transformed, *args, **kwargs)
|
||||
transformed = self.algorithm.transform(state, best)
|
||||
self.problem.show(
|
||||
state, state.randkey, self.algorithm.forward, transformed, *args, **kwargs
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ class BaseProblem:
|
||||
"""initialize the state of the problem"""
|
||||
return state
|
||||
|
||||
def evaluate(self, randkey, state: State, act_func: Callable, params):
|
||||
def evaluate(self, state: State, randkey, act_func: Callable, params):
|
||||
"""evaluate one individual"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -32,7 +32,7 @@ class BaseProblem:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs):
|
||||
def show(self, state: State, randkey, act_func: Callable, params, *args, **kwargs):
|
||||
"""
|
||||
show how a genome perform in this problem
|
||||
"""
|
||||
|
||||
@@ -8,42 +8,44 @@ from .. import BaseProblem
|
||||
class FuncFit(BaseProblem):
|
||||
jitable = True
|
||||
|
||||
def __init__(self,
|
||||
error_method: str = 'mse'
|
||||
):
|
||||
def __init__(self, error_method: str = "mse"):
|
||||
super().__init__()
|
||||
|
||||
assert error_method in {'mse', 'rmse', 'mae', 'mape'}
|
||||
assert error_method in {"mse", "rmse", "mae", "mape"}
|
||||
self.error_method = error_method
|
||||
|
||||
def setup(self, state: State = State()):
|
||||
return state
|
||||
|
||||
def evaluate(self, randkey, state, act_func, params):
|
||||
def evaluate(self, state, randkey, act_func, params):
|
||||
|
||||
state, predict = jax.vmap(act_func, in_axes=(None, 0, None), out_axes=(None, 0))(state, self.inputs, params)
|
||||
predict = jax.vmap(act_func, in_axes=(None, 0, None))(
|
||||
state, self.inputs, params
|
||||
)
|
||||
|
||||
if self.error_method == 'mse':
|
||||
if self.error_method == "mse":
|
||||
loss = jnp.mean((predict - self.targets) ** 2)
|
||||
|
||||
elif self.error_method == 'rmse':
|
||||
elif self.error_method == "rmse":
|
||||
loss = jnp.sqrt(jnp.mean((predict - self.targets) ** 2))
|
||||
|
||||
elif self.error_method == 'mae':
|
||||
elif self.error_method == "mae":
|
||||
loss = jnp.mean(jnp.abs(predict - self.targets))
|
||||
|
||||
elif self.error_method == 'mape':
|
||||
elif self.error_method == "mape":
|
||||
loss = jnp.mean(jnp.abs((predict - self.targets) / self.targets))
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return state, -loss
|
||||
return -loss
|
||||
|
||||
def show(self, randkey, state, act_func, params, *args, **kwargs):
|
||||
state, predict = jax.vmap(act_func, in_axes=(None, 0, None), out_axes=(None, 0))(state, self.inputs, params)
|
||||
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||
predict = jax.vmap(act_func, in_axes=(None, 0, None))(
|
||||
state, self.inputs, params
|
||||
)
|
||||
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
|
||||
state, loss = self.evaluate(randkey, state, act_func, params)
|
||||
loss = self.evaluate(state, randkey, act_func, params)
|
||||
loss = -loss
|
||||
|
||||
msg = ""
|
||||
|
||||
@@ -4,27 +4,16 @@ from .func_fit import FuncFit
|
||||
|
||||
|
||||
class XOR(FuncFit):
|
||||
|
||||
def __init__(self, error_method: str = 'mse'):
|
||||
def __init__(self, error_method: str = "mse"):
|
||||
super().__init__(error_method)
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return np.array([
|
||||
[0, 0],
|
||||
[0, 1],
|
||||
[1, 0],
|
||||
[1, 1]
|
||||
])
|
||||
return np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
|
||||
|
||||
@property
|
||||
def targets(self):
|
||||
return np.array([
|
||||
[0],
|
||||
[1],
|
||||
[1],
|
||||
[0]
|
||||
])
|
||||
return np.array([[0], [1], [1], [0]])
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
|
||||
@@ -4,35 +4,27 @@ from .func_fit import FuncFit
|
||||
|
||||
|
||||
class XOR3d(FuncFit):
|
||||
|
||||
def __init__(self, error_method: str = 'mse'):
|
||||
def __init__(self, error_method: str = "mse"):
|
||||
super().__init__(error_method)
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return np.array([
|
||||
[0, 0, 0],
|
||||
[0, 0, 1],
|
||||
[0, 1, 0],
|
||||
[0, 1, 1],
|
||||
[1, 0, 0],
|
||||
[1, 0, 1],
|
||||
[1, 1, 0],
|
||||
[1, 1, 1],
|
||||
])
|
||||
return np.array(
|
||||
[
|
||||
[0, 0, 0],
|
||||
[0, 0, 1],
|
||||
[0, 1, 0],
|
||||
[0, 1, 1],
|
||||
[1, 0, 0],
|
||||
[1, 0, 1],
|
||||
[1, 1, 0],
|
||||
[1, 1, 1],
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def targets(self):
|
||||
return np.array([
|
||||
[0],
|
||||
[1],
|
||||
[1],
|
||||
[0],
|
||||
[1],
|
||||
[0],
|
||||
[0],
|
||||
[1]
|
||||
])
|
||||
return np.array([[0], [1], [1], [0], [1], [0], [0], [1]])
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
|
||||
@@ -25,7 +25,19 @@ class BraxEnv(RLEnv):
|
||||
def output_shape(self):
|
||||
return (self.env.action_size,)
|
||||
|
||||
def show(self, randkey, state, act_func, params, save_path=None, height=512, width=512, duration=0.1, *args, **kwargs):
|
||||
def show(
|
||||
self,
|
||||
state,
|
||||
randkey,
|
||||
act_func,
|
||||
params,
|
||||
save_path=None,
|
||||
height=512,
|
||||
width=512,
|
||||
duration=0.1,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
import jax
|
||||
import imageio
|
||||
@@ -48,11 +60,13 @@ class BraxEnv(RLEnv):
|
||||
key, env_state, obs, r, done = jax.jit(step)(randkey, env_state, obs)
|
||||
reward += r
|
||||
|
||||
imgs = [image.render_array(sys=self.env.sys, state=s, width=width, height=height) for s in
|
||||
tqdm(state_histories, desc="Rendering")]
|
||||
imgs = [
|
||||
image.render_array(sys=self.env.sys, state=s, width=width, height=height)
|
||||
for s in tqdm(state_histories, desc="Rendering")
|
||||
]
|
||||
|
||||
def create_gif(image_list, gif_name, duration):
|
||||
with imageio.get_writer(gif_name, mode='I', duration=duration) as writer:
|
||||
with imageio.get_writer(gif_name, mode="I", duration=duration) as writer:
|
||||
for image in image_list:
|
||||
formatted_image = np.array(image, dtype=np.uint8)
|
||||
writer.append_data(formatted_image)
|
||||
@@ -60,5 +74,3 @@ class BraxEnv(RLEnv):
|
||||
create_gif(imgs, save_path, duration=0.1)
|
||||
print("Gif saved to: ", save_path)
|
||||
print("Total reward: ", reward)
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from .rl_jit import RLEnv
|
||||
|
||||
|
||||
class GymNaxEnv(RLEnv):
|
||||
|
||||
def __init__(self, env_name):
|
||||
super().__init__()
|
||||
assert env_name in gymnax.registered_envs, f"Env {env_name} not registered"
|
||||
@@ -24,5 +23,5 @@ class GymNaxEnv(RLEnv):
|
||||
def output_shape(self):
|
||||
return self.env.action_space(self.env_params).shape
|
||||
|
||||
def show(self, randkey, state, act_func, params, *args, **kwargs):
|
||||
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||
raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).")
|
||||
|
||||
@@ -12,29 +12,29 @@ class RLEnv(BaseProblem):
|
||||
super().__init__()
|
||||
self.max_step = max_step
|
||||
|
||||
def evaluate(self, randkey, state, act_func, params):
|
||||
def evaluate(self, state, randkey, act_func, params):
|
||||
rng_reset, rng_episode = jax.random.split(randkey)
|
||||
init_obs, init_env_state = self.reset(rng_reset)
|
||||
|
||||
def cond_func(carry):
|
||||
_, _, _, _, done, _, count = carry
|
||||
return ~done & (count < self.max_step)
|
||||
_, _, _, done, _, count = carry
|
||||
return ~done & (count < self.max_step)
|
||||
|
||||
def body_func(carry):
|
||||
state_, obs, env_state, rng, done, tr, count = carry # tr -> total reward
|
||||
state_, action = act_func(state_, obs, params)
|
||||
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
|
||||
next_rng, _ = jax.random.split(rng)
|
||||
return state_, next_obs, next_env_state, next_rng, done, tr + reward, count + 1
|
||||
obs, env_state, rng, done, tr, count = carry # tr -> total reward
|
||||
action = act_func(state, obs, params)
|
||||
next_obs, next_env_state, reward, done, _ = self.step(
|
||||
rng, env_state, action
|
||||
)
|
||||
next_rng, _ = jax.random.split(rng)
|
||||
return next_obs, next_env_state, next_rng, done, tr + reward, count + 1
|
||||
|
||||
state, _, _, _, _, total_reward, _ = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(state, init_obs, init_env_state, rng_episode, False, 0.0, 0)
|
||||
_, _, _, _, total_reward, _ = jax.lax.while_loop(
|
||||
cond_func, body_func, (init_obs, init_env_state, rng_episode, False, 0.0, 0)
|
||||
)
|
||||
|
||||
return state, total_reward
|
||||
|
||||
return total_reward
|
||||
|
||||
# @partial(jax.jit, static_argnums=(0,))
|
||||
def step(self, randkey, env_state, action):
|
||||
return self.env_step(randkey, env_state, action)
|
||||
@@ -57,5 +57,5 @@ class RLEnv(BaseProblem):
|
||||
def output_shape(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def show(self, randkey, state, act_func, params, *args, **kwargs):
|
||||
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -36,7 +36,9 @@ def main():
|
||||
elite_mask = jnp.zeros((1000,), dtype=jnp.bool_)
|
||||
elite_mask = elite_mask.at[:5].set(1)
|
||||
|
||||
state = algorithm.create_next_generation(jax.random.key(0), state, winner, losser, elite_mask)
|
||||
state = algorithm.create_next_generation(
|
||||
jax.random.key(0), state, winner, losser, elite_mask
|
||||
)
|
||||
pop_nodes, pop_conns = algorithm.species.ask(state.species)
|
||||
|
||||
transforms = batch_transform(pop_nodes, pop_conns)
|
||||
@@ -48,5 +50,5 @@ def main():
|
||||
print(_)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -19,7 +19,7 @@ def main():
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_options=(Act.tanh,),
|
||||
activation_default=Act.tanh,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
transformed = genome.transform(nodes, conns)
|
||||
@@ -35,7 +35,7 @@ def main():
|
||||
print(output)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
a = jnp.array([1, 3, 5, 6, 8])
|
||||
b = jnp.array([1, 2, 3])
|
||||
print(jnp.isin(a, b))
|
||||
|
||||
@@ -2,6 +2,7 @@ from algorithm.neat import *
|
||||
from utils import Act, Agg, State
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse
|
||||
|
||||
|
||||
def test_default():
|
||||
@@ -135,3 +136,29 @@ def test_recurrent():
|
||||
print(outputs)
|
||||
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
|
||||
# expected: [[0.5], [0.75], [0.5], [0.75]]
|
||||
|
||||
|
||||
def test_random_initialize():
|
||||
genome = DefaultGenome(
|
||||
num_inputs=2,
|
||||
num_outputs=1,
|
||||
max_nodes=5,
|
||||
max_conns=4,
|
||||
node_gene=NodeGeneWithoutResponse(
|
||||
activation_default=Act.identity,
|
||||
activation_options=(Act.identity,),
|
||||
aggregation_default=Agg.sum,
|
||||
aggregation_options=(Agg.sum,),
|
||||
),
|
||||
)
|
||||
state = genome.setup()
|
||||
key = jax.random.PRNGKey(0)
|
||||
nodes, conns = genome.initialize(state, key)
|
||||
transformed = genome.transform(state, nodes, conns)
|
||||
print(*transformed, sep="\n")
|
||||
|
||||
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
|
||||
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, None)))(
|
||||
state, inputs, transformed
|
||||
)
|
||||
print(outputs)
|
||||
|
||||
@@ -19,11 +19,11 @@ def main():
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_options=(Act.tanh,),
|
||||
activation_default=Act.tanh,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
transformed = genome.transform(nodes, conns)
|
||||
print(*transformed, sep='\n')
|
||||
print(*transformed, sep="\n")
|
||||
|
||||
key = jax.random.key(0)
|
||||
dummy_input = jnp.zeros((8,))
|
||||
@@ -31,5 +31,5 @@ def main():
|
||||
print(output)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -2,4 +2,4 @@ from .activation import Act, act
|
||||
from .aggregation import Agg, agg
|
||||
from .tools import *
|
||||
from .graph import *
|
||||
from .state import State
|
||||
from .state import State
|
||||
|
||||
@@ -116,3 +116,41 @@ def argmin_with_mask(arr, mask):
|
||||
masked_arr = jnp.where(mask, arr, jnp.inf)
|
||||
min_idx = jnp.argmin(masked_arr)
|
||||
return min_idx
|
||||
|
||||
|
||||
def add_node(nodes, new_key: int, attrs):
|
||||
"""
|
||||
Add a new node to the genome.
|
||||
The new node will place at the first NaN row.
|
||||
"""
|
||||
exist_keys = nodes[:, 0]
|
||||
pos = fetch_first(jnp.isnan(exist_keys))
|
||||
new_nodes = nodes.at[pos, 0].set(new_key)
|
||||
return new_nodes.at[pos, 1:].set(attrs)
|
||||
|
||||
|
||||
def delete_node_by_pos(nodes, pos):
|
||||
"""
|
||||
Delete a node from the genome.
|
||||
Delete the node by its pos in nodes.
|
||||
"""
|
||||
return nodes.at[pos].set(jnp.nan)
|
||||
|
||||
|
||||
def add_conn(conns, i_key, o_key, enable: bool, attrs):
|
||||
"""
|
||||
Add a new connection to the genome.
|
||||
The new connection will place at the first NaN row.
|
||||
"""
|
||||
con_keys = conns[:, 0]
|
||||
pos = fetch_first(jnp.isnan(con_keys))
|
||||
new_conns = conns.at[pos, 0:3].set(jnp.array([i_key, o_key, enable]))
|
||||
return new_conns.at[pos, 3:].set(attrs)
|
||||
|
||||
|
||||
def delete_conn_by_pos(conns, pos):
|
||||
"""
|
||||
Delete a connection from the genome.
|
||||
Delete the connection by its idx.
|
||||
"""
|
||||
return conns.at[pos].set(jnp.nan)
|
||||
|
||||
Reference in New Issue
Block a user