add aaai exp
This commit is contained in:
16
aaai_exp.py
16
aaai_exp.py
@@ -21,16 +21,16 @@ def conf_cartpole():
|
||||
pop_size=10000
|
||||
),
|
||||
neat=NeatConfig(
|
||||
inputs=4,
|
||||
outputs=2,
|
||||
inputs=3,
|
||||
outputs=1,
|
||||
),
|
||||
gene=NormalGeneConfig(
|
||||
activation_default=Act.tanh,
|
||||
activation_options=(Act.tanh,),
|
||||
),
|
||||
problem=GymNaxConfig(
|
||||
env_name='CartPole-v1',
|
||||
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1}
|
||||
env_name='Pendulum-v1',
|
||||
output_transform=lambda out: out * 2 # the action of pendulum is [-2, 2]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -100,6 +100,9 @@ def main():
|
||||
alg_state = algorithm.setup(alg_key)
|
||||
|
||||
for i in range(conf.basic.generation_limit):
|
||||
|
||||
total_tic = time()
|
||||
|
||||
pro_key, _ = jax.random.split(pro_key)
|
||||
|
||||
fitnesses, a1, env_time, forward_time = batch_evaluate(
|
||||
@@ -118,9 +121,10 @@ def main():
|
||||
a2 = time() - alg_tic
|
||||
|
||||
alg_time = a1 + a2
|
||||
total_time = time() - total_tic
|
||||
|
||||
print(f"generation:{i}, alg_time: {alg_time:.2f}, env_time: {env_time:.2f}, forward_time: {forward_time:.2f}, "
|
||||
f"max_fitness: {np.max(fitnesses):.2f}")
|
||||
print(f"generation: {i}, alg_time: {alg_time:.2f}, env_time: {env_time:.2f}, forward_time: {forward_time:.2f}, total_time: {total_time: .2f}, "
|
||||
f"max_fitness: {np.max(fitnesses):.2f}", f"avg_fitness: {np.mean(fitnesses):.2f}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
33
examples/a.py
Normal file
33
examples/a.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import jax.random
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
import time
|
||||
|
||||
|
||||
def random_array(key):
|
||||
return jax.random.normal(key, (1000,))
|
||||
|
||||
def random_array_np():
|
||||
return np.random.normal(size=(1000,))
|
||||
|
||||
|
||||
def t_jax():
|
||||
key = jax.random.PRNGKey(42)
|
||||
max_li = []
|
||||
tic = time.time()
|
||||
for _ in range(100):
|
||||
key, sub_key = jax.random.split(key)
|
||||
array = random_array(sub_key)
|
||||
array = jax.device_get(array)
|
||||
max_li.append(max(array))
|
||||
print(max_li, time.time() - tic)
|
||||
|
||||
def t_np():
|
||||
max_li = []
|
||||
tic = time.time()
|
||||
for _ in range(100):
|
||||
max_li.append(max(random_array_np()))
|
||||
print(max_li, time.time() - tic)
|
||||
|
||||
if __name__ == '__main__':
|
||||
t_np()
|
||||
Reference in New Issue
Block a user