refactor names;
delete useless
This commit is contained in:
131
aaai_exp.py
131
aaai_exp.py
@@ -1,131 +0,0 @@
|
||||
from typing import Callable
|
||||
from time import time
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp, vmap, jit
|
||||
import gymnax
|
||||
import numpy as np
|
||||
|
||||
from config import *
|
||||
from algorithm import NEAT
|
||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
||||
from problem.rl_env import GymNaxConfig, GymNaxEnv
|
||||
|
||||
|
||||
def conf_cartpole():
|
||||
return Config(
|
||||
basic=BasicConfig(
|
||||
seed=42,
|
||||
fitness_target=500,
|
||||
generation_limit=150,
|
||||
pop_size=10000
|
||||
),
|
||||
neat=NeatConfig(
|
||||
inputs=3,
|
||||
outputs=1,
|
||||
),
|
||||
gene=NormalGeneConfig(
|
||||
activation_default=Act.tanh,
|
||||
activation_options=(Act.tanh,),
|
||||
),
|
||||
problem=GymNaxConfig(
|
||||
env_name='Pendulum-v1',
|
||||
output_transform=lambda out: out * 2 # the action of pendulum is [-2, 2]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def batch_evaluate(
|
||||
key,
|
||||
alg_state,
|
||||
genomes,
|
||||
env_params,
|
||||
batch_transform: Callable,
|
||||
batch_act: Callable,
|
||||
batch_reset: Callable,
|
||||
batch_step: Callable,
|
||||
):
|
||||
alg_time, env_time, forward_time = 0, 0, 0
|
||||
pop_size = genomes.nodes.shape[0]
|
||||
|
||||
alg_tic = time()
|
||||
genomes_transform = batch_transform(alg_state, genomes)
|
||||
alg_time += time() - alg_tic
|
||||
|
||||
reset_keys = jax.random.split(key, pop_size)
|
||||
observations, states = batch_reset(reset_keys, env_params)
|
||||
|
||||
done = np.zeros(pop_size, dtype=bool)
|
||||
fitnesses = np.zeros(pop_size)
|
||||
|
||||
while not np.all(done):
|
||||
key, _ = jax.random.split(key)
|
||||
vmap_keys = jax.random.split(key, pop_size)
|
||||
|
||||
forward_tic = time()
|
||||
actions = batch_act(alg_state, observations, genomes_transform).block_until_ready()
|
||||
forward_time += time() - forward_tic
|
||||
|
||||
env_tic = time()
|
||||
observations, states, reward, current_done, _ = batch_step(vmap_keys, states, actions, env_params)
|
||||
reward, current_done = jax.device_get([reward, current_done])
|
||||
env_time += time() - env_tic
|
||||
|
||||
fitnesses += reward * np.logical_not(done)
|
||||
done = np.logical_or(done, current_done)
|
||||
|
||||
return fitnesses, alg_time, env_time, forward_time
|
||||
|
||||
|
||||
def main():
|
||||
conf = conf_cartpole()
|
||||
algorithm = NEAT(conf, NormalGene)
|
||||
|
||||
def act(state, inputs, genome):
|
||||
res = algorithm.act(state, inputs, genome)
|
||||
return conf.problem.output_transform(res)
|
||||
|
||||
batch_transform = jit(vmap(algorithm.transform, in_axes=(None, 0)))
|
||||
# (state, obs, genome_transform) -> action
|
||||
batch_act = jit(vmap(act, in_axes=(None, 0, 0)))
|
||||
|
||||
env, env_params = gymnax.make(conf.problem.env_name)
|
||||
# (seed, params) -> (ini_obs, ini_state)
|
||||
batch_reset = jit(vmap(env.reset, in_axes=(0, None)))
|
||||
# (seed, state, action, params) -> (obs, state, reward, done, info)
|
||||
batch_step = jit(vmap(env.step, in_axes=(0, 0, 0, None)))
|
||||
|
||||
key = jax.random.PRNGKey(conf.basic.seed)
|
||||
alg_key, pro_key = jax.random.split(key)
|
||||
alg_state = algorithm.setup(alg_key)
|
||||
|
||||
for i in range(conf.basic.generation_limit):
|
||||
|
||||
total_tic = time()
|
||||
|
||||
pro_key, _ = jax.random.split(pro_key)
|
||||
|
||||
fitnesses, a1, env_time, forward_time = batch_evaluate(
|
||||
pro_key,
|
||||
alg_state,
|
||||
algorithm.ask(alg_state),
|
||||
env_params,
|
||||
batch_transform,
|
||||
batch_act,
|
||||
batch_reset,
|
||||
batch_step
|
||||
)
|
||||
alg_tic = time()
|
||||
alg_state = algorithm.tell(alg_state, fitnesses)
|
||||
alg_state = jax.tree_map(lambda x: x.block_until_ready(), alg_state)
|
||||
a2 = time() - alg_tic
|
||||
|
||||
alg_time = a1 + a2
|
||||
total_time = time() - total_tic
|
||||
|
||||
print(f"generation: {i}, alg_time: {alg_time:.2f}, env_time: {env_time:.2f}, forward_time: {forward_time:.2f}, total_time: {total_time: .2f}, "
|
||||
f"max_fitness: {np.max(fitnesses):.2f}", f"avg_fitness: {np.mean(fitnesses):.2f}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,33 +0,0 @@
|
||||
import jax.random
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
import time
|
||||
|
||||
|
||||
def random_array(key):
|
||||
return jax.random.normal(key, (1000,))
|
||||
|
||||
def random_array_np():
|
||||
return np.random.normal(size=(1000,))
|
||||
|
||||
|
||||
def t_jax():
|
||||
key = jax.random.PRNGKey(42)
|
||||
max_li = []
|
||||
tic = time.time()
|
||||
for _ in range(100):
|
||||
key, sub_key = jax.random.split(key)
|
||||
array = random_array(sub_key)
|
||||
array = jax.device_get(array)
|
||||
max_li.append(max(array))
|
||||
print(max_li, time.time() - tic)
|
||||
|
||||
def t_np():
|
||||
max_li = []
|
||||
tic = time.time()
|
||||
for _ in range(100):
|
||||
max_li.append(max(random_array_np()))
|
||||
print(max_li, time.time() - tic)
|
||||
|
||||
if __name__ == '__main__':
|
||||
t_np()
|
||||
@@ -1,5 +1,5 @@
|
||||
from config import *
|
||||
from pipeline_jitable_env import Pipeline
|
||||
from pipeline import Pipeline
|
||||
from algorithm import NEAT
|
||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
||||
from problem.func_fit import XOR, FuncFitConfig
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from config import *
|
||||
from pipeline_jitable_env import Pipeline
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat import NormalGene, NormalGeneConfig
|
||||
from algorithm.hyperneat import HyperNEAT, NormalSubstrate, NormalSubstrateConfig
|
||||
from problem.func_fit import XOR3d, FuncFitConfig
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from config import *
|
||||
from pipeline_jitable_env import Pipeline
|
||||
from pipeline import Pipeline
|
||||
from algorithm import NEAT
|
||||
from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig
|
||||
from problem.func_fit import XOR3d, FuncFitConfig
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from config import *
|
||||
from pipeline_jitable_env import Pipeline
|
||||
from pipeline import Pipeline
|
||||
from algorithm import NEAT
|
||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
||||
from problem.rl_env import GymNaxConfig, GymNaxEnv
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from config import *
|
||||
from pipeline_jitable_env import Pipeline
|
||||
from pipeline import Pipeline
|
||||
from algorithm import NEAT
|
||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
||||
from problem.rl_env import GymNaxConfig, GymNaxEnv
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from config import *
|
||||
from pipeline_jitable_env import Pipeline
|
||||
from pipeline import Pipeline
|
||||
from algorithm import NEAT
|
||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
||||
from algorithm.hyperneat import HyperNEAT, NormalSubstrateConfig, NormalSubstrate
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from config import *
|
||||
from pipeline_jitable_env import Pipeline
|
||||
from pipeline import Pipeline
|
||||
from algorithm import NEAT
|
||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
||||
from problem.rl_env import GymNaxConfig, GymNaxEnv
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from config import *
|
||||
from pipeline_jitable_env import Pipeline
|
||||
from pipeline import Pipeline
|
||||
from algorithm import NEAT
|
||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
||||
from problem.rl_env import GymNaxConfig, GymNaxEnv
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from config import *
|
||||
from pipeline_jitable_env import Pipeline
|
||||
from pipeline import Pipeline
|
||||
from algorithm import NEAT
|
||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
||||
from problem.rl_env import GymNaxConfig, GymNaxEnv
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from config import *
|
||||
from pipeline_jitable_env import Pipeline
|
||||
from pipeline import Pipeline
|
||||
from algorithm import NEAT
|
||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
||||
from problem.rl_env import GymNaxConfig, GymNaxEnv
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
import time
|
||||
|
||||
import jax.numpy as jnp
|
||||
from config import *
|
||||
from pipeline_jitable_env import Pipeline
|
||||
from algorithm import NEAT
|
||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
||||
from problem.rl_env import GymNaxConfig, GymNaxEnv
|
||||
|
||||
|
||||
def conf_with_seed(seed):
|
||||
return Config(
|
||||
basic=BasicConfig(
|
||||
seed=seed,
|
||||
fitness_target=501,
|
||||
pop_size=10000,
|
||||
generation_limit=100
|
||||
),
|
||||
neat=NeatConfig(
|
||||
inputs=4,
|
||||
outputs=1,
|
||||
max_species=10000
|
||||
),
|
||||
gene=NormalGeneConfig(
|
||||
activation_default=Act.sigmoid,
|
||||
activation_options=(Act.sigmoid,),
|
||||
),
|
||||
problem=GymNaxConfig(
|
||||
env_name='CartPole-v1',
|
||||
output_transform=lambda out: jnp.where(out[0] > 0.5, 1, 0) # the action of cartpole is {0, 1}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
times = []
|
||||
|
||||
for seed in range(10):
|
||||
conf = conf_with_seed(seed)
|
||||
algorithm = NEAT(conf, NormalGene)
|
||||
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
|
||||
state = pipeline.setup()
|
||||
pipeline.pre_compile(state)
|
||||
tic = time.time()
|
||||
state, best = pipeline.auto_run(state)
|
||||
time_cost = time.time() - tic
|
||||
times.append(time_cost)
|
||||
print(times)
|
||||
|
||||
print(f"total_times: {times}")
|
||||
|
||||
|
||||
@@ -1,301 +0,0 @@
|
||||
mean_totaltime,Average Fitness
|
||||
16.23358095,-1531.406262
|
||||
34.16262472,-1469.014171
|
||||
52.56341431,-1429.524865
|
||||
68.70164006,-1398.485718
|
||||
87.61529462,-1369.068788
|
||||
105.4480726,-1338.810888
|
||||
120.3260972,-1308.414005
|
||||
134.2059872,-1280.022147
|
||||
148.1975725,-1258.681994
|
||||
161.0129987,-1248.612688
|
||||
174.4836712,-1242.987934
|
||||
187.5842981,-1245.195745
|
||||
201.7285712,-1244.295623
|
||||
217.0596706,-1241.958828
|
||||
233.1745317,-1242.23715
|
||||
249.9489457,-1240.606505
|
||||
265.2109861,-1241.584935
|
||||
282.1818033,-1242.254389
|
||||
301.1290779,-1240.280126
|
||||
317.5243011,-1238.729867
|
||||
334.9861856,-1238.403718
|
||||
354.7175949,-1236.6442
|
||||
373.5419451,-1238.018048
|
||||
391.36524,-1236.910213
|
||||
410.9625152,-1235.673599
|
||||
430.0237711,-1234.652077
|
||||
449.6226187,-1236.544195
|
||||
468.2591698,-1233.919256
|
||||
489.2314327,-1236.904311
|
||||
510.1973371,-1236.444161
|
||||
530.6520383,-1234.882298
|
||||
552.1866553,-1236.077152
|
||||
574.0524314,-1235.339124
|
||||
595.665635,-1233.053687
|
||||
618.1490554,-1235.346718
|
||||
640.9171209,-1234.058986
|
||||
663.9380282,-1233.082789
|
||||
687.5950083,-1233.227056
|
||||
711.6376721,-1232.342045
|
||||
737.2806357,-1232.589748
|
||||
762.8286344,-1229.058886
|
||||
787.2488623,-1233.670237
|
||||
812.7839067,-1232.203953
|
||||
837.8952895,-1231.658721
|
||||
864.4482835,-1232.48289
|
||||
891.955594,-1230.740013
|
||||
920.0526731,-1229.984369
|
||||
949.259579,-1230.266798
|
||||
980.3165732,-1228.807429
|
||||
1010.551222,-1229.953455
|
||||
1042.59788,-1230.538116
|
||||
1073.075853,-1228.559482
|
||||
1106.577571,-1229.979631
|
||||
1139.959677,-1229.714391
|
||||
1173.958232,-1228.932949
|
||||
1209.781649,-1226.865015
|
||||
1245.958032,-1227.553255
|
||||
1281.796618,-1228.989607
|
||||
1321.521287,-1227.193127
|
||||
1361.788322,-1227.00891
|
||||
1401.135148,-1226.937434
|
||||
1440.664442,-1228.277831
|
||||
1480.485734,-1226.27974
|
||||
1523.433305,-1224.812807
|
||||
1566.113366,-1226.587234
|
||||
1610.083268,-1227.333781
|
||||
1654.933877,-1227.395313
|
||||
1700.679548,-1223.551766
|
||||
1747.142431,-1227.437868
|
||||
1794.300387,-1225.468609
|
||||
1840.876058,-1225.929573
|
||||
1888.193045,-1227.768213
|
||||
1936.7367,-1224.390808
|
||||
1984.760717,-1227.475326
|
||||
2033.411831,-1223.673373
|
||||
2084.356429,-1223.815094
|
||||
2135.077312,-1224.983295
|
||||
2186.35972,-1224.084667
|
||||
2238.44057,-1226.530143
|
||||
2292.238697,-1223.684731
|
||||
2345.740624,-1224.418812
|
||||
2399.445412,-1224.068882
|
||||
2453.074848,-1223.038722
|
||||
2506.191465,-1224.688486
|
||||
2560.467831,-1224.369871
|
||||
2615.123451,-1223.741346
|
||||
2669.90867,-1224.06301
|
||||
2723.421276,-1223.571537
|
||||
2777.560479,-1221.698283
|
||||
2832.592864,-1223.216613
|
||||
2887.017148,-1223.993198
|
||||
2941.483496,-1224.1121
|
||||
2995.423772,-1223.169763
|
||||
3050.948137,-1223.459228
|
||||
3106.069071,-1223.111025
|
||||
3159.795594,-1224.609992
|
||||
3214.050797,-1223.487228
|
||||
3269.429128,-1223.715971
|
||||
3324.581811,-1223.355198
|
||||
3380.680088,-1223.137084
|
||||
3436.805942,-1222.142975
|
||||
3492.414253,-1222.925149
|
||||
3549.505774,-1223.379747
|
||||
3607.501727,-1222.272351
|
||||
3664.639311,-1224.368308
|
||||
3722.034017,-1223.901903
|
||||
3779.537015,-1222.94451
|
||||
3836.6152,-1221.838935
|
||||
3894.218441,-1223.663879
|
||||
3951.011881,-1224.02154
|
||||
4007.885495,-1223.465968
|
||||
4063.988944,-1222.439965
|
||||
4122.273734,-1221.786334
|
||||
4179.429916,-1222.157602
|
||||
4237.573861,-1223.250346
|
||||
4295.552148,-1222.545104
|
||||
4353.574564,-1222.514246
|
||||
4411.656075,-1222.108242
|
||||
4468.160373,-1222.456373
|
||||
4523.770186,-1221.893127
|
||||
4578.243049,-1222.369885
|
||||
4633.660001,-1221.74155
|
||||
4689.995503,-1221.485722
|
||||
4745.046551,-1222.821859
|
||||
4799.297272,-1220.567495
|
||||
4855.528112,-1221.962994
|
||||
4912.181782,-1223.706512
|
||||
4968.106065,-1223.249642
|
||||
5023.349784,-1223.88398
|
||||
5077.95836,-1221.843378
|
||||
5132.109917,-1220.167506
|
||||
5185.437263,-1222.633295
|
||||
5239.11293,-1220.070837
|
||||
5292.425083,-1222.163853
|
||||
5344.871339,-1221.905685
|
||||
5398.5643,-1221.195668
|
||||
5453.522891,-1220.820716
|
||||
5510.54199,-1221.871458
|
||||
5565.892794,-1221.127944
|
||||
5619.692913,-1221.370239
|
||||
5672.505697,-1220.411562
|
||||
5726.742534,-1219.956466
|
||||
5780.314767,-1223.574739
|
||||
5831.786025,-1220.038731
|
||||
5883.706538,-1222.285694
|
||||
5935.963855,-1219.789742
|
||||
5990.031264,-1220.348486
|
||||
6042.056406,-1221.695037
|
||||
6094.931383,-1222.573256
|
||||
6149.866284,-1220.799972
|
||||
6201.410016,-1223.62492
|
||||
6254.863233,-1221.173897
|
||||
6307.299802,-1218.384144
|
||||
6358.419653,-1221.949964
|
||||
6410.39371,-1220.695003
|
||||
6462.294865,-1221.354209
|
||||
6514.29837,-1220.013649
|
||||
6566.028445,-1221.057066
|
||||
6619.112563,-1220.250728
|
||||
6671.586666,-1220.74064
|
||||
6724.056905,-1220.696191
|
||||
6778.867936,-1221.342228
|
||||
6834.066508,-1220.67166
|
||||
6887.848139,-1221.681324
|
||||
6940.725238,-1221.786548
|
||||
6993.728177,-1220.293248
|
||||
7047.068145,-1220.784974
|
||||
7101.058736,-1221.296277
|
||||
7155.695031,-1220.314099
|
||||
7211.131008,-1220.001403
|
||||
7264.672096,-1222.582639
|
||||
7319.855532,-1220.218512
|
||||
7373.416443,-1222.413372
|
||||
7426.826476,-1221.740957
|
||||
7481.033519,-1220.12891
|
||||
7535.098597,-1220.954482
|
||||
7589.86269,-1221.591681
|
||||
7645.189188,-1219.780434
|
||||
7700.019942,-1219.86627
|
||||
7754.408913,-1221.254077
|
||||
7809.17354,-1220.560716
|
||||
7863.955751,-1219.303522
|
||||
7919.377039,-1220.023321
|
||||
7973.91837,-1221.021169
|
||||
8030.191454,-1221.020316
|
||||
8086.216218,-1219.728618
|
||||
8143.879354,-1221.815614
|
||||
8200.316735,-1219.043831
|
||||
8257.940883,-1220.157413
|
||||
8314.852015,-1220.684859
|
||||
8370.81427,-1220.370599
|
||||
8428.692646,-1219.714431
|
||||
8484.287711,-1218.305622
|
||||
8537.958919,-1221.428118
|
||||
8594.432063,-1219.59307
|
||||
8649.300665,-1220.87185
|
||||
8704.677604,-1221.031256
|
||||
8760.582589,-1218.94684
|
||||
8818.110156,-1221.658566
|
||||
8874.317387,-1219.76819
|
||||
8931.379297,-1221.200231
|
||||
8988.303494,-1220.058048
|
||||
9045.237712,-1219.77335
|
||||
9102.797204,-1220.183627
|
||||
9158.2346,-1220.408912
|
||||
9213.641003,-1218.583548
|
||||
9270.128289,-1221.096059
|
||||
9325.747626,-1220.0666
|
||||
9381.558162,-1219.557564
|
||||
9438.852373,-1219.330095
|
||||
9495.594231,-1221.993273
|
||||
9553.896612,-1223.187594
|
||||
9611.398788,-1219.720465
|
||||
9670.280587,-1218.745421
|
||||
9727.735437,-1220.320216
|
||||
9785.576854,-1218.623738
|
||||
9842.847055,-1220.192794
|
||||
9900.27451,-1218.631224
|
||||
9960.404387,-1219.626358
|
||||
10020.38233,-1218.897498
|
||||
10081.03875,-1220.686493
|
||||
10142.4862,-1222.619642
|
||||
10203.81352,-1221.138795
|
||||
10263.73111,-1219.75704
|
||||
10322.01157,-1221.510761
|
||||
10383.70518,-1220.942931
|
||||
10444.83872,-1220.590052
|
||||
10505.46391,-1220.915548
|
||||
10565.21068,-1220.83712
|
||||
10623.1464,-1220.426405
|
||||
10681.92922,-1221.087758
|
||||
10740.52835,-1219.740644
|
||||
10799.10282,-1219.464613
|
||||
10858.83727,-1218.971765
|
||||
10917.77249,-1221.522595
|
||||
10978.69096,-1219.90022
|
||||
11039.15463,-1219.721981
|
||||
11100.71659,-1218.736818
|
||||
11160.92038,-1218.613676
|
||||
11223.41109,-1220.718136
|
||||
11286.18113,-1217.805265
|
||||
11347.3255,-1221.471473
|
||||
11409.22849,-1218.647645
|
||||
11471.5558,-1218.650381
|
||||
11534.54889,-1221.838598
|
||||
11596.56306,-1216.725289
|
||||
11658.5901,-1220.448722
|
||||
11720.84754,-1218.200733
|
||||
11781.86972,-1219.514898
|
||||
11844.25683,-1219.122911
|
||||
11905.98746,-1219.437825
|
||||
11967.72706,-1220.671814
|
||||
12029.60287,-1221.425081
|
||||
12090.42101,-1220.465861
|
||||
12151.72214,-1218.325056
|
||||
12215.04968,-1221.885592
|
||||
12276.16149,-1217.636273
|
||||
12335.90203,-1218.246472
|
||||
12396.63493,-1218.821679
|
||||
12459.444,-1220.249942
|
||||
12522.88157,-1218.853892
|
||||
12585.31227,-1220.934625
|
||||
12647.22732,-1219.792465
|
||||
12707.95382,-1219.762632
|
||||
12770.33686,-1217.819044
|
||||
12833.20918,-1218.993567
|
||||
12896.81557,-1219.114086
|
||||
12958.97821,-1219.069654
|
||||
13023.42444,-1219.832851
|
||||
13086.90235,-1220.588547
|
||||
13149.95939,-1217.843585
|
||||
13214.8704,-1218.933695
|
||||
13278.31332,-1218.540474
|
||||
13342.29491,-1219.083535
|
||||
13404.31573,-1221.076755
|
||||
13465.29136,-1220.365732
|
||||
13526.70887,-1217.986482
|
||||
13590.52374,-1219.40228
|
||||
13654.7832,-1217.818409
|
||||
13718.30046,-1218.816018
|
||||
13781.31456,-1220.176623
|
||||
13843.58471,-1218.819039
|
||||
13905.20301,-1218.231392
|
||||
13967.45721,-1219.155412
|
||||
14030.04214,-1219.768839
|
||||
14092.84348,-1219.087143
|
||||
14155.60131,-1217.773162
|
||||
14218.77685,-1218.442359
|
||||
14281.016,-1219.142937
|
||||
14346.43126,-1218.540456
|
||||
14409.63096,-1220.27516
|
||||
14475.35352,-1220.047078
|
||||
14541.72266,-1217.934667
|
||||
14605.84911,-1219.121418
|
||||
14673.27209,-1219.250748
|
||||
14737.62863,-1220.368594
|
||||
14802.66072,-1220.611695
|
||||
14867.97276,-1221.085498
|
||||
14935.65309,-1216.312149
|
||||
15001.18007,-1218.441571
|
||||
|
@@ -1,22 +0,0 @@
|
||||
from matplotlib import pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
# 使用 genfromtxt 函数读取 CSV 文件
|
||||
data = np.genfromtxt('neatpython.csv', delimiter=',', skip_header=1) # 假设有一个头部行
|
||||
mean_time, fitness_mean = data.T
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.plot(mean_time, fitness_mean, color='green', label='NEAT-Python', linestyle=':')
|
||||
ax.set_xlabel('Time (s)')
|
||||
ax.set_ylabel('Average Fitness')
|
||||
ax.set_xlim(0, 500)
|
||||
ax.set_ylim(-2000, -1000)
|
||||
ax.legend()
|
||||
|
||||
|
||||
# ci = 1.96 * neatax_sem
|
||||
# lower_bound = neatax_mean - ci
|
||||
# upper_bound = neatax_mean + ci
|
||||
# plt.plot(mean_time, fitness_mean, color='r', label='NEAT-Python')
|
||||
# plt.fill_between(x_axis, lower_bound, upper_bound, color='red', alpha=0.2)
|
||||
fig.show()
|
||||
111
pipeline_time.py
111
pipeline_time.py
@@ -1,111 +0,0 @@
|
||||
from typing import Type
|
||||
|
||||
import jax
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from algorithm import NEAT, HyperNEAT
|
||||
from config import Config
|
||||
from core import State, Algorithm, Problem
|
||||
|
||||
|
||||
class Pipeline:
|
||||
|
||||
def __init__(self, config: Config, algorithm: Algorithm, problem_type: Type[Problem]):
|
||||
self.config = config
|
||||
self.algorithm = algorithm
|
||||
self.problem = problem_type(config.problem)
|
||||
|
||||
if isinstance(algorithm, NEAT):
|
||||
assert config.neat.inputs == self.problem.input_shape[-1]
|
||||
|
||||
elif isinstance(algorithm, HyperNEAT):
|
||||
assert config.hyperneat.inputs == self.problem.input_shape[-1]
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.act_func = self.algorithm.act
|
||||
|
||||
for _ in range(len(self.problem.input_shape) - 1):
|
||||
self.act_func = jax.vmap(self.act_func, in_axes=(None, 0, None))
|
||||
|
||||
self.best_genome = None
|
||||
self.best_fitness = float('-inf')
|
||||
self.generation_timestamp = None
|
||||
|
||||
def setup(self):
|
||||
key = jax.random.PRNGKey(self.config.basic.seed)
|
||||
algorithm_key, evaluate_key = jax.random.split(key, 2)
|
||||
state = State()
|
||||
state = self.algorithm.setup(algorithm_key, state)
|
||||
return state.update(
|
||||
evaluate_key=evaluate_key
|
||||
)
|
||||
|
||||
def step(self, state):
|
||||
|
||||
key, sub_key = jax.random.split(state.evaluate_key)
|
||||
keys = jax.random.split(key, self.config.basic.pop_size)
|
||||
|
||||
pop = self.algorithm.ask(state)
|
||||
|
||||
pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0))(state, pop)
|
||||
|
||||
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0))(keys, state, self.act_func,
|
||||
pop_transformed)
|
||||
|
||||
state = self.algorithm.tell(state, fitnesses)
|
||||
|
||||
return state.update(evaluate_key=sub_key), fitnesses
|
||||
|
||||
def auto_run(self, ini_state):
|
||||
state = ini_state
|
||||
for _ in range(self.config.basic.generation_limit):
|
||||
|
||||
self.generation_timestamp = time.time()
|
||||
|
||||
previous_pop = self.algorithm.ask(state)
|
||||
|
||||
state, fitnesses = self.step(state)
|
||||
|
||||
fitnesses = jax.device_get(fitnesses)
|
||||
|
||||
self.analysis(state, previous_pop, fitnesses)
|
||||
|
||||
if max(fitnesses) >= self.config.basic.fitness_target:
|
||||
print("Fitness limit reached!")
|
||||
return state, self.best_genome
|
||||
|
||||
print("Generation limit reached!")
|
||||
return state, self.best_genome
|
||||
|
||||
def analysis(self, state, pop, fitnesses):
|
||||
|
||||
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
|
||||
|
||||
new_timestamp = time.time()
|
||||
|
||||
cost_time = new_timestamp - self.generation_timestamp
|
||||
|
||||
max_idx = np.argmax(fitnesses)
|
||||
if fitnesses[max_idx] > self.best_fitness:
|
||||
self.best_fitness = fitnesses[max_idx]
|
||||
self.best_genome = pop[max_idx]
|
||||
|
||||
member_count = jax.device_get(state.species_info.member_count)
|
||||
species_sizes = [int(i) for i in member_count if i > 0]
|
||||
|
||||
print(f"Generation: {state.generation}",
|
||||
f"species: {len(species_sizes)}, {species_sizes}",
|
||||
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")
|
||||
|
||||
def show(self, state, genome):
|
||||
transformed = self.algorithm.transform(state, genome)
|
||||
self.problem.show(state.evaluate_key, state, self.act_func, transformed)
|
||||
|
||||
def pre_compile(self, state):
|
||||
tic = time.time()
|
||||
print("start compile")
|
||||
self.step.lower(self, state).compile()
|
||||
print(f"compile finished, cost time: {time.time() - tic}s")
|
||||
Reference in New Issue
Block a user