update problem and pipeline
This commit is contained in:
@@ -7,8 +7,6 @@ import numpy as np
|
||||
|
||||
from tensorneat.algorithm import BaseAlgorithm
|
||||
from tensorneat.problem import BaseProblem
|
||||
from tensorneat.problem.rl_env import RLEnv
|
||||
from tensorneat.problem.func_fit import FuncFit
|
||||
from tensorneat.common import State, StatefulBaseClass
|
||||
|
||||
|
||||
@@ -20,10 +18,8 @@ class Pipeline(StatefulBaseClass):
|
||||
seed: int = 42,
|
||||
fitness_target: float = 1,
|
||||
generation_limit: int = 1000,
|
||||
pre_update: bool = False,
|
||||
update_batch_size: int = 10000,
|
||||
save_dir=None,
|
||||
is_save: bool = False,
|
||||
save_dir=None,
|
||||
):
|
||||
assert problem.jitable, "Currently, problem must be jitable"
|
||||
|
||||
@@ -36,7 +32,6 @@ class Pipeline(StatefulBaseClass):
|
||||
|
||||
np.random.seed(self.seed)
|
||||
|
||||
# TODO: make each algorithm's input_num and output_num
|
||||
assert (
|
||||
algorithm.num_inputs == self.problem.input_shape[-1]
|
||||
), f"algorithm input shape is {algorithm.num_inputs} but problem input shape is {self.problem.input_shape}"
|
||||
@@ -44,22 +39,6 @@ class Pipeline(StatefulBaseClass):
|
||||
self.best_genome = None
|
||||
self.best_fitness = float("-inf")
|
||||
self.generation_timestamp = None
|
||||
self.pre_update = pre_update
|
||||
self.update_batch_size = update_batch_size
|
||||
if pre_update:
|
||||
if isinstance(problem, RLEnv):
|
||||
assert problem.record_episode, "record_episode must be True"
|
||||
self.fetch_data = lambda episode: episode["obs"]
|
||||
elif isinstance(problem, FuncFit):
|
||||
assert problem.return_data, "return_data must be True"
|
||||
self.fetch_data = lambda data: data
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
if isinstance(problem, RLEnv):
|
||||
assert not problem.record_episode, "record_episode must be False"
|
||||
elif isinstance(problem, FuncFit):
|
||||
assert not problem.return_data, "return_data must be False"
|
||||
self.is_save = is_save
|
||||
|
||||
if is_save:
|
||||
@@ -79,14 +58,6 @@ class Pipeline(StatefulBaseClass):
|
||||
print("initializing")
|
||||
state = state.register(randkey=jax.random.PRNGKey(self.seed))
|
||||
|
||||
if self.pre_update:
|
||||
# initial with mean = 0 and std = 1
|
||||
state = state.register(
|
||||
data=jax.random.normal(
|
||||
state.randkey, (self.update_batch_size, self.algorithm.num_inputs)
|
||||
)
|
||||
)
|
||||
|
||||
state = self.algorithm.setup(state)
|
||||
state = self.problem.setup(state)
|
||||
|
||||
@@ -112,49 +83,9 @@ class Pipeline(StatefulBaseClass):
|
||||
state, pop
|
||||
)
|
||||
|
||||
if self.pre_update:
|
||||
# update the population
|
||||
_, pop_transformed = jax.vmap(
|
||||
self.algorithm.update_by_batch, in_axes=(None, None, 0)
|
||||
)(state, state.data, pop_transformed)
|
||||
|
||||
# raw_data: (Pop, Batch, num_inputs)
|
||||
fitnesses, raw_data = jax.vmap(
|
||||
self.problem.evaluate, in_axes=(None, 0, None, 0)
|
||||
)(state, keys, self.algorithm.forward, pop_transformed)
|
||||
|
||||
# update population
|
||||
pop_nodes, pop_conns = jax.vmap(self.algorithm.restore, in_axes=(None, 0))(
|
||||
state, pop_transformed
|
||||
)
|
||||
state = state.update(pop_nodes=pop_nodes, pop_conns=pop_conns)
|
||||
|
||||
# update data for next generation
|
||||
data = self.fetch_data(raw_data)
|
||||
assert (
|
||||
data.ndim == 3
|
||||
and data.shape[0] == self.pop_size
|
||||
and data.shape[2] == self.algorithm.num_inputs
|
||||
)
|
||||
# reshape to (Pop * Batch, num_inputs)
|
||||
data = data.reshape(
|
||||
data.shape[0] * data.shape[1], self.algorithm.num_inputs
|
||||
)
|
||||
# shuffle
|
||||
data = jax.random.permutation(randkey_, data, axis=0)
|
||||
# cutoff or expand
|
||||
if data.shape[0] >= self.update_batch_size:
|
||||
data = data[: self.update_batch_size] # cutoff
|
||||
else:
|
||||
data = (
|
||||
jnp.full(state.data.shape, jnp.nan).at[: data.shape[0]].set(data)
|
||||
) # expand
|
||||
state = state.update(data=data)
|
||||
|
||||
else:
|
||||
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
|
||||
state, keys, self.algorithm.forward, pop_transformed
|
||||
)
|
||||
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
|
||||
state, keys, self.algorithm.forward, pop_transformed
|
||||
)
|
||||
|
||||
# replace nan with -inf
|
||||
fitnesses = jnp.where(jnp.isnan(fitnesses), -jnp.inf, fitnesses)
|
||||
|
||||
Reference in New Issue
Block a user