remove create_func....

This commit is contained in:
wls2002
2023-08-04 17:29:36 +08:00
parent c7fb1ddabe
commit 0e44b13291
29 changed files with 591 additions and 259 deletions

View File

36
examples/func_fit/xor.py Normal file
View File

@@ -0,0 +1,36 @@
from config import *
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.func_fit import XOR, FuncFitConfig
if __name__ == '__main__':
config = Config(
basic=BasicConfig(
seed=42,
fitness_target=-1e-2,
pop_size=10000
),
neat=NeatConfig(
max_nodes=50,
max_conns=100,
max_species=30,
conn_add=0.8,
conn_delete=0,
node_add=0.4,
node_delete=0,
inputs=2,
outputs=1
),
gene=NormalGeneConfig(),
problem=FuncFitConfig(
error_method='rmse'
)
)
algorithm = NEAT(config, NormalGene)
pipeline = Pipeline(config, algorithm, XOR)
state = pipeline.setup()
pipeline.pre_compile(state)
state, best = pipeline.auto_run(state)
pipeline.show(state, best)

View File

@@ -0,0 +1,40 @@
from config import *
from pipeline import Pipeline
from algorithm.neat import NormalGene, NormalGeneConfig
from algorithm.hyperneat import HyperNEAT, NormalSubstrate, NormalSubstrateConfig
from problem.func_fit import XOR3d, FuncFitConfig
if __name__ == '__main__':
config = Config(
basic=BasicConfig(
seed=42,
fitness_target=0,
pop_size=1000
),
neat=NeatConfig(
max_nodes=50,
max_conns=100,
max_species=30,
inputs=4,
outputs=1
),
hyperneat=HyperNeatConfig(
inputs=3,
outputs=1
),
substrate=NormalSubstrateConfig(
input_coors=((-1, -1), (-0.5, -1), (0.5, -1), (1, -1)),
),
gene=NormalGeneConfig(
activation_default='tanh',
activation_options=('tanh', ),
),
problem=FuncFitConfig()
)
algorithm = HyperNEAT(config, NormalGene, NormalSubstrate)
pipeline = Pipeline(config, algorithm, XOR3d)
state = pipeline.setup()
state, best = pipeline.auto_run(state)
pipeline.show(state, best)

View File

@@ -0,0 +1,40 @@
from config import *
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig
from problem.func_fit import XOR3d, FuncFitConfig
if __name__ == '__main__':
config = Config(
basic=BasicConfig(
seed=42,
fitness_target=-1e-2,
generation_limit=300,
pop_size=1000
),
neat=NeatConfig(
network_type="recurrent",
max_nodes=50,
max_conns=100,
max_species=30,
conn_add=0.5,
conn_delete=0.5,
node_add=0.4,
node_delete=0.4,
inputs=3,
outputs=1
),
gene=RecurrentGeneConfig(
activate_times=10
),
problem=FuncFitConfig(
error_method='rmse'
)
)
algorithm = NEAT(config, RecurrentGene)
pipeline = Pipeline(config, algorithm, XOR3d)
state = pipeline.setup()
state, best = pipeline.auto_run(state)
pipeline.show(state, best)

View File

@@ -0,0 +1,84 @@
import jax.numpy as jnp
from config import *
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
def example_conf1():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=500,
pop_size=10000
),
neat=NeatConfig(
inputs=4,
outputs=1,
),
gene=NormalGeneConfig(
activation_default='sigmoid',
activation_options=('sigmoid',),
),
problem=GymNaxConfig(
env_name='CartPole-v1',
output_transform=lambda out: jnp.where(out[0] > 0.5, 1, 0) # the action of cartpole is {0, 1}
)
)
def example_conf2():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=500,
pop_size=10000
),
neat=NeatConfig(
inputs=4,
outputs=1,
),
gene=NormalGeneConfig(
activation_default='tanh',
activation_options=('tanh',),
),
problem=GymNaxConfig(
env_name='CartPole-v1',
output_transform=lambda out: jnp.where(out[0] > 0, 1, 0) # the action of cartpole is {0, 1}
)
)
def example_conf3():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=500,
pop_size=10000
),
neat=NeatConfig(
inputs=4,
outputs=2,
),
gene=NormalGeneConfig(
activation_default='tanh',
activation_options=('tanh',),
),
problem=GymNaxConfig(
env_name='CartPole-v1',
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1}
)
)
if __name__ == '__main__':
# all config files above can solve cartpole
conf = example_conf3()
algorithm = NEAT(conf, NormalGene)
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
state = pipeline.setup()
state, best = pipeline.auto_run(state)
pipeline.show(state, best)

View File

@@ -1,31 +0,0 @@
from functools import partial
import jax
from utils import unflatten_conns, act, agg, Activation, Aggregation
from algorithm.neat.gene import RecurrentGeneConfig
config = RecurrentGeneConfig(
activation_options=("tanh", "sigmoid"),
activation_default="tanh",
)
class A:
def __init__(self):
self.act_funcs = [Activation.name2func[name] for name in config.activation_options]
self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
self.isTrue = False
@partial(jax.jit, static_argnums=(0,))
def step(self):
i = jax.numpy.array([0, 1])
z = jax.numpy.array([
[1, 1],
[2, 2]
])
print(self.act_funcs)
return jax.vmap(act, in_axes=(0, 0, None))(i, z, self.act_funcs)
AA = A()
print(AA.step())

View File

@@ -1,40 +0,0 @@
import jax
import numpy as np
from config import Config, BasicConfig, NeatConfig
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses
if __name__ == '__main__':
config = Config(
basic=BasicConfig(
fitness_target=3.9999999,
pop_size=10000
),
neat=NeatConfig(
maximum_nodes=50,
maximum_conns=100,
compatibility_threshold=4
),
gene=NormalGeneConfig()
)
algorithm = NEAT(config, NormalGene)
pipeline = Pipeline(config, algorithm)
pipeline.auto_run(evaluate)

View File

@@ -1,49 +0,0 @@
import jax
import numpy as np
from config import Config, BasicConfig, NeatConfig
from pipeline import Pipeline
from algorithm import NEAT, HyperNEAT
from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig
from algorithm.hyperneat.substrate import NormalSubstrate, NormalSubstrateConfig
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses
if __name__ == '__main__':
config = Config(
basic=BasicConfig(
fitness_target=3.99999,
pop_size=10000
),
neat=NeatConfig(
network_type="recurrent",
maximum_nodes=50,
maximum_conns=100,
inputs=4,
outputs=1
),
gene=RecurrentGeneConfig(
activation_default="tanh",
activation_options=("tanh",),
),
substrate=NormalSubstrateConfig(),
)
neat = NEAT(config, RecurrentGene)
hyperNEAT = HyperNEAT(config, neat, NormalSubstrate)
pipeline = Pipeline(config, hyperNEAT)
pipeline.auto_run(evaluate)

View File

@@ -1,42 +0,0 @@
import jax
import numpy as np
from config import Config, BasicConfig, NeatConfig
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses
if __name__ == '__main__':
config = Config(
basic=BasicConfig(
fitness_target=3.99999,
pop_size=10000
),
neat=NeatConfig(
network_type="recurrent",
maximum_nodes=50,
maximum_conns=100
),
gene=RecurrentGeneConfig(
activate_times=3
)
)
algorithm = NEAT(config, RecurrentGene)
pipeline = Pipeline(config, algorithm)
pipeline.auto_run(evaluate)