diff --git a/src/tensorneat/genome/default.py b/src/tensorneat/genome/default.py index 66a9e64..e308a4e 100644 --- a/src/tensorneat/genome/default.py +++ b/src/tensorneat/genome/default.py @@ -280,8 +280,26 @@ class DefaultGenome(BaseGenome): node2layer = { node: layer for layer, nodes in enumerate(topo_layers) for node in nodes } + + # reorder nodes in each layer to make them more compact + subset_key = {} + for layer, nodes in enumerate(topo_layers): + if layer == 0 or len(nodes) == 1: + subset_key[layer] = nodes + continue + nodes_y = [] + for node in nodes: + node_y = 0 + for y, last_node in enumerate(topo_layers[layer-1]): + if (last_node, node) in conns_list: + node_y += y + nodes_y.append(node_y) + nodes = [node for _, node in sorted(zip(nodes_y, nodes))] + subset_key[layer] = nodes + if reverse_node_order: - topo_order = topo_order[::-1] + for layer, nodes in subset_key.items(): + subset_key[layer] = nodes[::-1] G = nx.DiGraph() @@ -300,7 +318,12 @@ class DefaultGenome(BaseGenome): for conn in conns_list: G.add_edge(conn[0], conn[1]) - pos = nx.multipartite_layout(G) + pos = nx.multipartite_layout(G, subset_key=subset_key) + + # if layout == "spring": + # pos = nx.spring_layout(G, pos = pos, fixed=input_idx + output_idx, weight=None) + # elif layout == "spectral": + # pos = nx.spectral_layout(G, weight=None) def rotate_layout(pos, angle): angle_rad = np.deg2rad(angle) @@ -318,6 +341,43 @@ class DefaultGenome(BaseGenome): node_sizes = [n["size"] for n in G.nodes.values()] node_colors = [n["color"] for n in G.nodes.values()] + # for layer, nodes in enumerate(topo_layers): + # if layer < 2 or len(nodes) == 0: + # continue + # arc_edges_posi, arc_edges_nega = [], [] + # for node in nodes: + # for input_node in input_idx: + # if (input_node, node) not in conns_list: + # continue + # relative_pos = pos[input_node] - pos[node] + # relative_pos = relative_pos[0] * relative_pos[1] + # if relative_pos > 0: + # arc_edges_posi.append((input_node, node)) + # else: + # arc_edges_nega.append((input_node, node)) + # if len(arc_edges_posi) > 0: + # nx.draw_networkx_edges( + # G, + # pos=rotated_pos, + # edgelist=arc_edges_posi, + # arrowstyle=arrowstyle, + # arrowsize=arrowsize, + # edge_color=edge_color, + # connectionstyle="arc3,rad=0.5" + # ) + # G.remove_edges_from(arc_edges_posi) + # if len(arc_edges_nega) > 0: + # nx.draw_networkx_edges( + # G, + # pos=rotated_pos, + # edgelist=arc_edges_nega, + # arrowstyle=arrowstyle, + # arrowsize=arrowsize, + # edge_color=edge_color, + # connectionstyle="arc3,rad=-0.5" + # ) + # G.remove_edges_from(arc_edges_nega) + nx.draw( G, pos=rotated_pos, @@ -327,8 +387,7 @@ class DefaultGenome(BaseGenome): edgecolors=edgecolors, arrowstyle=arrowstyle, arrowsize=arrowsize, - edge_color=edge_color, - **kwargs, + edge_color=edge_color ) plt.savefig(save_path, dpi=save_dpi) plt.close() diff --git a/src/tensorneat/problem/rl/brax.py b/src/tensorneat/problem/rl/brax.py index cf6e84d..f7ac483 100644 --- a/src/tensorneat/problem/rl/brax.py +++ b/src/tensorneat/problem/rl/brax.py @@ -42,7 +42,7 @@ class BraxEnv(RLEnv): **kwargs, ): - assert output_type in ["rgb_array", "gif"] + assert output_type in ["rgb_array", "gif", "mp4"] import jax import imageio @@ -93,8 +93,14 @@ class BraxEnv(RLEnv): return imgs if save_path is None: - save_path = f"{self.env_name}.gif" + save_path = f"{self.env_name}.{output_type}" imageio.mimsave(save_path, imgs, *args, **kwargs) - print("Gif saved to: ", save_path) + if output_type == "gif": + imageio.mimsave(save_path, imgs, *args, **kwargs) + elif output_type == "mp4": + fps = kwargs.get("fps", 30) + imageio.mimsave(save_path, imgs, fps=fps, codec="libx264", format="mp4") + + print(f"{output_type} saved to: ", save_path)