update hyperneat and related examples

This commit is contained in:
root
2024-07-11 15:08:02 +08:00
parent 9bad577d89
commit 3cb5fbf581
7 changed files with 102 additions and 136 deletions

View File

@@ -1,60 +0,0 @@
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import BraxEnv
from tensorneat.common import Act
import jax, jax.numpy as jnp
def split_right_left(randkey, forward_func, obs):
right_obs_keys = jnp.array([2, 3, 4, 11, 12, 13])
left_obs_keys = jnp.array([5, 6, 7, 14, 15, 16])
right_action_keys = jnp.array([0, 1, 2])
left_action_keys = jnp.array([3, 4, 5])
right_foot_obs = obs
left_foot_obs = obs
left_foot_obs = left_foot_obs.at[right_obs_keys].set(obs[left_obs_keys])
left_foot_obs = left_foot_obs.at[left_obs_keys].set(obs[right_obs_keys])
right_action, left_action = jax.vmap(forward_func)(jnp.stack([right_foot_obs, left_foot_obs]))
# print(right_action.shape)
# print(left_action.shape)
return jnp.concatenate([right_action, left_action])
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=17,
num_outputs=3,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
),
output_transform=Act.tanh,
),
pop_size=1000,
species_size=10,
),
),
problem=BraxEnv(
env_name="walker2d",
max_step=1000,
action_policy=split_right_left
),
generation_limit=10000,
fitness_target=5000,
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

51
examples/brax/walker2d.py Normal file
View File

@@ -0,0 +1,51 @@
from tensorneat.pipeline import Pipeline
from tensorneat.algorithm.neat import NEAT
from tensorneat.genome import DefaultGenome, BiasNode
from tensorneat.problem.rl_env import BraxEnv
from tensorneat.common import Act, Agg
import jax, jax.numpy as jnp
def random_sample_policy(randkey, obs):
return jax.random.uniform(randkey, (6,))
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
pop_size=1000,
species_size=20,
survival_threshold=0.1,
compatibility_threshold=1.0,
genome=DefaultGenome(
max_nodes=100,
max_conns=200,
num_inputs=17,
num_outputs=6,
init_hidden_layers=(),
node_gene=BiasNode(
activation_options=Act.tanh,
aggregation_options=Agg.sum,
),
output_transform=Act.standard_tanh,
),
),
problem=BraxEnv(
env_name="walker2d",
max_step=1000,
obs_normalization=True,
sample_episodes=1000,
sample_policy=random_sample_policy,
),
seed=42,
generation_limit=100,
fitness_target=5000,
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -1,53 +1,33 @@
from pipeline import Pipeline
from algorithm.neat import *
from algorithm.hyperneat import *
from tensorneat.pipeline import Pipeline
from tensorneat.algorithm.neat import NEAT
from tensorneat.algorithm.hyperneat import HyperNEAT, FullSubstrate
from tensorneat.genome import DefaultGenome
from tensorneat.common import Act
from problem.func_fit import XOR3d
from tensorneat.problem.func_fit import XOR3d
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=HyperNEAT(
substrate=FullSubstrate(
input_coors=[(-1, -1), (0.333, -1), (-0.333, -1), (1, -1)], # 3(XOR3d inputs) + 1(bias)
hidden_coors=[
(-1, -0.5), (0.333, -0.5), (-0.333, -0.5),
(1, -0.5),
(-1, 0),
(0.333, 0),
(-0.333, 0),
(1, 0),
(-1, 0.5),
(0.333, 0.5),
(-0.333, 0.5),
(1, 0.5),
],
output_coors=[
(0, 1), # one output
],
input_coors=((-1, -1), (-0.33, -1), (0.33, -1), (1, -1)),
hidden_coors=((-1, 0), (0, 0), (1, 0)),
output_coors=((0, 1),),
),
neat=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=4, # [*coor1, *coor2]
num_outputs=1, # the weight of connection between two coor1 and coor2
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
output_transform=Act.tanh, # the activation function for output node in NEAT
),
pop_size=1000,
species_size=10,
compatibility_threshold=2,
survival_threshold=0.03,
pop_size=10000,
species_size=20,
survival_threshold=0.01,
genome=DefaultGenome(
num_inputs=4, # size of query coors
num_outputs=1,
init_hidden_layers=(),
output_transform=Act.standard_tanh,
),
),
activation=Act.tanh,
activate_time=10,
output_transform=Act.sigmoid, # the activation function for output node in HyperNEAT
output_transform=Act.standard_sigmoid,
),
problem=XOR3d(),
generation_limit=300,

View File

@@ -1,2 +1,3 @@
from .base import BaseAlgorithm
from .neat import NEAT
from .hyperneat import HyperNEAT

View File

@@ -1,12 +1,12 @@
from typing import Callable
import jax, jax.numpy as jnp
import jax
from jax import vmap, numpy as jnp
from tensorneat.common import State, Act, Agg
from .. import BaseAlgorithm, NEAT
from ..neat.gene import BaseNodeGene, BaseConnGene
from ..neat.genome import RecurrentGenome
from .substrate import *
from tensorneat.common import State, Act, Agg
from tensorneat.algorithm import BaseAlgorithm, NEAT
from tensorneat.genome import BaseNode, BaseConn, RecurrentGenome
class HyperNEAT(BaseAlgorithm):
@@ -14,64 +14,65 @@ class HyperNEAT(BaseAlgorithm):
self,
substrate: BaseSubstrate,
neat: NEAT,
below_threshold: float = 0.3,
weight_threshold: float = 0.3,
max_weight: float = 5.0,
aggregation=Agg.sum,
activation=Act.sigmoid,
aggregation: Callable = Agg.sum,
activation: Callable = Act.sigmoid,
activate_time: int = 10,
output_transform: Callable = Act.sigmoid,
output_transform: Callable = Act.standard_sigmoid,
):
assert (
substrate.query_coors.shape[1] == neat.num_inputs
), "Substrate input size should be equal to NEAT input size"
), "Query coors of Substrate should be equal to NEAT input size"
self.substrate = substrate
self.neat = neat
self.below_threshold = below_threshold
self.weight_threshold = weight_threshold
self.max_weight = max_weight
self.hyper_genome = RecurrentGenome(
num_inputs=substrate.num_inputs,
num_outputs=substrate.num_outputs,
max_nodes=substrate.nodes_cnt,
max_conns=substrate.conns_cnt,
node_gene=HyperNodeGene(aggregation, activation),
conn_gene=HyperNEATConnGene(),
node_gene=HyperNEATNode(aggregation, activation),
conn_gene=HyperNEATConn(),
activate_time=activate_time,
output_transform=output_transform,
)
self.pop_size = neat.pop_size
def setup(self, state=State()):
state = self.neat.setup(state)
state = self.substrate.setup(state)
return self.hyper_genome.setup(state)
def ask(self, state: State):
def ask(self, state):
return self.neat.ask(state)
def tell(self, state: State, fitness):
def tell(self, state, fitness):
state = self.neat.tell(state, fitness)
return state
def transform(self, state, individual):
transformed = self.neat.transform(state, individual)
query_res = jax.vmap(self.neat.forward, in_axes=(None, None, 0))(
query_res = vmap(self.neat.forward, in_axes=(None, None, 0))(
state, transformed, self.substrate.query_coors
)
# mute the connection with weight below threshold
# mute the connection with weight weight threshold
query_res = jnp.where(
(-self.below_threshold < query_res) & (query_res < self.below_threshold),
(-self.weight_threshold < query_res) & (query_res < self.weight_threshold),
0.0,
query_res,
)
# make query res in range [-max_weight, max_weight]
query_res = jnp.where(
query_res > 0, query_res - self.below_threshold, query_res
query_res > 0, query_res - self.weight_threshold, query_res
)
query_res = jnp.where(
query_res < 0, query_res + self.below_threshold, query_res
query_res < 0, query_res + self.weight_threshold, query_res
)
query_res = query_res / (1 - self.below_threshold) * self.max_weight
query_res = query_res / (1 - self.weight_threshold) * self.max_weight
h_nodes, h_conns = self.substrate.make_nodes(
query_res
@@ -79,11 +80,11 @@ class HyperNEAT(BaseAlgorithm):
return self.hyper_genome.transform(state, h_nodes, h_conns)
def forward(self, state, inputs, transformed):
def forward(self, state, transformed, inputs):
# add bias
inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])])
res = self.hyper_genome.forward(state, inputs_with_bias, transformed)
res = self.hyper_genome.forward(state, transformed, inputs_with_bias)
return res
@property
@@ -94,18 +95,11 @@ class HyperNEAT(BaseAlgorithm):
def num_outputs(self):
return self.substrate.num_outputs
@property
def pop_size(self):
return self.neat.pop_size
def member_count(self, state: State):
return self.neat.member_count(state)
def generation(self, state: State):
return self.neat.generation(state)
def show_details(self, state, fitness):
return self.neat.show_details(state, fitness)
class HyperNodeGene(BaseNodeGene):
class HyperNEATNode(BaseNode):
def __init__(
self,
aggregation=Agg.sum,
@@ -123,7 +117,7 @@ class HyperNodeGene(BaseNodeGene):
)
class HyperNEATConnGene(BaseConnGene):
class HyperNEATConn(BaseConn):
custom_attrs = ["weight"]
def forward(self, state, attrs, inputs):

View File

@@ -23,10 +23,10 @@ from ...utils import (
class DefaultMutation(BaseMutation):
def __init__(
self,
conn_add: float = 0.1,
conn_delete: float = 0,
conn_add: float = 0.2,
conn_delete: float = 0.2,
node_add: float = 0.1,
node_delete: float = 0,
node_delete: float = 0.1,
):
self.conn_add = conn_add
self.conn_delete = conn_delete

View File

@@ -34,7 +34,7 @@ class Pipeline(StatefulBaseClass):
self.generation_limit = generation_limit
self.pop_size = self.algorithm.pop_size
# print(self.problem.input_shape, self.problem.output_shape)
np.random.seed(self.seed)
# TODO: make each algorithm's input_num and output_num
assert (