refactor genome.py use (C, 4) to replace (2, N, N) to represent connections

faster, faster and faster!
This commit is contained in:
wls2002
2023-05-12 00:57:55 +08:00
parent e5fc1167d9
commit 47b1a1dbb2
16 changed files with 363 additions and 419 deletions

View File

@@ -20,7 +20,7 @@ from jax import numpy as jnp
from jax import jit
from jax import Array
from .utils import fetch_first, EMPTY_NODE
from .utils import fetch_first
def initialize_genomes(pop_size: int,
@@ -124,79 +124,6 @@ def expand_single(nodes: NDArray, cons: NDArray, new_N: int, new_C: int) -> Tupl
return new_nodes, new_cons
def analysis(nodes: NDArray, cons: NDArray, input_keys, output_keys) -> \
Tuple[Dict[int, Tuple[float, float, int, int]], Dict[Tuple[int, int], Tuple[float, bool]]]:
"""
Convert a genome from array to dict.
:param nodes: (N, 5)
:param cons: (C, 4)
:param output_keys:
:param input_keys:
:return: nodes_dict[key: (bias, response, act, agg)], cons_dict[(i_key, o_key): (weight, enabled)]
"""
# update nodes_dict
try:
nodes_dict = {}
for i, node in enumerate(nodes):
if np.isnan(node[0]):
continue
key = int(node[0])
assert key not in nodes_dict, f"Duplicate node key: {key}!"
bias = node[1] if not np.isnan(node[1]) else None
response = node[2] if not np.isnan(node[2]) else None
act = node[3] if not np.isnan(node[3]) else None
agg = node[4] if not np.isnan(node[4]) else None
nodes_dict[key] = (bias, response, act, agg)
# check nodes_dict
for i in input_keys:
assert i in nodes_dict, f"Input node {i} not found in nodes_dict!"
bias, response, act, agg = nodes_dict[i]
assert bias is None and response is None and act is None and agg is None, \
f"Input node {i} must has None bias, response, act, or agg!"
for o in output_keys:
assert o in nodes_dict, f"Output node {o} not found in nodes_dict!"
for k, v in nodes_dict.items():
if k not in input_keys:
bias, response, act, agg = v
assert bias is not None and response is not None and act is not None and agg is not None, \
f"Normal node {k} must has non-None bias, response, act, or agg!"
# update connections
cons_dict = {}
for i, con in enumerate(cons):
if np.isnan(con[0]):
continue
assert ~np.isnan(con[1]), f"Connection {i} must has non-None o_key!"
i_key = int(con[0])
o_key = int(con[1])
assert i_key in nodes_dict, f"Input node {i_key} not found in nodes_dict!"
assert o_key in nodes_dict, f"Output node {o_key} not found in nodes_dict!"
key = (i_key, o_key)
weight = con[2] if not np.isnan(con[2]) else None
enabled = (con[3] == 1) if not np.isnan(con[3]) else None
assert weight is not None, f"Connection {key} must has non-None weight!"
assert enabled is not None, f"Connection {key} must has non-None enabled!"
cons_dict[key] = (weight, enabled)
return nodes_dict, cons_dict
except AssertionError:
print(nodes)
print(cons)
raise AssertionError
def pop_analysis(pop_nodes, pop_cons, input_keys, output_keys):
res = []
for nodes, cons in zip(pop_nodes, pop_cons):
res.append(analysis(nodes, cons, input_keys, output_keys))
return res
@jit
def count(nodes, cons):
node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0]))
@@ -231,7 +158,7 @@ def delete_node_by_idx(nodes: Array, cons: Array, idx: int) -> Tuple[Array, Arra
"""
use idx to delete a node from the genome. only delete the node, regardless of connections.
"""
nodes = nodes.at[idx].set(EMPTY_NODE)
nodes = nodes.at[idx].set(np.nan)
return nodes, cons
@@ -243,7 +170,7 @@ def add_connection(nodes: Array, cons: Array, i_key: int, o_key: int,
"""
con_keys = cons[:, 0]
idx = fetch_first(jnp.isnan(con_keys))
return add_connection_by_idx(idx, nodes, cons, i_key, o_key, weight, enabled)
return add_connection_by_idx(nodes, cons, idx, i_key, o_key, weight, enabled)
@jit