Using Evox to deal with RL tasks! With distributed Gym environment!
Three simple tasks in Gym[classical] are tested.
This commit is contained in:
@@ -2,12 +2,16 @@ import jax
|
||||
from jax import Array, numpy as jnp, jit, vmap
|
||||
|
||||
from .utils import I_INT
|
||||
from .activations import act_name2func
|
||||
from .aggregations import agg_name2func
|
||||
|
||||
|
||||
def create_forward_function(config):
|
||||
"""
|
||||
meta method to create forward function
|
||||
"""
|
||||
config['activation_funcs'] = [act_name2func[name] for name in config['activation_option_names']]
|
||||
config['aggregation_funcs'] = [agg_name2func[name] for name in config['aggregation_option_names']]
|
||||
|
||||
def act(idx, z):
|
||||
"""
|
||||
@@ -92,12 +96,11 @@ def create_forward_function(config):
|
||||
common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0))
|
||||
|
||||
if config['forward_way'] == 'single':
|
||||
return jit(batch_forward)
|
||||
return jit(forward)
|
||||
# return jit(batch_forward)
|
||||
|
||||
elif config['forward_way'] == 'pop':
|
||||
return jit(pop_batch_forward)
|
||||
|
||||
elif config['forward_way'] == 'common':
|
||||
return jit(common_forward)
|
||||
|
||||
return jit(forward)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Some graph algorithms implemented in jax.
|
||||
Some graph algorithm implemented in jax.
|
||||
Only used in feed-forward networks.
|
||||
"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user