adjust parameter for xor problem
This commit is contained in:
@@ -26,7 +26,7 @@ def distance(nodes1: Array, connections1: Array, nodes2: Array, connections2: Ar
|
||||
def node_distance(nodes1, nodes2, disjoint_coe=1., compatibility_coe=0.5):
|
||||
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
|
||||
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
|
||||
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
|
||||
max_cnt = jnp.maximum(node_cnt1, node_cnt2) - 2
|
||||
|
||||
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
|
||||
keys = nodes[:, 0]
|
||||
@@ -72,6 +72,7 @@ def connection_distance(cons1, cons2, disjoint_coe=1., compatibility_coe=0.5):
|
||||
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt)
|
||||
|
||||
|
||||
@vmap
|
||||
def batch_homologous_node_distance(b_n1, b_n2):
|
||||
return homologous_node_distance(b_n1, b_n2)
|
||||
|
||||
@@ -16,7 +16,7 @@ def object2array(genome, N):
|
||||
nodes = np.full((N, 5), np.nan)
|
||||
connections = np.full((2, N, N), np.nan)
|
||||
|
||||
assert len(genome.nodes) + len(genome.input_keys) + 1 <= N # remain one inf row for mutation adding extra node
|
||||
assert len(genome.nodes) + len(genome.input_keys) <= N # remain one inf row for mutation adding extra node
|
||||
|
||||
idx = 0
|
||||
n2i = {}
|
||||
|
||||
@@ -18,8 +18,8 @@ class Genome:
|
||||
# Fitness results.
|
||||
self.fitness = None
|
||||
|
||||
self.input_keys = [-i - 1 for i in range(config.basic.num_inputs)]
|
||||
self.output_keys = [i for i in range(config.basic.num_outputs)]
|
||||
# self.input_keys = [-i - 1 for i in range(config.basic.num_inputs)]
|
||||
# self.output_keys = [i for i in range(config.basic.num_outputs)]
|
||||
|
||||
if init_val:
|
||||
self.initialize()
|
||||
|
||||
@@ -6,7 +6,7 @@ import jax
|
||||
from .species import SpeciesController
|
||||
from .genome import create_initialize_function, create_mutate_function, create_forward_function
|
||||
from .genome import batch_crossover
|
||||
from .genome import expand, expand_single, pop_analysis
|
||||
from .genome import expand, expand_single, distance
|
||||
|
||||
from .genome.origin_neat import *
|
||||
|
||||
@@ -53,14 +53,6 @@ class Pipeline:
|
||||
return func
|
||||
|
||||
def tell(self, fitnesses):
|
||||
# idx = np.argmax(fitnesses)
|
||||
# print(f"argmax: {idx}, max: {np.max(fitnesses)}, a_max: {fitnesses[idx]}")
|
||||
# n, c = self.pop_nodes[idx], self.pop_connections[idx]
|
||||
# func = create_forward_function(n, c, self.N, self.input_idx, self.output_idx, batch=True)
|
||||
# out = func(xor_inputs)
|
||||
# print(f"max fitness: {fitnesses[idx]}")
|
||||
# print(f"real fitness: {4 - np.sum(np.abs(out - xor_outputs), axis=0)}")
|
||||
# print(f"Out:\n{func(np.array([[0, 0], [0, 1], [1, 0], [1, 1]]))}")
|
||||
|
||||
self.generation += 1
|
||||
|
||||
@@ -70,6 +62,18 @@ class Pipeline:
|
||||
|
||||
self.update_next_generation(crossover_pair)
|
||||
|
||||
# for i in range(self.pop_size):
|
||||
# for j in range(self.pop_size):
|
||||
# n1, c1 = self.pop_nodes[i], self.pop_connections[i]
|
||||
# n2, c2 = self.pop_nodes[j], self.pop_connections[j]
|
||||
# g1 = array2object(self.config.neat, n1, c1)
|
||||
# g2 = array2object(self.config.neat, n2, c2)
|
||||
# d_real = g1.distance(g2)
|
||||
# d = distance(n1, c1, n2, c2)
|
||||
# print(d_real, d)
|
||||
# assert np.allclose(d_real, d)
|
||||
|
||||
|
||||
# analysis = pop_analysis(self.pop_nodes, self.pop_connections, self.input_idx, self.output_idx)
|
||||
|
||||
# try:
|
||||
|
||||
@@ -26,7 +26,7 @@ def evaluate(forward_func: Callable) -> List[float]:
|
||||
@partial(using_cprofile, root_abs_path='/mnt/e/neat-jax/', replace_pattern="/mnt/e/neat-jax/")
|
||||
def main():
|
||||
config = Configer.load_config()
|
||||
pipeline = Pipeline(config, seed=123123)
|
||||
pipeline = Pipeline(config, seed=11323)
|
||||
pipeline.auto_run(evaluate)
|
||||
|
||||
# for _ in range(100):
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
"basic": {
|
||||
"num_inputs": 2,
|
||||
"num_outputs": 1,
|
||||
"init_maximum_nodes": 5,
|
||||
"expands_coe": 1.5
|
||||
"init_maximum_nodes": 10,
|
||||
"expands_coe": 2
|
||||
},
|
||||
"neat": {
|
||||
"population": {
|
||||
@@ -39,7 +39,7 @@
|
||||
"mutate_rate": 0.01
|
||||
},
|
||||
"weight": {
|
||||
"init_mean": 0.0,
|
||||
"init_mean": 1.0,
|
||||
"init_stdev": 1.0,
|
||||
"mutate_power": 0.5,
|
||||
"mutate_rate": 0.8,
|
||||
@@ -59,7 +59,7 @@
|
||||
"node_delete_prob": 0.2
|
||||
},
|
||||
"species": {
|
||||
"compatibility_threshold": 3.5,
|
||||
"compatibility_threshold": 3,
|
||||
"species_fitness_func": "max",
|
||||
"max_stagnation": 20,
|
||||
"species_elitism": 2,
|
||||
|
||||
Reference in New Issue
Block a user