bug down! Here it can solve xor successfully!

This commit is contained in:
wls2002
2023-05-07 16:03:52 +08:00
parent d1f54022bd
commit a3b9bca866
12 changed files with 120 additions and 254 deletions

View File

@@ -1,4 +1,4 @@
from .genome import create_initialize_function, expand, expand_single
from .genome import create_initialize_function, expand, expand_single, pop_analysis
from .distance import distance
from .mutate import create_mutate_function
from .forward import create_forward_function

View File

@@ -5,8 +5,7 @@ import jax
from jax import jit, vmap, Array
from jax import numpy as jnp
# from .utils import flatten_connections, unflatten_connections
from algorithms.neat.genome.utils import flatten_connections, unflatten_connections
from .utils import flatten_connections, unflatten_connections
@vmap
@@ -93,59 +92,4 @@ def crossover_gene(rand_key: Array, g1: Array, g2: Array) -> Array:
only gene with the same key will be crossover, thus don't need to consider change key
"""
r = jax.random.uniform(rand_key, shape=g1.shape)
return jnp.where(r > 0.5, g1, g2)
if __name__ == '__main__':
import numpy as np
randkey = jax.random.PRNGKey(40)
nodes1 = np.array([
[4, 1, 1, 1, 1],
[6, 2, 2, 2, 2],
[1, 3, 3, 3, 3],
[5, 4, 4, 4, 4],
[np.nan, np.nan, np.nan, np.nan, np.nan]
])
nodes2 = np.array([
[4, 1.5, 1.5, 1.5, 1.5],
[7, 3.5, 3.5, 3.5, 3.5],
[5, 4.5, 4.5, 4.5, 4.5],
[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan],
])
weights1 = np.array([
[
[1, 2, 3, 4., np.nan],
[5, np.nan, 7, 8, np.nan],
[9, 10, 11, 12, np.nan],
[13, 14, 15, 16, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
],
[
[0, 1, 0, 1, np.nan],
[0, np.nan, 0, 1, np.nan],
[0, 1, 0, 1, np.nan],
[0, 1, 0, 1, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
]
])
weights2 = np.array([
[
[1.5, 2.5, 3.5, np.nan, np.nan],
[3.5, 4.5, 5.5, np.nan, np.nan],
[6.5, 7.5, 8.5, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
],
[
[1, 0, 1, np.nan, np.nan],
[1, 0, 1, np.nan, np.nan],
[1, 0, 1, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan]
]
])
res = crossover(randkey, nodes1, weights1, nodes2, weights2)
print(*res, sep='\n')
return jnp.where(r > 0.5, g1, g2)

View File

@@ -1,9 +1,7 @@
from functools import partial
from jax import jit, vmap, Array
from jax import numpy as jnp
from algorithms.neat.genome.utils import flatten_connections, set_operation_analysis
from .utils import flatten_connections, EMPTY_NODE, EMPTY_CON
@jit
@@ -14,55 +12,65 @@ def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Ar
connections are a 3-d array with shape (2, N, N), axis 0 means [weights, enable]
"""
node_distance = gene_distance(nodes1, nodes2, 'node')
nd = node_distance(nodes1, nodes2) # node distance
# refactor connections
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
cons1 = flatten_connections(keys1, connections1)
cons2 = flatten_connections(keys2, connections2)
connection_distance = gene_distance(cons1, cons2, 'connection')
return node_distance + connection_distance
cd = connection_distance(cons1, cons2) # connection distance
return nd + cd
@partial(jit, static_argnames=["gene_type"])
def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
if gene_type == 'node':
keys1, keys2 = ar1[:, :1], ar2[:, :1]
else: # connection
keys1, keys2 = ar1[:, :2], ar2[:, :2]
@jit
def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5):
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2)
nodes = jnp.concatenate((ar1, ar2), axis=0)
sorted_nodes = nodes[n_sorted_indices]
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0]
sorted_indices = jnp.argsort(keys, axis=0)
nodes = nodes[sorted_indices]
nodes = jnp.concatenate([nodes, EMPTY_NODE], axis=0) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row
nan_mask = jnp.isnan(nodes[:, 0])
if gene_type == 'node':
node_exist_mask = jnp.any(~jnp.isnan(sorted_nodes[:, 1:]), axis=1)
else:
node_exist_mask = jnp.any(~jnp.isnan(sorted_nodes[:, 2:]), axis=1)
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~nan_mask[:-1]
n_intersect_mask = n_intersect_mask & node_exist_mask
n_union_mask = n_union_mask & node_exist_mask
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
nd = batch_homologous_node_distance(fr, sr)
nd = jnp.where(jnp.isnan(nd), 0, nd)
homologous_distance = jnp.sum(nd * intersect_mask)
fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:]
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
return jnp.where(max_cnt == 0, 0, val / max_cnt)
non_homologous_cnt = jnp.sum(n_union_mask) - jnp.sum(n_intersect_mask)
if gene_type == 'node':
node_distance = batch_homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes)
else: # connection
node_distance = homologous_connection_distance(fr_sorted_nodes, sr_sorted_nodes)
node_distance = jnp.where(jnp.isnan(node_distance), 0, node_distance)
homologous_distance = jnp.sum(node_distance * n_intersect_mask[:-1])
@jit
def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
con_cnt1 = jnp.sum(~jnp.isnan(cons1[:, 2])) # weight is not nan, means the connection exists
con_cnt2 = jnp.sum(~jnp.isnan(cons2[:, 2]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
gene_cnt1 = jnp.sum(jnp.all(~jnp.isnan(ar1), axis=1))
gene_cnt2 = jnp.sum(jnp.all(~jnp.isnan(ar2), axis=1))
max_cnt = jnp.maximum(gene_cnt1, gene_cnt2)
cons = jnp.concatenate((cons1, cons2), axis=0)
keys = cons[:, :2]
sorted_indices = jnp.lexsort(keys.T[::-1])
cons = cons[sorted_indices]
cons = jnp.concatenate([cons, EMPTY_CON], axis=0) # add a nan row to the end
fr, sr = cons[:-1], cons[1:] # first row, second row
# both genome has such connection
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 2]) & ~jnp.isnan(sr[:, 2])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
cd = batch_homologous_connection_distance(fr, sr)
cd = jnp.where(jnp.isnan(cd), 0, cd)
homologous_distance = jnp.sum(cd * intersect_mask)
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
return jnp.where(max_cnt == 0, 0, val / max_cnt) # consider the case that both genome has no gene
return jnp.where(max_cnt == 0, 0, val / max_cnt)
@vmap
def batch_homologous_node_distance(b_n1, b_n2):

View File

@@ -7,12 +7,12 @@ from numpy.typing import NDArray
from .aggregations import agg
from .activations import act
from .graph import topological_sort, batch_topological_sort, topological_sort_debug
from .graph import topological_sort, batch_topological_sort
from .utils import I_INT
def create_forward_function(nodes: NDArray, connections: NDArray,
N: int, input_idx: NDArray, output_idx: NDArray, batch: bool, debug: bool = False):
N: int, input_idx: NDArray, output_idx: NDArray, batch: bool):
"""
create forward function for different situations
@@ -26,11 +26,6 @@ def create_forward_function(nodes: NDArray, connections: NDArray,
:return:
"""
if debug:
cal_seqs = topological_sort_debug(nodes, connections)
return lambda inputs: forward_single_debug(inputs, N, input_idx, output_idx,
cal_seqs, nodes, connections)
if nodes.ndim == 2: # single genome
cal_seqs = topological_sort(nodes, connections)
if not batch:
@@ -51,7 +46,6 @@ def create_forward_function(nodes: NDArray, connections: NDArray,
raise ValueError(f"nodes.ndim should be 2 or 3, but got {nodes.ndim}")
# @partial(jit, static_argnames=['N', 'input_idx', 'output_idx'])
@partial(jit, static_argnames=['N'])
def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
cal_seqs: Array, nodes: Array, connections: Array) -> Array:
@@ -79,38 +73,19 @@ def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
z = z * nodes[i, 2] + nodes[i, 1]
z = act(nodes[i, 3], z)
# for some nodes (inputs nodes), the output z will be nan, thus we do not update the vals
new_vals = jnp.where(jnp.isnan(z), carry, carry.at[i].set(z))
new_vals = carry.at[i].set(z)
return new_vals
def miss():
return carry
return jax.lax.cond(i == I_INT, miss, hit), None
return jax.lax.cond((i == I_INT) | (jnp.isin(i, input_idx)), miss, hit), None
vals, _ = jax.lax.scan(scan_body, ini_vals, cal_seqs)
return vals[output_idx]
def forward_single_debug(inputs, N, input_idx, output_idx: Array, cal_seqs, nodes, connections):
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[input_idx].set(inputs)
vals = ini_vals
for i in cal_seqs:
if i == I_INT:
break
ins = vals * connections[0, :, i]
z = agg(nodes[i, 4], ins)
z = z * nodes[i, 2] + nodes[i, 1]
z = act(nodes[i, 3], z)
# for some nodes (inputs nodes), the output z will be nan, thus we do not update the vals
vals = jnp.where(jnp.isnan(z), vals, vals.at[i].set(z))
return vals[output_idx]
@partial(vmap, in_axes=(0, None, None, None, None, None, None))
def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
cal_seqs: Array, nodes: Array, connections: Array) -> Array:

View File

@@ -208,7 +208,6 @@ def pop_analysis(pop_nodes, pop_connections, input_keys, output_keys):
return res
@jit
def add_node(new_node_key: int, nodes: Array, connections: Array,
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[Array, Array]:
@@ -247,7 +246,7 @@ def delete_node_by_idx(idx: int, nodes: Array, connections: Array) -> Tuple[Arra
@jit
def add_connection(from_node: int, to_node: int, nodes: Array, connections: Array,
weight: float = 0.0, enabled: bool = True) -> Tuple[Array, Array]:
weight: float = 1.0, enabled: bool = True) -> Tuple[Array, Array]:
"""
add a new connection to the genome.
"""

View File

@@ -74,26 +74,6 @@ def topological_sort(nodes: Array, connections: Array) -> Array:
return res
# @jit
def topological_sort_debug(nodes: Array, connections: Array) -> Array:
connections_enable = connections[1, :, :] == 1
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(connections_enable, axis=0))
res = jnp.full(in_degree.shape, I_INT)
idx = 0
for _ in range(in_degree.shape[0]):
i = fetch_first(in_degree == 0.)
if i == I_INT:
break
res = res.at[idx].set(i)
idx += 1
in_degree = in_degree.at[i].set(-1)
children = connections_enable[i, :]
in_degree = jnp.where(children, in_degree - 1, in_degree)
return res
@vmap
def batch_topological_sort(pop_nodes: Array, pop_connections: Array) -> Array:
"""
@@ -102,7 +82,7 @@ def batch_topological_sort(pop_nodes: Array, pop_connections: Array) -> Array:
:param pop_connections:
:return:
"""
return topological_sort(nodes, connections)
return topological_sort(pop_nodes, pop_connections)
@jit
@@ -148,7 +128,6 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra
check_cycles(nodes, connections, 0, 3) -> False
check_cycles(nodes, connections, 1, 0) -> False
"""
# connections_enable = connections[0, :, :] == 1
connections_enable = ~jnp.isnan(connections[0, :, :])
connections_enable = connections_enable.at[from_idx, to_idx].set(True)
@@ -191,7 +170,6 @@ if __name__ == '__main__':
]
)
print(topological_sort_debug(nodes, connections))
print(topological_sort(nodes, connections))
print(check_cycles(nodes, connections, 3, 2))

View File

@@ -403,13 +403,13 @@ def mutate_add_node(rand_key: Array, new_node_key: int, nodes: Array, connection
# add two new connections
weight = new_connections[0, from_idx, to_idx]
new_nodes, new_connections = add_connection_by_idx(from_idx, new_idx,
new_nodes, new_connections, weight=0, enabled=True)
new_nodes, new_connections, weight=1., enabled=True)
new_nodes, new_connections = add_connection_by_idx(new_idx, to_idx,
new_nodes, new_connections, weight=weight, enabled=True)
return new_nodes, new_connections
# if from_idx == I_INT, that means no connection exist, do nothing
nodes, connections = jax.lax.select(from_idx == I_INT, nothing, successful_add_node)
nodes, connections = jax.lax.cond(from_idx == I_INT, nothing, successful_add_node)
return nodes, connections
@@ -430,16 +430,20 @@ def mutate_delete_node(rand_key: Array, nodes: Array, connections: Array,
node_key, node_idx = choice_node_key(rand_key, nodes, input_keys, output_keys,
allow_input_keys=False, allow_output_keys=False)
# delete the node
aux_nodes, aux_connections = delete_node_by_idx(node_idx, nodes, connections)
def nothing():
return nodes, connections
# delete connections
aux_connections = aux_connections.at[:, node_idx, :].set(jnp.nan)
aux_connections = aux_connections.at[:, :, node_idx].set(jnp.nan)
def successful_delete_node():
# delete the node
aux_nodes, aux_connections = delete_node_by_idx(node_idx, nodes, connections)
# check node_key valid
nodes = jnp.where(jnp.isnan(node_key), nodes, aux_nodes) # if node_key is nan, do not delete the node
connections = jnp.where(jnp.isnan(node_key), connections, aux_connections)
# delete connections
aux_connections = aux_connections.at[:, node_idx, :].set(jnp.nan)
aux_connections = aux_connections.at[:, :, node_idx].set(jnp.nan)
return aux_nodes, aux_connections
nodes, connections = jax.lax.cond(node_idx == I_INT, nothing, successful_delete_node)
return nodes, connections
@@ -501,7 +505,7 @@ def mutate_delete_connection(rand_key: Array, nodes: Array, connections: Array):
def successfully_delete_connection():
return delete_connection_by_idx(from_idx, to_idx, nodes, connections)
nodes, connections = jax.lax.select(from_idx == I_INT, nothing, successfully_delete_connection)
nodes, connections = jax.lax.cond(from_idx == I_INT, nothing, successfully_delete_connection)
return nodes, connections
@@ -544,16 +548,22 @@ def choice_connection_key(rand_key: Array, nodes: Array, connection: Array) -> T
:param connection:
:return: from_key, to_key, from_idx, to_idx
"""
k1, k2 = jax.random.split(rand_key, num=2)
has_connections_row = jnp.any(~jnp.isnan(connection[0, :, :]), axis=1)
from_idx = fetch_random(k1, has_connections_row)
col = connection[0, from_idx, :]
to_idx = fetch_random(k2, ~jnp.isnan(col))
from_key, to_key = nodes[from_idx, 0], nodes[to_idx, 0]
from_key = jnp.where(from_idx != I_INT, from_key, jnp.nan)
to_key = jnp.where(to_idx != I_INT, to_key, jnp.nan)
def nothing():
return jnp.nan, jnp.nan, I_INT, I_INT
def has_connection():
f_idx = fetch_random(k1, has_connections_row)
col = connection[0, f_idx, :]
t_idx = fetch_random(k2, ~jnp.isnan(col))
f_key, t_key = nodes[f_idx, 0], nodes[t_idx, 0]
return f_key, t_key, f_idx, t_idx
from_key, to_key, from_idx, to_idx = jax.lax.cond(jnp.any(has_connections_row), has_connection, nothing)
return from_key, to_key, from_idx, to_idx

View File

@@ -82,44 +82,6 @@ def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
return val / max_cnt
def gene_distance(ar1, ar2, gene_type, compatibility_coe=0.5, disjoint_coe=1.):
if gene_type == 'node':
keys1, keys2 = ar1[:, :1], ar2[:, :1]
else: # connection
keys1, keys2 = ar1[:, :2], ar2[:, :2]
n_sorted_indices, n_intersect_mask, n_union_mask = set_operation_analysis(keys1, keys2)
nodes = np.concatenate((ar1, ar2), axis=0)
sorted_nodes = nodes[n_sorted_indices]
if gene_type == 'node':
node_exist_mask = np.any(~np.isnan(sorted_nodes[:, 1:]), axis=1)
else:
node_exist_mask = np.any(~np.isnan(sorted_nodes[:, 2:]), axis=1)
n_intersect_mask = n_intersect_mask & node_exist_mask
n_union_mask = n_union_mask & node_exist_mask
fr_sorted_nodes, sr_sorted_nodes = sorted_nodes[:-1], sorted_nodes[1:]
non_homologous_cnt = np.sum(n_union_mask) - np.sum(n_intersect_mask)
if gene_type == 'node':
node_distance = batch_homologous_node_distance(fr_sorted_nodes, sr_sorted_nodes)
else: # connection
node_distance = batch_homologous_connection_distance(fr_sorted_nodes, sr_sorted_nodes)
node_distance = np.where(np.isnan(node_distance), 0, node_distance)
homologous_distance = np.sum(node_distance * n_intersect_mask[:-1])
gene_cnt1 = np.sum(np.all(~np.isnan(ar1), axis=1))
gene_cnt2 = np.sum(np.all(~np.isnan(ar2), axis=1))
max_cnt = np.maximum(gene_cnt1, gene_cnt2)
val = non_homologous_cnt * disjoint_coe + homologous_distance * compatibility_coe
return np.where(max_cnt == 0, 0, val / max_cnt) # consider the case that both genome has no gene
def batch_homologous_node_distance(b_n1, b_n2):
res = []
for n1, n2 in zip(b_n1, b_n2):

View File

@@ -6,7 +6,8 @@ from jax import numpy as jnp, Array
from jax import jit
I_INT = jnp.iinfo(jnp.int32).max # infinite int
EMPTY_NODE = jnp.full((1, 5), jnp.nan)
EMPTY_CON = jnp.full((1, 4), jnp.nan)
@jit
def flatten_connections(keys, connections):

View File

@@ -1,12 +1,12 @@
from typing import List, Union, Tuple, Callable
import time
import numpy as np
import jax
from .species import SpeciesController
from .genome.numpy import create_initialize_function, create_mutate_function, create_forward_function
from .genome.numpy import batch_crossover
from .genome.numpy import expand, expand_single, pop_analysis
from .genome import create_initialize_function, create_mutate_function, create_forward_function
from .genome import batch_crossover
from .genome import expand, expand_single, pop_analysis
from .genome.origin_neat import *
@@ -19,7 +19,8 @@ class Pipeline:
Neat algorithm pipeline.
"""
def __init__(self, config):
def __init__(self, config, seed=42):
self.randkey = jax.random.PRNGKey(seed)
self.config = config
self.N = config.basic.init_maximum_nodes
@@ -69,23 +70,23 @@ class Pipeline:
self.update_next_generation(crossover_pair)
analysis = pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx)
# analysis = pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx)
try:
for nodes, connections in zip(self.pop_nodes, self.pop_connections):
g = array2object(self.config, nodes, connections)
print(g)
net = FeedForwardNetwork.create(g)
real_out = [net.activate(x) for x in xor_inputs]
func = create_forward_function(nodes, connections, self.N, self.input_idx, self.output_idx, batch=True)
out = func(xor_inputs)
real_out = np.array(real_out)
out = np.array(out)
print(real_out, out)
assert np.allclose(real_out, out)
except AssertionError:
np.save("err_nodes.npy", self.pop_nodes)
np.save("err_connections.npy", self.pop_connections)
# try:
# for nodes, connections in zip(self.pop_nodes, self.pop_connections):
# g = array2object(self.config, nodes, connections)
# print(g)
# net = FeedForwardNetwork.create(g)
# real_out = [net.activate(x) for x in xor_inputs]
# func = create_forward_function(nodes, connections, self.N, self.input_idx, self.output_idx, batch=True)
# out = func(xor_inputs)
# real_out = np.array(real_out)
# out = np.array(out)
# print(real_out, out)
# assert np.allclose(real_out, out)
# except AssertionError:
# np.save("err_nodes.npy", self.pop_nodes)
# np.save("err_connections.npy", self.pop_connections)
# print(g)
@@ -93,7 +94,6 @@ class Pipeline:
self.expand()
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config.neat.population.generation_limit):
forward_func = self.ask(batch=True)
@@ -109,7 +109,6 @@ class Pipeline:
self.tell(fitnesses)
print("Generation limit reached!")
def update_next_generation(self, crossover_pair: List[Union[int, Tuple[int, int]]]) -> None:
"""
create the next generation
@@ -117,6 +116,7 @@ class Pipeline:
"""
assert self.pop_nodes.shape[0] == self.pop_size
k1, k2, self.randkey = jax.random.split(self.randkey, 3)
# crossover
# prepare elitism mask and crossover pair
@@ -127,19 +127,20 @@ class Pipeline:
crossover_pair[i] = (pair, pair)
crossover_pair = np.array(crossover_pair)
crossover_rand_keys = jax.random.split(k1, self.pop_size)
# batch crossover
wpn = self.pop_nodes[crossover_pair[:, 0]] # winner pop nodes
wpc = self.pop_connections[crossover_pair[:, 0]] # winner pop connections
lpn = self.pop_nodes[crossover_pair[:, 1]] # loser pop nodes
lpc = self.pop_connections[crossover_pair[:, 1]] # loser pop connections
# npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
npn, npc = batch_crossover(wpn, wpc, lpn, lpc)
# print(pop_analysis(npn, npc, self.input_idx, self.output_idx))
npn, npc = batch_crossover(crossover_rand_keys, wpn, wpc, lpn, lpc) # new pop nodes, new pop connections
# mutate
mutate_rand_keys = jax.random.split(k2, self.pop_size)
new_node_keys = np.array(self.fetch_new_node_keys())
m_npn, m_npc = self.mutate_func(npn, npc, new_node_keys) # mutate_new_pop_nodes
m_npn, m_npc = self.mutate_func(mutate_rand_keys, npn, npc, new_node_keys) # mutate_new_pop_nodes
# elitism don't mutate
# (pop_size, ) to (pop_size, 1, 1)
@@ -156,7 +157,6 @@ class Pipeline:
unused.append(key)
self.new_node_keys_pool = unused + self.new_node_keys_pool
def expand(self):
"""
Expand the population if needed.
@@ -176,7 +176,6 @@ class Pipeline:
for s in self.species_controller.species.values():
s.representative = expand_single(*s.representative, self.N)
def fetch_new_node_keys(self):
# if remain unused keys are not enough, create new keys
if len(self.new_node_keys_pool) < self.pop_size:
@@ -189,7 +188,6 @@ class Pipeline:
self.new_node_keys_pool = self.new_node_keys_pool[self.pop_size:]
return res
def default_analysis(self, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
species_sizes = [len(s.members) for s in self.species_controller.species.values()]

View File

@@ -1,9 +1,11 @@
from typing import List, Tuple, Dict, Union
from itertools import count
import jax
import numpy as np
from numpy.typing import NDArray
from .genome.numpy import distance
from .genome import distance
class Species(object):
@@ -45,6 +47,9 @@ class SpeciesController:
self.species_idxer = count(0)
self.species: Dict[int, Species] = {} # species_id -> species
self.distance = distance
self.o2m_distance = jax.vmap(distance, in_axes=(None, None, 0, 0))
def speciate(self, pop_nodes: NDArray, pop_connections: NDArray, generation: int) -> None:
"""
:param pop_nodes:
@@ -62,7 +67,7 @@ class SpeciesController:
for sid, species in self.species.items():
# calculate the distance between the representative and the population
r_nodes, r_connections = species.representative
distances = o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections)
distances = self.o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections)
min_idx = find_min_with_mask(distances, unspeciated) # find the min un-specified distance
new_representatives[sid] = min_idx
@@ -75,7 +80,7 @@ class SpeciesController:
if previous_species_list: # exist previous species
rid_list = [new_representatives[sid] for sid in previous_species_list]
res_pop_distance = [
o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
self.o2m_distance(pop_nodes[rid], pop_connections[rid], pop_nodes, pop_connections)
for rid in rid_list
]
@@ -102,7 +107,7 @@ class SpeciesController:
sid, rid = list(zip(*[(k, v) for k, v in new_representatives.items()]))
distances = [
distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
self.distance(pop_nodes[i], pop_connections[i], pop_nodes[r], pop_connections[r])
for r in rid
]
distances = np.array(distances)
@@ -314,16 +319,3 @@ def sort_element_with_fitnesses(members: List[int], fitnesses: List[float]) \
sorted_members = [item[0] for item in sorted_combined]
sorted_fitnesses = [item[1] for item in sorted_combined]
return sorted_members, sorted_fitnesses
def o2m_distance(r_nodes, r_connections, pop_nodes, pop_connections):
distances = []
for nodes, connections in zip(pop_nodes, pop_connections):
d = distance(r_nodes, r_connections, nodes, connections)
if d < 0:
d = distance(r_nodes, r_connections, nodes, connections)
print(d)
assert False
distances.append(d)
distances = np.stack(distances, axis=0)
return distances

View File

@@ -17,7 +17,7 @@ def evaluate(forward_func: Callable) -> List[float]:
:return:
"""
outs = forward_func(xor_inputs)
fitnesses = 4 - np.sum(np.abs(outs - xor_outputs), axis=(1, 2))
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
# print(fitnesses)
return fitnesses.tolist() # returns a list
@@ -26,7 +26,7 @@ def evaluate(forward_func: Callable) -> List[float]:
@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
def main():
config = Configer.load_config()
pipeline = Pipeline(config)
pipeline = Pipeline(config, seed=123123)
pipeline.auto_run(evaluate)
# for _ in range(100):
@@ -38,5 +38,4 @@ def main():
if __name__ == '__main__':
np.random.seed(63124326)
main()