change str in config (act, agg) from str to callable
This commit is contained in:
@@ -3,6 +3,7 @@ from pipeline import Pipeline
|
||||
from algorithm.neat import NormalGene, NormalGeneConfig
|
||||
from algorithm.hyperneat import HyperNEAT, NormalSubstrate, NormalSubstrateConfig
|
||||
from problem.func_fit import XOR3d, FuncFitConfig
|
||||
from utils import Act
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@@ -27,8 +28,8 @@ if __name__ == '__main__':
|
||||
input_coors=((-1, -1), (-0.5, -1), (0.5, -1), (1, -1)),
|
||||
),
|
||||
gene=NormalGeneConfig(
|
||||
activation_default='tanh',
|
||||
activation_options=('tanh', ),
|
||||
activation_default=Act.tanh,
|
||||
activation_options=(Act.tanh, ),
|
||||
),
|
||||
problem=FuncFitConfig()
|
||||
)
|
||||
|
||||
@@ -36,5 +36,6 @@ if __name__ == '__main__':
|
||||
algorithm = NEAT(config, RecurrentGene)
|
||||
pipeline = Pipeline(config, algorithm, XOR3d)
|
||||
state = pipeline.setup()
|
||||
pipeline.pre_compile(state)
|
||||
state, best = pipeline.auto_run(state)
|
||||
pipeline.show(state, best)
|
||||
|
||||
@@ -19,8 +19,8 @@ def example_conf1():
|
||||
outputs=1,
|
||||
),
|
||||
gene=NormalGeneConfig(
|
||||
activation_default='sigmoid',
|
||||
activation_options=('sigmoid',),
|
||||
activation_default=Act.sigmoid,
|
||||
activation_options=(Act.sigmoid,),
|
||||
),
|
||||
problem=GymNaxConfig(
|
||||
env_name='CartPole-v1',
|
||||
@@ -41,8 +41,8 @@ def example_conf2():
|
||||
outputs=1,
|
||||
),
|
||||
gene=NormalGeneConfig(
|
||||
activation_default='tanh',
|
||||
activation_options=('tanh',),
|
||||
activation_default=Act.tanh,
|
||||
activation_options=(Act.tanh,),
|
||||
),
|
||||
problem=GymNaxConfig(
|
||||
env_name='CartPole-v1',
|
||||
@@ -63,8 +63,8 @@ def example_conf3():
|
||||
outputs=2,
|
||||
),
|
||||
gene=NormalGeneConfig(
|
||||
activation_default='tanh',
|
||||
activation_options=('tanh',),
|
||||
activation_default=Act.tanh,
|
||||
activation_options=(Act.tanh,),
|
||||
),
|
||||
problem=GymNaxConfig(
|
||||
env_name='CartPole-v1',
|
||||
@@ -80,5 +80,5 @@ if __name__ == '__main__':
|
||||
algorithm = NEAT(conf, NormalGene)
|
||||
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
|
||||
state = pipeline.setup()
|
||||
pipeline.pre_compile(state)
|
||||
state, best = pipeline.auto_run(state)
|
||||
pipeline.show(state, best)
|
||||
|
||||
Reference in New Issue
Block a user