add aaai exp

This commit is contained in:
wls2002
2023-08-13 12:30:50 +08:00
parent 33e6ef4916
commit 9ae9d9dfdc
2 changed files with 43 additions and 6 deletions

View File

@@ -21,16 +21,16 @@ def conf_cartpole():
pop_size=10000 pop_size=10000
), ),
neat=NeatConfig( neat=NeatConfig(
inputs=4, inputs=3,
outputs=2, outputs=1,
), ),
gene=NormalGeneConfig( gene=NormalGeneConfig(
activation_default=Act.tanh, activation_default=Act.tanh,
activation_options=(Act.tanh,), activation_options=(Act.tanh,),
), ),
problem=GymNaxConfig( problem=GymNaxConfig(
env_name='CartPole-v1', env_name='Pendulum-v1',
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1} output_transform=lambda out: out * 2 # the action of pendulum is [-2, 2]
) )
) )
@@ -100,6 +100,9 @@ def main():
alg_state = algorithm.setup(alg_key) alg_state = algorithm.setup(alg_key)
for i in range(conf.basic.generation_limit): for i in range(conf.basic.generation_limit):
total_tic = time()
pro_key, _ = jax.random.split(pro_key) pro_key, _ = jax.random.split(pro_key)
fitnesses, a1, env_time, forward_time = batch_evaluate( fitnesses, a1, env_time, forward_time = batch_evaluate(
@@ -118,9 +121,10 @@ def main():
a2 = time() - alg_tic a2 = time() - alg_tic
alg_time = a1 + a2 alg_time = a1 + a2
total_time = time() - total_tic
print(f"generation:{i}, alg_time: {alg_time:.2f}, env_time: {env_time:.2f}, forward_time: {forward_time:.2f}, " print(f"generation: {i}, alg_time: {alg_time:.2f}, env_time: {env_time:.2f}, forward_time: {forward_time:.2f}, total_time: {total_time: .2f}, "
f"max_fitness: {np.max(fitnesses):.2f}") f"max_fitness: {np.max(fitnesses):.2f}", f"avg_fitness: {np.mean(fitnesses):.2f}")
if __name__ == '__main__': if __name__ == '__main__':

33
examples/a.py Normal file
View File

@@ -0,0 +1,33 @@
import jax.random
import numpy as np
import jax.numpy as jnp
import time
def random_array(key):
return jax.random.normal(key, (1000,))
def random_array_np():
return np.random.normal(size=(1000,))
def t_jax():
key = jax.random.PRNGKey(42)
max_li = []
tic = time.time()
for _ in range(100):
key, sub_key = jax.random.split(key)
array = random_array(sub_key)
array = jax.device_get(array)
max_li.append(max(array))
print(max_li, time.time() - tic)
def t_np():
max_li = []
tic = time.time()
for _ in range(100):
max_li.append(max(random_array_np()))
print(max_li, time.time() - tic)
if __name__ == '__main__':
t_np()