add repr for genome and gene;

add ipynb test for testing whether add node or add conn will not change the output for the network.
This commit is contained in:
wls2002
2024-06-09 22:32:29 +08:00
parent 52e5d603f5
commit dfc8f9198e
11 changed files with 127 additions and 10 deletions

View File

@@ -41,15 +41,13 @@ class DefaultMutation(BaseMutation):
return nodes, conns
def mutate_structure(self, state, randkey, genome, nodes, conns, new_node_key):
remain_node_space = jnp.isnan(nodes[:, 0]).sum()
remain_conn_space = jnp.isnan(conns[:, 0]).sum()
def mutate_add_node(key_, nodes_, conns_):
"""
add a node while do not influence the output of the network
"""
remain_node_space = jnp.isnan(nodes_[:, 0]).sum()
remain_conn_space = jnp.isnan(conns_[:, 0]).sum()
i_key, o_key, idx = self.choose_connection_key(
key_, conns_
) # choose a connection
@@ -83,7 +81,7 @@ class DefaultMutation(BaseMutation):
return new_nodes, new_conns
return jax.lax.cond(
(idx == I_INF) & (remain_node_space < 1) & (remain_conn_space < 2),
(idx == I_INF) | (remain_node_space < 1) | (remain_conn_space < 2),
lambda: (nodes_, conns_), # do nothing
successful_add_node,
)
@@ -92,7 +90,6 @@ class DefaultMutation(BaseMutation):
"""
delete a node
"""
# randomly choose a node
key, idx = self.choose_node_key(
key_,
@@ -127,6 +124,8 @@ class DefaultMutation(BaseMutation):
add a connection while do not influence the output of the network
"""
remain_conn_space = jnp.isnan(conns_[:, 0]).sum()
# randomly choose two nodes
k1_, k2_ = jax.random.split(key_, num=2)
@@ -164,7 +163,7 @@ class DefaultMutation(BaseMutation):
if genome.network_type == "feedforward":
u_conns = unflatten_conns(nodes_, conns_)
conns_exist = (u_conns != I_INF)
conns_exist = u_conns != I_INF
is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx)
return jax.lax.cond(

View File

@@ -40,3 +40,6 @@ class BaseGene(StatefulBaseClass):
@property
def length(self):
return len(self.fixed_attrs) + len(self.custom_attrs)
def repr(self, state, gene, precision=2):
raise NotImplementedError

View File

@@ -22,3 +22,12 @@ class BaseConnGene(BaseGene):
jax.vmap(self.forward, in_axes=(None, None, 0))(state, attrs, batch_inputs),
attrs,
)
def repr(self, state, conn, precision=2, idx_width=3, func_width=8):
in_idx, out_idx = conn[:2]
in_idx = int(in_idx)
out_idx = int(out_idx)
return "{}(in: {:<{idx_width}}, out: {:<{idx_width}})".format(
self.__class__.__name__, in_idx, out_idx, idx_width=idx_width
)

View File

@@ -60,3 +60,19 @@ class DefaultConnGene(BaseConnGene):
def forward(self, state, attrs, inputs):
weight = attrs[0]
return inputs * weight
def repr(self, state, conn, precision=2, idx_width=3, func_width=8):
in_idx, out_idx, weight = conn
in_idx = int(in_idx)
out_idx = int(out_idx)
weight = round(float(weight), precision)
return "{}(in: {:<{idx_width}}, out: {:<{idx_width}}, weight: {:<{float_width}})".format(
self.__class__.__name__,
in_idx,
out_idx,
weight,
idx_width=idx_width,
float_width=precision + 3,
)

View File

@@ -39,3 +39,11 @@ class BaseNodeGene(BaseGene):
),
attrs,
)
def repr(self, state, node, precision=2, idx_width=3, func_width=8):
idx = node[0]
idx = int(idx)
return "{}(idx={:<{idx_width}})".format(
self.__class__.__name__, idx, idx_width=idx_width
)

View File

@@ -122,3 +122,28 @@ class DefaultNodeGene(BaseNodeGene):
)
return z
def repr(self, state, node, precision=2, idx_width=3, func_width=8):
idx, bias, res, agg, act = node
idx = int(idx)
bias = round(float(bias), precision)
res = round(float(res), precision)
agg = int(agg)
act = int(act)
if act == -1:
act_func = Act.identity
else:
act_func = self.activation_options[act]
return "{}(idx={:<{idx_width}}, bias={:<{float_width}}, response={:<{float_width}}, aggregation={:<{func_width}}, activation={:<{func_width}})".format(
self.__class__.__name__,
idx,
bias,
res,
self.aggregation_options[agg].__name__,
act_func.__name__,
idx_width=idx_width,
float_width=precision + 3,
func_width=func_width
)

View File

@@ -98,3 +98,26 @@ class NodeGeneWithoutResponse(BaseNodeGene):
)
return z
def repr(self, state, node, precision=2, idx_width=3, func_width=8):
idx, bias, agg, act = node
idx = int(idx)
bias = round(float(bias), precision)
agg = int(agg)
act = int(act)
if act == -1:
act_func = Act.identity
else:
act_func = self.activation_options[act]
return "{}(idx={:<{idx_width}}, bias={:<{float_width}}, aggregation={:<{func_width}}, activation={:<{func_width}})".format(
self.__class__.__name__,
idx,
bias,
self.aggregation_options[agg].__name__,
act_func.__name__,
idx_width=idx_width,
float_width=precision + 3,
func_width=func_width,
)

View File

@@ -25,3 +25,8 @@ class KANNode(BaseNodeGene):
def forward(self, state, attrs, inputs, is_output_node=False):
return Agg.sum(inputs)
def repr(self, state, node, precision=2):
idx = node[0]
idx = int(idx)
return "{}(idx: {})".format(self.__class__.__name__, idx)

View File

@@ -1,3 +1,4 @@
import numpy as np
import jax, jax.numpy as jnp
from ..gene import BaseNodeGene, BaseConnGene
from ..ga import BaseMutation, BaseCrossover
@@ -20,8 +21,8 @@ class BaseGenome(StatefulBaseClass):
):
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.input_idx = jnp.arange(num_inputs)
self.output_idx = jnp.arange(num_inputs, num_inputs + num_outputs)
self.input_idx = np.arange(num_inputs)
self.output_idx = np.arange(num_inputs, num_inputs + num_outputs)
self.max_nodes = max_nodes
self.max_conns = max_conns
self.node_gene = node_gene
@@ -127,3 +128,30 @@ class BaseGenome(StatefulBaseClass):
Update the genome by a batch of data.
"""
raise NotImplementedError
def repr(self, state, nodes, conns, precision=2):
nodes, conns = jax.device_get([nodes, conns])
nodes_cnt, conns_cnt = self.valid_cnt(nodes), self.valid_cnt(conns)
s = f"{self.__class__.__name__}(nodes={nodes_cnt}, conns={conns_cnt}):\n"
s += f"\tNodes:\n"
for node in nodes:
if np.isnan(node[0]):
break
s += f"\t\t{self.node_gene.repr(state, node, precision=precision)}"
node_idx = int(node[0])
if np.isin(node_idx, self.input_idx):
s += " (input)"
elif np.isin(node_idx, self.output_idx):
s += " (output)"
s += "\n"
s += f"\tConns:\n"
for conn in conns:
if np.isnan(conn[0]):
break
s += f"\t\t{self.conn_gene.repr(state, conn, precision=precision)}\n"
return s
@classmethod
def valid_cnt(cls, arr):
return jnp.sum(~jnp.isnan(arr[:, 0]))

View File

@@ -47,3 +47,4 @@ if __name__ == "__main__":
state, best = pipeline.auto_run(state)
# show result
pipeline.show(state, best)
pipeline.save(state=state)

View File

@@ -37,7 +37,7 @@ class StatefulBaseClass:
if "aux_for_state" in obj.__dict__:
if warning:
warnings.warn(
"This object state to load, ignore it",
"This object has state to load, ignore it",
category=UserWarning,
)
del obj.__dict__["aux_for_state"]