complete HyperNEAT!
This commit is contained in:
11
examples/a.py
Normal file
11
examples/a.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
|
||||
a = jnp.zeros((5, 5))
|
||||
k1 = jnp.array([1, 2, 3])
|
||||
k2 = jnp.array([2, 3, 4])
|
||||
v = jnp.array([1, 1, 1])
|
||||
|
||||
a = a.at[k1, k2].set(v)
|
||||
|
||||
print(a)
|
||||
@@ -1,13 +1,44 @@
|
||||
import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.tree_util import register_pytree_node_class
|
||||
|
||||
|
||||
vals = np.array([1, 2])
|
||||
weights = np.array([[0, 4], [5, 0]])
|
||||
@register_pytree_node_class
|
||||
class Genome:
|
||||
def __init__(self, nodes, conns):
|
||||
self.nodes = nodes
|
||||
self.conns = conns
|
||||
|
||||
ins1 = vals * weights[:, 0]
|
||||
ins2 = vals * weights[:, 1]
|
||||
ins_all = vals * weights.T
|
||||
def update_nodes(self, nodes):
|
||||
return Genome(nodes, self.conns)
|
||||
|
||||
print(ins1)
|
||||
print(ins2)
|
||||
print(ins_all)
|
||||
def update_conns(self, conns):
|
||||
return Genome(self.nodes, conns)
|
||||
|
||||
def tree_flatten(self):
|
||||
children = self.nodes, self.conns
|
||||
aux_data = None
|
||||
return children, aux_data
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
return cls(*children)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Genome ({self.nodes}, \n\t{self.conns})"
|
||||
|
||||
@jax.jit
|
||||
def add_node(self, a: int):
|
||||
nodes = self.nodes.at[0, :].set(a)
|
||||
return self.update_nodes(nodes)
|
||||
|
||||
|
||||
nodes, conns = jnp.array([[1, 2, 3, 4, 5]]), jnp.array([[1, 2, 3, 4]])
|
||||
g = Genome(nodes, conns)
|
||||
print(g)
|
||||
|
||||
g = g.add_node(1)
|
||||
print(g)
|
||||
|
||||
g = jax.jit(g.add_node)(2)
|
||||
print(g)
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
from algorithm.state import State
|
||||
|
||||
|
||||
@jax.jit
|
||||
def func(state: State, a):
|
||||
return state.update(a=a)
|
||||
|
||||
|
||||
state = State(c=1, b=2)
|
||||
print(state)
|
||||
|
||||
vmap_func = jax.vmap(func, in_axes=(None, 0))
|
||||
print(vmap_func(state, jnp.array([1, 2, 3])))
|
||||
@@ -1,7 +1,12 @@
|
||||
[basic]
|
||||
forward_way = "common"
|
||||
network_type = "recurrent"
|
||||
activate_times = 5
|
||||
fitness_threshold = 4
|
||||
|
||||
[population]
|
||||
fitness_threshold = 4
|
||||
pop_size = 1000
|
||||
|
||||
[neat]
|
||||
network_type = "recurrent"
|
||||
num_inputs = 4
|
||||
num_outputs = 1
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
from algorithm import Configer, NEAT
|
||||
from algorithm.neat import NormalGene, RecurrentGene, Pipeline
|
||||
from pipeline import Pipeline
|
||||
from config import Configer
|
||||
from algorithm import NEAT
|
||||
from algorithm.neat import RecurrentGene
|
||||
|
||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||
@@ -21,7 +23,6 @@ def evaluate(forward_func):
|
||||
|
||||
def main():
|
||||
config = Configer.load_config("xor.ini")
|
||||
# algorithm = NEAT(config, NormalGene)
|
||||
algorithm = NEAT(config, RecurrentGene)
|
||||
pipeline = Pipeline(config, algorithm)
|
||||
best = pipeline.auto_run(evaluate)
|
||||
|
||||
33
examples/xor_hyperneat.py
Normal file
33
examples/xor_hyperneat.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
from pipeline import Pipeline
|
||||
from config import Configer
|
||||
from algorithm import NEAT, HyperNEAT
|
||||
from algorithm.neat import RecurrentGene
|
||||
from algorithm.hyperneat import BaseSubstrate
|
||||
|
||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||
|
||||
|
||||
def evaluate(forward_func):
|
||||
"""
|
||||
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
|
||||
:return:
|
||||
"""
|
||||
outs = forward_func(xor_inputs)
|
||||
outs = jax.device_get(outs)
|
||||
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||
return fitnesses
|
||||
|
||||
|
||||
def main():
|
||||
config = Configer.load_config("xor.ini")
|
||||
algorithm = HyperNEAT(config, RecurrentGene, BaseSubstrate)
|
||||
pipeline = Pipeline(config, algorithm)
|
||||
pipeline.auto_run(evaluate)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,50 +0,0 @@
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
from algorithm.config import Configer
|
||||
from algorithm.neat import NEAT, NormalGene, RecurrentGene, Pipeline
|
||||
from algorithm.neat.genome import create_mutate
|
||||
|
||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||
|
||||
|
||||
def single_genome(func, nodes, conns):
|
||||
t = RecurrentGene.forward_transform(nodes, conns)
|
||||
out1 = func(xor_inputs[0], t)
|
||||
out2 = func(xor_inputs[1], t)
|
||||
out3 = func(xor_inputs[2], t)
|
||||
out4 = func(xor_inputs[3], t)
|
||||
print(out1, out2, out3, out4)
|
||||
|
||||
|
||||
def batch_genome(func, nodes, conns):
|
||||
t = NormalGene.forward_transform(nodes, conns)
|
||||
out = jax.vmap(func, in_axes=(0, None))(xor_inputs, t)
|
||||
print(out)
|
||||
|
||||
|
||||
def pop_batch_genome(func, pop_nodes, pop_conns):
|
||||
t = jax.vmap(NormalGene.forward_transform)(pop_nodes, pop_conns)
|
||||
func = jax.vmap(jax.vmap(func, in_axes=(0, None)), in_axes=(None, 0))
|
||||
out = func(xor_inputs, t)
|
||||
print(out)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = Configer.load_config("xor.ini")
|
||||
# neat = NEAT(config, NormalGene)
|
||||
neat = NEAT(config, RecurrentGene)
|
||||
randkey = jax.random.PRNGKey(42)
|
||||
state = neat.setup(randkey)
|
||||
forward_func = RecurrentGene.create_forward(config)
|
||||
mutate_func = create_mutate(config, RecurrentGene)
|
||||
|
||||
nodes, conns = state.pop_nodes[0], state.pop_conns[0]
|
||||
single_genome(forward_func, nodes, conns)
|
||||
# batch_genome(forward_func, nodes, conns)
|
||||
|
||||
nodes, conns = mutate_func(state, randkey, nodes, conns, 10000)
|
||||
single_genome(forward_func, nodes, conns)
|
||||
|
||||
# batch_genome(forward_func, nodes, conns)
|
||||
#
|
||||
Reference in New Issue
Block a user