diff --git a/src/tensorneat/algorithm/hyperneat/substrate/default.py b/src/tensorneat/algorithm/hyperneat/substrate/default.py index 40c691c..af99191 100644 --- a/src/tensorneat/algorithm/hyperneat/substrate/default.py +++ b/src/tensorneat/algorithm/hyperneat/substrate/default.py @@ -1,4 +1,5 @@ -from jax import vmap, numpy as jnp +from jax import vmap +import numpy as np from .base import BaseSubstrate from tensorneat.genome.utils import set_conn_attrs @@ -8,9 +9,9 @@ class DefaultSubstrate(BaseSubstrate): def __init__(self, num_inputs, num_outputs, coors, nodes, conns): self.inputs = num_inputs self.outputs = num_outputs - self.coors = jnp.array(coors) - self.nodes = jnp.array(nodes) - self.conns = jnp.array(conns) + self.coors = np.array(coors) + self.nodes = np.array(nodes) + self.conns = np.array(conns) def make_nodes(self, query_res): return self.nodes diff --git a/src/tensorneat/common/state.py b/src/tensorneat/common/state.py index 36bd165..ba0a09c 100644 --- a/src/tensorneat/common/state.py +++ b/src/tensorneat/common/state.py @@ -1,3 +1,5 @@ +import pickle + from jax.tree_util import register_pytree_node_class @@ -39,6 +41,15 @@ class State: def __contains__(self, item): return item in self.state_dict + def save(self, file_name): + with open(file_name, "wb") as f: + pickle.dump(self, f) + + @classmethod + def load(cls, file_name): + with open(file_name, "rb") as f: + return pickle.load(f) + def tree_flatten(self): children = list(self.state_dict.values()) aux_data = list(self.state_dict.keys()) diff --git a/src/tensorneat/common/stateful_class.py b/src/tensorneat/common/stateful_class.py index a7e2c2a..2f126f4 100644 --- a/src/tensorneat/common/stateful_class.py +++ b/src/tensorneat/common/stateful_class.py @@ -9,30 +9,6 @@ class StatefulBaseClass: def setup(self, state=State()): return state - def save(self, state: Optional[State] = None, path: Optional[str] = None): - if path is None: - time = datetime.datetime.now().strftime("%Y%m%d%H%M%S") - path = f"./{self.__class__.__name__} {time}.pkl" - if state is not None: - self.__dict__["aux_for_state"] = state - with open(path, "wb") as f: - pickle.dump(self, f) - - def __getstate__(self): - # only pickle the picklable attributes - state = self.__dict__.copy() - non_picklable_keys = [] - for key, value in state.items(): - try: - pickle.dumps(value) - except Exception: - non_picklable_keys.append(key) - - for key in non_picklable_keys: - state.pop(key) - - return state - def show_config(self, registered_objects=None): if registered_objects is None: # root call registered_objects = [] @@ -47,27 +23,53 @@ class StatefulBaseClass: config[str(key)] = str(value) return config - @classmethod - def load(cls, path: str, with_state: bool = False, warning: bool = True): - with open(path, "rb") as f: - obj = pickle.load(f) - if with_state: - if "aux_for_state" not in obj.__dict__: - if warning: - warnings.warn( - "This object does not have state to load, return empty state", - category=UserWarning, - ) - return obj, State() - state = obj.__dict__["aux_for_state"] - del obj.__dict__["aux_for_state"] - return obj, state - else: - if "aux_for_state" in obj.__dict__: - if warning: - warnings.warn( - "This object has state to load, ignore it", - category=UserWarning, - ) - del obj.__dict__["aux_for_state"] - return obj + # TODO: Bug need be fixed + # def save(self, state: Optional[State] = None, path: Optional[str] = None): + # if path is None: + # time = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + # path = f"./{self.__class__.__name__} {time}.pkl" + # if state is not None: + # self.__dict__["aux_for_state"] = state + # with open(path, "wb") as f: + # pickle.dump(self, f) + + # def __getstate__(self): + # # only pickle the picklable attributes + # state = self.__dict__.copy() + # non_picklable_keys = [] + # for key, value in state.items(): + # try: + # pickle.dumps(value) + # except Exception as e: + # print(f"Cannot pickle key {key}: {e}") + # non_picklable_keys.append(key) + + # for key in non_picklable_keys: + # state.pop(key) + + # return state + + # @classmethod + # def load(cls, path: str, with_state: bool = False, warning: bool = True): + # with open(path, "rb") as f: + # obj = pickle.load(f) + # if with_state: + # if "aux_for_state" not in obj.__dict__: + # if warning: + # warnings.warn( + # "This object does not have state to load, return empty state", + # category=UserWarning, + # ) + # return obj, State() + # state = obj.__dict__["aux_for_state"] + # del obj.__dict__["aux_for_state"] + # return obj, state + # else: + # if "aux_for_state" in obj.__dict__: + # if warning: + # warnings.warn( + # "This object has state to load, ignore it", + # category=UserWarning, + # ) + # del obj.__dict__["aux_for_state"] + # return obj diff --git a/src/tensorneat/genome/base.py b/src/tensorneat/genome/base.py index 8252d3c..bad52d4 100644 --- a/src/tensorneat/genome/base.py +++ b/src/tensorneat/genome/base.py @@ -96,9 +96,9 @@ class BaseGenome(StatefulBaseClass): def setup(self, state=State()): state = self.node_gene.setup(state) state = self.conn_gene.setup(state) - state = self.mutation.setup(state, self) - state = self.crossover.setup(state, self) - state = self.distance.setup(state, self) + state = self.mutation.setup(state) + state = self.crossover.setup(state) + state = self.distance.setup(state) return state def transform(self, state, nodes, conns): @@ -114,13 +114,13 @@ class BaseGenome(StatefulBaseClass): raise NotImplementedError def execute_mutation(self, state, randkey, nodes, conns, new_node_key): - return self.mutation(state, randkey, nodes, conns, new_node_key) + return self.mutation(state, self, randkey, nodes, conns, new_node_key) def execute_crossover(self, state, randkey, nodes1, conns1, nodes2, conns2): - return self.crossover(state, randkey, nodes1, conns1, nodes2, conns2) + return self.crossover(state, self, randkey, nodes1, conns1, nodes2, conns2) def execute_distance(self, state, nodes1, conns1, nodes2, conns2): - return self.distance(state, nodes1, conns1, nodes2, conns2) + return self.distance(state, self, nodes1, conns1, nodes2, conns2) def initialize(self, state, randkey): k1, k2 = jax.random.split(randkey) # k1 for nodes, k2 for conns diff --git a/src/tensorneat/genome/gene/node/bias.py b/src/tensorneat/genome/gene/node/bias.py index 37f5ae5..d162c09 100644 --- a/src/tensorneat/genome/gene/node/bias.py +++ b/src/tensorneat/genome/gene/node/bias.py @@ -13,7 +13,7 @@ from tensorneat.common import ( get_func_name ) -from . import BaseNode +from .base import BaseNode class BiasNode(BaseNode): diff --git a/src/tensorneat/genome/operations/crossover/base.py b/src/tensorneat/genome/operations/crossover/base.py index 143d519..e44b49e 100644 --- a/src/tensorneat/genome/operations/crossover/base.py +++ b/src/tensorneat/genome/operations/crossover/base.py @@ -3,10 +3,5 @@ from tensorneat.common import StatefulBaseClass, State class BaseCrossover(StatefulBaseClass): - def setup(self, state=State(), genome = None): - assert genome is not None, "genome should not be None" - self.genome = genome - return state - - def __call__(self, state, randkey, nodes1, nodes2, conns1, conns2): + def __call__(self, state, genome, randkey, nodes1, nodes2, conns1, conns2): raise NotImplementedError diff --git a/src/tensorneat/genome/operations/crossover/default.py b/src/tensorneat/genome/operations/crossover/default.py index 1548152..d17a0f4 100644 --- a/src/tensorneat/genome/operations/crossover/default.py +++ b/src/tensorneat/genome/operations/crossover/default.py @@ -11,14 +11,14 @@ from ...utils import ( class DefaultCrossover(BaseCrossover): - def __call__(self, state, randkey, nodes1, conns1, nodes2, conns2): + def __call__(self, state, genome, randkey, nodes1, conns1, nodes2, conns2): """ use genome1 and genome2 to generate a new genome notice that genome1 should have higher fitness than genome2 (genome1 is winner!) """ randkey1, randkey2 = jax.random.split(randkey, 2) - randkeys1 = jax.random.split(randkey1, self.genome.max_nodes) - randkeys2 = jax.random.split(randkey2, self.genome.max_conns) + randkeys1 = jax.random.split(randkey1, genome.max_nodes) + randkeys2 = jax.random.split(randkey2, genome.max_conns) # crossover nodes keys1, keys2 = nodes1[:, 0], nodes2[:, 0] @@ -33,7 +33,7 @@ class DefaultCrossover(BaseCrossover): new_node_attrs = jnp.where( jnp.isnan(node_attrs1) | jnp.isnan(node_attrs2), # one of them is nan node_attrs1, # not homologous genes or both nan, use the value of nodes1(winner) - vmap(self.genome.node_gene.crossover, in_axes=(None, 0, 0, 0))( + vmap(genome.node_gene.crossover, in_axes=(None, 0, 0, 0))( state, randkeys1, node_attrs1, node_attrs2 ), # homologous or both nan ) @@ -49,7 +49,7 @@ class DefaultCrossover(BaseCrossover): new_conn_attrs = jnp.where( jnp.isnan(conns_attrs1) | jnp.isnan(conns_attrs2), conns_attrs1, # not homologous genes or both nan, use the value of conns1(winner) - vmap(self.genome.conn_gene.crossover, in_axes=(None, 0, 0, 0))( + vmap(genome.conn_gene.crossover, in_axes=(None, 0, 0, 0))( state, randkeys2, conns_attrs1, conns_attrs2 ), # homologous or both nan ) diff --git a/src/tensorneat/genome/operations/distance/base.py b/src/tensorneat/genome/operations/distance/base.py index f3f6a65..cbb0a7e 100644 --- a/src/tensorneat/genome/operations/distance/base.py +++ b/src/tensorneat/genome/operations/distance/base.py @@ -3,12 +3,7 @@ from tensorneat.common import StatefulBaseClass, State class BaseDistance(StatefulBaseClass): - def setup(self, state=State(), genome = None): - assert genome is not None, "genome should not be None" - self.genome = genome - return state - - def __call__(self, state, nodes1, nodes2, conns1, conns2): + def __call__(self, state, genome, nodes1, nodes2, conns1, conns2): """ The distance between two genomes """ diff --git a/src/tensorneat/genome/operations/distance/default.py b/src/tensorneat/genome/operations/distance/default.py index d992d2e..00d7284 100644 --- a/src/tensorneat/genome/operations/distance/default.py +++ b/src/tensorneat/genome/operations/distance/default.py @@ -13,16 +13,16 @@ class DefaultDistance(BaseDistance): self.compatibility_disjoint = compatibility_disjoint self.compatibility_weight = compatibility_weight - def __call__(self, state, nodes1, conns1, nodes2, conns2): + def __call__(self, state, genome, nodes1, conns1, nodes2, conns2): """ The distance between two genomes """ - d = self.node_distance(state, nodes1, nodes2) + self.conn_distance( - state, conns1, conns2 + d = self.node_distance(state, genome, nodes1, nodes2) + self.conn_distance( + state, genome, conns1, conns2 ) return d - def node_distance(self, state, nodes1, nodes2): + def node_distance(self, state, genome, nodes1, nodes2): """ The distance of the nodes part for two genomes """ @@ -50,7 +50,7 @@ class DefaultDistance(BaseDistance): # calculate the distance of homologous nodes fr_attrs = vmap(extract_node_attrs)(fr) sr_attrs = vmap(extract_node_attrs)(sr) - hnd = vmap(self.genome.node_gene.distance, in_axes=(None, 0, 0))( + hnd = vmap(genome.node_gene.distance, in_axes=(None, 0, 0))( state, fr_attrs, sr_attrs ) # homologous node distance hnd = jnp.where(jnp.isnan(hnd), 0, hnd) @@ -65,7 +65,7 @@ class DefaultDistance(BaseDistance): return val - def conn_distance(self, state, conns1, conns2): + def conn_distance(self, state, genome, conns1, conns2): """ The distance of the conns part for two genomes """ @@ -89,7 +89,7 @@ class DefaultDistance(BaseDistance): fr_attrs = vmap(extract_conn_attrs)(fr) sr_attrs = vmap(extract_conn_attrs)(sr) - hcd = vmap(self.genome.conn_gene.distance, in_axes=(None, 0, 0))( + hcd = vmap(genome.conn_gene.distance, in_axes=(None, 0, 0))( state, fr_attrs, sr_attrs ) # homologous connection distance hcd = jnp.where(jnp.isnan(hcd), 0, hcd) diff --git a/src/tensorneat/genome/operations/mutation/base.py b/src/tensorneat/genome/operations/mutation/base.py index 2d138af..a1b08cb 100644 --- a/src/tensorneat/genome/operations/mutation/base.py +++ b/src/tensorneat/genome/operations/mutation/base.py @@ -3,10 +3,5 @@ from tensorneat.common import StatefulBaseClass, State class BaseMutation(StatefulBaseClass): - def setup(self, state=State(), genome = None): - assert genome is not None, "genome should not be None" - self.genome = genome - return state - - def __call__(self, state, randkey, nodes, conns, new_node_key): + def __call__(self, state, genome, randkey, nodes, conns, new_node_key): raise NotImplementedError diff --git a/src/tensorneat/genome/operations/mutation/default.py b/src/tensorneat/genome/operations/mutation/default.py index 7a0f2b5..d761e1a 100644 --- a/src/tensorneat/genome/operations/mutation/default.py +++ b/src/tensorneat/genome/operations/mutation/default.py @@ -33,17 +33,17 @@ class DefaultMutation(BaseMutation): self.node_add = node_add self.node_delete = node_delete - def __call__(self, state, randkey, nodes, conns, new_node_key): + def __call__(self, state, genome, randkey, nodes, conns, new_node_key): k1, k2 = jax.random.split(randkey) nodes, conns = self.mutate_structure( - state, k1, nodes, conns, new_node_key + state, genome, k1, nodes, conns, new_node_key ) - nodes, conns = self.mutate_values(state, k2, nodes, conns) + nodes, conns = self.mutate_values(state, genome, k2, nodes, conns) return nodes, conns - def mutate_structure(self, state, randkey, nodes, conns, new_node_key): + def mutate_structure(self, state, genome, randkey, nodes, conns, new_node_key): def mutate_add_node(key_, nodes_, conns_): """ add a node while do not influence the output of the network @@ -62,7 +62,7 @@ class DefaultMutation(BaseMutation): # add a new node with identity attrs new_nodes = add_node( - nodes_, new_node_key, self.genome.node_gene.new_identity_attrs(state) + nodes_, new_node_key, genome.node_gene.new_identity_attrs(state) ) # add two new connections @@ -71,7 +71,7 @@ class DefaultMutation(BaseMutation): new_conns, i_key, new_node_key, - self.genome.conn_gene.new_identity_attrs(state), + genome.conn_gene.new_identity_attrs(state), ) # second is with the origin attrs new_conns = add_conn( @@ -97,8 +97,8 @@ class DefaultMutation(BaseMutation): key, idx = self.choose_node_key( key_, nodes_, - self.genome.input_idx, - self.genome.output_idx, + genome.input_idx, + genome.output_idx, allow_input_keys=False, allow_output_keys=False, ) @@ -136,8 +136,8 @@ class DefaultMutation(BaseMutation): i_key, from_idx = self.choose_node_key( k1_, nodes_, - self.genome.input_idx, - self.genome.output_idx, + genome.input_idx, + genome.output_idx, allow_input_keys=True, allow_output_keys=True, ) @@ -146,8 +146,8 @@ class DefaultMutation(BaseMutation): o_key, to_idx = self.choose_node_key( k2_, nodes_, - self.genome.input_idx, - self.genome.output_idx, + genome.input_idx, + genome.output_idx, allow_input_keys=False, allow_output_keys=True, ) @@ -161,10 +161,10 @@ class DefaultMutation(BaseMutation): def successful(): # add a connection with zero attrs return nodes_, add_conn( - conns_, i_key, o_key, self.genome.conn_gene.new_zero_attrs(state) + conns_, i_key, o_key, genome.conn_gene.new_zero_attrs(state) ) - if self.genome.network_type == "feedforward": + if genome.network_type == "feedforward": u_conns = unflatten_conns(nodes_, conns_) conns_exist = u_conns != I_INF is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx) @@ -175,7 +175,7 @@ class DefaultMutation(BaseMutation): successful, ) - elif self.genome.network_type == "recurrent": + elif genome.network_type == "recurrent": return jax.lax.cond( is_already_exist | (remain_conn_space < 1), nothing, @@ -183,7 +183,7 @@ class DefaultMutation(BaseMutation): ) else: - raise ValueError(f"Invalid network type: {self.genome.network_type}") + raise ValueError(f"Invalid network type: {genome.network_type}") def mutate_delete_conn(key_, nodes_, conns_): # randomly choose a connection @@ -223,19 +223,19 @@ class DefaultMutation(BaseMutation): return nodes, conns - def mutate_values(self, state, randkey, nodes, conns): + def mutate_values(self, state, genome, randkey, nodes, conns): k1, k2 = jax.random.split(randkey) - nodes_randkeys = jax.random.split(k1, num=self.genome.max_nodes) - conns_randkeys = jax.random.split(k2, num=self.genome.max_conns) + nodes_randkeys = jax.random.split(k1, num=genome.max_nodes) + conns_randkeys = jax.random.split(k2, num=genome.max_conns) node_attrs = vmap(extract_node_attrs)(nodes) - new_node_attrs = vmap(self.genome.node_gene.mutate, in_axes=(None, 0, 0))( + new_node_attrs = vmap(genome.node_gene.mutate, in_axes=(None, 0, 0))( state, nodes_randkeys, node_attrs ) new_nodes = vmap(set_node_attrs)(nodes, new_node_attrs) conn_attrs = vmap(extract_conn_attrs)(conns) - new_conn_attrs = vmap(self.genome.conn_gene.mutate, in_axes=(None, 0, 0))( + new_conn_attrs = vmap(genome.conn_gene.mutate, in_axes=(None, 0, 0))( state, conns_randkeys, conn_attrs ) new_conns = vmap(set_conn_attrs)(conns, new_conn_attrs) diff --git a/src/tensorneat/pipeline.py b/src/tensorneat/pipeline.py index 0de5a71..7a86eda 100644 --- a/src/tensorneat/pipeline.py +++ b/src/tensorneat/pipeline.py @@ -184,6 +184,6 @@ class Pipeline(StatefulBaseClass): def show(self, state, best, *args, **kwargs): transformed = self.algorithm.transform(state, best) - self.problem.show( + return self.problem.show( state, state.randkey, self.algorithm.forward, transformed, *args, **kwargs )