Merge branch 'main' into advance
This commit is contained in:
13
README.md
13
README.md
@@ -23,11 +23,14 @@
|
||||
TensorNEAT is a JAX-based libaray for NeuroEvolution of Augmenting Topologies (NEAT) algorithms, focused on harnessing GPU acceleration to enhance the efficiency of evolving neural network structures for complex tasks. Its core mechanism involves the tensorization of network topologies, enabling parallel processing and significantly boosting computational speed and scalability by leveraging modern hardware accelerators. TensorNEAT is compatible with the [EvoX](https://github.com/EMI-Group/evox/) framewrok.
|
||||
|
||||
## Requirements
|
||||
TensorNEAT requires:
|
||||
- jax (version >= 0.4.16)
|
||||
- jaxlib (version >= 0.3.0)
|
||||
- brax [optional]
|
||||
- gymnax [optional]
|
||||
Due to the rapid iteration of JAX versions, configuring the runtime environment for TensorNEAT can be challenging. We recommend the following versions for the relevant libraries:
|
||||
|
||||
- jax (0.4.28)
|
||||
- jaxlib (0.4.28+cuda12.cudnn89)
|
||||
- brax (0.10.3)
|
||||
- gymnax (0.0.8)
|
||||
|
||||
We provide detailed JAX-related environment references in [recommend_environment](recommend_environment.txt). If you encounter any issues while configuring the environment yourself, you can use this as a reference.
|
||||
|
||||
## Example
|
||||
Simple Example for XOR problem:
|
||||
|
||||
9
recommend_environment.txt
Normal file
9
recommend_environment.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
brax==0.10.3
|
||||
flax==0.8.4
|
||||
gymnax==0.0.8
|
||||
jax==0.4.28
|
||||
jaxlib==0.4.28+cuda12.cudnn89
|
||||
jaxopt==0.8.3
|
||||
mujoco==3.1.4
|
||||
mujoco-mjx==3.1.4
|
||||
optax==0.2.2
|
||||
8
tensorneat/.idea/.gitignore
generated
vendored
Normal file
8
tensorneat/.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
# 默认忽略的文件
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# 基于编辑器的 HTTP 客户端请求
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from utils import State, Act, Agg
|
||||
@@ -18,6 +20,7 @@ class HyperNEAT(BaseAlgorithm):
|
||||
activation=Act.sigmoid,
|
||||
aggregation=Agg.sum,
|
||||
activate_time: int = 10,
|
||||
output_transform: Callable = Act.sigmoid,
|
||||
):
|
||||
assert substrate.query_coors.shape[1] == neat.num_inputs, \
|
||||
"Substrate input size should be equal to NEAT input size"
|
||||
@@ -34,6 +37,7 @@ class HyperNEAT(BaseAlgorithm):
|
||||
node_gene=HyperNodeGene(activation, aggregation),
|
||||
conn_gene=HyperNEATConnGene(),
|
||||
activate_time=activate_time,
|
||||
output_transform=output_transform
|
||||
)
|
||||
|
||||
def setup(self, randkey):
|
||||
@@ -102,11 +106,13 @@ class HyperNodeGene(BaseNodeGene):
|
||||
self.activation = activation
|
||||
self.aggregation = aggregation
|
||||
|
||||
def forward(self, attrs, inputs):
|
||||
return self.activation(
|
||||
self.aggregation(inputs)
|
||||
)
|
||||
def forward(self, attrs, inputs, is_output_node=False):
|
||||
return jax.lax.cond(
|
||||
is_output_node,
|
||||
lambda: self.aggregation(inputs), # output node does not need activation
|
||||
lambda: self.activation(self.aggregation(inputs))
|
||||
|
||||
)
|
||||
|
||||
class HyperNEATConnGene(BaseConnGene):
|
||||
custom_attrs = ['weight']
|
||||
|
||||
@@ -2,6 +2,7 @@ import jax, jax.numpy as jnp
|
||||
|
||||
from .base import BaseCrossover
|
||||
|
||||
|
||||
class DefaultCrossover(BaseCrossover):
|
||||
|
||||
def __call__(self, randkey, genome, nodes1, conns1, nodes2, conns2):
|
||||
@@ -14,17 +15,19 @@ class DefaultCrossover(BaseCrossover):
|
||||
# crossover nodes
|
||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||
# make homologous genes align in nodes2 align with nodes1
|
||||
nodes2 = self.align_array(keys1, keys2, nodes2, False)
|
||||
nodes2 = self.align_array(keys1, keys2, nodes2, is_conn=False)
|
||||
|
||||
# For not homologous genes, use the value of nodes1(winner)
|
||||
# For homologous genes, use the crossover result between nodes1 and nodes2
|
||||
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, self.crossover_gene(randkey_1, nodes1, nodes2))
|
||||
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1,
|
||||
self.crossover_gene(randkey_1, nodes1, nodes2, is_conn=False))
|
||||
|
||||
# crossover connections
|
||||
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
|
||||
conns2 = self.align_array(con_keys1, con_keys2, conns2, True)
|
||||
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
|
||||
|
||||
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, self.crossover_gene(randkey_2, conns1, conns2))
|
||||
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1,
|
||||
self.crossover_gene(randkey_2, conns1, conns2, is_conn=True))
|
||||
|
||||
return new_nodes, new_conns
|
||||
|
||||
@@ -54,14 +57,14 @@ class DefaultCrossover(BaseCrossover):
|
||||
|
||||
return refactor_ar2
|
||||
|
||||
def crossover_gene(self, rand_key, g1, g2):
|
||||
"""
|
||||
crossover two genes
|
||||
:param rand_key:
|
||||
:param g1:
|
||||
:param g2:
|
||||
:return:
|
||||
only gene with the same key will be crossover, thus don't need to consider change key
|
||||
"""
|
||||
def crossover_gene(self, rand_key, g1, g2, is_conn):
|
||||
r = jax.random.uniform(rand_key, shape=g1.shape)
|
||||
return jnp.where(r > 0.5, g1, g2)
|
||||
new_gene = jnp.where(r > 0.5, g1, g2)
|
||||
if is_conn: # fix enabled
|
||||
enabled = jnp.where(
|
||||
g1[:, 2] + g2[:, 2] > 0, # any of them is enabled
|
||||
1,
|
||||
0
|
||||
)
|
||||
new_gene = new_gene.at[:, 2].set(enabled)
|
||||
return new_gene
|
||||
|
||||
@@ -154,8 +154,8 @@ class DefaultMutation(BaseMutation):
|
||||
nodes_keys = jax.random.split(k1, num=nodes.shape[0])
|
||||
conns_keys = jax.random.split(k2, num=conns.shape[0])
|
||||
|
||||
new_nodes = jax.vmap(genome.node_gene.mutate, in_axes=(0, 0))(nodes_keys, nodes)
|
||||
new_conns = jax.vmap(genome.conn_gene.mutate, in_axes=(0, 0))(conns_keys, conns)
|
||||
new_nodes = jax.vmap(genome.node_gene.mutate)(nodes_keys, nodes)
|
||||
new_conns = jax.vmap(genome.conn_gene.mutate)(conns_keys, conns)
|
||||
|
||||
# nan nodes not changed
|
||||
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)
|
||||
|
||||
@@ -8,5 +8,5 @@ class BaseNodeGene(BaseGene):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, attrs, inputs):
|
||||
def forward(self, attrs, inputs, is_output_node=False):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -95,11 +95,17 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
(node1[4] != node2[4])
|
||||
)
|
||||
|
||||
def forward(self, attrs, inputs):
|
||||
def forward(self, attrs, inputs, is_output_node=False):
|
||||
bias, res, act_idx, agg_idx = attrs
|
||||
|
||||
z = agg(agg_idx, inputs, self.aggregation_options)
|
||||
z = bias + res * z
|
||||
z = act(act_idx, z, self.activation_options)
|
||||
|
||||
# the last output node should not be activated
|
||||
z = jax.lax.cond(
|
||||
is_output_node,
|
||||
lambda: z,
|
||||
lambda: act(act_idx, z, self.activation_options)
|
||||
)
|
||||
|
||||
return z
|
||||
|
||||
@@ -25,19 +25,13 @@ class DefaultGenome(BaseGenome):
|
||||
|
||||
if output_transform is not None:
|
||||
try:
|
||||
aux = output_transform(jnp.zeros(num_outputs))
|
||||
_ = output_transform(jnp.zeros(num_outputs))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Output transform function failed: {e}")
|
||||
self.output_transform = output_transform
|
||||
|
||||
def transform(self, nodes, conns):
|
||||
u_conns = unflatten_conns(nodes, conns)
|
||||
|
||||
# DONE: Seems like there is a bug in this line
|
||||
# conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
|
||||
# modified: exist conn and enable is true
|
||||
# conn_enable = jnp.where( (~jnp.isnan(u_conns[0])) & (u_conns[0] == 1), True, False)
|
||||
# advanced modified: when and only when enabled is True
|
||||
conn_enable = u_conns[0] == 1
|
||||
|
||||
# remove enable attr
|
||||
@@ -64,13 +58,7 @@ class DefaultGenome(BaseGenome):
|
||||
|
||||
def hit():
|
||||
ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values)
|
||||
# ins = values * weights[:, i]
|
||||
|
||||
z = self.node_gene.forward(nodes_attrs[i], ins)
|
||||
# z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins)
|
||||
# z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
|
||||
# z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z)
|
||||
|
||||
z = self.node_gene.forward(nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx))
|
||||
new_values = values.at[i].set(z)
|
||||
return new_values
|
||||
|
||||
@@ -78,7 +66,11 @@ class DefaultGenome(BaseGenome):
|
||||
return values
|
||||
|
||||
# the val of input nodes is obtained by the task, not by calculation
|
||||
values = jax.lax.cond(jnp.isin(i, self.input_idx), miss, hit)
|
||||
values = jax.lax.cond(
|
||||
jnp.isin(i, self.input_idx),
|
||||
miss,
|
||||
hit
|
||||
)
|
||||
|
||||
return values, idx + 1
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
from utils import unflatten_conns
|
||||
|
||||
@@ -18,10 +20,18 @@ class RecurrentGenome(BaseGenome):
|
||||
node_gene: BaseNodeGene = DefaultNodeGene(),
|
||||
conn_gene: BaseConnGene = DefaultConnGene(),
|
||||
activate_time: int = 10,
|
||||
output_transform: Callable = None
|
||||
):
|
||||
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
|
||||
self.activate_time = activate_time
|
||||
|
||||
if output_transform is not None:
|
||||
try:
|
||||
_ = output_transform(jnp.zeros(num_outputs))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Output transform function failed: {e}")
|
||||
self.output_transform = output_transform
|
||||
|
||||
def transform(self, nodes, conns):
|
||||
u_conns = unflatten_conns(nodes, conns)
|
||||
|
||||
@@ -52,7 +62,11 @@ class RecurrentGenome(BaseGenome):
|
||||
)(conns, values)
|
||||
|
||||
# calculate nodes
|
||||
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T)
|
||||
is_output_nodes = jnp.isin(
|
||||
jnp.arange(N),
|
||||
self.output_idx
|
||||
)
|
||||
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T, is_output_nodes)
|
||||
return values
|
||||
|
||||
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)
|
||||
|
||||
36
tensorneat/examples/brax/walker.py
Normal file
36
tensorneat/examples/brax/walker.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.rl_env import BraxEnv
|
||||
from utils import Act
|
||||
|
||||
if __name__ == '__main__':
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
genome=DefaultGenome(
|
||||
num_inputs=17,
|
||||
num_outputs=6,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_options=(Act.tanh,),
|
||||
activation_default=Act.tanh,
|
||||
)
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
),
|
||||
),
|
||||
problem=BraxEnv(
|
||||
env_name='walker2d',
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=5000
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
@@ -2,6 +2,7 @@ from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.func_fit import XOR3d
|
||||
from utils import Act
|
||||
|
||||
if __name__ == '__main__':
|
||||
pipeline = Pipeline(
|
||||
@@ -10,13 +11,25 @@ if __name__ == '__main__':
|
||||
genome=DefaultGenome(
|
||||
num_inputs=3,
|
||||
num_outputs=1,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
max_nodes=100,
|
||||
max_conns=200,
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_default=Act.tanh,
|
||||
activation_options=(Act.tanh,),
|
||||
),
|
||||
output_transform=Act.sigmoid, # the activation function for output node
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
compatibility_threshold=3.5,
|
||||
survival_threshold=0.01, # magic
|
||||
),
|
||||
mutation=DefaultMutation(
|
||||
node_add=0.05,
|
||||
conn_add=0.2,
|
||||
node_delete=0,
|
||||
conn_delete=0,
|
||||
)
|
||||
),
|
||||
problem=XOR3d(),
|
||||
generation_limit=10000,
|
||||
|
||||
@@ -28,14 +28,17 @@ if __name__ == '__main__':
|
||||
activation_default=Act.tanh,
|
||||
activation_options=(Act.tanh,),
|
||||
),
|
||||
output_transform=Act.tanh, # the activation function for output node in NEAT
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
compatibility_threshold=3.5,
|
||||
survival_threshold=0.03,
|
||||
),
|
||||
),
|
||||
activation=Act.sigmoid,
|
||||
activation=Act.tanh,
|
||||
activate_time=10,
|
||||
output_transform=Act.sigmoid, # the activation function for output node in HyperNEAT
|
||||
),
|
||||
problem=XOR3d(),
|
||||
generation_limit=300,
|
||||
|
||||
@@ -2,8 +2,7 @@ from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.func_fit import XOR3d
|
||||
from utils.activation import ACT_ALL
|
||||
from utils.aggregation import AGG_ALL
|
||||
from utils.activation import ACT_ALL, Act
|
||||
|
||||
if __name__ == '__main__':
|
||||
pipeline = Pipeline(
|
||||
@@ -18,14 +17,21 @@ if __name__ == '__main__':
|
||||
activate_time=5,
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_options=ACT_ALL,
|
||||
# aggregation_options=AGG_ALL,
|
||||
activation_replace_rate=0.2
|
||||
),
|
||||
output_transform=Act.sigmoid
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
compatibility_threshold=3.5,
|
||||
survival_threshold=0.03,
|
||||
),
|
||||
mutation=DefaultMutation(
|
||||
node_add=0.05,
|
||||
conn_add=0.2,
|
||||
node_delete=0,
|
||||
conn_delete=0,
|
||||
)
|
||||
),
|
||||
problem=XOR3d(),
|
||||
generation_limit=10000,
|
||||
|
||||
@@ -69,7 +69,7 @@ class Pipeline:
|
||||
pop_transformed
|
||||
)
|
||||
|
||||
fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses)
|
||||
# fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses)
|
||||
|
||||
alg_state = self.algorithm.tell(state.alg, fitnesses)
|
||||
|
||||
@@ -80,9 +80,12 @@ class Pipeline:
|
||||
|
||||
def auto_run(self, ini_state):
|
||||
state = ini_state
|
||||
print("start compile")
|
||||
tic = time.time()
|
||||
compiled_step = jax.jit(self.step).lower(ini_state).compile()
|
||||
|
||||
for w in range(self.generation_limit):
|
||||
print(f"compile finished, cost time: {time.time() - tic:.6f}s", )
|
||||
for _ in range(self.generation_limit):
|
||||
|
||||
self.generation_timestamp = time.time()
|
||||
|
||||
@@ -92,11 +95,6 @@ class Pipeline:
|
||||
state, fitnesses = compiled_step(state)
|
||||
|
||||
fitnesses = jax.device_get(fitnesses)
|
||||
for idx, fitnesses_i in enumerate(fitnesses):
|
||||
if np.isnan(fitnesses_i):
|
||||
print("Fitness is nan")
|
||||
print(previous_pop[0][idx], previous_pop[1][idx])
|
||||
assert False
|
||||
|
||||
self.analysis(state, previous_pop, fitnesses)
|
||||
|
||||
|
||||
@@ -8,34 +8,10 @@ from .. import BaseProblem
|
||||
class RLEnv(BaseProblem):
|
||||
jitable = True
|
||||
|
||||
# TODO: move output transform to algorithm
|
||||
def __init__(self, max_step=1000):
|
||||
super().__init__()
|
||||
self.max_step = max_step
|
||||
|
||||
# def evaluate(self, randkey, state, act_func, params):
|
||||
# rng_reset, rng_episode = jax.random.split(randkey)
|
||||
# init_obs, init_env_state = self.reset(rng_reset)
|
||||
|
||||
# def cond_func(carry):
|
||||
# _, _, _, done, _ = carry
|
||||
# return ~done
|
||||
|
||||
# def body_func(carry):
|
||||
# obs, env_state, rng, _, tr = carry # total reward
|
||||
# action = act_func(obs, params)
|
||||
# next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
|
||||
# next_rng, _ = jax.random.split(rng)
|
||||
# return next_obs, next_env_state, next_rng, done, tr + reward
|
||||
|
||||
# _, _, _, _, total_reward = jax.lax.while_loop(
|
||||
# cond_func,
|
||||
# body_func,
|
||||
# (init_obs, init_env_state, rng_episode, False, 0.0)
|
||||
# )
|
||||
|
||||
# return total_reward
|
||||
|
||||
def evaluate(self, randkey, state, act_func, params):
|
||||
rng_reset, rng_episode = jax.random.split(randkey)
|
||||
init_obs, init_env_state = self.reset(rng_reset)
|
||||
@@ -58,6 +34,7 @@ class RLEnv(BaseProblem):
|
||||
)
|
||||
|
||||
return total_reward
|
||||
|
||||
@partial(jax.jit, static_argnums=(0,))
|
||||
def step(self, randkey, env_state, action):
|
||||
return self.env_step(randkey, env_state, action)
|
||||
|
||||
Reference in New Issue
Block a user