modify pipeline for "update_by_data";
fix bug in speciate. currently, node_delete and conn_delete can successfully work
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
from functools import partial
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from algorithm import BaseAlgorithm
|
||||
from problem import BaseProblem
|
||||
from problem.rl_env import RLEnv
|
||||
from problem.func_fit import FuncFit
|
||||
from utils import State
|
||||
|
||||
|
||||
@@ -17,6 +17,8 @@ class Pipeline:
|
||||
seed: int = 42,
|
||||
fitness_target: float = 1,
|
||||
generation_limit: int = 1000,
|
||||
pre_update: bool = False,
|
||||
update_batch_size: int = 10000,
|
||||
):
|
||||
assert problem.jitable, "Currently, problem must be jitable"
|
||||
|
||||
@@ -37,10 +39,30 @@ class Pipeline:
|
||||
self.best_genome = None
|
||||
self.best_fitness = float("-inf")
|
||||
self.generation_timestamp = None
|
||||
self.pre_update = pre_update
|
||||
self.update_batch_size = update_batch_size
|
||||
if pre_update:
|
||||
if isinstance(problem, RLEnv):
|
||||
assert problem.record_episode, "record_episode must be True"
|
||||
self.fetch_data = lambda episode: episode["obs"]
|
||||
elif isinstance(problem, FuncFit):
|
||||
assert problem.return_data, "return_data must be True"
|
||||
self.fetch_data = lambda data: data
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def setup(self, state=State()):
|
||||
print("initializing")
|
||||
state = state.register(randkey=jax.random.PRNGKey(self.seed))
|
||||
|
||||
if self.pre_update:
|
||||
# initial with mean = 0 and std = 1
|
||||
state = state.register(
|
||||
data=jax.random.normal(
|
||||
state.randkey, (self.update_batch_size, self.algorithm.num_inputs)
|
||||
)
|
||||
)
|
||||
|
||||
state = self.algorithm.setup(state)
|
||||
state = self.problem.setup(state)
|
||||
print("initializing finished")
|
||||
@@ -57,9 +79,42 @@ class Pipeline:
|
||||
state, pop
|
||||
)
|
||||
|
||||
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
|
||||
state, keys, self.algorithm.forward, pop_transformed
|
||||
)
|
||||
if self.pre_update:
|
||||
# update the population
|
||||
_, pop_transformed = jax.vmap(
|
||||
self.algorithm.update_by_batch, in_axes=(None, None, 0)
|
||||
)(state, state.data, pop_transformed)
|
||||
|
||||
# raw_data: (Pop, Batch, num_inputs)
|
||||
fitnesses, raw_data = jax.vmap(
|
||||
self.problem.evaluate, in_axes=(None, 0, None, 0)
|
||||
)(state, keys, self.algorithm.forward, pop_transformed)
|
||||
|
||||
data = self.fetch_data(raw_data)
|
||||
assert (
|
||||
data.ndim == 3
|
||||
and data.shape[0] == self.pop_size
|
||||
and data.shape[2] == self.algorithm.num_inputs
|
||||
)
|
||||
# reshape to (Pop * Batch, num_inputs)
|
||||
data = data.reshape(
|
||||
data.shape[0] * data.shape[1], self.algorithm.num_inputs
|
||||
)
|
||||
# shuffle
|
||||
data = jax.random.permutation(randkey_, data, axis=0)
|
||||
# cutoff or expand
|
||||
if data.shape[0] >= self.update_batch_size:
|
||||
data = data[: self.update_batch_size] # cutoff
|
||||
else:
|
||||
data = (
|
||||
jnp.full(state.data.shape, jnp.nan).at[: data.shape[0]].set(data)
|
||||
) # expand
|
||||
state = state.update(data=data)
|
||||
|
||||
else:
|
||||
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
|
||||
state, keys, self.algorithm.forward, pop_transformed
|
||||
)
|
||||
|
||||
state = self.algorithm.tell(state, fitnesses)
|
||||
|
||||
@@ -89,24 +144,18 @@ class Pipeline:
|
||||
print("Fitness limit reached!")
|
||||
return state, self.best_genome
|
||||
|
||||
# node = previous_pop[0][0][:, 0]
|
||||
# node_count = jnp.sum(~jnp.isnan(node))
|
||||
# conn = previous_pop[1][0][:, 0]
|
||||
# conn_count = jnp.sum(~jnp.isnan(conn))
|
||||
# if (w % 5 == 0):
|
||||
# print("node_count", node_count)
|
||||
# print("conn_count", conn_count)
|
||||
|
||||
print("Generation limit reached!")
|
||||
return state, self.best_genome
|
||||
|
||||
def analysis(self, state, pop, fitnesses):
|
||||
|
||||
valid_fitnesses = fitnesses[~np.isnan(fitnesses)]
|
||||
|
||||
max_f, min_f, mean_f, std_f = (
|
||||
max(fitnesses),
|
||||
min(fitnesses),
|
||||
np.mean(fitnesses),
|
||||
np.std(fitnesses),
|
||||
max(valid_fitnesses),
|
||||
min(valid_fitnesses),
|
||||
np.mean(valid_fitnesses),
|
||||
np.std(valid_fitnesses),
|
||||
)
|
||||
|
||||
new_timestamp = time.time()
|
||||
@@ -122,9 +171,9 @@ class Pipeline:
|
||||
species_sizes = [int(i) for i in member_count if i > 0]
|
||||
|
||||
print(
|
||||
f"Generation: {self.algorithm.generation(state)}",
|
||||
f"species: {len(species_sizes)}, {species_sizes}",
|
||||
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms",
|
||||
f"Generation: {self.algorithm.generation(state)}, Cost time: {cost_time * 1000:.2f}ms\n",
|
||||
f"\tspecies: {len(species_sizes)}, {species_sizes}\n",
|
||||
f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n",
|
||||
)
|
||||
|
||||
def show(self, state, best, *args, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user