fix bugs
This commit is contained in:
@@ -8,7 +8,7 @@ if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
genome=DefaultGenome(
|
||||
genome=DenseInitialize(
|
||||
num_inputs=3,
|
||||
num_outputs=1,
|
||||
max_nodes=50,
|
||||
@@ -21,7 +21,7 @@ if __name__ == "__main__":
|
||||
# aggregation_options=(Agg.sum,),
|
||||
aggregation_options=AGG_ALL,
|
||||
),
|
||||
output_transform=Act.sigmoid, # the activation function for output node
|
||||
output_transform=Act.standard_sigmoid, # the activation function for output node
|
||||
mutation=DefaultMutation(
|
||||
node_add=0.1,
|
||||
conn_add=0.1,
|
||||
@@ -29,7 +29,7 @@ if __name__ == "__main__":
|
||||
conn_delete=0,
|
||||
),
|
||||
),
|
||||
pop_size=100000,
|
||||
pop_size=10000,
|
||||
species_size=20,
|
||||
compatibility_threshold=2,
|
||||
survival_threshold=0.01, # magic
|
||||
@@ -37,7 +37,7 @@ if __name__ == "__main__":
|
||||
),
|
||||
problem=XOR3d(),
|
||||
generation_limit=10000,
|
||||
fitness_target=-1e-8,
|
||||
fitness_target=-1e-3,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
|
||||
@@ -6,7 +6,7 @@ from algorithm.neat import *
|
||||
from problem.rl_env import GymNaxEnv
|
||||
|
||||
|
||||
def action_policy(forward_func, obs):
|
||||
def action_policy(randkey, forward_func, obs):
|
||||
return jnp.argmax(forward_func(obs))
|
||||
|
||||
|
||||
@@ -27,7 +27,9 @@ if __name__ == "__main__":
|
||||
species_size=10,
|
||||
),
|
||||
),
|
||||
problem=GymNaxEnv(env_name="CartPole-v1", repeat_times=5, action_policy=action_policy),
|
||||
problem=GymNaxEnv(
|
||||
env_name="CartPole-v1", repeat_times=5, action_policy=action_policy
|
||||
),
|
||||
generation_limit=10000,
|
||||
fitness_target=500,
|
||||
)
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from algorithm.neat import *
|
||||
from algorithm.neat.genome.hidden import AdvanceInitialize
|
||||
from algorithm.neat.genome.dense import DenseInitialize
|
||||
from utils.graph import topological_sort_python
|
||||
from utils import *
|
||||
|
||||
if __name__ == '__main__':
|
||||
genome = AdvanceInitialize(
|
||||
num_inputs=17,
|
||||
num_outputs=6,
|
||||
hidden_cnt=8,
|
||||
if __name__ == "__main__":
|
||||
genome = DenseInitialize(
|
||||
num_inputs=3,
|
||||
num_outputs=1,
|
||||
max_nodes=50,
|
||||
max_conns=500,
|
||||
)
|
||||
@@ -19,16 +19,19 @@ if __name__ == '__main__':
|
||||
nodes, conns = genome.initialize(state, randkey)
|
||||
|
||||
network = genome.network_dict(state, nodes, conns)
|
||||
print(set(network["nodes"]), set(network["conns"]))
|
||||
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
|
||||
print(order)
|
||||
|
||||
input_idx, output_idx = genome.get_input_idx(), genome.get_output_idx()
|
||||
print(input_idx, output_idx)
|
||||
|
||||
print(genome.repr(state, nodes, conns))
|
||||
print(network)
|
||||
res = genome.sympy_func(state, network, sympy_input_transform=lambda x: 999999999*x, sympy_output_transform=SympyStandardSigmoid)
|
||||
(symbols,
|
||||
args_symbols,
|
||||
input_symbols,
|
||||
nodes_exprs,
|
||||
output_exprs,
|
||||
forward_func,) = res
|
||||
|
||||
res = genome.sympy_func(state, network, precision=3)
|
||||
print(res)
|
||||
print(symbols)
|
||||
print(output_exprs[0].subs(args_symbols))
|
||||
|
||||
inputs = jnp.zeros(3)
|
||||
print(forward_func(inputs))
|
||||
|
||||
Reference in New Issue
Block a user