odify genome for the official release

This commit is contained in:
root
2024-07-10 11:24:11 +08:00
parent 075460f896
commit ee8ec84202
83 changed files with 588 additions and 611 deletions

39
examples/brax/ant.py Normal file
View File

@@ -0,0 +1,39 @@
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import BraxEnv
from tensorneat.common import Act
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=27,
num_outputs=8,
max_nodes=100,
max_conns=200,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
),
output_transform=Act.tanh,
),
pop_size=1000,
species_size=10,
compatibility_threshold=3.5,
survival_threshold=0.01,
),
),
problem=BraxEnv(
env_name="ant",
),
generation_limit=10000,
fitness_target=5000,
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -0,0 +1,48 @@
import jax
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import BraxEnv
from tensorneat.common import Act
def sample_policy(randkey, obs):
return jax.random.uniform(randkey, (6,), minval=-1, maxval=1)
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=17,
num_outputs=6,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
),
output_transform=Act.tanh,
),
pop_size=1000,
species_size=10,
),
),
problem=BraxEnv(
env_name="halfcheetah",
max_step=1000,
obs_normalization=True,
sample_episodes=1000,
sample_policy=sample_policy,
),
generation_limit=10000,
fitness_target=5000,
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

37
examples/brax/reacher.py Normal file
View File

@@ -0,0 +1,37 @@
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import BraxEnv
from tensorneat.common import Act
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=11,
num_outputs=2,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
),
output_transform=Act.tanh,
),
pop_size=100,
species_size=10,
),
),
problem=BraxEnv(
env_name="reacher",
),
generation_limit=10000,
fitness_target=5000,
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -0,0 +1,19 @@
import jax
from problem.rl_env import BraxEnv
def random_policy(randkey, forward_func, obs):
return jax.random.uniform(randkey, (6,), minval=-1, maxval=1)
if __name__ == "__main__":
problem = BraxEnv(env_name="walker2d", max_step=1000, action_policy=random_policy)
state = problem.setup()
randkey = jax.random.key(0)
problem.show(
state,
randkey,
act_func=lambda state, params, obs: obs,
params=None,
save_path="walker2d_random_policy",
)

60
examples/brax/walker.py Normal file
View File

@@ -0,0 +1,60 @@
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import BraxEnv
from tensorneat.common import Act
import jax, jax.numpy as jnp
def split_right_left(randkey, forward_func, obs):
right_obs_keys = jnp.array([2, 3, 4, 11, 12, 13])
left_obs_keys = jnp.array([5, 6, 7, 14, 15, 16])
right_action_keys = jnp.array([0, 1, 2])
left_action_keys = jnp.array([3, 4, 5])
right_foot_obs = obs
left_foot_obs = obs
left_foot_obs = left_foot_obs.at[right_obs_keys].set(obs[left_obs_keys])
left_foot_obs = left_foot_obs.at[left_obs_keys].set(obs[right_obs_keys])
right_action, left_action = jax.vmap(forward_func)(jnp.stack([right_foot_obs, left_foot_obs]))
# print(right_action.shape)
# print(left_action.shape)
return jnp.concatenate([right_action, left_action])
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=17,
num_outputs=3,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
),
output_transform=Act.tanh,
),
pop_size=1000,
species_size=10,
),
),
problem=BraxEnv(
env_name="walker2d",
max_step=1000,
action_policy=split_right_left
),
generation_limit=10000,
fitness_target=5000,
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

50
examples/func_fit/xor.py Normal file
View File

@@ -0,0 +1,50 @@
from pipeline import Pipeline
from algorithm.neat import *
from problem.func_fit import XOR3d
from tensorneat.common import ACT_ALL, AGG_ALL, Act, Agg
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DenseInitialize(
num_inputs=3,
num_outputs=1,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_default=Act.tanh,
# activation_options=(Act.tanh,),
activation_options=ACT_ALL,
aggregation_default=Agg.sum,
# aggregation_options=(Agg.sum,),
aggregation_options=AGG_ALL,
),
output_transform=Act.standard_sigmoid, # the activation function for output node
mutation=DefaultMutation(
node_add=0.1,
conn_add=0.1,
node_delete=0,
conn_delete=0,
),
),
pop_size=10000,
species_size=20,
compatibility_threshold=2,
survival_threshold=0.01, # magic
),
),
problem=XOR3d(),
generation_limit=10000,
fitness_target=-1e-3,
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)
# show result
pipeline.show(state, best)
pipeline.save(state=state)

View File

@@ -0,0 +1,63 @@
from pipeline import Pipeline
from algorithm.neat import *
from algorithm.hyperneat import *
from tensorneat.common import Act
from problem.func_fit import XOR3d
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=HyperNEAT(
substrate=FullSubstrate(
input_coors=[(-1, -1), (0.333, -1), (-0.333, -1), (1, -1)], # 3(XOR3d inputs) + 1(bias)
hidden_coors=[
(-1, -0.5), (0.333, -0.5), (-0.333, -0.5),
(1, -0.5),
(-1, 0),
(0.333, 0),
(-0.333, 0),
(1, 0),
(-1, 0.5),
(0.333, 0.5),
(-0.333, 0.5),
(1, 0.5),
],
output_coors=[
(0, 1), # one output
],
),
neat=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=4, # [*coor1, *coor2]
num_outputs=1, # the weight of connection between two coor1 and coor2
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
output_transform=Act.tanh, # the activation function for output node in NEAT
),
pop_size=1000,
species_size=10,
compatibility_threshold=2,
survival_threshold=0.03,
),
),
activation=Act.tanh,
activate_time=10,
output_transform=Act.sigmoid, # the activation function for output node in HyperNEAT
),
problem=XOR3d(),
generation_limit=300,
fitness_target=-1e-6,
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)
# show result
pipeline.show(state, best)

View File

@@ -0,0 +1,47 @@
from pipeline import Pipeline
from algorithm.neat import *
from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse
from problem.func_fit import XOR3d
from utils.activation import ACT_ALL, Act
if __name__ == "__main__":
pipeline = Pipeline(
seed=0,
algorithm=NEAT(
species=DefaultSpecies(
genome=RecurrentGenome(
num_inputs=3,
num_outputs=1,
max_nodes=50,
max_conns=100,
activate_time=5,
node_gene=NodeGeneWithoutResponse(
activation_options=ACT_ALL, activation_replace_rate=0.2
),
output_transform=Act.sigmoid,
mutation=DefaultMutation(
node_add=0.05,
conn_add=0.2,
node_delete=0,
conn_delete=0,
),
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
survival_threshold=0.03,
),
),
problem=XOR3d(),
generation_limit=10000,
fitness_target=-1e-8,
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)
# show result
pipeline.show(state, best)

36
examples/gymnax/arcbot.py Normal file
View File

@@ -0,0 +1,36 @@
import jax.numpy as jnp
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import GymNaxEnv
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=6,
num_outputs=3,
max_nodes=50,
max_conns=100,
output_transform=lambda out: jnp.argmax(
out
), # the action of acrobot is {0, 1, 2}
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name="Acrobot-v1",
),
generation_limit=10000,
fitness_target=-62,
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -0,0 +1,41 @@
import jax.numpy as jnp
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import GymNaxEnv
def action_policy(randkey, forward_func, obs):
return jnp.argmax(forward_func(obs))
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=4,
num_outputs=2,
max_nodes=50,
max_conns=100,
# output_transform=lambda out: jnp.argmax(
# out
# ), # the action of cartpole is {0, 1}
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name="CartPole-v1", repeat_times=5, action_policy=action_policy
),
generation_limit=10000,
fitness_target=500,
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -0,0 +1,74 @@
import jax
from pipeline import Pipeline
from algorithm.neat import *
from algorithm.hyperneat import *
from tensorneat.common import Act
from problem.rl_env import GymNaxEnv
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=HyperNEAT(
substrate=FullSubstrate(
input_coors=[
(-1, -1),
(-0.5, -1),
(0, -1),
(0.5, -1),
(1, -1),
], # 4(problem inputs) + 1(bias)
hidden_coors=[
(-1, -0.5),
(0.333, -0.5),
(-0.333, -0.5),
(1, -0.5),
(-1, 0),
(0.333, 0),
(-0.333, 0),
(1, 0),
(-1, 0.5),
(0.333, 0.5),
(-0.333, 0.5),
(1, 0.5),
],
output_coors=[
(-1, 1),
(1, 1), # one output
],
),
neat=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=4, # [*coor1, *coor2]
num_outputs=1, # the weight of connection between two coor1 and coor2
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
output_transform=Act.tanh, # the activation function for output node in NEAT
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
survival_threshold=0.03,
),
),
activation=Act.tanh, # the activation function for output node in HyperNEAT
activate_time=10,
output_transform=jax.numpy.argmax, # action of cartpole is in {0, 1}
),
problem=GymNaxEnv(
env_name="CartPole-v1",
),
generation_limit=300,
fitness_target=500,
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -0,0 +1,36 @@
import jax.numpy as jnp
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import GymNaxEnv
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=2,
num_outputs=3,
max_nodes=50,
max_conns=100,
output_transform=lambda out: jnp.argmax(
out
), # the action of mountain car is {0, 1, 2}
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name="MountainCar-v0",
),
generation_limit=10000,
fitness_target=-86,
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -0,0 +1,37 @@
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import GymNaxEnv
from tensorneat.common import Act
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=2,
num_outputs=1,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
),
output_transform=Act.tanh
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name="MountainCarContinuous-v0",
),
generation_limit=10000,
fitness_target=99,
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -0,0 +1,38 @@
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import GymNaxEnv
from tensorneat.common import Act
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=3,
num_outputs=1,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
),
output_transform=lambda out: Act.tanh(out)
* 2, # the action of pendulum is [-2, 2]
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name="Pendulum-v1",
),
generation_limit=10000,
fitness_target=-10,
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

View File

@@ -0,0 +1,33 @@
import jax.numpy as jnp
from pipeline import Pipeline
from algorithm.neat import *
from problem.rl_env import GymNaxEnv
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=8,
num_outputs=2,
max_nodes=50,
max_conns=100,
),
pop_size=10000,
species_size=10,
),
),
problem=GymNaxEnv(
env_name="Reacher-misc",
),
generation_limit=10000,
fitness_target=90,
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,37 @@
import jax, jax.numpy as jnp
from algorithm.neat import *
from algorithm.neat.genome.dense import DenseInitialize
from utils.graph import topological_sort_python
from tensorneat.common import *
if __name__ == "__main__":
genome = DenseInitialize(
num_inputs=3,
num_outputs=1,
max_nodes=50,
max_conns=500,
)
state = genome.setup()
randkey = jax.random.PRNGKey(42)
nodes, conns = genome.initialize(state, randkey)
network = genome.network_dict(state, nodes, conns)
input_idx, output_idx = genome.get_input_idx(), genome.get_output_idx()
res = genome.sympy_func(state, network, sympy_input_transform=lambda x: 999999999*x, sympy_output_transform=SympyStandardSigmoid)
(symbols,
args_symbols,
input_symbols,
nodes_exprs,
output_exprs,
forward_func,) = res
print(symbols)
print(output_exprs[0].subs(args_symbols))
inputs = jnp.zeros(3)
print(forward_func(inputs))

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 90 KiB

View File

@@ -0,0 +1,191 @@
{
"nodes": {
"0": {
"bias": 0.13710324466228485,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"1": {
"bias": -1.4202250242233276,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"2": {
"bias": -0.4653860926628113,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"3": {
"bias": 0.5835710167884827,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"4": {
"bias": 2.187405824661255,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"5": {
"bias": 0.24963024258613586,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"6": {
"bias": -0.966821551322937,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"7": {
"bias": 0.4452081620693207,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"8": {
"bias": -0.07293166220188141,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"9": {
"bias": -0.1625899225473404,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"10": {
"bias": -0.8576332330703735,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"11": {
"bias": -0.18487468361854553,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"12": {
"bias": 1.4335486888885498,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"13": {
"bias": -0.8690621256828308,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"14": {
"bias": -0.23014676570892334,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"15": {
"bias": 0.7880322337150574,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"16": {
"bias": -0.22258250415325165,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"17": {
"bias": 0.2773352861404419,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"18": {
"bias": -0.40279051661491394,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"19": {
"bias": 1.092000961303711,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"20": {
"bias": -0.4063087999820709,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"21": {
"bias": 0.3895529806613922,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"22": {
"bias": -0.18007506430149078,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"23": {
"bias": -0.8112533092498779,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"24": {
"bias": 0.2946726381778717,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"25": {
"bias": -1.118497371673584,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"26": {
"bias": 1.3674490451812744,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"27": {
"bias": -1.6514816284179688,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"28": {
"bias": 0.9440701603889465,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"29": {
"bias": 1.564852237701416,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
},
"30": {
"bias": -0.5568665266036987,
"res": 1.0,
"agg": "sum",
"act": "sigmoid"
}
},
"conns": {

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 89 KiB

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,13 @@
import networkx as nx
import matplotlib.pyplot as plt
# 创建一个空白的有向图
G = nx.DiGraph()
# 添加边
G.add_edge('A', 'B')
G.add_edge('A', 'C')
G.add_edge('B', 'C')
G.add_edge('C', 'D')
# 绘制有向图

View File

@@ -0,0 +1,25 @@
import jax, jax.numpy as jnp
import jax.random
from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048
def random_policy(state, params, obs):
key = jax.random.key(obs.sum())
actions = jax.random.normal(key, (4,))
# actions = actions.at[2:].set(-9999)
# return jnp.array([4, 4, 0, 1])
# return jnp.array([1, 2, 3, 4])
# return actions
return actions
if __name__ == "__main__":
problem = Jumanji_2048(
max_step=10000, repeat_times=1000, guarantee_invalid_action=False
)
state = problem.setup()
jit_evaluate = jax.jit(
lambda state, randkey: problem.evaluate(state, randkey, random_policy, None)
)
randkey = jax.random.PRNGKey(0)
reward = jit_evaluate(state, randkey)
print(reward)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,120 @@
import jax, jax.numpy as jnp
from pipeline import Pipeline
from algorithm.neat import *
from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse
from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048
from tensorneat.common import Act, Agg
def rot_li(li):
return li[1:] + [li[0]]
def rot_boards(board):
def rot(a, _):
a = jnp.rot90(a)
return a, a # carry, y
# carry, np.stack(ys)
_, boards = jax.lax.scan(rot, board, jnp.arange(4, dtype=jnp.int32))
return boards
direction = ["up", "right", "down", "left"]
lr_flip_direction = ["up", "left", "down", "right"]
directions = []
lr_flip_directions = []
for _ in range(4):
direction = rot_li(direction)
lr_flip_direction = rot_li(lr_flip_direction)
directions.append(direction.copy())
lr_flip_directions.append(lr_flip_direction.copy())
full_directions = directions + lr_flip_directions
def action_policy(forward_func, obs):
board = obs.reshape(4, 4)
lr_flip_board = jnp.fliplr(board)
boards = rot_boards(board)
lr_flip_boards = rot_boards(lr_flip_board)
# stack
full_boards = jnp.concatenate([boards, lr_flip_boards], axis=0)
scores = jax.vmap(forward_func)(full_boards.reshape(8, -1))
total_score = {"up": 0, "right": 0, "down": 0, "left": 0}
for i in range(8):
dire = full_directions[i]
for j in range(4):
total_score[dire[j]] += scores[i, j]
return jnp.array(
[
total_score["up"],
total_score["right"],
total_score["down"],
total_score["left"],
]
)
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=16,
num_outputs=4,
max_nodes=100,
max_conns=1000,
node_gene=NodeGeneWithoutResponse(
activation_default=Act.sigmoid,
activation_options=(
Act.sigmoid,
Act.relu,
Act.tanh,
Act.identity,
),
aggregation_default=Agg.sum,
aggregation_options=(Agg.sum, ),
activation_replace_rate=0.02,
aggregation_replace_rate=0.02,
bias_mutate_rate=0.03,
bias_init_std=0.5,
bias_mutate_power=0.02,
bias_replace_rate=0.01,
),
conn_gene=DefaultConnGene(
weight_mutate_rate=0.015,
weight_replace_rate=0.03,
weight_mutate_power=0.05,
),
mutation=DefaultMutation(node_add=0.001, conn_add=0.002),
),
pop_size=1000,
species_size=5,
survival_threshold=0.01,
max_stagnation=7,
genome_elitism=3,
compatibility_threshold=1.2,
),
),
problem=Jumanji_2048(
max_step=1000,
repeat_times=50,
# guarantee_invalid_action=True,
guarantee_invalid_action=False,
action_policy=action_policy,
),
generation_limit=10000,
fitness_target=13000,
save_path="2048.npz",
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)

10
examples/tmp.py Normal file
View File

@@ -0,0 +1,10 @@
import jax, jax.numpy as jnp
from tensorneat.algorithm import NEAT
from tensorneat.algorithm.neat import DefaultGenome
key = jax.random.key(0)
genome = DefaultGenome(num_inputs=5, num_outputs=3, init_hidden_layers=(1, ))
state = genome.setup()
nodes, conns = genome.initialize(state, key)
print(genome.repr(state, nodes, conns))

View File

@@ -0,0 +1,6 @@
import ray
ray.init(num_gpus=2)
available_resources = ray.available_resources()
print("Available resources:", available_resources)