add mp4 format to show; fix a bug in visualize
This commit is contained in:
@@ -280,8 +280,26 @@ class DefaultGenome(BaseGenome):
|
|||||||
node2layer = {
|
node2layer = {
|
||||||
node: layer for layer, nodes in enumerate(topo_layers) for node in nodes
|
node: layer for layer, nodes in enumerate(topo_layers) for node in nodes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# reorder nodes in each layer to make them more compact
|
||||||
|
subset_key = {}
|
||||||
|
for layer, nodes in enumerate(topo_layers):
|
||||||
|
if layer == 0 or len(nodes) == 1:
|
||||||
|
subset_key[layer] = nodes
|
||||||
|
continue
|
||||||
|
nodes_y = []
|
||||||
|
for node in nodes:
|
||||||
|
node_y = 0
|
||||||
|
for y, last_node in enumerate(topo_layers[layer-1]):
|
||||||
|
if (last_node, node) in conns_list:
|
||||||
|
node_y += y
|
||||||
|
nodes_y.append(node_y)
|
||||||
|
nodes = [node for _, node in sorted(zip(nodes_y, nodes))]
|
||||||
|
subset_key[layer] = nodes
|
||||||
|
|
||||||
if reverse_node_order:
|
if reverse_node_order:
|
||||||
topo_order = topo_order[::-1]
|
for layer, nodes in subset_key.items():
|
||||||
|
subset_key[layer] = nodes[::-1]
|
||||||
|
|
||||||
G = nx.DiGraph()
|
G = nx.DiGraph()
|
||||||
|
|
||||||
@@ -300,7 +318,12 @@ class DefaultGenome(BaseGenome):
|
|||||||
|
|
||||||
for conn in conns_list:
|
for conn in conns_list:
|
||||||
G.add_edge(conn[0], conn[1])
|
G.add_edge(conn[0], conn[1])
|
||||||
pos = nx.multipartite_layout(G)
|
pos = nx.multipartite_layout(G, subset_key=subset_key)
|
||||||
|
|
||||||
|
# if layout == "spring":
|
||||||
|
# pos = nx.spring_layout(G, pos = pos, fixed=input_idx + output_idx, weight=None)
|
||||||
|
# elif layout == "spectral":
|
||||||
|
# pos = nx.spectral_layout(G, weight=None)
|
||||||
|
|
||||||
def rotate_layout(pos, angle):
|
def rotate_layout(pos, angle):
|
||||||
angle_rad = np.deg2rad(angle)
|
angle_rad = np.deg2rad(angle)
|
||||||
@@ -318,6 +341,43 @@ class DefaultGenome(BaseGenome):
|
|||||||
node_sizes = [n["size"] for n in G.nodes.values()]
|
node_sizes = [n["size"] for n in G.nodes.values()]
|
||||||
node_colors = [n["color"] for n in G.nodes.values()]
|
node_colors = [n["color"] for n in G.nodes.values()]
|
||||||
|
|
||||||
|
# for layer, nodes in enumerate(topo_layers):
|
||||||
|
# if layer < 2 or len(nodes) == 0:
|
||||||
|
# continue
|
||||||
|
# arc_edges_posi, arc_edges_nega = [], []
|
||||||
|
# for node in nodes:
|
||||||
|
# for input_node in input_idx:
|
||||||
|
# if (input_node, node) not in conns_list:
|
||||||
|
# continue
|
||||||
|
# relative_pos = pos[input_node] - pos[node]
|
||||||
|
# relative_pos = relative_pos[0] * relative_pos[1]
|
||||||
|
# if relative_pos > 0:
|
||||||
|
# arc_edges_posi.append((input_node, node))
|
||||||
|
# else:
|
||||||
|
# arc_edges_nega.append((input_node, node))
|
||||||
|
# if len(arc_edges_posi) > 0:
|
||||||
|
# nx.draw_networkx_edges(
|
||||||
|
# G,
|
||||||
|
# pos=rotated_pos,
|
||||||
|
# edgelist=arc_edges_posi,
|
||||||
|
# arrowstyle=arrowstyle,
|
||||||
|
# arrowsize=arrowsize,
|
||||||
|
# edge_color=edge_color,
|
||||||
|
# connectionstyle="arc3,rad=0.5"
|
||||||
|
# )
|
||||||
|
# G.remove_edges_from(arc_edges_posi)
|
||||||
|
# if len(arc_edges_nega) > 0:
|
||||||
|
# nx.draw_networkx_edges(
|
||||||
|
# G,
|
||||||
|
# pos=rotated_pos,
|
||||||
|
# edgelist=arc_edges_nega,
|
||||||
|
# arrowstyle=arrowstyle,
|
||||||
|
# arrowsize=arrowsize,
|
||||||
|
# edge_color=edge_color,
|
||||||
|
# connectionstyle="arc3,rad=-0.5"
|
||||||
|
# )
|
||||||
|
# G.remove_edges_from(arc_edges_nega)
|
||||||
|
|
||||||
nx.draw(
|
nx.draw(
|
||||||
G,
|
G,
|
||||||
pos=rotated_pos,
|
pos=rotated_pos,
|
||||||
@@ -327,8 +387,7 @@ class DefaultGenome(BaseGenome):
|
|||||||
edgecolors=edgecolors,
|
edgecolors=edgecolors,
|
||||||
arrowstyle=arrowstyle,
|
arrowstyle=arrowstyle,
|
||||||
arrowsize=arrowsize,
|
arrowsize=arrowsize,
|
||||||
edge_color=edge_color,
|
edge_color=edge_color
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
plt.savefig(save_path, dpi=save_dpi)
|
plt.savefig(save_path, dpi=save_dpi)
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ class BraxEnv(RLEnv):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
assert output_type in ["rgb_array", "gif"]
|
assert output_type in ["rgb_array", "gif", "mp4"]
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import imageio
|
import imageio
|
||||||
@@ -93,8 +93,14 @@ class BraxEnv(RLEnv):
|
|||||||
return imgs
|
return imgs
|
||||||
|
|
||||||
if save_path is None:
|
if save_path is None:
|
||||||
save_path = f"{self.env_name}.gif"
|
save_path = f"{self.env_name}.{output_type}"
|
||||||
|
|
||||||
imageio.mimsave(save_path, imgs, *args, **kwargs)
|
imageio.mimsave(save_path, imgs, *args, **kwargs)
|
||||||
|
|
||||||
print("Gif saved to: ", save_path)
|
if output_type == "gif":
|
||||||
|
imageio.mimsave(save_path, imgs, *args, **kwargs)
|
||||||
|
elif output_type == "mp4":
|
||||||
|
fps = kwargs.get("fps", 30)
|
||||||
|
imageio.mimsave(save_path, imgs, fps=fps, codec="libx264", format="mp4")
|
||||||
|
|
||||||
|
print(f"{output_type} saved to: ", save_path)
|
||||||
|
|||||||
Reference in New Issue
Block a user