Merge branch 'main' into advance
This commit is contained in:
8
tensorneat/.idea/.gitignore
generated
vendored
Normal file
8
tensorneat/.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
# 默认忽略的文件
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# 基于编辑器的 HTTP 客户端请求
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from utils import State, Act, Agg
|
||||
@@ -18,6 +20,7 @@ class HyperNEAT(BaseAlgorithm):
|
||||
activation=Act.sigmoid,
|
||||
aggregation=Agg.sum,
|
||||
activate_time: int = 10,
|
||||
output_transform: Callable = Act.sigmoid,
|
||||
):
|
||||
assert substrate.query_coors.shape[1] == neat.num_inputs, \
|
||||
"Substrate input size should be equal to NEAT input size"
|
||||
@@ -34,6 +37,7 @@ class HyperNEAT(BaseAlgorithm):
|
||||
node_gene=HyperNodeGene(activation, aggregation),
|
||||
conn_gene=HyperNEATConnGene(),
|
||||
activate_time=activate_time,
|
||||
output_transform=output_transform
|
||||
)
|
||||
|
||||
def setup(self, randkey):
|
||||
@@ -102,11 +106,13 @@ class HyperNodeGene(BaseNodeGene):
|
||||
self.activation = activation
|
||||
self.aggregation = aggregation
|
||||
|
||||
def forward(self, attrs, inputs):
|
||||
return self.activation(
|
||||
self.aggregation(inputs)
|
||||
)
|
||||
def forward(self, attrs, inputs, is_output_node=False):
|
||||
return jax.lax.cond(
|
||||
is_output_node,
|
||||
lambda: self.aggregation(inputs), # output node does not need activation
|
||||
lambda: self.activation(self.aggregation(inputs))
|
||||
|
||||
)
|
||||
|
||||
class HyperNEATConnGene(BaseConnGene):
|
||||
custom_attrs = ['weight']
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
class BaseCrossover:
|
||||
def __call__(self, randkey, genome, nodes1, nodes2, conns1, conns2):
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -2,6 +2,7 @@ import jax, jax.numpy as jnp
|
||||
|
||||
from .base import BaseCrossover
|
||||
|
||||
|
||||
class DefaultCrossover(BaseCrossover):
|
||||
|
||||
def __call__(self, randkey, genome, nodes1, conns1, nodes2, conns2):
|
||||
@@ -14,17 +15,19 @@ class DefaultCrossover(BaseCrossover):
|
||||
# crossover nodes
|
||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||
# make homologous genes align in nodes2 align with nodes1
|
||||
nodes2 = self.align_array(keys1, keys2, nodes2, False)
|
||||
nodes2 = self.align_array(keys1, keys2, nodes2, is_conn=False)
|
||||
|
||||
# For not homologous genes, use the value of nodes1(winner)
|
||||
# For homologous genes, use the crossover result between nodes1 and nodes2
|
||||
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, self.crossover_gene(randkey_1, nodes1, nodes2))
|
||||
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1,
|
||||
self.crossover_gene(randkey_1, nodes1, nodes2, is_conn=False))
|
||||
|
||||
# crossover connections
|
||||
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
|
||||
conns2 = self.align_array(con_keys1, con_keys2, conns2, True)
|
||||
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
|
||||
|
||||
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, self.crossover_gene(randkey_2, conns1, conns2))
|
||||
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1,
|
||||
self.crossover_gene(randkey_2, conns1, conns2, is_conn=True))
|
||||
|
||||
return new_nodes, new_conns
|
||||
|
||||
@@ -54,14 +57,14 @@ class DefaultCrossover(BaseCrossover):
|
||||
|
||||
return refactor_ar2
|
||||
|
||||
def crossover_gene(self, rand_key, g1, g2):
|
||||
"""
|
||||
crossover two genes
|
||||
:param rand_key:
|
||||
:param g1:
|
||||
:param g2:
|
||||
:return:
|
||||
only gene with the same key will be crossover, thus don't need to consider change key
|
||||
"""
|
||||
def crossover_gene(self, rand_key, g1, g2, is_conn):
|
||||
r = jax.random.uniform(rand_key, shape=g1.shape)
|
||||
return jnp.where(r > 0.5, g1, g2)
|
||||
new_gene = jnp.where(r > 0.5, g1, g2)
|
||||
if is_conn: # fix enabled
|
||||
enabled = jnp.where(
|
||||
g1[:, 2] + g2[:, 2] > 0, # any of them is enabled
|
||||
1,
|
||||
0
|
||||
)
|
||||
new_gene = new_gene.at[:, 2].set(enabled)
|
||||
return new_gene
|
||||
|
||||
@@ -154,8 +154,8 @@ class DefaultMutation(BaseMutation):
|
||||
nodes_keys = jax.random.split(k1, num=nodes.shape[0])
|
||||
conns_keys = jax.random.split(k2, num=conns.shape[0])
|
||||
|
||||
new_nodes = jax.vmap(genome.node_gene.mutate, in_axes=(0, 0))(nodes_keys, nodes)
|
||||
new_conns = jax.vmap(genome.conn_gene.mutate, in_axes=(0, 0))(conns_keys, conns)
|
||||
new_nodes = jax.vmap(genome.node_gene.mutate)(nodes_keys, nodes)
|
||||
new_conns = jax.vmap(genome.conn_gene.mutate)(conns_keys, conns)
|
||||
|
||||
# nan nodes not changed
|
||||
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)
|
||||
|
||||
@@ -8,5 +8,5 @@ class BaseNodeGene(BaseGene):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, attrs, inputs):
|
||||
def forward(self, attrs, inputs, is_output_node=False):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -95,11 +95,17 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
(node1[4] != node2[4])
|
||||
)
|
||||
|
||||
def forward(self, attrs, inputs):
|
||||
def forward(self, attrs, inputs, is_output_node=False):
|
||||
bias, res, act_idx, agg_idx = attrs
|
||||
|
||||
z = agg(agg_idx, inputs, self.aggregation_options)
|
||||
z = bias + res * z
|
||||
z = act(act_idx, z, self.activation_options)
|
||||
|
||||
# the last output node should not be activated
|
||||
z = jax.lax.cond(
|
||||
is_output_node,
|
||||
lambda: z,
|
||||
lambda: act(act_idx, z, self.activation_options)
|
||||
)
|
||||
|
||||
return z
|
||||
|
||||
@@ -25,19 +25,13 @@ class DefaultGenome(BaseGenome):
|
||||
|
||||
if output_transform is not None:
|
||||
try:
|
||||
aux = output_transform(jnp.zeros(num_outputs))
|
||||
_ = output_transform(jnp.zeros(num_outputs))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Output transform function failed: {e}")
|
||||
self.output_transform = output_transform
|
||||
|
||||
def transform(self, nodes, conns):
|
||||
u_conns = unflatten_conns(nodes, conns)
|
||||
|
||||
# DONE: Seems like there is a bug in this line
|
||||
# conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
|
||||
# modified: exist conn and enable is true
|
||||
# conn_enable = jnp.where( (~jnp.isnan(u_conns[0])) & (u_conns[0] == 1), True, False)
|
||||
# advanced modified: when and only when enabled is True
|
||||
conn_enable = u_conns[0] == 1
|
||||
|
||||
# remove enable attr
|
||||
@@ -64,13 +58,7 @@ class DefaultGenome(BaseGenome):
|
||||
|
||||
def hit():
|
||||
ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values)
|
||||
# ins = values * weights[:, i]
|
||||
|
||||
z = self.node_gene.forward(nodes_attrs[i], ins)
|
||||
# z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins)
|
||||
# z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
|
||||
# z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z)
|
||||
|
||||
z = self.node_gene.forward(nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx))
|
||||
new_values = values.at[i].set(z)
|
||||
return new_values
|
||||
|
||||
@@ -78,7 +66,11 @@ class DefaultGenome(BaseGenome):
|
||||
return values
|
||||
|
||||
# the val of input nodes is obtained by the task, not by calculation
|
||||
values = jax.lax.cond(jnp.isin(i, self.input_idx), miss, hit)
|
||||
values = jax.lax.cond(
|
||||
jnp.isin(i, self.input_idx),
|
||||
miss,
|
||||
hit
|
||||
)
|
||||
|
||||
return values, idx + 1
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
from utils import unflatten_conns
|
||||
|
||||
@@ -18,10 +20,18 @@ class RecurrentGenome(BaseGenome):
|
||||
node_gene: BaseNodeGene = DefaultNodeGene(),
|
||||
conn_gene: BaseConnGene = DefaultConnGene(),
|
||||
activate_time: int = 10,
|
||||
output_transform: Callable = None
|
||||
):
|
||||
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
|
||||
self.activate_time = activate_time
|
||||
|
||||
if output_transform is not None:
|
||||
try:
|
||||
_ = output_transform(jnp.zeros(num_outputs))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Output transform function failed: {e}")
|
||||
self.output_transform = output_transform
|
||||
|
||||
def transform(self, nodes, conns):
|
||||
u_conns = unflatten_conns(nodes, conns)
|
||||
|
||||
@@ -52,7 +62,11 @@ class RecurrentGenome(BaseGenome):
|
||||
)(conns, values)
|
||||
|
||||
# calculate nodes
|
||||
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T)
|
||||
is_output_nodes = jnp.isin(
|
||||
jnp.arange(N),
|
||||
self.output_idx
|
||||
)
|
||||
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T, is_output_nodes)
|
||||
return values
|
||||
|
||||
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)
|
||||
|
||||
36
tensorneat/examples/brax/walker.py
Normal file
36
tensorneat/examples/brax/walker.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.rl_env import BraxEnv
|
||||
from utils import Act
|
||||
|
||||
if __name__ == '__main__':
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
genome=DefaultGenome(
|
||||
num_inputs=17,
|
||||
num_outputs=6,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_options=(Act.tanh,),
|
||||
activation_default=Act.tanh,
|
||||
)
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
),
|
||||
),
|
||||
problem=BraxEnv(
|
||||
env_name='walker2d',
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=5000
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
@@ -2,6 +2,7 @@ from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.func_fit import XOR3d
|
||||
from utils import Act
|
||||
|
||||
if __name__ == '__main__':
|
||||
pipeline = Pipeline(
|
||||
@@ -10,13 +11,25 @@ if __name__ == '__main__':
|
||||
genome=DefaultGenome(
|
||||
num_inputs=3,
|
||||
num_outputs=1,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
max_nodes=100,
|
||||
max_conns=200,
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_default=Act.tanh,
|
||||
activation_options=(Act.tanh,),
|
||||
),
|
||||
output_transform=Act.sigmoid, # the activation function for output node
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
compatibility_threshold=3.5,
|
||||
survival_threshold=0.01, # magic
|
||||
),
|
||||
mutation=DefaultMutation(
|
||||
node_add=0.05,
|
||||
conn_add=0.2,
|
||||
node_delete=0,
|
||||
conn_delete=0,
|
||||
)
|
||||
),
|
||||
problem=XOR3d(),
|
||||
generation_limit=10000,
|
||||
|
||||
@@ -28,14 +28,17 @@ if __name__ == '__main__':
|
||||
activation_default=Act.tanh,
|
||||
activation_options=(Act.tanh,),
|
||||
),
|
||||
output_transform=Act.tanh, # the activation function for output node in NEAT
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
compatibility_threshold=3.5,
|
||||
survival_threshold=0.03,
|
||||
),
|
||||
),
|
||||
activation=Act.sigmoid,
|
||||
activation=Act.tanh,
|
||||
activate_time=10,
|
||||
output_transform=Act.sigmoid, # the activation function for output node in HyperNEAT
|
||||
),
|
||||
problem=XOR3d(),
|
||||
generation_limit=300,
|
||||
|
||||
@@ -2,8 +2,7 @@ from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.func_fit import XOR3d
|
||||
from utils.activation import ACT_ALL
|
||||
from utils.aggregation import AGG_ALL
|
||||
from utils.activation import ACT_ALL, Act
|
||||
|
||||
if __name__ == '__main__':
|
||||
pipeline = Pipeline(
|
||||
@@ -18,14 +17,21 @@ if __name__ == '__main__':
|
||||
activate_time=5,
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_options=ACT_ALL,
|
||||
# aggregation_options=AGG_ALL,
|
||||
activation_replace_rate=0.2
|
||||
),
|
||||
output_transform=Act.sigmoid
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
compatibility_threshold=3.5,
|
||||
survival_threshold=0.03,
|
||||
),
|
||||
mutation=DefaultMutation(
|
||||
node_add=0.05,
|
||||
conn_add=0.2,
|
||||
node_delete=0,
|
||||
conn_delete=0,
|
||||
)
|
||||
),
|
||||
problem=XOR3d(),
|
||||
generation_limit=10000,
|
||||
|
||||
@@ -69,7 +69,7 @@ class Pipeline:
|
||||
pop_transformed
|
||||
)
|
||||
|
||||
fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses)
|
||||
# fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses)
|
||||
|
||||
alg_state = self.algorithm.tell(state.alg, fitnesses)
|
||||
|
||||
@@ -80,9 +80,12 @@ class Pipeline:
|
||||
|
||||
def auto_run(self, ini_state):
|
||||
state = ini_state
|
||||
print("start compile")
|
||||
tic = time.time()
|
||||
compiled_step = jax.jit(self.step).lower(ini_state).compile()
|
||||
|
||||
for w in range(self.generation_limit):
|
||||
print(f"compile finished, cost time: {time.time() - tic:.6f}s", )
|
||||
for _ in range(self.generation_limit):
|
||||
|
||||
self.generation_timestamp = time.time()
|
||||
|
||||
@@ -92,11 +95,6 @@ class Pipeline:
|
||||
state, fitnesses = compiled_step(state)
|
||||
|
||||
fitnesses = jax.device_get(fitnesses)
|
||||
for idx, fitnesses_i in enumerate(fitnesses):
|
||||
if np.isnan(fitnesses_i):
|
||||
print("Fitness is nan")
|
||||
print(previous_pop[0][idx], previous_pop[1][idx])
|
||||
assert False
|
||||
|
||||
self.analysis(state, previous_pop, fitnesses)
|
||||
|
||||
|
||||
@@ -8,34 +8,10 @@ from .. import BaseProblem
|
||||
class RLEnv(BaseProblem):
|
||||
jitable = True
|
||||
|
||||
# TODO: move output transform to algorithm
|
||||
def __init__(self, max_step=1000):
|
||||
super().__init__()
|
||||
self.max_step = max_step
|
||||
|
||||
# def evaluate(self, randkey, state, act_func, params):
|
||||
# rng_reset, rng_episode = jax.random.split(randkey)
|
||||
# init_obs, init_env_state = self.reset(rng_reset)
|
||||
|
||||
# def cond_func(carry):
|
||||
# _, _, _, done, _ = carry
|
||||
# return ~done
|
||||
|
||||
# def body_func(carry):
|
||||
# obs, env_state, rng, _, tr = carry # total reward
|
||||
# action = act_func(obs, params)
|
||||
# next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
|
||||
# next_rng, _ = jax.random.split(rng)
|
||||
# return next_obs, next_env_state, next_rng, done, tr + reward
|
||||
|
||||
# _, _, _, _, total_reward = jax.lax.while_loop(
|
||||
# cond_func,
|
||||
# body_func,
|
||||
# (init_obs, init_env_state, rng_episode, False, 0.0)
|
||||
# )
|
||||
|
||||
# return total_reward
|
||||
|
||||
def evaluate(self, randkey, state, act_func, params):
|
||||
rng_reset, rng_episode = jax.random.split(randkey)
|
||||
init_obs, init_env_state = self.reset(rng_reset)
|
||||
@@ -58,6 +34,7 @@ class RLEnv(BaseProblem):
|
||||
)
|
||||
|
||||
return total_reward
|
||||
|
||||
@partial(jax.jit, static_argnums=(0,))
|
||||
def step(self, randkey, env_state, action):
|
||||
return self.env_step(randkey, env_state, action)
|
||||
|
||||
Reference in New Issue
Block a user