bug down! Here it can solve xor successfully!
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from .genome import create_initialize_function, expand, expand_single
|
||||
from .genome import create_initialize_function, expand, expand_single, pop_analysis
|
||||
from .distance import distance
|
||||
from .mutate import create_mutate_function
|
||||
from .forward import create_forward_function
|
||||
|
||||
@@ -5,8 +5,7 @@ import jax
|
||||
from jax import jit, vmap, Array
|
||||
from jax import numpy as jnp
|
||||
|
||||
# from .utils import flatten_connections, unflatten_connections
|
||||
from algorithms.neat.genome.utils import flatten_connections, unflatten_connections
|
||||
from .utils import flatten_connections, unflatten_connections
|
||||
|
||||
|
||||
@vmap
|
||||
@@ -93,59 +92,4 @@ def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
|
||||
only gene with the same key will be crossover, thus don't need to consider change key
|
||||
"""
|
||||
r = jax.random.uniform(rand_key, shape=g1.shape)
|
||||
return jnp.where(r > 0.5, g1, g2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import numpy as np
|
||||
|
||||
randkey = jax.random.PRNGKey(40)
|
||||
nodes1 = np.array([
|
||||
[4, 1, 1, 1, 1],
|
||||
[6, 2, 2, 2, 2],
|
||||
[1, 3, 3, 3, 3],
|
||||
[5, 4, 4, 4, 4],
|
||||
[np.nan, np.nan, np.nan, np.nan, np.nan]
|
||||
])
|
||||
nodes2 = np.array([
|
||||
[4, 1.5, 1.5, 1.5, 1.5],
|
||||
[7, 3.5, 3.5, 3.5, 3.5],
|
||||
[5, 4.5, 4.5, 4.5, 4.5],
|
||||
[np.nan, np.nan, np.nan, np.nan, np.nan],
|
||||
[np.nan, np.nan, np.nan, np.nan, np.nan],
|
||||
])
|
||||
weights1 = np.array([
|
||||
[
|
||||
[1, 2, 3, 4., np.nan],
|
||||
[5, np.nan, 7, 8, np.nan],
|
||||
[9, 10, 11, 12, np.nan],
|
||||
[13, 14, 15, 16, np.nan],
|
||||
[np.nan, np.nan, np.nan, np.nan, np.nan]
|
||||
],
|
||||
[
|
||||
[0, 1, 0, 1, np.nan],
|
||||
[0, np.nan, 0, 1, np.nan],
|
||||
[0, 1, 0, 1, np.nan],
|
||||
[0, 1, 0, 1, np.nan],
|
||||
[np.nan, np.nan, np.nan, np.nan, np.nan]
|
||||
]
|
||||
])
|
||||
weights2 = np.array([
|
||||
[
|
||||
[1.5, 2.5, 3.5, np.nan, np.nan],
|
||||
[3.5, 4.5, 5.5, np.nan, np.nan],
|
||||
[6.5, 7.5, 8.5, np.nan, np.nan],
|
||||
[np.nan, np.nan, np.nan, np.nan, np.nan],
|
||||
[np.nan, np.nan, np.nan, np.nan, np.nan]
|
||||
],
|
||||
[
|
||||
[1, 0, 1, np.nan, np.nan],
|
||||
[1, 0, 1, np.nan, np.nan],
|
||||
[1, 0, 1, np.nan, np.nan],
|
||||
[np.nan, np.nan, np.nan, np.nan, np.nan],
|
||||
[np.nan, np.nan, np.nan, np.nan, np.nan]
|
||||
]
|
||||
])
|
||||
|
||||
res = crossover(randkey, nodes1, weights1, nodes2, weights2)
|
||||
print(*res, sep='\n')
|
||||
return jnp.where(r > 0.5, g1, g2)
|
||||
@@ -1,9 +1,7 @@
|
||||
from functools import partial
|
||||
|
||||
from jax import jit, vmap, Array
|
||||
from jax import numpy as jnp
|
||||
|
||||
from algorithms.neat.genome.utils import flatten_connections, set_operation_analysis
|
||||
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
|
||||
|
||||
|
||||
@jit
|
||||
@@ -14,55 +12,65 @@ def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Ar
|
||||
connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable]
|
||||
"""
|
||||
|
||||
node_distance = gene_distance(nodes1, nodes2, 'node')
|
||||
nd = node_distance(nodes1, nodes2) # node distance
|
||||
|
||||
# refactor connections
|
||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||
cons1 = flatten_connections(keys1, connections1)
|
||||
cons2 = flatten_connections(keys2, connections2)
|
||||
|
||||
connection_distance = gene_distance(cons1, cons2, 'connection')
|
||||
return node_distance + connection_distance
|
||||
cd = connection_distance(cons1, cons2) # connection distance
|
||||
return nd + cd
|
||||
|
||||
|
||||
@partial(jit, static_argnames=["gene_type"])
|
||||
def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
|
||||
if gene_type == 'node':
|
||||
keys1, keys2 = ar1[:, :1], ar2[:, :1]
|
||||
else: # connection
|
||||
keys1, keys2 = ar1[:, :2], ar2[:, :2]
|
||||
@jit
|
||||
def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5):
|
||||
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
|
||||
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
|
||||
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
|
||||
|
||||
n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2)
|
||||
nodes = jnp.concatenate((ar1, ar2), axis=0)
|
||||
sorted_nodes = nodes[n_sorted_indices]
|
||||
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
|
||||
keys = nodes[:, 0]
|
||||
sorted_indices = jnp.argsort(keys, axis=0)
|
||||
nodes = nodes[sorted_indices]
|
||||
nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end
|
||||
fr, sr = nodes[:-1], nodes[1:] # first row, second row
|
||||
nan_mask = jnp.isnan(nodes[:, 0])
|
||||
|
||||
if gene_type == 'node':
|
||||
node_exist_mask = jnp.any(~jnp.isnan(sorted_nodes[:, 1:]), axis=1)
|
||||
else:
|
||||
node_exist_mask = jnp.any(~jnp.isnan(sorted_nodes[:, 2:]), axis=1)
|
||||
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~nan_mask[:-1]
|
||||
|
||||
n_intersect_mask = n_intersect_mask & node_exist_mask
|
||||
n_union_mask = n_union_mask & node_exist_mask
|
||||
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
nd = batch_homologous_node_distance(fr, sr)
|
||||
nd = jnp.where(jnp.isnan(nd), 0, nd)
|
||||
homologous_distance = jnp.sum(nd * intersect_mask)
|
||||
|
||||
fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:]
|
||||
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt)
|
||||
|
||||
non_homologous_cnt = jnp.sum(n_union_mask) - jnp.sum(n_intersect_mask)
|
||||
if gene_type == 'node':
|
||||
node_distance = batch_homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes)
|
||||
else: # connection
|
||||
node_distance = homologous_connection_distance(fr_sorted_nodes, sr_sorted_nodes)
|
||||
|
||||
node_distance = jnp.where(jnp.isnan(node_distance), 0, node_distance)
|
||||
homologous_distance = jnp.sum(node_distance * n_intersect_mask[:-1])
|
||||
@jit
|
||||
def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
|
||||
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 2])) # weight is not nan, means the connection exists
|
||||
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 2]))
|
||||
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
|
||||
|
||||
gene_cnt1 = jnp.sum(jnp.all(~jnp.isnan(ar1), axis=1))
|
||||
gene_cnt2 = jnp.sum(jnp.all(~jnp.isnan(ar2), axis=1))
|
||||
max_cnt = jnp.maximum(gene_cnt1, gene_cnt2)
|
||||
cons = jnp.concatenate((cons1, cons2), axis=0)
|
||||
keys = cons[:, :2]
|
||||
sorted_indices = jnp.lexsort(keys.T[::-1])
|
||||
cons = cons[sorted_indices]
|
||||
cons = jnp.concatenate([cons, EMPTY_CON], axis=0) # add a nan row to the end
|
||||
fr, sr = cons[:-1], cons[1:] # first row, second row
|
||||
|
||||
# both genome has such connection
|
||||
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 2]) & ~jnp.isnan(sr[:, 2])
|
||||
|
||||
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
cd = batch_homologous_connection_distance(fr, sr)
|
||||
cd = jnp.where(jnp.isnan(cd), 0, cd)
|
||||
homologous_distance = jnp.sum(cd * intersect_mask)
|
||||
|
||||
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
|
||||
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt) # consider the case that both genome has no gene
|
||||
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt)
|
||||
|
||||
@vmap
|
||||
def batch_homologous_node_distance(b_n1, b_n2):
|
||||
|
||||
@@ -7,12 +7,12 @@ from numpy.typing import NDArray
|
||||
|
||||
from .aggregations import agg
|
||||
from .activations import act
|
||||
from .graph import topological_sort, batch_topological_sort, topological_sort_debug
|
||||
from .graph import topological_sort, batch_topological_sort
|
||||
from .utils import I_INT
|
||||
|
||||
|
||||
def create_forward_function(nodes: NDArray, connections: NDArray,
|
||||
N: int, input_idx: NDArray, output_idx: NDArray, batch: bool, debug: bool = False):
|
||||
N: int, input_idx: NDArray, output_idx: NDArray, batch: bool):
|
||||
"""
|
||||
create forward function for different situations
|
||||
|
||||
@@ -26,11 +26,6 @@ def create_forward_function(nodes: NDArray, connections: NDArray,
|
||||
:return:
|
||||
"""
|
||||
|
||||
if debug:
|
||||
cal_seqs = topological_sort_debug(nodes, connections)
|
||||
return lambda inputs: forward_single_debug(inputs, N, input_idx, output_idx,
|
||||
cal_seqs, nodes, connections)
|
||||
|
||||
if nodes.ndim == 2: # single genome
|
||||
cal_seqs = topological_sort(nodes, connections)
|
||||
if not batch:
|
||||
@@ -51,7 +46,6 @@ def create_forward_function(nodes: NDArray, connections: NDArray,
|
||||
raise ValueError(f"nodes.ndim should be 2 or 3, but got {nodes.ndim}")
|
||||
|
||||
|
||||
# @partial(jit, static_argnames=['N', 'input_idx', 'output_idx'])
|
||||
@partial(jit, static_argnames=['N'])
|
||||
def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
||||
cal_seqs: Array, nodes: Array, connections: Array) -> Array:
|
||||
@@ -79,38 +73,19 @@ def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
||||
z = z * nodes[i, 2] + nodes[i, 1]
|
||||
z = act(nodes[i, 3], z)
|
||||
|
||||
# for some nodes (inputs nodes), the output z will be nan, thus we do not update the vals
|
||||
new_vals = jnp.where(jnp.isnan(z), carry, carry.at[i].set(z))
|
||||
new_vals = carry.at[i].set(z)
|
||||
return new_vals
|
||||
|
||||
def miss():
|
||||
return carry
|
||||
|
||||
return jax.lax.cond(i == I_INT, miss, hit), None
|
||||
return jax.lax.cond((i == I_INT) | (jnp.isin(i, input_idx)), miss, hit), None
|
||||
|
||||
vals, _ = jax.lax.scan(scan_body, ini_vals, cal_seqs)
|
||||
|
||||
return vals[output_idx]
|
||||
|
||||
|
||||
def forward_single_debug(inputs, N, input_idx, output_idx: Array, cal_seqs, nodes, connections):
|
||||
ini_vals = jnp.full((N,), jnp.nan)
|
||||
ini_vals = ini_vals.at[input_idx].set(inputs)
|
||||
vals = ini_vals
|
||||
for i in cal_seqs:
|
||||
if i == I_INT:
|
||||
break
|
||||
ins = vals * connections[0, :, i]
|
||||
z = agg(nodes[i, 4], ins)
|
||||
z = z * nodes[i, 2] + nodes[i, 1]
|
||||
z = act(nodes[i, 3], z)
|
||||
|
||||
# for some nodes (inputs nodes), the output z will be nan, thus we do not update the vals
|
||||
vals = jnp.where(jnp.isnan(z), vals, vals.at[i].set(z))
|
||||
|
||||
return vals[output_idx]
|
||||
|
||||
|
||||
@partial(vmap, in_axes=(0, None, None, None, None, None, None))
|
||||
def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
|
||||
cal_seqs: Array, nodes: Array, connections: Array) -> Array:
|
||||
|
||||
@@ -208,7 +208,6 @@ def pop_analysis(pop_nodes, pop_connections, input_keys, output_keys):
|
||||
return res
|
||||
|
||||
|
||||
|
||||
@jit
|
||||
def add_node(new_node_key: int, nodes: Array, connections: Array,
|
||||
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]:
|
||||
@@ -247,7 +246,7 @@ def delete_node_by_idx(idx: int, nodes: Array, connections: Array) -> Tuple[Arra
|
||||
|
||||
@jit
|
||||
def add_connection(from_node: int, to_node: int, nodes: Array, connections: Array,
|
||||
weight: float = 0.0, enabled: bool = True) -> Tuple[Array, Array]:
|
||||
weight: float = 1.0, enabled: bool = True) -> Tuple[Array, Array]:
|
||||
"""
|
||||
add a new connection to the genome.
|
||||
"""
|
||||
|
||||
@@ -74,26 +74,6 @@ def topological_sort(nodes: Array, connections: Array) -> Array:
|
||||
return res
|
||||
|
||||
|
||||
# @jit
|
||||
def topological_sort_debug(nodes: Array, connections: Array) -> Array:
|
||||
connections_enable = connections[1, :, :] == 1
|
||||
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0))
|
||||
res = jnp.full(in_degree.shape, I_INT)
|
||||
idx = 0
|
||||
|
||||
for _ in range(in_degree.shape[0]):
|
||||
i = fetch_first(in_degree == 0.)
|
||||
if i == I_INT:
|
||||
break
|
||||
res = res.at[idx].set(i)
|
||||
idx += 1
|
||||
in_degree = in_degree.at[i].set(-1)
|
||||
children = connections_enable[i, :]
|
||||
in_degree = jnp.where(children, in_degree - 1, in_degree)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@vmap
|
||||
def batch_topological_sort(pop_nodes: Array, pop_connections: Array) -> Array:
|
||||
"""
|
||||
@@ -102,7 +82,7 @@ def batch_topological_sort(pop_nodes: Array, pop_connections: Array) -> Array:
|
||||
:param pop_connections:
|
||||
:return:
|
||||
"""
|
||||
return topological_sort(nodes, connections)
|
||||
return topological_sort(pop_nodes, pop_connections)
|
||||
|
||||
|
||||
@jit
|
||||
@@ -148,7 +128,6 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra
|
||||
check_cycles(nodes, connections, 0, 3) -> False
|
||||
check_cycles(nodes, connections, 1, 0) -> False
|
||||
"""
|
||||
# connections_enable = connections[0, :, :] == 1
|
||||
connections_enable = ~jnp.isnan(connections[0, :, :])
|
||||
|
||||
connections_enable = connections_enable.at[from_idx, to_idx].set(True)
|
||||
@@ -191,7 +170,6 @@ if __name__ == '__main__':
|
||||
]
|
||||
)
|
||||
|
||||
print(topological_sort_debug(nodes, connections))
|
||||
print(topological_sort(nodes, connections))
|
||||
|
||||
print(check_cycles(nodes, connections, 3, 2))
|
||||
|
||||
@@ -403,13 +403,13 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection
|
||||
# add two new connections
|
||||
weight = new_connections[0, from_idx, to_idx]
|
||||
new_nodes, new_connections = add_connection_by_idx(from_idx, new_idx,
|
||||
new_nodes, new_connections, weight=0, enabled=True)
|
||||
new_nodes, new_connections, weight=1., enabled=True)
|
||||
new_nodes, new_connections = add_connection_by_idx(new_idx, to_idx,
|
||||
new_nodes, new_connections, weight=weight, enabled=True)
|
||||
return new_nodes, new_connections
|
||||
|
||||
# if from_idx == I_INT, that means no connection exist, do nothing
|
||||
nodes, connections = jax.lax.select(from_idx == I_INT, nothing, successful_add_node)
|
||||
nodes, connections = jax.lax.cond(from_idx == I_INT, nothing, successful_add_node)
|
||||
|
||||
return nodes, connections
|
||||
|
||||
@@ -430,16 +430,20 @@ def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
|
||||
node_key, node_idx = choice_node_key(rand_key, nodes, input_keys, output_keys,
|
||||
allow_input_keys=False, allow_output_keys=False)
|
||||
|
||||
# delete the node
|
||||
aux_nodes, aux_connections = delete_node_by_idx(node_idx, nodes, connections)
|
||||
def nothing():
|
||||
return nodes, connections
|
||||
|
||||
# delete connections
|
||||
aux_connections = aux_connections.at[:, node_idx, :].set(jnp.nan)
|
||||
aux_connections = aux_connections.at[:, :, node_idx].set(jnp.nan)
|
||||
def successful_delete_node():
|
||||
# delete the node
|
||||
aux_nodes, aux_connections = delete_node_by_idx(node_idx, nodes, connections)
|
||||
|
||||
# check node_key valid
|
||||
nodes = jnp.where(jnp.isnan(node_key), nodes, aux_nodes) # if node_key is nan, do not delete the node
|
||||
connections = jnp.where(jnp.isnan(node_key), connections, aux_connections)
|
||||
# delete connections
|
||||
aux_connections = aux_connections.at[:, node_idx, :].set(jnp.nan)
|
||||
aux_connections = aux_connections.at[:, :, node_idx].set(jnp.nan)
|
||||
|
||||
return aux_nodes, aux_connections
|
||||
|
||||
nodes, connections = jax.lax.cond(node_idx == I_INT, nothing, successful_delete_node)
|
||||
|
||||
return nodes, connections
|
||||
|
||||
@@ -501,7 +505,7 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
|
||||
def successfully_delete_connection():
|
||||
return delete_connection_by_idx(from_idx, to_idx, nodes, connections)
|
||||
|
||||
nodes, connections = jax.lax.select(from_idx == I_INT, nothing, successfully_delete_connection)
|
||||
nodes, connections = jax.lax.cond(from_idx == I_INT, nothing, successfully_delete_connection)
|
||||
|
||||
return nodes, connections
|
||||
|
||||
@@ -544,16 +548,22 @@ def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> T
|
||||
:param connection:
|
||||
:return: from_key, to_key, from_idx, to_idx
|
||||
"""
|
||||
|
||||
k1, k2 = jax.random.split(rand_key, num=2)
|
||||
|
||||
has_connections_row = jnp.any(~jnp.isnan(connection[0, :, :]), axis=1)
|
||||
from_idx = fetch_random(k1, has_connections_row)
|
||||
col = connection[0, from_idx, :]
|
||||
to_idx = fetch_random(k2, ~jnp.isnan(col))
|
||||
from_key, to_key = nodes[from_idx, 0], nodes[to_idx, 0]
|
||||
|
||||
from_key = jnp.where(from_idx != I_INT, from_key, jnp.nan)
|
||||
to_key = jnp.where(to_idx != I_INT, to_key, jnp.nan)
|
||||
def nothing():
|
||||
return jnp.nan, jnp.nan, I_INT, I_INT
|
||||
|
||||
def has_connection():
|
||||
f_idx = fetch_random(k1, has_connections_row)
|
||||
col = connection[0, f_idx, :]
|
||||
t_idx = fetch_random(k2, ~jnp.isnan(col))
|
||||
f_key, t_key = nodes[f_idx, 0], nodes[t_idx, 0]
|
||||
return f_key, t_key, f_idx, t_idx
|
||||
|
||||
from_key, to_key, from_idx, to_idx = jax.lax.cond(jnp.any(has_connections_row), has_connection, nothing)
|
||||
return from_key, to_key, from_idx, to_idx
|
||||
|
||||
|
||||
|
||||
@@ -82,44 +82,6 @@ def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
|
||||
return val / max_cnt
|
||||
|
||||
|
||||
def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
|
||||
if gene_type == 'node':
|
||||
keys1, keys2 = ar1[:, :1], ar2[:, :1]
|
||||
else: # connection
|
||||
keys1, keys2 = ar1[:, :2], ar2[:, :2]
|
||||
|
||||
n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2)
|
||||
nodes = np.concatenate((ar1, ar2), axis=0)
|
||||
sorted_nodes = nodes[n_sorted_indices]
|
||||
|
||||
if gene_type == 'node':
|
||||
node_exist_mask = np.any(~np.isnan(sorted_nodes[:, 1:]), axis=1)
|
||||
else:
|
||||
node_exist_mask = np.any(~np.isnan(sorted_nodes[:, 2:]), axis=1)
|
||||
|
||||
n_intersect_mask = n_intersect_mask & node_exist_mask
|
||||
n_union_mask = n_union_mask & node_exist_mask
|
||||
|
||||
fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:]
|
||||
|
||||
non_homologous_cnt = np.sum(n_union_mask) - np.sum(n_intersect_mask)
|
||||
if gene_type == 'node':
|
||||
node_distance = batch_homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes)
|
||||
else: # connection
|
||||
node_distance = batch_homologous_connection_distance(fr_sorted_nodes, sr_sorted_nodes)
|
||||
|
||||
node_distance = np.where(np.isnan(node_distance), 0, node_distance)
|
||||
homologous_distance = np.sum(node_distance * n_intersect_mask[:-1])
|
||||
|
||||
gene_cnt1 = np.sum(np.all(~np.isnan(ar1), axis=1))
|
||||
gene_cnt2 = np.sum(np.all(~np.isnan(ar2), axis=1))
|
||||
max_cnt = np.maximum(gene_cnt1, gene_cnt2)
|
||||
|
||||
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
|
||||
|
||||
return np.where(max_cnt == 0, 0, val / max_cnt) # consider the case that both genome has no gene
|
||||
|
||||
|
||||
def batch_homologous_node_distance(b_n1, b_n2):
|
||||
res = []
|
||||
for n1, n2 in zip(b_n1, b_n2):
|
||||
|
||||
@@ -6,7 +6,8 @@ from jax import numpy as jnp, Array
|
||||
from jax import jit
|
||||
|
||||
I_INT = jnp.iinfo(jnp.int32).max # infinite int
|
||||
|
||||
EMPTY_NODE = jnp.full((1, 5), jnp.nan)
|
||||
EMPTY_CON = jnp.full((1, 4), jnp.nan)
|
||||
|
||||
@jit
|
||||
def flatten_connections(keys, connections):
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from typing import List, Union, Tuple, Callable
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
|
||||
from .species import SpeciesController
|
||||
from .genome.numpy import create_initialize_function, create_mutate_function, create_forward_function
|
||||
from .genome.numpy import batch_crossover
|
||||
from .genome.numpy import expand, expand_single, pop_analysis
|
||||
from .genome import create_initialize_function, create_mutate_function, create_forward_function
|
||||
from .genome import batch_crossover
|
||||
from .genome import expand, expand_single, pop_analysis
|
||||
|
||||
from .genome.origin_neat import *
|
||||
|
||||
@@ -19,7 +19,8 @@ class Pipeline:
|
||||
Neat algorithm pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, seed=42):
|
||||
self.randkey = jax.random.PRNGKey(seed)
|
||||
|
||||
self.config = config
|
||||
self.N = config.basic.init_maximum_nodes
|
||||
@@ -69,23 +70,23 @@ class Pipeline:
|
||||
|
||||
self.update_next_generation(crossover_pair)
|
||||
|
||||
analysis = pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx)
|
||||
# analysis = pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx)
|
||||
|
||||
try:
|
||||
for nodes, connections in zip(self.pop_nodes, self.pop_connections):
|
||||
g = array2object(self.config, nodes, connections)
|
||||
print(g)
|
||||
net = FeedForwardNetwork.create(g)
|
||||
real_out = [net.activate(x) for x in xor_inputs]
|
||||
func = create_forward_function(nodes, connections, self.N, self.input_idx, self.output_idx, batch=True)
|
||||
out = func(xor_inputs)
|
||||
real_out = np.array(real_out)
|
||||
out = np.array(out)
|
||||
print(real_out, out)
|
||||
assert np.allclose(real_out, out)
|
||||
except AssertionError:
|
||||
np.save("err_nodes.npy", self.pop_nodes)
|
||||
np.save("err_connections.npy", self.pop_connections)
|
||||
# try:
|
||||
# for nodes, connections in zip(self.pop_nodes, self.pop_connections):
|
||||
# g = array2object(self.config, nodes, connections)
|
||||
# print(g)
|
||||
# net = FeedForwardNetwork.create(g)
|
||||
# real_out = [net.activate(x) for x in xor_inputs]
|
||||
# func = create_forward_function(nodes, connections, self.N, self.input_idx, self.output_idx, batch=True)
|
||||
# out = func(xor_inputs)
|
||||
# real_out = np.array(real_out)
|
||||
# out = np.array(out)
|
||||
# print(real_out, out)
|
||||
# assert np.allclose(real_out, out)
|
||||
# except AssertionError:
|
||||
# np.save("err_nodes.npy", self.pop_nodes)
|
||||
# np.save("err_connections.npy", self.pop_connections)
|
||||
|
||||
# print(g)
|
||||
|
||||
@@ -93,7 +94,6 @@ class Pipeline:
|
||||
|
||||
self.expand()
|
||||
|
||||
|
||||
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
||||
for _ in range(self.config.neat.population.generation_limit):
|
||||
forward_func = self.ask(batch=True)
|
||||
@@ -109,7 +109,6 @@ class Pipeline:
|
||||
self.tell(fitnesses)
|
||||
print("Generation limit reached!")
|
||||
|
||||
|
||||
def update_next_generation(self, crossover_pair: List[Union[int, Tuple[int, int]]]) -> None:
|
||||
"""
|
||||
create the next generation
|
||||
@@ -117,6 +116,7 @@ class Pipeline:
|
||||
"""
|
||||
|
||||
assert self.pop_nodes.shape[0] == self.pop_size
|
||||
k1, k2, self.randkey = jax.random.split(self.randkey, 3)
|
||||
|
||||
# crossover
|
||||
# prepare elitism mask and crossover pair
|
||||
@@ -127,19 +127,20 @@ class Pipeline:
|
||||
crossover_pair[i] = (pair, pair)
|
||||
crossover_pair = np.array(crossover_pair)
|
||||
|
||||
crossover_rand_keys = jax.random.split(k1, self.pop_size)
|
||||
|
||||
# batch crossover
|
||||
wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes
|
||||
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
|
||||
lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes
|
||||
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
|
||||
# npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
|
||||
npn, npc = batch_crossover(wpn, wpc, lpn, lpc)
|
||||
# print(pop_analysis(npn, npc, self.input_idx, self.output_idx))
|
||||
npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
|
||||
|
||||
# mutate
|
||||
mutate_rand_keys = jax.random.split(k2, self.pop_size)
|
||||
new_node_keys = np.array(self.fetch_new_node_keys())
|
||||
|
||||
m_npn, m_npc = self.mutate_func(npn, npc, new_node_keys) # mutate_new_pop_nodes
|
||||
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
|
||||
|
||||
# elitism don't mutate
|
||||
# (pop_size, ) to (pop_size, 1, 1)
|
||||
@@ -156,7 +157,6 @@ class Pipeline:
|
||||
unused.append(key)
|
||||
self.new_node_keys_pool = unused + self.new_node_keys_pool
|
||||
|
||||
|
||||
def expand(self):
|
||||
"""
|
||||
Expand the population if needed.
|
||||
@@ -176,7 +176,6 @@ class Pipeline:
|
||||
for s in self.species_controller.species.values():
|
||||
s.representative = expand_single(*s.representative, self.N)
|
||||
|
||||
|
||||
def fetch_new_node_keys(self):
|
||||
# if remain unused keys are not enough, create new keys
|
||||
if len(self.new_node_keys_pool) < self.pop_size:
|
||||
@@ -189,7 +188,6 @@ class Pipeline:
|
||||
self.new_node_keys_pool = self.new_node_keys_pool[self.pop_size:]
|
||||
return res
|
||||
|
||||
|
||||
def default_analysis(self, fitnesses):
|
||||
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
|
||||
species_sizes = [len(s.members) for s in self.species_controller.species.values()]
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from typing import List, Tuple, Dict, Union
|
||||
from itertools import count
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
from .genome.numpy import distance
|
||||
|
||||
from .genome import distance
|
||||
|
||||
|
||||
class Species(object):
|
||||
@@ -45,6 +47,9 @@ class SpeciesController:
|
||||
self.species_idxer = count(0)
|
||||
self.species: Dict[int, Species] = {} # species_id -> species
|
||||
|
||||
self.distance = distance
|
||||
self.o2m_distance = jax.vmap(distance, in_axes=(None, None, 0, 0))
|
||||
|
||||
def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int) -> None:
|
||||
"""
|
||||
:param pop_nodes:
|
||||
@@ -62,7 +67,7 @@ class SpeciesController:
|
||||
for sid, species in self.species.items():
|
||||
# calculate the distance between the representative and the population
|
||||
r_nodes, r_connections = species.representative
|
||||
distances = o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections)
|
||||
distances = self.o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections)
|
||||
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
|
||||
|
||||
new_representatives[sid] = min_idx
|
||||
@@ -75,7 +80,7 @@ class SpeciesController:
|
||||
if previous_species_list: # exist previous species
|
||||
rid_list = [new_representatives[sid] for sid in previous_species_list]
|
||||
res_pop_distance = [
|
||||
o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
|
||||
self.o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
|
||||
for rid in rid_list
|
||||
]
|
||||
|
||||
@@ -102,7 +107,7 @@ class SpeciesController:
|
||||
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
|
||||
|
||||
distances = [
|
||||
distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
|
||||
self.distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
|
||||
for r in rid
|
||||
]
|
||||
distances = np.array(distances)
|
||||
@@ -314,16 +319,3 @@ def sort_element_with_fitnesses(members: List[int], fitnesses: List[float]) \
|
||||
sorted_members = [item[0] for item in sorted_combined]
|
||||
sorted_fitnesses = [item[1] for item in sorted_combined]
|
||||
return sorted_members, sorted_fitnesses
|
||||
|
||||
|
||||
def o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections):
|
||||
distances = []
|
||||
for nodes, connections in zip(pop_nodes, pop_connections):
|
||||
d = distance(r_nodes, r_connections, nodes, connections)
|
||||
if d < 0:
|
||||
d = distance(r_nodes, r_connections, nodes, connections)
|
||||
print(d)
|
||||
assert False
|
||||
distances.append(d)
|
||||
distances = np.stack(distances, axis=0)
|
||||
return distances
|
||||
|
||||
@@ -17,7 +17,7 @@ def evaluate(forward_func: Callable) -> List[float]:
|
||||
:return:
|
||||
"""
|
||||
outs = forward_func(xor_inputs)
|
||||
fitnesses = 4 - np.sum(np.abs(outs - xor_outputs), axis=(1, 2))
|
||||
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||
# print(fitnesses)
|
||||
return fitnesses.tolist() # returns a list
|
||||
|
||||
@@ -26,7 +26,7 @@ def evaluate(forward_func: Callable) -> List[float]:
|
||||
@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
|
||||
def main():
|
||||
config = Configer.load_config()
|
||||
pipeline = Pipeline(config)
|
||||
pipeline = Pipeline(config, seed=123123)
|
||||
pipeline.auto_run(evaluate)
|
||||
|
||||
# for _ in range(100):
|
||||
@@ -38,5 +38,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
np.random.seed(63124326)
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user