small change for elegant code style
This commit is contained in:
@@ -5,7 +5,7 @@ import jax.numpy as jnp
|
||||
|
||||
from algorithm.state import State
|
||||
from .gene import BaseGene
|
||||
from .genome import initialize_genomes, create_mutate, create_distance, crossover
|
||||
from .genome import initialize_genomes
|
||||
from .population import create_tell
|
||||
|
||||
|
||||
@@ -14,11 +14,6 @@ class NEAT:
|
||||
self.config = config
|
||||
self.gene_type = gene_type
|
||||
|
||||
self.mutate = jax.jit(create_mutate(config, self.gene_type))
|
||||
self.distance = jax.jit(create_distance(config, self.gene_type))
|
||||
self.crossover = jax.jit(crossover)
|
||||
self.pop_forward_transform = jax.jit(jax.vmap(self.gene_type.forward_transform))
|
||||
self.forward = jax.jit(self.gene_type.create_forward(config))
|
||||
self.tell_func = jax.jit(create_tell(config, self.gene_type))
|
||||
|
||||
def setup(self, randkey):
|
||||
@@ -64,10 +59,11 @@ class NEAT:
|
||||
idx2species=idx2species,
|
||||
center_nodes=center_nodes,
|
||||
center_conns=center_conns,
|
||||
|
||||
# avoid jax auto cast from int to float. that would cause re-compilation.
|
||||
generation=jnp.asarray(generation, dtype=jnp.int32),
|
||||
next_node_key=jnp.asarray(next_node_key, dtype=jnp.float32),
|
||||
next_species_key=jnp.asarray(next_species_key)
|
||||
next_species_key=jnp.asarray(next_species_key, dtype=jnp.float32),
|
||||
)
|
||||
|
||||
# move to device
|
||||
|
||||
Reference in New Issue
Block a user