remove create_func....
This commit is contained in:
@@ -2,4 +2,4 @@ from .algorithm import Algorithm
|
||||
from .state import State
|
||||
from .genome import Genome
|
||||
from .gene import Gene
|
||||
from .substrate import Substrate
|
||||
from .substrate import Substrate
|
||||
@@ -1,28 +1,50 @@
|
||||
from jax import Array
|
||||
from functools import partial
|
||||
import jax
|
||||
from .state import State
|
||||
from .genome import Genome
|
||||
|
||||
EMPTY = lambda *args: args
|
||||
|
||||
|
||||
class Algorithm:
|
||||
|
||||
def setup(self, randkey, state: State = State()):
|
||||
"""initialize the state of the algorithm"""
|
||||
pass
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@partial(jax.jit, static_argnums=(0,))
|
||||
def ask(self, state: State):
|
||||
"""require the population to be evaluated"""
|
||||
pass
|
||||
|
||||
return self.ask_algorithm(state)
|
||||
|
||||
@partial(jax.jit, static_argnums=(0,))
|
||||
def tell(self, state: State, fitness):
|
||||
"""update the state of the algorithm"""
|
||||
pass
|
||||
|
||||
def forward(self, inputs: Array, transformed: Array):
|
||||
"""the forward function of a single forward transformation"""
|
||||
pass
|
||||
return self.tell_algorithm(state, fitness)
|
||||
|
||||
@partial(jax.jit, static_argnums=(0,))
|
||||
def transform(self, state: State, genome: Genome):
|
||||
"""transform the genome into a neural network"""
|
||||
|
||||
return self.forward_transform(state, genome)
|
||||
|
||||
@partial(jax.jit, static_argnums=(0,))
|
||||
def act(self, state: State, inputs, genome: Genome):
|
||||
return self.forward(state, inputs, genome)
|
||||
|
||||
def forward_transform(self, state: State, genome: Genome):
|
||||
"""create the forward transformation of a genome"""
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, state: State, inputs, genome: Genome):
|
||||
raise NotImplementedError
|
||||
|
||||
def ask_algorithm(self, state: State):
|
||||
"""ask the specific algorithm for a new population"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def tell_algorithm(self, state: State, fitness):
|
||||
"""tell the specific algorithm the fitness of the population"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
51
core/gene.py
51
core/gene.py
@@ -1,46 +1,37 @@
|
||||
from jax import Array, numpy as jnp
|
||||
|
||||
from config import GeneConfig
|
||||
from .state import State
|
||||
from .genome import Genome
|
||||
|
||||
|
||||
class Gene:
|
||||
node_attrs = []
|
||||
conn_attrs = []
|
||||
|
||||
@staticmethod
|
||||
def setup(config: GeneConfig, state: State):
|
||||
return state
|
||||
def setup(self, state=State()):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def new_node_attrs(state: State):
|
||||
return jnp.zeros(0)
|
||||
def update(self, state):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def new_conn_attrs(state: State):
|
||||
return jnp.zeros(0)
|
||||
def new_node_attrs(self, state: State):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def mutate_node(state: State, attrs: Array, randkey: Array):
|
||||
return attrs
|
||||
def new_conn_attrs(self, state: State):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def mutate_conn(state: State, attrs: Array, randkey: Array):
|
||||
return attrs
|
||||
def mutate_node(self, state: State, randkey, node_attrs):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def distance_node(state: State, node1: Array, node2: Array):
|
||||
return node1
|
||||
def mutate_conn(self, state: State, randkey, conn_attrs):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def distance_conn(state: State, conn1: Array, conn2: Array):
|
||||
return conn1
|
||||
def distance_node(self, state: State, node_attrs1, node_attrs2):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def forward_transform(state: State, genome: Genome):
|
||||
return jnp.zeros(0) # transformed
|
||||
def distance_conn(self, state: State, conn_attrs1, conn_attrs2):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def create_forward(state: State, config: GeneConfig):
|
||||
return lambda *args: args # forward function
|
||||
def forward_transform(self, state: State, genome):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, state: State, inputs, transform):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -84,4 +84,3 @@ class Genome:
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
return cls(*children)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user