initial commit
This commit is contained in:
0
examples/__init__.py
Normal file
0
examples/__init__.py
Normal file
71
examples/genome_test.py
Normal file
71
examples/genome_test.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import time
|
||||
|
||||
import jax.random
|
||||
|
||||
from utils import Configer
|
||||
from algorithms.neat.genome.genome import *
|
||||
|
||||
from algorithms.neat.species import SpeciesController
|
||||
from algorithms.neat.genome.forward import create_forward_function
|
||||
from algorithms.neat.genome.mutate import create_mutate_function
|
||||
|
||||
if __name__ == '__main__':
|
||||
N = 10
|
||||
pop_nodes, pop_connections, input_idx, output_idx = initialize_genomes(10000, N, 2, 1,
|
||||
default_act=9, default_agg=0)
|
||||
inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
|
||||
# forward = create_forward_function(pop_nodes, pop_connections, 5, input_idx, output_idx, batch=True)
|
||||
nodes, connections = pop_nodes[0], pop_connections[0]
|
||||
|
||||
forward = create_forward_function(pop_nodes, pop_connections, N, input_idx, output_idx, batch=True)
|
||||
out = forward(inputs)
|
||||
print(out.shape)
|
||||
print(out)
|
||||
|
||||
config = Configer.load_config()
|
||||
s_c = SpeciesController(config.neat)
|
||||
s_c.speciate(pop_nodes, pop_connections, 0)
|
||||
s_c.speciate(pop_nodes, pop_connections, 0)
|
||||
print(s_c.genome_to_species)
|
||||
|
||||
start = time.time()
|
||||
for i in range(100):
|
||||
print(i)
|
||||
s_c.speciate(pop_nodes, pop_connections, i)
|
||||
print(time.time() - start)
|
||||
|
||||
seed = jax.random.PRNGKey(42)
|
||||
mutate_func = create_mutate_function(config, input_idx, output_idx, batch=False)
|
||||
print(nodes, connections, sep='\n')
|
||||
print(*mutate_func(seed, nodes, connections, 100), sep='\n')
|
||||
|
||||
randseeds = jax.random.split(seed, 10000)
|
||||
new_node_keys = jax.random.randint(randseeds[0], minval=0, maxval=10000, shape=(10000,))
|
||||
batch_mutate_func = create_mutate_function(config, input_idx, output_idx, batch=True)
|
||||
pop_nodes, pop_connections = batch_mutate_func(randseeds, pop_nodes, pop_connections, new_node_keys)
|
||||
print(pop_nodes, pop_connections, sep='\n')
|
||||
|
||||
start = time.time()
|
||||
for i in range(100):
|
||||
print(i)
|
||||
pop_nodes, pop_connections = batch_mutate_func(randseeds, pop_nodes, pop_connections, new_node_keys)
|
||||
print(time.time() - start)
|
||||
|
||||
print(nodes, connections, sep='\n')
|
||||
nodes, connections = add_node(6, nodes, connections)
|
||||
nodes, connections = add_node(7, nodes, connections)
|
||||
print(nodes, connections, sep='\n')
|
||||
|
||||
nodes, connections = add_connection(6, 7, nodes, connections)
|
||||
nodes, connections = add_connection(0, 7, nodes, connections)
|
||||
nodes, connections = add_connection(1, 7, nodes, connections)
|
||||
print(nodes, connections, sep='\n')
|
||||
|
||||
nodes, connections = delete_connection(6, 7, nodes, connections)
|
||||
print(nodes, connections, sep='\n')
|
||||
|
||||
nodes, connections = delete_node(6, nodes, connections)
|
||||
print(nodes, connections, sep='\n')
|
||||
|
||||
nodes, connections = delete_node(7, nodes, connections)
|
||||
print(nodes, connections, sep='\n')
|
||||
37
examples/jax_playground.py
Normal file
37
examples/jax_playground.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax import random
|
||||
from jax import vmap, jit
|
||||
|
||||
|
||||
def plus1(x):
|
||||
return x + 1
|
||||
|
||||
|
||||
def minus1(x):
|
||||
return x - 1
|
||||
|
||||
|
||||
def func(rand_key, x):
|
||||
r = jax.random.uniform(rand_key, shape=())
|
||||
return jax.lax.cond(r > 0.5, plus1, minus1, x)
|
||||
|
||||
|
||||
def func2(rand_key):
|
||||
r = jax.random.uniform(rand_key, ())
|
||||
if r < 0.3:
|
||||
return 1
|
||||
elif r < 0.5:
|
||||
return 2
|
||||
else:
|
||||
return 3
|
||||
|
||||
|
||||
|
||||
key = random.PRNGKey(0)
|
||||
print(func(key, 0))
|
||||
|
||||
batch_func = vmap(jit(func))
|
||||
keys = random.split(key, 100)
|
||||
print(batch_func(keys, jnp.zeros(100)))
|
||||
40
examples/xor.py
Normal file
40
examples/xor.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import Callable, List
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
from utils import Configer
|
||||
from algorithms.neat import Pipeline
|
||||
|
||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
|
||||
xor_outputs = np.array([[0], [1], [1], [0]])
|
||||
|
||||
|
||||
def evaluate(forward_func: Callable) -> List[float]:
|
||||
"""
|
||||
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
|
||||
:return:
|
||||
"""
|
||||
outs = forward_func(xor_inputs)
|
||||
outs = jax.device_get(outs)
|
||||
fitnesses = np.mean((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||
return fitnesses.tolist() # returns a list
|
||||
|
||||
|
||||
def main():
|
||||
config = Configer.load_config()
|
||||
pipeline = Pipeline(config)
|
||||
forward_func = pipeline.ask(batch=True)
|
||||
fitnesses = evaluate(forward_func)
|
||||
pipeline.tell(fitnesses)
|
||||
|
||||
|
||||
|
||||
# for i in range(100):
|
||||
# forward_func = pipeline.ask(batch=True)
|
||||
# fitnesses = evaluate(forward_func)
|
||||
# pipeline.tell(fitnesses)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user