refactor names;

delete useless
This commit is contained in:
wls2002
2023-09-15 22:33:21 +08:00
parent d317317ed2
commit 4efa9445d5
17 changed files with 10 additions and 661 deletions

View File

@@ -1,131 +0,0 @@
from typing import Callable
from time import time
import jax
from jax import numpy as jnp, vmap, jit
import gymnax
import numpy as np
from config import *
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
def conf_cartpole():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=500,
generation_limit=150,
pop_size=10000
),
neat=NeatConfig(
inputs=3,
outputs=1,
),
gene=NormalGeneConfig(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
problem=GymNaxConfig(
env_name='Pendulum-v1',
output_transform=lambda out: out * 2 # the action of pendulum is [-2, 2]
)
)
def batch_evaluate(
key,
alg_state,
genomes,
env_params,
batch_transform: Callable,
batch_act: Callable,
batch_reset: Callable,
batch_step: Callable,
):
alg_time, env_time, forward_time = 0, 0, 0
pop_size = genomes.nodes.shape[0]
alg_tic = time()
genomes_transform = batch_transform(alg_state, genomes)
alg_time += time() - alg_tic
reset_keys = jax.random.split(key, pop_size)
observations, states = batch_reset(reset_keys, env_params)
done = np.zeros(pop_size, dtype=bool)
fitnesses = np.zeros(pop_size)
while not np.all(done):
key, _ = jax.random.split(key)
vmap_keys = jax.random.split(key, pop_size)
forward_tic = time()
actions = batch_act(alg_state, observations, genomes_transform).block_until_ready()
forward_time += time() - forward_tic
env_tic = time()
observations, states, reward, current_done, _ = batch_step(vmap_keys, states, actions, env_params)
reward, current_done = jax.device_get([reward, current_done])
env_time += time() - env_tic
fitnesses += reward * np.logical_not(done)
done = np.logical_or(done, current_done)
return fitnesses, alg_time, env_time, forward_time
def main():
conf = conf_cartpole()
algorithm = NEAT(conf, NormalGene)
def act(state, inputs, genome):
res = algorithm.act(state, inputs, genome)
return conf.problem.output_transform(res)
batch_transform = jit(vmap(algorithm.transform, in_axes=(None, 0)))
# (state, obs, genome_transform) -> action
batch_act = jit(vmap(act, in_axes=(None, 0, 0)))
env, env_params = gymnax.make(conf.problem.env_name)
# (seed, params) -> (ini_obs, ini_state)
batch_reset = jit(vmap(env.reset, in_axes=(0, None)))
# (seed, state, action, params) -> (obs, state, reward, done, info)
batch_step = jit(vmap(env.step, in_axes=(0, 0, 0, None)))
key = jax.random.PRNGKey(conf.basic.seed)
alg_key, pro_key = jax.random.split(key)
alg_state = algorithm.setup(alg_key)
for i in range(conf.basic.generation_limit):
total_tic = time()
pro_key, _ = jax.random.split(pro_key)
fitnesses, a1, env_time, forward_time = batch_evaluate(
pro_key,
alg_state,
algorithm.ask(alg_state),
env_params,
batch_transform,
batch_act,
batch_reset,
batch_step
)
alg_tic = time()
alg_state = algorithm.tell(alg_state, fitnesses)
alg_state = jax.tree_map(lambda x: x.block_until_ready(), alg_state)
a2 = time() - alg_tic
alg_time = a1 + a2
total_time = time() - total_tic
print(f"generation: {i}, alg_time: {alg_time:.2f}, env_time: {env_time:.2f}, forward_time: {forward_time:.2f}, total_time: {total_time: .2f}, "
f"max_fitness: {np.max(fitnesses):.2f}", f"avg_fitness: {np.mean(fitnesses):.2f}")
if __name__ == '__main__':
main()

View File

@@ -1,33 +0,0 @@
import jax.random
import numpy as np
import jax.numpy as jnp
import time
def random_array(key):
return jax.random.normal(key, (1000,))
def random_array_np():
return np.random.normal(size=(1000,))
def t_jax():
key = jax.random.PRNGKey(42)
max_li = []
tic = time.time()
for _ in range(100):
key, sub_key = jax.random.split(key)
array = random_array(sub_key)
array = jax.device_get(array)
max_li.append(max(array))
print(max_li, time.time() - tic)
def t_np():
max_li = []
tic = time.time()
for _ in range(100):
max_li.append(max(random_array_np()))
print(max_li, time.time() - tic)
if __name__ == '__main__':
t_np()

View File

@@ -1,5 +1,5 @@
from config import *
from pipeline_jitable_env import Pipeline
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.func_fit import XOR, FuncFitConfig

View File

@@ -1,5 +1,5 @@
from config import *
from pipeline_jitable_env import Pipeline
from pipeline import Pipeline
from algorithm.neat import NormalGene, NormalGeneConfig
from algorithm.hyperneat import HyperNEAT, NormalSubstrate, NormalSubstrateConfig
from problem.func_fit import XOR3d, FuncFitConfig

View File

@@ -1,5 +1,5 @@
from config import *
from pipeline_jitable_env import Pipeline
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig
from problem.func_fit import XOR3d, FuncFitConfig

View File

@@ -1,7 +1,7 @@
import jax.numpy as jnp
from config import *
from pipeline_jitable_env import Pipeline
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv

View File

@@ -1,7 +1,7 @@
import jax.numpy as jnp
from config import *
from pipeline_jitable_env import Pipeline
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv

View File

@@ -1,7 +1,7 @@
import jax.numpy as jnp
from config import *
from pipeline_jitable_env import Pipeline
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from algorithm.hyperneat import HyperNEAT, NormalSubstrateConfig, NormalSubstrate

View File

@@ -1,7 +1,7 @@
import jax.numpy as jnp
from config import *
from pipeline_jitable_env import Pipeline
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv

View File

@@ -1,7 +1,7 @@
import jax.numpy as jnp
from config import *
from pipeline_jitable_env import Pipeline
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv

View File

@@ -1,7 +1,7 @@
import jax.numpy as jnp
from config import *
from pipeline_jitable_env import Pipeline
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv

View File

@@ -1,5 +1,5 @@
from config import *
from pipeline_jitable_env import Pipeline
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv

View File

@@ -1,53 +0,0 @@
import time
import jax.numpy as jnp
from config import *
from pipeline_jitable_env import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
def conf_with_seed(seed):
return Config(
basic=BasicConfig(
seed=seed,
fitness_target=501,
pop_size=10000,
generation_limit=100
),
neat=NeatConfig(
inputs=4,
outputs=1,
max_species=10000
),
gene=NormalGeneConfig(
activation_default=Act.sigmoid,
activation_options=(Act.sigmoid,),
),
problem=GymNaxConfig(
env_name='CartPole-v1',
output_transform=lambda out: jnp.where(out[0] > 0.5, 1, 0) # the action of cartpole is {0, 1}
)
)
if __name__ == '__main__':
times = []
for seed in range(10):
conf = conf_with_seed(seed)
algorithm = NEAT(conf, NormalGene)
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
state = pipeline.setup()
pipeline.pre_compile(state)
tic = time.time()
state, best = pipeline.auto_run(state)
time_cost = time.time() - tic
times.append(time_cost)
print(times)
print(f"total_times: {times}")

View File

@@ -1,301 +0,0 @@
mean_totaltime,Average Fitness
16.23358095,-1531.406262
34.16262472,-1469.014171
52.56341431,-1429.524865
68.70164006,-1398.485718
87.61529462,-1369.068788
105.4480726,-1338.810888
120.3260972,-1308.414005
134.2059872,-1280.022147
148.1975725,-1258.681994
161.0129987,-1248.612688
174.4836712,-1242.987934
187.5842981,-1245.195745
201.7285712,-1244.295623
217.0596706,-1241.958828
233.1745317,-1242.23715
249.9489457,-1240.606505
265.2109861,-1241.584935
282.1818033,-1242.254389
301.1290779,-1240.280126
317.5243011,-1238.729867
334.9861856,-1238.403718
354.7175949,-1236.6442
373.5419451,-1238.018048
391.36524,-1236.910213
410.9625152,-1235.673599
430.0237711,-1234.652077
449.6226187,-1236.544195
468.2591698,-1233.919256
489.2314327,-1236.904311
510.1973371,-1236.444161
530.6520383,-1234.882298
552.1866553,-1236.077152
574.0524314,-1235.339124
595.665635,-1233.053687
618.1490554,-1235.346718
640.9171209,-1234.058986
663.9380282,-1233.082789
687.5950083,-1233.227056
711.6376721,-1232.342045
737.2806357,-1232.589748
762.8286344,-1229.058886
787.2488623,-1233.670237
812.7839067,-1232.203953
837.8952895,-1231.658721
864.4482835,-1232.48289
891.955594,-1230.740013
920.0526731,-1229.984369
949.259579,-1230.266798
980.3165732,-1228.807429
1010.551222,-1229.953455
1042.59788,-1230.538116
1073.075853,-1228.559482
1106.577571,-1229.979631
1139.959677,-1229.714391
1173.958232,-1228.932949
1209.781649,-1226.865015
1245.958032,-1227.553255
1281.796618,-1228.989607
1321.521287,-1227.193127
1361.788322,-1227.00891
1401.135148,-1226.937434
1440.664442,-1228.277831
1480.485734,-1226.27974
1523.433305,-1224.812807
1566.113366,-1226.587234
1610.083268,-1227.333781
1654.933877,-1227.395313
1700.679548,-1223.551766
1747.142431,-1227.437868
1794.300387,-1225.468609
1840.876058,-1225.929573
1888.193045,-1227.768213
1936.7367,-1224.390808
1984.760717,-1227.475326
2033.411831,-1223.673373
2084.356429,-1223.815094
2135.077312,-1224.983295
2186.35972,-1224.084667
2238.44057,-1226.530143
2292.238697,-1223.684731
2345.740624,-1224.418812
2399.445412,-1224.068882
2453.074848,-1223.038722
2506.191465,-1224.688486
2560.467831,-1224.369871
2615.123451,-1223.741346
2669.90867,-1224.06301
2723.421276,-1223.571537
2777.560479,-1221.698283
2832.592864,-1223.216613
2887.017148,-1223.993198
2941.483496,-1224.1121
2995.423772,-1223.169763
3050.948137,-1223.459228
3106.069071,-1223.111025
3159.795594,-1224.609992
3214.050797,-1223.487228
3269.429128,-1223.715971
3324.581811,-1223.355198
3380.680088,-1223.137084
3436.805942,-1222.142975
3492.414253,-1222.925149
3549.505774,-1223.379747
3607.501727,-1222.272351
3664.639311,-1224.368308
3722.034017,-1223.901903
3779.537015,-1222.94451
3836.6152,-1221.838935
3894.218441,-1223.663879
3951.011881,-1224.02154
4007.885495,-1223.465968
4063.988944,-1222.439965
4122.273734,-1221.786334
4179.429916,-1222.157602
4237.573861,-1223.250346
4295.552148,-1222.545104
4353.574564,-1222.514246
4411.656075,-1222.108242
4468.160373,-1222.456373
4523.770186,-1221.893127
4578.243049,-1222.369885
4633.660001,-1221.74155
4689.995503,-1221.485722
4745.046551,-1222.821859
4799.297272,-1220.567495
4855.528112,-1221.962994
4912.181782,-1223.706512
4968.106065,-1223.249642
5023.349784,-1223.88398
5077.95836,-1221.843378
5132.109917,-1220.167506
5185.437263,-1222.633295
5239.11293,-1220.070837
5292.425083,-1222.163853
5344.871339,-1221.905685
5398.5643,-1221.195668
5453.522891,-1220.820716
5510.54199,-1221.871458
5565.892794,-1221.127944
5619.692913,-1221.370239
5672.505697,-1220.411562
5726.742534,-1219.956466
5780.314767,-1223.574739
5831.786025,-1220.038731
5883.706538,-1222.285694
5935.963855,-1219.789742
5990.031264,-1220.348486
6042.056406,-1221.695037
6094.931383,-1222.573256
6149.866284,-1220.799972
6201.410016,-1223.62492
6254.863233,-1221.173897
6307.299802,-1218.384144
6358.419653,-1221.949964
6410.39371,-1220.695003
6462.294865,-1221.354209
6514.29837,-1220.013649
6566.028445,-1221.057066
6619.112563,-1220.250728
6671.586666,-1220.74064
6724.056905,-1220.696191
6778.867936,-1221.342228
6834.066508,-1220.67166
6887.848139,-1221.681324
6940.725238,-1221.786548
6993.728177,-1220.293248
7047.068145,-1220.784974
7101.058736,-1221.296277
7155.695031,-1220.314099
7211.131008,-1220.001403
7264.672096,-1222.582639
7319.855532,-1220.218512
7373.416443,-1222.413372
7426.826476,-1221.740957
7481.033519,-1220.12891
7535.098597,-1220.954482
7589.86269,-1221.591681
7645.189188,-1219.780434
7700.019942,-1219.86627
7754.408913,-1221.254077
7809.17354,-1220.560716
7863.955751,-1219.303522
7919.377039,-1220.023321
7973.91837,-1221.021169
8030.191454,-1221.020316
8086.216218,-1219.728618
8143.879354,-1221.815614
8200.316735,-1219.043831
8257.940883,-1220.157413
8314.852015,-1220.684859
8370.81427,-1220.370599
8428.692646,-1219.714431
8484.287711,-1218.305622
8537.958919,-1221.428118
8594.432063,-1219.59307
8649.300665,-1220.87185
8704.677604,-1221.031256
8760.582589,-1218.94684
8818.110156,-1221.658566
8874.317387,-1219.76819
8931.379297,-1221.200231
8988.303494,-1220.058048
9045.237712,-1219.77335
9102.797204,-1220.183627
9158.2346,-1220.408912
9213.641003,-1218.583548
9270.128289,-1221.096059
9325.747626,-1220.0666
9381.558162,-1219.557564
9438.852373,-1219.330095
9495.594231,-1221.993273
9553.896612,-1223.187594
9611.398788,-1219.720465
9670.280587,-1218.745421
9727.735437,-1220.320216
9785.576854,-1218.623738
9842.847055,-1220.192794
9900.27451,-1218.631224
9960.404387,-1219.626358
10020.38233,-1218.897498
10081.03875,-1220.686493
10142.4862,-1222.619642
10203.81352,-1221.138795
10263.73111,-1219.75704
10322.01157,-1221.510761
10383.70518,-1220.942931
10444.83872,-1220.590052
10505.46391,-1220.915548
10565.21068,-1220.83712
10623.1464,-1220.426405
10681.92922,-1221.087758
10740.52835,-1219.740644
10799.10282,-1219.464613
10858.83727,-1218.971765
10917.77249,-1221.522595
10978.69096,-1219.90022
11039.15463,-1219.721981
11100.71659,-1218.736818
11160.92038,-1218.613676
11223.41109,-1220.718136
11286.18113,-1217.805265
11347.3255,-1221.471473
11409.22849,-1218.647645
11471.5558,-1218.650381
11534.54889,-1221.838598
11596.56306,-1216.725289
11658.5901,-1220.448722
11720.84754,-1218.200733
11781.86972,-1219.514898
11844.25683,-1219.122911
11905.98746,-1219.437825
11967.72706,-1220.671814
12029.60287,-1221.425081
12090.42101,-1220.465861
12151.72214,-1218.325056
12215.04968,-1221.885592
12276.16149,-1217.636273
12335.90203,-1218.246472
12396.63493,-1218.821679
12459.444,-1220.249942
12522.88157,-1218.853892
12585.31227,-1220.934625
12647.22732,-1219.792465
12707.95382,-1219.762632
12770.33686,-1217.819044
12833.20918,-1218.993567
12896.81557,-1219.114086
12958.97821,-1219.069654
13023.42444,-1219.832851
13086.90235,-1220.588547
13149.95939,-1217.843585
13214.8704,-1218.933695
13278.31332,-1218.540474
13342.29491,-1219.083535
13404.31573,-1221.076755
13465.29136,-1220.365732
13526.70887,-1217.986482
13590.52374,-1219.40228
13654.7832,-1217.818409
13718.30046,-1218.816018
13781.31456,-1220.176623
13843.58471,-1218.819039
13905.20301,-1218.231392
13967.45721,-1219.155412
14030.04214,-1219.768839
14092.84348,-1219.087143
14155.60131,-1217.773162
14218.77685,-1218.442359
14281.016,-1219.142937
14346.43126,-1218.540456
14409.63096,-1220.27516
14475.35352,-1220.047078
14541.72266,-1217.934667
14605.84911,-1219.121418
14673.27209,-1219.250748
14737.62863,-1220.368594
14802.66072,-1220.611695
14867.97276,-1221.085498
14935.65309,-1216.312149
15001.18007,-1218.441571
1 mean_totaltime Average Fitness
2 16.23358095 -1531.406262
3 34.16262472 -1469.014171
4 52.56341431 -1429.524865
5 68.70164006 -1398.485718
6 87.61529462 -1369.068788
7 105.4480726 -1338.810888
8 120.3260972 -1308.414005
9 134.2059872 -1280.022147
10 148.1975725 -1258.681994
11 161.0129987 -1248.612688
12 174.4836712 -1242.987934
13 187.5842981 -1245.195745
14 201.7285712 -1244.295623
15 217.0596706 -1241.958828
16 233.1745317 -1242.23715
17 249.9489457 -1240.606505
18 265.2109861 -1241.584935
19 282.1818033 -1242.254389
20 301.1290779 -1240.280126
21 317.5243011 -1238.729867
22 334.9861856 -1238.403718
23 354.7175949 -1236.6442
24 373.5419451 -1238.018048
25 391.36524 -1236.910213
26 410.9625152 -1235.673599
27 430.0237711 -1234.652077
28 449.6226187 -1236.544195
29 468.2591698 -1233.919256
30 489.2314327 -1236.904311
31 510.1973371 -1236.444161
32 530.6520383 -1234.882298
33 552.1866553 -1236.077152
34 574.0524314 -1235.339124
35 595.665635 -1233.053687
36 618.1490554 -1235.346718
37 640.9171209 -1234.058986
38 663.9380282 -1233.082789
39 687.5950083 -1233.227056
40 711.6376721 -1232.342045
41 737.2806357 -1232.589748
42 762.8286344 -1229.058886
43 787.2488623 -1233.670237
44 812.7839067 -1232.203953
45 837.8952895 -1231.658721
46 864.4482835 -1232.48289
47 891.955594 -1230.740013
48 920.0526731 -1229.984369
49 949.259579 -1230.266798
50 980.3165732 -1228.807429
51 1010.551222 -1229.953455
52 1042.59788 -1230.538116
53 1073.075853 -1228.559482
54 1106.577571 -1229.979631
55 1139.959677 -1229.714391
56 1173.958232 -1228.932949
57 1209.781649 -1226.865015
58 1245.958032 -1227.553255
59 1281.796618 -1228.989607
60 1321.521287 -1227.193127
61 1361.788322 -1227.00891
62 1401.135148 -1226.937434
63 1440.664442 -1228.277831
64 1480.485734 -1226.27974
65 1523.433305 -1224.812807
66 1566.113366 -1226.587234
67 1610.083268 -1227.333781
68 1654.933877 -1227.395313
69 1700.679548 -1223.551766
70 1747.142431 -1227.437868
71 1794.300387 -1225.468609
72 1840.876058 -1225.929573
73 1888.193045 -1227.768213
74 1936.7367 -1224.390808
75 1984.760717 -1227.475326
76 2033.411831 -1223.673373
77 2084.356429 -1223.815094
78 2135.077312 -1224.983295
79 2186.35972 -1224.084667
80 2238.44057 -1226.530143
81 2292.238697 -1223.684731
82 2345.740624 -1224.418812
83 2399.445412 -1224.068882
84 2453.074848 -1223.038722
85 2506.191465 -1224.688486
86 2560.467831 -1224.369871
87 2615.123451 -1223.741346
88 2669.90867 -1224.06301
89 2723.421276 -1223.571537
90 2777.560479 -1221.698283
91 2832.592864 -1223.216613
92 2887.017148 -1223.993198
93 2941.483496 -1224.1121
94 2995.423772 -1223.169763
95 3050.948137 -1223.459228
96 3106.069071 -1223.111025
97 3159.795594 -1224.609992
98 3214.050797 -1223.487228
99 3269.429128 -1223.715971
100 3324.581811 -1223.355198
101 3380.680088 -1223.137084
102 3436.805942 -1222.142975
103 3492.414253 -1222.925149
104 3549.505774 -1223.379747
105 3607.501727 -1222.272351
106 3664.639311 -1224.368308
107 3722.034017 -1223.901903
108 3779.537015 -1222.94451
109 3836.6152 -1221.838935
110 3894.218441 -1223.663879
111 3951.011881 -1224.02154
112 4007.885495 -1223.465968
113 4063.988944 -1222.439965
114 4122.273734 -1221.786334
115 4179.429916 -1222.157602
116 4237.573861 -1223.250346
117 4295.552148 -1222.545104
118 4353.574564 -1222.514246
119 4411.656075 -1222.108242
120 4468.160373 -1222.456373
121 4523.770186 -1221.893127
122 4578.243049 -1222.369885
123 4633.660001 -1221.74155
124 4689.995503 -1221.485722
125 4745.046551 -1222.821859
126 4799.297272 -1220.567495
127 4855.528112 -1221.962994
128 4912.181782 -1223.706512
129 4968.106065 -1223.249642
130 5023.349784 -1223.88398
131 5077.95836 -1221.843378
132 5132.109917 -1220.167506
133 5185.437263 -1222.633295
134 5239.11293 -1220.070837
135 5292.425083 -1222.163853
136 5344.871339 -1221.905685
137 5398.5643 -1221.195668
138 5453.522891 -1220.820716
139 5510.54199 -1221.871458
140 5565.892794 -1221.127944
141 5619.692913 -1221.370239
142 5672.505697 -1220.411562
143 5726.742534 -1219.956466
144 5780.314767 -1223.574739
145 5831.786025 -1220.038731
146 5883.706538 -1222.285694
147 5935.963855 -1219.789742
148 5990.031264 -1220.348486
149 6042.056406 -1221.695037
150 6094.931383 -1222.573256
151 6149.866284 -1220.799972
152 6201.410016 -1223.62492
153 6254.863233 -1221.173897
154 6307.299802 -1218.384144
155 6358.419653 -1221.949964
156 6410.39371 -1220.695003
157 6462.294865 -1221.354209
158 6514.29837 -1220.013649
159 6566.028445 -1221.057066
160 6619.112563 -1220.250728
161 6671.586666 -1220.74064
162 6724.056905 -1220.696191
163 6778.867936 -1221.342228
164 6834.066508 -1220.67166
165 6887.848139 -1221.681324
166 6940.725238 -1221.786548
167 6993.728177 -1220.293248
168 7047.068145 -1220.784974
169 7101.058736 -1221.296277
170 7155.695031 -1220.314099
171 7211.131008 -1220.001403
172 7264.672096 -1222.582639
173 7319.855532 -1220.218512
174 7373.416443 -1222.413372
175 7426.826476 -1221.740957
176 7481.033519 -1220.12891
177 7535.098597 -1220.954482
178 7589.86269 -1221.591681
179 7645.189188 -1219.780434
180 7700.019942 -1219.86627
181 7754.408913 -1221.254077
182 7809.17354 -1220.560716
183 7863.955751 -1219.303522
184 7919.377039 -1220.023321
185 7973.91837 -1221.021169
186 8030.191454 -1221.020316
187 8086.216218 -1219.728618
188 8143.879354 -1221.815614
189 8200.316735 -1219.043831
190 8257.940883 -1220.157413
191 8314.852015 -1220.684859
192 8370.81427 -1220.370599
193 8428.692646 -1219.714431
194 8484.287711 -1218.305622
195 8537.958919 -1221.428118
196 8594.432063 -1219.59307
197 8649.300665 -1220.87185
198 8704.677604 -1221.031256
199 8760.582589 -1218.94684
200 8818.110156 -1221.658566
201 8874.317387 -1219.76819
202 8931.379297 -1221.200231
203 8988.303494 -1220.058048
204 9045.237712 -1219.77335
205 9102.797204 -1220.183627
206 9158.2346 -1220.408912
207 9213.641003 -1218.583548
208 9270.128289 -1221.096059
209 9325.747626 -1220.0666
210 9381.558162 -1219.557564
211 9438.852373 -1219.330095
212 9495.594231 -1221.993273
213 9553.896612 -1223.187594
214 9611.398788 -1219.720465
215 9670.280587 -1218.745421
216 9727.735437 -1220.320216
217 9785.576854 -1218.623738
218 9842.847055 -1220.192794
219 9900.27451 -1218.631224
220 9960.404387 -1219.626358
221 10020.38233 -1218.897498
222 10081.03875 -1220.686493
223 10142.4862 -1222.619642
224 10203.81352 -1221.138795
225 10263.73111 -1219.75704
226 10322.01157 -1221.510761
227 10383.70518 -1220.942931
228 10444.83872 -1220.590052
229 10505.46391 -1220.915548
230 10565.21068 -1220.83712
231 10623.1464 -1220.426405
232 10681.92922 -1221.087758
233 10740.52835 -1219.740644
234 10799.10282 -1219.464613
235 10858.83727 -1218.971765
236 10917.77249 -1221.522595
237 10978.69096 -1219.90022
238 11039.15463 -1219.721981
239 11100.71659 -1218.736818
240 11160.92038 -1218.613676
241 11223.41109 -1220.718136
242 11286.18113 -1217.805265
243 11347.3255 -1221.471473
244 11409.22849 -1218.647645
245 11471.5558 -1218.650381
246 11534.54889 -1221.838598
247 11596.56306 -1216.725289
248 11658.5901 -1220.448722
249 11720.84754 -1218.200733
250 11781.86972 -1219.514898
251 11844.25683 -1219.122911
252 11905.98746 -1219.437825
253 11967.72706 -1220.671814
254 12029.60287 -1221.425081
255 12090.42101 -1220.465861
256 12151.72214 -1218.325056
257 12215.04968 -1221.885592
258 12276.16149 -1217.636273
259 12335.90203 -1218.246472
260 12396.63493 -1218.821679
261 12459.444 -1220.249942
262 12522.88157 -1218.853892
263 12585.31227 -1220.934625
264 12647.22732 -1219.792465
265 12707.95382 -1219.762632
266 12770.33686 -1217.819044
267 12833.20918 -1218.993567
268 12896.81557 -1219.114086
269 12958.97821 -1219.069654
270 13023.42444 -1219.832851
271 13086.90235 -1220.588547
272 13149.95939 -1217.843585
273 13214.8704 -1218.933695
274 13278.31332 -1218.540474
275 13342.29491 -1219.083535
276 13404.31573 -1221.076755
277 13465.29136 -1220.365732
278 13526.70887 -1217.986482
279 13590.52374 -1219.40228
280 13654.7832 -1217.818409
281 13718.30046 -1218.816018
282 13781.31456 -1220.176623
283 13843.58471 -1218.819039
284 13905.20301 -1218.231392
285 13967.45721 -1219.155412
286 14030.04214 -1219.768839
287 14092.84348 -1219.087143
288 14155.60131 -1217.773162
289 14218.77685 -1218.442359
290 14281.016 -1219.142937
291 14346.43126 -1218.540456
292 14409.63096 -1220.27516
293 14475.35352 -1220.047078
294 14541.72266 -1217.934667
295 14605.84911 -1219.121418
296 14673.27209 -1219.250748
297 14737.62863 -1220.368594
298 14802.66072 -1220.611695
299 14867.97276 -1221.085498
300 14935.65309 -1216.312149
301 15001.18007 -1218.441571

View File

@@ -1,22 +0,0 @@
from matplotlib import pyplot as plt
import numpy as np
# 使用 genfromtxt 函数读取 CSV 文件
data = np.genfromtxt('neatpython.csv', delimiter=',', skip_header=1) # 假设有一个头部行
mean_time, fitness_mean = data.T
fig, ax = plt.subplots()
ax.plot(mean_time, fitness_mean, color='green', label='NEAT-Python', linestyle=':')
ax.set_xlabel('Time (s)')
ax.set_ylabel('Average Fitness')
ax.set_xlim(0, 500)
ax.set_ylim(-2000, -1000)
ax.legend()
# ci = 1.96 * neatax_sem
# lower_bound = neatax_mean - ci
# upper_bound = neatax_mean + ci
# plt.plot(mean_time, fitness_mean, color='r', label='NEAT-Python')
# plt.fill_between(x_axis, lower_bound, upper_bound, color='red', alpha=0.2)
fig.show()

View File

@@ -1,111 +0,0 @@
from typing import Type
import jax
import time
import numpy as np
from algorithm import NEAT, HyperNEAT
from config import Config
from core import State, Algorithm, Problem
class Pipeline:
def __init__(self, config: Config, algorithm: Algorithm, problem_type: Type[Problem]):
self.config = config
self.algorithm = algorithm
self.problem = problem_type(config.problem)
if isinstance(algorithm, NEAT):
assert config.neat.inputs == self.problem.input_shape[-1]
elif isinstance(algorithm, HyperNEAT):
assert config.hyperneat.inputs == self.problem.input_shape[-1]
else:
raise NotImplementedError
self.act_func = self.algorithm.act
for _ in range(len(self.problem.input_shape) - 1):
self.act_func = jax.vmap(self.act_func, in_axes=(None, 0, None))
self.best_genome = None
self.best_fitness = float('-inf')
self.generation_timestamp = None
def setup(self):
key = jax.random.PRNGKey(self.config.basic.seed)
algorithm_key, evaluate_key = jax.random.split(key, 2)
state = State()
state = self.algorithm.setup(algorithm_key, state)
return state.update(
evaluate_key=evaluate_key
)
def step(self, state):
key, sub_key = jax.random.split(state.evaluate_key)
keys = jax.random.split(key, self.config.basic.pop_size)
pop = self.algorithm.ask(state)
pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0))(state, pop)
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0))(keys, state, self.act_func,
pop_transformed)
state = self.algorithm.tell(state, fitnesses)
return state.update(evaluate_key=sub_key), fitnesses
def auto_run(self, ini_state):
state = ini_state
for _ in range(self.config.basic.generation_limit):
self.generation_timestamp = time.time()
previous_pop = self.algorithm.ask(state)
state, fitnesses = self.step(state)
fitnesses = jax.device_get(fitnesses)
self.analysis(state, previous_pop, fitnesses)
if max(fitnesses) >= self.config.basic.fitness_target:
print("Fitness limit reached!")
return state, self.best_genome
print("Generation limit reached!")
return state, self.best_genome
def analysis(self, state, pop, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
new_timestamp = time.time()
cost_time = new_timestamp - self.generation_timestamp
max_idx = np.argmax(fitnesses)
if fitnesses[max_idx] > self.best_fitness:
self.best_fitness = fitnesses[max_idx]
self.best_genome = pop[max_idx]
member_count = jax.device_get(state.species_info.member_count)
species_sizes = [int(i) for i in member_count if i > 0]
print(f"Generation: {state.generation}",
f"species: {len(species_sizes)}, {species_sizes}",
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")
def show(self, state, genome):
transformed = self.algorithm.transform(state, genome)
self.problem.show(state.evaluate_key, state, self.act_func, transformed)
def pre_compile(self, state):
tic = time.time()
print("start compile")
self.step.lower(self, state).compile()
print(f"compile finished, cost time: {time.time() - tic}s")