remove create_func....

This commit is contained in:
wls2002
2023-08-04 17:29:36 +08:00
parent c7fb1ddabe
commit 0e44b13291
29 changed files with 591 additions and 259 deletions

View File

@@ -6,7 +6,7 @@ class Gene:
node_attrs = []
conn_attrs = []
def __init__(self, config: GeneConfig):
def __init__(self, config: GeneConfig = GeneConfig()):
raise NotImplementedError
def setup(self, state=State()):

View File

@@ -19,6 +19,11 @@ class Genome:
def __getitem__(self, idx):
return self.__class__(self.nodes[idx], self.conns[idx])
def __eq__(self, other):
nodes_eq = jnp.alltrue((self.nodes == other.nodes) | (jnp.isnan(self.nodes) & jnp.isnan(other.nodes)))
conns_eq = jnp.alltrue((self.conns == other.conns) | (jnp.isnan(self.conns) & jnp.isnan(other.conns)))
return nodes_eq & conns_eq
def set(self, idx, value: Genome):
return self.__class__(self.nodes.at[idx].set(value.nodes), self.conns.at[idx].set(value.conns))
@@ -83,4 +88,3 @@ class Genome:
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)

View File

@@ -1,15 +1,27 @@
from typing import Callable
from config import ProblemConfig
from state import State
from .state import State
class Problem:
def __init__(self, config: ProblemConfig):
def __init__(self, problem_config: ProblemConfig = ProblemConfig()):
self.config = problem_config
def evaluate(self, randkey, state: State, act_func: Callable, params):
raise NotImplementedError
def setup(self, state=State()):
@property
def input_shape(self):
raise NotImplementedError
def evaluate(self, state: State, act_func: Callable, params):
@property
def output_shape(self):
raise NotImplementedError
def show(self, randkey, state: State, act_func: Callable, params):
"""
show how a genome perform in this problem
"""
raise NotImplementedError

View File

@@ -4,7 +4,5 @@ from config import SubstrateConfig
class Substrate:
@staticmethod
def setup(state, config: SubstrateConfig):
def setup(state, config: SubstrateConfig = SubstrateConfig()):
return state