add mp4 format to show; fix a bug in visualize

This commit is contained in:
Zenbook
2025-02-23 18:16:44 +08:00
parent d86a3196bd
commit c2566c3931
2 changed files with 72 additions and 7 deletions

View File

@@ -280,8 +280,26 @@ class DefaultGenome(BaseGenome):
node2layer = {
node: layer for layer, nodes in enumerate(topo_layers) for node in nodes
}
# reorder nodes in each layer to make them more compact
subset_key = {}
for layer, nodes in enumerate(topo_layers):
if layer == 0 or len(nodes) == 1:
subset_key[layer] = nodes
continue
nodes_y = []
for node in nodes:
node_y = 0
for y, last_node in enumerate(topo_layers[layer-1]):
if (last_node, node) in conns_list:
node_y += y
nodes_y.append(node_y)
nodes = [node for _, node in sorted(zip(nodes_y, nodes))]
subset_key[layer] = nodes
if reverse_node_order:
topo_order = topo_order[::-1]
for layer, nodes in subset_key.items():
subset_key[layer] = nodes[::-1]
G = nx.DiGraph()
@@ -300,7 +318,12 @@ class DefaultGenome(BaseGenome):
for conn in conns_list:
G.add_edge(conn[0], conn[1])
pos = nx.multipartite_layout(G)
pos = nx.multipartite_layout(G, subset_key=subset_key)
# if layout == "spring":
# pos = nx.spring_layout(G, pos = pos, fixed=input_idx + output_idx, weight=None)
# elif layout == "spectral":
# pos = nx.spectral_layout(G, weight=None)
def rotate_layout(pos, angle):
angle_rad = np.deg2rad(angle)
@@ -318,6 +341,43 @@ class DefaultGenome(BaseGenome):
node_sizes = [n["size"] for n in G.nodes.values()]
node_colors = [n["color"] for n in G.nodes.values()]
# for layer, nodes in enumerate(topo_layers):
# if layer < 2 or len(nodes) == 0:
# continue
# arc_edges_posi, arc_edges_nega = [], []
# for node in nodes:
# for input_node in input_idx:
# if (input_node, node) not in conns_list:
# continue
# relative_pos = pos[input_node] - pos[node]
# relative_pos = relative_pos[0] * relative_pos[1]
# if relative_pos > 0:
# arc_edges_posi.append((input_node, node))
# else:
# arc_edges_nega.append((input_node, node))
# if len(arc_edges_posi) > 0:
# nx.draw_networkx_edges(
# G,
# pos=rotated_pos,
# edgelist=arc_edges_posi,
# arrowstyle=arrowstyle,
# arrowsize=arrowsize,
# edge_color=edge_color,
# connectionstyle="arc3,rad=0.5"
# )
# G.remove_edges_from(arc_edges_posi)
# if len(arc_edges_nega) > 0:
# nx.draw_networkx_edges(
# G,
# pos=rotated_pos,
# edgelist=arc_edges_nega,
# arrowstyle=arrowstyle,
# arrowsize=arrowsize,
# edge_color=edge_color,
# connectionstyle="arc3,rad=-0.5"
# )
# G.remove_edges_from(arc_edges_nega)
nx.draw(
G,
pos=rotated_pos,
@@ -327,8 +387,7 @@ class DefaultGenome(BaseGenome):
edgecolors=edgecolors,
arrowstyle=arrowstyle,
arrowsize=arrowsize,
edge_color=edge_color,
**kwargs,
edge_color=edge_color
)
plt.savefig(save_path, dpi=save_dpi)
plt.close()

View File

@@ -42,7 +42,7 @@ class BraxEnv(RLEnv):
**kwargs,
):
assert output_type in ["rgb_array", "gif"]
assert output_type in ["rgb_array", "gif", "mp4"]
import jax
import imageio
@@ -93,8 +93,14 @@ class BraxEnv(RLEnv):
return imgs
if save_path is None:
save_path = f"{self.env_name}.gif"
save_path = f"{self.env_name}.{output_type}"
imageio.mimsave(save_path, imgs, *args, **kwargs)
print("Gif saved to: ", save_path)
if output_type == "gif":
imageio.mimsave(save_path, imgs, *args, **kwargs)
elif output_type == "mp4":
fps = kwargs.get("fps", 30)
imageio.mimsave(save_path, imgs, fps=fps, codec="libx264", format="mp4")
print(f"{output_type} saved to: ", save_path)