complete show() in brax env

This commit is contained in:
wls2002
2023-10-22 21:01:06 +08:00
parent 7f042e07c2
commit 15dadebd7e
11 changed files with 152 additions and 22 deletions

View File

@@ -105,7 +105,7 @@ class HyperNEATGene:
values = values.at[input_idx].set(inputs_with_bias) values = values.at[input_idx].set(inputs_with_bias)
nodes_ins = values * weights.T nodes_ins = values * weights.T
values = batch_agg(nodes_ins) # z = agg(ins) values = batch_agg(nodes_ins) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias # values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(values) # z = act(z) values = batch_act(values) # z = act(z)
return values return values

View File

@@ -22,7 +22,7 @@ class Problem:
def output_shape(self): def output_shape(self):
raise NotImplementedError raise NotImplementedError
def show(self, randkey, state: State, act_func: Callable, params): def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs):
""" """
show how a genome perform in this problem show how a genome perform in this problem
""" """

View File

@@ -12,7 +12,7 @@ def example_conf():
basic=BasicConfig( basic=BasicConfig(
seed=42, seed=42,
fitness_target=10000, fitness_target=10000,
pop_size=100 pop_size=1000
), ),
neat=NeatConfig( neat=NeatConfig(
inputs=27, inputs=27,
@@ -23,6 +23,7 @@ def example_conf():
activation_options=(Act.tanh,), activation_options=(Act.tanh,),
), ),
problem=BraxConfig( problem=BraxConfig(
env_name="ant"
) )
) )

View File

@@ -15,7 +15,8 @@ def example_conf():
basic=BasicConfig( basic=BasicConfig(
seed=42, seed=42,
fitness_target=10000, fitness_target=10000,
pop_size=10000 generation_limit=10,
pop_size=100
), ),
neat=NeatConfig( neat=NeatConfig(
inputs=17, inputs=17,
@@ -33,9 +34,9 @@ def example_conf():
if __name__ == '__main__': if __name__ == '__main__':
conf = example_conf() conf = example_conf()
algorithm = NEAT(conf, NormalGene) algorithm = NEAT(conf, NormalGene)
pipeline = Pipeline(conf, algorithm, BraxEnv) pipeline = Pipeline(conf, algorithm, BraxEnv)
state = pipeline.setup() state = pipeline.setup()
pipeline.pre_compile(state) pipeline.pre_compile(state)
state, best = pipeline.auto_run(state) state, best = pipeline.auto_run(state)
pipeline.show(state, best, save_path="half_cheetah.gif", )

View File

@@ -12,7 +12,7 @@ def example_conf():
basic=BasicConfig( basic=BasicConfig(
seed=42, seed=42,
fitness_target=10000, fitness_target=10000,
pop_size=10000 pop_size=1000
), ),
neat=NeatConfig( neat=NeatConfig(
inputs=11, inputs=11,

View File

@@ -1,7 +1,14 @@
import imageio
import jax import jax
import brax import brax
from brax import envs from brax import envs
from brax.io import image
import matplotlib.pyplot as plt
import time
from tqdm import tqdm
import numpy as np
def inference_func(key, *args): def inference_func(key, *args):
@@ -17,20 +24,50 @@ jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step) jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.jit(inference_func) jit_inference_fn = jax.jit(inference_func)
rollout = []
rng = jax.random.PRNGKey(seed=1) rng = jax.random.PRNGKey(seed=1)
ori_state = jit_env_reset(rng=rng) ori_state = jit_env_reset(rng=rng)
state = ori_state state = ori_state
for _ in range(100): render_history = []
rollout.append(state.pipeline_state)
for i in range(100):
act_rng, rng = jax.random.split(rng) act_rng, rng = jax.random.split(rng)
tic = time.time()
act = jit_inference_fn(act_rng, state.obs) act = jit_inference_fn(act_rng, state.obs)
state = jit_env_step(state, act) state = jit_env_step(state, act)
print("step time: ", time.time() - tic)
render_history.append(state.pipeline_state)
# img = image.render_array(sys=env.sys, state=pipeline_state, width=512, height=512)
# print("render time: ", time.time() - tic)
# plt.imsave("../images/ant_{}.png".format(i), img)
reward = state.reward reward = state.reward
# print(reward) done = state.done
print(i, reward)
a = 1 render_history = jax.device_get(render_history)
# print(render_history)
imgs = [image.render_array(sys=env.sys, state=s, width=512, height=512) for s in tqdm(render_history)]
# for i, s in enumerate(tqdm(render_history)):
# img = image.render_array(sys=env.sys, state=s, width=512, height=512)
# print(img.shape)
# # print(type(img))
# plt.imsave("../images/ant_{}.png".format(i), img)
def create_gif(image_list, gif_name, duration):
with imageio.get_writer(gif_name, mode='I', duration=duration) as writer:
for image in image_list:
# 确保图像的数据类型正确
formatted_image = np.array(image, dtype=np.uint8)
writer.append_data(formatted_image)
create_gif(imgs, "../images/ant.gif", 0.1)

54
examples/brax_render.py Normal file
View File

@@ -0,0 +1,54 @@
import brax
from brax import envs
from brax.envs.wrappers import gym as gym_wrapper
from brax.io import image
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import traceback
# print(f"Using Brax {brax.__version__}, Jax {jax.__version__}")
# print("From GymWrapper, env.reset()")
# try:
# env = envs.create("inverted_pendulum",
# batch_size=1,
# episode_length=150,
# backend='generalized')
# env = gym_wrapper.GymWrapper(env)
# env.reset()
# img = env.render(mode='rgb_array')
# plt.imshow(img)
# except Exception:
# traceback.print_exc()
#
# print("From GymWrapper, env.reset() and action")
# try:
# env = envs.create("inverted_pendulum",
# batch_size=1,
# episode_length=150,
# backend='generalized')
# env = gym_wrapper.GymWrapper(env)
# env.reset()
# action = jnp.zeros(env.action_space.shape)
# env.step(action)
# img = env.render(mode='rgb_array')
# plt.imshow(img)
# except Exception:
# traceback.print_exc()
print("From brax env")
try:
env = envs.create("inverted_pendulum",
batch_size=1,
episode_length=150,
backend='generalized')
key = jax.random.PRNGKey(0)
initial_env_state = env.reset(key)
base_state = initial_env_state.pipeline_state
pipeline_state = env.pipeline_init(base_state.q.ravel(), base_state.qd.ravel())
img = image.render_array(sys=env.sys, state=pipeline_state, width=256, height=256)
print(f"pixel values: [{img.min()}, {img.max()}]")
plt.imshow(img)
plt.show()
except Exception:
traceback.print_exc()

View File

@@ -111,9 +111,9 @@ class Pipeline:
f"species: {len(species_sizes)}, {species_sizes}", f"species: {len(species_sizes)}, {species_sizes}",
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms") f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")
def show(self, state, genome): def show(self, state, genome, *args, **kwargs):
transformed = self.algorithm.transform(state, genome) transformed = self.algorithm.transform(state, genome)
self.problem.show(state.evaluate_key, state, self.act_func, transformed) self.problem.show(state.evaluate_key, state, self.act_func, transformed, *args, **kwargs)
def pre_compile(self, state): def pre_compile(self, state):
tic = time.time() tic = time.time()

View File

@@ -44,7 +44,7 @@ class FuncFit(Problem):
return -loss return -loss
def show(self, randkey, state: State, act_func: Callable, params): def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs):
predict = act_func(state, self.inputs, params) predict = act_func(state, self.inputs, params)
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict]) inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
loss = -self.evaluate(randkey, state, act_func, params) loss = -self.evaluate(randkey, state, act_func, params)

View File

@@ -34,12 +34,50 @@ class BraxEnv(RLEnv):
@property @property
def input_shape(self): def input_shape(self):
return (self.env.observation_size, ) return (self.env.observation_size,)
@property @property
def output_shape(self): def output_shape(self):
return (self.env.action_size, ) return (self.env.action_size,)
def show(self, randkey, state: State, act_func: Callable, params, save_path=None, height=512, width=512,
duration=0.1, *args,
**kwargs):
import jax
import imageio
import numpy as np
from brax.io import image
from tqdm import tqdm
obs, env_state = self.reset(randkey)
reward, done = 0.0, False
state_histories = []
def step(key, env_state, obs):
key, _ = jax.random.split(key)
net_out = act_func(state, obs, params)
action = self.config.output_transform(net_out)
obs, env_state, r, done, _ = self.step(randkey, env_state, action)
return key, env_state, obs, r, done
while not done:
state_histories.append(env_state.pipeline_state)
key, env_state, obs, r, done = jax.jit(step)(randkey, env_state, obs)
reward += r
imgs = [image.render_array(sys=self.env.sys, state=s, width=width, height=height) for s in
tqdm(state_histories, desc="Rendering")]
def create_gif(image_list, gif_name, duration):
with imageio.get_writer(gif_name, mode='I', duration=duration) as writer:
for image in image_list:
# 确保图像的数据类型正确
formatted_image = np.array(image, dtype=np.uint8)
writer.append_data(formatted_image)
create_gif(imgs, save_path, duration=0.1)
print("Gif saved to: ", save_path)
print("Total reward: ", reward)
def show(self, randkey, state: State, act_func: Callable, params):
# TODO
raise NotImplementedError("im busy! to de done!")

View File

@@ -32,7 +32,6 @@ class RLEnv(Problem):
def body_func(carry): def body_func(carry):
obs, env_state, rng, _, tr = carry # total reward obs, env_state, rng, _, tr = carry # total reward
net_out = act_func(state, obs, params) net_out = act_func(state, obs, params)
action = self.config.output_transform(net_out) action = self.config.output_transform(net_out)
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action) next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
next_rng, _ = jax.random.split(rng) next_rng, _ = jax.random.split(rng)
@@ -68,5 +67,5 @@ class RLEnv(Problem):
def output_shape(self): def output_shape(self):
raise NotImplementedError raise NotImplementedError
def show(self, randkey, state: State, act_func: Callable, params): def show(self, randkey, state: State, act_func: Callable, params, *args, **kwargs):
raise NotImplementedError raise NotImplementedError