Merge branch 'main' into advance
This commit is contained in:
13
README.md
13
README.md
@@ -23,11 +23,14 @@
|
|||||||
TensorNEAT is a JAX-based libaray for NeuroEvolution of Augmenting Topologies (NEAT) algorithms, focused on harnessing GPU acceleration to enhance the efficiency of evolving neural network structures for complex tasks. Its core mechanism involves the tensorization of network topologies, enabling parallel processing and significantly boosting computational speed and scalability by leveraging modern hardware accelerators. TensorNEAT is compatible with the [EvoX](https://github.com/EMI-Group/evox/) framewrok.
|
TensorNEAT is a JAX-based libaray for NeuroEvolution of Augmenting Topologies (NEAT) algorithms, focused on harnessing GPU acceleration to enhance the efficiency of evolving neural network structures for complex tasks. Its core mechanism involves the tensorization of network topologies, enabling parallel processing and significantly boosting computational speed and scalability by leveraging modern hardware accelerators. TensorNEAT is compatible with the [EvoX](https://github.com/EMI-Group/evox/) framewrok.
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
TensorNEAT requires:
|
Due to the rapid iteration of JAX versions, configuring the runtime environment for TensorNEAT can be challenging. We recommend the following versions for the relevant libraries:
|
||||||
- jax (version >= 0.4.16)
|
|
||||||
- jaxlib (version >= 0.3.0)
|
- jax (0.4.28)
|
||||||
- brax [optional]
|
- jaxlib (0.4.28+cuda12.cudnn89)
|
||||||
- gymnax [optional]
|
- brax (0.10.3)
|
||||||
|
- gymnax (0.0.8)
|
||||||
|
|
||||||
|
We provide detailed JAX-related environment references in [recommend_environment](recommend_environment.txt). If you encounter any issues while configuring the environment yourself, you can use this as a reference.
|
||||||
|
|
||||||
## Example
|
## Example
|
||||||
Simple Example for XOR problem:
|
Simple Example for XOR problem:
|
||||||
|
|||||||
9
recommend_environment.txt
Normal file
9
recommend_environment.txt
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
brax==0.10.3
|
||||||
|
flax==0.8.4
|
||||||
|
gymnax==0.0.8
|
||||||
|
jax==0.4.28
|
||||||
|
jaxlib==0.4.28+cuda12.cudnn89
|
||||||
|
jaxopt==0.8.3
|
||||||
|
mujoco==3.1.4
|
||||||
|
mujoco-mjx==3.1.4
|
||||||
|
optax==0.2.2
|
||||||
8
tensorneat/.idea/.gitignore
generated
vendored
Normal file
8
tensorneat/.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
# 默认忽略的文件
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# 基于编辑器的 HTTP 客户端请求
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Callable
|
||||||
|
|
||||||
import jax, jax.numpy as jnp
|
import jax, jax.numpy as jnp
|
||||||
|
|
||||||
from utils import State, Act, Agg
|
from utils import State, Act, Agg
|
||||||
@@ -18,6 +20,7 @@ class HyperNEAT(BaseAlgorithm):
|
|||||||
activation=Act.sigmoid,
|
activation=Act.sigmoid,
|
||||||
aggregation=Agg.sum,
|
aggregation=Agg.sum,
|
||||||
activate_time: int = 10,
|
activate_time: int = 10,
|
||||||
|
output_transform: Callable = Act.sigmoid,
|
||||||
):
|
):
|
||||||
assert substrate.query_coors.shape[1] == neat.num_inputs, \
|
assert substrate.query_coors.shape[1] == neat.num_inputs, \
|
||||||
"Substrate input size should be equal to NEAT input size"
|
"Substrate input size should be equal to NEAT input size"
|
||||||
@@ -34,6 +37,7 @@ class HyperNEAT(BaseAlgorithm):
|
|||||||
node_gene=HyperNodeGene(activation, aggregation),
|
node_gene=HyperNodeGene(activation, aggregation),
|
||||||
conn_gene=HyperNEATConnGene(),
|
conn_gene=HyperNEATConnGene(),
|
||||||
activate_time=activate_time,
|
activate_time=activate_time,
|
||||||
|
output_transform=output_transform
|
||||||
)
|
)
|
||||||
|
|
||||||
def setup(self, randkey):
|
def setup(self, randkey):
|
||||||
@@ -102,11 +106,13 @@ class HyperNodeGene(BaseNodeGene):
|
|||||||
self.activation = activation
|
self.activation = activation
|
||||||
self.aggregation = aggregation
|
self.aggregation = aggregation
|
||||||
|
|
||||||
def forward(self, attrs, inputs):
|
def forward(self, attrs, inputs, is_output_node=False):
|
||||||
return self.activation(
|
return jax.lax.cond(
|
||||||
self.aggregation(inputs)
|
is_output_node,
|
||||||
)
|
lambda: self.aggregation(inputs), # output node does not need activation
|
||||||
|
lambda: self.activation(self.aggregation(inputs))
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
class HyperNEATConnGene(BaseConnGene):
|
class HyperNEATConnGene(BaseConnGene):
|
||||||
custom_attrs = ['weight']
|
custom_attrs = ['weight']
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import jax, jax.numpy as jnp
|
|||||||
|
|
||||||
from .base import BaseCrossover
|
from .base import BaseCrossover
|
||||||
|
|
||||||
|
|
||||||
class DefaultCrossover(BaseCrossover):
|
class DefaultCrossover(BaseCrossover):
|
||||||
|
|
||||||
def __call__(self, randkey, genome, nodes1, conns1, nodes2, conns2):
|
def __call__(self, randkey, genome, nodes1, conns1, nodes2, conns2):
|
||||||
@@ -14,17 +15,19 @@ class DefaultCrossover(BaseCrossover):
|
|||||||
# crossover nodes
|
# crossover nodes
|
||||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||||
# make homologous genes align in nodes2 align with nodes1
|
# make homologous genes align in nodes2 align with nodes1
|
||||||
nodes2 = self.align_array(keys1, keys2, nodes2, False)
|
nodes2 = self.align_array(keys1, keys2, nodes2, is_conn=False)
|
||||||
|
|
||||||
# For not homologous genes, use the value of nodes1(winner)
|
# For not homologous genes, use the value of nodes1(winner)
|
||||||
# For homologous genes, use the crossover result between nodes1 and nodes2
|
# For homologous genes, use the crossover result between nodes1 and nodes2
|
||||||
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, self.crossover_gene(randkey_1, nodes1, nodes2))
|
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1,
|
||||||
|
self.crossover_gene(randkey_1, nodes1, nodes2, is_conn=False))
|
||||||
|
|
||||||
# crossover connections
|
# crossover connections
|
||||||
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
|
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
|
||||||
conns2 = self.align_array(con_keys1, con_keys2, conns2, True)
|
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
|
||||||
|
|
||||||
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, self.crossover_gene(randkey_2, conns1, conns2))
|
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1,
|
||||||
|
self.crossover_gene(randkey_2, conns1, conns2, is_conn=True))
|
||||||
|
|
||||||
return new_nodes, new_conns
|
return new_nodes, new_conns
|
||||||
|
|
||||||
@@ -54,14 +57,14 @@ class DefaultCrossover(BaseCrossover):
|
|||||||
|
|
||||||
return refactor_ar2
|
return refactor_ar2
|
||||||
|
|
||||||
def crossover_gene(self, rand_key, g1, g2):
|
def crossover_gene(self, rand_key, g1, g2, is_conn):
|
||||||
"""
|
|
||||||
crossover two genes
|
|
||||||
:param rand_key:
|
|
||||||
:param g1:
|
|
||||||
:param g2:
|
|
||||||
:return:
|
|
||||||
only gene with the same key will be crossover, thus don't need to consider change key
|
|
||||||
"""
|
|
||||||
r = jax.random.uniform(rand_key, shape=g1.shape)
|
r = jax.random.uniform(rand_key, shape=g1.shape)
|
||||||
return jnp.where(r > 0.5, g1, g2)
|
new_gene = jnp.where(r > 0.5, g1, g2)
|
||||||
|
if is_conn: # fix enabled
|
||||||
|
enabled = jnp.where(
|
||||||
|
g1[:, 2] + g2[:, 2] > 0, # any of them is enabled
|
||||||
|
1,
|
||||||
|
0
|
||||||
|
)
|
||||||
|
new_gene = new_gene.at[:, 2].set(enabled)
|
||||||
|
return new_gene
|
||||||
|
|||||||
@@ -154,8 +154,8 @@ class DefaultMutation(BaseMutation):
|
|||||||
nodes_keys = jax.random.split(k1, num=nodes.shape[0])
|
nodes_keys = jax.random.split(k1, num=nodes.shape[0])
|
||||||
conns_keys = jax.random.split(k2, num=conns.shape[0])
|
conns_keys = jax.random.split(k2, num=conns.shape[0])
|
||||||
|
|
||||||
new_nodes = jax.vmap(genome.node_gene.mutate, in_axes=(0, 0))(nodes_keys, nodes)
|
new_nodes = jax.vmap(genome.node_gene.mutate)(nodes_keys, nodes)
|
||||||
new_conns = jax.vmap(genome.conn_gene.mutate, in_axes=(0, 0))(conns_keys, conns)
|
new_conns = jax.vmap(genome.conn_gene.mutate)(conns_keys, conns)
|
||||||
|
|
||||||
# nan nodes not changed
|
# nan nodes not changed
|
||||||
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)
|
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)
|
||||||
|
|||||||
@@ -8,5 +8,5 @@ class BaseNodeGene(BaseGene):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, attrs, inputs):
|
def forward(self, attrs, inputs, is_output_node=False):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -95,11 +95,17 @@ class DefaultNodeGene(BaseNodeGene):
|
|||||||
(node1[4] != node2[4])
|
(node1[4] != node2[4])
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, attrs, inputs):
|
def forward(self, attrs, inputs, is_output_node=False):
|
||||||
bias, res, act_idx, agg_idx = attrs
|
bias, res, act_idx, agg_idx = attrs
|
||||||
|
|
||||||
z = agg(agg_idx, inputs, self.aggregation_options)
|
z = agg(agg_idx, inputs, self.aggregation_options)
|
||||||
z = bias + res * z
|
z = bias + res * z
|
||||||
z = act(act_idx, z, self.activation_options)
|
|
||||||
|
# the last output node should not be activated
|
||||||
|
z = jax.lax.cond(
|
||||||
|
is_output_node,
|
||||||
|
lambda: z,
|
||||||
|
lambda: act(act_idx, z, self.activation_options)
|
||||||
|
)
|
||||||
|
|
||||||
return z
|
return z
|
||||||
|
|||||||
@@ -25,19 +25,13 @@ class DefaultGenome(BaseGenome):
|
|||||||
|
|
||||||
if output_transform is not None:
|
if output_transform is not None:
|
||||||
try:
|
try:
|
||||||
aux = output_transform(jnp.zeros(num_outputs))
|
_ = output_transform(jnp.zeros(num_outputs))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Output transform function failed: {e}")
|
raise ValueError(f"Output transform function failed: {e}")
|
||||||
self.output_transform = output_transform
|
self.output_transform = output_transform
|
||||||
|
|
||||||
def transform(self, nodes, conns):
|
def transform(self, nodes, conns):
|
||||||
u_conns = unflatten_conns(nodes, conns)
|
u_conns = unflatten_conns(nodes, conns)
|
||||||
|
|
||||||
# DONE: Seems like there is a bug in this line
|
|
||||||
# conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
|
|
||||||
# modified: exist conn and enable is true
|
|
||||||
# conn_enable = jnp.where( (~jnp.isnan(u_conns[0])) & (u_conns[0] == 1), True, False)
|
|
||||||
# advanced modified: when and only when enabled is True
|
|
||||||
conn_enable = u_conns[0] == 1
|
conn_enable = u_conns[0] == 1
|
||||||
|
|
||||||
# remove enable attr
|
# remove enable attr
|
||||||
@@ -64,13 +58,7 @@ class DefaultGenome(BaseGenome):
|
|||||||
|
|
||||||
def hit():
|
def hit():
|
||||||
ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values)
|
ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values)
|
||||||
# ins = values * weights[:, i]
|
z = self.node_gene.forward(nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx))
|
||||||
|
|
||||||
z = self.node_gene.forward(nodes_attrs[i], ins)
|
|
||||||
# z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins)
|
|
||||||
# z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
|
|
||||||
# z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z)
|
|
||||||
|
|
||||||
new_values = values.at[i].set(z)
|
new_values = values.at[i].set(z)
|
||||||
return new_values
|
return new_values
|
||||||
|
|
||||||
@@ -78,7 +66,11 @@ class DefaultGenome(BaseGenome):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
# the val of input nodes is obtained by the task, not by calculation
|
# the val of input nodes is obtained by the task, not by calculation
|
||||||
values = jax.lax.cond(jnp.isin(i, self.input_idx), miss, hit)
|
values = jax.lax.cond(
|
||||||
|
jnp.isin(i, self.input_idx),
|
||||||
|
miss,
|
||||||
|
hit
|
||||||
|
)
|
||||||
|
|
||||||
return values, idx + 1
|
return values, idx + 1
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Callable
|
||||||
|
|
||||||
import jax, jax.numpy as jnp
|
import jax, jax.numpy as jnp
|
||||||
from utils import unflatten_conns
|
from utils import unflatten_conns
|
||||||
|
|
||||||
@@ -18,10 +20,18 @@ class RecurrentGenome(BaseGenome):
|
|||||||
node_gene: BaseNodeGene = DefaultNodeGene(),
|
node_gene: BaseNodeGene = DefaultNodeGene(),
|
||||||
conn_gene: BaseConnGene = DefaultConnGene(),
|
conn_gene: BaseConnGene = DefaultConnGene(),
|
||||||
activate_time: int = 10,
|
activate_time: int = 10,
|
||||||
|
output_transform: Callable = None
|
||||||
):
|
):
|
||||||
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
|
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
|
||||||
self.activate_time = activate_time
|
self.activate_time = activate_time
|
||||||
|
|
||||||
|
if output_transform is not None:
|
||||||
|
try:
|
||||||
|
_ = output_transform(jnp.zeros(num_outputs))
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Output transform function failed: {e}")
|
||||||
|
self.output_transform = output_transform
|
||||||
|
|
||||||
def transform(self, nodes, conns):
|
def transform(self, nodes, conns):
|
||||||
u_conns = unflatten_conns(nodes, conns)
|
u_conns = unflatten_conns(nodes, conns)
|
||||||
|
|
||||||
@@ -52,7 +62,11 @@ class RecurrentGenome(BaseGenome):
|
|||||||
)(conns, values)
|
)(conns, values)
|
||||||
|
|
||||||
# calculate nodes
|
# calculate nodes
|
||||||
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T)
|
is_output_nodes = jnp.isin(
|
||||||
|
jnp.arange(N),
|
||||||
|
self.output_idx
|
||||||
|
)
|
||||||
|
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T, is_output_nodes)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)
|
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)
|
||||||
|
|||||||
36
tensorneat/examples/brax/walker.py
Normal file
36
tensorneat/examples/brax/walker.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from pipeline import Pipeline
|
||||||
|
from algorithm.neat import *
|
||||||
|
|
||||||
|
from problem.rl_env import BraxEnv
|
||||||
|
from utils import Act
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pipeline = Pipeline(
|
||||||
|
algorithm=NEAT(
|
||||||
|
species=DefaultSpecies(
|
||||||
|
genome=DefaultGenome(
|
||||||
|
num_inputs=17,
|
||||||
|
num_outputs=6,
|
||||||
|
max_nodes=50,
|
||||||
|
max_conns=100,
|
||||||
|
node_gene=DefaultNodeGene(
|
||||||
|
activation_options=(Act.tanh,),
|
||||||
|
activation_default=Act.tanh,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
pop_size=10000,
|
||||||
|
species_size=10,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
problem=BraxEnv(
|
||||||
|
env_name='walker2d',
|
||||||
|
),
|
||||||
|
generation_limit=10000,
|
||||||
|
fitness_target=5000
|
||||||
|
)
|
||||||
|
|
||||||
|
# initialize state
|
||||||
|
state = pipeline.setup()
|
||||||
|
# print(state)
|
||||||
|
# run until terminate
|
||||||
|
state, best = pipeline.auto_run(state)
|
||||||
@@ -2,6 +2,7 @@ from pipeline import Pipeline
|
|||||||
from algorithm.neat import *
|
from algorithm.neat import *
|
||||||
|
|
||||||
from problem.func_fit import XOR3d
|
from problem.func_fit import XOR3d
|
||||||
|
from utils import Act
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pipeline = Pipeline(
|
pipeline = Pipeline(
|
||||||
@@ -10,13 +11,25 @@ if __name__ == '__main__':
|
|||||||
genome=DefaultGenome(
|
genome=DefaultGenome(
|
||||||
num_inputs=3,
|
num_inputs=3,
|
||||||
num_outputs=1,
|
num_outputs=1,
|
||||||
max_nodes=50,
|
max_nodes=100,
|
||||||
max_conns=100,
|
max_conns=200,
|
||||||
|
node_gene=DefaultNodeGene(
|
||||||
|
activation_default=Act.tanh,
|
||||||
|
activation_options=(Act.tanh,),
|
||||||
|
),
|
||||||
|
output_transform=Act.sigmoid, # the activation function for output node
|
||||||
),
|
),
|
||||||
pop_size=10000,
|
pop_size=10000,
|
||||||
species_size=10,
|
species_size=10,
|
||||||
compatibility_threshold=3.5,
|
compatibility_threshold=3.5,
|
||||||
|
survival_threshold=0.01, # magic
|
||||||
),
|
),
|
||||||
|
mutation=DefaultMutation(
|
||||||
|
node_add=0.05,
|
||||||
|
conn_add=0.2,
|
||||||
|
node_delete=0,
|
||||||
|
conn_delete=0,
|
||||||
|
)
|
||||||
),
|
),
|
||||||
problem=XOR3d(),
|
problem=XOR3d(),
|
||||||
generation_limit=10000,
|
generation_limit=10000,
|
||||||
|
|||||||
@@ -28,14 +28,17 @@ if __name__ == '__main__':
|
|||||||
activation_default=Act.tanh,
|
activation_default=Act.tanh,
|
||||||
activation_options=(Act.tanh,),
|
activation_options=(Act.tanh,),
|
||||||
),
|
),
|
||||||
|
output_transform=Act.tanh, # the activation function for output node in NEAT
|
||||||
),
|
),
|
||||||
pop_size=10000,
|
pop_size=10000,
|
||||||
species_size=10,
|
species_size=10,
|
||||||
compatibility_threshold=3.5,
|
compatibility_threshold=3.5,
|
||||||
|
survival_threshold=0.03,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
activation=Act.sigmoid,
|
activation=Act.tanh,
|
||||||
activate_time=10,
|
activate_time=10,
|
||||||
|
output_transform=Act.sigmoid, # the activation function for output node in HyperNEAT
|
||||||
),
|
),
|
||||||
problem=XOR3d(),
|
problem=XOR3d(),
|
||||||
generation_limit=300,
|
generation_limit=300,
|
||||||
|
|||||||
@@ -2,8 +2,7 @@ from pipeline import Pipeline
|
|||||||
from algorithm.neat import *
|
from algorithm.neat import *
|
||||||
|
|
||||||
from problem.func_fit import XOR3d
|
from problem.func_fit import XOR3d
|
||||||
from utils.activation import ACT_ALL
|
from utils.activation import ACT_ALL, Act
|
||||||
from utils.aggregation import AGG_ALL
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pipeline = Pipeline(
|
pipeline = Pipeline(
|
||||||
@@ -18,14 +17,21 @@ if __name__ == '__main__':
|
|||||||
activate_time=5,
|
activate_time=5,
|
||||||
node_gene=DefaultNodeGene(
|
node_gene=DefaultNodeGene(
|
||||||
activation_options=ACT_ALL,
|
activation_options=ACT_ALL,
|
||||||
# aggregation_options=AGG_ALL,
|
|
||||||
activation_replace_rate=0.2
|
activation_replace_rate=0.2
|
||||||
),
|
),
|
||||||
|
output_transform=Act.sigmoid
|
||||||
),
|
),
|
||||||
pop_size=10000,
|
pop_size=10000,
|
||||||
species_size=10,
|
species_size=10,
|
||||||
compatibility_threshold=3.5,
|
compatibility_threshold=3.5,
|
||||||
|
survival_threshold=0.03,
|
||||||
),
|
),
|
||||||
|
mutation=DefaultMutation(
|
||||||
|
node_add=0.05,
|
||||||
|
conn_add=0.2,
|
||||||
|
node_delete=0,
|
||||||
|
conn_delete=0,
|
||||||
|
)
|
||||||
),
|
),
|
||||||
problem=XOR3d(),
|
problem=XOR3d(),
|
||||||
generation_limit=10000,
|
generation_limit=10000,
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ class Pipeline:
|
|||||||
pop_transformed
|
pop_transformed
|
||||||
)
|
)
|
||||||
|
|
||||||
fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses)
|
# fitnesses = jnp.where(jnp.isnan(fitnesses), -1e6, fitnesses)
|
||||||
|
|
||||||
alg_state = self.algorithm.tell(state.alg, fitnesses)
|
alg_state = self.algorithm.tell(state.alg, fitnesses)
|
||||||
|
|
||||||
@@ -80,9 +80,12 @@ class Pipeline:
|
|||||||
|
|
||||||
def auto_run(self, ini_state):
|
def auto_run(self, ini_state):
|
||||||
state = ini_state
|
state = ini_state
|
||||||
|
print("start compile")
|
||||||
|
tic = time.time()
|
||||||
compiled_step = jax.jit(self.step).lower(ini_state).compile()
|
compiled_step = jax.jit(self.step).lower(ini_state).compile()
|
||||||
|
|
||||||
for w in range(self.generation_limit):
|
print(f"compile finished, cost time: {time.time() - tic:.6f}s", )
|
||||||
|
for _ in range(self.generation_limit):
|
||||||
|
|
||||||
self.generation_timestamp = time.time()
|
self.generation_timestamp = time.time()
|
||||||
|
|
||||||
@@ -92,11 +95,6 @@ class Pipeline:
|
|||||||
state, fitnesses = compiled_step(state)
|
state, fitnesses = compiled_step(state)
|
||||||
|
|
||||||
fitnesses = jax.device_get(fitnesses)
|
fitnesses = jax.device_get(fitnesses)
|
||||||
for idx, fitnesses_i in enumerate(fitnesses):
|
|
||||||
if np.isnan(fitnesses_i):
|
|
||||||
print("Fitness is nan")
|
|
||||||
print(previous_pop[0][idx], previous_pop[1][idx])
|
|
||||||
assert False
|
|
||||||
|
|
||||||
self.analysis(state, previous_pop, fitnesses)
|
self.analysis(state, previous_pop, fitnesses)
|
||||||
|
|
||||||
|
|||||||
@@ -8,34 +8,10 @@ from .. import BaseProblem
|
|||||||
class RLEnv(BaseProblem):
|
class RLEnv(BaseProblem):
|
||||||
jitable = True
|
jitable = True
|
||||||
|
|
||||||
# TODO: move output transform to algorithm
|
|
||||||
def __init__(self, max_step=1000):
|
def __init__(self, max_step=1000):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_step = max_step
|
self.max_step = max_step
|
||||||
|
|
||||||
# def evaluate(self, randkey, state, act_func, params):
|
|
||||||
# rng_reset, rng_episode = jax.random.split(randkey)
|
|
||||||
# init_obs, init_env_state = self.reset(rng_reset)
|
|
||||||
|
|
||||||
# def cond_func(carry):
|
|
||||||
# _, _, _, done, _ = carry
|
|
||||||
# return ~done
|
|
||||||
|
|
||||||
# def body_func(carry):
|
|
||||||
# obs, env_state, rng, _, tr = carry # total reward
|
|
||||||
# action = act_func(obs, params)
|
|
||||||
# next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
|
|
||||||
# next_rng, _ = jax.random.split(rng)
|
|
||||||
# return next_obs, next_env_state, next_rng, done, tr + reward
|
|
||||||
|
|
||||||
# _, _, _, _, total_reward = jax.lax.while_loop(
|
|
||||||
# cond_func,
|
|
||||||
# body_func,
|
|
||||||
# (init_obs, init_env_state, rng_episode, False, 0.0)
|
|
||||||
# )
|
|
||||||
|
|
||||||
# return total_reward
|
|
||||||
|
|
||||||
def evaluate(self, randkey, state, act_func, params):
|
def evaluate(self, randkey, state, act_func, params):
|
||||||
rng_reset, rng_episode = jax.random.split(randkey)
|
rng_reset, rng_episode = jax.random.split(randkey)
|
||||||
init_obs, init_env_state = self.reset(rng_reset)
|
init_obs, init_env_state = self.reset(rng_reset)
|
||||||
@@ -58,6 +34,7 @@ class RLEnv(BaseProblem):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return total_reward
|
return total_reward
|
||||||
|
|
||||||
@partial(jax.jit, static_argnums=(0,))
|
@partial(jax.jit, static_argnums=(0,))
|
||||||
def step(self, randkey, env_state, action):
|
def step(self, randkey, env_state, action):
|
||||||
return self.env_step(randkey, env_state, action)
|
return self.env_step(randkey, env_state, action)
|
||||||
|
|||||||
Reference in New Issue
Block a user