bug down! Here it can solve xor successfully!

This commit is contained in:
wls2002
2023-05-07 16:03:52 +08:00
parent d1f54022bd
commit a3b9bca866
12 changed files with 120 additions and 254 deletions

View File

@@ -1,4 +1,4 @@
from .genome import create_initialize_function, expand, expand_single from .genome import create_initialize_function, expand, expand_single, pop_analysis
from .distance import distance from .distance import distance
from .mutate import create_mutate_function from .mutate import create_mutate_function
from .forward import create_forward_function from .forward import create_forward_function

View File

@@ -5,8 +5,7 @@ import jax
from jax import jit, vmap, Array from jax import jit, vmap, Array
from jax import numpy as jnp from jax import numpy as jnp
# from .utils import flatten_connections, unflatten_connections from .utils import flatten_connections, unflatten_connections
from algorithms.neat.genome.utils import flatten_connections, unflatten_connections
@vmap @vmap
@@ -94,58 +93,3 @@ def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
""" """
r = jax.random.uniform(rand_key, shape=g1.shape) r = jax.random.uniform(rand_key, shape=g1.shape)
return jnp.where(r > 0.5, g1, g2) return jnp.where(r > 0.5, g1, g2)
if __name__ == '__main__':
import numpy as np
randkey = jax.random.PRNGKey(40)
nodes1 = np.array([
[4, 1, 1, 1, 1],
[6, 2, 2, 2, 2],
[1, 3, 3, 3, 3],
[5, 4, 4, 4, 4],
[np.nan, np.nan, np.nan, np.nan, np.nan]
])
nodes2 = np.array([
[4, 1.5, 1.5, 1.5, 1.5],
[7, 3.5, 3.5, 3.5, 3.5],
[5, 4.5, 4.5, 4.5, 4.5],
[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan],
])
weights1 = np.array([
[
[1, 2, 3, 4., np.nan],
[5, np.nan, 7, 8, np.nan],
[9, 10, 11, 12, np.nan],
[13, 14, 15, 16, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
],
[
[0, 1, 0, 1, np.nan],
[0, np.nan, 0, 1, np.nan],
[0, 1, 0, 1, np.nan],
[0, 1, 0, 1, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
]
])
weights2 = np.array([
[
[1.5, 2.5, 3.5, np.nan, np.nan],
[3.5, 4.5, 5.5, np.nan, np.nan],
[6.5, 7.5, 8.5, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
],
[
[1, 0, 1, np.nan, np.nan],
[1, 0, 1, np.nan, np.nan],
[1, 0, 1, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
]
])
res = crossover(randkey, nodes1, weights1, nodes2, weights2)
print(*res, sep='\n')

View File

@@ -1,9 +1,7 @@
from functools import partial
from jax import jit, vmap, Array from jax import jit, vmap, Array
from jax import numpy as jnp from jax import numpy as jnp
from algorithms.neat.genome.utils import flatten_connections, set_operation_analysis from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
@jit @jit
@@ -14,55 +12,65 @@ def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Ar
connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable] connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable]
""" """
node_distance = gene_distance(nodes1, nodes2, 'node') nd = node_distance(nodes1, nodes2) # node distance
# refactor connections # refactor connections
keys1, keys2 = nodes1[:, 0], nodes2[:, 0] keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
cons1 = flatten_connections(keys1, connections1) cons1 = flatten_connections(keys1, connections1)
cons2 = flatten_connections(keys2, connections2) cons2 = flatten_connections(keys2, connections2)
cd = connection_distance(cons1, cons2) # connection distance
connection_distance = gene_distance(cons1, cons2, 'connection') return nd + cd
return node_distance + connection_distance
@partial(jit, static_argnames=["gene_type"]) @jit
def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.): def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5):
if gene_type == 'node': node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
keys1, keys2 = ar1[:, :1], ar2[:, :1] node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
else: # connection max_cnt = jnp.maximum(node_cnt1, node_cnt2)
keys1, keys2 = ar1[:, :2], ar2[:, :2]
n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2) nodes = jnp.concatenate((nodes1, nodes2), axis=0)
nodes = jnp.concatenate((ar1, ar2), axis=0) keys = nodes[:, 0]
sorted_nodes = nodes[n_sorted_indices] sorted_indices = jnp.argsort(keys, axis=0)
nodes = nodes[sorted_indices]
nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row
nan_mask = jnp.isnan(nodes[:, 0])
if gene_type == 'node': intersect_mask = (fr[:, 0] == sr[:, 0]) & ~nan_mask[:-1]
node_exist_mask = jnp.any(~jnp.isnan(sorted_nodes[:, 1:]), axis=1)
else:
node_exist_mask = jnp.any(~jnp.isnan(sorted_nodes[:, 2:]), axis=1)
n_intersect_mask = n_intersect_mask & node_exist_mask non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
n_union_mask = n_union_mask & node_exist_mask nd = batch_homologous_node_distance(fr, sr)
nd = jnp.where(jnp.isnan(nd), 0, nd)
homologous_distance = jnp.sum(nd * intersect_mask)
fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:] val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
return jnp.where(max_cnt == 0, 0, val / max_cnt)
non_homologous_cnt = jnp.sum(n_union_mask) - jnp.sum(n_intersect_mask)
if gene_type == 'node':
node_distance = batch_homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes)
else: # connection
node_distance = homologous_connection_distance(fr_sorted_nodes, sr_sorted_nodes)
node_distance = jnp.where(jnp.isnan(node_distance), 0, node_distance) @jit
homologous_distance = jnp.sum(node_distance * n_intersect_mask[:-1]) def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 2])) # weight is not nan, means the connection exists
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 2]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
gene_cnt1 = jnp.sum(jnp.all(~jnp.isnan(ar1), axis=1)) cons = jnp.concatenate((cons1, cons2), axis=0)
gene_cnt2 = jnp.sum(jnp.all(~jnp.isnan(ar2), axis=1)) keys = cons[:, :2]
max_cnt = jnp.maximum(gene_cnt1, gene_cnt2) sorted_indices = jnp.lexsort(keys.T[::-1])
cons = cons[sorted_indices]
cons = jnp.concatenate([cons, EMPTY_CON], axis=0) # add a nan row to the end
fr, sr = cons[:-1], cons[1:] # first row, second row
# both genome has such connection
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 2]) & ~jnp.isnan(sr[:, 2])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
cd = batch_homologous_connection_distance(fr, sr)
cd = jnp.where(jnp.isnan(cd), 0, cd)
homologous_distance = jnp.sum(cd * intersect_mask)
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
return jnp.where(max_cnt == 0, 0, val / max_cnt) # consider the case that both genome has no gene return jnp.where(max_cnt == 0, 0, val / max_cnt)
@vmap @vmap
def batch_homologous_node_distance(b_n1, b_n2): def batch_homologous_node_distance(b_n1, b_n2):

View File

@@ -7,12 +7,12 @@ from numpy.typing import NDArray
from .aggregations import agg from .aggregations import agg
from .activations import act from .activations import act
from .graph import topological_sort, batch_topological_sort, topological_sort_debug from .graph import topological_sort, batch_topological_sort
from .utils import I_INT from .utils import I_INT
def create_forward_function(nodes: NDArray, connections: NDArray, def create_forward_function(nodes: NDArray, connections: NDArray,
N: int, input_idx: NDArray, output_idx: NDArray, batch: bool, debug: bool = False): N: int, input_idx: NDArray, output_idx: NDArray, batch: bool):
""" """
create forward function for different situations create forward function for different situations
@@ -26,11 +26,6 @@ def create_forward_function(nodes: NDArray, connections: NDArray,
:return: :return:
""" """
if debug:
cal_seqs = topological_sort_debug(nodes, connections)
return lambda inputs: forward_single_debug(inputs, N, input_idx, output_idx,
cal_seqs, nodes, connections)
if nodes.ndim == 2: # single genome if nodes.ndim == 2: # single genome
cal_seqs = topological_sort(nodes, connections) cal_seqs = topological_sort(nodes, connections)
if not batch: if not batch:
@@ -51,7 +46,6 @@ def create_forward_function(nodes: NDArray, connections: NDArray,
raise ValueError(f"nodes.ndim should be 2 or 3, but got {nodes.ndim}") raise ValueError(f"nodes.ndim should be 2 or 3, but got {nodes.ndim}")
# @partial(jit, static_argnames=['N', 'input_idx', 'output_idx'])
@partial(jit, static_argnames=['N']) @partial(jit, static_argnames=['N'])
def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array, def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
cal_seqs: Array, nodes: Array, connections: Array) -> Array: cal_seqs: Array, nodes: Array, connections: Array) -> Array:
@@ -79,38 +73,19 @@ def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
z = z * nodes[i, 2] + nodes[i, 1] z = z * nodes[i, 2] + nodes[i, 1]
z = act(nodes[i, 3], z) z = act(nodes[i, 3], z)
# for some nodes (inputs nodes), the output z will be nan, thus we do not update the vals new_vals = carry.at[i].set(z)
new_vals = jnp.where(jnp.isnan(z), carry, carry.at[i].set(z))
return new_vals return new_vals
def miss(): def miss():
return carry return carry
return jax.lax.cond(i == I_INT, miss, hit), None return jax.lax.cond((i == I_INT) | (jnp.isin(i, input_idx)), miss, hit), None
vals, _ = jax.lax.scan(scan_body, ini_vals, cal_seqs) vals, _ = jax.lax.scan(scan_body, ini_vals, cal_seqs)
return vals[output_idx] return vals[output_idx]
def forward_single_debug(inputs, N, input_idx, output_idx: Array, cal_seqs, nodes, connections):
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[input_idx].set(inputs)
vals = ini_vals
for i in cal_seqs:
if i == I_INT:
break
ins = vals * connections[0, :, i]
z = agg(nodes[i, 4], ins)
z = z * nodes[i, 2] + nodes[i, 1]
z = act(nodes[i, 3], z)
# for some nodes (inputs nodes), the output z will be nan, thus we do not update the vals
vals = jnp.where(jnp.isnan(z), vals, vals.at[i].set(z))
return vals[output_idx]
@partial(vmap, in_axes=(0, None, None, None, None, None, None)) @partial(vmap, in_axes=(0, None, None, None, None, None, None))
def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array, def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
cal_seqs: Array, nodes: Array, connections: Array) -> Array: cal_seqs: Array, nodes: Array, connections: Array) -> Array:

View File

@@ -208,7 +208,6 @@ def pop_analysis(pop_nodes, pop_connections, input_keys, output_keys):
return res return res
@jit @jit
def add_node(new_node_key: int, nodes: Array, connections: Array, def add_node(new_node_key: int, nodes: Array, connections: Array,
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]: bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]:
@@ -247,7 +246,7 @@ def delete_node_by_idx(idx: int, nodes: Array, connections: Array) -> Tuple[Arra
@jit @jit
def add_connection(from_node: int, to_node: int, nodes: Array, connections: Array, def add_connection(from_node: int, to_node: int, nodes: Array, connections: Array,
weight: float = 0.0, enabled: bool = True) -> Tuple[Array, Array]: weight: float = 1.0, enabled: bool = True) -> Tuple[Array, Array]:
""" """
add a new connection to the genome. add a new connection to the genome.
""" """

View File

@@ -74,26 +74,6 @@ def topological_sort(nodes: Array, connections: Array) -> Array:
return res return res
# @jit
def topological_sort_debug(nodes: Array, connections: Array) -> Array:
connections_enable = connections[1, :, :] == 1
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0))
res = jnp.full(in_degree.shape, I_INT)
idx = 0
for _ in range(in_degree.shape[0]):
i = fetch_first(in_degree == 0.)
if i == I_INT:
break
res = res.at[idx].set(i)
idx += 1
in_degree = in_degree.at[i].set(-1)
children = connections_enable[i, :]
in_degree = jnp.where(children, in_degree - 1, in_degree)
return res
@vmap @vmap
def batch_topological_sort(pop_nodes: Array, pop_connections: Array) -> Array: def batch_topological_sort(pop_nodes: Array, pop_connections: Array) -> Array:
""" """
@@ -102,7 +82,7 @@ def batch_topological_sort(pop_nodes: Array, pop_connections: Array) -> Array:
:param pop_connections: :param pop_connections:
:return: :return:
""" """
return topological_sort(nodes, connections) return topological_sort(pop_nodes, pop_connections)
@jit @jit
@@ -148,7 +128,6 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra
check_cycles(nodes, connections, 0, 3) -> False check_cycles(nodes, connections, 0, 3) -> False
check_cycles(nodes, connections, 1, 0) -> False check_cycles(nodes, connections, 1, 0) -> False
""" """
# connections_enable = connections[0, :, :] == 1
connections_enable = ~jnp.isnan(connections[0, :, :]) connections_enable = ~jnp.isnan(connections[0, :, :])
connections_enable = connections_enable.at[from_idx, to_idx].set(True) connections_enable = connections_enable.at[from_idx, to_idx].set(True)
@@ -191,7 +170,6 @@ if __name__ == '__main__':
] ]
) )
print(topological_sort_debug(nodes, connections))
print(topological_sort(nodes, connections)) print(topological_sort(nodes, connections))
print(check_cycles(nodes, connections, 3, 2)) print(check_cycles(nodes, connections, 3, 2))

View File

@@ -403,13 +403,13 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection
# add two new connections # add two new connections
weight = new_connections[0, from_idx, to_idx] weight = new_connections[0, from_idx, to_idx]
new_nodes, new_connections = add_connection_by_idx(from_idx, new_idx, new_nodes, new_connections = add_connection_by_idx(from_idx, new_idx,
new_nodes, new_connections, weight=0, enabled=True) new_nodes, new_connections, weight=1., enabled=True)
new_nodes, new_connections = add_connection_by_idx(new_idx, to_idx, new_nodes, new_connections = add_connection_by_idx(new_idx, to_idx,
new_nodes, new_connections, weight=weight, enabled=True) new_nodes, new_connections, weight=weight, enabled=True)
return new_nodes, new_connections return new_nodes, new_connections
# if from_idx == I_INT, that means no connection exist, do nothing # if from_idx == I_INT, that means no connection exist, do nothing
nodes, connections = jax.lax.select(from_idx == I_INT, nothing, successful_add_node) nodes, connections = jax.lax.cond(from_idx == I_INT, nothing, successful_add_node)
return nodes, connections return nodes, connections
@@ -430,6 +430,10 @@ def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
node_key, node_idx = choice_node_key(rand_key, nodes, input_keys, output_keys, node_key, node_idx = choice_node_key(rand_key, nodes, input_keys, output_keys,
allow_input_keys=False, allow_output_keys=False) allow_input_keys=False, allow_output_keys=False)
def nothing():
return nodes, connections
def successful_delete_node():
# delete the node # delete the node
aux_nodes, aux_connections = delete_node_by_idx(node_idx, nodes, connections) aux_nodes, aux_connections = delete_node_by_idx(node_idx, nodes, connections)
@@ -437,9 +441,9 @@ def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
aux_connections = aux_connections.at[:, node_idx, :].set(jnp.nan) aux_connections = aux_connections.at[:, node_idx, :].set(jnp.nan)
aux_connections = aux_connections.at[:, :, node_idx].set(jnp.nan) aux_connections = aux_connections.at[:, :, node_idx].set(jnp.nan)
# check node_key valid return aux_nodes, aux_connections
nodes = jnp.where(jnp.isnan(node_key), nodes, aux_nodes) # if node_key is nan, do not delete the node
connections = jnp.where(jnp.isnan(node_key), connections, aux_connections) nodes, connections = jax.lax.cond(node_idx == I_INT, nothing, successful_delete_node)
return nodes, connections return nodes, connections
@@ -501,7 +505,7 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
def successfully_delete_connection(): def successfully_delete_connection():
return delete_connection_by_idx(from_idx, to_idx, nodes, connections) return delete_connection_by_idx(from_idx, to_idx, nodes, connections)
nodes, connections = jax.lax.select(from_idx == I_INT, nothing, successfully_delete_connection) nodes, connections = jax.lax.cond(from_idx == I_INT, nothing, successfully_delete_connection)
return nodes, connections return nodes, connections
@@ -544,16 +548,22 @@ def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> T
:param connection: :param connection:
:return: from_key, to_key, from_idx, to_idx :return: from_key, to_key, from_idx, to_idx
""" """
k1, k2 = jax.random.split(rand_key, num=2) k1, k2 = jax.random.split(rand_key, num=2)
has_connections_row = jnp.any(~jnp.isnan(connection[0, :, :]), axis=1) has_connections_row = jnp.any(~jnp.isnan(connection[0, :, :]), axis=1)
from_idx = fetch_random(k1, has_connections_row)
col = connection[0, from_idx, :]
to_idx = fetch_random(k2, ~jnp.isnan(col))
from_key, to_key = nodes[from_idx, 0], nodes[to_idx, 0]
from_key = jnp.where(from_idx != I_INT, from_key, jnp.nan) def nothing():
to_key = jnp.where(to_idx != I_INT, to_key, jnp.nan) return jnp.nan, jnp.nan, I_INT, I_INT
def has_connection():
f_idx = fetch_random(k1, has_connections_row)
col = connection[0, f_idx, :]
t_idx = fetch_random(k2, ~jnp.isnan(col))
f_key, t_key = nodes[f_idx, 0], nodes[t_idx, 0]
return f_key, t_key, f_idx, t_idx
from_key, to_key, from_idx, to_idx = jax.lax.cond(jnp.any(has_connections_row), has_connection, nothing)
return from_key, to_key, from_idx, to_idx return from_key, to_key, from_idx, to_idx

View File

@@ -82,44 +82,6 @@ def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
return val / max_cnt return val / max_cnt
def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
if gene_type == 'node':
keys1, keys2 = ar1[:, :1], ar2[:, :1]
else: # connection
keys1, keys2 = ar1[:, :2], ar2[:, :2]
n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2)
nodes = np.concatenate((ar1, ar2), axis=0)
sorted_nodes = nodes[n_sorted_indices]
if gene_type == 'node':
node_exist_mask = np.any(~np.isnan(sorted_nodes[:, 1:]), axis=1)
else:
node_exist_mask = np.any(~np.isnan(sorted_nodes[:, 2:]), axis=1)
n_intersect_mask = n_intersect_mask & node_exist_mask
n_union_mask = n_union_mask & node_exist_mask
fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:]
non_homologous_cnt = np.sum(n_union_mask) - np.sum(n_intersect_mask)
if gene_type == 'node':
node_distance = batch_homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes)
else: # connection
node_distance = batch_homologous_connection_distance(fr_sorted_nodes, sr_sorted_nodes)
node_distance = np.where(np.isnan(node_distance), 0, node_distance)
homologous_distance = np.sum(node_distance * n_intersect_mask[:-1])
gene_cnt1 = np.sum(np.all(~np.isnan(ar1), axis=1))
gene_cnt2 = np.sum(np.all(~np.isnan(ar2), axis=1))
max_cnt = np.maximum(gene_cnt1, gene_cnt2)
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
return np.where(max_cnt == 0, 0, val / max_cnt) # consider the case that both genome has no gene
def batch_homologous_node_distance(b_n1, b_n2): def batch_homologous_node_distance(b_n1, b_n2):
res = [] res = []
for n1, n2 in zip(b_n1, b_n2): for n1, n2 in zip(b_n1, b_n2):

View File

@@ -6,7 +6,8 @@ from jax import numpy as jnp, Array
from jax import jit from jax import jit
I_INT = jnp.iinfo(jnp.int32).max # infinite int I_INT = jnp.iinfo(jnp.int32).max # infinite int
EMPTY_NODE = jnp.full((1, 5), jnp.nan)
EMPTY_CON = jnp.full((1, 4), jnp.nan)
@jit @jit
def flatten_connections(keys, connections): def flatten_connections(keys, connections):

View File

@@ -1,12 +1,12 @@
from typing import List, Union, Tuple, Callable from typing import List, Union, Tuple, Callable
import time import time
import numpy as np import jax
from .species import SpeciesController from .species import SpeciesController
from .genome.numpy import create_initialize_function, create_mutate_function, create_forward_function from .genome import create_initialize_function, create_mutate_function, create_forward_function
from .genome.numpy import batch_crossover from .genome import batch_crossover
from .genome.numpy import expand, expand_single, pop_analysis from .genome import expand, expand_single, pop_analysis
from .genome.origin_neat import * from .genome.origin_neat import *
@@ -19,7 +19,8 @@ class Pipeline:
Neat algorithm pipeline. Neat algorithm pipeline.
""" """
def __init__(self, config): def __init__(self, config, seed=42):
self.randkey = jax.random.PRNGKey(seed)
self.config = config self.config = config
self.N = config.basic.init_maximum_nodes self.N = config.basic.init_maximum_nodes
@@ -69,23 +70,23 @@ class Pipeline:
self.update_next_generation(crossover_pair) self.update_next_generation(crossover_pair)
analysis = pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx) # analysis = pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx)
try: # try:
for nodes, connections in zip(self.pop_nodes, self.pop_connections): # for nodes, connections in zip(self.pop_nodes, self.pop_connections):
g = array2object(self.config, nodes, connections) # g = array2object(self.config, nodes, connections)
print(g) # print(g)
net = FeedForwardNetwork.create(g) # net = FeedForwardNetwork.create(g)
real_out = [net.activate(x) for x in xor_inputs] # real_out = [net.activate(x) for x in xor_inputs]
func = create_forward_function(nodes, connections, self.N, self.input_idx, self.output_idx, batch=True) # func = create_forward_function(nodes, connections, self.N, self.input_idx, self.output_idx, batch=True)
out = func(xor_inputs) # out = func(xor_inputs)
real_out = np.array(real_out) # real_out = np.array(real_out)
out = np.array(out) # out = np.array(out)
print(real_out, out) # print(real_out, out)
assert np.allclose(real_out, out) # assert np.allclose(real_out, out)
except AssertionError: # except AssertionError:
np.save("err_nodes.npy", self.pop_nodes) # np.save("err_nodes.npy", self.pop_nodes)
np.save("err_connections.npy", self.pop_connections) # np.save("err_connections.npy", self.pop_connections)
# print(g) # print(g)
@@ -93,7 +94,6 @@ class Pipeline:
self.expand() self.expand()
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"): def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config.neat.population.generation_limit): for _ in range(self.config.neat.population.generation_limit):
forward_func = self.ask(batch=True) forward_func = self.ask(batch=True)
@@ -109,7 +109,6 @@ class Pipeline:
self.tell(fitnesses) self.tell(fitnesses)
print("Generation limit reached!") print("Generation limit reached!")
def update_next_generation(self, crossover_pair: List[Union[int, Tuple[int, int]]]) -> None: def update_next_generation(self, crossover_pair: List[Union[int, Tuple[int, int]]]) -> None:
""" """
create the next generation create the next generation
@@ -117,6 +116,7 @@ class Pipeline:
""" """
assert self.pop_nodes.shape[0] == self.pop_size assert self.pop_nodes.shape[0] == self.pop_size
k1, k2, self.randkey = jax.random.split(self.randkey, 3)
# crossover # crossover
# prepare elitism mask and crossover pair # prepare elitism mask and crossover pair
@@ -127,19 +127,20 @@ class Pipeline:
crossover_pair[i] = (pair, pair) crossover_pair[i] = (pair, pair)
crossover_pair = np.array(crossover_pair) crossover_pair = np.array(crossover_pair)
crossover_rand_keys = jax.random.split(k1, self.pop_size)
# batch crossover # batch crossover
wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
# npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
npn, npc = batch_crossover(wpn, wpc, lpn, lpc)
# print(pop_analysis(npn, npc, self.input_idx, self.output_idx))
# mutate # mutate
mutate_rand_keys = jax.random.split(k2, self.pop_size)
new_node_keys = np.array(self.fetch_new_node_keys()) new_node_keys = np.array(self.fetch_new_node_keys())
m_npn, m_npc = self.mutate_func(npn, npc, new_node_keys) # mutate_new_pop_nodes m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
# elitism don't mutate # elitism don't mutate
# (pop_size, ) to (pop_size, 1, 1) # (pop_size, ) to (pop_size, 1, 1)
@@ -156,7 +157,6 @@ class Pipeline:
unused.append(key) unused.append(key)
self.new_node_keys_pool = unused + self.new_node_keys_pool self.new_node_keys_pool = unused + self.new_node_keys_pool
def expand(self): def expand(self):
""" """
Expand the population if needed. Expand the population if needed.
@@ -176,7 +176,6 @@ class Pipeline:
for s in self.species_controller.species.values(): for s in self.species_controller.species.values():
s.representative = expand_single(*s.representative, self.N) s.representative = expand_single(*s.representative, self.N)
def fetch_new_node_keys(self): def fetch_new_node_keys(self):
# if remain unused keys are not enough, create new keys # if remain unused keys are not enough, create new keys
if len(self.new_node_keys_pool) < self.pop_size: if len(self.new_node_keys_pool) < self.pop_size:
@@ -189,7 +188,6 @@ class Pipeline:
self.new_node_keys_pool = self.new_node_keys_pool[self.pop_size:] self.new_node_keys_pool = self.new_node_keys_pool[self.pop_size:]
return res return res
def default_analysis(self, fitnesses): def default_analysis(self, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses) max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
species_sizes = [len(s.members) for s in self.species_controller.species.values()] species_sizes = [len(s.members) for s in self.species_controller.species.values()]

View File

@@ -1,9 +1,11 @@
from typing import List, Tuple, Dict, Union from typing import List, Tuple, Dict, Union
from itertools import count from itertools import count
import jax
import numpy as np import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
from .genome.numpy import distance
from .genome import distance
class Species(object): class Species(object):
@@ -45,6 +47,9 @@ class SpeciesController:
self.species_idxer = count(0) self.species_idxer = count(0)
self.species: Dict[int, Species] = {} # species_id -> species self.species: Dict[int, Species] = {} # species_id -> species
self.distance = distance
self.o2m_distance = jax.vmap(distance, in_axes=(None, None, 0, 0))
def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int) -> None: def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int) -> None:
""" """
:param pop_nodes: :param pop_nodes:
@@ -62,7 +67,7 @@ class SpeciesController:
for sid, species in self.species.items(): for sid, species in self.species.items():
# calculate the distance between the representative and the population # calculate the distance between the representative and the population
r_nodes, r_connections = species.representative r_nodes, r_connections = species.representative
distances = o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections) distances = self.o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections)
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
new_representatives[sid] = min_idx new_representatives[sid] = min_idx
@@ -75,7 +80,7 @@ class SpeciesController:
if previous_species_list: # exist previous species if previous_species_list: # exist previous species
rid_list = [new_representatives[sid] for sid in previous_species_list] rid_list = [new_representatives[sid] for sid in previous_species_list]
res_pop_distance = [ res_pop_distance = [
o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections) self.o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
for rid in rid_list for rid in rid_list
] ]
@@ -102,7 +107,7 @@ class SpeciesController:
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()])) sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
distances = [ distances = [
distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r]) self.distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
for r in rid for r in rid
] ]
distances = np.array(distances) distances = np.array(distances)
@@ -314,16 +319,3 @@ def sort_element_with_fitnesses(members: List[int], fitnesses: List[float]) \
sorted_members = [item[0] for item in sorted_combined] sorted_members = [item[0] for item in sorted_combined]
sorted_fitnesses = [item[1] for item in sorted_combined] sorted_fitnesses = [item[1] for item in sorted_combined]
return sorted_members, sorted_fitnesses return sorted_members, sorted_fitnesses
def o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections):
distances = []
for nodes, connections in zip(pop_nodes, pop_connections):
d = distance(r_nodes, r_connections, nodes, connections)
if d < 0:
d = distance(r_nodes, r_connections, nodes, connections)
print(d)
assert False
distances.append(d)
distances = np.stack(distances, axis=0)
return distances

View File

@@ -17,7 +17,7 @@ def evaluate(forward_func: Callable) -> List[float]:
:return: :return:
""" """
outs = forward_func(xor_inputs) outs = forward_func(xor_inputs)
fitnesses = 4 - np.sum(np.abs(outs - xor_outputs), axis=(1, 2)) fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
# print(fitnesses) # print(fitnesses)
return fitnesses.tolist() # returns a list return fitnesses.tolist() # returns a list
@@ -26,7 +26,7 @@ def evaluate(forward_func: Callable) -> List[float]:
@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/") @partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
def main(): def main():
config = Configer.load_config() config = Configer.load_config()
pipeline = Pipeline(config) pipeline = Pipeline(config, seed=123123)
pipeline.auto_run(evaluate) pipeline.auto_run(evaluate)
# for _ in range(100): # for _ in range(100):
@@ -38,5 +38,4 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
np.random.seed(63124326)
main() main()