Using Evox to deal with RL tasks! With distributed Gym environment!

Three simple tasks in Gym[classical] are tested.
This commit is contained in:
wls2002
2023-07-04 15:44:08 +08:00
parent c4d34e877b
commit 7bf46575f4
18 changed files with 547 additions and 43 deletions

View File

@@ -2,12 +2,16 @@ import jax
from jax import Array, numpy as jnp, jit, vmap
from .utils import I_INT
from .activations import act_name2func
from .aggregations import agg_name2func
def create_forward_function(config):
"""
meta method to create forward function
"""
config['activation_funcs'] = [act_name2func[name] for name in config['activation_option_names']]
config['aggregation_funcs'] = [agg_name2func[name] for name in config['aggregation_option_names']]
def act(idx, z):
"""
@@ -92,12 +96,11 @@ def create_forward_function(config):
common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0))
if config['forward_way'] == 'single':
return jit(batch_forward)
return jit(forward)
# return jit(batch_forward)
elif config['forward_way'] == 'pop':
return jit(pop_batch_forward)
elif config['forward_way'] == 'common':
return jit(common_forward)
return jit(forward)

View File

@@ -1,5 +1,5 @@
"""
Some graph algorithms implemented in jax.
Some graph algorithm implemented in jax.
Only used in feed-forward networks.
"""