Merge branch 'main' into advance

This commit is contained in:
WLS2002
2024-05-24 19:42:03 +08:00
committed by GitHub
17 changed files with 156 additions and 82 deletions

8
tensorneat/.idea/.gitignore generated vendored Normal file
View File

@@ -0,0 +1,8 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

View File

@@ -1,3 +1,5 @@
from typing import Callable
import jax, jax.numpy as jnp
from utils import State, Act, Agg
@@ -18,6 +20,7 @@ class HyperNEAT(BaseAlgorithm):
activation=Act.sigmoid,
aggregation=Agg.sum,
activate_time: int = 10,
output_transform: Callable = Act.sigmoid,
):
assert substrate.query_coors.shape[1] == neat.num_inputs, \
"Substrate input size should be equal to NEAT input size"
@@ -34,6 +37,7 @@ class HyperNEAT(BaseAlgorithm):
node_gene=HyperNodeGene(activation, aggregation),
conn_gene=HyperNEATConnGene(),
activate_time=activate_time,
output_transform=output_transform
)
def setup(self, randkey):
@@ -102,11 +106,13 @@ class HyperNodeGene(BaseNodeGene):
self.activation = activation
self.aggregation = aggregation
def forward(self, attrs, inputs):
return self.activation(
self.aggregation(inputs)
)
def forward(self, attrs, inputs, is_output_node=False):
return jax.lax.cond(
is_output_node,
lambda: self.aggregation(inputs), # output node does not need activation
lambda: self.activation(self.aggregation(inputs))
)
class HyperNEATConnGene(BaseConnGene):
custom_attrs = ['weight']

View File

@@ -1,3 +1,3 @@
class BaseCrossover:
def __call__(self, randkey, genome, nodes1, nodes2, conns1, conns2):
raise NotImplementedError
raise NotImplementedError

View File

@@ -2,6 +2,7 @@ import jax, jax.numpy as jnp
from .base import BaseCrossover
class DefaultCrossover(BaseCrossover):
def __call__(self, randkey, genome, nodes1, conns1, nodes2, conns2):
@@ -14,17 +15,19 @@ class DefaultCrossover(BaseCrossover):
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
# make homologous genes align in nodes2 align with nodes1
nodes2 = self.align_array(keys1, keys2, nodes2, False)
nodes2 = self.align_array(keys1, keys2, nodes2, is_conn=False)
# For not homologous genes, use the value of nodes1(winner)
# For homologous genes, use the crossover result between nodes1 and nodes2
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, self.crossover_gene(randkey_1, nodes1, nodes2))
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1,
self.crossover_gene(randkey_1, nodes1, nodes2, is_conn=False))
# crossover connections
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
conns2 = self.align_array(con_keys1, con_keys2, conns2, True)
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, self.crossover_gene(randkey_2, conns1, conns2))
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1,
self.crossover_gene(randkey_2, conns1, conns2, is_conn=True))
return new_nodes, new_conns
@@ -54,14 +57,14 @@ class DefaultCrossover(BaseCrossover):
return refactor_ar2
def crossover_gene(self, rand_key, g1, g2):
"""
crossover two genes
:param rand_key:
:param g1:
:param g2:
:return:
only gene with the same key will be crossover, thus don't need to consider change key
"""
def crossover_gene(self, rand_key, g1, g2, is_conn):
r = jax.random.uniform(rand_key, shape=g1.shape)
return jnp.where(r > 0.5, g1, g2)
new_gene = jnp.where(r > 0.5, g1, g2)
if is_conn: # fix enabled
enabled = jnp.where(
g1[:, 2] + g2[:, 2] > 0, # any of them is enabled
1,
0
)
new_gene = new_gene.at[:, 2].set(enabled)
return new_gene

View File

@@ -154,8 +154,8 @@ class DefaultMutation(BaseMutation):
nodes_keys = jax.random.split(k1, num=nodes.shape[0])
conns_keys = jax.random.split(k2, num=conns.shape[0])
new_nodes = jax.vmap(genome.node_gene.mutate, in_axes=(0, 0))(nodes_keys, nodes)
new_conns = jax.vmap(genome.conn_gene.mutate, in_axes=(0, 0))(conns_keys, conns)
new_nodes = jax.vmap(genome.node_gene.mutate)(nodes_keys, nodes)
new_conns = jax.vmap(genome.conn_gene.mutate)(conns_keys, conns)
# nan nodes not changed
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)

View File

@@ -8,5 +8,5 @@ class BaseNodeGene(BaseGene):
def __init__(self):
super().__init__()
def forward(self, attrs, inputs):
def forward(self, attrs, inputs, is_output_node=False):
raise NotImplementedError

View File

@@ -95,11 +95,17 @@ class DefaultNodeGene(BaseNodeGene):
(node1[4] != node2[4])
)
def forward(self, attrs, inputs):
def forward(self, attrs, inputs, is_output_node=False):
bias, res, act_idx, agg_idx = attrs
z = agg(agg_idx, inputs, self.aggregation_options)
z = bias + res * z
z = act(act_idx, z, self.activation_options)
# the last output node should not be activated
z = jax.lax.cond(
is_output_node,
lambda: z,
lambda: act(act_idx, z, self.activation_options)
)
return z

View File

@@ -25,19 +25,13 @@ class DefaultGenome(BaseGenome):
if output_transform is not None:
try:
aux = output_transform(jnp.zeros(num_outputs))
_ = output_transform(jnp.zeros(num_outputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
self.output_transform = output_transform
def transform(self, nodes, conns):
u_conns = unflatten_conns(nodes, conns)
# DONE: Seems like there is a bug in this line
# conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
# modified: exist conn and enable is true
# conn_enable = jnp.where( (~jnp.isnan(u_conns[0])) & (u_conns[0] == 1), True, False)
# advanced modified: when and only when enabled is True
conn_enable = u_conns[0] == 1
# remove enable attr
@@ -64,13 +58,7 @@ class DefaultGenome(BaseGenome):
def hit():
ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values)
# ins = values * weights[:, i]
z = self.node_gene.forward(nodes_attrs[i], ins)
# z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins)
# z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
# z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z)
z = self.node_gene.forward(nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx))
new_values = values.at[i].set(z)
return new_values
@@ -78,7 +66,11 @@ class DefaultGenome(BaseGenome):
return values
# the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(jnp.isin(i, self.input_idx), miss, hit)
values = jax.lax.cond(
jnp.isin(i, self.input_idx),
miss,
hit
)
return values, idx + 1

View File

@@ -1,3 +1,5 @@
from typing import Callable
import jax, jax.numpy as jnp
from utils import unflatten_conns
@@ -18,10 +20,18 @@ class RecurrentGenome(BaseGenome):
node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(),
activate_time: int = 10,
output_transform: Callable = None
):
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
self.activate_time = activate_time
if output_transform is not None:
try:
_ = output_transform(jnp.zeros(num_outputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
self.output_transform = output_transform
def transform(self, nodes, conns):
u_conns = unflatten_conns(nodes, conns)
@@ -52,7 +62,11 @@ class RecurrentGenome(BaseGenome):
)(conns, values)
# calculate nodes
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T)
is_output_nodes = jnp.isin(
jnp.arange(N),
self.output_idx
)
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T, is_output_nodes)
return values
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)

View File

@@ -0,0 +1,36 @@
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import BraxEnv
from utils import Act
if __name__ == '__main__':
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=17,
num_outputs=6,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
)
),
pop_size=10000,
species_size=10,
),
),
problem=BraxEnv(
env_name='walker2d',
),
generation_limit=10000,
fitness_target=5000
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -2,6 +2,7 @@ from pipeline import Pipeline
from algorithm.neat import *
from problem.func_fit import XOR3d
from utils import Act
if __name__ == '__main__':
pipeline = Pipeline(
@@ -10,13 +11,25 @@ if __name__ == '__main__':
genome=DefaultGenome(
num_inputs=3,
num_outputs=1,
max_nodes=50,
max_conns=100,
max_nodes=100,
max_conns=200,
node_gene=DefaultNodeGene(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
output_transform=Act.sigmoid, # the activation function for output node
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
survival_threshold=0.01, # magic
),
mutation=DefaultMutation(
node_add=0.05,
conn_add=0.2,
node_delete=0,
conn_delete=0,
)
),
problem=XOR3d(),
generation_limit=10000,

View File

@@ -28,14 +28,17 @@ if __name__ == '__main__':
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
output_transform=Act.tanh, # the activation function for output node in NEAT
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
survival_threshold=0.03,
),
),
activation=Act.sigmoid,
activation=Act.tanh,
activate_time=10,
output_transform=Act.sigmoid, # the activation function for output node in HyperNEAT
),
problem=XOR3d(),
generation_limit=300,

View File

@@ -2,8 +2,7 @@ from pipeline import Pipeline
from algorithm.neat import *
from problem.func_fit import XOR3d
from utils.activation import ACT_ALL
from utils.aggregation import AGG_ALL
from utils.activation import ACT_ALL, Act
if __name__ == '__main__':
pipeline = Pipeline(
@@ -18,14 +17,21 @@ if __name__ == '__main__':
activate_time=5,
node_gene=DefaultNodeGene(
activation_options=ACT_ALL,
# aggregation_options=AGG_ALL,
activation_replace_rate=0.2
),
output_transform=Act.sigmoid
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
survival_threshold=0.03,
),
mutation=DefaultMutation(
node_add=0.05,
conn_add=0.2,
node_delete=0,
conn_delete=0,
)
),
problem=XOR3d(),
generation_limit=10000,

View File

@@ -69,7 +69,7 @@ class Pipeline:
pop_transformed
)
fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses)
# fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses)
alg_state = self.algorithm.tell(state.alg, fitnesses)
@@ -80,9 +80,12 @@ class Pipeline:
def auto_run(self, ini_state):
state = ini_state
print("start compile")
tic = time.time()
compiled_step = jax.jit(self.step).lower(ini_state).compile()
for w in range(self.generation_limit):
print(f"compile finished, cost time: {time.time() - tic:.6f}s", )
for _ in range(self.generation_limit):
self.generation_timestamp = time.time()
@@ -92,11 +95,6 @@ class Pipeline:
state, fitnesses = compiled_step(state)
fitnesses = jax.device_get(fitnesses)
for idx, fitnesses_i in enumerate(fitnesses):
if np.isnan(fitnesses_i):
print("Fitness is nan")
print(previous_pop[0][idx], previous_pop[1][idx])
assert False
self.analysis(state, previous_pop, fitnesses)

View File

@@ -8,34 +8,10 @@ from .. import BaseProblem
class RLEnv(BaseProblem):
jitable = True
# TODO: move output transform to algorithm
def __init__(self, max_step=1000):
super().__init__()
self.max_step = max_step
# def evaluate(self, randkey, state, act_func, params):
# rng_reset, rng_episode = jax.random.split(randkey)
# init_obs, init_env_state = self.reset(rng_reset)
# def cond_func(carry):
# _, _, _, done, _ = carry
# return ~done
# def body_func(carry):
# obs, env_state, rng, _, tr = carry # total reward
# action = act_func(obs, params)
# next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
# next_rng, _ = jax.random.split(rng)
# return next_obs, next_env_state, next_rng, done, tr + reward
# _, _, _, _, total_reward = jax.lax.while_loop(
# cond_func,
# body_func,
# (init_obs, init_env_state, rng_episode, False, 0.0)
# )
# return total_reward
def evaluate(self, randkey, state, act_func, params):
rng_reset, rng_episode = jax.random.split(randkey)
init_obs, init_env_state = self.reset(rng_reset)
@@ -58,6 +34,7 @@ class RLEnv(BaseProblem):
)
return total_reward
@partial(jax.jit, static_argnums=(0,))
def step(self, randkey, env_state, action):
return self.env_step(randkey, env_state, action)