From 4efa9445d54ccec9e7629b2c82558494fed83c6d Mon Sep 17 00:00:00 2001 From: wls2002 Date: Fri, 15 Sep 2023 22:33:21 +0800 Subject: [PATCH] refactor names; delete useless --- aaai_exp.py | 131 --------- examples/a.py | 33 --- examples/func_fit/xor.py | 2 +- examples/func_fit/xor_hyperneat.py | 2 +- examples/func_fit/xor_recurrent.py | 2 +- examples/gymnax/acrobot.py | 2 +- examples/gymnax/cartpole.py | 2 +- examples/gymnax/cartpole_hyperneat.py | 2 +- examples/gymnax/mountain_car.py | 2 +- examples/gymnax/mountain_car_continuous.py | 2 +- examples/gymnax/pendulum.py | 2 +- examples/gymnax/reacher.py | 2 +- exp_for_hardwares.py | 53 ---- graph/fitness-time/neatpython.csv | 301 --------------------- graph/fitness-time/script.py | 22 -- pipeline_jitable_env.py => pipeline.py | 0 pipeline_time.py | 111 -------- 17 files changed, 10 insertions(+), 661 deletions(-) delete mode 100644 aaai_exp.py delete mode 100644 examples/a.py delete mode 100644 exp_for_hardwares.py delete mode 100644 graph/fitness-time/neatpython.csv delete mode 100644 graph/fitness-time/script.py rename pipeline_jitable_env.py => pipeline.py (100%) delete mode 100644 pipeline_time.py diff --git a/aaai_exp.py b/aaai_exp.py deleted file mode 100644 index 915da53..0000000 --- a/aaai_exp.py +++ /dev/null @@ -1,131 +0,0 @@ -from typing import Callable -from time import time - -import jax -from jax import numpy as jnp, vmap, jit -import gymnax -import numpy as np - -from config import * -from algorithm import NEAT -from algorithm.neat.gene import NormalGene, NormalGeneConfig -from problem.rl_env import GymNaxConfig, GymNaxEnv - - -def conf_cartpole(): - return Config( - basic=BasicConfig( - seed=42, - fitness_target=500, - generation_limit=150, - pop_size=10000 - ), - neat=NeatConfig( - inputs=3, - outputs=1, - ), - gene=NormalGeneConfig( - activation_default=Act.tanh, - activation_options=(Act.tanh,), - ), - problem=GymNaxConfig( - env_name='Pendulum-v1', - output_transform=lambda out: out * 2 # the action of pendulum is [-2, 2] - ) - ) - - -def batch_evaluate( - key, - alg_state, - genomes, - env_params, - batch_transform: Callable, - batch_act: Callable, - batch_reset: Callable, - batch_step: Callable, -): - alg_time, env_time, forward_time = 0, 0, 0 - pop_size = genomes.nodes.shape[0] - - alg_tic = time() - genomes_transform = batch_transform(alg_state, genomes) - alg_time += time() - alg_tic - - reset_keys = jax.random.split(key, pop_size) - observations, states = batch_reset(reset_keys, env_params) - - done = np.zeros(pop_size, dtype=bool) - fitnesses = np.zeros(pop_size) - - while not np.all(done): - key, _ = jax.random.split(key) - vmap_keys = jax.random.split(key, pop_size) - - forward_tic = time() - actions = batch_act(alg_state, observations, genomes_transform).block_until_ready() - forward_time += time() - forward_tic - - env_tic = time() - observations, states, reward, current_done, _ = batch_step(vmap_keys, states, actions, env_params) - reward, current_done = jax.device_get([reward, current_done]) - env_time += time() - env_tic - - fitnesses += reward * np.logical_not(done) - done = np.logical_or(done, current_done) - - return fitnesses, alg_time, env_time, forward_time - - -def main(): - conf = conf_cartpole() - algorithm = NEAT(conf, NormalGene) - - def act(state, inputs, genome): - res = algorithm.act(state, inputs, genome) - return conf.problem.output_transform(res) - - batch_transform = jit(vmap(algorithm.transform, in_axes=(None, 0))) - # (state, obs, genome_transform) -> action - batch_act = jit(vmap(act, in_axes=(None, 0, 0))) - - env, env_params = gymnax.make(conf.problem.env_name) - # (seed, params) -> (ini_obs, ini_state) - batch_reset = jit(vmap(env.reset, in_axes=(0, None))) - # (seed, state, action, params) -> (obs, state, reward, done, info) - batch_step = jit(vmap(env.step, in_axes=(0, 0, 0, None))) - - key = jax.random.PRNGKey(conf.basic.seed) - alg_key, pro_key = jax.random.split(key) - alg_state = algorithm.setup(alg_key) - - for i in range(conf.basic.generation_limit): - - total_tic = time() - - pro_key, _ = jax.random.split(pro_key) - - fitnesses, a1, env_time, forward_time = batch_evaluate( - pro_key, - alg_state, - algorithm.ask(alg_state), - env_params, - batch_transform, - batch_act, - batch_reset, - batch_step - ) - alg_tic = time() - alg_state = algorithm.tell(alg_state, fitnesses) - alg_state = jax.tree_map(lambda x: x.block_until_ready(), alg_state) - a2 = time() - alg_tic - - alg_time = a1 + a2 - total_time = time() - total_tic - - print(f"generation: {i}, alg_time: {alg_time:.2f}, env_time: {env_time:.2f}, forward_time: {forward_time:.2f}, total_time: {total_time: .2f}, " - f"max_fitness: {np.max(fitnesses):.2f}", f"avg_fitness: {np.mean(fitnesses):.2f}") - - -if __name__ == '__main__': - main() diff --git a/examples/a.py b/examples/a.py deleted file mode 100644 index d100c6e..0000000 --- a/examples/a.py +++ /dev/null @@ -1,33 +0,0 @@ -import jax.random -import numpy as np -import jax.numpy as jnp -import time - - -def random_array(key): - return jax.random.normal(key, (1000,)) - -def random_array_np(): - return np.random.normal(size=(1000,)) - - -def t_jax(): - key = jax.random.PRNGKey(42) - max_li = [] - tic = time.time() - for _ in range(100): - key, sub_key = jax.random.split(key) - array = random_array(sub_key) - array = jax.device_get(array) - max_li.append(max(array)) - print(max_li, time.time() - tic) - -def t_np(): - max_li = [] - tic = time.time() - for _ in range(100): - max_li.append(max(random_array_np())) - print(max_li, time.time() - tic) - -if __name__ == '__main__': - t_np() \ No newline at end of file diff --git a/examples/func_fit/xor.py b/examples/func_fit/xor.py index 8e5ca6c..a2d45ee 100644 --- a/examples/func_fit/xor.py +++ b/examples/func_fit/xor.py @@ -1,5 +1,5 @@ from config import * -from pipeline_jitable_env import Pipeline +from pipeline import Pipeline from algorithm import NEAT from algorithm.neat.gene import NormalGene, NormalGeneConfig from problem.func_fit import XOR, FuncFitConfig diff --git a/examples/func_fit/xor_hyperneat.py b/examples/func_fit/xor_hyperneat.py index 0148e28..cfd23f1 100644 --- a/examples/func_fit/xor_hyperneat.py +++ b/examples/func_fit/xor_hyperneat.py @@ -1,5 +1,5 @@ from config import * -from pipeline_jitable_env import Pipeline +from pipeline import Pipeline from algorithm.neat import NormalGene, NormalGeneConfig from algorithm.hyperneat import HyperNEAT, NormalSubstrate, NormalSubstrateConfig from problem.func_fit import XOR3d, FuncFitConfig diff --git a/examples/func_fit/xor_recurrent.py b/examples/func_fit/xor_recurrent.py index 6787a7d..d100fd8 100644 --- a/examples/func_fit/xor_recurrent.py +++ b/examples/func_fit/xor_recurrent.py @@ -1,5 +1,5 @@ from config import * -from pipeline_jitable_env import Pipeline +from pipeline import Pipeline from algorithm import NEAT from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig from problem.func_fit import XOR3d, FuncFitConfig diff --git a/examples/gymnax/acrobot.py b/examples/gymnax/acrobot.py index 0f6cdd0..3867ab3 100644 --- a/examples/gymnax/acrobot.py +++ b/examples/gymnax/acrobot.py @@ -1,7 +1,7 @@ import jax.numpy as jnp from config import * -from pipeline_jitable_env import Pipeline +from pipeline import Pipeline from algorithm import NEAT from algorithm.neat.gene import NormalGene, NormalGeneConfig from problem.rl_env import GymNaxConfig, GymNaxEnv diff --git a/examples/gymnax/cartpole.py b/examples/gymnax/cartpole.py index 7931e8b..d5564c5 100644 --- a/examples/gymnax/cartpole.py +++ b/examples/gymnax/cartpole.py @@ -1,7 +1,7 @@ import jax.numpy as jnp from config import * -from pipeline_jitable_env import Pipeline +from pipeline import Pipeline from algorithm import NEAT from algorithm.neat.gene import NormalGene, NormalGeneConfig from problem.rl_env import GymNaxConfig, GymNaxEnv diff --git a/examples/gymnax/cartpole_hyperneat.py b/examples/gymnax/cartpole_hyperneat.py index 792c380..3a689ff 100644 --- a/examples/gymnax/cartpole_hyperneat.py +++ b/examples/gymnax/cartpole_hyperneat.py @@ -1,7 +1,7 @@ import jax.numpy as jnp from config import * -from pipeline_jitable_env import Pipeline +from pipeline import Pipeline from algorithm import NEAT from algorithm.neat.gene import NormalGene, NormalGeneConfig from algorithm.hyperneat import HyperNEAT, NormalSubstrateConfig, NormalSubstrate diff --git a/examples/gymnax/mountain_car.py b/examples/gymnax/mountain_car.py index 7a897bb..6c3a43d 100644 --- a/examples/gymnax/mountain_car.py +++ b/examples/gymnax/mountain_car.py @@ -1,7 +1,7 @@ import jax.numpy as jnp from config import * -from pipeline_jitable_env import Pipeline +from pipeline import Pipeline from algorithm import NEAT from algorithm.neat.gene import NormalGene, NormalGeneConfig from problem.rl_env import GymNaxConfig, GymNaxEnv diff --git a/examples/gymnax/mountain_car_continuous.py b/examples/gymnax/mountain_car_continuous.py index d5b5a01..41169a4 100644 --- a/examples/gymnax/mountain_car_continuous.py +++ b/examples/gymnax/mountain_car_continuous.py @@ -1,7 +1,7 @@ import jax.numpy as jnp from config import * -from pipeline_jitable_env import Pipeline +from pipeline import Pipeline from algorithm import NEAT from algorithm.neat.gene import NormalGene, NormalGeneConfig from problem.rl_env import GymNaxConfig, GymNaxEnv diff --git a/examples/gymnax/pendulum.py b/examples/gymnax/pendulum.py index 6b91b5a..5a75832 100644 --- a/examples/gymnax/pendulum.py +++ b/examples/gymnax/pendulum.py @@ -1,7 +1,7 @@ import jax.numpy as jnp from config import * -from pipeline_jitable_env import Pipeline +from pipeline import Pipeline from algorithm import NEAT from algorithm.neat.gene import NormalGene, NormalGeneConfig from problem.rl_env import GymNaxConfig, GymNaxEnv diff --git a/examples/gymnax/reacher.py b/examples/gymnax/reacher.py index 08cf04c..39afdbd 100644 --- a/examples/gymnax/reacher.py +++ b/examples/gymnax/reacher.py @@ -1,5 +1,5 @@ from config import * -from pipeline_jitable_env import Pipeline +from pipeline import Pipeline from algorithm import NEAT from algorithm.neat.gene import NormalGene, NormalGeneConfig from problem.rl_env import GymNaxConfig, GymNaxEnv diff --git a/exp_for_hardwares.py b/exp_for_hardwares.py deleted file mode 100644 index 3bb52f4..0000000 --- a/exp_for_hardwares.py +++ /dev/null @@ -1,53 +0,0 @@ -import time - -import jax.numpy as jnp -from config import * -from pipeline_jitable_env import Pipeline -from algorithm import NEAT -from algorithm.neat.gene import NormalGene, NormalGeneConfig -from problem.rl_env import GymNaxConfig, GymNaxEnv - - -def conf_with_seed(seed): - return Config( - basic=BasicConfig( - seed=seed, - fitness_target=501, - pop_size=10000, - generation_limit=100 - ), - neat=NeatConfig( - inputs=4, - outputs=1, - max_species=10000 - ), - gene=NormalGeneConfig( - activation_default=Act.sigmoid, - activation_options=(Act.sigmoid,), - ), - problem=GymNaxConfig( - env_name='CartPole-v1', - output_transform=lambda out: jnp.where(out[0] > 0.5, 1, 0) # the action of cartpole is {0, 1} - ) - ) - - -if __name__ == '__main__': - - times = [] - - for seed in range(10): - conf = conf_with_seed(seed) - algorithm = NEAT(conf, NormalGene) - pipeline = Pipeline(conf, algorithm, GymNaxEnv) - state = pipeline.setup() - pipeline.pre_compile(state) - tic = time.time() - state, best = pipeline.auto_run(state) - time_cost = time.time() - tic - times.append(time_cost) - print(times) - - print(f"total_times: {times}") - - diff --git a/graph/fitness-time/neatpython.csv b/graph/fitness-time/neatpython.csv deleted file mode 100644 index f774585..0000000 --- a/graph/fitness-time/neatpython.csv +++ /dev/null @@ -1,301 +0,0 @@ -mean_totaltime,Average Fitness -16.23358095,-1531.406262 -34.16262472,-1469.014171 -52.56341431,-1429.524865 -68.70164006,-1398.485718 -87.61529462,-1369.068788 -105.4480726,-1338.810888 -120.3260972,-1308.414005 -134.2059872,-1280.022147 -148.1975725,-1258.681994 -161.0129987,-1248.612688 -174.4836712,-1242.987934 -187.5842981,-1245.195745 -201.7285712,-1244.295623 -217.0596706,-1241.958828 -233.1745317,-1242.23715 -249.9489457,-1240.606505 -265.2109861,-1241.584935 -282.1818033,-1242.254389 -301.1290779,-1240.280126 -317.5243011,-1238.729867 -334.9861856,-1238.403718 -354.7175949,-1236.6442 -373.5419451,-1238.018048 -391.36524,-1236.910213 -410.9625152,-1235.673599 -430.0237711,-1234.652077 -449.6226187,-1236.544195 -468.2591698,-1233.919256 -489.2314327,-1236.904311 -510.1973371,-1236.444161 -530.6520383,-1234.882298 -552.1866553,-1236.077152 -574.0524314,-1235.339124 -595.665635,-1233.053687 -618.1490554,-1235.346718 -640.9171209,-1234.058986 -663.9380282,-1233.082789 -687.5950083,-1233.227056 -711.6376721,-1232.342045 -737.2806357,-1232.589748 -762.8286344,-1229.058886 -787.2488623,-1233.670237 -812.7839067,-1232.203953 -837.8952895,-1231.658721 -864.4482835,-1232.48289 -891.955594,-1230.740013 -920.0526731,-1229.984369 -949.259579,-1230.266798 -980.3165732,-1228.807429 -1010.551222,-1229.953455 -1042.59788,-1230.538116 -1073.075853,-1228.559482 -1106.577571,-1229.979631 -1139.959677,-1229.714391 -1173.958232,-1228.932949 -1209.781649,-1226.865015 -1245.958032,-1227.553255 -1281.796618,-1228.989607 -1321.521287,-1227.193127 -1361.788322,-1227.00891 -1401.135148,-1226.937434 -1440.664442,-1228.277831 -1480.485734,-1226.27974 -1523.433305,-1224.812807 -1566.113366,-1226.587234 -1610.083268,-1227.333781 -1654.933877,-1227.395313 -1700.679548,-1223.551766 -1747.142431,-1227.437868 -1794.300387,-1225.468609 -1840.876058,-1225.929573 -1888.193045,-1227.768213 -1936.7367,-1224.390808 -1984.760717,-1227.475326 -2033.411831,-1223.673373 -2084.356429,-1223.815094 -2135.077312,-1224.983295 -2186.35972,-1224.084667 -2238.44057,-1226.530143 -2292.238697,-1223.684731 -2345.740624,-1224.418812 -2399.445412,-1224.068882 -2453.074848,-1223.038722 -2506.191465,-1224.688486 -2560.467831,-1224.369871 -2615.123451,-1223.741346 -2669.90867,-1224.06301 -2723.421276,-1223.571537 -2777.560479,-1221.698283 -2832.592864,-1223.216613 -2887.017148,-1223.993198 -2941.483496,-1224.1121 -2995.423772,-1223.169763 -3050.948137,-1223.459228 -3106.069071,-1223.111025 -3159.795594,-1224.609992 -3214.050797,-1223.487228 -3269.429128,-1223.715971 -3324.581811,-1223.355198 -3380.680088,-1223.137084 -3436.805942,-1222.142975 -3492.414253,-1222.925149 -3549.505774,-1223.379747 -3607.501727,-1222.272351 -3664.639311,-1224.368308 -3722.034017,-1223.901903 -3779.537015,-1222.94451 -3836.6152,-1221.838935 -3894.218441,-1223.663879 -3951.011881,-1224.02154 -4007.885495,-1223.465968 -4063.988944,-1222.439965 -4122.273734,-1221.786334 -4179.429916,-1222.157602 -4237.573861,-1223.250346 -4295.552148,-1222.545104 -4353.574564,-1222.514246 -4411.656075,-1222.108242 -4468.160373,-1222.456373 -4523.770186,-1221.893127 -4578.243049,-1222.369885 -4633.660001,-1221.74155 -4689.995503,-1221.485722 -4745.046551,-1222.821859 -4799.297272,-1220.567495 -4855.528112,-1221.962994 -4912.181782,-1223.706512 -4968.106065,-1223.249642 -5023.349784,-1223.88398 -5077.95836,-1221.843378 -5132.109917,-1220.167506 -5185.437263,-1222.633295 -5239.11293,-1220.070837 -5292.425083,-1222.163853 -5344.871339,-1221.905685 -5398.5643,-1221.195668 -5453.522891,-1220.820716 -5510.54199,-1221.871458 -5565.892794,-1221.127944 -5619.692913,-1221.370239 -5672.505697,-1220.411562 -5726.742534,-1219.956466 -5780.314767,-1223.574739 -5831.786025,-1220.038731 -5883.706538,-1222.285694 -5935.963855,-1219.789742 -5990.031264,-1220.348486 -6042.056406,-1221.695037 -6094.931383,-1222.573256 -6149.866284,-1220.799972 -6201.410016,-1223.62492 -6254.863233,-1221.173897 -6307.299802,-1218.384144 -6358.419653,-1221.949964 -6410.39371,-1220.695003 -6462.294865,-1221.354209 -6514.29837,-1220.013649 -6566.028445,-1221.057066 -6619.112563,-1220.250728 -6671.586666,-1220.74064 -6724.056905,-1220.696191 -6778.867936,-1221.342228 -6834.066508,-1220.67166 -6887.848139,-1221.681324 -6940.725238,-1221.786548 -6993.728177,-1220.293248 -7047.068145,-1220.784974 -7101.058736,-1221.296277 -7155.695031,-1220.314099 -7211.131008,-1220.001403 -7264.672096,-1222.582639 -7319.855532,-1220.218512 -7373.416443,-1222.413372 -7426.826476,-1221.740957 -7481.033519,-1220.12891 -7535.098597,-1220.954482 -7589.86269,-1221.591681 -7645.189188,-1219.780434 -7700.019942,-1219.86627 -7754.408913,-1221.254077 -7809.17354,-1220.560716 -7863.955751,-1219.303522 -7919.377039,-1220.023321 -7973.91837,-1221.021169 -8030.191454,-1221.020316 -8086.216218,-1219.728618 -8143.879354,-1221.815614 -8200.316735,-1219.043831 -8257.940883,-1220.157413 -8314.852015,-1220.684859 -8370.81427,-1220.370599 -8428.692646,-1219.714431 -8484.287711,-1218.305622 -8537.958919,-1221.428118 -8594.432063,-1219.59307 -8649.300665,-1220.87185 -8704.677604,-1221.031256 -8760.582589,-1218.94684 -8818.110156,-1221.658566 -8874.317387,-1219.76819 -8931.379297,-1221.200231 -8988.303494,-1220.058048 -9045.237712,-1219.77335 -9102.797204,-1220.183627 -9158.2346,-1220.408912 -9213.641003,-1218.583548 -9270.128289,-1221.096059 -9325.747626,-1220.0666 -9381.558162,-1219.557564 -9438.852373,-1219.330095 -9495.594231,-1221.993273 -9553.896612,-1223.187594 -9611.398788,-1219.720465 -9670.280587,-1218.745421 -9727.735437,-1220.320216 -9785.576854,-1218.623738 -9842.847055,-1220.192794 -9900.27451,-1218.631224 -9960.404387,-1219.626358 -10020.38233,-1218.897498 -10081.03875,-1220.686493 -10142.4862,-1222.619642 -10203.81352,-1221.138795 -10263.73111,-1219.75704 -10322.01157,-1221.510761 -10383.70518,-1220.942931 -10444.83872,-1220.590052 -10505.46391,-1220.915548 -10565.21068,-1220.83712 -10623.1464,-1220.426405 -10681.92922,-1221.087758 -10740.52835,-1219.740644 -10799.10282,-1219.464613 -10858.83727,-1218.971765 -10917.77249,-1221.522595 -10978.69096,-1219.90022 -11039.15463,-1219.721981 -11100.71659,-1218.736818 -11160.92038,-1218.613676 -11223.41109,-1220.718136 -11286.18113,-1217.805265 -11347.3255,-1221.471473 -11409.22849,-1218.647645 -11471.5558,-1218.650381 -11534.54889,-1221.838598 -11596.56306,-1216.725289 -11658.5901,-1220.448722 -11720.84754,-1218.200733 -11781.86972,-1219.514898 -11844.25683,-1219.122911 -11905.98746,-1219.437825 -11967.72706,-1220.671814 -12029.60287,-1221.425081 -12090.42101,-1220.465861 -12151.72214,-1218.325056 -12215.04968,-1221.885592 -12276.16149,-1217.636273 -12335.90203,-1218.246472 -12396.63493,-1218.821679 -12459.444,-1220.249942 -12522.88157,-1218.853892 -12585.31227,-1220.934625 -12647.22732,-1219.792465 -12707.95382,-1219.762632 -12770.33686,-1217.819044 -12833.20918,-1218.993567 -12896.81557,-1219.114086 -12958.97821,-1219.069654 -13023.42444,-1219.832851 -13086.90235,-1220.588547 -13149.95939,-1217.843585 -13214.8704,-1218.933695 -13278.31332,-1218.540474 -13342.29491,-1219.083535 -13404.31573,-1221.076755 -13465.29136,-1220.365732 -13526.70887,-1217.986482 -13590.52374,-1219.40228 -13654.7832,-1217.818409 -13718.30046,-1218.816018 -13781.31456,-1220.176623 -13843.58471,-1218.819039 -13905.20301,-1218.231392 -13967.45721,-1219.155412 -14030.04214,-1219.768839 -14092.84348,-1219.087143 -14155.60131,-1217.773162 -14218.77685,-1218.442359 -14281.016,-1219.142937 -14346.43126,-1218.540456 -14409.63096,-1220.27516 -14475.35352,-1220.047078 -14541.72266,-1217.934667 -14605.84911,-1219.121418 -14673.27209,-1219.250748 -14737.62863,-1220.368594 -14802.66072,-1220.611695 -14867.97276,-1221.085498 -14935.65309,-1216.312149 -15001.18007,-1218.441571 diff --git a/graph/fitness-time/script.py b/graph/fitness-time/script.py deleted file mode 100644 index d952b34..0000000 --- a/graph/fitness-time/script.py +++ /dev/null @@ -1,22 +0,0 @@ -from matplotlib import pyplot as plt -import numpy as np - -# 使用 genfromtxt 函数读取 CSV 文件 -data = np.genfromtxt('neatpython.csv', delimiter=',', skip_header=1) # 假设有一个头部行 -mean_time, fitness_mean = data.T - -fig, ax = plt.subplots() -ax.plot(mean_time, fitness_mean, color='green', label='NEAT-Python', linestyle=':') -ax.set_xlabel('Time (s)') -ax.set_ylabel('Average Fitness') -ax.set_xlim(0, 500) -ax.set_ylim(-2000, -1000) -ax.legend() - - -# ci = 1.96 * neatax_sem -# lower_bound = neatax_mean - ci -# upper_bound = neatax_mean + ci -# plt.plot(mean_time, fitness_mean, color='r', label='NEAT-Python') -# plt.fill_between(x_axis, lower_bound, upper_bound, color='red', alpha=0.2) -fig.show() diff --git a/pipeline_jitable_env.py b/pipeline.py similarity index 100% rename from pipeline_jitable_env.py rename to pipeline.py diff --git a/pipeline_time.py b/pipeline_time.py deleted file mode 100644 index d54751a..0000000 --- a/pipeline_time.py +++ /dev/null @@ -1,111 +0,0 @@ -from typing import Type - -import jax -import time -import numpy as np - -from algorithm import NEAT, HyperNEAT -from config import Config -from core import State, Algorithm, Problem - - -class Pipeline: - - def __init__(self, config: Config, algorithm: Algorithm, problem_type: Type[Problem]): - self.config = config - self.algorithm = algorithm - self.problem = problem_type(config.problem) - - if isinstance(algorithm, NEAT): - assert config.neat.inputs == self.problem.input_shape[-1] - - elif isinstance(algorithm, HyperNEAT): - assert config.hyperneat.inputs == self.problem.input_shape[-1] - - else: - raise NotImplementedError - - self.act_func = self.algorithm.act - - for _ in range(len(self.problem.input_shape) - 1): - self.act_func = jax.vmap(self.act_func, in_axes=(None, 0, None)) - - self.best_genome = None - self.best_fitness = float('-inf') - self.generation_timestamp = None - - def setup(self): - key = jax.random.PRNGKey(self.config.basic.seed) - algorithm_key, evaluate_key = jax.random.split(key, 2) - state = State() - state = self.algorithm.setup(algorithm_key, state) - return state.update( - evaluate_key=evaluate_key - ) - - def step(self, state): - - key, sub_key = jax.random.split(state.evaluate_key) - keys = jax.random.split(key, self.config.basic.pop_size) - - pop = self.algorithm.ask(state) - - pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0))(state, pop) - - fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0))(keys, state, self.act_func, - pop_transformed) - - state = self.algorithm.tell(state, fitnesses) - - return state.update(evaluate_key=sub_key), fitnesses - - def auto_run(self, ini_state): - state = ini_state - for _ in range(self.config.basic.generation_limit): - - self.generation_timestamp = time.time() - - previous_pop = self.algorithm.ask(state) - - state, fitnesses = self.step(state) - - fitnesses = jax.device_get(fitnesses) - - self.analysis(state, previous_pop, fitnesses) - - if max(fitnesses) >= self.config.basic.fitness_target: - print("Fitness limit reached!") - return state, self.best_genome - - print("Generation limit reached!") - return state, self.best_genome - - def analysis(self, state, pop, fitnesses): - - max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses) - - new_timestamp = time.time() - - cost_time = new_timestamp - self.generation_timestamp - - max_idx = np.argmax(fitnesses) - if fitnesses[max_idx] > self.best_fitness: - self.best_fitness = fitnesses[max_idx] - self.best_genome = pop[max_idx] - - member_count = jax.device_get(state.species_info.member_count) - species_sizes = [int(i) for i in member_count if i > 0] - - print(f"Generation: {state.generation}", - f"species: {len(species_sizes)}, {species_sizes}", - f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms") - - def show(self, state, genome): - transformed = self.algorithm.transform(state, genome) - self.problem.show(state.evaluate_key, state, self.act_func, transformed) - - def pre_compile(self, state): - tic = time.time() - print("start compile") - self.step.lower(self, state).compile() - print(f"compile finished, cost time: {time.time() - tic}s")