This commit is contained in:
wls2002
2024-06-20 16:32:52 +08:00
parent 9f72813c35
commit 075460f896
17 changed files with 224 additions and 140 deletions

View File

@@ -8,7 +8,7 @@ if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
genome=DenseInitialize(
num_inputs=3,
num_outputs=1,
max_nodes=50,
@@ -21,7 +21,7 @@ if __name__ == "__main__":
# aggregation_options=(Agg.sum,),
aggregation_options=AGG_ALL,
),
output_transform=Act.sigmoid, # the activation function for output node
output_transform=Act.standard_sigmoid, # the activation function for output node
mutation=DefaultMutation(
node_add=0.1,
conn_add=0.1,
@@ -29,7 +29,7 @@ if __name__ == "__main__":
conn_delete=0,
),
),
pop_size=100000,
pop_size=10000,
species_size=20,
compatibility_threshold=2,
survival_threshold=0.01, # magic
@@ -37,7 +37,7 @@ if __name__ == "__main__":
),
problem=XOR3d(),
generation_limit=10000,
fitness_target=-1e-8,
fitness_target=-1e-3,
)
# initialize state

View File

@@ -6,7 +6,7 @@ from algorithm.neat import *
from problem.rl_env import GymNaxEnv
def action_policy(forward_func, obs):
def action_policy(randkey, forward_func, obs):
return jnp.argmax(forward_func(obs))
@@ -27,7 +27,9 @@ if __name__ == "__main__":
species_size=10,
),
),
problem=GymNaxEnv(env_name="CartPole-v1", repeat_times=5, action_policy=action_policy),
problem=GymNaxEnv(
env_name="CartPole-v1", repeat_times=5, action_policy=action_policy
),
generation_limit=10000,
fitness_target=500,
)

View File

@@ -1,14 +1,14 @@
import jax, jax.numpy as jnp
from algorithm.neat import *
from algorithm.neat.genome.hidden import AdvanceInitialize
from algorithm.neat.genome.dense import DenseInitialize
from utils.graph import topological_sort_python
from utils import *
if __name__ == '__main__':
genome = AdvanceInitialize(
num_inputs=17,
num_outputs=6,
hidden_cnt=8,
if __name__ == "__main__":
genome = DenseInitialize(
num_inputs=3,
num_outputs=1,
max_nodes=50,
max_conns=500,
)
@@ -19,16 +19,19 @@ if __name__ == '__main__':
nodes, conns = genome.initialize(state, randkey)
network = genome.network_dict(state, nodes, conns)
print(set(network["nodes"]), set(network["conns"]))
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
print(order)
input_idx, output_idx = genome.get_input_idx(), genome.get_output_idx()
print(input_idx, output_idx)
print(genome.repr(state, nodes, conns))
print(network)
res = genome.sympy_func(state, network, sympy_input_transform=lambda x: 999999999*x, sympy_output_transform=SympyStandardSigmoid)
(symbols,
args_symbols,
input_symbols,
nodes_exprs,
output_exprs,
forward_func,) = res
res = genome.sympy_func(state, network, precision=3)
print(res)
print(symbols)
print(output_exprs[0].subs(args_symbols))
inputs = jnp.zeros(3)
print(forward_func(inputs))