modify pipeline for "update_by_data";
fix bug in speciate. currently, node_delete and conn_delete can successfully work
This commit is contained in:
@@ -19,9 +19,15 @@ class BaseAlgorithm:
|
|||||||
"""transform the genome into a neural network"""
|
"""transform the genome into a neural network"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def restore(self, state, transformed):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def forward(self, state, inputs, transformed):
|
def forward(self, state, inputs, transformed):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def update_by_batch(self, state, batch_input, transformed):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_inputs(self):
|
def num_inputs(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -178,18 +178,25 @@ class DefaultMutation(BaseMutation):
|
|||||||
def no(key_, nodes_, conns_):
|
def no(key_, nodes_, conns_):
|
||||||
return nodes_, conns_
|
return nodes_, conns_
|
||||||
|
|
||||||
nodes, conns = jax.lax.cond(
|
if self.node_add > 0:
|
||||||
r1 < self.node_add, mutate_add_node, no, k1, nodes, conns
|
nodes, conns = jax.lax.cond(
|
||||||
)
|
r1 < self.node_add, mutate_add_node, no, k1, nodes, conns
|
||||||
nodes, conns = jax.lax.cond(
|
)
|
||||||
r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns
|
|
||||||
)
|
if self.node_delete > 0:
|
||||||
nodes, conns = jax.lax.cond(
|
nodes, conns = jax.lax.cond(
|
||||||
r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns
|
r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns
|
||||||
)
|
)
|
||||||
nodes, conns = jax.lax.cond(
|
|
||||||
r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns
|
if self.conn_add > 0:
|
||||||
)
|
nodes, conns = jax.lax.cond(
|
||||||
|
r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.conn_delete > 0:
|
||||||
|
nodes, conns = jax.lax.cond(
|
||||||
|
r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns
|
||||||
|
)
|
||||||
|
|
||||||
return nodes, conns
|
return nodes, conns
|
||||||
|
|
||||||
|
|||||||
@@ -117,7 +117,9 @@ class DefaultGenome(BaseGenome):
|
|||||||
|
|
||||||
def hit():
|
def hit():
|
||||||
batch_ins, new_conn_attrs = jax.vmap(
|
batch_ins, new_conn_attrs = jax.vmap(
|
||||||
self.conn_gene.update_by_batch, in_axes=(None, 1, 1), out_axes=(1, 1)
|
self.conn_gene.update_by_batch,
|
||||||
|
in_axes=(None, 1, 1),
|
||||||
|
out_axes=(1, 1),
|
||||||
)(state, u_conns_[:, :, i], batch_values)
|
)(state, u_conns_[:, :, i], batch_values)
|
||||||
batch_z, new_node_attrs = self.node_gene.update_by_batch(
|
batch_z, new_node_attrs = self.node_gene.update_by_batch(
|
||||||
state,
|
state,
|
||||||
@@ -132,12 +134,12 @@ class DefaultGenome(BaseGenome):
|
|||||||
u_conns_.at[:, :, i].set(new_conn_attrs),
|
u_conns_.at[:, :, i].set(new_conn_attrs),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# the val of input nodes is obtained by the task, not by calculation
|
||||||
(batch_values, nodes_attrs_, u_conns_) = jax.lax.cond(
|
(batch_values, nodes_attrs_, u_conns_) = jax.lax.cond(
|
||||||
jnp.isin(i, self.input_idx),
|
jnp.isin(i, self.input_idx),
|
||||||
lambda: (batch_values, nodes_attrs_, u_conns_),
|
lambda: (batch_values, nodes_attrs_, u_conns_),
|
||||||
hit,
|
hit,
|
||||||
)
|
)
|
||||||
# the val of input nodes is obtained by the task, not by calculation
|
|
||||||
|
|
||||||
return batch_values, nodes_attrs_, u_conns_, idx + 1
|
return batch_values, nodes_attrs_, u_conns_, idx + 1
|
||||||
|
|
||||||
|
|||||||
@@ -44,9 +44,15 @@ class NEAT(BaseAlgorithm):
|
|||||||
nodes, conns = individual
|
nodes, conns = individual
|
||||||
return self.genome.transform(state, nodes, conns)
|
return self.genome.transform(state, nodes, conns)
|
||||||
|
|
||||||
|
def restore(self, state, transformed):
|
||||||
|
return self.genome.restore(state, transformed)
|
||||||
|
|
||||||
def forward(self, state, inputs, transformed):
|
def forward(self, state, inputs, transformed):
|
||||||
return self.genome.forward(state, inputs, transformed)
|
return self.genome.forward(state, inputs, transformed)
|
||||||
|
|
||||||
|
def update_by_batch(self, state, batch_input, transformed):
|
||||||
|
return self.genome.update_by_batch(state, batch_input, transformed)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_inputs(self):
|
def num_inputs(self):
|
||||||
return self.genome.num_inputs
|
return self.genome.num_inputs
|
||||||
|
|||||||
@@ -113,6 +113,9 @@ class DefaultSpecies(BaseSpecies):
|
|||||||
return state.pop_nodes, state.pop_conns
|
return state.pop_nodes, state.pop_conns
|
||||||
|
|
||||||
def update_species(self, state, fitness):
|
def update_species(self, state, fitness):
|
||||||
|
# set nan to -inf
|
||||||
|
fitness = jnp.where(jnp.isnan(fitness), -jnp.inf, fitness)
|
||||||
|
|
||||||
# update the fitness of each species
|
# update the fitness of each species
|
||||||
state, species_fitness = self.update_species_fitness(state, fitness)
|
state, species_fitness = self.update_species_fitness(state, fitness)
|
||||||
|
|
||||||
@@ -121,6 +124,7 @@ class DefaultSpecies(BaseSpecies):
|
|||||||
|
|
||||||
# sort species_info by their fitness. (also push nan to the end)
|
# sort species_info by their fitness. (also push nan to the end)
|
||||||
sort_indices = jnp.argsort(species_fitness)[::-1]
|
sort_indices = jnp.argsort(species_fitness)[::-1]
|
||||||
|
|
||||||
state = state.update(
|
state = state.update(
|
||||||
species_keys=state.species_keys[sort_indices],
|
species_keys=state.species_keys[sort_indices],
|
||||||
best_fitness=state.best_fitness[sort_indices],
|
best_fitness=state.best_fitness[sort_indices],
|
||||||
|
|||||||
@@ -21,11 +21,11 @@ if __name__ == "__main__":
|
|||||||
mutation=DefaultMutation(
|
mutation=DefaultMutation(
|
||||||
node_add=0.05,
|
node_add=0.05,
|
||||||
conn_add=0.05,
|
conn_add=0.05,
|
||||||
node_delete=0,
|
node_delete=0.05,
|
||||||
conn_delete=0,
|
conn_delete=0.05,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
pop_size=100,
|
pop_size=1000,
|
||||||
species_size=20,
|
species_size=20,
|
||||||
compatibility_threshold=2,
|
compatibility_threshold=2,
|
||||||
survival_threshold=0.01, # magic
|
survival_threshold=0.01, # magic
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
from functools import partial
|
|
||||||
|
|
||||||
import jax, jax.numpy as jnp
|
import jax, jax.numpy as jnp
|
||||||
import time
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from algorithm import BaseAlgorithm
|
from algorithm import BaseAlgorithm
|
||||||
from problem import BaseProblem
|
from problem import BaseProblem
|
||||||
|
from problem.rl_env import RLEnv
|
||||||
|
from problem.func_fit import FuncFit
|
||||||
from utils import State
|
from utils import State
|
||||||
|
|
||||||
|
|
||||||
@@ -17,6 +17,8 @@ class Pipeline:
|
|||||||
seed: int = 42,
|
seed: int = 42,
|
||||||
fitness_target: float = 1,
|
fitness_target: float = 1,
|
||||||
generation_limit: int = 1000,
|
generation_limit: int = 1000,
|
||||||
|
pre_update: bool = False,
|
||||||
|
update_batch_size: int = 10000,
|
||||||
):
|
):
|
||||||
assert problem.jitable, "Currently, problem must be jitable"
|
assert problem.jitable, "Currently, problem must be jitable"
|
||||||
|
|
||||||
@@ -37,10 +39,30 @@ class Pipeline:
|
|||||||
self.best_genome = None
|
self.best_genome = None
|
||||||
self.best_fitness = float("-inf")
|
self.best_fitness = float("-inf")
|
||||||
self.generation_timestamp = None
|
self.generation_timestamp = None
|
||||||
|
self.pre_update = pre_update
|
||||||
|
self.update_batch_size = update_batch_size
|
||||||
|
if pre_update:
|
||||||
|
if isinstance(problem, RLEnv):
|
||||||
|
assert problem.record_episode, "record_episode must be True"
|
||||||
|
self.fetch_data = lambda episode: episode["obs"]
|
||||||
|
elif isinstance(problem, FuncFit):
|
||||||
|
assert problem.return_data, "return_data must be True"
|
||||||
|
self.fetch_data = lambda data: data
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def setup(self, state=State()):
|
def setup(self, state=State()):
|
||||||
print("initializing")
|
print("initializing")
|
||||||
state = state.register(randkey=jax.random.PRNGKey(self.seed))
|
state = state.register(randkey=jax.random.PRNGKey(self.seed))
|
||||||
|
|
||||||
|
if self.pre_update:
|
||||||
|
# initial with mean = 0 and std = 1
|
||||||
|
state = state.register(
|
||||||
|
data=jax.random.normal(
|
||||||
|
state.randkey, (self.update_batch_size, self.algorithm.num_inputs)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
state = self.algorithm.setup(state)
|
state = self.algorithm.setup(state)
|
||||||
state = self.problem.setup(state)
|
state = self.problem.setup(state)
|
||||||
print("initializing finished")
|
print("initializing finished")
|
||||||
@@ -57,9 +79,42 @@ class Pipeline:
|
|||||||
state, pop
|
state, pop
|
||||||
)
|
)
|
||||||
|
|
||||||
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
|
if self.pre_update:
|
||||||
state, keys, self.algorithm.forward, pop_transformed
|
# update the population
|
||||||
)
|
_, pop_transformed = jax.vmap(
|
||||||
|
self.algorithm.update_by_batch, in_axes=(None, None, 0)
|
||||||
|
)(state, state.data, pop_transformed)
|
||||||
|
|
||||||
|
# raw_data: (Pop, Batch, num_inputs)
|
||||||
|
fitnesses, raw_data = jax.vmap(
|
||||||
|
self.problem.evaluate, in_axes=(None, 0, None, 0)
|
||||||
|
)(state, keys, self.algorithm.forward, pop_transformed)
|
||||||
|
|
||||||
|
data = self.fetch_data(raw_data)
|
||||||
|
assert (
|
||||||
|
data.ndim == 3
|
||||||
|
and data.shape[0] == self.pop_size
|
||||||
|
and data.shape[2] == self.algorithm.num_inputs
|
||||||
|
)
|
||||||
|
# reshape to (Pop * Batch, num_inputs)
|
||||||
|
data = data.reshape(
|
||||||
|
data.shape[0] * data.shape[1], self.algorithm.num_inputs
|
||||||
|
)
|
||||||
|
# shuffle
|
||||||
|
data = jax.random.permutation(randkey_, data, axis=0)
|
||||||
|
# cutoff or expand
|
||||||
|
if data.shape[0] >= self.update_batch_size:
|
||||||
|
data = data[: self.update_batch_size] # cutoff
|
||||||
|
else:
|
||||||
|
data = (
|
||||||
|
jnp.full(state.data.shape, jnp.nan).at[: data.shape[0]].set(data)
|
||||||
|
) # expand
|
||||||
|
state = state.update(data=data)
|
||||||
|
|
||||||
|
else:
|
||||||
|
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
|
||||||
|
state, keys, self.algorithm.forward, pop_transformed
|
||||||
|
)
|
||||||
|
|
||||||
state = self.algorithm.tell(state, fitnesses)
|
state = self.algorithm.tell(state, fitnesses)
|
||||||
|
|
||||||
@@ -89,24 +144,18 @@ class Pipeline:
|
|||||||
print("Fitness limit reached!")
|
print("Fitness limit reached!")
|
||||||
return state, self.best_genome
|
return state, self.best_genome
|
||||||
|
|
||||||
# node = previous_pop[0][0][:, 0]
|
|
||||||
# node_count = jnp.sum(~jnp.isnan(node))
|
|
||||||
# conn = previous_pop[1][0][:, 0]
|
|
||||||
# conn_count = jnp.sum(~jnp.isnan(conn))
|
|
||||||
# if (w % 5 == 0):
|
|
||||||
# print("node_count", node_count)
|
|
||||||
# print("conn_count", conn_count)
|
|
||||||
|
|
||||||
print("Generation limit reached!")
|
print("Generation limit reached!")
|
||||||
return state, self.best_genome
|
return state, self.best_genome
|
||||||
|
|
||||||
def analysis(self, state, pop, fitnesses):
|
def analysis(self, state, pop, fitnesses):
|
||||||
|
|
||||||
|
valid_fitnesses = fitnesses[~np.isnan(fitnesses)]
|
||||||
|
|
||||||
max_f, min_f, mean_f, std_f = (
|
max_f, min_f, mean_f, std_f = (
|
||||||
max(fitnesses),
|
max(valid_fitnesses),
|
||||||
min(fitnesses),
|
min(valid_fitnesses),
|
||||||
np.mean(fitnesses),
|
np.mean(valid_fitnesses),
|
||||||
np.std(fitnesses),
|
np.std(valid_fitnesses),
|
||||||
)
|
)
|
||||||
|
|
||||||
new_timestamp = time.time()
|
new_timestamp = time.time()
|
||||||
@@ -122,9 +171,9 @@ class Pipeline:
|
|||||||
species_sizes = [int(i) for i in member_count if i > 0]
|
species_sizes = [int(i) for i in member_count if i > 0]
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"Generation: {self.algorithm.generation(state)}",
|
f"Generation: {self.algorithm.generation(state)}, Cost time: {cost_time * 1000:.2f}ms\n",
|
||||||
f"species: {len(species_sizes)}, {species_sizes}",
|
f"\tspecies: {len(species_sizes)}, {species_sizes}\n",
|
||||||
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms",
|
f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n",
|
||||||
)
|
)
|
||||||
|
|
||||||
def show(self, state, best, *args, **kwargs):
|
def show(self, state, best, *args, **kwargs):
|
||||||
|
|||||||
@@ -49,7 +49,10 @@ class FuncFit(BaseProblem):
|
|||||||
state, self.inputs, params
|
state, self.inputs, params
|
||||||
)
|
)
|
||||||
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
|
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
|
||||||
loss = self.evaluate(state, randkey, act_func, params)
|
if self.return_data:
|
||||||
|
loss, _ = self.evaluate(state, randkey, act_func, params)
|
||||||
|
else:
|
||||||
|
loss = self.evaluate(state, randkey, act_func, params)
|
||||||
loss = -loss
|
loss = -loss
|
||||||
|
|
||||||
msg = ""
|
msg = ""
|
||||||
|
|||||||
@@ -4,14 +4,19 @@ from .func_fit import FuncFit
|
|||||||
|
|
||||||
|
|
||||||
class XOR(FuncFit):
|
class XOR(FuncFit):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def inputs(self):
|
def inputs(self):
|
||||||
return np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
|
return np.array(
|
||||||
|
[[0, 0], [0, 1], [1, 0], [1, 1]],
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def targets(self):
|
def targets(self):
|
||||||
return np.array([[0], [1], [1], [0]])
|
return np.array(
|
||||||
|
[[0], [1], [1], [0]],
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_shape(self):
|
def input_shape(self):
|
||||||
|
|||||||
@@ -16,12 +16,16 @@ class XOR3d(FuncFit):
|
|||||||
[1, 0, 1],
|
[1, 0, 1],
|
||||||
[1, 1, 0],
|
[1, 1, 0],
|
||||||
[1, 1, 1],
|
[1, 1, 1],
|
||||||
]
|
],
|
||||||
|
dtype=np.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def targets(self):
|
def targets(self):
|
||||||
return np.array([[0], [1], [1], [0], [1], [0], [0], [1]])
|
return np.array(
|
||||||
|
[[0], [1], [1], [0], [1], [0], [0], [1]],
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_shape(self):
|
def input_shape(self):
|
||||||
|
|||||||
@@ -1,2 +1,3 @@
|
|||||||
from .gymnax_env import GymNaxEnv
|
from .gymnax_env import GymNaxEnv
|
||||||
from .brax_env import BraxEnv
|
from .brax_env import BraxEnv
|
||||||
|
from .rl_jit import RLEnv
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from .activation import Act, act
|
from .activation import Act, act, ACT_ALL
|
||||||
from .aggregation import Agg, agg
|
from .aggregation import Agg, agg, AGG_ALL
|
||||||
from .tools import *
|
from .tools import *
|
||||||
from .graph import *
|
from .graph import *
|
||||||
from .state import State
|
from .state import State
|
||||||
|
|||||||
Reference in New Issue
Block a user