small change for elegant code style

This commit is contained in:
wls2002
2023-07-19 16:38:43 +08:00
parent a684e6584d
commit 80ee5ea2ea
5 changed files with 7 additions and 87 deletions

View File

@@ -14,7 +14,7 @@ activate_times = 10
fitness_threshold = 3.9999 fitness_threshold = 3.9999
generation_limit = 1000 generation_limit = 1000
fitness_criterion = "max" fitness_criterion = "max"
pop_size = 1000 pop_size = 50000
[genome] [genome]
compatibility_disjoint = 1.0 compatibility_disjoint = 1.0

View File

@@ -1,5 +1,3 @@
from typing import Tuple
import jax import jax
from jax import jit, Array, numpy as jnp from jax import jit, Array, numpy as jnp

View File

@@ -12,28 +12,7 @@ from ..utils import fetch_first, I_INT
@jit @jit
def topological_sort(nodes: Array, conns: Array) -> Array: def topological_sort(nodes: Array, conns: Array) -> Array:
""" """
a jit-able version of topological_sort! that's crazy! a jit-able version of topological_sort!
:param nodes: nodes array
:param conns: connections array
:return: topological sorted sequence
Example:
nodes = jnp.array([
[0],
[1],
[2],
[3]
])
connections = jnp.array([
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
]
])
topological_sort(nodes, connections) -> [0, 1, 2, 3]
""" """
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(conns, axis=0)) in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(conns, axis=0))
@@ -65,30 +44,9 @@ def topological_sort(nodes: Array, conns: Array) -> Array:
def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array: def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
""" """
Check whether a new connection (from_idx -> to_idx) will cause a cycle. Check whether a new connection (from_idx -> to_idx) will cause a cycle.
Example:
nodes = jnp.array([
[0],
[1],
[2],
[3]
])
connections = jnp.array([
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
])
check_cycles(nodes, conns, 3, 2) -> True
check_cycles(nodes, conns, 2, 3) -> False
check_cycles(nodes, conns, 0, 3) -> False
check_cycles(nodes, conns, 1, 0) -> False
""" """
conns = conns.at[from_idx, to_idx].set(True) conns = conns.at[from_idx, to_idx].set(True)
# conns_enable = ~jnp.isnan(conns[0, :, :])
# conns_enable = conns_enable.at[from_idx, to_idx].set(True)
visited = jnp.full(nodes.shape[0], False) visited = jnp.full(nodes.shape[0], False)
new_visited = visited.at[to_idx].set(True) new_visited = visited.at[to_idx].set(True)
@@ -107,36 +65,3 @@ def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited)) _, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
return visited[from_idx] return visited[from_idx]
# if __name__ == '__main__':
# nodes = jnp.array([
# [0],
# [1],
# [2],
# [3],
# [jnp.nan]
# ])
# connections = jnp.array([
# [
# [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
# [jnp.nan, jnp.nan, 1, 1, jnp.nan],
# [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
# ],
# [
# [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
# [jnp.nan, jnp.nan, 1, 1, jnp.nan],
# [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
# ]
# ]
# )
#
# print(topological_sort(nodes, connections))
#
# print(check_cycles(nodes, connections, 3, 2))
# print(check_cycles(nodes, connections, 2, 3))
# print(check_cycles(nodes, connections, 0, 3))
# print(check_cycles(nodes, connections, 1, 0))

View File

@@ -5,7 +5,7 @@ import jax.numpy as jnp
from algorithm.state import State from algorithm.state import State
from .gene import BaseGene from .gene import BaseGene
from .genome import initialize_genomes, create_mutate, create_distance, crossover from .genome import initialize_genomes
from .population import create_tell from .population import create_tell
@@ -14,11 +14,6 @@ class NEAT:
self.config = config self.config = config
self.gene_type = gene_type self.gene_type = gene_type
self.mutate = jax.jit(create_mutate(config, self.gene_type))
self.distance = jax.jit(create_distance(config, self.gene_type))
self.crossover = jax.jit(crossover)
self.pop_forward_transform = jax.jit(jax.vmap(self.gene_type.forward_transform))
self.forward = jax.jit(self.gene_type.create_forward(config))
self.tell_func = jax.jit(create_tell(config, self.gene_type)) self.tell_func = jax.jit(create_tell(config, self.gene_type))
def setup(self, randkey): def setup(self, randkey):
@@ -64,10 +59,11 @@ class NEAT:
idx2species=idx2species, idx2species=idx2species,
center_nodes=center_nodes, center_nodes=center_nodes,
center_conns=center_conns, center_conns=center_conns,
# avoid jax auto cast from int to float. that would cause re-compilation. # avoid jax auto cast from int to float. that would cause re-compilation.
generation=jnp.asarray(generation, dtype=jnp.int32), generation=jnp.asarray(generation, dtype=jnp.int32),
next_node_key=jnp.asarray(next_node_key, dtype=jnp.float32), next_node_key=jnp.asarray(next_node_key, dtype=jnp.float32),
next_species_key=jnp.asarray(next_species_key) next_species_key=jnp.asarray(next_species_key, dtype=jnp.float32),
) )
# move to device # move to device

View File

@@ -5,6 +5,7 @@ import jax
from jax import vmap, jit from jax import vmap, jit
import numpy as np import numpy as np
class Pipeline: class Pipeline:
""" """
Neat algorithm pipeline. Neat algorithm pipeline.
@@ -73,4 +74,4 @@ class Pipeline:
print(f"Generation: {self.state.generation}", print(f"Generation: {self.state.generation}",
f"species: {len(species_sizes)}, {species_sizes}", f"species: {len(species_sizes)}, {species_sizes}",
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Cost time: {cost_time}") f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Cost time: {cost_time}")