small change for elegant code style

This commit is contained in:
wls2002
2023-07-19 16:38:43 +08:00
parent a684e6584d
commit 80ee5ea2ea
5 changed files with 7 additions and 87 deletions

View File

@@ -5,7 +5,7 @@ import jax.numpy as jnp
from algorithm.state import State
from .gene import BaseGene
from .genome import initialize_genomes, create_mutate, create_distance, crossover
from .genome import initialize_genomes
from .population import create_tell
@@ -14,11 +14,6 @@ class NEAT:
self.config = config
self.gene_type = gene_type
self.mutate = jax.jit(create_mutate(config, self.gene_type))
self.distance = jax.jit(create_distance(config, self.gene_type))
self.crossover = jax.jit(crossover)
self.pop_forward_transform = jax.jit(jax.vmap(self.gene_type.forward_transform))
self.forward = jax.jit(self.gene_type.create_forward(config))
self.tell_func = jax.jit(create_tell(config, self.gene_type))
def setup(self, randkey):
@@ -64,10 +59,11 @@ class NEAT:
idx2species=idx2species,
center_nodes=center_nodes,
center_conns=center_conns,
# avoid jax auto cast from int to float. that would cause re-compilation.
generation=jnp.asarray(generation, dtype=jnp.int32),
next_node_key=jnp.asarray(next_node_key, dtype=jnp.float32),
next_species_key=jnp.asarray(next_species_key)
next_species_key=jnp.asarray(next_species_key, dtype=jnp.float32),
)
# move to device