From 15dadebd7e48c54cfe037b32069867245b40e055 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Sun, 22 Oct 2023 21:01:06 +0800 Subject: [PATCH] complete show() in brax env --- algorithm/hyperneat/hyperneat.py | 2 +- core/problem.py | 2 +- examples/brax/ant.py | 3 +- examples/brax/half_cheetah.py | 5 +-- examples/brax/reacher.py | 2 +- examples/brax_env.py | 49 +++++++++++++++++++++++++---- examples/brax_render.py | 54 ++++++++++++++++++++++++++++++++ pipeline.py | 4 +-- problem/func_fit/func_fit.py | 2 +- problem/rl_env/brax_env.py | 48 +++++++++++++++++++++++++--- problem/rl_env/rl_jit.py | 3 +- 11 files changed, 152 insertions(+), 22 deletions(-) create mode 100644 examples/brax_render.py diff --git a/algorithm/hyperneat/hyperneat.py b/algorithm/hyperneat/hyperneat.py index 9bd21ae..0ef08cf 100644 --- a/algorithm/hyperneat/hyperneat.py +++ b/algorithm/hyperneat/hyperneat.py @@ -105,7 +105,7 @@ class HyperNEATGene: values = values.at[input_idx].set(inputs_with_bias) nodes_ins = values * weights.T values = batch_agg(nodes_ins) # z = agg(ins) - values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias + # values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias values = batch_act(values) # z = act(z) return values diff --git a/core/problem.py b/core/problem.py index f65af85..c57e2fd 100644 --- a/core/problem.py +++ b/core/problem.py @@ -22,7 +22,7 @@ class Problem: def output_shape(self): raise NotImplementedError - def show(self, randkey, state: State, act_func: Callable, params): + def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs): """ show how a genome perform in this problem """ diff --git a/examples/brax/ant.py b/examples/brax/ant.py index f4034b1..e8f4b54 100644 --- a/examples/brax/ant.py +++ b/examples/brax/ant.py @@ -12,7 +12,7 @@ def example_conf(): basic=BasicConfig( seed=42, fitness_target=10000, - pop_size=100 + pop_size=1000 ), neat=NeatConfig( inputs=27, @@ -23,6 +23,7 @@ def example_conf(): activation_options=(Act.tanh,), ), problem=BraxConfig( + env_name="ant" ) ) diff --git a/examples/brax/half_cheetah.py b/examples/brax/half_cheetah.py index eb2baf9..dbc6207 100644 --- a/examples/brax/half_cheetah.py +++ b/examples/brax/half_cheetah.py @@ -15,7 +15,8 @@ def example_conf(): basic=BasicConfig( seed=42, fitness_target=10000, - pop_size=10000 + generation_limit=10, + pop_size=100 ), neat=NeatConfig( inputs=17, @@ -33,9 +34,9 @@ def example_conf(): if __name__ == '__main__': conf = example_conf() - algorithm = NEAT(conf, NormalGene) pipeline = Pipeline(conf, algorithm, BraxEnv) state = pipeline.setup() pipeline.pre_compile(state) state, best = pipeline.auto_run(state) + pipeline.show(state, best, save_path="half_cheetah.gif", ) diff --git a/examples/brax/reacher.py b/examples/brax/reacher.py index af31765..a6ed280 100644 --- a/examples/brax/reacher.py +++ b/examples/brax/reacher.py @@ -12,7 +12,7 @@ def example_conf(): basic=BasicConfig( seed=42, fitness_target=10000, - pop_size=10000 + pop_size=1000 ), neat=NeatConfig( inputs=11, diff --git a/examples/brax_env.py b/examples/brax_env.py index 61d94e7..2124f24 100644 --- a/examples/brax_env.py +++ b/examples/brax_env.py @@ -1,7 +1,14 @@ +import imageio import jax import brax from brax import envs +from brax.io import image +import matplotlib.pyplot as plt + +import time +from tqdm import tqdm +import numpy as np def inference_func(key, *args): @@ -17,20 +24,50 @@ jit_env_reset = jax.jit(env.reset) jit_env_step = jax.jit(env.step) jit_inference_fn = jax.jit(inference_func) - -rollout = [] rng = jax.random.PRNGKey(seed=1) ori_state = jit_env_reset(rng=rng) state = ori_state -for _ in range(100): - rollout.append(state.pipeline_state) +render_history = [] + +for i in range(100): act_rng, rng = jax.random.split(rng) + + tic = time.time() act = jit_inference_fn(act_rng, state.obs) state = jit_env_step(state, act) + print("step time: ", time.time() - tic) + + render_history.append(state.pipeline_state) + + # img = image.render_array(sys=env.sys, state=pipeline_state, width=512, height=512) + # print("render time: ", time.time() - tic) + + # plt.imsave("../images/ant_{}.png".format(i), img) + reward = state.reward - # print(reward) + done = state.done + print(i, reward) -a = 1 +render_history = jax.device_get(render_history) +# print(render_history) + +imgs = [image.render_array(sys=env.sys, state=s, width=512, height=512) for s in tqdm(render_history)] +# for i, s in enumerate(tqdm(render_history)): +# img = image.render_array(sys=env.sys, state=s, width=512, height=512) +# print(img.shape) +# # print(type(img)) +# plt.imsave("../images/ant_{}.png".format(i), img) + + +def create_gif(image_list, gif_name, duration): + with imageio.get_writer(gif_name, mode='I', duration=duration) as writer: + for image in image_list: + # 确保图像的数据类型正确 + formatted_image = np.array(image, dtype=np.uint8) + writer.append_data(formatted_image) + + +create_gif(imgs, "../images/ant.gif", 0.1) diff --git a/examples/brax_render.py b/examples/brax_render.py new file mode 100644 index 0000000..4347697 --- /dev/null +++ b/examples/brax_render.py @@ -0,0 +1,54 @@ +import brax +from brax import envs +from brax.envs.wrappers import gym as gym_wrapper +from brax.io import image +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import traceback + +# print(f"Using Brax {brax.__version__}, Jax {jax.__version__}") +# print("From GymWrapper, env.reset()") +# try: +# env = envs.create("inverted_pendulum", +# batch_size=1, +# episode_length=150, +# backend='generalized') +# env = gym_wrapper.GymWrapper(env) +# env.reset() +# img = env.render(mode='rgb_array') +# plt.imshow(img) +# except Exception: +# traceback.print_exc() +# +# print("From GymWrapper, env.reset() and action") +# try: +# env = envs.create("inverted_pendulum", +# batch_size=1, +# episode_length=150, +# backend='generalized') +# env = gym_wrapper.GymWrapper(env) +# env.reset() +# action = jnp.zeros(env.action_space.shape) +# env.step(action) +# img = env.render(mode='rgb_array') +# plt.imshow(img) +# except Exception: +# traceback.print_exc() + +print("From brax env") +try: + env = envs.create("inverted_pendulum", + batch_size=1, + episode_length=150, + backend='generalized') + key = jax.random.PRNGKey(0) + initial_env_state = env.reset(key) + base_state = initial_env_state.pipeline_state + pipeline_state = env.pipeline_init(base_state.q.ravel(), base_state.qd.ravel()) + img = image.render_array(sys=env.sys, state=pipeline_state, width=256, height=256) + print(f"pixel values: [{img.min()}, {img.max()}]") + plt.imshow(img) + plt.show() +except Exception: + traceback.print_exc() \ No newline at end of file diff --git a/pipeline.py b/pipeline.py index f6ebbea..8694d23 100644 --- a/pipeline.py +++ b/pipeline.py @@ -111,9 +111,9 @@ class Pipeline: f"species: {len(species_sizes)}, {species_sizes}", f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms") - def show(self, state, genome): + def show(self, state, genome, *args, **kwargs): transformed = self.algorithm.transform(state, genome) - self.problem.show(state.evaluate_key, state, self.act_func, transformed) + self.problem.show(state.evaluate_key, state, self.act_func, transformed, *args, **kwargs) def pre_compile(self, state): tic = time.time() diff --git a/problem/func_fit/func_fit.py b/problem/func_fit/func_fit.py index 43f3043..a438972 100644 --- a/problem/func_fit/func_fit.py +++ b/problem/func_fit/func_fit.py @@ -44,7 +44,7 @@ class FuncFit(Problem): return -loss - def show(self, randkey, state: State, act_func: Callable, params): + def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs): predict = act_func(state, self.inputs, params) inputs, target, predict = jax.device_get([self.inputs, self.targets, predict]) loss = -self.evaluate(randkey, state, act_func, params) diff --git a/problem/rl_env/brax_env.py b/problem/rl_env/brax_env.py index 710834b..7b82e40 100644 --- a/problem/rl_env/brax_env.py +++ b/problem/rl_env/brax_env.py @@ -34,12 +34,50 @@ class BraxEnv(RLEnv): @property def input_shape(self): - return (self.env.observation_size, ) + return (self.env.observation_size,) @property def output_shape(self): - return (self.env.action_size, ) + return (self.env.action_size,) + + def show(self, randkey, state: State, act_func: Callable, params, save_path=None, height=512, width=512, + duration=0.1, *args, + **kwargs): + + import jax + import imageio + import numpy as np + from brax.io import image + from tqdm import tqdm + + obs, env_state = self.reset(randkey) + reward, done = 0.0, False + state_histories = [] + + def step(key, env_state, obs): + key, _ = jax.random.split(key) + net_out = act_func(state, obs, params) + action = self.config.output_transform(net_out) + obs, env_state, r, done, _ = self.step(randkey, env_state, action) + return key, env_state, obs, r, done + + while not done: + state_histories.append(env_state.pipeline_state) + key, env_state, obs, r, done = jax.jit(step)(randkey, env_state, obs) + reward += r + + imgs = [image.render_array(sys=self.env.sys, state=s, width=width, height=height) for s in + tqdm(state_histories, desc="Rendering")] + + def create_gif(image_list, gif_name, duration): + with imageio.get_writer(gif_name, mode='I', duration=duration) as writer: + for image in image_list: + # 确保图像的数据类型正确 + formatted_image = np.array(image, dtype=np.uint8) + writer.append_data(formatted_image) + + create_gif(imgs, save_path, duration=0.1) + print("Gif saved to: ", save_path) + print("Total reward: ", reward) + - def show(self, randkey, state: State, act_func: Callable, params): - # TODO - raise NotImplementedError("im busy! to de done!") diff --git a/problem/rl_env/rl_jit.py b/problem/rl_env/rl_jit.py index f8244d2..84a512b 100644 --- a/problem/rl_env/rl_jit.py +++ b/problem/rl_env/rl_jit.py @@ -32,7 +32,6 @@ class RLEnv(Problem): def body_func(carry): obs, env_state, rng, _, tr = carry # total reward net_out = act_func(state, obs, params) - action = self.config.output_transform(net_out) next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action) next_rng, _ = jax.random.split(rng) @@ -68,5 +67,5 @@ class RLEnv(Problem): def output_shape(self): raise NotImplementedError - def show(self, randkey, state: State, act_func: Callable, params): + def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs): raise NotImplementedError