bug down! Here it can solve xor successfully!

This commit is contained in:
wls2002
2023-05-07 16:03:52 +08:00
parent d1f54022bd
commit a3b9bca866
12 changed files with 120 additions and 254 deletions

View File

@@ -1,4 +1,4 @@
from .genome import create_initialize_function, expand, expand_single
from .genome import create_initialize_function, expand, expand_single, pop_analysis
from .distance import distance
from .mutate import create_mutate_function
from .forward import create_forward_function

View File

@@ -5,8 +5,7 @@ import jax
from jax import jit, vmap, Array
from jax import numpy as jnp
# from .utils import flatten_connections, unflatten_connections
from algorithms.neat.genome.utils import flatten_connections, unflatten_connections
from .utils import flatten_connections, unflatten_connections
@vmap
@@ -93,59 +92,4 @@ def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
only gene with the same key will be crossover, thus don't need to consider change key
"""
r = jax.random.uniform(rand_key, shape=g1.shape)
return jnp.where(r > 0.5, g1, g2)
if __name__ == '__main__':
import numpy as np
randkey = jax.random.PRNGKey(40)
nodes1 = np.array([
[4, 1, 1, 1, 1],
[6, 2, 2, 2, 2],
[1, 3, 3, 3, 3],
[5, 4, 4, 4, 4],
[np.nan, np.nan, np.nan, np.nan, np.nan]
])
nodes2 = np.array([
[4, 1.5, 1.5, 1.5, 1.5],
[7, 3.5, 3.5, 3.5, 3.5],
[5, 4.5, 4.5, 4.5, 4.5],
[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan],
])
weights1 = np.array([
[
[1, 2, 3, 4., np.nan],
[5, np.nan, 7, 8, np.nan],
[9, 10, 11, 12, np.nan],
[13, 14, 15, 16, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
],
[
[0, 1, 0, 1, np.nan],
[0, np.nan, 0, 1, np.nan],
[0, 1, 0, 1, np.nan],
[0, 1, 0, 1, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
]
])
weights2 = np.array([
[
[1.5, 2.5, 3.5, np.nan, np.nan],
[3.5, 4.5, 5.5, np.nan, np.nan],
[6.5, 7.5, 8.5, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
],
[
[1, 0, 1, np.nan, np.nan],
[1, 0, 1, np.nan, np.nan],
[1, 0, 1, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
]
])
res = crossover(randkey, nodes1, weights1, nodes2, weights2)
print(*res, sep='\n')
return jnp.where(r > 0.5, g1, g2)

View File

@@ -1,9 +1,7 @@
from functools import partial
from jax import jit, vmap, Array
from jax import numpy as jnp
from algorithms.neat.genome.utils import flatten_connections, set_operation_analysis
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
@jit
@@ -14,55 +12,65 @@ def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Ar
connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable]
"""
node_distance = gene_distance(nodes1, nodes2, 'node')
nd = node_distance(nodes1, nodes2) # node distance
# refactor connections
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
cons1 = flatten_connections(keys1, connections1)
cons2 = flatten_connections(keys2, connections2)
connection_distance = gene_distance(cons1, cons2, 'connection')
return node_distance + connection_distance
cd = connection_distance(cons1, cons2) # connection distance
return nd + cd
@partial(jit, static_argnames=["gene_type"])
def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
if gene_type == 'node':
keys1, keys2 = ar1[:, :1], ar2[:, :1]
else: # connection
keys1, keys2 = ar1[:, :2], ar2[:, :2]
@jit
def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5):
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2)
nodes = jnp.concatenate((ar1, ar2), axis=0)
sorted_nodes = nodes[n_sorted_indices]
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0]
sorted_indices = jnp.argsort(keys, axis=0)
nodes = nodes[sorted_indices]
nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row
nan_mask = jnp.isnan(nodes[:, 0])
if gene_type == 'node':
node_exist_mask = jnp.any(~jnp.isnan(sorted_nodes[:, 1:]), axis=1)
else:
node_exist_mask = jnp.any(~jnp.isnan(sorted_nodes[:, 2:]), axis=1)
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~nan_mask[:-1]
n_intersect_mask = n_intersect_mask & node_exist_mask
n_union_mask = n_union_mask & node_exist_mask
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
nd = batch_homologous_node_distance(fr, sr)
nd = jnp.where(jnp.isnan(nd), 0, nd)
homologous_distance = jnp.sum(nd * intersect_mask)
fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:]
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
return jnp.where(max_cnt == 0, 0, val / max_cnt)
non_homologous_cnt = jnp.sum(n_union_mask) - jnp.sum(n_intersect_mask)
if gene_type == 'node':
node_distance = batch_homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes)
else: # connection
node_distance = homologous_connection_distance(fr_sorted_nodes, sr_sorted_nodes)
node_distance = jnp.where(jnp.isnan(node_distance), 0, node_distance)
homologous_distance = jnp.sum(node_distance * n_intersect_mask[:-1])
@jit
def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 2])) # weight is not nan, means the connection exists
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 2]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
gene_cnt1 = jnp.sum(jnp.all(~jnp.isnan(ar1), axis=1))
gene_cnt2 = jnp.sum(jnp.all(~jnp.isnan(ar2), axis=1))
max_cnt = jnp.maximum(gene_cnt1, gene_cnt2)
cons = jnp.concatenate((cons1, cons2), axis=0)
keys = cons[:, :2]
sorted_indices = jnp.lexsort(keys.T[::-1])
cons = cons[sorted_indices]
cons = jnp.concatenate([cons, EMPTY_CON], axis=0) # add a nan row to the end
fr, sr = cons[:-1], cons[1:] # first row, second row
# both genome has such connection
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 2]) & ~jnp.isnan(sr[:, 2])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
cd = batch_homologous_connection_distance(fr, sr)
cd = jnp.where(jnp.isnan(cd), 0, cd)
homologous_distance = jnp.sum(cd * intersect_mask)
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
return jnp.where(max_cnt == 0, 0, val / max_cnt) # consider the case that both genome has no gene
return jnp.where(max_cnt == 0, 0, val / max_cnt)
@vmap
def batch_homologous_node_distance(b_n1, b_n2):

View File

@@ -7,12 +7,12 @@ from numpy.typing import NDArray
from .aggregations import agg
from .activations import act
from .graph import topological_sort, batch_topological_sort, topological_sort_debug
from .graph import topological_sort, batch_topological_sort
from .utils import I_INT
def create_forward_function(nodes: NDArray, connections: NDArray,
N: int, input_idx: NDArray, output_idx: NDArray, batch: bool, debug: bool = False):
N: int, input_idx: NDArray, output_idx: NDArray, batch: bool):
"""
create forward function for different situations
@@ -26,11 +26,6 @@ def create_forward_function(nodes: NDArray, connections: NDArray,
:return:
"""
if debug:
cal_seqs = topological_sort_debug(nodes, connections)
return lambda inputs: forward_single_debug(inputs, N, input_idx, output_idx,
cal_seqs, nodes, connections)
if nodes.ndim == 2: # single genome
cal_seqs = topological_sort(nodes, connections)
if not batch:
@@ -51,7 +46,6 @@ def create_forward_function(nodes: NDArray, connections: NDArray,
raise ValueError(f"nodes.ndim should be 2 or 3, but got {nodes.ndim}")
# @partial(jit, static_argnames=['N', 'input_idx', 'output_idx'])
@partial(jit, static_argnames=['N'])
def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
cal_seqs: Array, nodes: Array, connections: Array) -> Array:
@@ -79,38 +73,19 @@ def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
z = z * nodes[i, 2] + nodes[i, 1]
z = act(nodes[i, 3], z)
# for some nodes (inputs nodes), the output z will be nan, thus we do not update the vals
new_vals = jnp.where(jnp.isnan(z), carry, carry.at[i].set(z))
new_vals = carry.at[i].set(z)
return new_vals
def miss():
return carry
return jax.lax.cond(i == I_INT, miss, hit), None
return jax.lax.cond((i == I_INT) | (jnp.isin(i, input_idx)), miss, hit), None
vals, _ = jax.lax.scan(scan_body, ini_vals, cal_seqs)
return vals[output_idx]
def forward_single_debug(inputs, N, input_idx, output_idx: Array, cal_seqs, nodes, connections):
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[input_idx].set(inputs)
vals = ini_vals
for i in cal_seqs:
if i == I_INT:
break
ins = vals * connections[0, :, i]
z = agg(nodes[i, 4], ins)
z = z * nodes[i, 2] + nodes[i, 1]
z = act(nodes[i, 3], z)
# for some nodes (inputs nodes), the output z will be nan, thus we do not update the vals
vals = jnp.where(jnp.isnan(z), vals, vals.at[i].set(z))
return vals[output_idx]
@partial(vmap, in_axes=(0, None, None, None, None, None, None))
def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
cal_seqs: Array, nodes: Array, connections: Array) -> Array:

View File

@@ -208,7 +208,6 @@ def pop_analysis(pop_nodes, pop_connections, input_keys, output_keys):
return res
@jit
def add_node(new_node_key: int, nodes: Array, connections: Array,
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]:
@@ -247,7 +246,7 @@ def delete_node_by_idx(idx: int, nodes: Array, connections: Array) -> Tuple[Arra
@jit
def add_connection(from_node: int, to_node: int, nodes: Array, connections: Array,
weight: float = 0.0, enabled: bool = True) -> Tuple[Array, Array]:
weight: float = 1.0, enabled: bool = True) -> Tuple[Array, Array]:
"""
add a new connection to the genome.
"""

View File

@@ -74,26 +74,6 @@ def topological_sort(nodes: Array, connections: Array) -> Array:
return res
# @jit
def topological_sort_debug(nodes: Array, connections: Array) -> Array:
connections_enable = connections[1, :, :] == 1
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0))
res = jnp.full(in_degree.shape, I_INT)
idx = 0
for _ in range(in_degree.shape[0]):
i = fetch_first(in_degree == 0.)
if i == I_INT:
break
res = res.at[idx].set(i)
idx += 1
in_degree = in_degree.at[i].set(-1)
children = connections_enable[i, :]
in_degree = jnp.where(children, in_degree - 1, in_degree)
return res
@vmap
def batch_topological_sort(pop_nodes: Array, pop_connections: Array) -> Array:
"""
@@ -102,7 +82,7 @@ def batch_topological_sort(pop_nodes: Array, pop_connections: Array) -> Array:
:param pop_connections:
:return:
"""
return topological_sort(nodes, connections)
return topological_sort(pop_nodes, pop_connections)
@jit
@@ -148,7 +128,6 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra
check_cycles(nodes, connections, 0, 3) -> False
check_cycles(nodes, connections, 1, 0) -> False
"""
# connections_enable = connections[0, :, :] == 1
connections_enable = ~jnp.isnan(connections[0, :, :])
connections_enable = connections_enable.at[from_idx, to_idx].set(True)
@@ -191,7 +170,6 @@ if __name__ == '__main__':
]
)
print(topological_sort_debug(nodes, connections))
print(topological_sort(nodes, connections))
print(check_cycles(nodes, connections, 3, 2))

View File

@@ -403,13 +403,13 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection
# add two new connections
weight = new_connections[0, from_idx, to_idx]
new_nodes, new_connections = add_connection_by_idx(from_idx, new_idx,
new_nodes, new_connections, weight=0, enabled=True)
new_nodes, new_connections, weight=1., enabled=True)
new_nodes, new_connections = add_connection_by_idx(new_idx, to_idx,
new_nodes, new_connections, weight=weight, enabled=True)
return new_nodes, new_connections
# if from_idx == I_INT, that means no connection exist, do nothing
nodes, connections = jax.lax.select(from_idx == I_INT, nothing, successful_add_node)
nodes, connections = jax.lax.cond(from_idx == I_INT, nothing, successful_add_node)
return nodes, connections
@@ -430,16 +430,20 @@ def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
node_key, node_idx = choice_node_key(rand_key, nodes, input_keys, output_keys,
allow_input_keys=False, allow_output_keys=False)
# delete the node
aux_nodes, aux_connections = delete_node_by_idx(node_idx, nodes, connections)
def nothing():
return nodes, connections
# delete connections
aux_connections = aux_connections.at[:, node_idx, :].set(jnp.nan)
aux_connections = aux_connections.at[:, :, node_idx].set(jnp.nan)
def successful_delete_node():
# delete the node
aux_nodes, aux_connections = delete_node_by_idx(node_idx, nodes, connections)
# check node_key valid
nodes = jnp.where(jnp.isnan(node_key), nodes, aux_nodes) # if node_key is nan, do not delete the node
connections = jnp.where(jnp.isnan(node_key), connections, aux_connections)
# delete connections
aux_connections = aux_connections.at[:, node_idx, :].set(jnp.nan)
aux_connections = aux_connections.at[:, :, node_idx].set(jnp.nan)
return aux_nodes, aux_connections
nodes, connections = jax.lax.cond(node_idx == I_INT, nothing, successful_delete_node)
return nodes, connections
@@ -501,7 +505,7 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
def successfully_delete_connection():
return delete_connection_by_idx(from_idx, to_idx, nodes, connections)
nodes, connections = jax.lax.select(from_idx == I_INT, nothing, successfully_delete_connection)
nodes, connections = jax.lax.cond(from_idx == I_INT, nothing, successfully_delete_connection)
return nodes, connections
@@ -544,16 +548,22 @@ def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> T
:param connection:
:return: from_key, to_key, from_idx, to_idx
"""
k1, k2 = jax.random.split(rand_key, num=2)
has_connections_row = jnp.any(~jnp.isnan(connection[0, :, :]), axis=1)
from_idx = fetch_random(k1, has_connections_row)
col = connection[0, from_idx, :]
to_idx = fetch_random(k2, ~jnp.isnan(col))
from_key, to_key = nodes[from_idx, 0], nodes[to_idx, 0]
from_key = jnp.where(from_idx != I_INT, from_key, jnp.nan)
to_key = jnp.where(to_idx != I_INT, to_key, jnp.nan)
def nothing():
return jnp.nan, jnp.nan, I_INT, I_INT
def has_connection():
f_idx = fetch_random(k1, has_connections_row)
col = connection[0, f_idx, :]
t_idx = fetch_random(k2, ~jnp.isnan(col))
f_key, t_key = nodes[f_idx, 0], nodes[t_idx, 0]
return f_key, t_key, f_idx, t_idx
from_key, to_key, from_idx, to_idx = jax.lax.cond(jnp.any(has_connections_row), has_connection, nothing)
return from_key, to_key, from_idx, to_idx

View File

@@ -82,44 +82,6 @@ def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
return val / max_cnt
def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
if gene_type == 'node':
keys1, keys2 = ar1[:, :1], ar2[:, :1]
else: # connection
keys1, keys2 = ar1[:, :2], ar2[:, :2]
n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2)
nodes = np.concatenate((ar1, ar2), axis=0)
sorted_nodes = nodes[n_sorted_indices]
if gene_type == 'node':
node_exist_mask = np.any(~np.isnan(sorted_nodes[:, 1:]), axis=1)
else:
node_exist_mask = np.any(~np.isnan(sorted_nodes[:, 2:]), axis=1)
n_intersect_mask = n_intersect_mask & node_exist_mask
n_union_mask = n_union_mask & node_exist_mask
fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:]
non_homologous_cnt = np.sum(n_union_mask) - np.sum(n_intersect_mask)
if gene_type == 'node':
node_distance = batch_homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes)
else: # connection
node_distance = batch_homologous_connection_distance(fr_sorted_nodes, sr_sorted_nodes)
node_distance = np.where(np.isnan(node_distance), 0, node_distance)
homologous_distance = np.sum(node_distance * n_intersect_mask[:-1])
gene_cnt1 = np.sum(np.all(~np.isnan(ar1), axis=1))
gene_cnt2 = np.sum(np.all(~np.isnan(ar2), axis=1))
max_cnt = np.maximum(gene_cnt1, gene_cnt2)
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
return np.where(max_cnt == 0, 0, val / max_cnt) # consider the case that both genome has no gene
def batch_homologous_node_distance(b_n1, b_n2):
res = []
for n1, n2 in zip(b_n1, b_n2):

View File

@@ -6,7 +6,8 @@ from jax import numpy as jnp, Array
from jax import jit
I_INT = jnp.iinfo(jnp.int32).max # infinite int
EMPTY_NODE = jnp.full((1, 5), jnp.nan)
EMPTY_CON = jnp.full((1, 4), jnp.nan)
@jit
def flatten_connections(keys, connections):