complete HyperNEAT!

This commit is contained in:
wls2002
2023-07-21 15:03:12 +08:00
parent 80ee5ea2ea
commit 48f90c7eef
32 changed files with 432 additions and 136 deletions

11
examples/a.py Normal file
View File

@@ -0,0 +1,11 @@
import numpy as np
import jax.numpy as jnp
a = jnp.zeros((5, 5))
k1 = jnp.array([1, 2, 3])
k2 = jnp.array([2, 3, 4])
v = jnp.array([1, 1, 1])
a = a.at[k1, k2].set(v)
print(a)

View File

@@ -1,13 +1,44 @@
import numpy as np
import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
vals = np.array([1, 2])
weights = np.array([[0, 4], [5, 0]])
@register_pytree_node_class
class Genome:
def __init__(self, nodes, conns):
self.nodes = nodes
self.conns = conns
ins1 = vals * weights[:, 0]
ins2 = vals * weights[:, 1]
ins_all = vals * weights.T
def update_nodes(self, nodes):
return Genome(nodes, self.conns)
print(ins1)
print(ins2)
print(ins_all)
def update_conns(self, conns):
return Genome(self.nodes, conns)
def tree_flatten(self):
children = self.nodes, self.conns
aux_data = None
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
def __repr__(self):
return f"Genome ({self.nodes}, \n\t{self.conns})"
@jax.jit
def add_node(self, a: int):
nodes = self.nodes.at[0, :].set(a)
return self.update_nodes(nodes)
nodes, conns = jnp.array([[1, 2, 3, 4, 5]]), jnp.array([[1, 2, 3, 4]])
g = Genome(nodes, conns)
print(g)
g = g.add_node(1)
print(g)
g = jax.jit(g.add_node)(2)
print(g)

View File

@@ -1,15 +0,0 @@
import jax
from jax import numpy as jnp
from algorithm.state import State
@jax.jit
def func(state: State, a):
return state.update(a=a)
state = State(c=1, b=2)
print(state)
vmap_func = jax.vmap(func, in_axes=(None, 0))
print(vmap_func(state, jnp.array([1, 2, 3])))

View File

@@ -1,7 +1,12 @@
[basic]
forward_way = "common"
network_type = "recurrent"
activate_times = 5
fitness_threshold = 4
[population]
fitness_threshold = 4
pop_size = 1000
[neat]
network_type = "recurrent"
num_inputs = 4
num_outputs = 1

View File

@@ -1,8 +1,10 @@
import jax
import numpy as np
from algorithm import Configer, NEAT
from algorithm.neat import NormalGene, RecurrentGene, Pipeline
from pipeline import Pipeline
from config import Configer
from algorithm import NEAT
from algorithm.neat import RecurrentGene
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
@@ -21,7 +23,6 @@ def evaluate(forward_func):
def main():
config = Configer.load_config("xor.ini")
# algorithm = NEAT(config, NormalGene)
algorithm = NEAT(config, RecurrentGene)
pipeline = Pipeline(config, algorithm)
best = pipeline.auto_run(evaluate)

33
examples/xor_hyperneat.py Normal file
View File

@@ -0,0 +1,33 @@
import jax
import numpy as np
from pipeline import Pipeline
from config import Configer
from algorithm import NEAT, HyperNEAT
from algorithm.neat import RecurrentGene
from algorithm.hyperneat import BaseSubstrate
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses
def main():
config = Configer.load_config("xor.ini")
algorithm = HyperNEAT(config, RecurrentGene, BaseSubstrate)
pipeline = Pipeline(config, algorithm)
pipeline.auto_run(evaluate)
if __name__ == '__main__':
main()

View File

@@ -1,50 +0,0 @@
import jax
import numpy as np
from algorithm.config import Configer
from algorithm.neat import NEAT, NormalGene, RecurrentGene, Pipeline
from algorithm.neat.genome import create_mutate
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
def single_genome(func, nodes, conns):
t = RecurrentGene.forward_transform(nodes, conns)
out1 = func(xor_inputs[0], t)
out2 = func(xor_inputs[1], t)
out3 = func(xor_inputs[2], t)
out4 = func(xor_inputs[3], t)
print(out1, out2, out3, out4)
def batch_genome(func, nodes, conns):
t = NormalGene.forward_transform(nodes, conns)
out = jax.vmap(func, in_axes=(0, None))(xor_inputs, t)
print(out)
def pop_batch_genome(func, pop_nodes, pop_conns):
t = jax.vmap(NormalGene.forward_transform)(pop_nodes, pop_conns)
func = jax.vmap(jax.vmap(func, in_axes=(0, None)), in_axes=(None, 0))
out = func(xor_inputs, t)
print(out)
if __name__ == '__main__':
config = Configer.load_config("xor.ini")
# neat = NEAT(config, NormalGene)
neat = NEAT(config, RecurrentGene)
randkey = jax.random.PRNGKey(42)
state = neat.setup(randkey)
forward_func = RecurrentGene.create_forward(config)
mutate_func = create_mutate(config, RecurrentGene)
nodes, conns = state.pop_nodes[0], state.pop_conns[0]
single_genome(forward_func, nodes, conns)
# batch_genome(forward_func, nodes, conns)
nodes, conns = mutate_func(state, randkey, nodes, conns, 10000)
single_genome(forward_func, nodes, conns)
# batch_genome(forward_func, nodes, conns)
#