remove create_func....

This commit is contained in:
wls2002
2023-08-02 13:26:01 +08:00
parent 85318f98f3
commit 1499e062fe
34 changed files with 558 additions and 1022 deletions

View File

@@ -1,28 +1,50 @@
from jax import Array
from functools import partial
import jax
from .state import State
from .genome import Genome
EMPTY = lambda *args: args
class Algorithm:
def setup(self, randkey, state: State = State()):
"""initialize the state of the algorithm"""
pass
raise NotImplementedError
@partial(jax.jit, static_argnums=(0,))
def ask(self, state: State):
"""require the population to be evaluated"""
pass
return self.ask_algorithm(state)
@partial(jax.jit, static_argnums=(0,))
def tell(self, state: State, fitness):
"""update the state of the algorithm"""
pass
def forward(self, inputs: Array, transformed: Array):
"""the forward function of a single forward transformation"""
pass
return self.tell_algorithm(state, fitness)
@partial(jax.jit, static_argnums=(0,))
def transform(self, state: State, genome: Genome):
"""transform the genome into a neural network"""
return self.forward_transform(state, genome)
@partial(jax.jit, static_argnums=(0,))
def act(self, state: State, inputs, genome: Genome):
return self.forward(state, inputs, genome)
def forward_transform(self, state: State, genome: Genome):
"""create the forward transformation of a genome"""
pass
raise NotImplementedError
def forward(self, state: State, inputs, genome: Genome):
raise NotImplementedError
def ask_algorithm(self, state: State):
"""ask the specific algorithm for a new population"""
raise NotImplementedError
def tell_algorithm(self, state: State, fitness):
"""tell the specific algorithm the fitness of the population"""
raise NotImplementedError