small change for elegant code style
This commit is contained in:
@@ -14,7 +14,7 @@ activate_times = 10
|
|||||||
fitness_threshold = 3.9999
|
fitness_threshold = 3.9999
|
||||||
generation_limit = 1000
|
generation_limit = 1000
|
||||||
fitness_criterion = "max"
|
fitness_criterion = "max"
|
||||||
pop_size = 1000
|
pop_size = 50000
|
||||||
|
|
||||||
[genome]
|
[genome]
|
||||||
compatibility_disjoint = 1.0
|
compatibility_disjoint = 1.0
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
from jax import jit, Array, numpy as jnp
|
from jax import jit, Array, numpy as jnp
|
||||||
|
|
||||||
|
|||||||
@@ -12,28 +12,7 @@ from ..utils import fetch_first, I_INT
|
|||||||
@jit
|
@jit
|
||||||
def topological_sort(nodes: Array, conns: Array) -> Array:
|
def topological_sort(nodes: Array, conns: Array) -> Array:
|
||||||
"""
|
"""
|
||||||
a jit-able version of topological_sort! that's crazy!
|
a jit-able version of topological_sort!
|
||||||
:param nodes: nodes array
|
|
||||||
:param conns: connections array
|
|
||||||
:return: topological sorted sequence
|
|
||||||
|
|
||||||
Example:
|
|
||||||
nodes = jnp.array([
|
|
||||||
[0],
|
|
||||||
[1],
|
|
||||||
[2],
|
|
||||||
[3]
|
|
||||||
])
|
|
||||||
connections = jnp.array([
|
|
||||||
[
|
|
||||||
[0, 0, 1, 0],
|
|
||||||
[0, 0, 1, 1],
|
|
||||||
[0, 0, 0, 1],
|
|
||||||
[0, 0, 0, 0]
|
|
||||||
]
|
|
||||||
])
|
|
||||||
|
|
||||||
topological_sort(nodes, connections) -> [0, 1, 2, 3]
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(conns, axis=0))
|
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(conns, axis=0))
|
||||||
@@ -65,30 +44,9 @@ def topological_sort(nodes: Array, conns: Array) -> Array:
|
|||||||
def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
|
def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
|
||||||
"""
|
"""
|
||||||
Check whether a new connection (from_idx -> to_idx) will cause a cycle.
|
Check whether a new connection (from_idx -> to_idx) will cause a cycle.
|
||||||
|
|
||||||
Example:
|
|
||||||
nodes = jnp.array([
|
|
||||||
[0],
|
|
||||||
[1],
|
|
||||||
[2],
|
|
||||||
[3]
|
|
||||||
])
|
|
||||||
connections = jnp.array([
|
|
||||||
[0, 0, 1, 0],
|
|
||||||
[0, 0, 1, 1],
|
|
||||||
[0, 0, 0, 1],
|
|
||||||
[0, 0, 0, 0]
|
|
||||||
])
|
|
||||||
|
|
||||||
check_cycles(nodes, conns, 3, 2) -> True
|
|
||||||
check_cycles(nodes, conns, 2, 3) -> False
|
|
||||||
check_cycles(nodes, conns, 0, 3) -> False
|
|
||||||
check_cycles(nodes, conns, 1, 0) -> False
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
conns = conns.at[from_idx, to_idx].set(True)
|
conns = conns.at[from_idx, to_idx].set(True)
|
||||||
# conns_enable = ~jnp.isnan(conns[0, :, :])
|
|
||||||
# conns_enable = conns_enable.at[from_idx, to_idx].set(True)
|
|
||||||
|
|
||||||
visited = jnp.full(nodes.shape[0], False)
|
visited = jnp.full(nodes.shape[0], False)
|
||||||
new_visited = visited.at[to_idx].set(True)
|
new_visited = visited.at[to_idx].set(True)
|
||||||
@@ -107,36 +65,3 @@ def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
|
|||||||
|
|
||||||
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
|
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
|
||||||
return visited[from_idx]
|
return visited[from_idx]
|
||||||
|
|
||||||
# if __name__ == '__main__':
|
|
||||||
# nodes = jnp.array([
|
|
||||||
# [0],
|
|
||||||
# [1],
|
|
||||||
# [2],
|
|
||||||
# [3],
|
|
||||||
# [jnp.nan]
|
|
||||||
# ])
|
|
||||||
# connections = jnp.array([
|
|
||||||
# [
|
|
||||||
# [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
|
|
||||||
# [jnp.nan, jnp.nan, 1, 1, jnp.nan],
|
|
||||||
# [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
|
|
||||||
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
|
|
||||||
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
|
||||||
# ],
|
|
||||||
# [
|
|
||||||
# [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
|
|
||||||
# [jnp.nan, jnp.nan, 1, 1, jnp.nan],
|
|
||||||
# [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
|
|
||||||
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
|
|
||||||
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
|
||||||
# ]
|
|
||||||
# ]
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# print(topological_sort(nodes, connections))
|
|
||||||
#
|
|
||||||
# print(check_cycles(nodes, connections, 3, 2))
|
|
||||||
# print(check_cycles(nodes, connections, 2, 3))
|
|
||||||
# print(check_cycles(nodes, connections, 0, 3))
|
|
||||||
# print(check_cycles(nodes, connections, 1, 0))
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import jax.numpy as jnp
|
|||||||
|
|
||||||
from algorithm.state import State
|
from algorithm.state import State
|
||||||
from .gene import BaseGene
|
from .gene import BaseGene
|
||||||
from .genome import initialize_genomes, create_mutate, create_distance, crossover
|
from .genome import initialize_genomes
|
||||||
from .population import create_tell
|
from .population import create_tell
|
||||||
|
|
||||||
|
|
||||||
@@ -14,11 +14,6 @@ class NEAT:
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.gene_type = gene_type
|
self.gene_type = gene_type
|
||||||
|
|
||||||
self.mutate = jax.jit(create_mutate(config, self.gene_type))
|
|
||||||
self.distance = jax.jit(create_distance(config, self.gene_type))
|
|
||||||
self.crossover = jax.jit(crossover)
|
|
||||||
self.pop_forward_transform = jax.jit(jax.vmap(self.gene_type.forward_transform))
|
|
||||||
self.forward = jax.jit(self.gene_type.create_forward(config))
|
|
||||||
self.tell_func = jax.jit(create_tell(config, self.gene_type))
|
self.tell_func = jax.jit(create_tell(config, self.gene_type))
|
||||||
|
|
||||||
def setup(self, randkey):
|
def setup(self, randkey):
|
||||||
@@ -64,10 +59,11 @@ class NEAT:
|
|||||||
idx2species=idx2species,
|
idx2species=idx2species,
|
||||||
center_nodes=center_nodes,
|
center_nodes=center_nodes,
|
||||||
center_conns=center_conns,
|
center_conns=center_conns,
|
||||||
|
|
||||||
# avoid jax auto cast from int to float. that would cause re-compilation.
|
# avoid jax auto cast from int to float. that would cause re-compilation.
|
||||||
generation=jnp.asarray(generation, dtype=jnp.int32),
|
generation=jnp.asarray(generation, dtype=jnp.int32),
|
||||||
next_node_key=jnp.asarray(next_node_key, dtype=jnp.float32),
|
next_node_key=jnp.asarray(next_node_key, dtype=jnp.float32),
|
||||||
next_species_key=jnp.asarray(next_species_key)
|
next_species_key=jnp.asarray(next_species_key, dtype=jnp.float32),
|
||||||
)
|
)
|
||||||
|
|
||||||
# move to device
|
# move to device
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import jax
|
|||||||
from jax import vmap, jit
|
from jax import vmap, jit
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class Pipeline:
|
class Pipeline:
|
||||||
"""
|
"""
|
||||||
Neat algorithm pipeline.
|
Neat algorithm pipeline.
|
||||||
@@ -73,4 +74,4 @@ class Pipeline:
|
|||||||
|
|
||||||
print(f"Generation: {self.state.generation}",
|
print(f"Generation: {self.state.generation}",
|
||||||
f"species: {len(species_sizes)}, {species_sizes}",
|
f"species: {len(species_sizes)}, {species_sizes}",
|
||||||
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Cost time: {cost_time}")
|
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Cost time: {cost_time}")
|
||||||
|
|||||||
Reference in New Issue
Block a user