add repr for genome and gene;
add ipynb test for testing whether add node or add conn will not change the output for the network.
This commit is contained in:
@@ -41,15 +41,13 @@ class DefaultMutation(BaseMutation):
|
|||||||
return nodes, conns
|
return nodes, conns
|
||||||
|
|
||||||
def mutate_structure(self, state, randkey, genome, nodes, conns, new_node_key):
|
def mutate_structure(self, state, randkey, genome, nodes, conns, new_node_key):
|
||||||
|
|
||||||
remain_node_space = jnp.isnan(nodes[:, 0]).sum()
|
|
||||||
remain_conn_space = jnp.isnan(conns[:, 0]).sum()
|
|
||||||
|
|
||||||
def mutate_add_node(key_, nodes_, conns_):
|
def mutate_add_node(key_, nodes_, conns_):
|
||||||
"""
|
"""
|
||||||
add a node while do not influence the output of the network
|
add a node while do not influence the output of the network
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
remain_node_space = jnp.isnan(nodes_[:, 0]).sum()
|
||||||
|
remain_conn_space = jnp.isnan(conns_[:, 0]).sum()
|
||||||
i_key, o_key, idx = self.choose_connection_key(
|
i_key, o_key, idx = self.choose_connection_key(
|
||||||
key_, conns_
|
key_, conns_
|
||||||
) # choose a connection
|
) # choose a connection
|
||||||
@@ -83,7 +81,7 @@ class DefaultMutation(BaseMutation):
|
|||||||
return new_nodes, new_conns
|
return new_nodes, new_conns
|
||||||
|
|
||||||
return jax.lax.cond(
|
return jax.lax.cond(
|
||||||
(idx == I_INF) & (remain_node_space < 1) & (remain_conn_space < 2),
|
(idx == I_INF) | (remain_node_space < 1) | (remain_conn_space < 2),
|
||||||
lambda: (nodes_, conns_), # do nothing
|
lambda: (nodes_, conns_), # do nothing
|
||||||
successful_add_node,
|
successful_add_node,
|
||||||
)
|
)
|
||||||
@@ -92,7 +90,6 @@ class DefaultMutation(BaseMutation):
|
|||||||
"""
|
"""
|
||||||
delete a node
|
delete a node
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# randomly choose a node
|
# randomly choose a node
|
||||||
key, idx = self.choose_node_key(
|
key, idx = self.choose_node_key(
|
||||||
key_,
|
key_,
|
||||||
@@ -127,6 +124,8 @@ class DefaultMutation(BaseMutation):
|
|||||||
add a connection while do not influence the output of the network
|
add a connection while do not influence the output of the network
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
remain_conn_space = jnp.isnan(conns_[:, 0]).sum()
|
||||||
|
|
||||||
# randomly choose two nodes
|
# randomly choose two nodes
|
||||||
k1_, k2_ = jax.random.split(key_, num=2)
|
k1_, k2_ = jax.random.split(key_, num=2)
|
||||||
|
|
||||||
@@ -164,7 +163,7 @@ class DefaultMutation(BaseMutation):
|
|||||||
|
|
||||||
if genome.network_type == "feedforward":
|
if genome.network_type == "feedforward":
|
||||||
u_conns = unflatten_conns(nodes_, conns_)
|
u_conns = unflatten_conns(nodes_, conns_)
|
||||||
conns_exist = (u_conns != I_INF)
|
conns_exist = u_conns != I_INF
|
||||||
is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx)
|
is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx)
|
||||||
|
|
||||||
return jax.lax.cond(
|
return jax.lax.cond(
|
||||||
|
|||||||
@@ -40,3 +40,6 @@ class BaseGene(StatefulBaseClass):
|
|||||||
@property
|
@property
|
||||||
def length(self):
|
def length(self):
|
||||||
return len(self.fixed_attrs) + len(self.custom_attrs)
|
return len(self.fixed_attrs) + len(self.custom_attrs)
|
||||||
|
|
||||||
|
def repr(self, state, gene, precision=2):
|
||||||
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -22,3 +22,12 @@ class BaseConnGene(BaseGene):
|
|||||||
jax.vmap(self.forward, in_axes=(None, None, 0))(state, attrs, batch_inputs),
|
jax.vmap(self.forward, in_axes=(None, None, 0))(state, attrs, batch_inputs),
|
||||||
attrs,
|
attrs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def repr(self, state, conn, precision=2, idx_width=3, func_width=8):
|
||||||
|
in_idx, out_idx = conn[:2]
|
||||||
|
in_idx = int(in_idx)
|
||||||
|
out_idx = int(out_idx)
|
||||||
|
|
||||||
|
return "{}(in: {:<{idx_width}}, out: {:<{idx_width}})".format(
|
||||||
|
self.__class__.__name__, in_idx, out_idx, idx_width=idx_width
|
||||||
|
)
|
||||||
|
|||||||
@@ -60,3 +60,19 @@ class DefaultConnGene(BaseConnGene):
|
|||||||
def forward(self, state, attrs, inputs):
|
def forward(self, state, attrs, inputs):
|
||||||
weight = attrs[0]
|
weight = attrs[0]
|
||||||
return inputs * weight
|
return inputs * weight
|
||||||
|
|
||||||
|
def repr(self, state, conn, precision=2, idx_width=3, func_width=8):
|
||||||
|
in_idx, out_idx, weight = conn
|
||||||
|
|
||||||
|
in_idx = int(in_idx)
|
||||||
|
out_idx = int(out_idx)
|
||||||
|
weight = round(float(weight), precision)
|
||||||
|
|
||||||
|
return "{}(in: {:<{idx_width}}, out: {:<{idx_width}}, weight: {:<{float_width}})".format(
|
||||||
|
self.__class__.__name__,
|
||||||
|
in_idx,
|
||||||
|
out_idx,
|
||||||
|
weight,
|
||||||
|
idx_width=idx_width,
|
||||||
|
float_width=precision + 3,
|
||||||
|
)
|
||||||
|
|||||||
@@ -39,3 +39,11 @@ class BaseNodeGene(BaseGene):
|
|||||||
),
|
),
|
||||||
attrs,
|
attrs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def repr(self, state, node, precision=2, idx_width=3, func_width=8):
|
||||||
|
idx = node[0]
|
||||||
|
|
||||||
|
idx = int(idx)
|
||||||
|
return "{}(idx={:<{idx_width}})".format(
|
||||||
|
self.__class__.__name__, idx, idx_width=idx_width
|
||||||
|
)
|
||||||
|
|||||||
@@ -122,3 +122,28 @@ class DefaultNodeGene(BaseNodeGene):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return z
|
return z
|
||||||
|
|
||||||
|
def repr(self, state, node, precision=2, idx_width=3, func_width=8):
|
||||||
|
idx, bias, res, agg, act = node
|
||||||
|
|
||||||
|
idx = int(idx)
|
||||||
|
bias = round(float(bias), precision)
|
||||||
|
res = round(float(res), precision)
|
||||||
|
agg = int(agg)
|
||||||
|
act = int(act)
|
||||||
|
|
||||||
|
if act == -1:
|
||||||
|
act_func = Act.identity
|
||||||
|
else:
|
||||||
|
act_func = self.activation_options[act]
|
||||||
|
return "{}(idx={:<{idx_width}}, bias={:<{float_width}}, response={:<{float_width}}, aggregation={:<{func_width}}, activation={:<{func_width}})".format(
|
||||||
|
self.__class__.__name__,
|
||||||
|
idx,
|
||||||
|
bias,
|
||||||
|
res,
|
||||||
|
self.aggregation_options[agg].__name__,
|
||||||
|
act_func.__name__,
|
||||||
|
idx_width=idx_width,
|
||||||
|
float_width=precision + 3,
|
||||||
|
func_width=func_width
|
||||||
|
)
|
||||||
|
|||||||
@@ -98,3 +98,26 @@ class NodeGeneWithoutResponse(BaseNodeGene):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return z
|
return z
|
||||||
|
|
||||||
|
def repr(self, state, node, precision=2, idx_width=3, func_width=8):
|
||||||
|
idx, bias, agg, act = node
|
||||||
|
|
||||||
|
idx = int(idx)
|
||||||
|
bias = round(float(bias), precision)
|
||||||
|
agg = int(agg)
|
||||||
|
act = int(act)
|
||||||
|
|
||||||
|
if act == -1:
|
||||||
|
act_func = Act.identity
|
||||||
|
else:
|
||||||
|
act_func = self.activation_options[act]
|
||||||
|
return "{}(idx={:<{idx_width}}, bias={:<{float_width}}, aggregation={:<{func_width}}, activation={:<{func_width}})".format(
|
||||||
|
self.__class__.__name__,
|
||||||
|
idx,
|
||||||
|
bias,
|
||||||
|
self.aggregation_options[agg].__name__,
|
||||||
|
act_func.__name__,
|
||||||
|
idx_width=idx_width,
|
||||||
|
float_width=precision + 3,
|
||||||
|
func_width=func_width,
|
||||||
|
)
|
||||||
|
|||||||
@@ -25,3 +25,8 @@ class KANNode(BaseNodeGene):
|
|||||||
|
|
||||||
def forward(self, state, attrs, inputs, is_output_node=False):
|
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||||
return Agg.sum(inputs)
|
return Agg.sum(inputs)
|
||||||
|
|
||||||
|
def repr(self, state, node, precision=2):
|
||||||
|
idx = node[0]
|
||||||
|
idx = int(idx)
|
||||||
|
return "{}(idx: {})".format(self.__class__.__name__, idx)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import numpy as np
|
||||||
import jax, jax.numpy as jnp
|
import jax, jax.numpy as jnp
|
||||||
from ..gene import BaseNodeGene, BaseConnGene
|
from ..gene import BaseNodeGene, BaseConnGene
|
||||||
from ..ga import BaseMutation, BaseCrossover
|
from ..ga import BaseMutation, BaseCrossover
|
||||||
@@ -20,8 +21,8 @@ class BaseGenome(StatefulBaseClass):
|
|||||||
):
|
):
|
||||||
self.num_inputs = num_inputs
|
self.num_inputs = num_inputs
|
||||||
self.num_outputs = num_outputs
|
self.num_outputs = num_outputs
|
||||||
self.input_idx = jnp.arange(num_inputs)
|
self.input_idx = np.arange(num_inputs)
|
||||||
self.output_idx = jnp.arange(num_inputs, num_inputs + num_outputs)
|
self.output_idx = np.arange(num_inputs, num_inputs + num_outputs)
|
||||||
self.max_nodes = max_nodes
|
self.max_nodes = max_nodes
|
||||||
self.max_conns = max_conns
|
self.max_conns = max_conns
|
||||||
self.node_gene = node_gene
|
self.node_gene = node_gene
|
||||||
@@ -127,3 +128,30 @@ class BaseGenome(StatefulBaseClass):
|
|||||||
Update the genome by a batch of data.
|
Update the genome by a batch of data.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def repr(self, state, nodes, conns, precision=2):
|
||||||
|
nodes, conns = jax.device_get([nodes, conns])
|
||||||
|
nodes_cnt, conns_cnt = self.valid_cnt(nodes), self.valid_cnt(conns)
|
||||||
|
s = f"{self.__class__.__name__}(nodes={nodes_cnt}, conns={conns_cnt}):\n"
|
||||||
|
s += f"\tNodes:\n"
|
||||||
|
for node in nodes:
|
||||||
|
if np.isnan(node[0]):
|
||||||
|
break
|
||||||
|
s += f"\t\t{self.node_gene.repr(state, node, precision=precision)}"
|
||||||
|
node_idx = int(node[0])
|
||||||
|
if np.isin(node_idx, self.input_idx):
|
||||||
|
s += " (input)"
|
||||||
|
elif np.isin(node_idx, self.output_idx):
|
||||||
|
s += " (output)"
|
||||||
|
s += "\n"
|
||||||
|
|
||||||
|
s += f"\tConns:\n"
|
||||||
|
for conn in conns:
|
||||||
|
if np.isnan(conn[0]):
|
||||||
|
break
|
||||||
|
s += f"\t\t{self.conn_gene.repr(state, conn, precision=precision)}\n"
|
||||||
|
return s
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def valid_cnt(cls, arr):
|
||||||
|
return jnp.sum(~jnp.isnan(arr[:, 0]))
|
||||||
|
|||||||
@@ -47,3 +47,4 @@ if __name__ == "__main__":
|
|||||||
state, best = pipeline.auto_run(state)
|
state, best = pipeline.auto_run(state)
|
||||||
# show result
|
# show result
|
||||||
pipeline.show(state, best)
|
pipeline.show(state, best)
|
||||||
|
pipeline.save(state=state)
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class StatefulBaseClass:
|
|||||||
if "aux_for_state" in obj.__dict__:
|
if "aux_for_state" in obj.__dict__:
|
||||||
if warning:
|
if warning:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"This object state to load, ignore it",
|
"This object has state to load, ignore it",
|
||||||
category=UserWarning,
|
category=UserWarning,
|
||||||
)
|
)
|
||||||
del obj.__dict__["aux_for_state"]
|
del obj.__dict__["aux_for_state"]
|
||||||
|
|||||||
Reference in New Issue
Block a user