add repr for genome and gene;

add ipynb test for testing whether add node or add conn will not change the output for the network.
This commit is contained in:
wls2002
2024-06-09 22:32:29 +08:00
parent 52e5d603f5
commit dfc8f9198e
11 changed files with 127 additions and 10 deletions

View File

@@ -1,3 +1,4 @@
import numpy as np
import jax, jax.numpy as jnp
from ..gene import BaseNodeGene, BaseConnGene
from ..ga import BaseMutation, BaseCrossover
@@ -20,8 +21,8 @@ class BaseGenome(StatefulBaseClass):
):
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.input_idx = jnp.arange(num_inputs)
self.output_idx = jnp.arange(num_inputs, num_inputs + num_outputs)
self.input_idx = np.arange(num_inputs)
self.output_idx = np.arange(num_inputs, num_inputs + num_outputs)
self.max_nodes = max_nodes
self.max_conns = max_conns
self.node_gene = node_gene
@@ -127,3 +128,30 @@ class BaseGenome(StatefulBaseClass):
Update the genome by a batch of data.
"""
raise NotImplementedError
def repr(self, state, nodes, conns, precision=2):
nodes, conns = jax.device_get([nodes, conns])
nodes_cnt, conns_cnt = self.valid_cnt(nodes), self.valid_cnt(conns)
s = f"{self.__class__.__name__}(nodes={nodes_cnt}, conns={conns_cnt}):\n"
s += f"\tNodes:\n"
for node in nodes:
if np.isnan(node[0]):
break
s += f"\t\t{self.node_gene.repr(state, node, precision=precision)}"
node_idx = int(node[0])
if np.isin(node_idx, self.input_idx):
s += " (input)"
elif np.isin(node_idx, self.output_idx):
s += " (output)"
s += "\n"
s += f"\tConns:\n"
for conn in conns:
if np.isnan(conn[0]):
break
s += f"\t\t{self.conn_gene.repr(state, conn, precision=precision)}\n"
return s
@classmethod
def valid_cnt(cls, arr):
return jnp.sum(~jnp.isnan(arr[:, 0]))