remove create_func....

This commit is contained in:
wls2002
2023-08-02 15:02:08 +08:00
parent 1499e062fe
commit c7fb1ddabe
22 changed files with 425 additions and 21 deletions

View File

@@ -1,24 +1,31 @@
from functools import partial
import jax
from utils import unflatten_conns, act, agg, Activation, Aggregation
from algorithm.neat.gene import RecurrentGeneConfig
config = RecurrentGeneConfig(
activation_options=("tanh", "sigmoid"),
activation_default="tanh",
)
class A:
def __init__(self):
self.a = 1
self.b = 2
self.act_funcs = [Activation.name2func[name] for name in config.activation_options]
self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
self.isTrue = False
@partial(jax.jit, static_argnums=(0,))
def step(self):
if self.isTrue:
return self.a + 1
else:
return self.b + 1
i = jax.numpy.array([0, 1])
z = jax.numpy.array([
[1, 1],
[2, 2]
])
print(self.act_funcs)
return jax.vmap(act, in_axes=(0, 0, None))(i, z, self.act_funcs)
AA = A()
print(AA.step(), hash(AA))
print(AA.step(), hash(AA))
print(AA.step(), hash(AA))
AA.a = (2, 3, 4)
print(AA.step(), hash(AA))
print(AA.step())

View File

@@ -28,11 +28,13 @@ if __name__ == '__main__':
pop_size=10000
),
neat=NeatConfig(
maximum_nodes=20,
maximum_conns=50,
)
maximum_nodes=50,
maximum_conns=100,
compatibility_threshold=4
),
gene=NormalGeneConfig()
)
normal_gene = NormalGene(NormalGeneConfig())
algorithm = NEAT(config, normal_gene)
algorithm = NEAT(config, NormalGene)
pipeline = Pipeline(config, algorithm)
pipeline.auto_run(evaluate)

49
examples/xor_hyperneat.py Normal file
View File

@@ -0,0 +1,49 @@
import jax
import numpy as np
from config import Config, BasicConfig, NeatConfig
from pipeline import Pipeline
from algorithm import NEAT, HyperNEAT
from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig
from algorithm.hyperneat.substrate import NormalSubstrate, NormalSubstrateConfig
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses
if __name__ == '__main__':
config = Config(
basic=BasicConfig(
fitness_target=3.99999,
pop_size=10000
),
neat=NeatConfig(
network_type="recurrent",
maximum_nodes=50,
maximum_conns=100,
inputs=4,
outputs=1
),
gene=RecurrentGeneConfig(
activation_default="tanh",
activation_options=("tanh",),
),
substrate=NormalSubstrateConfig(),
)
neat = NEAT(config, RecurrentGene)
hyperNEAT = HyperNEAT(config, neat, NormalSubstrate)
pipeline = Pipeline(config, hyperNEAT)
pipeline.auto_run(evaluate)

42
examples/xor_recurrent.py Normal file
View File

@@ -0,0 +1,42 @@
import jax
import numpy as np
from config import Config, BasicConfig, NeatConfig
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses
if __name__ == '__main__':
config = Config(
basic=BasicConfig(
fitness_target=3.99999,
pop_size=10000
),
neat=NeatConfig(
network_type="recurrent",
maximum_nodes=50,
maximum_conns=100
),
gene=RecurrentGeneConfig(
activate_times=3
)
)
algorithm = NEAT(config, RecurrentGene)
pipeline = Pipeline(config, algorithm)
pipeline.auto_run(evaluate)