create function "distance_numpy", serve as o2o distance function

This commit is contained in:
wls2002
2023-05-07 23:47:53 +08:00
parent b257505bee
commit 64f8eaccaf
3 changed files with 95 additions and 6 deletions

View File

@@ -1,5 +1,7 @@
from jax import jit, vmap, Array from jax import jit, vmap, Array
from jax import numpy as jnp from jax import numpy as jnp
import numpy as np
from numpy.typing import NDArray
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
@@ -14,7 +16,11 @@ def create_distance_function(config, type: str):
compatibility_coe = config.neat.genome.compatibility_weight_coefficient compatibility_coe = config.neat.genome.compatibility_weight_coefficient
if type == 'o2o': if type == 'o2o':
return lambda nodes1, connections1, nodes2, connections2: \ return lambda nodes1, connections1, nodes2, connections2: \
distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe) distance_numpy(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
# return lambda nodes1, connections1, nodes2, connections2: \
# distance(nodes1, connections1, nodes2, connections2, disjoint_coe, compatibility_coe)
elif type == 'o2m': elif type == 'o2m':
func = vmap(distance, in_axes=(None, None, 0, 0, None, None)) func = vmap(distance, in_axes=(None, None, 0, 0, None, None))
return lambda nodes1, connections1, batch_nodes2, batch_connections2: \ return lambda nodes1, connections1, batch_nodes2, batch_connections2: \
@@ -23,6 +29,89 @@ def create_distance_function(config, type: str):
raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]') raise ValueError(f'unknown distance type: {type}, should be one of ["o2o", "o2m"]')
def distance_numpy(nodes1: NDArray, connection1: NDArray, nodes2: NDArray,
connection2: NDArray, disjoint_coe: float = 1., compatibility_coe: float = 0.5):
"""
use in o2o distance.
o2o can't use vmap, numpy should be faster than jax function
:param nodes1:
:param connection1:
:param nodes2:
:param connection2:
:param disjoint_coe:
:param compatibility_coe:
:return:
"""
def analysis(nodes, connections):
nodes_dict = {}
idx2key = {}
for i, node in enumerate(nodes):
if np.isnan(node[0]):
continue
key = int(node[0])
nodes_dict[key] = (node[1], node[2], node[3], node[4])
idx2key[i] = key
connections_dict = {}
for i in range(connections.shape[1]):
for j in range(connections.shape[2]):
if np.isnan(connections[0, i, j]) and np.isnan(connections[1, i, j]):
continue
key = (idx2key[i], idx2key[j])
weight = connections[0, i, j] if not np.isnan(connections[0, i, j]) else None
enabled = (connections[1, i, j] == 1) if not np.isnan(connections[1, i, j]) else None
connections_dict[key] = (weight, enabled)
return nodes_dict, connections_dict
nodes1, connections1 = analysis(nodes1, connection1)
nodes2, connections2 = analysis(nodes2, connection2)
nd = 0.0
if nodes1 or nodes2: # otherwise, both are empty
disjoint_nodes = 0
for k2 in nodes2:
if k2 not in nodes1:
disjoint_nodes += 1
for k1, n1 in nodes1.items():
n2 = nodes2.get(k1)
if n2 is None:
disjoint_nodes += 1
else:
if np.isnan(n1[0]): # n1[1] is nan means input nodes
continue
d = abs(n1[0] - n2[0]) + abs(n1[1] - n2[1])
d += 1 if n1[2] != n2[2] else 0
d += 1 if n1[3] != n2[3] else 0
nd += d
max_nodes = max(len(nodes1), len(nodes2))
nd = (compatibility_coe * nd + disjoint_coe * disjoint_nodes) / max_nodes
cd = 0.0
if connections1 or connections2:
disjoint_connections = 0
for k2 in connections2:
if k2 not in connections1:
disjoint_connections += 1
for k1, c1 in connections1.items():
c2 = connections2.get(k1)
if c2 is None:
disjoint_connections += 1
else:
# Homologous genes compute their own distance value.
d = abs(c1[0] - c2[0])
d += 1 if c1[1] != c2[1] else 0
cd += d
max_conn = max(len(connections1), len(connections2))
cd = (compatibility_coe * cd + disjoint_coe * disjoint_connections) / max_conn
return nd + cd
@jit @jit
def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array, disjoint_coe: float = 1., def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Array, disjoint_coe: float = 1.,
compatibility_coe: float = 0.5) -> Array: compatibility_coe: float = 0.5) -> Array:
@@ -46,7 +135,7 @@ def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Ar
def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5): def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5):
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0])) node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0])) node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2) - 2 max_cnt = jnp.maximum(node_cnt1, node_cnt2)
nodes = jnp.concatenate((nodes1, nodes2), axis=0) nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0] keys = nodes[:, 0]

View File

@@ -23,8 +23,8 @@ def evaluate(forward_func: Callable) -> List[float]:
return fitnesses.tolist() # returns a list return fitnesses.tolist() # returns a list
@using_cprofile # @using_cprofile
# @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/") @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
def main(): def main():
config = Configer.load_config() config = Configer.load_config()
pipeline = Pipeline(config, seed=11323) pipeline = Pipeline(config, seed=11323)

View File

@@ -9,8 +9,8 @@
"population": { "population": {
"fitness_criterion": "max", "fitness_criterion": "max",
"fitness_threshold": 76, "fitness_threshold": 76,
"generation_limit": 1000, "generation_limit": 100,
"pop_size": 200, "pop_size": 1000,
"reset_on_extinction": "False" "reset_on_extinction": "False"
}, },
"gene": { "gene": {