remove create_func....

This commit is contained in:
wls2002
2023-08-02 13:26:01 +08:00
parent 85318f98f3
commit 1499e062fe
34 changed files with 558 additions and 1022 deletions

View File

@@ -1,46 +1,37 @@
from jax import Array, numpy as jnp
from config import GeneConfig
from .state import State
from .genome import Genome
class Gene:
node_attrs = []
conn_attrs = []
@staticmethod
def setup(config: GeneConfig, state: State):
return state
def setup(self, state=State()):
raise NotImplementedError
@staticmethod
def new_node_attrs(state: State):
return jnp.zeros(0)
def update(self, state):
raise NotImplementedError
@staticmethod
def new_conn_attrs(state: State):
return jnp.zeros(0)
def new_node_attrs(self, state: State):
raise NotImplementedError
@staticmethod
def mutate_node(state: State, attrs: Array, randkey: Array):
return attrs
def new_conn_attrs(self, state: State):
raise NotImplementedError
@staticmethod
def mutate_conn(state: State, attrs: Array, randkey: Array):
return attrs
def mutate_node(self, state: State, randkey, node_attrs):
raise NotImplementedError
@staticmethod
def distance_node(state: State, node1: Array, node2: Array):
return node1
def mutate_conn(self, state: State, randkey, conn_attrs):
raise NotImplementedError
@staticmethod
def distance_conn(state: State, conn1: Array, conn2: Array):
return conn1
def distance_node(self, state: State, node_attrs1, node_attrs2):
raise NotImplementedError
@staticmethod
def forward_transform(state: State, genome: Genome):
return jnp.zeros(0) # transformed
def distance_conn(self, state: State, conn_attrs1, conn_attrs2):
raise NotImplementedError
@staticmethod
def create_forward(state: State, config: GeneConfig):
return lambda *args: args # forward function
def forward_transform(self, state: State, genome):
raise NotImplementedError
def forward(self, state: State, inputs, transform):
raise NotImplementedError