Merge branch 'main' into advance

This commit is contained in:
WLS2002
2024-05-24 19:42:03 +08:00
committed by GitHub
17 changed files with 156 additions and 82 deletions

View File

@@ -23,11 +23,14 @@
TensorNEAT is a JAX-based libaray for NeuroEvolution of Augmenting Topologies (NEAT) algorithms, focused on harnessing GPU acceleration to enhance the efficiency of evolving neural network structures for complex tasks. Its core mechanism involves the tensorization of network topologies, enabling parallel processing and significantly boosting computational speed and scalability by leveraging modern hardware accelerators. TensorNEAT is compatible with the [EvoX](https://github.com/EMI-Group/evox/) framewrok.
## Requirements
TensorNEAT requires:
- jax (version >= 0.4.16)
- jaxlib (version >= 0.3.0)
- brax [optional]
- gymnax [optional]
Due to the rapid iteration of JAX versions, configuring the runtime environment for TensorNEAT can be challenging. We recommend the following versions for the relevant libraries:
- jax (0.4.28)
- jaxlib (0.4.28+cuda12.cudnn89)
- brax (0.10.3)
- gymnax (0.0.8)
We provide detailed JAX-related environment references in [recommend_environment](recommend_environment.txt). If you encounter any issues while configuring the environment yourself, you can use this as a reference.
## Example
Simple Example for XOR problem:

View File

@@ -0,0 +1,9 @@
brax==0.10.3
flax==0.8.4
gymnax==0.0.8
jax==0.4.28
jaxlib==0.4.28+cuda12.cudnn89
jaxopt==0.8.3
mujoco==3.1.4
mujoco-mjx==3.1.4
optax==0.2.2

8
tensorneat/.idea/.gitignore generated vendored Normal file
View File

@@ -0,0 +1,8 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

View File

@@ -1,3 +1,5 @@
from typing import Callable
import jax, jax.numpy as jnp
from utils import State, Act, Agg
@@ -18,6 +20,7 @@ class HyperNEAT(BaseAlgorithm):
activation=Act.sigmoid,
aggregation=Agg.sum,
activate_time: int = 10,
output_transform: Callable = Act.sigmoid,
):
assert substrate.query_coors.shape[1] == neat.num_inputs, \
"Substrate input size should be equal to NEAT input size"
@@ -34,6 +37,7 @@ class HyperNEAT(BaseAlgorithm):
node_gene=HyperNodeGene(activation, aggregation),
conn_gene=HyperNEATConnGene(),
activate_time=activate_time,
output_transform=output_transform
)
def setup(self, randkey):
@@ -102,11 +106,13 @@ class HyperNodeGene(BaseNodeGene):
self.activation = activation
self.aggregation = aggregation
def forward(self, attrs, inputs):
return self.activation(
self.aggregation(inputs)
)
def forward(self, attrs, inputs, is_output_node=False):
return jax.lax.cond(
is_output_node,
lambda: self.aggregation(inputs), # output node does not need activation
lambda: self.activation(self.aggregation(inputs))
)
class HyperNEATConnGene(BaseConnGene):
custom_attrs = ['weight']

View File

@@ -2,6 +2,7 @@ import jax, jax.numpy as jnp
from .base import BaseCrossover
class DefaultCrossover(BaseCrossover):
def __call__(self, randkey, genome, nodes1, conns1, nodes2, conns2):
@@ -14,17 +15,19 @@ class DefaultCrossover(BaseCrossover):
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
# make homologous genes align in nodes2 align with nodes1
nodes2 = self.align_array(keys1, keys2, nodes2, False)
nodes2 = self.align_array(keys1, keys2, nodes2, is_conn=False)
# For not homologous genes, use the value of nodes1(winner)
# For homologous genes, use the crossover result between nodes1 and nodes2
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, self.crossover_gene(randkey_1, nodes1, nodes2))
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1,
self.crossover_gene(randkey_1, nodes1, nodes2, is_conn=False))
# crossover connections
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
conns2 = self.align_array(con_keys1, con_keys2, conns2, True)
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, self.crossover_gene(randkey_2, conns1, conns2))
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1,
self.crossover_gene(randkey_2, conns1, conns2, is_conn=True))
return new_nodes, new_conns
@@ -54,14 +57,14 @@ class DefaultCrossover(BaseCrossover):
return refactor_ar2
def crossover_gene(self, rand_key, g1, g2):
"""
crossover two genes
:param rand_key:
:param g1:
:param g2:
:return:
only gene with the same key will be crossover, thus don't need to consider change key
"""
def crossover_gene(self, rand_key, g1, g2, is_conn):
r = jax.random.uniform(rand_key, shape=g1.shape)
return jnp.where(r > 0.5, g1, g2)
new_gene = jnp.where(r > 0.5, g1, g2)
if is_conn: # fix enabled
enabled = jnp.where(
g1[:, 2] + g2[:, 2] > 0, # any of them is enabled
1,
0
)
new_gene = new_gene.at[:, 2].set(enabled)
return new_gene

View File

@@ -154,8 +154,8 @@ class DefaultMutation(BaseMutation):
nodes_keys = jax.random.split(k1, num=nodes.shape[0])
conns_keys = jax.random.split(k2, num=conns.shape[0])
new_nodes = jax.vmap(genome.node_gene.mutate, in_axes=(0, 0))(nodes_keys, nodes)
new_conns = jax.vmap(genome.conn_gene.mutate, in_axes=(0, 0))(conns_keys, conns)
new_nodes = jax.vmap(genome.node_gene.mutate)(nodes_keys, nodes)
new_conns = jax.vmap(genome.conn_gene.mutate)(conns_keys, conns)
# nan nodes not changed
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)

View File

@@ -8,5 +8,5 @@ class BaseNodeGene(BaseGene):
def __init__(self):
super().__init__()
def forward(self, attrs, inputs):
def forward(self, attrs, inputs, is_output_node=False):
raise NotImplementedError

View File

@@ -95,11 +95,17 @@ class DefaultNodeGene(BaseNodeGene):
(node1[4] != node2[4])
)
def forward(self, attrs, inputs):
def forward(self, attrs, inputs, is_output_node=False):
bias, res, act_idx, agg_idx = attrs
z = agg(agg_idx, inputs, self.aggregation_options)
z = bias + res * z
z = act(act_idx, z, self.activation_options)
# the last output node should not be activated
z = jax.lax.cond(
is_output_node,
lambda: z,
lambda: act(act_idx, z, self.activation_options)
)
return z

View File

@@ -25,19 +25,13 @@ class DefaultGenome(BaseGenome):
if output_transform is not None:
try:
aux = output_transform(jnp.zeros(num_outputs))
_ = output_transform(jnp.zeros(num_outputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
self.output_transform = output_transform
def transform(self, nodes, conns):
u_conns = unflatten_conns(nodes, conns)
# DONE: Seems like there is a bug in this line
# conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
# modified: exist conn and enable is true
# conn_enable = jnp.where( (~jnp.isnan(u_conns[0])) & (u_conns[0] == 1), True, False)
# advanced modified: when and only when enabled is True
conn_enable = u_conns[0] == 1
# remove enable attr
@@ -64,13 +58,7 @@ class DefaultGenome(BaseGenome):
def hit():
ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values)
# ins = values * weights[:, i]
z = self.node_gene.forward(nodes_attrs[i], ins)
# z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins)
# z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
# z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z)
z = self.node_gene.forward(nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx))
new_values = values.at[i].set(z)
return new_values
@@ -78,7 +66,11 @@ class DefaultGenome(BaseGenome):
return values
# the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(jnp.isin(i, self.input_idx), miss, hit)
values = jax.lax.cond(
jnp.isin(i, self.input_idx),
miss,
hit
)
return values, idx + 1

View File

@@ -1,3 +1,5 @@
from typing import Callable
import jax, jax.numpy as jnp
from utils import unflatten_conns
@@ -18,10 +20,18 @@ class RecurrentGenome(BaseGenome):
node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(),
activate_time: int = 10,
output_transform: Callable = None
):
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
self.activate_time = activate_time
if output_transform is not None:
try:
_ = output_transform(jnp.zeros(num_outputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
self.output_transform = output_transform
def transform(self, nodes, conns):
u_conns = unflatten_conns(nodes, conns)
@@ -52,7 +62,11 @@ class RecurrentGenome(BaseGenome):
)(conns, values)
# calculate nodes
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T)
is_output_nodes = jnp.isin(
jnp.arange(N),
self.output_idx
)
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T, is_output_nodes)
return values
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)

View File

@@ -0,0 +1,36 @@
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import BraxEnv
from utils import Act
if __name__ == '__main__':
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=17,
num_outputs=6,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
)
),
pop_size=10000,
species_size=10,
),
),
problem=BraxEnv(
env_name='walker2d',
),
generation_limit=10000,
fitness_target=5000
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -2,6 +2,7 @@ from pipeline import Pipeline
from algorithm.neat import *
from problem.func_fit import XOR3d
from utils import Act
if __name__ == '__main__':
pipeline = Pipeline(
@@ -10,13 +11,25 @@ if __name__ == '__main__':
genome=DefaultGenome(
num_inputs=3,
num_outputs=1,
max_nodes=50,
max_conns=100,
max_nodes=100,
max_conns=200,
node_gene=DefaultNodeGene(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
output_transform=Act.sigmoid, # the activation function for output node
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
survival_threshold=0.01, # magic
),
mutation=DefaultMutation(
node_add=0.05,
conn_add=0.2,
node_delete=0,
conn_delete=0,
)
),
problem=XOR3d(),
generation_limit=10000,

View File

@@ -28,14 +28,17 @@ if __name__ == '__main__':
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
output_transform=Act.tanh, # the activation function for output node in NEAT
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
survival_threshold=0.03,
),
),
activation=Act.sigmoid,
activation=Act.tanh,
activate_time=10,
output_transform=Act.sigmoid, # the activation function for output node in HyperNEAT
),
problem=XOR3d(),
generation_limit=300,

View File

@@ -2,8 +2,7 @@ from pipeline import Pipeline
from algorithm.neat import *
from problem.func_fit import XOR3d
from utils.activation import ACT_ALL
from utils.aggregation import AGG_ALL
from utils.activation import ACT_ALL, Act
if __name__ == '__main__':
pipeline = Pipeline(
@@ -18,14 +17,21 @@ if __name__ == '__main__':
activate_time=5,
node_gene=DefaultNodeGene(
activation_options=ACT_ALL,
# aggregation_options=AGG_ALL,
activation_replace_rate=0.2
),
output_transform=Act.sigmoid
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
survival_threshold=0.03,
),
mutation=DefaultMutation(
node_add=0.05,
conn_add=0.2,
node_delete=0,
conn_delete=0,
)
),
problem=XOR3d(),
generation_limit=10000,

View File

@@ -69,7 +69,7 @@ class Pipeline:
pop_transformed
)
fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses)
# fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses)
alg_state = self.algorithm.tell(state.alg, fitnesses)
@@ -80,9 +80,12 @@ class Pipeline:
def auto_run(self, ini_state):
state = ini_state
print("start compile")
tic = time.time()
compiled_step = jax.jit(self.step).lower(ini_state).compile()
for w in range(self.generation_limit):
print(f"compile finished, cost time: {time.time() - tic:.6f}s", )
for _ in range(self.generation_limit):
self.generation_timestamp = time.time()
@@ -92,11 +95,6 @@ class Pipeline:
state, fitnesses = compiled_step(state)
fitnesses = jax.device_get(fitnesses)
for idx, fitnesses_i in enumerate(fitnesses):
if np.isnan(fitnesses_i):
print("Fitness is nan")
print(previous_pop[0][idx], previous_pop[1][idx])
assert False
self.analysis(state, previous_pop, fitnesses)

View File

@@ -8,34 +8,10 @@ from .. import BaseProblem
class RLEnv(BaseProblem):
jitable = True
# TODO: move output transform to algorithm
def __init__(self, max_step=1000):
super().__init__()
self.max_step = max_step
# def evaluate(self, randkey, state, act_func, params):
# rng_reset, rng_episode = jax.random.split(randkey)
# init_obs, init_env_state = self.reset(rng_reset)
# def cond_func(carry):
# _, _, _, done, _ = carry
# return ~done
# def body_func(carry):
# obs, env_state, rng, _, tr = carry # total reward
# action = act_func(obs, params)
# next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
# next_rng, _ = jax.random.split(rng)
# return next_obs, next_env_state, next_rng, done, tr + reward
# _, _, _, _, total_reward = jax.lax.while_loop(
# cond_func,
# body_func,
# (init_obs, init_env_state, rng_episode, False, 0.0)
# )
# return total_reward
def evaluate(self, randkey, state, act_func, params):
rng_reset, rng_episode = jax.random.split(randkey)
init_obs, init_env_state = self.reset(rng_reset)
@@ -58,6 +34,7 @@ class RLEnv(BaseProblem):
)
return total_reward
@partial(jax.jit, static_argnums=(0,))
def step(self, randkey, env_state, action):
return self.env_step(randkey, env_state, action)