remove create_func....
This commit is contained in:
@@ -6,7 +6,7 @@ class Gene:
|
||||
node_attrs = []
|
||||
conn_attrs = []
|
||||
|
||||
def __init__(self, config: GeneConfig):
|
||||
def __init__(self, config: GeneConfig = GeneConfig()):
|
||||
raise NotImplementedError
|
||||
|
||||
def setup(self, state=State()):
|
||||
|
||||
@@ -19,6 +19,11 @@ class Genome:
|
||||
def __getitem__(self, idx):
|
||||
return self.__class__(self.nodes[idx], self.conns[idx])
|
||||
|
||||
def __eq__(self, other):
|
||||
nodes_eq = jnp.alltrue((self.nodes == other.nodes) | (jnp.isnan(self.nodes) & jnp.isnan(other.nodes)))
|
||||
conns_eq = jnp.alltrue((self.conns == other.conns) | (jnp.isnan(self.conns) & jnp.isnan(other.conns)))
|
||||
return nodes_eq & conns_eq
|
||||
|
||||
def set(self, idx, value: Genome):
|
||||
return self.__class__(self.nodes.at[idx].set(value.nodes), self.conns.at[idx].set(value.conns))
|
||||
|
||||
@@ -83,4 +88,3 @@ class Genome:
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
return cls(*children)
|
||||
|
||||
|
||||
@@ -1,15 +1,27 @@
|
||||
from typing import Callable
|
||||
|
||||
from config import ProblemConfig
|
||||
from state import State
|
||||
from .state import State
|
||||
|
||||
|
||||
class Problem:
|
||||
|
||||
def __init__(self, config: ProblemConfig):
|
||||
def __init__(self, problem_config: ProblemConfig = ProblemConfig()):
|
||||
self.config = problem_config
|
||||
|
||||
def evaluate(self, randkey, state: State, act_func: Callable, params):
|
||||
raise NotImplementedError
|
||||
|
||||
def setup(self, state=State()):
|
||||
@property
|
||||
def input_shape(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def evaluate(self, state: State, act_func: Callable, params):
|
||||
@property
|
||||
def output_shape(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def show(self, randkey, state: State, act_func: Callable, params):
|
||||
"""
|
||||
show how a genome perform in this problem
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -4,7 +4,5 @@ from config import SubstrateConfig
|
||||
class Substrate:
|
||||
|
||||
@staticmethod
|
||||
def setup(state, config: SubstrateConfig):
|
||||
def setup(state, config: SubstrateConfig = SubstrateConfig()):
|
||||
return state
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user