diff --git a/tensorneat/algorithm/neat/ga/mutation/default.py b/tensorneat/algorithm/neat/ga/mutation/default.py index 2786022..83f994e 100644 --- a/tensorneat/algorithm/neat/ga/mutation/default.py +++ b/tensorneat/algorithm/neat/ga/mutation/default.py @@ -41,15 +41,13 @@ class DefaultMutation(BaseMutation): return nodes, conns def mutate_structure(self, state, randkey, genome, nodes, conns, new_node_key): - - remain_node_space = jnp.isnan(nodes[:, 0]).sum() - remain_conn_space = jnp.isnan(conns[:, 0]).sum() - def mutate_add_node(key_, nodes_, conns_): """ add a node while do not influence the output of the network """ + remain_node_space = jnp.isnan(nodes_[:, 0]).sum() + remain_conn_space = jnp.isnan(conns_[:, 0]).sum() i_key, o_key, idx = self.choose_connection_key( key_, conns_ ) # choose a connection @@ -83,7 +81,7 @@ class DefaultMutation(BaseMutation): return new_nodes, new_conns return jax.lax.cond( - (idx == I_INF) & (remain_node_space < 1) & (remain_conn_space < 2), + (idx == I_INF) | (remain_node_space < 1) | (remain_conn_space < 2), lambda: (nodes_, conns_), # do nothing successful_add_node, ) @@ -92,7 +90,6 @@ class DefaultMutation(BaseMutation): """ delete a node """ - # randomly choose a node key, idx = self.choose_node_key( key_, @@ -127,6 +124,8 @@ class DefaultMutation(BaseMutation): add a connection while do not influence the output of the network """ + remain_conn_space = jnp.isnan(conns_[:, 0]).sum() + # randomly choose two nodes k1_, k2_ = jax.random.split(key_, num=2) @@ -164,7 +163,7 @@ class DefaultMutation(BaseMutation): if genome.network_type == "feedforward": u_conns = unflatten_conns(nodes_, conns_) - conns_exist = (u_conns != I_INF) + conns_exist = u_conns != I_INF is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx) return jax.lax.cond( diff --git a/tensorneat/algorithm/neat/gene/base.py b/tensorneat/algorithm/neat/gene/base.py index d3cbad6..296a6c4 100644 --- a/tensorneat/algorithm/neat/gene/base.py +++ b/tensorneat/algorithm/neat/gene/base.py @@ -40,3 +40,6 @@ class BaseGene(StatefulBaseClass): @property def length(self): return len(self.fixed_attrs) + len(self.custom_attrs) + + def repr(self, state, gene, precision=2): + raise NotImplementedError diff --git a/tensorneat/algorithm/neat/gene/conn/base.py b/tensorneat/algorithm/neat/gene/conn/base.py index 39d72fe..cfabaf5 100644 --- a/tensorneat/algorithm/neat/gene/conn/base.py +++ b/tensorneat/algorithm/neat/gene/conn/base.py @@ -22,3 +22,12 @@ class BaseConnGene(BaseGene): jax.vmap(self.forward, in_axes=(None, None, 0))(state, attrs, batch_inputs), attrs, ) + + def repr(self, state, conn, precision=2, idx_width=3, func_width=8): + in_idx, out_idx = conn[:2] + in_idx = int(in_idx) + out_idx = int(out_idx) + + return "{}(in: {:<{idx_width}}, out: {:<{idx_width}})".format( + self.__class__.__name__, in_idx, out_idx, idx_width=idx_width + ) diff --git a/tensorneat/algorithm/neat/gene/conn/default.py b/tensorneat/algorithm/neat/gene/conn/default.py index 00520cc..0e31a16 100644 --- a/tensorneat/algorithm/neat/gene/conn/default.py +++ b/tensorneat/algorithm/neat/gene/conn/default.py @@ -60,3 +60,19 @@ class DefaultConnGene(BaseConnGene): def forward(self, state, attrs, inputs): weight = attrs[0] return inputs * weight + + def repr(self, state, conn, precision=2, idx_width=3, func_width=8): + in_idx, out_idx, weight = conn + + in_idx = int(in_idx) + out_idx = int(out_idx) + weight = round(float(weight), precision) + + return "{}(in: {:<{idx_width}}, out: {:<{idx_width}}, weight: {:<{float_width}})".format( + self.__class__.__name__, + in_idx, + out_idx, + weight, + idx_width=idx_width, + float_width=precision + 3, + ) diff --git a/tensorneat/algorithm/neat/gene/node/base.py b/tensorneat/algorithm/neat/gene/node/base.py index 452bf91..59481d7 100644 --- a/tensorneat/algorithm/neat/gene/node/base.py +++ b/tensorneat/algorithm/neat/gene/node/base.py @@ -39,3 +39,11 @@ class BaseNodeGene(BaseGene): ), attrs, ) + + def repr(self, state, node, precision=2, idx_width=3, func_width=8): + idx = node[0] + + idx = int(idx) + return "{}(idx={:<{idx_width}})".format( + self.__class__.__name__, idx, idx_width=idx_width + ) diff --git a/tensorneat/algorithm/neat/gene/node/default.py b/tensorneat/algorithm/neat/gene/node/default.py index cd89eab..99098c3 100644 --- a/tensorneat/algorithm/neat/gene/node/default.py +++ b/tensorneat/algorithm/neat/gene/node/default.py @@ -122,3 +122,28 @@ class DefaultNodeGene(BaseNodeGene): ) return z + + def repr(self, state, node, precision=2, idx_width=3, func_width=8): + idx, bias, res, agg, act = node + + idx = int(idx) + bias = round(float(bias), precision) + res = round(float(res), precision) + agg = int(agg) + act = int(act) + + if act == -1: + act_func = Act.identity + else: + act_func = self.activation_options[act] + return "{}(idx={:<{idx_width}}, bias={:<{float_width}}, response={:<{float_width}}, aggregation={:<{func_width}}, activation={:<{func_width}})".format( + self.__class__.__name__, + idx, + bias, + res, + self.aggregation_options[agg].__name__, + act_func.__name__, + idx_width=idx_width, + float_width=precision + 3, + func_width=func_width + ) diff --git a/tensorneat/algorithm/neat/gene/node/default_without_response.py b/tensorneat/algorithm/neat/gene/node/default_without_response.py index 5f52ea8..f5ee75c 100644 --- a/tensorneat/algorithm/neat/gene/node/default_without_response.py +++ b/tensorneat/algorithm/neat/gene/node/default_without_response.py @@ -98,3 +98,26 @@ class NodeGeneWithoutResponse(BaseNodeGene): ) return z + + def repr(self, state, node, precision=2, idx_width=3, func_width=8): + idx, bias, agg, act = node + + idx = int(idx) + bias = round(float(bias), precision) + agg = int(agg) + act = int(act) + + if act == -1: + act_func = Act.identity + else: + act_func = self.activation_options[act] + return "{}(idx={:<{idx_width}}, bias={:<{float_width}}, aggregation={:<{func_width}}, activation={:<{func_width}})".format( + self.__class__.__name__, + idx, + bias, + self.aggregation_options[agg].__name__, + act_func.__name__, + idx_width=idx_width, + float_width=precision + 3, + func_width=func_width, + ) diff --git a/tensorneat/algorithm/neat/gene/node/kan_node.py b/tensorneat/algorithm/neat/gene/node/kan_node.py index 300a6a2..e283272 100644 --- a/tensorneat/algorithm/neat/gene/node/kan_node.py +++ b/tensorneat/algorithm/neat/gene/node/kan_node.py @@ -25,3 +25,8 @@ class KANNode(BaseNodeGene): def forward(self, state, attrs, inputs, is_output_node=False): return Agg.sum(inputs) + + def repr(self, state, node, precision=2): + idx = node[0] + idx = int(idx) + return "{}(idx: {})".format(self.__class__.__name__, idx) diff --git a/tensorneat/algorithm/neat/genome/base.py b/tensorneat/algorithm/neat/genome/base.py index 71170df..3259351 100644 --- a/tensorneat/algorithm/neat/genome/base.py +++ b/tensorneat/algorithm/neat/genome/base.py @@ -1,3 +1,4 @@ +import numpy as np import jax, jax.numpy as jnp from ..gene import BaseNodeGene, BaseConnGene from ..ga import BaseMutation, BaseCrossover @@ -20,8 +21,8 @@ class BaseGenome(StatefulBaseClass): ): self.num_inputs = num_inputs self.num_outputs = num_outputs - self.input_idx = jnp.arange(num_inputs) - self.output_idx = jnp.arange(num_inputs, num_inputs + num_outputs) + self.input_idx = np.arange(num_inputs) + self.output_idx = np.arange(num_inputs, num_inputs + num_outputs) self.max_nodes = max_nodes self.max_conns = max_conns self.node_gene = node_gene @@ -127,3 +128,30 @@ class BaseGenome(StatefulBaseClass): Update the genome by a batch of data. """ raise NotImplementedError + + def repr(self, state, nodes, conns, precision=2): + nodes, conns = jax.device_get([nodes, conns]) + nodes_cnt, conns_cnt = self.valid_cnt(nodes), self.valid_cnt(conns) + s = f"{self.__class__.__name__}(nodes={nodes_cnt}, conns={conns_cnt}):\n" + s += f"\tNodes:\n" + for node in nodes: + if np.isnan(node[0]): + break + s += f"\t\t{self.node_gene.repr(state, node, precision=precision)}" + node_idx = int(node[0]) + if np.isin(node_idx, self.input_idx): + s += " (input)" + elif np.isin(node_idx, self.output_idx): + s += " (output)" + s += "\n" + + s += f"\tConns:\n" + for conn in conns: + if np.isnan(conn[0]): + break + s += f"\t\t{self.conn_gene.repr(state, conn, precision=precision)}\n" + return s + + @classmethod + def valid_cnt(cls, arr): + return jnp.sum(~jnp.isnan(arr[:, 0])) diff --git a/tensorneat/examples/func_fit/xor.py b/tensorneat/examples/func_fit/xor.py index 24ebc0f..e6038cd 100644 --- a/tensorneat/examples/func_fit/xor.py +++ b/tensorneat/examples/func_fit/xor.py @@ -47,3 +47,4 @@ if __name__ == "__main__": state, best = pipeline.auto_run(state) # show result pipeline.show(state, best) + pipeline.save(state=state) diff --git a/tensorneat/utils/stateful_class.py b/tensorneat/utils/stateful_class.py index e865531..a1c8090 100644 --- a/tensorneat/utils/stateful_class.py +++ b/tensorneat/utils/stateful_class.py @@ -37,7 +37,7 @@ class StatefulBaseClass: if "aux_for_state" in obj.__dict__: if warning: warnings.warn( - "This object state to load, ignore it", + "This object has state to load, ignore it", category=UserWarning, ) del obj.__dict__["aux_for_state"]