initial commit

This commit is contained in:
wls2002
2023-05-05 14:19:13 +08:00
commit 6faa07f507
43 changed files with 2517 additions and 0 deletions

0
examples/__init__.py Normal file
View File

71
examples/genome_test.py Normal file
View File

@@ -0,0 +1,71 @@
import time
import jax.random
from utils import Configer
from algorithms.neat.genome.genome import *
from algorithms.neat.species import SpeciesController
from algorithms.neat.genome.forward import create_forward_function
from algorithms.neat.genome.mutate import create_mutate_function
if __name__ == '__main__':
N = 10
pop_nodes, pop_connections, input_idx, output_idx = initialize_genomes(10000, N, 2, 1,
default_act=9, default_agg=0)
inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
# forward = create_forward_function(pop_nodes, pop_connections, 5, input_idx, output_idx, batch=True)
nodes, connections = pop_nodes[0], pop_connections[0]
forward = create_forward_function(pop_nodes, pop_connections, N, input_idx, output_idx, batch=True)
out = forward(inputs)
print(out.shape)
print(out)
config = Configer.load_config()
s_c = SpeciesController(config.neat)
s_c.speciate(pop_nodes, pop_connections, 0)
s_c.speciate(pop_nodes, pop_connections, 0)
print(s_c.genome_to_species)
start = time.time()
for i in range(100):
print(i)
s_c.speciate(pop_nodes, pop_connections, i)
print(time.time() - start)
seed = jax.random.PRNGKey(42)
mutate_func = create_mutate_function(config, input_idx, output_idx, batch=False)
print(nodes, connections, sep='\n')
print(*mutate_func(seed, nodes, connections, 100), sep='\n')
randseeds = jax.random.split(seed, 10000)
new_node_keys = jax.random.randint(randseeds[0], minval=0, maxval=10000, shape=(10000,))
batch_mutate_func = create_mutate_function(config, input_idx, output_idx, batch=True)
pop_nodes, pop_connections = batch_mutate_func(randseeds, pop_nodes, pop_connections, new_node_keys)
print(pop_nodes, pop_connections, sep='\n')
start = time.time()
for i in range(100):
print(i)
pop_nodes, pop_connections = batch_mutate_func(randseeds, pop_nodes, pop_connections, new_node_keys)
print(time.time() - start)
print(nodes, connections, sep='\n')
nodes, connections = add_node(6, nodes, connections)
nodes, connections = add_node(7, nodes, connections)
print(nodes, connections, sep='\n')
nodes, connections = add_connection(6, 7, nodes, connections)
nodes, connections = add_connection(0, 7, nodes, connections)
nodes, connections = add_connection(1, 7, nodes, connections)
print(nodes, connections, sep='\n')
nodes, connections = delete_connection(6, 7, nodes, connections)
print(nodes, connections, sep='\n')
nodes, connections = delete_node(6, nodes, connections)
print(nodes, connections, sep='\n')
nodes, connections = delete_node(7, nodes, connections)
print(nodes, connections, sep='\n')

View File

@@ -0,0 +1,37 @@
import jax
import jax.numpy as jnp
import numpy as np
from jax import random
from jax import vmap, jit
def plus1(x):
return x + 1
def minus1(x):
return x - 1
def func(rand_key, x):
r = jax.random.uniform(rand_key, shape=())
return jax.lax.cond(r > 0.5, plus1, minus1, x)
def func2(rand_key):
r = jax.random.uniform(rand_key, ())
if r < 0.3:
return 1
elif r < 0.5:
return 2
else:
return 3
key = random.PRNGKey(0)
print(func(key, 0))
batch_func = vmap(jit(func))
keys = random.split(key, 100)
print(batch_func(keys, jnp.zeros(100)))

40
examples/xor.py Normal file
View File

@@ -0,0 +1,40 @@
from typing import Callable, List
import jax
import numpy as np
from utils import Configer
from algorithms.neat import Pipeline
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
xor_outputs = np.array([[0], [1], [1], [0]])
def evaluate(forward_func: Callable) -> List[float]:
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = np.mean((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses.tolist() # returns a list
def main():
config = Configer.load_config()
pipeline = Pipeline(config)
forward_func = pipeline.ask(batch=True)
fitnesses = evaluate(forward_func)
pipeline.tell(fitnesses)
# for i in range(100):
# forward_func = pipeline.ask(batch=True)
# fitnesses = evaluate(forward_func)
# pipeline.tell(fitnesses)
if __name__ == '__main__':
main()