hyper neat
This commit is contained in:
@@ -12,11 +12,10 @@ print(asdict(config))
|
||||
pop_nodes = jnp.ones((Config.basic.pop_size, Config.neat.maximum_nodes, 3))
|
||||
pop_conns = jnp.ones((Config.basic.pop_size, Config.neat.maximum_conns, 5))
|
||||
|
||||
pop_genomes1 = jax.vmap(Genome)(pop_nodes, pop_conns)
|
||||
pop_genomes2 = Genome(pop_nodes, pop_conns)
|
||||
pop_genomes = Genome(pop_nodes, pop_conns)
|
||||
|
||||
print(pop_genomes)
|
||||
print(pop_genomes[0])
|
||||
print(pop_genomes[0: 20])
|
||||
|
||||
@jax.vmap
|
||||
def pop_cnts(genome):
|
||||
|
||||
@@ -15,5 +15,9 @@ def func(d):
|
||||
|
||||
|
||||
d = {0: 1, 1: NetworkType.ANN.value}
|
||||
n = None
|
||||
|
||||
print(n or d)
|
||||
print(d)
|
||||
|
||||
print(func(d))
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
from config import Config, BasicConfig
|
||||
from config import Config, BasicConfig, NeatConfig
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
||||
from algorithm.neat.neat import NEAT
|
||||
from algorithm import NEAT, NormalGene, NormalGeneConfig
|
||||
|
||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||
@@ -23,8 +22,14 @@ def evaluate(forward_func):
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = Config(
|
||||
basic=BasicConfig(fitness_target=4),
|
||||
gene=NormalGeneConfig()
|
||||
basic=BasicConfig(
|
||||
fitness_target=3.99999,
|
||||
pop_size=10000
|
||||
),
|
||||
neat=NeatConfig(
|
||||
maximum_nodes=50,
|
||||
maximum_conns=100,
|
||||
)
|
||||
)
|
||||
algorithm = NEAT(config, NormalGene)
|
||||
pipeline = Pipeline(config, algorithm)
|
||||
|
||||
49
examples/xor_hyperNEAT.py
Normal file
49
examples/xor_hyperNEAT.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
from config import Config, BasicConfig, NeatConfig
|
||||
from pipeline import Pipeline
|
||||
from algorithm import NEAT, RecurrentGene, RecurrentGeneConfig
|
||||
from algorithm import HyperNEAT, NormalSubstrate, NormalSubstrateConfig
|
||||
|
||||
|
||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||
|
||||
|
||||
def evaluate(forward_func):
|
||||
"""
|
||||
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
|
||||
:return:
|
||||
"""
|
||||
outs = forward_func(xor_inputs)
|
||||
outs = jax.device_get(outs)
|
||||
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||
return fitnesses
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = Config(
|
||||
basic=BasicConfig(
|
||||
fitness_target=3.99999,
|
||||
pop_size=1000
|
||||
),
|
||||
neat=NeatConfig(
|
||||
network_type="recurrent",
|
||||
maximum_nodes=50,
|
||||
maximum_conns=100,
|
||||
inputs=4,
|
||||
outputs=1
|
||||
|
||||
),
|
||||
gene=RecurrentGeneConfig(
|
||||
activation_default="tanh",
|
||||
activation_options=("tanh", ),
|
||||
),
|
||||
substrate=NormalSubstrateConfig(),
|
||||
)
|
||||
neat = NEAT(config, RecurrentGene)
|
||||
hyperNEAT = HyperNEAT(config, neat, NormalSubstrate)
|
||||
|
||||
pipeline = Pipeline(config, hyperNEAT)
|
||||
pipeline.auto_run(evaluate)
|
||||
39
examples/xor_recurrent.py
Normal file
39
examples/xor_recurrent.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
from config import Config, BasicConfig, NeatConfig
|
||||
from pipeline import Pipeline
|
||||
from algorithm import NEAT, RecurrentGene, RecurrentGeneConfig
|
||||
|
||||
|
||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||
|
||||
|
||||
def evaluate(forward_func):
|
||||
"""
|
||||
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
|
||||
:return:
|
||||
"""
|
||||
outs = forward_func(xor_inputs)
|
||||
outs = jax.device_get(outs)
|
||||
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||
return fitnesses
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = Config(
|
||||
basic=BasicConfig(
|
||||
fitness_target=3.99999,
|
||||
pop_size=10000
|
||||
),
|
||||
neat=NeatConfig(
|
||||
network_type="recurrent",
|
||||
maximum_nodes=50,
|
||||
maximum_conns=100
|
||||
),
|
||||
gene=RecurrentGeneConfig()
|
||||
)
|
||||
algorithm = NEAT(config, RecurrentGene)
|
||||
pipeline = Pipeline(config, algorithm)
|
||||
pipeline.auto_run(evaluate)
|
||||
Reference in New Issue
Block a user