add repr for genome and gene;

add ipynb test for testing whether add node or add conn will not change the output for the network.
This commit is contained in:
wls2002
2024-06-09 22:32:29 +08:00
parent 52e5d603f5
commit dfc8f9198e
11 changed files with 127 additions and 10 deletions

View File

@@ -41,15 +41,13 @@ class DefaultMutation(BaseMutation):
return nodes, conns return nodes, conns
def mutate_structure(self, state, randkey, genome, nodes, conns, new_node_key): def mutate_structure(self, state, randkey, genome, nodes, conns, new_node_key):
remain_node_space = jnp.isnan(nodes[:, 0]).sum()
remain_conn_space = jnp.isnan(conns[:, 0]).sum()
def mutate_add_node(key_, nodes_, conns_): def mutate_add_node(key_, nodes_, conns_):
""" """
add a node while do not influence the output of the network add a node while do not influence the output of the network
""" """
remain_node_space = jnp.isnan(nodes_[:, 0]).sum()
remain_conn_space = jnp.isnan(conns_[:, 0]).sum()
i_key, o_key, idx = self.choose_connection_key( i_key, o_key, idx = self.choose_connection_key(
key_, conns_ key_, conns_
) # choose a connection ) # choose a connection
@@ -83,7 +81,7 @@ class DefaultMutation(BaseMutation):
return new_nodes, new_conns return new_nodes, new_conns
return jax.lax.cond( return jax.lax.cond(
(idx == I_INF) & (remain_node_space < 1) & (remain_conn_space < 2), (idx == I_INF) | (remain_node_space < 1) | (remain_conn_space < 2),
lambda: (nodes_, conns_), # do nothing lambda: (nodes_, conns_), # do nothing
successful_add_node, successful_add_node,
) )
@@ -92,7 +90,6 @@ class DefaultMutation(BaseMutation):
""" """
delete a node delete a node
""" """
# randomly choose a node # randomly choose a node
key, idx = self.choose_node_key( key, idx = self.choose_node_key(
key_, key_,
@@ -127,6 +124,8 @@ class DefaultMutation(BaseMutation):
add a connection while do not influence the output of the network add a connection while do not influence the output of the network
""" """
remain_conn_space = jnp.isnan(conns_[:, 0]).sum()
# randomly choose two nodes # randomly choose two nodes
k1_, k2_ = jax.random.split(key_, num=2) k1_, k2_ = jax.random.split(key_, num=2)
@@ -164,7 +163,7 @@ class DefaultMutation(BaseMutation):
if genome.network_type == "feedforward": if genome.network_type == "feedforward":
u_conns = unflatten_conns(nodes_, conns_) u_conns = unflatten_conns(nodes_, conns_)
conns_exist = (u_conns != I_INF) conns_exist = u_conns != I_INF
is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx) is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx)
return jax.lax.cond( return jax.lax.cond(

View File

@@ -40,3 +40,6 @@ class BaseGene(StatefulBaseClass):
@property @property
def length(self): def length(self):
return len(self.fixed_attrs) + len(self.custom_attrs) return len(self.fixed_attrs) + len(self.custom_attrs)
def repr(self, state, gene, precision=2):
raise NotImplementedError

View File

@@ -22,3 +22,12 @@ class BaseConnGene(BaseGene):
jax.vmap(self.forward, in_axes=(None, None, 0))(state, attrs, batch_inputs), jax.vmap(self.forward, in_axes=(None, None, 0))(state, attrs, batch_inputs),
attrs, attrs,
) )
def repr(self, state, conn, precision=2, idx_width=3, func_width=8):
in_idx, out_idx = conn[:2]
in_idx = int(in_idx)
out_idx = int(out_idx)
return "{}(in: {:<{idx_width}}, out: {:<{idx_width}})".format(
self.__class__.__name__, in_idx, out_idx, idx_width=idx_width
)

View File

@@ -60,3 +60,19 @@ class DefaultConnGene(BaseConnGene):
def forward(self, state, attrs, inputs): def forward(self, state, attrs, inputs):
weight = attrs[0] weight = attrs[0]
return inputs * weight return inputs * weight
def repr(self, state, conn, precision=2, idx_width=3, func_width=8):
in_idx, out_idx, weight = conn
in_idx = int(in_idx)
out_idx = int(out_idx)
weight = round(float(weight), precision)
return "{}(in: {:<{idx_width}}, out: {:<{idx_width}}, weight: {:<{float_width}})".format(
self.__class__.__name__,
in_idx,
out_idx,
weight,
idx_width=idx_width,
float_width=precision + 3,
)

View File

@@ -39,3 +39,11 @@ class BaseNodeGene(BaseGene):
), ),
attrs, attrs,
) )
def repr(self, state, node, precision=2, idx_width=3, func_width=8):
idx = node[0]
idx = int(idx)
return "{}(idx={:<{idx_width}})".format(
self.__class__.__name__, idx, idx_width=idx_width
)

View File

@@ -122,3 +122,28 @@ class DefaultNodeGene(BaseNodeGene):
) )
return z return z
def repr(self, state, node, precision=2, idx_width=3, func_width=8):
idx, bias, res, agg, act = node
idx = int(idx)
bias = round(float(bias), precision)
res = round(float(res), precision)
agg = int(agg)
act = int(act)
if act == -1:
act_func = Act.identity
else:
act_func = self.activation_options[act]
return "{}(idx={:<{idx_width}}, bias={:<{float_width}}, response={:<{float_width}}, aggregation={:<{func_width}}, activation={:<{func_width}})".format(
self.__class__.__name__,
idx,
bias,
res,
self.aggregation_options[agg].__name__,
act_func.__name__,
idx_width=idx_width,
float_width=precision + 3,
func_width=func_width
)

View File

@@ -98,3 +98,26 @@ class NodeGeneWithoutResponse(BaseNodeGene):
) )
return z return z
def repr(self, state, node, precision=2, idx_width=3, func_width=8):
idx, bias, agg, act = node
idx = int(idx)
bias = round(float(bias), precision)
agg = int(agg)
act = int(act)
if act == -1:
act_func = Act.identity
else:
act_func = self.activation_options[act]
return "{}(idx={:<{idx_width}}, bias={:<{float_width}}, aggregation={:<{func_width}}, activation={:<{func_width}})".format(
self.__class__.__name__,
idx,
bias,
self.aggregation_options[agg].__name__,
act_func.__name__,
idx_width=idx_width,
float_width=precision + 3,
func_width=func_width,
)

View File

@@ -25,3 +25,8 @@ class KANNode(BaseNodeGene):
def forward(self, state, attrs, inputs, is_output_node=False): def forward(self, state, attrs, inputs, is_output_node=False):
return Agg.sum(inputs) return Agg.sum(inputs)
def repr(self, state, node, precision=2):
idx = node[0]
idx = int(idx)
return "{}(idx: {})".format(self.__class__.__name__, idx)

View File

@@ -1,3 +1,4 @@
import numpy as np
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from ..gene import BaseNodeGene, BaseConnGene from ..gene import BaseNodeGene, BaseConnGene
from ..ga import BaseMutation, BaseCrossover from ..ga import BaseMutation, BaseCrossover
@@ -20,8 +21,8 @@ class BaseGenome(StatefulBaseClass):
): ):
self.num_inputs = num_inputs self.num_inputs = num_inputs
self.num_outputs = num_outputs self.num_outputs = num_outputs
self.input_idx = jnp.arange(num_inputs) self.input_idx = np.arange(num_inputs)
self.output_idx = jnp.arange(num_inputs, num_inputs + num_outputs) self.output_idx = np.arange(num_inputs, num_inputs + num_outputs)
self.max_nodes = max_nodes self.max_nodes = max_nodes
self.max_conns = max_conns self.max_conns = max_conns
self.node_gene = node_gene self.node_gene = node_gene
@@ -127,3 +128,30 @@ class BaseGenome(StatefulBaseClass):
Update the genome by a batch of data. Update the genome by a batch of data.
""" """
raise NotImplementedError raise NotImplementedError
def repr(self, state, nodes, conns, precision=2):
nodes, conns = jax.device_get([nodes, conns])
nodes_cnt, conns_cnt = self.valid_cnt(nodes), self.valid_cnt(conns)
s = f"{self.__class__.__name__}(nodes={nodes_cnt}, conns={conns_cnt}):\n"
s += f"\tNodes:\n"
for node in nodes:
if np.isnan(node[0]):
break
s += f"\t\t{self.node_gene.repr(state, node, precision=precision)}"
node_idx = int(node[0])
if np.isin(node_idx, self.input_idx):
s += " (input)"
elif np.isin(node_idx, self.output_idx):
s += " (output)"
s += "\n"
s += f"\tConns:\n"
for conn in conns:
if np.isnan(conn[0]):
break
s += f"\t\t{self.conn_gene.repr(state, conn, precision=precision)}\n"
return s
@classmethod
def valid_cnt(cls, arr):
return jnp.sum(~jnp.isnan(arr[:, 0]))

View File

@@ -47,3 +47,4 @@ if __name__ == "__main__":
state, best = pipeline.auto_run(state) state, best = pipeline.auto_run(state)
# show result # show result
pipeline.show(state, best) pipeline.show(state, best)
pipeline.save(state=state)

View File

@@ -37,7 +37,7 @@ class StatefulBaseClass:
if "aux_for_state" in obj.__dict__: if "aux_for_state" in obj.__dict__:
if warning: if warning:
warnings.warn( warnings.warn(
"This object state to load, ignore it", "This object has state to load, ignore it",
category=UserWarning, category=UserWarning,
) )
del obj.__dict__["aux_for_state"] del obj.__dict__["aux_for_state"]