add repr for genome and gene;
add ipynb test for testing whether add node or add conn will not change the output for the network.
This commit is contained in:
@@ -41,15 +41,13 @@ class DefaultMutation(BaseMutation):
|
||||
return nodes, conns
|
||||
|
||||
def mutate_structure(self, state, randkey, genome, nodes, conns, new_node_key):
|
||||
|
||||
remain_node_space = jnp.isnan(nodes[:, 0]).sum()
|
||||
remain_conn_space = jnp.isnan(conns[:, 0]).sum()
|
||||
|
||||
def mutate_add_node(key_, nodes_, conns_):
|
||||
"""
|
||||
add a node while do not influence the output of the network
|
||||
"""
|
||||
|
||||
remain_node_space = jnp.isnan(nodes_[:, 0]).sum()
|
||||
remain_conn_space = jnp.isnan(conns_[:, 0]).sum()
|
||||
i_key, o_key, idx = self.choose_connection_key(
|
||||
key_, conns_
|
||||
) # choose a connection
|
||||
@@ -83,7 +81,7 @@ class DefaultMutation(BaseMutation):
|
||||
return new_nodes, new_conns
|
||||
|
||||
return jax.lax.cond(
|
||||
(idx == I_INF) & (remain_node_space < 1) & (remain_conn_space < 2),
|
||||
(idx == I_INF) | (remain_node_space < 1) | (remain_conn_space < 2),
|
||||
lambda: (nodes_, conns_), # do nothing
|
||||
successful_add_node,
|
||||
)
|
||||
@@ -92,7 +90,6 @@ class DefaultMutation(BaseMutation):
|
||||
"""
|
||||
delete a node
|
||||
"""
|
||||
|
||||
# randomly choose a node
|
||||
key, idx = self.choose_node_key(
|
||||
key_,
|
||||
@@ -127,6 +124,8 @@ class DefaultMutation(BaseMutation):
|
||||
add a connection while do not influence the output of the network
|
||||
"""
|
||||
|
||||
remain_conn_space = jnp.isnan(conns_[:, 0]).sum()
|
||||
|
||||
# randomly choose two nodes
|
||||
k1_, k2_ = jax.random.split(key_, num=2)
|
||||
|
||||
@@ -164,7 +163,7 @@ class DefaultMutation(BaseMutation):
|
||||
|
||||
if genome.network_type == "feedforward":
|
||||
u_conns = unflatten_conns(nodes_, conns_)
|
||||
conns_exist = (u_conns != I_INF)
|
||||
conns_exist = u_conns != I_INF
|
||||
is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx)
|
||||
|
||||
return jax.lax.cond(
|
||||
|
||||
@@ -40,3 +40,6 @@ class BaseGene(StatefulBaseClass):
|
||||
@property
|
||||
def length(self):
|
||||
return len(self.fixed_attrs) + len(self.custom_attrs)
|
||||
|
||||
def repr(self, state, gene, precision=2):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -22,3 +22,12 @@ class BaseConnGene(BaseGene):
|
||||
jax.vmap(self.forward, in_axes=(None, None, 0))(state, attrs, batch_inputs),
|
||||
attrs,
|
||||
)
|
||||
|
||||
def repr(self, state, conn, precision=2, idx_width=3, func_width=8):
|
||||
in_idx, out_idx = conn[:2]
|
||||
in_idx = int(in_idx)
|
||||
out_idx = int(out_idx)
|
||||
|
||||
return "{}(in: {:<{idx_width}}, out: {:<{idx_width}})".format(
|
||||
self.__class__.__name__, in_idx, out_idx, idx_width=idx_width
|
||||
)
|
||||
|
||||
@@ -60,3 +60,19 @@ class DefaultConnGene(BaseConnGene):
|
||||
def forward(self, state, attrs, inputs):
|
||||
weight = attrs[0]
|
||||
return inputs * weight
|
||||
|
||||
def repr(self, state, conn, precision=2, idx_width=3, func_width=8):
|
||||
in_idx, out_idx, weight = conn
|
||||
|
||||
in_idx = int(in_idx)
|
||||
out_idx = int(out_idx)
|
||||
weight = round(float(weight), precision)
|
||||
|
||||
return "{}(in: {:<{idx_width}}, out: {:<{idx_width}}, weight: {:<{float_width}})".format(
|
||||
self.__class__.__name__,
|
||||
in_idx,
|
||||
out_idx,
|
||||
weight,
|
||||
idx_width=idx_width,
|
||||
float_width=precision + 3,
|
||||
)
|
||||
|
||||
@@ -39,3 +39,11 @@ class BaseNodeGene(BaseGene):
|
||||
),
|
||||
attrs,
|
||||
)
|
||||
|
||||
def repr(self, state, node, precision=2, idx_width=3, func_width=8):
|
||||
idx = node[0]
|
||||
|
||||
idx = int(idx)
|
||||
return "{}(idx={:<{idx_width}})".format(
|
||||
self.__class__.__name__, idx, idx_width=idx_width
|
||||
)
|
||||
|
||||
@@ -122,3 +122,28 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
)
|
||||
|
||||
return z
|
||||
|
||||
def repr(self, state, node, precision=2, idx_width=3, func_width=8):
|
||||
idx, bias, res, agg, act = node
|
||||
|
||||
idx = int(idx)
|
||||
bias = round(float(bias), precision)
|
||||
res = round(float(res), precision)
|
||||
agg = int(agg)
|
||||
act = int(act)
|
||||
|
||||
if act == -1:
|
||||
act_func = Act.identity
|
||||
else:
|
||||
act_func = self.activation_options[act]
|
||||
return "{}(idx={:<{idx_width}}, bias={:<{float_width}}, response={:<{float_width}}, aggregation={:<{func_width}}, activation={:<{func_width}})".format(
|
||||
self.__class__.__name__,
|
||||
idx,
|
||||
bias,
|
||||
res,
|
||||
self.aggregation_options[agg].__name__,
|
||||
act_func.__name__,
|
||||
idx_width=idx_width,
|
||||
float_width=precision + 3,
|
||||
func_width=func_width
|
||||
)
|
||||
|
||||
@@ -98,3 +98,26 @@ class NodeGeneWithoutResponse(BaseNodeGene):
|
||||
)
|
||||
|
||||
return z
|
||||
|
||||
def repr(self, state, node, precision=2, idx_width=3, func_width=8):
|
||||
idx, bias, agg, act = node
|
||||
|
||||
idx = int(idx)
|
||||
bias = round(float(bias), precision)
|
||||
agg = int(agg)
|
||||
act = int(act)
|
||||
|
||||
if act == -1:
|
||||
act_func = Act.identity
|
||||
else:
|
||||
act_func = self.activation_options[act]
|
||||
return "{}(idx={:<{idx_width}}, bias={:<{float_width}}, aggregation={:<{func_width}}, activation={:<{func_width}})".format(
|
||||
self.__class__.__name__,
|
||||
idx,
|
||||
bias,
|
||||
self.aggregation_options[agg].__name__,
|
||||
act_func.__name__,
|
||||
idx_width=idx_width,
|
||||
float_width=precision + 3,
|
||||
func_width=func_width,
|
||||
)
|
||||
|
||||
@@ -25,3 +25,8 @@ class KANNode(BaseNodeGene):
|
||||
|
||||
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||
return Agg.sum(inputs)
|
||||
|
||||
def repr(self, state, node, precision=2):
|
||||
idx = node[0]
|
||||
idx = int(idx)
|
||||
return "{}(idx: {})".format(self.__class__.__name__, idx)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import numpy as np
|
||||
import jax, jax.numpy as jnp
|
||||
from ..gene import BaseNodeGene, BaseConnGene
|
||||
from ..ga import BaseMutation, BaseCrossover
|
||||
@@ -20,8 +21,8 @@ class BaseGenome(StatefulBaseClass):
|
||||
):
|
||||
self.num_inputs = num_inputs
|
||||
self.num_outputs = num_outputs
|
||||
self.input_idx = jnp.arange(num_inputs)
|
||||
self.output_idx = jnp.arange(num_inputs, num_inputs + num_outputs)
|
||||
self.input_idx = np.arange(num_inputs)
|
||||
self.output_idx = np.arange(num_inputs, num_inputs + num_outputs)
|
||||
self.max_nodes = max_nodes
|
||||
self.max_conns = max_conns
|
||||
self.node_gene = node_gene
|
||||
@@ -127,3 +128,30 @@ class BaseGenome(StatefulBaseClass):
|
||||
Update the genome by a batch of data.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def repr(self, state, nodes, conns, precision=2):
|
||||
nodes, conns = jax.device_get([nodes, conns])
|
||||
nodes_cnt, conns_cnt = self.valid_cnt(nodes), self.valid_cnt(conns)
|
||||
s = f"{self.__class__.__name__}(nodes={nodes_cnt}, conns={conns_cnt}):\n"
|
||||
s += f"\tNodes:\n"
|
||||
for node in nodes:
|
||||
if np.isnan(node[0]):
|
||||
break
|
||||
s += f"\t\t{self.node_gene.repr(state, node, precision=precision)}"
|
||||
node_idx = int(node[0])
|
||||
if np.isin(node_idx, self.input_idx):
|
||||
s += " (input)"
|
||||
elif np.isin(node_idx, self.output_idx):
|
||||
s += " (output)"
|
||||
s += "\n"
|
||||
|
||||
s += f"\tConns:\n"
|
||||
for conn in conns:
|
||||
if np.isnan(conn[0]):
|
||||
break
|
||||
s += f"\t\t{self.conn_gene.repr(state, conn, precision=precision)}\n"
|
||||
return s
|
||||
|
||||
@classmethod
|
||||
def valid_cnt(cls, arr):
|
||||
return jnp.sum(~jnp.isnan(arr[:, 0]))
|
||||
|
||||
@@ -47,3 +47,4 @@ if __name__ == "__main__":
|
||||
state, best = pipeline.auto_run(state)
|
||||
# show result
|
||||
pipeline.show(state, best)
|
||||
pipeline.save(state=state)
|
||||
|
||||
@@ -37,7 +37,7 @@ class StatefulBaseClass:
|
||||
if "aux_for_state" in obj.__dict__:
|
||||
if warning:
|
||||
warnings.warn(
|
||||
"This object state to load, ignore it",
|
||||
"This object has state to load, ignore it",
|
||||
category=UserWarning,
|
||||
)
|
||||
del obj.__dict__["aux_for_state"]
|
||||
|
||||
Reference in New Issue
Block a user