modify pipeline for "update_by_data";

fix bug in speciate. currently, node_delete and conn_delete can successfully work
This commit is contained in:
wls2002
2024-05-31 15:32:56 +08:00
parent 3ea9986bd4
commit 6aa9011043
12 changed files with 132 additions and 45 deletions

View File

@@ -19,9 +19,15 @@ class BaseAlgorithm:
"""transform the genome into a neural network""" """transform the genome into a neural network"""
raise NotImplementedError raise NotImplementedError
def restore(self, state, transformed):
raise NotImplementedError
def forward(self, state, inputs, transformed): def forward(self, state, inputs, transformed):
raise NotImplementedError raise NotImplementedError
def update_by_batch(self, state, batch_input, transformed):
raise NotImplementedError
@property @property
def num_inputs(self): def num_inputs(self):
raise NotImplementedError raise NotImplementedError

View File

@@ -178,15 +178,22 @@ class DefaultMutation(BaseMutation):
def no(key_, nodes_, conns_): def no(key_, nodes_, conns_):
return nodes_, conns_ return nodes_, conns_
if self.node_add > 0:
nodes, conns = jax.lax.cond( nodes, conns = jax.lax.cond(
r1 < self.node_add, mutate_add_node, no, k1, nodes, conns r1 < self.node_add, mutate_add_node, no, k1, nodes, conns
) )
if self.node_delete > 0:
nodes, conns = jax.lax.cond( nodes, conns = jax.lax.cond(
r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns
) )
if self.conn_add > 0:
nodes, conns = jax.lax.cond( nodes, conns = jax.lax.cond(
r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns
) )
if self.conn_delete > 0:
nodes, conns = jax.lax.cond( nodes, conns = jax.lax.cond(
r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns
) )

View File

@@ -117,7 +117,9 @@ class DefaultGenome(BaseGenome):
def hit(): def hit():
batch_ins, new_conn_attrs = jax.vmap( batch_ins, new_conn_attrs = jax.vmap(
self.conn_gene.update_by_batch, in_axes=(None, 1, 1), out_axes=(1, 1) self.conn_gene.update_by_batch,
in_axes=(None, 1, 1),
out_axes=(1, 1),
)(state, u_conns_[:, :, i], batch_values) )(state, u_conns_[:, :, i], batch_values)
batch_z, new_node_attrs = self.node_gene.update_by_batch( batch_z, new_node_attrs = self.node_gene.update_by_batch(
state, state,
@@ -132,12 +134,12 @@ class DefaultGenome(BaseGenome):
u_conns_.at[:, :, i].set(new_conn_attrs), u_conns_.at[:, :, i].set(new_conn_attrs),
) )
# the val of input nodes is obtained by the task, not by calculation
(batch_values, nodes_attrs_, u_conns_) = jax.lax.cond( (batch_values, nodes_attrs_, u_conns_) = jax.lax.cond(
jnp.isin(i, self.input_idx), jnp.isin(i, self.input_idx),
lambda: (batch_values, nodes_attrs_, u_conns_), lambda: (batch_values, nodes_attrs_, u_conns_),
hit, hit,
) )
# the val of input nodes is obtained by the task, not by calculation
return batch_values, nodes_attrs_, u_conns_, idx + 1 return batch_values, nodes_attrs_, u_conns_, idx + 1

View File

@@ -44,9 +44,15 @@ class NEAT(BaseAlgorithm):
nodes, conns = individual nodes, conns = individual
return self.genome.transform(state, nodes, conns) return self.genome.transform(state, nodes, conns)
def restore(self, state, transformed):
return self.genome.restore(state, transformed)
def forward(self, state, inputs, transformed): def forward(self, state, inputs, transformed):
return self.genome.forward(state, inputs, transformed) return self.genome.forward(state, inputs, transformed)
def update_by_batch(self, state, batch_input, transformed):
return self.genome.update_by_batch(state, batch_input, transformed)
@property @property
def num_inputs(self): def num_inputs(self):
return self.genome.num_inputs return self.genome.num_inputs

View File

@@ -113,6 +113,9 @@ class DefaultSpecies(BaseSpecies):
return state.pop_nodes, state.pop_conns return state.pop_nodes, state.pop_conns
def update_species(self, state, fitness): def update_species(self, state, fitness):
# set nan to -inf
fitness = jnp.where(jnp.isnan(fitness), -jnp.inf, fitness)
# update the fitness of each species # update the fitness of each species
state, species_fitness = self.update_species_fitness(state, fitness) state, species_fitness = self.update_species_fitness(state, fitness)
@@ -121,6 +124,7 @@ class DefaultSpecies(BaseSpecies):
# sort species_info by their fitness. (also push nan to the end) # sort species_info by their fitness. (also push nan to the end)
sort_indices = jnp.argsort(species_fitness)[::-1] sort_indices = jnp.argsort(species_fitness)[::-1]
state = state.update( state = state.update(
species_keys=state.species_keys[sort_indices], species_keys=state.species_keys[sort_indices],
best_fitness=state.best_fitness[sort_indices], best_fitness=state.best_fitness[sort_indices],

View File

@@ -21,11 +21,11 @@ if __name__ == "__main__":
mutation=DefaultMutation( mutation=DefaultMutation(
node_add=0.05, node_add=0.05,
conn_add=0.05, conn_add=0.05,
node_delete=0, node_delete=0.05,
conn_delete=0, conn_delete=0.05,
), ),
), ),
pop_size=100, pop_size=1000,
species_size=20, species_size=20,
compatibility_threshold=2, compatibility_threshold=2,
survival_threshold=0.01, # magic survival_threshold=0.01, # magic

View File

@@ -1,11 +1,11 @@
from functools import partial
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
import time import time
import numpy as np import numpy as np
from algorithm import BaseAlgorithm from algorithm import BaseAlgorithm
from problem import BaseProblem from problem import BaseProblem
from problem.rl_env import RLEnv
from problem.func_fit import FuncFit
from utils import State from utils import State
@@ -17,6 +17,8 @@ class Pipeline:
seed: int = 42, seed: int = 42,
fitness_target: float = 1, fitness_target: float = 1,
generation_limit: int = 1000, generation_limit: int = 1000,
pre_update: bool = False,
update_batch_size: int = 10000,
): ):
assert problem.jitable, "Currently, problem must be jitable" assert problem.jitable, "Currently, problem must be jitable"
@@ -37,10 +39,30 @@ class Pipeline:
self.best_genome = None self.best_genome = None
self.best_fitness = float("-inf") self.best_fitness = float("-inf")
self.generation_timestamp = None self.generation_timestamp = None
self.pre_update = pre_update
self.update_batch_size = update_batch_size
if pre_update:
if isinstance(problem, RLEnv):
assert problem.record_episode, "record_episode must be True"
self.fetch_data = lambda episode: episode["obs"]
elif isinstance(problem, FuncFit):
assert problem.return_data, "return_data must be True"
self.fetch_data = lambda data: data
else:
raise NotImplementedError
def setup(self, state=State()): def setup(self, state=State()):
print("initializing") print("initializing")
state = state.register(randkey=jax.random.PRNGKey(self.seed)) state = state.register(randkey=jax.random.PRNGKey(self.seed))
if self.pre_update:
# initial with mean = 0 and std = 1
state = state.register(
data=jax.random.normal(
state.randkey, (self.update_batch_size, self.algorithm.num_inputs)
)
)
state = self.algorithm.setup(state) state = self.algorithm.setup(state)
state = self.problem.setup(state) state = self.problem.setup(state)
print("initializing finished") print("initializing finished")
@@ -57,6 +79,39 @@ class Pipeline:
state, pop state, pop
) )
if self.pre_update:
# update the population
_, pop_transformed = jax.vmap(
self.algorithm.update_by_batch, in_axes=(None, None, 0)
)(state, state.data, pop_transformed)
# raw_data: (Pop, Batch, num_inputs)
fitnesses, raw_data = jax.vmap(
self.problem.evaluate, in_axes=(None, 0, None, 0)
)(state, keys, self.algorithm.forward, pop_transformed)
data = self.fetch_data(raw_data)
assert (
data.ndim == 3
and data.shape[0] == self.pop_size
and data.shape[2] == self.algorithm.num_inputs
)
# reshape to (Pop * Batch, num_inputs)
data = data.reshape(
data.shape[0] * data.shape[1], self.algorithm.num_inputs
)
# shuffle
data = jax.random.permutation(randkey_, data, axis=0)
# cutoff or expand
if data.shape[0] >= self.update_batch_size:
data = data[: self.update_batch_size] # cutoff
else:
data = (
jnp.full(state.data.shape, jnp.nan).at[: data.shape[0]].set(data)
) # expand
state = state.update(data=data)
else:
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))( fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
state, keys, self.algorithm.forward, pop_transformed state, keys, self.algorithm.forward, pop_transformed
) )
@@ -89,24 +144,18 @@ class Pipeline:
print("Fitness limit reached!") print("Fitness limit reached!")
return state, self.best_genome return state, self.best_genome
# node = previous_pop[0][0][:, 0]
# node_count = jnp.sum(~jnp.isnan(node))
# conn = previous_pop[1][0][:, 0]
# conn_count = jnp.sum(~jnp.isnan(conn))
# if (w % 5 == 0):
# print("node_count", node_count)
# print("conn_count", conn_count)
print("Generation limit reached!") print("Generation limit reached!")
return state, self.best_genome return state, self.best_genome
def analysis(self, state, pop, fitnesses): def analysis(self, state, pop, fitnesses):
valid_fitnesses = fitnesses[~np.isnan(fitnesses)]
max_f, min_f, mean_f, std_f = ( max_f, min_f, mean_f, std_f = (
max(fitnesses), max(valid_fitnesses),
min(fitnesses), min(valid_fitnesses),
np.mean(fitnesses), np.mean(valid_fitnesses),
np.std(fitnesses), np.std(valid_fitnesses),
) )
new_timestamp = time.time() new_timestamp = time.time()
@@ -122,9 +171,9 @@ class Pipeline:
species_sizes = [int(i) for i in member_count if i > 0] species_sizes = [int(i) for i in member_count if i > 0]
print( print(
f"Generation: {self.algorithm.generation(state)}", f"Generation: {self.algorithm.generation(state)}, Cost time: {cost_time * 1000:.2f}ms\n",
f"species: {len(species_sizes)}, {species_sizes}", f"\tspecies: {len(species_sizes)}, {species_sizes}\n",
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms", f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n",
) )
def show(self, state, best, *args, **kwargs): def show(self, state, best, *args, **kwargs):

View File

@@ -49,6 +49,9 @@ class FuncFit(BaseProblem):
state, self.inputs, params state, self.inputs, params
) )
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict]) inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
if self.return_data:
loss, _ = self.evaluate(state, randkey, act_func, params)
else:
loss = self.evaluate(state, randkey, act_func, params) loss = self.evaluate(state, randkey, act_func, params)
loss = -loss loss = -loss

View File

@@ -4,14 +4,19 @@ from .func_fit import FuncFit
class XOR(FuncFit): class XOR(FuncFit):
@property @property
def inputs(self): def inputs(self):
return np.array([[0, 0], [0, 1], [1, 0], [1, 1]]) return np.array(
[[0, 0], [0, 1], [1, 0], [1, 1]],
dtype=np.float32,
)
@property @property
def targets(self): def targets(self):
return np.array([[0], [1], [1], [0]]) return np.array(
[[0], [1], [1], [0]],
dtype=np.float32,
)
@property @property
def input_shape(self): def input_shape(self):

View File

@@ -16,12 +16,16 @@ class XOR3d(FuncFit):
[1, 0, 1], [1, 0, 1],
[1, 1, 0], [1, 1, 0],
[1, 1, 1], [1, 1, 1],
] ],
dtype=np.float32,
) )
@property @property
def targets(self): def targets(self):
return np.array([[0], [1], [1], [0], [1], [0], [0], [1]]) return np.array(
[[0], [1], [1], [0], [1], [0], [0], [1]],
dtype=np.float32,
)
@property @property
def input_shape(self): def input_shape(self):

View File

@@ -1,2 +1,3 @@
from .gymnax_env import GymNaxEnv from .gymnax_env import GymNaxEnv
from .brax_env import BraxEnv from .brax_env import BraxEnv
from .rl_jit import RLEnv

View File

@@ -1,5 +1,5 @@
from .activation import Act, act from .activation import Act, act, ACT_ALL
from .aggregation import Agg, agg from .aggregation import Agg, agg, AGG_ALL
from .tools import * from .tools import *
from .graph import * from .graph import *
from .state import State from .state import State