finish all refactoring
This commit is contained in:
@@ -16,9 +16,30 @@ class BaseAlgorithm:
|
|||||||
"""update the state of the algorithm"""
|
"""update the state of the algorithm"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def transform(self, state: State):
|
def transform(self, individual):
|
||||||
"""transform the genome into a neural network"""
|
"""transform the genome into a neural network"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def forward(self, inputs, transformed):
|
def forward(self, inputs, transformed):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_inputs(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_outputs(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pop_size(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def member_count(self, state: State):
|
||||||
|
# to analysis the species
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def generation(self, state: State):
|
||||||
|
# to analysis the algorithm
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
2
algorithm/hyperneat/__init__.py
Normal file
2
algorithm/hyperneat/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
from .hyperneat import HyperNEAT
|
||||||
|
from .substrate import BaseSubstrate, DefaultSubstrate, FullSubstrate
|
||||||
116
algorithm/hyperneat/hyperneat.py
Normal file
116
algorithm/hyperneat/hyperneat.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
import jax, jax.numpy as jnp
|
||||||
|
|
||||||
|
from utils import State, Act, Agg
|
||||||
|
from .. import BaseAlgorithm, NEAT
|
||||||
|
from ..neat.gene import BaseNodeGene, BaseConnGene
|
||||||
|
from ..neat.genome import RecurrentGenome
|
||||||
|
from .substrate import *
|
||||||
|
|
||||||
|
|
||||||
|
class HyperNEAT(BaseAlgorithm):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
substrate: BaseSubstrate,
|
||||||
|
neat: NEAT,
|
||||||
|
below_threshold: float = 0.3,
|
||||||
|
max_weight: float = 5.,
|
||||||
|
activation=Act.sigmoid,
|
||||||
|
aggregation=Agg.sum,
|
||||||
|
activate_time: int = 10,
|
||||||
|
):
|
||||||
|
assert substrate.query_coors.shape[1] == neat.num_inputs, \
|
||||||
|
"Substrate input size should be equal to NEAT input size"
|
||||||
|
|
||||||
|
self.substrate = substrate
|
||||||
|
self.neat = neat
|
||||||
|
self.below_threshold = below_threshold
|
||||||
|
self.max_weight = max_weight
|
||||||
|
self.hyper_genome = RecurrentGenome(
|
||||||
|
num_inputs=substrate.num_inputs,
|
||||||
|
num_outputs=substrate.num_outputs,
|
||||||
|
max_nodes=substrate.nodes_cnt,
|
||||||
|
max_conns=substrate.conns_cnt,
|
||||||
|
node_gene=HyperNodeGene(activation, aggregation),
|
||||||
|
conn_gene=HyperNEATConnGene(),
|
||||||
|
activate_time=activate_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
def setup(self, randkey):
|
||||||
|
return State(
|
||||||
|
neat_state=self.neat.setup(randkey)
|
||||||
|
)
|
||||||
|
|
||||||
|
def ask(self, state: State):
|
||||||
|
return self.neat.ask(state.neat_state)
|
||||||
|
|
||||||
|
def tell(self, state: State, fitness):
|
||||||
|
return state.update(
|
||||||
|
neat_state=self.neat.tell(state.neat_state, fitness)
|
||||||
|
)
|
||||||
|
|
||||||
|
def transform(self, individual):
|
||||||
|
transformed = self.neat.transform(individual)
|
||||||
|
query_res = jax.vmap(self.neat.forward, in_axes=(0, None))(self.substrate.query_coors, transformed)
|
||||||
|
|
||||||
|
# mute the connection with weight below threshold
|
||||||
|
query_res = jnp.where(
|
||||||
|
(-self.below_threshold < query_res) & (query_res < self.below_threshold),
|
||||||
|
0.,
|
||||||
|
query_res
|
||||||
|
)
|
||||||
|
|
||||||
|
# make query res in range [-max_weight, max_weight]
|
||||||
|
query_res = jnp.where(query_res > 0, query_res - self.below_threshold, query_res)
|
||||||
|
query_res = jnp.where(query_res < 0, query_res + self.below_threshold, query_res)
|
||||||
|
query_res = query_res / (1 - self.below_threshold) * self.max_weight
|
||||||
|
|
||||||
|
h_nodes, h_conns = self.substrate.make_nodes(query_res), self.substrate.make_conn(query_res)
|
||||||
|
return self.hyper_genome.transform(h_nodes, h_conns)
|
||||||
|
|
||||||
|
def forward(self, inputs, transformed):
|
||||||
|
# add bias
|
||||||
|
inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])])
|
||||||
|
return self.hyper_genome.forward(inputs_with_bias, transformed)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_inputs(self):
|
||||||
|
return self.substrate.num_inputs - 1 # remove bias
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_outputs(self):
|
||||||
|
return self.substrate.num_outputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pop_size(self):
|
||||||
|
return self.neat.pop_size
|
||||||
|
|
||||||
|
def member_count(self, state: State):
|
||||||
|
return self.neat.member_count(state.neat_state)
|
||||||
|
|
||||||
|
def generation(self, state: State):
|
||||||
|
return self.neat.generation(state.neat_state)
|
||||||
|
|
||||||
|
|
||||||
|
class HyperNodeGene(BaseNodeGene):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
activation=Act.sigmoid,
|
||||||
|
aggregation=Agg.sum,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.activation = activation
|
||||||
|
self.aggregation = aggregation
|
||||||
|
|
||||||
|
def forward(self, attrs, inputs):
|
||||||
|
return self.activation(
|
||||||
|
self.aggregation(inputs)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HyperNEATConnGene(BaseConnGene):
|
||||||
|
custom_attrs = ['weight']
|
||||||
|
|
||||||
|
def forward(self, attrs, inputs):
|
||||||
|
weight = attrs[0]
|
||||||
|
return inputs * weight
|
||||||
3
algorithm/hyperneat/substrate/__init__.py
Normal file
3
algorithm/hyperneat/substrate/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .base import BaseSubstrate
|
||||||
|
from .default import DefaultSubstrate
|
||||||
|
from .full import FullSubstrate
|
||||||
27
algorithm/hyperneat/substrate/base.py
Normal file
27
algorithm/hyperneat/substrate/base.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
class BaseSubstrate:
|
||||||
|
|
||||||
|
def make_nodes(self, query_res):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def make_conn(self, query_res):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def query_coors(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_inputs(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_outputs(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nodes_cnt(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def conns_cnt(self):
|
||||||
|
raise NotImplementedError
|
||||||
38
algorithm/hyperneat/substrate/default.py
Normal file
38
algorithm/hyperneat/substrate/default.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import jax.numpy as jnp
|
||||||
|
from . import BaseSubstrate
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultSubstrate(BaseSubstrate):
|
||||||
|
|
||||||
|
def __init__(self, num_inputs, num_outputs, coors, nodes, conns):
|
||||||
|
self.inputs = num_inputs
|
||||||
|
self.outputs = num_outputs
|
||||||
|
self.coors = jnp.array(coors)
|
||||||
|
self.nodes = jnp.array(nodes)
|
||||||
|
self.conns = jnp.array(conns)
|
||||||
|
|
||||||
|
def make_nodes(self, query_res):
|
||||||
|
return self.nodes
|
||||||
|
|
||||||
|
def make_conn(self, query_res):
|
||||||
|
return self.conns.at[:, 3:].set(query_res) # change weight
|
||||||
|
|
||||||
|
@property
|
||||||
|
def query_coors(self):
|
||||||
|
return self.coors
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_inputs(self):
|
||||||
|
return self.inputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_outputs(self):
|
||||||
|
return self.outputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nodes_cnt(self):
|
||||||
|
return self.nodes.shape[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def conns_cnt(self):
|
||||||
|
return self.conns.shape[0]
|
||||||
76
algorithm/hyperneat/substrate/full.py
Normal file
76
algorithm/hyperneat/substrate/full.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
import numpy as np
|
||||||
|
from .default import DefaultSubstrate
|
||||||
|
|
||||||
|
|
||||||
|
class FullSubstrate(DefaultSubstrate):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
input_coors=((-1, -1), (0, -1), (1, -1)),
|
||||||
|
hidden_coors=((-1, 0), (0, 0), (1, 0)),
|
||||||
|
output_coors=((0, 1),),
|
||||||
|
):
|
||||||
|
query_coors, nodes, conns = analysis_substrate(input_coors, output_coors, hidden_coors)
|
||||||
|
super().__init__(
|
||||||
|
len(input_coors),
|
||||||
|
len(output_coors),
|
||||||
|
query_coors,
|
||||||
|
nodes,
|
||||||
|
conns
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def analysis_substrate(input_coors, output_coors, hidden_coors):
|
||||||
|
input_coors = np.array(input_coors)
|
||||||
|
output_coors = np.array(output_coors)
|
||||||
|
hidden_coors = np.array(hidden_coors)
|
||||||
|
|
||||||
|
cd = input_coors.shape[1] # coordinate dimensions
|
||||||
|
si = input_coors.shape[0] # input coordinate size
|
||||||
|
so = output_coors.shape[0] # output coordinate size
|
||||||
|
sh = hidden_coors.shape[0] # hidden coordinate size
|
||||||
|
|
||||||
|
input_idx = np.arange(si)
|
||||||
|
output_idx = np.arange(si, si + so)
|
||||||
|
hidden_idx = np.arange(si + so, si + so + sh)
|
||||||
|
|
||||||
|
total_conns = si * sh + sh * sh + sh * so
|
||||||
|
query_coors = np.zeros((total_conns, cd * 2))
|
||||||
|
correspond_keys = np.zeros((total_conns, 2))
|
||||||
|
|
||||||
|
# connect input to hidden
|
||||||
|
aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, input_coors, hidden_coors)
|
||||||
|
query_coors[0: si * sh, :] = aux_coors
|
||||||
|
correspond_keys[0: si * sh, :] = aux_keys
|
||||||
|
|
||||||
|
# connect hidden to hidden
|
||||||
|
aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, hidden_coors, hidden_coors)
|
||||||
|
query_coors[si * sh: si * sh + sh * sh, :] = aux_coors
|
||||||
|
correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys
|
||||||
|
|
||||||
|
# connect hidden to output
|
||||||
|
aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, hidden_coors, output_coors)
|
||||||
|
query_coors[si * sh + sh * sh:, :] = aux_coors
|
||||||
|
correspond_keys[si * sh + sh * sh:, :] = aux_keys
|
||||||
|
|
||||||
|
nodes = np.concatenate((input_idx, output_idx, hidden_idx))[..., np.newaxis]
|
||||||
|
conns = np.zeros((correspond_keys.shape[0], 4), dtype=np.float32) # input_idx, output_idx, enabled, weight
|
||||||
|
conns[:, 0:2] = correspond_keys
|
||||||
|
conns[:, 2] = 1 # enabled is True
|
||||||
|
|
||||||
|
return query_coors, nodes, conns
|
||||||
|
|
||||||
|
|
||||||
|
def cartesian_product(keys1, keys2, coors1, coors2):
|
||||||
|
len1 = keys1.shape[0]
|
||||||
|
len2 = keys2.shape[0]
|
||||||
|
|
||||||
|
repeated_coors1 = np.repeat(coors1, len2, axis=0)
|
||||||
|
repeated_keys1 = np.repeat(keys1, len2)
|
||||||
|
|
||||||
|
tiled_coors2 = np.tile(coors2, (len1, 1))
|
||||||
|
tiled_keys2 = np.tile(keys2, len1)
|
||||||
|
|
||||||
|
new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1)
|
||||||
|
correspond_keys = np.column_stack((repeated_keys1, tiled_keys2))
|
||||||
|
|
||||||
|
return new_coors, correspond_keys
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
from .gene import *
|
from .gene import *
|
||||||
from .genome import *
|
from .genome import *
|
||||||
|
from .species import *
|
||||||
from .neat import NEAT
|
from .neat import NEAT
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,8 @@ import jax, jax.numpy as jnp
|
|||||||
from .base import BaseCrossover
|
from .base import BaseCrossover
|
||||||
|
|
||||||
class DefaultCrossover(BaseCrossover):
|
class DefaultCrossover(BaseCrossover):
|
||||||
def __call__(self, randkey, genome, nodes1, nodes2, conns1, conns2):
|
|
||||||
|
def __call__(self, randkey, genome, nodes1, conns1, nodes2, conns2):
|
||||||
"""
|
"""
|
||||||
use genome1 and genome2 to generate a new genome
|
use genome1 and genome2 to generate a new genome
|
||||||
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ class DefaultMutation(BaseMutation):
|
|||||||
return nodes_, conns_
|
return nodes_, conns_
|
||||||
|
|
||||||
def successful():
|
def successful():
|
||||||
return nodes_, genome.add_conn(conns_, i_key, o_key, True, genome.conns.new_custom_attrs())
|
return nodes_, genome.add_conn(conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs())
|
||||||
|
|
||||||
def already_exist():
|
def already_exist():
|
||||||
return nodes_, conns_.at[conn_pos, 2].set(True)
|
return nodes_, conns_.at[conn_pos, 2].set(True)
|
||||||
@@ -105,11 +105,12 @@ class DefaultMutation(BaseMutation):
|
|||||||
return jax.lax.cond(
|
return jax.lax.cond(
|
||||||
is_already_exist,
|
is_already_exist,
|
||||||
already_exist,
|
already_exist,
|
||||||
jax.lax.cond(
|
lambda:
|
||||||
is_cycle,
|
jax.lax.cond(
|
||||||
nothing,
|
is_cycle,
|
||||||
successful
|
nothing,
|
||||||
)
|
successful
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
elif genome.network_type == 'recurrent':
|
elif genome.network_type == 'recurrent':
|
||||||
@@ -138,23 +139,23 @@ class DefaultMutation(BaseMutation):
|
|||||||
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
|
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
|
||||||
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
|
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
|
||||||
|
|
||||||
def no(k, g):
|
def no(key_, nodes_, conns_):
|
||||||
return g
|
return nodes_, conns_
|
||||||
|
|
||||||
genome = jax.lax.cond(r1 < self.node_add, mutate_add_node, no, k1, nodes, conns)
|
nodes, conns = jax.lax.cond(r1 < self.node_add, mutate_add_node, no, k1, nodes, conns)
|
||||||
genome = jax.lax.cond(r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns)
|
nodes, conns = jax.lax.cond(r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns)
|
||||||
genome = jax.lax.cond(r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns)
|
nodes, conns = jax.lax.cond(r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns)
|
||||||
genome = jax.lax.cond(r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns)
|
nodes, conns = jax.lax.cond(r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns)
|
||||||
|
|
||||||
return genome
|
return nodes, conns
|
||||||
|
|
||||||
def mutate_values(self, randkey, genome, nodes, conns):
|
def mutate_values(self, randkey, genome, nodes, conns):
|
||||||
k1, k2 = jax.random.split(randkey, num=2)
|
k1, k2 = jax.random.split(randkey, num=2)
|
||||||
nodes_keys = jax.random.split(k1, num=genome.nodes.shape[0])
|
nodes_keys = jax.random.split(k1, num=nodes.shape[0])
|
||||||
conns_keys = jax.random.split(k2, num=genome.conns.shape[0])
|
conns_keys = jax.random.split(k2, num=conns.shape[0])
|
||||||
|
|
||||||
new_nodes = jax.vmap(genome.nodes.mutate, in_axes=(0, 0))(nodes_keys, nodes)
|
new_nodes = jax.vmap(genome.node_gene.mutate, in_axes=(0, 0))(nodes_keys, nodes)
|
||||||
new_conns = jax.vmap(genome.conns.mutate, in_axes=(0, 0))(conns_keys, conns)
|
new_conns = jax.vmap(genome.conn_gene.mutate, in_axes=(0, 0))(conns_keys, conns)
|
||||||
|
|
||||||
# nan nodes not changed
|
# nan nodes not changed
|
||||||
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)
|
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)
|
||||||
|
|||||||
@@ -7,8 +7,7 @@ from . import BaseConnGene
|
|||||||
class DefaultConnGene(BaseConnGene):
|
class DefaultConnGene(BaseConnGene):
|
||||||
"Default connection gene, with the same behavior as in NEAT-python."
|
"Default connection gene, with the same behavior as in NEAT-python."
|
||||||
|
|
||||||
fixed_attrs = ['input_index', 'output_index', 'enabled']
|
custom_attrs = ['weight']
|
||||||
attrs = ['weight']
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from . import BaseNodeGene
|
|||||||
class DefaultNodeGene(BaseNodeGene):
|
class DefaultNodeGene(BaseNodeGene):
|
||||||
"Default node gene, with the same behavior as in NEAT-python."
|
"Default node gene, with the same behavior as in NEAT-python."
|
||||||
|
|
||||||
fixed_attrs = ['index']
|
|
||||||
custom_attrs = ['bias', 'response', 'aggregation', 'activation']
|
custom_attrs = ['bias', 'response', 'aggregation', 'activation']
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -82,8 +81,8 @@ class DefaultNodeGene(BaseNodeGene):
|
|||||||
return (
|
return (
|
||||||
jnp.abs(node1[1] - node2[1]) +
|
jnp.abs(node1[1] - node2[1]) +
|
||||||
jnp.abs(node1[2] - node2[2]) +
|
jnp.abs(node1[2] - node2[2]) +
|
||||||
node1[3] != node2[3] +
|
(node1[3] != node2[3]) +
|
||||||
node1[4] != node2[4]
|
(node1[4] != node2[4])
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, attrs, inputs):
|
def forward(self, attrs, inputs):
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ from utils import fetch_first
|
|||||||
|
|
||||||
|
|
||||||
class BaseGenome:
|
class BaseGenome:
|
||||||
|
|
||||||
network_type = None
|
network_type = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Callable
|
||||||
|
|
||||||
import jax, jax.numpy as jnp
|
import jax, jax.numpy as jnp
|
||||||
from utils import unflatten_conns, topological_sort, I_INT
|
from utils import unflatten_conns, topological_sort, I_INT
|
||||||
|
|
||||||
@@ -13,10 +15,20 @@ class DefaultGenome(BaseGenome):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_inputs: int,
|
num_inputs: int,
|
||||||
num_outputs: int,
|
num_outputs: int,
|
||||||
|
max_nodes=5,
|
||||||
|
max_conns=4,
|
||||||
node_gene: BaseNodeGene = DefaultNodeGene(),
|
node_gene: BaseNodeGene = DefaultNodeGene(),
|
||||||
conn_gene: BaseConnGene = DefaultConnGene(),
|
conn_gene: BaseConnGene = DefaultConnGene(),
|
||||||
|
output_transform: Callable = None
|
||||||
):
|
):
|
||||||
super().__init__(num_inputs, num_outputs, node_gene, conn_gene)
|
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
|
||||||
|
|
||||||
|
if output_transform is not None:
|
||||||
|
try:
|
||||||
|
aux = output_transform(jnp.zeros(num_outputs))
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Output transform function failed: {e}")
|
||||||
|
self.output_transform = output_transform
|
||||||
|
|
||||||
def transform(self, nodes, conns):
|
def transform(self, nodes, conns):
|
||||||
u_conns = unflatten_conns(nodes, conns)
|
u_conns = unflatten_conns(nodes, conns)
|
||||||
@@ -72,4 +84,7 @@ class DefaultGenome(BaseGenome):
|
|||||||
|
|
||||||
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
|
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
|
||||||
|
|
||||||
return vals[self.output_idx]
|
if self.output_transform is None:
|
||||||
|
return vals[self.output_idx]
|
||||||
|
else:
|
||||||
|
return self.output_transform(vals[self.output_idx])
|
||||||
|
|||||||
@@ -13,11 +13,13 @@ class RecurrentGenome(BaseGenome):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_inputs: int,
|
num_inputs: int,
|
||||||
num_outputs: int,
|
num_outputs: int,
|
||||||
|
max_nodes: int,
|
||||||
|
max_conns: int,
|
||||||
node_gene: BaseNodeGene = DefaultNodeGene(),
|
node_gene: BaseNodeGene = DefaultNodeGene(),
|
||||||
conn_gene: BaseConnGene = DefaultConnGene(),
|
conn_gene: BaseConnGene = DefaultConnGene(),
|
||||||
activate_time: int = 10,
|
activate_time: int = 10,
|
||||||
):
|
):
|
||||||
super().__init__(num_inputs, num_outputs, node_gene, conn_gene)
|
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
|
||||||
self.activate_time = activate_time
|
self.activate_time = activate_time
|
||||||
|
|
||||||
def transform(self, nodes, conns):
|
def transform(self, nodes, conns):
|
||||||
|
|||||||
@@ -1,20 +1,19 @@
|
|||||||
import jax, jax.numpy as jnp
|
import jax, jax.numpy as jnp
|
||||||
from utils import State
|
from utils import State
|
||||||
from .. import BaseAlgorithm
|
from .. import BaseAlgorithm
|
||||||
from .genome import *
|
|
||||||
from .species import *
|
from .species import *
|
||||||
from .ga import *
|
from .ga import *
|
||||||
|
|
||||||
|
|
||||||
class NEAT(BaseAlgorithm):
|
class NEAT(BaseAlgorithm):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
genome: BaseGenome,
|
|
||||||
species: BaseSpecies,
|
species: BaseSpecies,
|
||||||
mutation: BaseMutation = DefaultMutation(),
|
mutation: BaseMutation = DefaultMutation(),
|
||||||
crossover: BaseCrossover = DefaultCrossover(),
|
crossover: BaseCrossover = DefaultCrossover(),
|
||||||
):
|
):
|
||||||
self.genome = genome
|
self.genome = species.genome
|
||||||
self.species = species
|
self.species = species
|
||||||
self.mutation = mutation
|
self.mutation = mutation
|
||||||
self.crossover = crossover
|
self.crossover = crossover
|
||||||
@@ -23,14 +22,14 @@ class NEAT(BaseAlgorithm):
|
|||||||
k1, k2 = jax.random.split(randkey, 2)
|
k1, k2 = jax.random.split(randkey, 2)
|
||||||
return State(
|
return State(
|
||||||
randkey=k1,
|
randkey=k1,
|
||||||
generation=0,
|
generation=jnp.array(0.),
|
||||||
next_node_key=max(*self.genome.input_idx, *self.genome.output_idx) + 2,
|
next_node_key=jnp.array(max(*self.genome.input_idx, *self.genome.output_idx) + 2, dtype=jnp.float32),
|
||||||
# inputs nodes, output nodes, 1 hidden node
|
# inputs nodes, output nodes, 1 hidden node
|
||||||
species=self.species.setup(k2),
|
species=self.species.setup(k2),
|
||||||
)
|
)
|
||||||
|
|
||||||
def ask(self, state: State):
|
def ask(self, state: State):
|
||||||
return self.species.ask(state)
|
return self.species.ask(state.species)
|
||||||
|
|
||||||
def tell(self, state: State, fitness):
|
def tell(self, state: State, fitness):
|
||||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||||
@@ -40,25 +39,39 @@ class NEAT(BaseAlgorithm):
|
|||||||
randkey=randkey
|
randkey=randkey
|
||||||
)
|
)
|
||||||
|
|
||||||
state, winner, loser, elite_mask = self.species.update_species(state, fitness, state.generation)
|
species_state, winner, loser, elite_mask = self.species.update_species(state.species, fitness, state.generation)
|
||||||
|
state = state.update(species=species_state)
|
||||||
|
|
||||||
state = self.create_next_generation(k2, state, winner, loser, elite_mask)
|
state = self.create_next_generation(k2, state, winner, loser, elite_mask)
|
||||||
|
|
||||||
state = self.species.speciate(state, state.generation)
|
species_state = self.species.speciate(state.species, state.generation)
|
||||||
|
state = state.update(species=species_state)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def transform(self, state: State):
|
def transform(self, individual):
|
||||||
"""transform the genome into a neural network"""
|
"""transform the genome into a neural network"""
|
||||||
raise NotImplementedError
|
nodes, conns = individual
|
||||||
|
return self.genome.transform(nodes, conns)
|
||||||
|
|
||||||
def forward(self, inputs, transformed):
|
def forward(self, inputs, transformed):
|
||||||
raise NotImplementedError
|
return self.genome.forward(inputs, transformed)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_inputs(self):
|
||||||
|
return self.genome.num_inputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_outputs(self):
|
||||||
|
return self.genome.num_outputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pop_size(self):
|
||||||
|
return self.species.pop_size
|
||||||
|
|
||||||
def create_next_generation(self, randkey, state, winner, loser, elite_mask):
|
def create_next_generation(self, randkey, state, winner, loser, elite_mask):
|
||||||
# prepare random keys
|
# prepare random keys
|
||||||
pop_size = self.species.pop_size
|
pop_size = self.species.pop_size
|
||||||
new_node_keys = jnp.arange(pop_size) + state.species.next_node_key
|
new_node_keys = jnp.arange(pop_size) + state.next_node_key
|
||||||
|
|
||||||
k1, k2 = jax.random.split(randkey, 2)
|
k1, k2 = jax.random.split(randkey, 2)
|
||||||
crossover_rand_keys = jax.random.split(k1, pop_size)
|
crossover_rand_keys = jax.random.split(k1, pop_size)
|
||||||
@@ -69,11 +82,11 @@ class NEAT(BaseAlgorithm):
|
|||||||
|
|
||||||
# batch crossover
|
# batch crossover
|
||||||
n_nodes, n_conns = (jax.vmap(self.crossover, in_axes=(0, None, 0, 0, 0, 0))
|
n_nodes, n_conns = (jax.vmap(self.crossover, in_axes=(0, None, 0, 0, 0, 0))
|
||||||
(crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc))
|
(crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc))
|
||||||
|
|
||||||
# batch mutation
|
# batch mutation
|
||||||
m_n_nodes, m_n_conns = (jax.vmap(self.mutation, in_axes=(0, None, 0, 0, 0))
|
m_n_nodes, m_n_conns = (jax.vmap(self.mutation, in_axes=(0, None, 0, 0, 0))
|
||||||
(mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys))
|
(mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys))
|
||||||
|
|
||||||
# elitism don't mutate
|
# elitism don't mutate
|
||||||
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
|
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
|
||||||
@@ -92,3 +105,9 @@ class NEAT(BaseAlgorithm):
|
|||||||
next_node_key=next_node_key,
|
next_node_key=next_node_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def member_count(self, state: State):
|
||||||
|
return state.species.member_count
|
||||||
|
|
||||||
|
def generation(self, state: State):
|
||||||
|
# to analysis the algorithm
|
||||||
|
return state.generation
|
||||||
|
|||||||
@@ -2,9 +2,10 @@ import numpy as np
|
|||||||
import jax, jax.numpy as jnp
|
import jax, jax.numpy as jnp
|
||||||
from utils import State, rank_elements, argmin_with_mask, fetch_first
|
from utils import State, rank_elements, argmin_with_mask, fetch_first
|
||||||
from ..genome import BaseGenome
|
from ..genome import BaseGenome
|
||||||
|
from .base import BaseSpecies
|
||||||
|
|
||||||
|
|
||||||
class DefaultSpecies:
|
class DefaultSpecies(BaseSpecies):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
genome: BaseGenome,
|
genome: BaseGenome,
|
||||||
@@ -18,9 +19,8 @@ class DefaultSpecies:
|
|||||||
genome_elitism: int = 2,
|
genome_elitism: int = 2,
|
||||||
survival_threshold: float = 0.2,
|
survival_threshold: float = 0.2,
|
||||||
min_species_size: int = 1,
|
min_species_size: int = 1,
|
||||||
compatibility_threshold: float = 3.5
|
compatibility_threshold: float = 3.
|
||||||
):
|
):
|
||||||
|
|
||||||
self.genome = genome
|
self.genome = genome
|
||||||
self.pop_size = pop_size
|
self.pop_size = pop_size
|
||||||
self.species_size = species_size
|
self.species_size = species_size
|
||||||
@@ -59,8 +59,12 @@ class DefaultSpecies:
|
|||||||
center_nodes = center_nodes.at[0].set(pop_nodes[0])
|
center_nodes = center_nodes.at[0].set(pop_nodes[0])
|
||||||
center_conns = center_conns.at[0].set(pop_conns[0])
|
center_conns = center_conns.at[0].set(pop_conns[0])
|
||||||
|
|
||||||
|
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns))
|
||||||
|
|
||||||
return State(
|
return State(
|
||||||
randkey=randkey,
|
randkey=randkey,
|
||||||
|
pop_nodes=pop_nodes,
|
||||||
|
pop_conns=pop_conns,
|
||||||
species_keys=species_keys,
|
species_keys=species_keys,
|
||||||
best_fitness=best_fitness,
|
best_fitness=best_fitness,
|
||||||
last_improved=last_improved,
|
last_improved=last_improved,
|
||||||
@@ -68,7 +72,7 @@ class DefaultSpecies:
|
|||||||
idx2species=idx2species,
|
idx2species=idx2species,
|
||||||
center_nodes=center_nodes,
|
center_nodes=center_nodes,
|
||||||
center_conns=center_conns,
|
center_conns=center_conns,
|
||||||
next_species_key=1, # 0 is reserved for the first species
|
next_species_key=jnp.array(1), # 0 is reserved for the first species
|
||||||
)
|
)
|
||||||
|
|
||||||
def ask(self, state):
|
def ask(self, state):
|
||||||
@@ -99,7 +103,7 @@ class DefaultSpecies:
|
|||||||
# crossover info
|
# crossover info
|
||||||
winner, loser, elite_mask = self.create_crossover_pair(state, k1, spawn_number, fitness)
|
winner, loser, elite_mask = self.create_crossover_pair(state, k1, spawn_number, fitness)
|
||||||
|
|
||||||
return state(randkey=k2), winner, loser, elite_mask
|
return state.update(randkey=k2), winner, loser, elite_mask
|
||||||
|
|
||||||
def update_species_fitness(self, state, fitness):
|
def update_species_fitness(self, state, fitness):
|
||||||
"""
|
"""
|
||||||
@@ -156,17 +160,17 @@ class DefaultSpecies:
|
|||||||
jnp.nan, # last_improved
|
jnp.nan, # last_improved
|
||||||
jnp.nan, # member_count
|
jnp.nan, # member_count
|
||||||
-jnp.inf, # species_fitness
|
-jnp.inf, # species_fitness
|
||||||
jnp.full_like(center_nodes[idx], jnp.nan), # center_nodes
|
jnp.full_like(state.center_nodes[idx], jnp.nan), # center_nodes
|
||||||
jnp.full_like(center_conns[idx], jnp.nan), # center_conns
|
jnp.full_like(state.center_conns[idx], jnp.nan), # center_conns
|
||||||
), # stagnation species
|
), # stagnation species
|
||||||
lambda: (
|
lambda: (
|
||||||
species_keys[idx],
|
state.species_keys[idx],
|
||||||
best_fitness[idx],
|
best_fitness[idx],
|
||||||
last_improved[idx],
|
last_improved[idx],
|
||||||
state.member_count[idx],
|
state.member_count[idx],
|
||||||
species_fitness[idx],
|
species_fitness[idx],
|
||||||
center_nodes[idx],
|
state.center_nodes[idx],
|
||||||
center_conns[idx]
|
state.center_conns[idx]
|
||||||
) # not stagnation species
|
) # not stagnation species
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -216,7 +220,7 @@ class DefaultSpecies:
|
|||||||
spawn_number = spawn_number.astype(jnp.int32)
|
spawn_number = spawn_number.astype(jnp.int32)
|
||||||
|
|
||||||
# must control the sum of spawn_number to be equal to pop_size
|
# must control the sum of spawn_number to be equal to pop_size
|
||||||
error = state.P - jnp.sum(spawn_number)
|
error = self.pop_size - jnp.sum(spawn_number)
|
||||||
|
|
||||||
# add error to the first species to control the sum of spawn_number
|
# add error to the first species to control the sum of spawn_number
|
||||||
spawn_number = spawn_number.at[0].add(error)
|
spawn_number = spawn_number.at[0].add(error)
|
||||||
@@ -287,14 +291,14 @@ class DefaultSpecies:
|
|||||||
def body_func(carry):
|
def body_func(carry):
|
||||||
i, i2s, cns, ccs, o2c = carry
|
i, i2s, cns, ccs, o2c = carry
|
||||||
|
|
||||||
distances = o2p_distance_func(cns, ccs, state.pop_nodes, state.pop_conns)
|
distances = o2p_distance_func(cns[i], ccs[i], state.pop_nodes, state.pop_conns)
|
||||||
|
|
||||||
# find the closest one
|
# find the closest one
|
||||||
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
|
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
|
||||||
|
|
||||||
i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i])
|
i2s = i2s.at[closest_idx].set(state.species_keys[i])
|
||||||
cns = cns.set(i, state.pop_nodes[closest_idx])
|
cns = cns.at[i].set(state.pop_nodes[closest_idx])
|
||||||
ccs = ccs.set(i, state.pop_conns[closest_idx])
|
ccs = ccs.at[i].set(state.pop_conns[closest_idx])
|
||||||
|
|
||||||
# the genome with closest_idx will become the new center, thus its distance to center is 0.
|
# the genome with closest_idx will become the new center, thus its distance to center is 0.
|
||||||
o2c = o2c.at[closest_idx].set(0)
|
o2c = o2c.at[closest_idx].set(0)
|
||||||
@@ -346,8 +350,8 @@ class DefaultSpecies:
|
|||||||
o2c = o2c.at[idx].set(0)
|
o2c = o2c.at[idx].set(0)
|
||||||
|
|
||||||
# update center genomes
|
# update center genomes
|
||||||
cns = cns.set(i, state.pop_nodes[idx])
|
cns = cns.at[i].set(state.pop_nodes[idx])
|
||||||
ccs = ccs.set(i, state.pop_conns[idx])
|
ccs = ccs.at[i].set(state.pop_conns[idx])
|
||||||
|
|
||||||
# find the members for the new species
|
# find the members for the new species
|
||||||
i2s, o2c = speciate_by_threshold(i, i2s, cns, ccs, sk, o2c)
|
i2s, o2c = speciate_by_threshold(i, i2s, cns, ccs, sk, o2c)
|
||||||
@@ -384,7 +388,7 @@ class DefaultSpecies:
|
|||||||
_, idx2species, center_nodes, center_conns, species_keys, _, next_species_key = jax.lax.while_loop(
|
_, idx2species, center_nodes, center_conns, species_keys, _, next_species_key = jax.lax.while_loop(
|
||||||
cond_func,
|
cond_func,
|
||||||
body_func,
|
body_func,
|
||||||
(0, state.idx2species, state.center_nodes, center_conns, state.species_info.species_keys, o2c_distances,
|
(0, state.idx2species, center_nodes, center_conns, state.species_keys, o2c_distances,
|
||||||
state.next_species_key)
|
state.next_species_key)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -401,8 +405,8 @@ class DefaultSpecies:
|
|||||||
def count_members(idx):
|
def count_members(idx):
|
||||||
return jax.lax.cond(
|
return jax.lax.cond(
|
||||||
jnp.isnan(species_keys[idx]), # if the species is not existing
|
jnp.isnan(species_keys[idx]), # if the species is not existing
|
||||||
lambda _: jnp.nan, # nan
|
lambda: jnp.nan, # nan
|
||||||
lambda _: jnp.sum(idx2species == species_keys[idx], dtype=jnp.float32) # count members
|
lambda: jnp.sum(idx2species == species_keys[idx], dtype=jnp.float32) # count members
|
||||||
)
|
)
|
||||||
|
|
||||||
member_count = jax.vmap(count_members)(self.species_arange)
|
member_count = jax.vmap(count_members)(self.species_arange)
|
||||||
@@ -422,7 +426,8 @@ class DefaultSpecies:
|
|||||||
"""
|
"""
|
||||||
The distance between two genomes
|
The distance between two genomes
|
||||||
"""
|
"""
|
||||||
return self.node_distance(nodes1, nodes2) + self.conn_distance(conns1, conns2)
|
d = self.node_distance(nodes1, nodes2) + self.conn_distance(conns1, conns2)
|
||||||
|
return d
|
||||||
|
|
||||||
def node_distance(self, nodes1, nodes2):
|
def node_distance(self, nodes1, nodes2):
|
||||||
"""
|
"""
|
||||||
@@ -494,18 +499,18 @@ def initialize_population(pop_size, genome):
|
|||||||
o_nodes[input_idx, 0] = genome.input_idx
|
o_nodes[input_idx, 0] = genome.input_idx
|
||||||
o_nodes[output_idx, 0] = genome.output_idx
|
o_nodes[output_idx, 0] = genome.output_idx
|
||||||
o_nodes[new_node_key, 0] = new_node_key # one hidden node
|
o_nodes[new_node_key, 0] = new_node_key # one hidden node
|
||||||
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = genome.node_gene.new_attrs()
|
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = genome.node_gene.new_custom_attrs()
|
||||||
o_nodes[new_node_key, 1:] = genome.node_gene.new_attrs() # one hidden node
|
o_nodes[new_node_key, 1:] = genome.node_gene.new_custom_attrs() # one hidden node
|
||||||
|
|
||||||
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)] # input nodes to hidden
|
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)] # input nodes to hidden
|
||||||
o_conns[input_idx, 0:2] = input_conns # in key, out key
|
o_conns[input_idx, 0:2] = input_conns # in key, out key
|
||||||
o_conns[input_idx, 2] = True # enabled
|
o_conns[input_idx, 2] = True # enabled
|
||||||
o_conns[input_idx, 3:] = genome.conn_gene.new_conn_attrs()
|
o_conns[input_idx, 3:] = genome.conn_gene.new_custom_attrs()
|
||||||
|
|
||||||
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx] # hidden to output nodes
|
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx] # hidden to output nodes
|
||||||
o_conns[output_idx, 0:2] = output_conns # in key, out key
|
o_conns[output_idx, 0:2] = output_conns # in key, out key
|
||||||
o_conns[output_idx, 2] = True # enabled
|
o_conns[output_idx, 2] = True # enabled
|
||||||
o_conns[output_idx, 3:] = genome.conn_gene.new_conn_attrs()
|
o_conns[output_idx, 3:] = genome.conn_gene.new_custom_attrs()
|
||||||
|
|
||||||
# repeat origin genome for P times to create population
|
# repeat origin genome for P times to create population
|
||||||
pop_nodes = np.tile(o_nodes, (pop_size, 1, 1))
|
pop_nodes = np.tile(o_nodes, (pop_size, 1, 1))
|
||||||
|
|||||||
@@ -1,38 +1,36 @@
|
|||||||
import jax.numpy as jnp
|
|
||||||
|
|
||||||
from config import *
|
|
||||||
from pipeline import Pipeline
|
from pipeline import Pipeline
|
||||||
from algorithm import NEAT
|
from algorithm.neat import *
|
||||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
|
||||||
from problem.rl_env import BraxEnv, BraxConfig
|
|
||||||
|
|
||||||
|
|
||||||
def example_conf():
|
|
||||||
return Config(
|
|
||||||
basic=BasicConfig(
|
|
||||||
seed=42,
|
|
||||||
fitness_target=10000,
|
|
||||||
pop_size=100
|
|
||||||
),
|
|
||||||
neat=NeatConfig(
|
|
||||||
inputs=27,
|
|
||||||
outputs=8,
|
|
||||||
),
|
|
||||||
gene=NormalGeneConfig(
|
|
||||||
activation_default=Act.tanh,
|
|
||||||
activation_options=(Act.tanh,),
|
|
||||||
),
|
|
||||||
problem=BraxConfig(
|
|
||||||
env_name="ant"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
from problem.rl_env import BraxEnv
|
||||||
|
from utils import Act
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
conf = example_conf()
|
pipeline = Pipeline(
|
||||||
|
algorithm=NEAT(
|
||||||
|
species=DefaultSpecies(
|
||||||
|
genome=DefaultGenome(
|
||||||
|
num_inputs=27,
|
||||||
|
num_outputs=8,
|
||||||
|
max_nodes=50,
|
||||||
|
max_conns=100,
|
||||||
|
node_gene=DefaultNodeGene(
|
||||||
|
activation_options=(Act.tanh,),
|
||||||
|
activation_default=Act.tanh,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
pop_size=1000,
|
||||||
|
species_size=10,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
problem=BraxEnv(
|
||||||
|
env_name='ant',
|
||||||
|
),
|
||||||
|
generation_limit=10000,
|
||||||
|
fitness_target=5000
|
||||||
|
)
|
||||||
|
|
||||||
algorithm = NEAT(conf, NormalGene)
|
# initialize state
|
||||||
pipeline = Pipeline(conf, algorithm, BraxEnv)
|
|
||||||
state = pipeline.setup()
|
state = pipeline.setup()
|
||||||
pipeline.pre_compile(state)
|
# print(state)
|
||||||
state, best = pipeline.auto_run(state)
|
# run until terminate
|
||||||
|
state, best = pipeline.auto_run(state)
|
||||||
@@ -1,42 +1,36 @@
|
|||||||
import jax.numpy as jnp
|
|
||||||
|
|
||||||
from config import *
|
|
||||||
from pipeline import Pipeline
|
from pipeline import Pipeline
|
||||||
from algorithm import NEAT
|
from algorithm.neat import *
|
||||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
|
||||||
from problem.rl_env import BraxEnv, BraxConfig
|
|
||||||
|
|
||||||
|
|
||||||
# ['ant', 'halfcheetah', 'hopper', 'humanoid', 'humanoidstandup', 'inverted_pendulum', 'inverted_double_pendulum', 'pusher', 'reacher', 'walker2d']
|
|
||||||
|
|
||||||
|
|
||||||
def example_conf():
|
|
||||||
return Config(
|
|
||||||
basic=BasicConfig(
|
|
||||||
seed=42,
|
|
||||||
fitness_target=10000,
|
|
||||||
generation_limit=10,
|
|
||||||
pop_size=100
|
|
||||||
),
|
|
||||||
neat=NeatConfig(
|
|
||||||
inputs=17,
|
|
||||||
outputs=6,
|
|
||||||
),
|
|
||||||
gene=NormalGeneConfig(
|
|
||||||
activation_default=Act.tanh,
|
|
||||||
activation_options=(Act.tanh,),
|
|
||||||
),
|
|
||||||
problem=BraxConfig(
|
|
||||||
env_name="halfcheetah"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
from problem.rl_env import BraxEnv
|
||||||
|
from utils import Act
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
conf = example_conf()
|
pipeline = Pipeline(
|
||||||
algorithm = NEAT(conf, NormalGene)
|
algorithm=NEAT(
|
||||||
pipeline = Pipeline(conf, algorithm, BraxEnv)
|
species=DefaultSpecies(
|
||||||
|
genome=DefaultGenome(
|
||||||
|
num_inputs=17,
|
||||||
|
num_outputs=6,
|
||||||
|
max_nodes=50,
|
||||||
|
max_conns=100,
|
||||||
|
node_gene=DefaultNodeGene(
|
||||||
|
activation_options=(Act.tanh,),
|
||||||
|
activation_default=Act.tanh,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
pop_size=1000,
|
||||||
|
species_size=10,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
problem=BraxEnv(
|
||||||
|
env_name='halhcheetah',
|
||||||
|
),
|
||||||
|
generation_limit=10000,
|
||||||
|
fitness_target=5000
|
||||||
|
)
|
||||||
|
|
||||||
|
# initialize state
|
||||||
state = pipeline.setup()
|
state = pipeline.setup()
|
||||||
pipeline.pre_compile(state)
|
# print(state)
|
||||||
state, best = pipeline.auto_run(state)
|
# run until terminate
|
||||||
pipeline.show(state, best, save_path="half_cheetah.gif", )
|
state, best = pipeline.auto_run(state)
|
||||||
@@ -1,38 +1,36 @@
|
|||||||
import jax.numpy as jnp
|
|
||||||
|
|
||||||
from config import *
|
|
||||||
from pipeline import Pipeline
|
from pipeline import Pipeline
|
||||||
from algorithm import NEAT
|
from algorithm.neat import *
|
||||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
|
||||||
from problem.rl_env import BraxEnv, BraxConfig
|
|
||||||
|
|
||||||
|
|
||||||
def example_conf():
|
|
||||||
return Config(
|
|
||||||
basic=BasicConfig(
|
|
||||||
seed=42,
|
|
||||||
fitness_target=10000,
|
|
||||||
pop_size=1000
|
|
||||||
),
|
|
||||||
neat=NeatConfig(
|
|
||||||
inputs=11,
|
|
||||||
outputs=2,
|
|
||||||
),
|
|
||||||
gene=NormalGeneConfig(
|
|
||||||
activation_default=Act.tanh,
|
|
||||||
activation_options=(Act.tanh,),
|
|
||||||
),
|
|
||||||
problem=BraxConfig(
|
|
||||||
env_name="reacher"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
from problem.rl_env import BraxEnv
|
||||||
|
from utils import Act
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
conf = example_conf()
|
pipeline = Pipeline(
|
||||||
|
algorithm=NEAT(
|
||||||
|
species=DefaultSpecies(
|
||||||
|
genome=DefaultGenome(
|
||||||
|
num_inputs=11,
|
||||||
|
num_outputs=2,
|
||||||
|
max_nodes=50,
|
||||||
|
max_conns=100,
|
||||||
|
node_gene=DefaultNodeGene(
|
||||||
|
activation_options=(Act.tanh,),
|
||||||
|
activation_default=Act.tanh,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
pop_size=100,
|
||||||
|
species_size=10,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
problem=BraxEnv(
|
||||||
|
env_name='reacher',
|
||||||
|
),
|
||||||
|
generation_limit=10000,
|
||||||
|
fitness_target=5000
|
||||||
|
)
|
||||||
|
|
||||||
algorithm = NEAT(conf, NormalGene)
|
# initialize state
|
||||||
pipeline = Pipeline(conf, algorithm, BraxEnv)
|
|
||||||
state = pipeline.setup()
|
state = pipeline.setup()
|
||||||
pipeline.pre_compile(state)
|
# print(state)
|
||||||
state, best = pipeline.auto_run(state)
|
# run until terminate
|
||||||
|
state, best = pipeline.auto_run(state)
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
import imageio
|
|
||||||
import jax
|
|
||||||
|
|
||||||
import brax
|
|
||||||
from brax import envs
|
|
||||||
from brax.io import image
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
import time
|
|
||||||
from tqdm import tqdm
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def inference_func(key, *args):
|
|
||||||
return jax.random.normal(key, shape=(env.action_size,))
|
|
||||||
|
|
||||||
|
|
||||||
env_name = "ant"
|
|
||||||
backend = "generalized"
|
|
||||||
|
|
||||||
env = envs.create(env_name=env_name, backend=backend)
|
|
||||||
|
|
||||||
jit_env_reset = jax.jit(env.reset)
|
|
||||||
jit_env_step = jax.jit(env.step)
|
|
||||||
jit_inference_fn = jax.jit(inference_func)
|
|
||||||
|
|
||||||
rng = jax.random.PRNGKey(seed=1)
|
|
||||||
ori_state = jit_env_reset(rng=rng)
|
|
||||||
state = ori_state
|
|
||||||
|
|
||||||
render_history = []
|
|
||||||
|
|
||||||
for i in range(100):
|
|
||||||
act_rng, rng = jax.random.split(rng)
|
|
||||||
|
|
||||||
tic = time.time()
|
|
||||||
act = jit_inference_fn(act_rng, state.obs)
|
|
||||||
state = jit_env_step(state, act)
|
|
||||||
print("step time: ", time.time() - tic)
|
|
||||||
|
|
||||||
render_history.append(state.pipeline_state)
|
|
||||||
|
|
||||||
# img = image.render_array(sys=env.sys, state=pipeline_state, width=512, height=512)
|
|
||||||
# print("render time: ", time.time() - tic)
|
|
||||||
|
|
||||||
# plt.imsave("../images/ant_{}.png".format(i), img)
|
|
||||||
|
|
||||||
reward = state.reward
|
|
||||||
done = state.done
|
|
||||||
print(i, reward)
|
|
||||||
|
|
||||||
render_history = jax.device_get(render_history)
|
|
||||||
# print(render_history)
|
|
||||||
|
|
||||||
imgs = [image.render_array(sys=env.sys, state=s, width=512, height=512) for s in tqdm(render_history)]
|
|
||||||
|
|
||||||
|
|
||||||
# for i, s in enumerate(tqdm(render_history)):
|
|
||||||
# img = image.render_array(sys=env.sys, state=s, width=512, height=512)
|
|
||||||
# print(img.shape)
|
|
||||||
# # print(type(img))
|
|
||||||
# plt.imsave("../images/ant_{}.png".format(i), img)
|
|
||||||
|
|
||||||
|
|
||||||
def create_gif(image_list, gif_name, duration):
|
|
||||||
with imageio.get_writer(gif_name, mode='I', duration=duration) as writer:
|
|
||||||
for image in image_list:
|
|
||||||
# 确保图像的数据类型正确
|
|
||||||
formatted_image = np.array(image, dtype=np.uint8)
|
|
||||||
writer.append_data(formatted_image)
|
|
||||||
|
|
||||||
|
|
||||||
create_gif(imgs, "../images/ant.gif", 0.1)
|
|
||||||
@@ -1,54 +0,0 @@
|
|||||||
import brax
|
|
||||||
from brax import envs
|
|
||||||
from brax.envs.wrappers import gym as gym_wrapper
|
|
||||||
from brax.io import image
|
|
||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
# print(f"Using Brax {brax.__version__}, Jax {jax.__version__}")
|
|
||||||
# print("From GymWrapper, env.reset()")
|
|
||||||
# try:
|
|
||||||
# env = envs.create("inverted_pendulum",
|
|
||||||
# batch_size=1,
|
|
||||||
# episode_length=150,
|
|
||||||
# backend='generalized')
|
|
||||||
# env = gym_wrapper.GymWrapper(env)
|
|
||||||
# env.reset()
|
|
||||||
# img = env.render(mode='rgb_array')
|
|
||||||
# plt.imshow(img)
|
|
||||||
# except Exception:
|
|
||||||
# traceback.print_exc()
|
|
||||||
#
|
|
||||||
# print("From GymWrapper, env.reset() and action")
|
|
||||||
# try:
|
|
||||||
# env = envs.create("inverted_pendulum",
|
|
||||||
# batch_size=1,
|
|
||||||
# episode_length=150,
|
|
||||||
# backend='generalized')
|
|
||||||
# env = gym_wrapper.GymWrapper(env)
|
|
||||||
# env.reset()
|
|
||||||
# action = jnp.zeros(env.action_space.shape)
|
|
||||||
# env.step(action)
|
|
||||||
# img = env.render(mode='rgb_array')
|
|
||||||
# plt.imshow(img)
|
|
||||||
# except Exception:
|
|
||||||
# traceback.print_exc()
|
|
||||||
|
|
||||||
print("From brax env")
|
|
||||||
try:
|
|
||||||
env = envs.create("inverted_pendulum",
|
|
||||||
batch_size=1,
|
|
||||||
episode_length=150,
|
|
||||||
backend='generalized')
|
|
||||||
key = jax.random.PRNGKey(0)
|
|
||||||
initial_env_state = env.reset(key)
|
|
||||||
base_state = initial_env_state.pipeline_state
|
|
||||||
pipeline_state = env.pipeline_init(base_state.q.ravel(), base_state.qd.ravel())
|
|
||||||
img = image.render_array(sys=env.sys, state=pipeline_state, width=256, height=256)
|
|
||||||
print(f"pixel values: [{img.min()}, {img.max()}]")
|
|
||||||
plt.imshow(img)
|
|
||||||
plt.show()
|
|
||||||
except Exception:
|
|
||||||
traceback.print_exc()
|
|
||||||
@@ -1,32 +1,31 @@
|
|||||||
from config import *
|
|
||||||
from pipeline import Pipeline
|
from pipeline import Pipeline
|
||||||
from algorithm import NEAT
|
from algorithm.neat import *
|
||||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
|
||||||
from problem.func_fit import XOR, FuncFitConfig
|
from problem.func_fit import XOR3d
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# running config
|
pipeline = Pipeline(
|
||||||
config = Config(
|
algorithm=NEAT(
|
||||||
basic=BasicConfig(
|
species=DefaultSpecies(
|
||||||
seed=42,
|
genome=DefaultGenome(
|
||||||
fitness_target=-1e-2,
|
num_inputs=3,
|
||||||
pop_size=10000
|
num_outputs=1,
|
||||||
|
max_nodes=50,
|
||||||
|
max_conns=100,
|
||||||
|
),
|
||||||
|
pop_size=10000,
|
||||||
|
species_size=10,
|
||||||
|
compatibility_threshold=3.5,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
neat=NeatConfig(
|
problem=XOR3d(),
|
||||||
inputs=2,
|
generation_limit=10000,
|
||||||
outputs=1
|
fitness_target=-1e-8
|
||||||
),
|
|
||||||
gene=NormalGeneConfig(),
|
|
||||||
problem=FuncFitConfig(
|
|
||||||
error_method='rmse'
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
# define algorithm: NEAT with NormalGene
|
|
||||||
algorithm = NEAT(config, NormalGene)
|
|
||||||
# full pipeline
|
|
||||||
pipeline = Pipeline(config, algorithm, XOR)
|
|
||||||
# initialize state
|
# initialize state
|
||||||
state = pipeline.setup()
|
state = pipeline.setup()
|
||||||
|
# print(state)
|
||||||
# run until terminate
|
# run until terminate
|
||||||
state, best = pipeline.auto_run(state)
|
state, best = pipeline.auto_run(state)
|
||||||
# show result
|
# show result
|
||||||
|
|||||||
51
examples/func_fit/xor3d_hyperneat.py
Normal file
51
examples/func_fit/xor3d_hyperneat.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
from pipeline import Pipeline
|
||||||
|
from algorithm.neat import *
|
||||||
|
from algorithm.hyperneat import *
|
||||||
|
from utils import Act
|
||||||
|
|
||||||
|
from problem.func_fit import XOR3d
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pipeline = Pipeline(
|
||||||
|
algorithm=HyperNEAT(
|
||||||
|
substrate=FullSubstrate(
|
||||||
|
input_coors=[(-1, -1), (0.333, -1), (-0.333, -1), (1, -1)],
|
||||||
|
hidden_coors=[
|
||||||
|
(-1, -0.5), (0.333, -0.5), (-0.333, -0.5), (1, -0.5),
|
||||||
|
(-1, 0), (0.333, 0), (-0.333, 0), (1, 0),
|
||||||
|
(-1, 0.5), (0.333, 0.5), (-0.333, 0.5), (1, 0.5),
|
||||||
|
],
|
||||||
|
output_coors=[(0, 1), ],
|
||||||
|
),
|
||||||
|
neat=NEAT(
|
||||||
|
species=DefaultSpecies(
|
||||||
|
genome=DefaultGenome(
|
||||||
|
num_inputs=4, # [-1, -1, -1, 0]
|
||||||
|
num_outputs=1,
|
||||||
|
max_nodes=50,
|
||||||
|
max_conns=100,
|
||||||
|
node_gene=DefaultNodeGene(
|
||||||
|
activation_default=Act.tanh,
|
||||||
|
activation_options=(Act.tanh,),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
pop_size=10000,
|
||||||
|
species_size=10,
|
||||||
|
compatibility_threshold=3.5,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
activation=Act.sigmoid,
|
||||||
|
activate_time=10,
|
||||||
|
),
|
||||||
|
problem=XOR3d(),
|
||||||
|
generation_limit=300,
|
||||||
|
fitness_target=-1e-6
|
||||||
|
)
|
||||||
|
|
||||||
|
# initialize state
|
||||||
|
state = pipeline.setup()
|
||||||
|
# print(state)
|
||||||
|
# run until terminate
|
||||||
|
state, best = pipeline.auto_run(state)
|
||||||
|
# show result
|
||||||
|
pipeline.show(state, best)
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
from config import *
|
|
||||||
from pipeline import Pipeline
|
|
||||||
from algorithm.neat import NormalGene, NormalGeneConfig
|
|
||||||
from algorithm.hyperneat import HyperNEAT, NormalSubstrate, NormalSubstrateConfig
|
|
||||||
from problem.func_fit import XOR3d, FuncFitConfig
|
|
||||||
from utils import Act
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
config = Config(
|
|
||||||
basic=BasicConfig(
|
|
||||||
seed=42,
|
|
||||||
fitness_target=0,
|
|
||||||
pop_size=1000
|
|
||||||
),
|
|
||||||
neat=NeatConfig(
|
|
||||||
max_nodes=50,
|
|
||||||
max_conns=100,
|
|
||||||
max_species=30,
|
|
||||||
inputs=4,
|
|
||||||
outputs=1
|
|
||||||
),
|
|
||||||
hyperneat=HyperNeatConfig(
|
|
||||||
inputs=3,
|
|
||||||
outputs=1
|
|
||||||
),
|
|
||||||
substrate=NormalSubstrateConfig(
|
|
||||||
input_coors=((-1, -1), (-0.5, -1), (0.5, -1), (1, -1)),
|
|
||||||
),
|
|
||||||
gene=NormalGeneConfig(
|
|
||||||
activation_default=Act.tanh,
|
|
||||||
activation_options=(Act.tanh, ),
|
|
||||||
),
|
|
||||||
problem=FuncFitConfig()
|
|
||||||
)
|
|
||||||
|
|
||||||
algorithm = HyperNEAT(config, NormalGene, NormalSubstrate)
|
|
||||||
pipeline = Pipeline(config, algorithm, XOR3d)
|
|
||||||
state = pipeline.setup()
|
|
||||||
state, best = pipeline.auto_run(state)
|
|
||||||
pipeline.show(state, best)
|
|
||||||
@@ -1,41 +1,41 @@
|
|||||||
from config import *
|
|
||||||
from pipeline import Pipeline
|
from pipeline import Pipeline
|
||||||
from algorithm import NEAT
|
from algorithm.neat import *
|
||||||
from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig
|
|
||||||
from problem.func_fit import XOR3d, FuncFitConfig
|
|
||||||
|
|
||||||
|
from problem.func_fit import XOR3d
|
||||||
|
from utils.activation import ACT_ALL
|
||||||
|
from utils.aggregation import AGG_ALL
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
config = Config(
|
pipeline = Pipeline(
|
||||||
basic=BasicConfig(
|
seed=0,
|
||||||
seed=42,
|
algorithm=NEAT(
|
||||||
fitness_target=-1e-2,
|
species=DefaultSpecies(
|
||||||
generation_limit=300,
|
genome=RecurrentGenome(
|
||||||
pop_size=1000
|
num_inputs=3,
|
||||||
|
num_outputs=1,
|
||||||
|
max_nodes=50,
|
||||||
|
max_conns=100,
|
||||||
|
activate_time=5,
|
||||||
|
node_gene=DefaultNodeGene(
|
||||||
|
activation_options=ACT_ALL,
|
||||||
|
# aggregation_options=AGG_ALL,
|
||||||
|
activation_replace_rate=0.2
|
||||||
|
),
|
||||||
|
),
|
||||||
|
pop_size=10000,
|
||||||
|
species_size=10,
|
||||||
|
compatibility_threshold=3.5,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
neat=NeatConfig(
|
problem=XOR3d(),
|
||||||
network_type="recurrent",
|
generation_limit=10000,
|
||||||
max_nodes=50,
|
fitness_target=-1e-8
|
||||||
max_conns=100,
|
|
||||||
max_species=30,
|
|
||||||
conn_add=0.5,
|
|
||||||
conn_delete=0.5,
|
|
||||||
node_add=0.4,
|
|
||||||
node_delete=0.4,
|
|
||||||
inputs=3,
|
|
||||||
outputs=1
|
|
||||||
),
|
|
||||||
gene=RecurrentGeneConfig(
|
|
||||||
activate_times=10
|
|
||||||
),
|
|
||||||
problem=FuncFitConfig(
|
|
||||||
error_method='rmse'
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
algorithm = NEAT(config, RecurrentGene)
|
# initialize state
|
||||||
pipeline = Pipeline(config, algorithm, XOR3d)
|
|
||||||
state = pipeline.setup()
|
state = pipeline.setup()
|
||||||
pipeline.pre_compile(state)
|
# print(state)
|
||||||
|
# run until terminate
|
||||||
state, best = pipeline.auto_run(state)
|
state, best = pipeline.auto_run(state)
|
||||||
|
# show result
|
||||||
pipeline.show(state, best)
|
pipeline.show(state, best)
|
||||||
|
|||||||
@@ -1,36 +0,0 @@
|
|||||||
from config import *
|
|
||||||
from pipeline import Pipeline
|
|
||||||
from algorithm import NEAT
|
|
||||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
|
||||||
from problem.func_fit import XOR, FuncFitConfig
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
config = Config(
|
|
||||||
basic=BasicConfig(
|
|
||||||
seed=42,
|
|
||||||
fitness_target=-1e-2,
|
|
||||||
pop_size=10000
|
|
||||||
),
|
|
||||||
neat=NeatConfig(
|
|
||||||
max_nodes=50,
|
|
||||||
max_conns=100,
|
|
||||||
max_species=30,
|
|
||||||
conn_add=0.8,
|
|
||||||
conn_delete=0,
|
|
||||||
node_add=0.4,
|
|
||||||
node_delete=0,
|
|
||||||
inputs=2,
|
|
||||||
outputs=1
|
|
||||||
),
|
|
||||||
gene=NormalGeneConfig(),
|
|
||||||
problem=FuncFitConfig(
|
|
||||||
error_method='rmse'
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
algorithm = NEAT(config, NormalGene)
|
|
||||||
pipeline = Pipeline(config, algorithm, XOR)
|
|
||||||
state = pipeline.setup()
|
|
||||||
pipeline.pre_compile(state)
|
|
||||||
state, best = pipeline.auto_run(state)
|
|
||||||
pipeline.show(state, best)
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
import jax.numpy as jnp
|
|
||||||
|
|
||||||
from config import *
|
|
||||||
from pipeline import Pipeline
|
|
||||||
from algorithm import NEAT
|
|
||||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
|
||||||
from problem.rl_env import GymNaxConfig, GymNaxEnv
|
|
||||||
|
|
||||||
|
|
||||||
def example_conf():
|
|
||||||
return Config(
|
|
||||||
basic=BasicConfig(
|
|
||||||
seed=42,
|
|
||||||
fitness_target=0,
|
|
||||||
pop_size=10000
|
|
||||||
),
|
|
||||||
neat=NeatConfig(
|
|
||||||
inputs=6,
|
|
||||||
outputs=3,
|
|
||||||
),
|
|
||||||
gene=NormalGeneConfig(
|
|
||||||
activation_default=Act.tanh,
|
|
||||||
activation_options=(Act.tanh,),
|
|
||||||
),
|
|
||||||
problem=GymNaxConfig(
|
|
||||||
env_name='Acrobot-v1',
|
|
||||||
output_transform=lambda out: jnp.argmax(out) # the action of acrobot is {0, 1, 2}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
conf = example_conf()
|
|
||||||
|
|
||||||
algorithm = NEAT(conf, NormalGene)
|
|
||||||
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
|
|
||||||
state = pipeline.setup()
|
|
||||||
pipeline.pre_compile(state)
|
|
||||||
state, best = pipeline.auto_run(state)
|
|
||||||
34
examples/gymnax/arcbot.py
Normal file
34
examples/gymnax/arcbot.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
from pipeline import Pipeline
|
||||||
|
from algorithm.neat import *
|
||||||
|
|
||||||
|
from problem.rl_env import GymNaxEnv
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pipeline = Pipeline(
|
||||||
|
algorithm=NEAT(
|
||||||
|
species=DefaultSpecies(
|
||||||
|
genome=DefaultGenome(
|
||||||
|
num_inputs=6,
|
||||||
|
num_outputs=3,
|
||||||
|
max_nodes=50,
|
||||||
|
max_conns=100,
|
||||||
|
output_transform=lambda out: jnp.argmax(out) # the action of acrobot is {0, 1, 2}
|
||||||
|
),
|
||||||
|
pop_size=10000,
|
||||||
|
species_size=10,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
problem=GymNaxEnv(
|
||||||
|
env_name='Acrobot-v1',
|
||||||
|
),
|
||||||
|
generation_limit=10000,
|
||||||
|
fitness_target=-62
|
||||||
|
)
|
||||||
|
|
||||||
|
# initialize state
|
||||||
|
state = pipeline.setup()
|
||||||
|
# print(state)
|
||||||
|
# run until terminate
|
||||||
|
state, best = pipeline.auto_run(state)
|
||||||
@@ -1,84 +1,34 @@
|
|||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
from config import *
|
|
||||||
from pipeline import Pipeline
|
from pipeline import Pipeline
|
||||||
from algorithm import NEAT
|
from algorithm.neat import *
|
||||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
|
||||||
from problem.rl_env import GymNaxConfig, GymNaxEnv
|
|
||||||
|
|
||||||
|
|
||||||
def example_conf1():
|
|
||||||
return Config(
|
|
||||||
basic=BasicConfig(
|
|
||||||
seed=42,
|
|
||||||
fitness_target=500,
|
|
||||||
pop_size=10000
|
|
||||||
),
|
|
||||||
neat=NeatConfig(
|
|
||||||
inputs=4,
|
|
||||||
outputs=1,
|
|
||||||
),
|
|
||||||
gene=NormalGeneConfig(
|
|
||||||
activation_default=Act.sigmoid,
|
|
||||||
activation_options=(Act.sigmoid,),
|
|
||||||
),
|
|
||||||
problem=GymNaxConfig(
|
|
||||||
env_name='CartPole-v1',
|
|
||||||
output_transform=lambda out: jnp.where(out[0] > 0.5, 1, 0) # the action of cartpole is {0, 1}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def example_conf2():
|
|
||||||
return Config(
|
|
||||||
basic=BasicConfig(
|
|
||||||
seed=42,
|
|
||||||
fitness_target=500,
|
|
||||||
pop_size=10000
|
|
||||||
),
|
|
||||||
neat=NeatConfig(
|
|
||||||
inputs=4,
|
|
||||||
outputs=1,
|
|
||||||
),
|
|
||||||
gene=NormalGeneConfig(
|
|
||||||
activation_default=Act.tanh,
|
|
||||||
activation_options=(Act.tanh,),
|
|
||||||
),
|
|
||||||
problem=GymNaxConfig(
|
|
||||||
env_name='CartPole-v1',
|
|
||||||
output_transform=lambda out: jnp.where(out[0] > 0, 1, 0) # the action of cartpole is {0, 1}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def example_conf3():
|
|
||||||
return Config(
|
|
||||||
basic=BasicConfig(
|
|
||||||
seed=42,
|
|
||||||
fitness_target=501,
|
|
||||||
pop_size=10000
|
|
||||||
),
|
|
||||||
neat=NeatConfig(
|
|
||||||
inputs=4,
|
|
||||||
outputs=2,
|
|
||||||
),
|
|
||||||
gene=NormalGeneConfig(
|
|
||||||
activation_default=Act.tanh,
|
|
||||||
activation_options=(Act.tanh,),
|
|
||||||
),
|
|
||||||
problem=GymNaxConfig(
|
|
||||||
env_name='CartPole-v1',
|
|
||||||
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
from problem.rl_env import GymNaxEnv
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# all config files above can solve cartpole
|
pipeline = Pipeline(
|
||||||
conf = example_conf3()
|
algorithm=NEAT(
|
||||||
|
species=DefaultSpecies(
|
||||||
|
genome=DefaultGenome(
|
||||||
|
num_inputs=4,
|
||||||
|
num_outputs=2,
|
||||||
|
max_nodes=50,
|
||||||
|
max_conns=100,
|
||||||
|
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1}
|
||||||
|
),
|
||||||
|
pop_size=10000,
|
||||||
|
species_size=10,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
problem=GymNaxEnv(
|
||||||
|
env_name='CartPole-v1',
|
||||||
|
),
|
||||||
|
generation_limit=10000,
|
||||||
|
fitness_target=500
|
||||||
|
)
|
||||||
|
|
||||||
algorithm = NEAT(conf, NormalGene)
|
# initialize state
|
||||||
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
|
|
||||||
state = pipeline.setup()
|
state = pipeline.setup()
|
||||||
pipeline.pre_compile(state)
|
# print(state)
|
||||||
state, best = pipeline.auto_run(state)
|
# run until terminate
|
||||||
|
state, best = pipeline.auto_run(state)
|
||||||
@@ -1,39 +1,34 @@
|
|||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
from config import *
|
|
||||||
from pipeline import Pipeline
|
from pipeline import Pipeline
|
||||||
from algorithm import NEAT
|
from algorithm.neat import *
|
||||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
|
||||||
from problem.rl_env import GymNaxConfig, GymNaxEnv
|
|
||||||
|
|
||||||
|
|
||||||
def example_conf():
|
|
||||||
return Config(
|
|
||||||
basic=BasicConfig(
|
|
||||||
seed=42,
|
|
||||||
fitness_target=0,
|
|
||||||
pop_size=10000
|
|
||||||
),
|
|
||||||
neat=NeatConfig(
|
|
||||||
inputs=2,
|
|
||||||
outputs=3,
|
|
||||||
),
|
|
||||||
gene=NormalGeneConfig(
|
|
||||||
activation_default=Act.sigmoid,
|
|
||||||
activation_options=(Act.sigmoid,),
|
|
||||||
),
|
|
||||||
problem=GymNaxConfig(
|
|
||||||
env_name='MountainCar-v0',
|
|
||||||
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1, 2}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
from problem.rl_env import GymNaxEnv
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
conf = example_conf()
|
pipeline = Pipeline(
|
||||||
|
algorithm=NEAT(
|
||||||
|
species=DefaultSpecies(
|
||||||
|
genome=DefaultGenome(
|
||||||
|
num_inputs=2,
|
||||||
|
num_outputs=3,
|
||||||
|
max_nodes=50,
|
||||||
|
max_conns=100,
|
||||||
|
output_transform=lambda out: jnp.argmax(out) # the action of mountain car is {0, 1, 2}
|
||||||
|
),
|
||||||
|
pop_size=10000,
|
||||||
|
species_size=10,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
problem=GymNaxEnv(
|
||||||
|
env_name='MountainCar-v0',
|
||||||
|
),
|
||||||
|
generation_limit=10000,
|
||||||
|
fitness_target=0
|
||||||
|
)
|
||||||
|
|
||||||
algorithm = NEAT(conf, NormalGene)
|
# initialize state
|
||||||
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
|
|
||||||
state = pipeline.setup()
|
state = pipeline.setup()
|
||||||
pipeline.pre_compile(state)
|
# print(state)
|
||||||
state, best = pipeline.auto_run(state)
|
# run until terminate
|
||||||
|
state, best = pipeline.auto_run(state)
|
||||||
@@ -1,38 +1,36 @@
|
|||||||
import jax.numpy as jnp
|
|
||||||
|
|
||||||
from config import *
|
|
||||||
from pipeline import Pipeline
|
from pipeline import Pipeline
|
||||||
from algorithm import NEAT
|
from algorithm.neat import *
|
||||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
|
||||||
from problem.rl_env import GymNaxConfig, GymNaxEnv
|
|
||||||
|
|
||||||
|
|
||||||
def example_conf():
|
|
||||||
return Config(
|
|
||||||
basic=BasicConfig(
|
|
||||||
seed=42,
|
|
||||||
fitness_target=100,
|
|
||||||
pop_size=10000
|
|
||||||
),
|
|
||||||
neat=NeatConfig(
|
|
||||||
inputs=2,
|
|
||||||
outputs=1,
|
|
||||||
),
|
|
||||||
gene=NormalGeneConfig(
|
|
||||||
activation_default=Act.tanh,
|
|
||||||
activation_options=(Act.tanh,),
|
|
||||||
),
|
|
||||||
problem=GymNaxConfig(
|
|
||||||
env_name='MountainCarContinuous-v0'
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
from problem.rl_env import GymNaxEnv
|
||||||
|
from utils import Act
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
conf = example_conf()
|
pipeline = Pipeline(
|
||||||
|
algorithm=NEAT(
|
||||||
|
species=DefaultSpecies(
|
||||||
|
genome=DefaultGenome(
|
||||||
|
num_inputs=2,
|
||||||
|
num_outputs=1,
|
||||||
|
max_nodes=50,
|
||||||
|
max_conns=100,
|
||||||
|
node_gene=DefaultNodeGene(
|
||||||
|
activation_options=(Act.tanh, ),
|
||||||
|
activation_default=Act.tanh,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
pop_size=10000,
|
||||||
|
species_size=10,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
problem=GymNaxEnv(
|
||||||
|
env_name='MountainCarContinuous-v0',
|
||||||
|
),
|
||||||
|
generation_limit=10000,
|
||||||
|
fitness_target=500
|
||||||
|
)
|
||||||
|
|
||||||
algorithm = NEAT(conf, NormalGene)
|
# initialize state
|
||||||
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
|
|
||||||
state = pipeline.setup()
|
state = pipeline.setup()
|
||||||
pipeline.pre_compile(state)
|
# print(state)
|
||||||
state, best = pipeline.auto_run(state)
|
# run until terminate
|
||||||
|
state, best = pipeline.auto_run(state)
|
||||||
@@ -1,40 +1,37 @@
|
|||||||
import jax.numpy as jnp
|
|
||||||
|
|
||||||
from config import *
|
|
||||||
from pipeline import Pipeline
|
from pipeline import Pipeline
|
||||||
from algorithm import NEAT
|
from algorithm.neat import *
|
||||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
|
||||||
from problem.rl_env import GymNaxConfig, GymNaxEnv
|
|
||||||
|
|
||||||
|
|
||||||
def example_conf():
|
|
||||||
return Config(
|
|
||||||
basic=BasicConfig(
|
|
||||||
seed=42,
|
|
||||||
fitness_target=0,
|
|
||||||
pop_size=10000
|
|
||||||
),
|
|
||||||
neat=NeatConfig(
|
|
||||||
inputs=3,
|
|
||||||
outputs=1,
|
|
||||||
),
|
|
||||||
gene=NormalGeneConfig(
|
|
||||||
activation_default=Act.tanh,
|
|
||||||
activation_options=(Act.tanh,),
|
|
||||||
),
|
|
||||||
problem=GymNaxConfig(
|
|
||||||
env_name='Pendulum-v1',
|
|
||||||
output_transform=lambda out: out * 2 # the action of pendulum is [-2, 2]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
from problem.rl_env import GymNaxEnv
|
||||||
|
from utils import Act
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
conf = example_conf()
|
pipeline = Pipeline(
|
||||||
|
algorithm=NEAT(
|
||||||
|
species=DefaultSpecies(
|
||||||
|
genome=DefaultGenome(
|
||||||
|
num_inputs=3,
|
||||||
|
num_outputs=1,
|
||||||
|
max_nodes=50,
|
||||||
|
max_conns=100,
|
||||||
|
node_gene=DefaultNodeGene(
|
||||||
|
activation_options=(Act.tanh,),
|
||||||
|
activation_default=Act.tanh,
|
||||||
|
),
|
||||||
|
output_transform=lambda out: out * 2 # the action of pendulum is [-2, 2]
|
||||||
|
),
|
||||||
|
pop_size=10000,
|
||||||
|
species_size=10,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
problem=GymNaxEnv(
|
||||||
|
env_name='Pendulum-v1',
|
||||||
|
),
|
||||||
|
generation_limit=10000,
|
||||||
|
fitness_target=0
|
||||||
|
)
|
||||||
|
|
||||||
algorithm = NEAT(conf, NormalGene)
|
# initialize state
|
||||||
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
|
|
||||||
state = pipeline.setup()
|
state = pipeline.setup()
|
||||||
pipeline.pre_compile(state)
|
# print(state)
|
||||||
state, best = pipeline.auto_run(state)
|
# run until terminate
|
||||||
|
state, best = pipeline.auto_run(state)
|
||||||
@@ -1,36 +1,33 @@
|
|||||||
from config import *
|
import jax.numpy as jnp
|
||||||
|
|
||||||
from pipeline import Pipeline
|
from pipeline import Pipeline
|
||||||
from algorithm import NEAT
|
from algorithm.neat import *
|
||||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
|
||||||
from problem.rl_env import GymNaxConfig, GymNaxEnv
|
|
||||||
|
|
||||||
|
|
||||||
def example_conf():
|
|
||||||
return Config(
|
|
||||||
basic=BasicConfig(
|
|
||||||
seed=42,
|
|
||||||
fitness_target=500,
|
|
||||||
pop_size=10000
|
|
||||||
),
|
|
||||||
neat=NeatConfig(
|
|
||||||
inputs=8,
|
|
||||||
outputs=2,
|
|
||||||
),
|
|
||||||
gene=NormalGeneConfig(
|
|
||||||
activation_default=Act.sigmoid,
|
|
||||||
activation_options=(Act.sigmoid,),
|
|
||||||
),
|
|
||||||
problem=GymNaxConfig(
|
|
||||||
env_name='Reacher-misc',
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
from problem.rl_env import GymNaxEnv
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
conf = example_conf()
|
pipeline = Pipeline(
|
||||||
|
algorithm=NEAT(
|
||||||
|
species=DefaultSpecies(
|
||||||
|
genome=DefaultGenome(
|
||||||
|
num_inputs=8,
|
||||||
|
num_outputs=2,
|
||||||
|
max_nodes=50,
|
||||||
|
max_conns=100,
|
||||||
|
),
|
||||||
|
pop_size=10000,
|
||||||
|
species_size=10,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
problem=GymNaxEnv(
|
||||||
|
env_name='Reacher-misc',
|
||||||
|
),
|
||||||
|
generation_limit=10000,
|
||||||
|
fitness_target =500
|
||||||
|
)
|
||||||
|
|
||||||
algorithm = NEAT(conf, NormalGene)
|
# initialize state
|
||||||
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
|
|
||||||
state = pipeline.setup()
|
state = pipeline.setup()
|
||||||
pipeline.pre_compile(state)
|
# print(state)
|
||||||
state, best = pipeline.auto_run(state)
|
# run until terminate
|
||||||
|
state, best = pipeline.auto_run(state)
|
||||||
90
pipeline.py
90
pipeline.py
@@ -1,25 +1,23 @@
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Type
|
|
||||||
|
|
||||||
import jax
|
import jax, jax.numpy as jnp
|
||||||
import time
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from algorithm import NEAT, HyperNEAT
|
from algorithm import BaseAlgorithm
|
||||||
from config import Config
|
from problem import BaseProblem
|
||||||
from core import State, Algorithm, Problem
|
from utils import State
|
||||||
|
|
||||||
|
|
||||||
class Pipeline:
|
class Pipeline:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
algorithm: Algorithm,
|
algorithm: BaseAlgorithm,
|
||||||
problem: Problem,
|
problem: BaseProblem,
|
||||||
seed: int = 42,
|
seed: int = 42,
|
||||||
fitness_target: float = 1,
|
fitness_target: float = 1,
|
||||||
generation_limit: int = 1000,
|
generation_limit: int = 1000,
|
||||||
pop_size: int = 100,
|
|
||||||
):
|
):
|
||||||
assert problem.jitable, "Currently, problem must be jitable"
|
assert problem.jitable, "Currently, problem must be jitable"
|
||||||
|
|
||||||
@@ -28,17 +26,18 @@ class Pipeline:
|
|||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.fitness_target = fitness_target
|
self.fitness_target = fitness_target
|
||||||
self.generation_limit = generation_limit
|
self.generation_limit = generation_limit
|
||||||
self.pop_size = pop_size
|
self.pop_size = self.algorithm.pop_size
|
||||||
|
|
||||||
print(self.problem.input_shape, self.problem.output_shape)
|
print(self.problem.input_shape, self.problem.output_shape)
|
||||||
|
|
||||||
# TODO: make each algorithm's input_num and output_num
|
# TODO: make each algorithm's input_num and output_num
|
||||||
assert algorithm.input_num == self.problem.input_shape[-1], f"problem input shape {self.problem.input_shape}"
|
assert algorithm.num_inputs == self.problem.input_shape[-1], \
|
||||||
|
f"algorithm input shape is {algorithm.num_inputs} but problem input shape is {self.problem.input_shape}"
|
||||||
|
|
||||||
self.act_func = self.algorithm.act
|
# self.act_func = self.algorithm.act
|
||||||
|
|
||||||
for _ in range(len(self.problem.input_shape) - 1):
|
# for _ in range(len(self.problem.input_shape) - 1):
|
||||||
self.act_func = jax.vmap(self.act_func, in_axes=(None, 0, None))
|
# self.act_func = jax.vmap(self.act_func, in_axes=(None, 0, None))
|
||||||
|
|
||||||
self.best_genome = None
|
self.best_genome = None
|
||||||
self.best_fitness = float('-inf')
|
self.best_fitness = float('-inf')
|
||||||
@@ -46,41 +45,57 @@ class Pipeline:
|
|||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
key = jax.random.PRNGKey(self.seed)
|
key = jax.random.PRNGKey(self.seed)
|
||||||
algorithm_key, evaluate_key = jax.random.split(key, 2)
|
key, algorithm_key, evaluate_key = jax.random.split(key, 3)
|
||||||
|
|
||||||
# TODO: Problem should has setup function to maintain state
|
# TODO: Problem should has setup function to maintain state
|
||||||
return State(
|
return State(
|
||||||
|
randkey=key,
|
||||||
alg=self.algorithm.setup(algorithm_key),
|
alg=self.algorithm.setup(algorithm_key),
|
||||||
pro=self.problem.setup(evaluate_key),
|
pro=self.problem.setup(evaluate_key),
|
||||||
)
|
)
|
||||||
|
|
||||||
@partial(jax.jit, static_argnums=(0,))
|
|
||||||
def step(self, state):
|
def step(self, state):
|
||||||
key, sub_key = jax.random.split(state.evaluate_key)
|
key, sub_key = jax.random.split(state.randkey)
|
||||||
keys = jax.random.split(key, self.pop_size)
|
keys = jax.random.split(key, self.pop_size)
|
||||||
|
|
||||||
pop = self.algorithm.ask(state)
|
pop = self.algorithm.ask(state.alg)
|
||||||
|
|
||||||
pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0))(state, pop)
|
pop_transformed = jax.vmap(self.algorithm.transform)(pop)
|
||||||
|
|
||||||
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0))(keys, state, self.act_func,
|
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0))(
|
||||||
pop_transformed)
|
keys,
|
||||||
|
state.pro,
|
||||||
|
self.algorithm.forward,
|
||||||
|
pop_transformed
|
||||||
|
)
|
||||||
|
|
||||||
state = self.algorithm.tell(state, fitnesses)
|
fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses)
|
||||||
|
|
||||||
return state.update(evaluate_key=sub_key), fitnesses
|
alg_state = self.algorithm.tell(state.alg, fitnesses)
|
||||||
|
|
||||||
|
return state.update(
|
||||||
|
randkey=sub_key,
|
||||||
|
alg=alg_state,
|
||||||
|
), fitnesses
|
||||||
|
|
||||||
def auto_run(self, ini_state):
|
def auto_run(self, ini_state):
|
||||||
state = ini_state
|
state = ini_state
|
||||||
|
compiled_step = jax.jit(self.step).lower(ini_state).compile()
|
||||||
|
|
||||||
for _ in range(self.generation_limit):
|
for _ in range(self.generation_limit):
|
||||||
|
|
||||||
self.generation_timestamp = time.time()
|
self.generation_timestamp = time.time()
|
||||||
|
|
||||||
previous_pop = self.algorithm.ask(state)
|
previous_pop = self.algorithm.ask(state.alg)
|
||||||
|
|
||||||
state, fitnesses = self.step(state)
|
state, fitnesses = compiled_step(state)
|
||||||
|
|
||||||
fitnesses = jax.device_get(fitnesses)
|
fitnesses = jax.device_get(fitnesses)
|
||||||
|
for idx, fitnesses_i in enumerate(fitnesses):
|
||||||
|
if np.isnan(fitnesses_i):
|
||||||
|
print("Fitness is nan")
|
||||||
|
print(previous_pop[0][idx], previous_pop[1][idx])
|
||||||
|
assert False
|
||||||
|
|
||||||
self.analysis(state, previous_pop, fitnesses)
|
self.analysis(state, previous_pop, fitnesses)
|
||||||
|
|
||||||
@@ -102,22 +117,15 @@ class Pipeline:
|
|||||||
max_idx = np.argmax(fitnesses)
|
max_idx = np.argmax(fitnesses)
|
||||||
if fitnesses[max_idx] > self.best_fitness:
|
if fitnesses[max_idx] > self.best_fitness:
|
||||||
self.best_fitness = fitnesses[max_idx]
|
self.best_fitness = fitnesses[max_idx]
|
||||||
self.best_genome = pop[max_idx]
|
self.best_genome = pop[0][max_idx], pop[1][max_idx]
|
||||||
|
|
||||||
member_count = jax.device_get(state.species_info.member_count)
|
member_count = jax.device_get(self.algorithm.member_count(state.alg))
|
||||||
species_sizes = [int(i) for i in member_count if i > 0]
|
species_sizes = [int(i) for i in member_count if i > 0]
|
||||||
|
|
||||||
print(f"Generation: {state.generation}",
|
print(f"Generation: {self.algorithm.generation(state.alg)}",
|
||||||
f"species: {len(species_sizes)}, {species_sizes}",
|
f"species: {len(species_sizes)}, {species_sizes}",
|
||||||
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")
|
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")
|
||||||
|
|
||||||
def show(self, state, genome, *args, **kwargs):
|
def show(self, state, best, *args, **kwargs):
|
||||||
transformed = self.algorithm.transform(state, genome)
|
transformed = self.algorithm.transform(best)
|
||||||
self.problem.show(state.evaluate_key, state, self.act_func, transformed, *args, **kwargs)
|
self.problem.show(state.randkey, state.pro, self.algorithm.forward, transformed, *args, **kwargs)
|
||||||
|
|
||||||
def pre_compile(self, state):
|
|
||||||
tic = time.time()
|
|
||||||
print("start compile")
|
|
||||||
self.step.lower(self, state).compile()
|
|
||||||
print(f"compile finished, cost time: {time.time() - tic}s")
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,19 +1,14 @@
|
|||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
from config import ProblemConfig
|
from utils import State
|
||||||
from core.state import State
|
|
||||||
|
|
||||||
|
|
||||||
class BaseProblem:
|
class BaseProblem:
|
||||||
|
|
||||||
jitable = None
|
jitable = None
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def setup(self, randkey, state: State = State()):
|
def setup(self, randkey, state: State = State()):
|
||||||
"""initialize the state of the problem"""
|
"""initialize the state of the problem"""
|
||||||
raise NotImplementedError
|
pass
|
||||||
|
|
||||||
def evaluate(self, randkey, state: State, act_func: Callable, params):
|
def evaluate(self, randkey, state: State, act_func: Callable, params):
|
||||||
"""evaluate one individual"""
|
"""evaluate one individual"""
|
||||||
|
|||||||
@@ -1,24 +1,27 @@
|
|||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
from utils import State
|
||||||
from .. import BaseProblem
|
from .. import BaseProblem
|
||||||
|
|
||||||
class FuncFit(BaseProblem):
|
|
||||||
|
|
||||||
|
class FuncFit(BaseProblem):
|
||||||
jitable = True
|
jitable = True
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
error_method: str = 'mse'
|
error_method: str = 'mse'
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
assert error_method in {'mse', 'rmse', 'mae', 'mape'}
|
assert error_method in {'mse', 'rmse', 'mae', 'mape'}
|
||||||
self.error_method = error_method
|
self.error_method = error_method
|
||||||
|
|
||||||
|
def setup(self, randkey, state: State = State()):
|
||||||
|
return state
|
||||||
|
|
||||||
def evaluate(self, randkey, state, act_func, params):
|
def evaluate(self, randkey, state, act_func, params):
|
||||||
|
|
||||||
predict = act_func(state, self.inputs, params)
|
predict = jax.vmap(act_func, in_axes=(0, None))(self.inputs, params)
|
||||||
|
|
||||||
if self.error_method == 'mse':
|
if self.error_method == 'mse':
|
||||||
loss = jnp.mean((predict - self.targets) ** 2)
|
loss = jnp.mean((predict - self.targets) ** 2)
|
||||||
@@ -38,7 +41,7 @@ class FuncFit(BaseProblem):
|
|||||||
return -loss
|
return -loss
|
||||||
|
|
||||||
def show(self, randkey, state, act_func, params, *args, **kwargs):
|
def show(self, randkey, state, act_func, params, *args, **kwargs):
|
||||||
predict = act_func(state, self.inputs, params)
|
predict = jax.vmap(act_func, in_axes=(0, None))(self.inputs, params)
|
||||||
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
|
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
|
||||||
loss = -self.evaluate(randkey, state, act_func, params)
|
loss = -self.evaluate(randkey, state, act_func, params)
|
||||||
msg = ""
|
msg = ""
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
from .gymnax_env import GymNaxEnv, GymNaxConfig
|
from .gymnax_env import GymNaxEnv
|
||||||
from .brax_env import BraxEnv, BraxConfig
|
from .brax_env import BraxEnv
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import gymnax
|
|||||||
from .rl_jit import RLEnv
|
from .rl_jit import RLEnv
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class GymNaxEnv(RLEnv):
|
class GymNaxEnv(RLEnv):
|
||||||
|
|
||||||
def __init__(self, env_name):
|
def __init__(self, env_name):
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ import jax
|
|||||||
|
|
||||||
from .. import BaseProblem
|
from .. import BaseProblem
|
||||||
|
|
||||||
class RLEnv(BaseProblem):
|
|
||||||
|
|
||||||
|
class RLEnv(BaseProblem):
|
||||||
jitable = True
|
jitable = True
|
||||||
|
|
||||||
# TODO: move output transform to algorithm
|
# TODO: move output transform to algorithm
|
||||||
@@ -19,9 +19,10 @@ class RLEnv(BaseProblem):
|
|||||||
def cond_func(carry):
|
def cond_func(carry):
|
||||||
_, _, _, done, _ = carry
|
_, _, _, done, _ = carry
|
||||||
return ~done
|
return ~done
|
||||||
|
|
||||||
def body_func(carry):
|
def body_func(carry):
|
||||||
obs, env_state, rng, _, tr = carry # total reward
|
obs, env_state, rng, _, tr = carry # total reward
|
||||||
action = act_func(state, obs, params)
|
action = act_func(obs, params)
|
||||||
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
|
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
|
||||||
next_rng, _ = jax.random.split(rng)
|
next_rng, _ = jax.random.split(rng)
|
||||||
return next_obs, next_env_state, next_rng, done, tr + reward
|
return next_obs, next_env_state, next_rng, done, tr + reward
|
||||||
|
|||||||
66
t.py
66
t.py
@@ -1,64 +1,4 @@
|
|||||||
from algorithm.neat import *
|
import jax.numpy as jnp
|
||||||
from utils import Act, Agg
|
|
||||||
|
|
||||||
import jax, jax.numpy as jnp
|
a = jnp.zeros((0, 9, 9))
|
||||||
|
print(a)
|
||||||
|
|
||||||
def main():
|
|
||||||
|
|
||||||
# index, bias, response, activation, aggregation
|
|
||||||
nodes = jnp.array([
|
|
||||||
[0, 0, 1, 0, 0], # in[0]
|
|
||||||
[1, 0, 1, 0, 0], # in[1]
|
|
||||||
[2, 0.5, 1, 0, 0], # out[0],
|
|
||||||
[3, 1, 1, 0, 0], # hidden[0],
|
|
||||||
[4, -1, 1, 0, 0], # hidden[1],
|
|
||||||
])
|
|
||||||
|
|
||||||
# in_node, out_node, enable, weight
|
|
||||||
conns = jnp.array([
|
|
||||||
[0, 3, 1, 0.5], # in[0] -> hidden[0]
|
|
||||||
[1, 4, 1, 0.5], # in[1] -> hidden[1]
|
|
||||||
[3, 2, 1, 0.5], # hidden[0] -> out[0]
|
|
||||||
[4, 2, 1, 0.5], # hidden[1] -> out[0]
|
|
||||||
])
|
|
||||||
|
|
||||||
genome = RecurrentGenome(
|
|
||||||
num_inputs=2,
|
|
||||||
num_outputs=1,
|
|
||||||
node_gene=DefaultNodeGene(
|
|
||||||
activation_default=Act.identity,
|
|
||||||
activation_options=(Act.identity, ),
|
|
||||||
aggregation_default=Agg.sum,
|
|
||||||
aggregation_options=(Agg.sum, ),
|
|
||||||
),
|
|
||||||
activate_time=3
|
|
||||||
)
|
|
||||||
|
|
||||||
transformed = genome.transform(nodes, conns)
|
|
||||||
print(*transformed, sep='\n')
|
|
||||||
|
|
||||||
inputs = jnp.array([0, 0])
|
|
||||||
outputs = genome.forward(inputs, transformed)
|
|
||||||
print(outputs)
|
|
||||||
|
|
||||||
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
|
|
||||||
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(0, None)))(inputs, transformed)
|
|
||||||
print(outputs)
|
|
||||||
expected: [[0.5], [0.75], [0.75], [1]]
|
|
||||||
|
|
||||||
print('\n-------------------------------------------------------\n')
|
|
||||||
|
|
||||||
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
|
|
||||||
print(conns)
|
|
||||||
|
|
||||||
transformed = genome.transform(nodes, conns)
|
|
||||||
print(*transformed, sep='\n')
|
|
||||||
|
|
||||||
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
|
|
||||||
outputs = jax.vmap(genome.forward, in_axes=(0, None))(inputs, transformed)
|
|
||||||
print(outputs)
|
|
||||||
expected: [[0.5], [0.75], [0.5], [0.75]]
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
@@ -26,6 +26,8 @@ def test_default():
|
|||||||
genome = DefaultGenome(
|
genome = DefaultGenome(
|
||||||
num_inputs=2,
|
num_inputs=2,
|
||||||
num_outputs=1,
|
num_outputs=1,
|
||||||
|
max_nodes=5,
|
||||||
|
max_conns=4,
|
||||||
node_gene=DefaultNodeGene(
|
node_gene=DefaultNodeGene(
|
||||||
activation_default=Act.identity,
|
activation_default=Act.identity,
|
||||||
activation_options=(Act.identity, ),
|
activation_options=(Act.identity, ),
|
||||||
@@ -80,6 +82,8 @@ def test_recurrent():
|
|||||||
genome = RecurrentGenome(
|
genome = RecurrentGenome(
|
||||||
num_inputs=2,
|
num_inputs=2,
|
||||||
num_outputs=1,
|
num_outputs=1,
|
||||||
|
max_nodes=5,
|
||||||
|
max_conns=4,
|
||||||
node_gene=DefaultNodeGene(
|
node_gene=DefaultNodeGene(
|
||||||
activation_default=Act.identity,
|
activation_default=Act.identity,
|
||||||
activation_options=(Act.identity, ),
|
activation_options=(Act.identity, ),
|
||||||
|
|||||||
@@ -6,48 +6,26 @@ class Act:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sigmoid(z):
|
def sigmoid(z):
|
||||||
z = jnp.clip(z * 5, -60, 60)
|
z = jnp.clip(5 * z, -10, 10)
|
||||||
return 1 / (1 + jnp.exp(-z))
|
return 1 / (1 + jnp.exp(-z))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tanh(z):
|
def tanh(z):
|
||||||
z = jnp.clip(z * 2.5, -60, 60)
|
|
||||||
return jnp.tanh(z)
|
return jnp.tanh(z)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sin(z):
|
def sin(z):
|
||||||
z = jnp.clip(z * 5, -60, 60)
|
|
||||||
return jnp.sin(z)
|
return jnp.sin(z)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def gauss(z):
|
|
||||||
z = jnp.clip(z * 5, -3.4, 3.4)
|
|
||||||
return jnp.exp(-z ** 2)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def relu(z):
|
def relu(z):
|
||||||
return jnp.maximum(z, 0)
|
return jnp.maximum(z, 0)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def elu(z):
|
|
||||||
return jnp.where(z > 0, z, jnp.exp(z) - 1)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def lelu(z):
|
def lelu(z):
|
||||||
leaky = 0.005
|
leaky = 0.005
|
||||||
return jnp.where(z > 0, z, leaky * z)
|
return jnp.where(z > 0, z, leaky * z)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def selu(z):
|
|
||||||
lam = 1.0507009873554804934193349852946
|
|
||||||
alpha = 1.6732632423543772848170429916717
|
|
||||||
return jnp.where(z > 0, lam * z, lam * alpha * (jnp.exp(z) - 1))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def softplus(z):
|
|
||||||
z = jnp.clip(z * 5, -60, 60)
|
|
||||||
return 0.2 * jnp.log(1 + jnp.exp(z))
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def identity(z):
|
def identity(z):
|
||||||
return z
|
return z
|
||||||
@@ -58,7 +36,11 @@ class Act:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def inv(z):
|
def inv(z):
|
||||||
z = jnp.maximum(z, 1e-7)
|
z = jnp.where(
|
||||||
|
z > 0,
|
||||||
|
jnp.maximum(z, 1e-7),
|
||||||
|
jnp.minimum(z, -1e-7)
|
||||||
|
)
|
||||||
return 1 / z
|
return 1 / z
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -68,24 +50,27 @@ class Act:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def exp(z):
|
def exp(z):
|
||||||
z = jnp.clip(z, -60, 60)
|
z = jnp.clip(z, -10, 10)
|
||||||
return jnp.exp(z)
|
return jnp.exp(z)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def abs(z):
|
def abs(z):
|
||||||
return jnp.abs(z)
|
return jnp.abs(z)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def hat(z):
|
|
||||||
return jnp.maximum(0, 1 - jnp.abs(z))
|
|
||||||
|
|
||||||
@staticmethod
|
ACT_ALL = (
|
||||||
def square(z):
|
Act.sigmoid,
|
||||||
return z ** 2
|
Act.tanh,
|
||||||
|
Act.sin,
|
||||||
@staticmethod
|
Act.relu,
|
||||||
def cube(z):
|
Act.lelu,
|
||||||
return z ** 3
|
Act.identity,
|
||||||
|
Act.clamped,
|
||||||
|
Act.inv,
|
||||||
|
Act.log,
|
||||||
|
Act.exp,
|
||||||
|
Act.abs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def act(idx, z, act_funcs):
|
def act(idx, z, act_funcs):
|
||||||
|
|||||||
@@ -51,6 +51,9 @@ class Agg:
|
|||||||
return mean_without_zeros
|
return mean_without_zeros
|
||||||
|
|
||||||
|
|
||||||
|
AGG_ALL = (Agg.sum, Agg.product, Agg.max, Agg.min, Agg.maxabs, Agg.median, Agg.mean)
|
||||||
|
|
||||||
|
|
||||||
def agg(idx, z, agg_funcs):
|
def agg(idx, z, agg_funcs):
|
||||||
"""
|
"""
|
||||||
calculate activation function for inputs of node
|
calculate activation function for inputs of node
|
||||||
|
|||||||
Reference in New Issue
Block a user