From 80ee5ea2ea379e5f48616b0ca553927744c17be3 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Wed, 19 Jul 2023 16:38:43 +0800 Subject: [PATCH] small change for elegant code style --- algorithm/default_config.ini | 2 +- algorithm/neat/genome/crossover.py | 2 - algorithm/neat/genome/graph.py | 77 +----------------------------- algorithm/neat/neat.py | 10 ++-- algorithm/neat/pipeline.py | 3 +- 5 files changed, 7 insertions(+), 87 deletions(-) diff --git a/algorithm/default_config.ini b/algorithm/default_config.ini index 2131d21..97e2803 100644 --- a/algorithm/default_config.ini +++ b/algorithm/default_config.ini @@ -14,7 +14,7 @@ activate_times = 10 fitness_threshold = 3.9999 generation_limit = 1000 fitness_criterion = "max" -pop_size = 1000 +pop_size = 50000 [genome] compatibility_disjoint = 1.0 diff --git a/algorithm/neat/genome/crossover.py b/algorithm/neat/genome/crossover.py index d61448f..302c82d 100644 --- a/algorithm/neat/genome/crossover.py +++ b/algorithm/neat/genome/crossover.py @@ -1,5 +1,3 @@ -from typing import Tuple - import jax from jax import jit, Array, numpy as jnp diff --git a/algorithm/neat/genome/graph.py b/algorithm/neat/genome/graph.py index 1f65feb..b813666 100644 --- a/algorithm/neat/genome/graph.py +++ b/algorithm/neat/genome/graph.py @@ -12,28 +12,7 @@ from ..utils import fetch_first, I_INT @jit def topological_sort(nodes: Array, conns: Array) -> Array: """ - a jit-able version of topological_sort! that's crazy! - :param nodes: nodes array - :param conns: connections array - :return: topological sorted sequence - - Example: - nodes = jnp.array([ - [0], - [1], - [2], - [3] - ]) - connections = jnp.array([ - [ - [0, 0, 1, 0], - [0, 0, 1, 1], - [0, 0, 0, 1], - [0, 0, 0, 0] - ] - ]) - - topological_sort(nodes, connections) -> [0, 1, 2, 3] + a jit-able version of topological_sort! """ in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(conns, axis=0)) @@ -65,30 +44,9 @@ def topological_sort(nodes: Array, conns: Array) -> Array: def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array: """ Check whether a new connection (from_idx -> to_idx) will cause a cycle. - - Example: - nodes = jnp.array([ - [0], - [1], - [2], - [3] - ]) - connections = jnp.array([ - [0, 0, 1, 0], - [0, 0, 1, 1], - [0, 0, 0, 1], - [0, 0, 0, 0] - ]) - - check_cycles(nodes, conns, 3, 2) -> True - check_cycles(nodes, conns, 2, 3) -> False - check_cycles(nodes, conns, 0, 3) -> False - check_cycles(nodes, conns, 1, 0) -> False """ conns = conns.at[from_idx, to_idx].set(True) - # conns_enable = ~jnp.isnan(conns[0, :, :]) - # conns_enable = conns_enable.at[from_idx, to_idx].set(True) visited = jnp.full(nodes.shape[0], False) new_visited = visited.at[to_idx].set(True) @@ -107,36 +65,3 @@ def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array: _, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited)) return visited[from_idx] - -# if __name__ == '__main__': -# nodes = jnp.array([ -# [0], -# [1], -# [2], -# [3], -# [jnp.nan] -# ]) -# connections = jnp.array([ -# [ -# [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan], -# [jnp.nan, jnp.nan, 1, 1, jnp.nan], -# [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan], -# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan], -# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan] -# ], -# [ -# [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan], -# [jnp.nan, jnp.nan, 1, 1, jnp.nan], -# [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan], -# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan], -# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan] -# ] -# ] -# ) -# -# print(topological_sort(nodes, connections)) -# -# print(check_cycles(nodes, connections, 3, 2)) -# print(check_cycles(nodes, connections, 2, 3)) -# print(check_cycles(nodes, connections, 0, 3)) -# print(check_cycles(nodes, connections, 1, 0)) diff --git a/algorithm/neat/neat.py b/algorithm/neat/neat.py index 03dc7d1..6ad078a 100644 --- a/algorithm/neat/neat.py +++ b/algorithm/neat/neat.py @@ -5,7 +5,7 @@ import jax.numpy as jnp from algorithm.state import State from .gene import BaseGene -from .genome import initialize_genomes, create_mutate, create_distance, crossover +from .genome import initialize_genomes from .population import create_tell @@ -14,11 +14,6 @@ class NEAT: self.config = config self.gene_type = gene_type - self.mutate = jax.jit(create_mutate(config, self.gene_type)) - self.distance = jax.jit(create_distance(config, self.gene_type)) - self.crossover = jax.jit(crossover) - self.pop_forward_transform = jax.jit(jax.vmap(self.gene_type.forward_transform)) - self.forward = jax.jit(self.gene_type.create_forward(config)) self.tell_func = jax.jit(create_tell(config, self.gene_type)) def setup(self, randkey): @@ -64,10 +59,11 @@ class NEAT: idx2species=idx2species, center_nodes=center_nodes, center_conns=center_conns, + # avoid jax auto cast from int to float. that would cause re-compilation. generation=jnp.asarray(generation, dtype=jnp.int32), next_node_key=jnp.asarray(next_node_key, dtype=jnp.float32), - next_species_key=jnp.asarray(next_species_key) + next_species_key=jnp.asarray(next_species_key, dtype=jnp.float32), ) # move to device diff --git a/algorithm/neat/pipeline.py b/algorithm/neat/pipeline.py index a03d0fd..612b5e9 100644 --- a/algorithm/neat/pipeline.py +++ b/algorithm/neat/pipeline.py @@ -5,6 +5,7 @@ import jax from jax import vmap, jit import numpy as np + class Pipeline: """ Neat algorithm pipeline. @@ -73,4 +74,4 @@ class Pipeline: print(f"Generation: {self.state.generation}", f"species: {len(species_sizes)}, {species_sizes}", - f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Cost time: {cost_time}") \ No newline at end of file + f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Cost time: {cost_time}")