complete HyperNEAT!
This commit is contained in:
@@ -3,22 +3,25 @@ from typing import Type
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from algorithm.state import State
|
||||
from algorithm import Algorithm, State
|
||||
from .gene import BaseGene
|
||||
from .genome import initialize_genomes
|
||||
from .population import create_tell
|
||||
|
||||
|
||||
class NEAT:
|
||||
class NEAT(Algorithm):
|
||||
def __init__(self, config, gene_type: Type[BaseGene]):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.gene_type = gene_type
|
||||
|
||||
self.tell_func = jax.jit(create_tell(config, self.gene_type))
|
||||
self.tell = create_tell(config, self.gene_type)
|
||||
self.ask = None
|
||||
self.forward = self.gene_type.create_forward(config)
|
||||
self.forward_transform = self.gene_type.forward_transform
|
||||
|
||||
def setup(self, randkey):
|
||||
|
||||
state = State(
|
||||
def setup(self, randkey, state=State()):
|
||||
state = state.update(
|
||||
P=self.config['pop_size'],
|
||||
N=self.config['maximum_nodes'],
|
||||
C=self.config['maximum_conns'],
|
||||
@@ -69,7 +72,4 @@ class NEAT:
|
||||
# move to device
|
||||
state = jax.device_put(state)
|
||||
|
||||
return state
|
||||
|
||||
def step(self, state, fitness):
|
||||
return self.tell_func(state, fitness)
|
||||
return state
|
||||
Reference in New Issue
Block a user