add aaai exp
This commit is contained in:
16
aaai_exp.py
16
aaai_exp.py
@@ -21,16 +21,16 @@ def conf_cartpole():
|
|||||||
pop_size=10000
|
pop_size=10000
|
||||||
),
|
),
|
||||||
neat=NeatConfig(
|
neat=NeatConfig(
|
||||||
inputs=4,
|
inputs=3,
|
||||||
outputs=2,
|
outputs=1,
|
||||||
),
|
),
|
||||||
gene=NormalGeneConfig(
|
gene=NormalGeneConfig(
|
||||||
activation_default=Act.tanh,
|
activation_default=Act.tanh,
|
||||||
activation_options=(Act.tanh,),
|
activation_options=(Act.tanh,),
|
||||||
),
|
),
|
||||||
problem=GymNaxConfig(
|
problem=GymNaxConfig(
|
||||||
env_name='CartPole-v1',
|
env_name='Pendulum-v1',
|
||||||
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1}
|
output_transform=lambda out: out * 2 # the action of pendulum is [-2, 2]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -100,6 +100,9 @@ def main():
|
|||||||
alg_state = algorithm.setup(alg_key)
|
alg_state = algorithm.setup(alg_key)
|
||||||
|
|
||||||
for i in range(conf.basic.generation_limit):
|
for i in range(conf.basic.generation_limit):
|
||||||
|
|
||||||
|
total_tic = time()
|
||||||
|
|
||||||
pro_key, _ = jax.random.split(pro_key)
|
pro_key, _ = jax.random.split(pro_key)
|
||||||
|
|
||||||
fitnesses, a1, env_time, forward_time = batch_evaluate(
|
fitnesses, a1, env_time, forward_time = batch_evaluate(
|
||||||
@@ -118,9 +121,10 @@ def main():
|
|||||||
a2 = time() - alg_tic
|
a2 = time() - alg_tic
|
||||||
|
|
||||||
alg_time = a1 + a2
|
alg_time = a1 + a2
|
||||||
|
total_time = time() - total_tic
|
||||||
|
|
||||||
print(f"generation:{i}, alg_time: {alg_time:.2f}, env_time: {env_time:.2f}, forward_time: {forward_time:.2f}, "
|
print(f"generation: {i}, alg_time: {alg_time:.2f}, env_time: {env_time:.2f}, forward_time: {forward_time:.2f}, total_time: {total_time: .2f}, "
|
||||||
f"max_fitness: {np.max(fitnesses):.2f}")
|
f"max_fitness: {np.max(fitnesses):.2f}", f"avg_fitness: {np.mean(fitnesses):.2f}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
33
examples/a.py
Normal file
33
examples/a.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import jax.random
|
||||||
|
import numpy as np
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
def random_array(key):
|
||||||
|
return jax.random.normal(key, (1000,))
|
||||||
|
|
||||||
|
def random_array_np():
|
||||||
|
return np.random.normal(size=(1000,))
|
||||||
|
|
||||||
|
|
||||||
|
def t_jax():
|
||||||
|
key = jax.random.PRNGKey(42)
|
||||||
|
max_li = []
|
||||||
|
tic = time.time()
|
||||||
|
for _ in range(100):
|
||||||
|
key, sub_key = jax.random.split(key)
|
||||||
|
array = random_array(sub_key)
|
||||||
|
array = jax.device_get(array)
|
||||||
|
max_li.append(max(array))
|
||||||
|
print(max_li, time.time() - tic)
|
||||||
|
|
||||||
|
def t_np():
|
||||||
|
max_li = []
|
||||||
|
tic = time.time()
|
||||||
|
for _ in range(100):
|
||||||
|
max_li.append(max(random_array_np()))
|
||||||
|
print(max_li, time.time() - tic)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
t_np()
|
||||||
Reference in New Issue
Block a user