perfect! fix bug about jax auto recompile

add task xor-3d
This commit is contained in:
wls2002
2023-07-02 22:15:26 +08:00
parent e711146f41
commit c4d34e877b
11 changed files with 234 additions and 104 deletions

View File

@@ -34,8 +34,6 @@ def get_fitnesses(pop_nodes, pop_cons, pop_unflatten_connections, pop_topologica
return evaluate(func)
def equal(ar1, ar2):
if ar1.shape != ar2.shape:
return False

View File

@@ -2,4 +2,4 @@
forward_way = "common"
[population]
fitness_threshold = 3.9999
fitness_threshold = 4

View File

@@ -2,7 +2,6 @@ import jax
import numpy as np
from configs import Configer
from algorithms.neat import Genome
from pipeline import Pipeline
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
@@ -22,10 +21,10 @@ def evaluate(forward_func):
def main():
config = Configer.load_config("xor.ini")
pipeline = Pipeline(config, seed=6)
pipeline = Pipeline(config)
nodes, cons = pipeline.auto_run(evaluate)
g = Genome(nodes, cons, config)
print(g)
# g = Genome(nodes, cons, config)
# print(g)
if __name__ == '__main__':

47
examples/xor3d.ini Normal file
View File

@@ -0,0 +1,47 @@
[basic]
num_inputs = 3
num_outputs = 1
maximum_nodes = 50
maximum_connections = 50
maximum_species = 10
forward_way = "common"
batch_size = 4
random_seed = 42
[population]
fitness_threshold = 8
generation_limit = 1000
fitness_criterion = "max"
pop_size = 100000
[genome]
compatibility_disjoint = 1.0
compatibility_weight = 0.5
conn_add_prob = 0.4
conn_add_trials = 1
conn_delete_prob = 0
node_add_prob = 0.2
node_delete_prob = 0
[species]
compatibility_threshold = 3
species_elitism = 1
max_stagnation = 15
genome_elitism = 2
survival_threshold = 0.2
min_species_size = 1
spawn_number_move_rate = 0.5
[gene-bias]
bias_init_mean = 0.0
bias_init_std = 1.0
bias_mutate_power = 0.5
bias_mutate_rate = 0.7
bias_replace_rate = 0.1
[gene-weight]
weight_init_mean = 0.0
weight_init_std = 1.0
weight_mutate_power = 0.5
weight_mutate_rate = 0.8
weight_replace_rate = 0.1

31
examples/xor3d.py Normal file
View File

@@ -0,0 +1,31 @@
import jax
import numpy as np
from configs import Configer
from pipeline import Pipeline
xor_inputs = np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0], [1], [0], [0], [1]], dtype=np.float32)
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 8 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses
def main():
config = Configer.load_config("xor3d.ini")
pipeline = Pipeline(config)
nodes, cons = pipeline.auto_run(evaluate)
# g = Genome(nodes, cons, config)
# print(g)
if __name__ == '__main__':
main()