complete HyperNEAT!

This commit is contained in:
wls2002
2023-07-21 15:03:12 +08:00
parent 80ee5ea2ea
commit 48f90c7eef
32 changed files with 432 additions and 136 deletions

View File

@@ -3,22 +3,25 @@ from typing import Type
import jax
import jax.numpy as jnp
from algorithm.state import State
from algorithm import Algorithm, State
from .gene import BaseGene
from .genome import initialize_genomes
from .population import create_tell
class NEAT:
class NEAT(Algorithm):
def __init__(self, config, gene_type: Type[BaseGene]):
super().__init__()
self.config = config
self.gene_type = gene_type
self.tell_func = jax.jit(create_tell(config, self.gene_type))
self.tell = create_tell(config, self.gene_type)
self.ask = None
self.forward = self.gene_type.create_forward(config)
self.forward_transform = self.gene_type.forward_transform
def setup(self, randkey):
state = State(
def setup(self, randkey, state=State()):
state = state.update(
P=self.config['pop_size'],
N=self.config['maximum_nodes'],
C=self.config['maximum_conns'],
@@ -69,7 +72,4 @@ class NEAT:
# move to device
state = jax.device_put(state)
return state
def step(self, state, fitness):
return self.tell_func(state, fitness)
return state