update hyperneat and related examples
This commit is contained in:
@@ -1,60 +0,0 @@
|
|||||||
from pipeline import Pipeline
|
|
||||||
from algorithm.neat import *
|
|
||||||
|
|
||||||
from problem.rl_env import BraxEnv
|
|
||||||
from tensorneat.common import Act
|
|
||||||
|
|
||||||
import jax, jax.numpy as jnp
|
|
||||||
|
|
||||||
|
|
||||||
def split_right_left(randkey, forward_func, obs):
|
|
||||||
right_obs_keys = jnp.array([2, 3, 4, 11, 12, 13])
|
|
||||||
left_obs_keys = jnp.array([5, 6, 7, 14, 15, 16])
|
|
||||||
right_action_keys = jnp.array([0, 1, 2])
|
|
||||||
left_action_keys = jnp.array([3, 4, 5])
|
|
||||||
|
|
||||||
right_foot_obs = obs
|
|
||||||
left_foot_obs = obs
|
|
||||||
left_foot_obs = left_foot_obs.at[right_obs_keys].set(obs[left_obs_keys])
|
|
||||||
left_foot_obs = left_foot_obs.at[left_obs_keys].set(obs[right_obs_keys])
|
|
||||||
|
|
||||||
right_action, left_action = jax.vmap(forward_func)(jnp.stack([right_foot_obs, left_foot_obs]))
|
|
||||||
# print(right_action.shape)
|
|
||||||
# print(left_action.shape)
|
|
||||||
|
|
||||||
return jnp.concatenate([right_action, left_action])
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pipeline = Pipeline(
|
|
||||||
algorithm=NEAT(
|
|
||||||
species=DefaultSpecies(
|
|
||||||
genome=DefaultGenome(
|
|
||||||
num_inputs=17,
|
|
||||||
num_outputs=3,
|
|
||||||
max_nodes=50,
|
|
||||||
max_conns=100,
|
|
||||||
node_gene=DefaultNodeGene(
|
|
||||||
activation_options=(Act.tanh,),
|
|
||||||
activation_default=Act.tanh,
|
|
||||||
),
|
|
||||||
output_transform=Act.tanh,
|
|
||||||
),
|
|
||||||
pop_size=1000,
|
|
||||||
species_size=10,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
problem=BraxEnv(
|
|
||||||
env_name="walker2d",
|
|
||||||
max_step=1000,
|
|
||||||
action_policy=split_right_left
|
|
||||||
),
|
|
||||||
generation_limit=10000,
|
|
||||||
fitness_target=5000,
|
|
||||||
)
|
|
||||||
|
|
||||||
# initialize state
|
|
||||||
state = pipeline.setup()
|
|
||||||
# print(state)
|
|
||||||
# run until terminate
|
|
||||||
state, best = pipeline.auto_run(state)
|
|
||||||
51
examples/brax/walker2d.py
Normal file
51
examples/brax/walker2d.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
from tensorneat.pipeline import Pipeline
|
||||||
|
from tensorneat.algorithm.neat import NEAT
|
||||||
|
from tensorneat.genome import DefaultGenome, BiasNode
|
||||||
|
|
||||||
|
from tensorneat.problem.rl_env import BraxEnv
|
||||||
|
from tensorneat.common import Act, Agg
|
||||||
|
|
||||||
|
import jax, jax.numpy as jnp
|
||||||
|
|
||||||
|
|
||||||
|
def random_sample_policy(randkey, obs):
|
||||||
|
return jax.random.uniform(randkey, (6,))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pipeline = Pipeline(
|
||||||
|
algorithm=NEAT(
|
||||||
|
pop_size=1000,
|
||||||
|
species_size=20,
|
||||||
|
survival_threshold=0.1,
|
||||||
|
compatibility_threshold=1.0,
|
||||||
|
genome=DefaultGenome(
|
||||||
|
max_nodes=100,
|
||||||
|
max_conns=200,
|
||||||
|
num_inputs=17,
|
||||||
|
num_outputs=6,
|
||||||
|
init_hidden_layers=(),
|
||||||
|
node_gene=BiasNode(
|
||||||
|
activation_options=Act.tanh,
|
||||||
|
aggregation_options=Agg.sum,
|
||||||
|
),
|
||||||
|
output_transform=Act.standard_tanh,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
problem=BraxEnv(
|
||||||
|
env_name="walker2d",
|
||||||
|
max_step=1000,
|
||||||
|
obs_normalization=True,
|
||||||
|
sample_episodes=1000,
|
||||||
|
sample_policy=random_sample_policy,
|
||||||
|
),
|
||||||
|
seed=42,
|
||||||
|
generation_limit=100,
|
||||||
|
fitness_target=5000,
|
||||||
|
)
|
||||||
|
|
||||||
|
# initialize state
|
||||||
|
state = pipeline.setup()
|
||||||
|
# print(state)
|
||||||
|
# run until terminate
|
||||||
|
state, best = pipeline.auto_run(state)
|
||||||
@@ -1,53 +1,33 @@
|
|||||||
from pipeline import Pipeline
|
from tensorneat.pipeline import Pipeline
|
||||||
from algorithm.neat import *
|
from tensorneat.algorithm.neat import NEAT
|
||||||
from algorithm.hyperneat import *
|
from tensorneat.algorithm.hyperneat import HyperNEAT, FullSubstrate
|
||||||
|
from tensorneat.genome import DefaultGenome
|
||||||
from tensorneat.common import Act
|
from tensorneat.common import Act
|
||||||
|
|
||||||
from problem.func_fit import XOR3d
|
from tensorneat.problem.func_fit import XOR3d
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pipeline = Pipeline(
|
pipeline = Pipeline(
|
||||||
algorithm=HyperNEAT(
|
algorithm=HyperNEAT(
|
||||||
substrate=FullSubstrate(
|
substrate=FullSubstrate(
|
||||||
input_coors=[(-1, -1), (0.333, -1), (-0.333, -1), (1, -1)], # 3(XOR3d inputs) + 1(bias)
|
input_coors=((-1, -1), (-0.33, -1), (0.33, -1), (1, -1)),
|
||||||
hidden_coors=[
|
hidden_coors=((-1, 0), (0, 0), (1, 0)),
|
||||||
(-1, -0.5), (0.333, -0.5), (-0.333, -0.5),
|
output_coors=((0, 1),),
|
||||||
(1, -0.5),
|
|
||||||
(-1, 0),
|
|
||||||
(0.333, 0),
|
|
||||||
(-0.333, 0),
|
|
||||||
(1, 0),
|
|
||||||
(-1, 0.5),
|
|
||||||
(0.333, 0.5),
|
|
||||||
(-0.333, 0.5),
|
|
||||||
(1, 0.5),
|
|
||||||
],
|
|
||||||
output_coors=[
|
|
||||||
(0, 1), # one output
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
neat=NEAT(
|
neat=NEAT(
|
||||||
species=DefaultSpecies(
|
pop_size=10000,
|
||||||
genome=DefaultGenome(
|
species_size=20,
|
||||||
num_inputs=4, # [*coor1, *coor2]
|
survival_threshold=0.01,
|
||||||
num_outputs=1, # the weight of connection between two coor1 and coor2
|
genome=DefaultGenome(
|
||||||
max_nodes=50,
|
num_inputs=4, # size of query coors
|
||||||
max_conns=100,
|
num_outputs=1,
|
||||||
node_gene=DefaultNodeGene(
|
init_hidden_layers=(),
|
||||||
activation_default=Act.tanh,
|
output_transform=Act.standard_tanh,
|
||||||
activation_options=(Act.tanh,),
|
|
||||||
),
|
|
||||||
output_transform=Act.tanh, # the activation function for output node in NEAT
|
|
||||||
),
|
|
||||||
pop_size=1000,
|
|
||||||
species_size=10,
|
|
||||||
compatibility_threshold=2,
|
|
||||||
survival_threshold=0.03,
|
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
activation=Act.tanh,
|
activation=Act.tanh,
|
||||||
activate_time=10,
|
activate_time=10,
|
||||||
output_transform=Act.sigmoid, # the activation function for output node in HyperNEAT
|
output_transform=Act.standard_sigmoid,
|
||||||
),
|
),
|
||||||
problem=XOR3d(),
|
problem=XOR3d(),
|
||||||
generation_limit=300,
|
generation_limit=300,
|
||||||
|
|||||||
@@ -1,2 +1,3 @@
|
|||||||
from .base import BaseAlgorithm
|
from .base import BaseAlgorithm
|
||||||
from .neat import NEAT
|
from .neat import NEAT
|
||||||
|
from .hyperneat import HyperNEAT
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import jax, jax.numpy as jnp
|
import jax
|
||||||
|
from jax import vmap, numpy as jnp
|
||||||
|
|
||||||
from tensorneat.common import State, Act, Agg
|
|
||||||
from .. import BaseAlgorithm, NEAT
|
|
||||||
from ..neat.gene import BaseNodeGene, BaseConnGene
|
|
||||||
from ..neat.genome import RecurrentGenome
|
|
||||||
from .substrate import *
|
from .substrate import *
|
||||||
|
from tensorneat.common import State, Act, Agg
|
||||||
|
from tensorneat.algorithm import BaseAlgorithm, NEAT
|
||||||
|
from tensorneat.genome import BaseNode, BaseConn, RecurrentGenome
|
||||||
|
|
||||||
|
|
||||||
class HyperNEAT(BaseAlgorithm):
|
class HyperNEAT(BaseAlgorithm):
|
||||||
@@ -14,64 +14,65 @@ class HyperNEAT(BaseAlgorithm):
|
|||||||
self,
|
self,
|
||||||
substrate: BaseSubstrate,
|
substrate: BaseSubstrate,
|
||||||
neat: NEAT,
|
neat: NEAT,
|
||||||
below_threshold: float = 0.3,
|
weight_threshold: float = 0.3,
|
||||||
max_weight: float = 5.0,
|
max_weight: float = 5.0,
|
||||||
aggregation=Agg.sum,
|
aggregation: Callable = Agg.sum,
|
||||||
activation=Act.sigmoid,
|
activation: Callable = Act.sigmoid,
|
||||||
activate_time: int = 10,
|
activate_time: int = 10,
|
||||||
output_transform: Callable = Act.sigmoid,
|
output_transform: Callable = Act.standard_sigmoid,
|
||||||
):
|
):
|
||||||
assert (
|
assert (
|
||||||
substrate.query_coors.shape[1] == neat.num_inputs
|
substrate.query_coors.shape[1] == neat.num_inputs
|
||||||
), "Substrate input size should be equal to NEAT input size"
|
), "Query coors of Substrate should be equal to NEAT input size"
|
||||||
|
|
||||||
self.substrate = substrate
|
self.substrate = substrate
|
||||||
self.neat = neat
|
self.neat = neat
|
||||||
self.below_threshold = below_threshold
|
self.weight_threshold = weight_threshold
|
||||||
self.max_weight = max_weight
|
self.max_weight = max_weight
|
||||||
self.hyper_genome = RecurrentGenome(
|
self.hyper_genome = RecurrentGenome(
|
||||||
num_inputs=substrate.num_inputs,
|
num_inputs=substrate.num_inputs,
|
||||||
num_outputs=substrate.num_outputs,
|
num_outputs=substrate.num_outputs,
|
||||||
max_nodes=substrate.nodes_cnt,
|
max_nodes=substrate.nodes_cnt,
|
||||||
max_conns=substrate.conns_cnt,
|
max_conns=substrate.conns_cnt,
|
||||||
node_gene=HyperNodeGene(aggregation, activation),
|
node_gene=HyperNEATNode(aggregation, activation),
|
||||||
conn_gene=HyperNEATConnGene(),
|
conn_gene=HyperNEATConn(),
|
||||||
activate_time=activate_time,
|
activate_time=activate_time,
|
||||||
output_transform=output_transform,
|
output_transform=output_transform,
|
||||||
)
|
)
|
||||||
|
self.pop_size = neat.pop_size
|
||||||
|
|
||||||
def setup(self, state=State()):
|
def setup(self, state=State()):
|
||||||
state = self.neat.setup(state)
|
state = self.neat.setup(state)
|
||||||
state = self.substrate.setup(state)
|
state = self.substrate.setup(state)
|
||||||
return self.hyper_genome.setup(state)
|
return self.hyper_genome.setup(state)
|
||||||
|
|
||||||
def ask(self, state: State):
|
def ask(self, state):
|
||||||
return self.neat.ask(state)
|
return self.neat.ask(state)
|
||||||
|
|
||||||
def tell(self, state: State, fitness):
|
def tell(self, state, fitness):
|
||||||
state = self.neat.tell(state, fitness)
|
state = self.neat.tell(state, fitness)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def transform(self, state, individual):
|
def transform(self, state, individual):
|
||||||
transformed = self.neat.transform(state, individual)
|
transformed = self.neat.transform(state, individual)
|
||||||
query_res = jax.vmap(self.neat.forward, in_axes=(None, None, 0))(
|
query_res = vmap(self.neat.forward, in_axes=(None, None, 0))(
|
||||||
state, transformed, self.substrate.query_coors
|
state, transformed, self.substrate.query_coors
|
||||||
)
|
)
|
||||||
# mute the connection with weight below threshold
|
# mute the connection with weight weight threshold
|
||||||
query_res = jnp.where(
|
query_res = jnp.where(
|
||||||
(-self.below_threshold < query_res) & (query_res < self.below_threshold),
|
(-self.weight_threshold < query_res) & (query_res < self.weight_threshold),
|
||||||
0.0,
|
0.0,
|
||||||
query_res,
|
query_res,
|
||||||
)
|
)
|
||||||
|
|
||||||
# make query res in range [-max_weight, max_weight]
|
# make query res in range [-max_weight, max_weight]
|
||||||
query_res = jnp.where(
|
query_res = jnp.where(
|
||||||
query_res > 0, query_res - self.below_threshold, query_res
|
query_res > 0, query_res - self.weight_threshold, query_res
|
||||||
)
|
)
|
||||||
query_res = jnp.where(
|
query_res = jnp.where(
|
||||||
query_res < 0, query_res + self.below_threshold, query_res
|
query_res < 0, query_res + self.weight_threshold, query_res
|
||||||
)
|
)
|
||||||
query_res = query_res / (1 - self.below_threshold) * self.max_weight
|
query_res = query_res / (1 - self.weight_threshold) * self.max_weight
|
||||||
|
|
||||||
h_nodes, h_conns = self.substrate.make_nodes(
|
h_nodes, h_conns = self.substrate.make_nodes(
|
||||||
query_res
|
query_res
|
||||||
@@ -79,11 +80,11 @@ class HyperNEAT(BaseAlgorithm):
|
|||||||
|
|
||||||
return self.hyper_genome.transform(state, h_nodes, h_conns)
|
return self.hyper_genome.transform(state, h_nodes, h_conns)
|
||||||
|
|
||||||
def forward(self, state, inputs, transformed):
|
def forward(self, state, transformed, inputs):
|
||||||
# add bias
|
# add bias
|
||||||
inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])])
|
inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])])
|
||||||
|
|
||||||
res = self.hyper_genome.forward(state, inputs_with_bias, transformed)
|
res = self.hyper_genome.forward(state, transformed, inputs_with_bias)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -94,18 +95,11 @@ class HyperNEAT(BaseAlgorithm):
|
|||||||
def num_outputs(self):
|
def num_outputs(self):
|
||||||
return self.substrate.num_outputs
|
return self.substrate.num_outputs
|
||||||
|
|
||||||
@property
|
def show_details(self, state, fitness):
|
||||||
def pop_size(self):
|
return self.neat.show_details(state, fitness)
|
||||||
return self.neat.pop_size
|
|
||||||
|
|
||||||
def member_count(self, state: State):
|
|
||||||
return self.neat.member_count(state)
|
|
||||||
|
|
||||||
def generation(self, state: State):
|
|
||||||
return self.neat.generation(state)
|
|
||||||
|
|
||||||
|
|
||||||
class HyperNodeGene(BaseNodeGene):
|
class HyperNEATNode(BaseNode):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
aggregation=Agg.sum,
|
aggregation=Agg.sum,
|
||||||
@@ -123,7 +117,7 @@ class HyperNodeGene(BaseNodeGene):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class HyperNEATConnGene(BaseConnGene):
|
class HyperNEATConn(BaseConn):
|
||||||
custom_attrs = ["weight"]
|
custom_attrs = ["weight"]
|
||||||
|
|
||||||
def forward(self, state, attrs, inputs):
|
def forward(self, state, attrs, inputs):
|
||||||
|
|||||||
@@ -23,10 +23,10 @@ from ...utils import (
|
|||||||
class DefaultMutation(BaseMutation):
|
class DefaultMutation(BaseMutation):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
conn_add: float = 0.1,
|
conn_add: float = 0.2,
|
||||||
conn_delete: float = 0,
|
conn_delete: float = 0.2,
|
||||||
node_add: float = 0.1,
|
node_add: float = 0.1,
|
||||||
node_delete: float = 0,
|
node_delete: float = 0.1,
|
||||||
):
|
):
|
||||||
self.conn_add = conn_add
|
self.conn_add = conn_add
|
||||||
self.conn_delete = conn_delete
|
self.conn_delete = conn_delete
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ class Pipeline(StatefulBaseClass):
|
|||||||
self.generation_limit = generation_limit
|
self.generation_limit = generation_limit
|
||||||
self.pop_size = self.algorithm.pop_size
|
self.pop_size = self.algorithm.pop_size
|
||||||
|
|
||||||
# print(self.problem.input_shape, self.problem.output_shape)
|
np.random.seed(self.seed)
|
||||||
|
|
||||||
# TODO: make each algorithm's input_num and output_num
|
# TODO: make each algorithm's input_num and output_num
|
||||||
assert (
|
assert (
|
||||||
|
|||||||
Reference in New Issue
Block a user