refactor names;

delete useless
This commit is contained in:
wls2002
2023-09-15 22:33:21 +08:00
parent d317317ed2
commit 4efa9445d5
17 changed files with 10 additions and 661 deletions

View File

@@ -1,131 +0,0 @@
from typing import Callable
from time import time
import jax
from jax import numpy as jnp, vmap, jit
import gymnax
import numpy as np
from config import *
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
def conf_cartpole():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=500,
generation_limit=150,
pop_size=10000
),
neat=NeatConfig(
inputs=3,
outputs=1,
),
gene=NormalGeneConfig(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
problem=GymNaxConfig(
env_name='Pendulum-v1',
output_transform=lambda out: out * 2 # the action of pendulum is [-2, 2]
)
)
def batch_evaluate(
key,
alg_state,
genomes,
env_params,
batch_transform: Callable,
batch_act: Callable,
batch_reset: Callable,
batch_step: Callable,
):
alg_time, env_time, forward_time = 0, 0, 0
pop_size = genomes.nodes.shape[0]
alg_tic = time()
genomes_transform = batch_transform(alg_state, genomes)
alg_time += time() - alg_tic
reset_keys = jax.random.split(key, pop_size)
observations, states = batch_reset(reset_keys, env_params)
done = np.zeros(pop_size, dtype=bool)
fitnesses = np.zeros(pop_size)
while not np.all(done):
key, _ = jax.random.split(key)
vmap_keys = jax.random.split(key, pop_size)
forward_tic = time()
actions = batch_act(alg_state, observations, genomes_transform).block_until_ready()
forward_time += time() - forward_tic
env_tic = time()
observations, states, reward, current_done, _ = batch_step(vmap_keys, states, actions, env_params)
reward, current_done = jax.device_get([reward, current_done])
env_time += time() - env_tic
fitnesses += reward * np.logical_not(done)
done = np.logical_or(done, current_done)
return fitnesses, alg_time, env_time, forward_time
def main():
conf = conf_cartpole()
algorithm = NEAT(conf, NormalGene)
def act(state, inputs, genome):
res = algorithm.act(state, inputs, genome)
return conf.problem.output_transform(res)
batch_transform = jit(vmap(algorithm.transform, in_axes=(None, 0)))
# (state, obs, genome_transform) -> action
batch_act = jit(vmap(act, in_axes=(None, 0, 0)))
env, env_params = gymnax.make(conf.problem.env_name)
# (seed, params) -> (ini_obs, ini_state)
batch_reset = jit(vmap(env.reset, in_axes=(0, None)))
# (seed, state, action, params) -> (obs, state, reward, done, info)
batch_step = jit(vmap(env.step, in_axes=(0, 0, 0, None)))
key = jax.random.PRNGKey(conf.basic.seed)
alg_key, pro_key = jax.random.split(key)
alg_state = algorithm.setup(alg_key)
for i in range(conf.basic.generation_limit):
total_tic = time()
pro_key, _ = jax.random.split(pro_key)
fitnesses, a1, env_time, forward_time = batch_evaluate(
pro_key,
alg_state,
algorithm.ask(alg_state),
env_params,
batch_transform,
batch_act,
batch_reset,
batch_step
)
alg_tic = time()
alg_state = algorithm.tell(alg_state, fitnesses)
alg_state = jax.tree_map(lambda x: x.block_until_ready(), alg_state)
a2 = time() - alg_tic
alg_time = a1 + a2
total_time = time() - total_tic
print(f"generation: {i}, alg_time: {alg_time:.2f}, env_time: {env_time:.2f}, forward_time: {forward_time:.2f}, total_time: {total_time: .2f}, "
f"max_fitness: {np.max(fitnesses):.2f}", f"avg_fitness: {np.mean(fitnesses):.2f}")
if __name__ == '__main__':
main()

View File

@@ -1,33 +0,0 @@
import jax.random
import numpy as np
import jax.numpy as jnp
import time
def random_array(key):
return jax.random.normal(key, (1000,))
def random_array_np():
return np.random.normal(size=(1000,))
def t_jax():
key = jax.random.PRNGKey(42)
max_li = []
tic = time.time()
for _ in range(100):
key, sub_key = jax.random.split(key)
array = random_array(sub_key)
array = jax.device_get(array)
max_li.append(max(array))
print(max_li, time.time() - tic)
def t_np():
max_li = []
tic = time.time()
for _ in range(100):
max_li.append(max(random_array_np()))
print(max_li, time.time() - tic)
if __name__ == '__main__':
t_np()

View File

@@ -1,5 +1,5 @@
from config import * from config import *
from pipeline_jitable_env import Pipeline from pipeline import Pipeline
from algorithm import NEAT from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.func_fit import XOR, FuncFitConfig from problem.func_fit import XOR, FuncFitConfig

View File

@@ -1,5 +1,5 @@
from config import * from config import *
from pipeline_jitable_env import Pipeline from pipeline import Pipeline
from algorithm.neat import NormalGene, NormalGeneConfig from algorithm.neat import NormalGene, NormalGeneConfig
from algorithm.hyperneat import HyperNEAT, NormalSubstrate, NormalSubstrateConfig from algorithm.hyperneat import HyperNEAT, NormalSubstrate, NormalSubstrateConfig
from problem.func_fit import XOR3d, FuncFitConfig from problem.func_fit import XOR3d, FuncFitConfig

View File

@@ -1,5 +1,5 @@
from config import * from config import *
from pipeline_jitable_env import Pipeline from pipeline import Pipeline
from algorithm import NEAT from algorithm import NEAT
from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig
from problem.func_fit import XOR3d, FuncFitConfig from problem.func_fit import XOR3d, FuncFitConfig

View File

@@ -1,7 +1,7 @@
import jax.numpy as jnp import jax.numpy as jnp
from config import * from config import *
from pipeline_jitable_env import Pipeline from pipeline import Pipeline
from algorithm import NEAT from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv from problem.rl_env import GymNaxConfig, GymNaxEnv

View File

@@ -1,7 +1,7 @@
import jax.numpy as jnp import jax.numpy as jnp
from config import * from config import *
from pipeline_jitable_env import Pipeline from pipeline import Pipeline
from algorithm import NEAT from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv from problem.rl_env import GymNaxConfig, GymNaxEnv

View File

@@ -1,7 +1,7 @@
import jax.numpy as jnp import jax.numpy as jnp
from config import * from config import *
from pipeline_jitable_env import Pipeline from pipeline import Pipeline
from algorithm import NEAT from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig from algorithm.neat.gene import NormalGene, NormalGeneConfig
from algorithm.hyperneat import HyperNEAT, NormalSubstrateConfig, NormalSubstrate from algorithm.hyperneat import HyperNEAT, NormalSubstrateConfig, NormalSubstrate

View File

@@ -1,7 +1,7 @@
import jax.numpy as jnp import jax.numpy as jnp
from config import * from config import *
from pipeline_jitable_env import Pipeline from pipeline import Pipeline
from algorithm import NEAT from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv from problem.rl_env import GymNaxConfig, GymNaxEnv

View File

@@ -1,7 +1,7 @@
import jax.numpy as jnp import jax.numpy as jnp
from config import * from config import *
from pipeline_jitable_env import Pipeline from pipeline import Pipeline
from algorithm import NEAT from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv from problem.rl_env import GymNaxConfig, GymNaxEnv

View File

@@ -1,7 +1,7 @@
import jax.numpy as jnp import jax.numpy as jnp
from config import * from config import *
from pipeline_jitable_env import Pipeline from pipeline import Pipeline
from algorithm import NEAT from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv from problem.rl_env import GymNaxConfig, GymNaxEnv

View File

@@ -1,5 +1,5 @@
from config import * from config import *
from pipeline_jitable_env import Pipeline from pipeline import Pipeline
from algorithm import NEAT from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv from problem.rl_env import GymNaxConfig, GymNaxEnv

View File

@@ -1,53 +0,0 @@
import time
import jax.numpy as jnp
from config import *
from pipeline_jitable_env import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
def conf_with_seed(seed):
return Config(
basic=BasicConfig(
seed=seed,
fitness_target=501,
pop_size=10000,
generation_limit=100
),
neat=NeatConfig(
inputs=4,
outputs=1,
max_species=10000
),
gene=NormalGeneConfig(
activation_default=Act.sigmoid,
activation_options=(Act.sigmoid,),
),
problem=GymNaxConfig(
env_name='CartPole-v1',
output_transform=lambda out: jnp.where(out[0] > 0.5, 1, 0) # the action of cartpole is {0, 1}
)
)
if __name__ == '__main__':
times = []
for seed in range(10):
conf = conf_with_seed(seed)
algorithm = NEAT(conf, NormalGene)
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
state = pipeline.setup()
pipeline.pre_compile(state)
tic = time.time()
state, best = pipeline.auto_run(state)
time_cost = time.time() - tic
times.append(time_cost)
print(times)
print(f"total_times: {times}")

View File

@@ -1,301 +0,0 @@
mean_totaltime,Average Fitness
16.23358095,-1531.406262
34.16262472,-1469.014171
52.56341431,-1429.524865
68.70164006,-1398.485718
87.61529462,-1369.068788
105.4480726,-1338.810888
120.3260972,-1308.414005
134.2059872,-1280.022147
148.1975725,-1258.681994
161.0129987,-1248.612688
174.4836712,-1242.987934
187.5842981,-1245.195745
201.7285712,-1244.295623
217.0596706,-1241.958828
233.1745317,-1242.23715
249.9489457,-1240.606505
265.2109861,-1241.584935
282.1818033,-1242.254389
301.1290779,-1240.280126
317.5243011,-1238.729867
334.9861856,-1238.403718
354.7175949,-1236.6442
373.5419451,-1238.018048
391.36524,-1236.910213
410.9625152,-1235.673599
430.0237711,-1234.652077
449.6226187,-1236.544195
468.2591698,-1233.919256
489.2314327,-1236.904311
510.1973371,-1236.444161
530.6520383,-1234.882298
552.1866553,-1236.077152
574.0524314,-1235.339124
595.665635,-1233.053687
618.1490554,-1235.346718
640.9171209,-1234.058986
663.9380282,-1233.082789
687.5950083,-1233.227056
711.6376721,-1232.342045
737.2806357,-1232.589748
762.8286344,-1229.058886
787.2488623,-1233.670237
812.7839067,-1232.203953
837.8952895,-1231.658721
864.4482835,-1232.48289
891.955594,-1230.740013
920.0526731,-1229.984369
949.259579,-1230.266798
980.3165732,-1228.807429
1010.551222,-1229.953455
1042.59788,-1230.538116
1073.075853,-1228.559482
1106.577571,-1229.979631
1139.959677,-1229.714391
1173.958232,-1228.932949
1209.781649,-1226.865015
1245.958032,-1227.553255
1281.796618,-1228.989607
1321.521287,-1227.193127
1361.788322,-1227.00891
1401.135148,-1226.937434
1440.664442,-1228.277831
1480.485734,-1226.27974
1523.433305,-1224.812807
1566.113366,-1226.587234
1610.083268,-1227.333781
1654.933877,-1227.395313
1700.679548,-1223.551766
1747.142431,-1227.437868
1794.300387,-1225.468609
1840.876058,-1225.929573
1888.193045,-1227.768213
1936.7367,-1224.390808
1984.760717,-1227.475326
2033.411831,-1223.673373
2084.356429,-1223.815094
2135.077312,-1224.983295
2186.35972,-1224.084667
2238.44057,-1226.530143
2292.238697,-1223.684731
2345.740624,-1224.418812
2399.445412,-1224.068882
2453.074848,-1223.038722
2506.191465,-1224.688486
2560.467831,-1224.369871
2615.123451,-1223.741346
2669.90867,-1224.06301
2723.421276,-1223.571537
2777.560479,-1221.698283
2832.592864,-1223.216613
2887.017148,-1223.993198
2941.483496,-1224.1121
2995.423772,-1223.169763
3050.948137,-1223.459228
3106.069071,-1223.111025
3159.795594,-1224.609992
3214.050797,-1223.487228
3269.429128,-1223.715971
3324.581811,-1223.355198
3380.680088,-1223.137084
3436.805942,-1222.142975
3492.414253,-1222.925149
3549.505774,-1223.379747
3607.501727,-1222.272351
3664.639311,-1224.368308
3722.034017,-1223.901903
3779.537015,-1222.94451
3836.6152,-1221.838935
3894.218441,-1223.663879
3951.011881,-1224.02154
4007.885495,-1223.465968
4063.988944,-1222.439965
4122.273734,-1221.786334
4179.429916,-1222.157602
4237.573861,-1223.250346
4295.552148,-1222.545104
4353.574564,-1222.514246
4411.656075,-1222.108242
4468.160373,-1222.456373
4523.770186,-1221.893127
4578.243049,-1222.369885
4633.660001,-1221.74155
4689.995503,-1221.485722
4745.046551,-1222.821859
4799.297272,-1220.567495
4855.528112,-1221.962994
4912.181782,-1223.706512
4968.106065,-1223.249642
5023.349784,-1223.88398
5077.95836,-1221.843378
5132.109917,-1220.167506
5185.437263,-1222.633295
5239.11293,-1220.070837
5292.425083,-1222.163853
5344.871339,-1221.905685
5398.5643,-1221.195668
5453.522891,-1220.820716
5510.54199,-1221.871458
5565.892794,-1221.127944
5619.692913,-1221.370239
5672.505697,-1220.411562
5726.742534,-1219.956466
5780.314767,-1223.574739
5831.786025,-1220.038731
5883.706538,-1222.285694
5935.963855,-1219.789742
5990.031264,-1220.348486
6042.056406,-1221.695037
6094.931383,-1222.573256
6149.866284,-1220.799972
6201.410016,-1223.62492
6254.863233,-1221.173897
6307.299802,-1218.384144
6358.419653,-1221.949964
6410.39371,-1220.695003
6462.294865,-1221.354209
6514.29837,-1220.013649
6566.028445,-1221.057066
6619.112563,-1220.250728
6671.586666,-1220.74064
6724.056905,-1220.696191
6778.867936,-1221.342228
6834.066508,-1220.67166
6887.848139,-1221.681324
6940.725238,-1221.786548
6993.728177,-1220.293248
7047.068145,-1220.784974
7101.058736,-1221.296277
7155.695031,-1220.314099
7211.131008,-1220.001403
7264.672096,-1222.582639
7319.855532,-1220.218512
7373.416443,-1222.413372
7426.826476,-1221.740957
7481.033519,-1220.12891
7535.098597,-1220.954482
7589.86269,-1221.591681
7645.189188,-1219.780434
7700.019942,-1219.86627
7754.408913,-1221.254077
7809.17354,-1220.560716
7863.955751,-1219.303522
7919.377039,-1220.023321
7973.91837,-1221.021169
8030.191454,-1221.020316
8086.216218,-1219.728618
8143.879354,-1221.815614
8200.316735,-1219.043831
8257.940883,-1220.157413
8314.852015,-1220.684859
8370.81427,-1220.370599
8428.692646,-1219.714431
8484.287711,-1218.305622
8537.958919,-1221.428118
8594.432063,-1219.59307
8649.300665,-1220.87185
8704.677604,-1221.031256
8760.582589,-1218.94684
8818.110156,-1221.658566
8874.317387,-1219.76819
8931.379297,-1221.200231
8988.303494,-1220.058048
9045.237712,-1219.77335
9102.797204,-1220.183627
9158.2346,-1220.408912
9213.641003,-1218.583548
9270.128289,-1221.096059
9325.747626,-1220.0666
9381.558162,-1219.557564
9438.852373,-1219.330095
9495.594231,-1221.993273
9553.896612,-1223.187594
9611.398788,-1219.720465
9670.280587,-1218.745421
9727.735437,-1220.320216
9785.576854,-1218.623738
9842.847055,-1220.192794
9900.27451,-1218.631224
9960.404387,-1219.626358
10020.38233,-1218.897498
10081.03875,-1220.686493
10142.4862,-1222.619642
10203.81352,-1221.138795
10263.73111,-1219.75704
10322.01157,-1221.510761
10383.70518,-1220.942931
10444.83872,-1220.590052
10505.46391,-1220.915548
10565.21068,-1220.83712
10623.1464,-1220.426405
10681.92922,-1221.087758
10740.52835,-1219.740644
10799.10282,-1219.464613
10858.83727,-1218.971765
10917.77249,-1221.522595
10978.69096,-1219.90022
11039.15463,-1219.721981
11100.71659,-1218.736818
11160.92038,-1218.613676
11223.41109,-1220.718136
11286.18113,-1217.805265
11347.3255,-1221.471473
11409.22849,-1218.647645
11471.5558,-1218.650381
11534.54889,-1221.838598
11596.56306,-1216.725289
11658.5901,-1220.448722
11720.84754,-1218.200733
11781.86972,-1219.514898
11844.25683,-1219.122911
11905.98746,-1219.437825
11967.72706,-1220.671814
12029.60287,-1221.425081
12090.42101,-1220.465861
12151.72214,-1218.325056
12215.04968,-1221.885592
12276.16149,-1217.636273
12335.90203,-1218.246472
12396.63493,-1218.821679
12459.444,-1220.249942
12522.88157,-1218.853892
12585.31227,-1220.934625
12647.22732,-1219.792465
12707.95382,-1219.762632
12770.33686,-1217.819044
12833.20918,-1218.993567
12896.81557,-1219.114086
12958.97821,-1219.069654
13023.42444,-1219.832851
13086.90235,-1220.588547
13149.95939,-1217.843585
13214.8704,-1218.933695
13278.31332,-1218.540474
13342.29491,-1219.083535
13404.31573,-1221.076755
13465.29136,-1220.365732
13526.70887,-1217.986482
13590.52374,-1219.40228
13654.7832,-1217.818409
13718.30046,-1218.816018
13781.31456,-1220.176623
13843.58471,-1218.819039
13905.20301,-1218.231392
13967.45721,-1219.155412
14030.04214,-1219.768839
14092.84348,-1219.087143
14155.60131,-1217.773162
14218.77685,-1218.442359
14281.016,-1219.142937
14346.43126,-1218.540456
14409.63096,-1220.27516
14475.35352,-1220.047078
14541.72266,-1217.934667
14605.84911,-1219.121418
14673.27209,-1219.250748
14737.62863,-1220.368594
14802.66072,-1220.611695
14867.97276,-1221.085498
14935.65309,-1216.312149
15001.18007,-1218.441571
1 mean_totaltime Average Fitness
2 16.23358095 -1531.406262
3 34.16262472 -1469.014171
4 52.56341431 -1429.524865
5 68.70164006 -1398.485718
6 87.61529462 -1369.068788
7 105.4480726 -1338.810888
8 120.3260972 -1308.414005
9 134.2059872 -1280.022147
10 148.1975725 -1258.681994
11 161.0129987 -1248.612688
12 174.4836712 -1242.987934
13 187.5842981 -1245.195745
14 201.7285712 -1244.295623
15 217.0596706 -1241.958828
16 233.1745317 -1242.23715
17 249.9489457 -1240.606505
18 265.2109861 -1241.584935
19 282.1818033 -1242.254389
20 301.1290779 -1240.280126
21 317.5243011 -1238.729867
22 334.9861856 -1238.403718
23 354.7175949 -1236.6442
24 373.5419451 -1238.018048
25 391.36524 -1236.910213
26 410.9625152 -1235.673599
27 430.0237711 -1234.652077
28 449.6226187 -1236.544195
29 468.2591698 -1233.919256
30 489.2314327 -1236.904311
31 510.1973371 -1236.444161
32 530.6520383 -1234.882298
33 552.1866553 -1236.077152
34 574.0524314 -1235.339124
35 595.665635 -1233.053687
36 618.1490554 -1235.346718
37 640.9171209 -1234.058986
38 663.9380282 -1233.082789
39 687.5950083 -1233.227056
40 711.6376721 -1232.342045
41 737.2806357 -1232.589748
42 762.8286344 -1229.058886
43 787.2488623 -1233.670237
44 812.7839067 -1232.203953
45 837.8952895 -1231.658721
46 864.4482835 -1232.48289
47 891.955594 -1230.740013
48 920.0526731 -1229.984369
49 949.259579 -1230.266798
50 980.3165732 -1228.807429
51 1010.551222 -1229.953455
52 1042.59788 -1230.538116
53 1073.075853 -1228.559482
54 1106.577571 -1229.979631
55 1139.959677 -1229.714391
56 1173.958232 -1228.932949
57 1209.781649 -1226.865015
58 1245.958032 -1227.553255
59 1281.796618 -1228.989607
60 1321.521287 -1227.193127
61 1361.788322 -1227.00891
62 1401.135148 -1226.937434
63 1440.664442 -1228.277831
64 1480.485734 -1226.27974
65 1523.433305 -1224.812807
66 1566.113366 -1226.587234
67 1610.083268 -1227.333781
68 1654.933877 -1227.395313
69 1700.679548 -1223.551766
70 1747.142431 -1227.437868
71 1794.300387 -1225.468609
72 1840.876058 -1225.929573
73 1888.193045 -1227.768213
74 1936.7367 -1224.390808
75 1984.760717 -1227.475326
76 2033.411831 -1223.673373
77 2084.356429 -1223.815094
78 2135.077312 -1224.983295
79 2186.35972 -1224.084667
80 2238.44057 -1226.530143
81 2292.238697 -1223.684731
82 2345.740624 -1224.418812
83 2399.445412 -1224.068882
84 2453.074848 -1223.038722
85 2506.191465 -1224.688486
86 2560.467831 -1224.369871
87 2615.123451 -1223.741346
88 2669.90867 -1224.06301
89 2723.421276 -1223.571537
90 2777.560479 -1221.698283
91 2832.592864 -1223.216613
92 2887.017148 -1223.993198
93 2941.483496 -1224.1121
94 2995.423772 -1223.169763
95 3050.948137 -1223.459228
96 3106.069071 -1223.111025
97 3159.795594 -1224.609992
98 3214.050797 -1223.487228
99 3269.429128 -1223.715971
100 3324.581811 -1223.355198
101 3380.680088 -1223.137084
102 3436.805942 -1222.142975
103 3492.414253 -1222.925149
104 3549.505774 -1223.379747
105 3607.501727 -1222.272351
106 3664.639311 -1224.368308
107 3722.034017 -1223.901903
108 3779.537015 -1222.94451
109 3836.6152 -1221.838935
110 3894.218441 -1223.663879
111 3951.011881 -1224.02154
112 4007.885495 -1223.465968
113 4063.988944 -1222.439965
114 4122.273734 -1221.786334
115 4179.429916 -1222.157602
116 4237.573861 -1223.250346
117 4295.552148 -1222.545104
118 4353.574564 -1222.514246
119 4411.656075 -1222.108242
120 4468.160373 -1222.456373
121 4523.770186 -1221.893127
122 4578.243049 -1222.369885
123 4633.660001 -1221.74155
124 4689.995503 -1221.485722
125 4745.046551 -1222.821859
126 4799.297272 -1220.567495
127 4855.528112 -1221.962994
128 4912.181782 -1223.706512
129 4968.106065 -1223.249642
130 5023.349784 -1223.88398
131 5077.95836 -1221.843378
132 5132.109917 -1220.167506
133 5185.437263 -1222.633295
134 5239.11293 -1220.070837
135 5292.425083 -1222.163853
136 5344.871339 -1221.905685
137 5398.5643 -1221.195668
138 5453.522891 -1220.820716
139 5510.54199 -1221.871458
140 5565.892794 -1221.127944
141 5619.692913 -1221.370239
142 5672.505697 -1220.411562
143 5726.742534 -1219.956466
144 5780.314767 -1223.574739
145 5831.786025 -1220.038731
146 5883.706538 -1222.285694
147 5935.963855 -1219.789742
148 5990.031264 -1220.348486
149 6042.056406 -1221.695037
150 6094.931383 -1222.573256
151 6149.866284 -1220.799972
152 6201.410016 -1223.62492
153 6254.863233 -1221.173897
154 6307.299802 -1218.384144
155 6358.419653 -1221.949964
156 6410.39371 -1220.695003
157 6462.294865 -1221.354209
158 6514.29837 -1220.013649
159 6566.028445 -1221.057066
160 6619.112563 -1220.250728
161 6671.586666 -1220.74064
162 6724.056905 -1220.696191
163 6778.867936 -1221.342228
164 6834.066508 -1220.67166
165 6887.848139 -1221.681324
166 6940.725238 -1221.786548
167 6993.728177 -1220.293248
168 7047.068145 -1220.784974
169 7101.058736 -1221.296277
170 7155.695031 -1220.314099
171 7211.131008 -1220.001403
172 7264.672096 -1222.582639
173 7319.855532 -1220.218512
174 7373.416443 -1222.413372
175 7426.826476 -1221.740957
176 7481.033519 -1220.12891
177 7535.098597 -1220.954482
178 7589.86269 -1221.591681
179 7645.189188 -1219.780434
180 7700.019942 -1219.86627
181 7754.408913 -1221.254077
182 7809.17354 -1220.560716
183 7863.955751 -1219.303522
184 7919.377039 -1220.023321
185 7973.91837 -1221.021169
186 8030.191454 -1221.020316
187 8086.216218 -1219.728618
188 8143.879354 -1221.815614
189 8200.316735 -1219.043831
190 8257.940883 -1220.157413
191 8314.852015 -1220.684859
192 8370.81427 -1220.370599
193 8428.692646 -1219.714431
194 8484.287711 -1218.305622
195 8537.958919 -1221.428118
196 8594.432063 -1219.59307
197 8649.300665 -1220.87185
198 8704.677604 -1221.031256
199 8760.582589 -1218.94684
200 8818.110156 -1221.658566
201 8874.317387 -1219.76819
202 8931.379297 -1221.200231
203 8988.303494 -1220.058048
204 9045.237712 -1219.77335
205 9102.797204 -1220.183627
206 9158.2346 -1220.408912
207 9213.641003 -1218.583548
208 9270.128289 -1221.096059
209 9325.747626 -1220.0666
210 9381.558162 -1219.557564
211 9438.852373 -1219.330095
212 9495.594231 -1221.993273
213 9553.896612 -1223.187594
214 9611.398788 -1219.720465
215 9670.280587 -1218.745421
216 9727.735437 -1220.320216
217 9785.576854 -1218.623738
218 9842.847055 -1220.192794
219 9900.27451 -1218.631224
220 9960.404387 -1219.626358
221 10020.38233 -1218.897498
222 10081.03875 -1220.686493
223 10142.4862 -1222.619642
224 10203.81352 -1221.138795
225 10263.73111 -1219.75704
226 10322.01157 -1221.510761
227 10383.70518 -1220.942931
228 10444.83872 -1220.590052
229 10505.46391 -1220.915548
230 10565.21068 -1220.83712
231 10623.1464 -1220.426405
232 10681.92922 -1221.087758
233 10740.52835 -1219.740644
234 10799.10282 -1219.464613
235 10858.83727 -1218.971765
236 10917.77249 -1221.522595
237 10978.69096 -1219.90022
238 11039.15463 -1219.721981
239 11100.71659 -1218.736818
240 11160.92038 -1218.613676
241 11223.41109 -1220.718136
242 11286.18113 -1217.805265
243 11347.3255 -1221.471473
244 11409.22849 -1218.647645
245 11471.5558 -1218.650381
246 11534.54889 -1221.838598
247 11596.56306 -1216.725289
248 11658.5901 -1220.448722
249 11720.84754 -1218.200733
250 11781.86972 -1219.514898
251 11844.25683 -1219.122911
252 11905.98746 -1219.437825
253 11967.72706 -1220.671814
254 12029.60287 -1221.425081
255 12090.42101 -1220.465861
256 12151.72214 -1218.325056
257 12215.04968 -1221.885592
258 12276.16149 -1217.636273
259 12335.90203 -1218.246472
260 12396.63493 -1218.821679
261 12459.444 -1220.249942
262 12522.88157 -1218.853892
263 12585.31227 -1220.934625
264 12647.22732 -1219.792465
265 12707.95382 -1219.762632
266 12770.33686 -1217.819044
267 12833.20918 -1218.993567
268 12896.81557 -1219.114086
269 12958.97821 -1219.069654
270 13023.42444 -1219.832851
271 13086.90235 -1220.588547
272 13149.95939 -1217.843585
273 13214.8704 -1218.933695
274 13278.31332 -1218.540474
275 13342.29491 -1219.083535
276 13404.31573 -1221.076755
277 13465.29136 -1220.365732
278 13526.70887 -1217.986482
279 13590.52374 -1219.40228
280 13654.7832 -1217.818409
281 13718.30046 -1218.816018
282 13781.31456 -1220.176623
283 13843.58471 -1218.819039
284 13905.20301 -1218.231392
285 13967.45721 -1219.155412
286 14030.04214 -1219.768839
287 14092.84348 -1219.087143
288 14155.60131 -1217.773162
289 14218.77685 -1218.442359
290 14281.016 -1219.142937
291 14346.43126 -1218.540456
292 14409.63096 -1220.27516
293 14475.35352 -1220.047078
294 14541.72266 -1217.934667
295 14605.84911 -1219.121418
296 14673.27209 -1219.250748
297 14737.62863 -1220.368594
298 14802.66072 -1220.611695
299 14867.97276 -1221.085498
300 14935.65309 -1216.312149
301 15001.18007 -1218.441571

View File

@@ -1,22 +0,0 @@
from matplotlib import pyplot as plt
import numpy as np
# 使用 genfromtxt 函数读取 CSV 文件
data = np.genfromtxt('neatpython.csv', delimiter=',', skip_header=1) # 假设有一个头部行
mean_time, fitness_mean = data.T
fig, ax = plt.subplots()
ax.plot(mean_time, fitness_mean, color='green', label='NEAT-Python', linestyle=':')
ax.set_xlabel('Time (s)')
ax.set_ylabel('Average Fitness')
ax.set_xlim(0, 500)
ax.set_ylim(-2000, -1000)
ax.legend()
# ci = 1.96 * neatax_sem
# lower_bound = neatax_mean - ci
# upper_bound = neatax_mean + ci
# plt.plot(mean_time, fitness_mean, color='r', label='NEAT-Python')
# plt.fill_between(x_axis, lower_bound, upper_bound, color='red', alpha=0.2)
fig.show()

View File

@@ -1,111 +0,0 @@
from typing import Type
import jax
import time
import numpy as np
from algorithm import NEAT, HyperNEAT
from config import Config
from core import State, Algorithm, Problem
class Pipeline:
def __init__(self, config: Config, algorithm: Algorithm, problem_type: Type[Problem]):
self.config = config
self.algorithm = algorithm
self.problem = problem_type(config.problem)
if isinstance(algorithm, NEAT):
assert config.neat.inputs == self.problem.input_shape[-1]
elif isinstance(algorithm, HyperNEAT):
assert config.hyperneat.inputs == self.problem.input_shape[-1]
else:
raise NotImplementedError
self.act_func = self.algorithm.act
for _ in range(len(self.problem.input_shape) - 1):
self.act_func = jax.vmap(self.act_func, in_axes=(None, 0, None))
self.best_genome = None
self.best_fitness = float('-inf')
self.generation_timestamp = None
def setup(self):
key = jax.random.PRNGKey(self.config.basic.seed)
algorithm_key, evaluate_key = jax.random.split(key, 2)
state = State()
state = self.algorithm.setup(algorithm_key, state)
return state.update(
evaluate_key=evaluate_key
)
def step(self, state):
key, sub_key = jax.random.split(state.evaluate_key)
keys = jax.random.split(key, self.config.basic.pop_size)
pop = self.algorithm.ask(state)
pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0))(state, pop)
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0))(keys, state, self.act_func,
pop_transformed)
state = self.algorithm.tell(state, fitnesses)
return state.update(evaluate_key=sub_key), fitnesses
def auto_run(self, ini_state):
state = ini_state
for _ in range(self.config.basic.generation_limit):
self.generation_timestamp = time.time()
previous_pop = self.algorithm.ask(state)
state, fitnesses = self.step(state)
fitnesses = jax.device_get(fitnesses)
self.analysis(state, previous_pop, fitnesses)
if max(fitnesses) >= self.config.basic.fitness_target:
print("Fitness limit reached!")
return state, self.best_genome
print("Generation limit reached!")
return state, self.best_genome
def analysis(self, state, pop, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
new_timestamp = time.time()
cost_time = new_timestamp - self.generation_timestamp
max_idx = np.argmax(fitnesses)
if fitnesses[max_idx] > self.best_fitness:
self.best_fitness = fitnesses[max_idx]
self.best_genome = pop[max_idx]
member_count = jax.device_get(state.species_info.member_count)
species_sizes = [int(i) for i in member_count if i > 0]
print(f"Generation: {state.generation}",
f"species: {len(species_sizes)}, {species_sizes}",
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")
def show(self, state, genome):
transformed = self.algorithm.transform(state, genome)
self.problem.show(state.evaluate_key, state, self.act_func, transformed)
def pre_compile(self, state):
tic = time.time()
print("start compile")
self.step.lower(self, state).compile()
print(f"compile finished, cost time: {time.time() - tic}s")