add show_details in problem;
releated to https://github.com/EMI-Group/tensorneat/issues/15
This commit is contained in:
@@ -20,6 +20,7 @@ class Pipeline(StatefulBaseClass):
|
||||
generation_limit: int = 1000,
|
||||
is_save: bool = False,
|
||||
save_dir=None,
|
||||
show_problem_details: bool = False,
|
||||
):
|
||||
assert problem.jitable, "Currently, problem must be jitable"
|
||||
|
||||
@@ -54,6 +55,8 @@ class Pipeline(StatefulBaseClass):
|
||||
if not os.path.exists(self.genome_dir):
|
||||
os.makedirs(self.genome_dir)
|
||||
|
||||
self.show_problem_details = show_problem_details
|
||||
|
||||
def setup(self, state=State()):
|
||||
print("initializing")
|
||||
state = state.register(randkey=jax.random.PRNGKey(self.seed))
|
||||
@@ -99,6 +102,14 @@ class Pipeline(StatefulBaseClass):
|
||||
print("start compile")
|
||||
tic = time.time()
|
||||
compiled_step = jax.jit(self.step).lower(state).compile()
|
||||
|
||||
if self.show_problem_details:
|
||||
self.compiled_pop_transform_func = (
|
||||
jax.jit(jax.vmap(self.algorithm.transform, in_axes=(None, 0)))
|
||||
.lower(self.algorithm.ask(state))
|
||||
.compile()
|
||||
)
|
||||
|
||||
# compiled_step = self.step
|
||||
print(
|
||||
f"compile finished, cost time: {time.time() - tic:.6f}s",
|
||||
@@ -138,7 +149,10 @@ class Pipeline(StatefulBaseClass):
|
||||
generation = int(state.generation)
|
||||
|
||||
valid_fitnesses = fitnesses[~np.isinf(fitnesses)]
|
||||
|
||||
# avoid there is no valid fitness in the whole population
|
||||
if len(valid_fitnesses) == 0:
|
||||
max_f, min_f, mean_f, std_f = ["NaN"] * 4
|
||||
else:
|
||||
max_f, min_f, mean_f, std_f = (
|
||||
max(valid_fitnesses),
|
||||
min(valid_fitnesses),
|
||||
@@ -158,9 +172,7 @@ class Pipeline(StatefulBaseClass):
|
||||
if self.is_save:
|
||||
# save best
|
||||
best_genome = jax.device_get((pop[0][max_idx], pop[1][max_idx]))
|
||||
file_name = os.path.join(
|
||||
self.genome_dir, f"{generation}.npz"
|
||||
)
|
||||
file_name = os.path.join(self.genome_dir, f"{generation}.npz")
|
||||
with open(file_name, "wb") as f:
|
||||
np.savez(
|
||||
f,
|
||||
@@ -171,9 +183,7 @@ class Pipeline(StatefulBaseClass):
|
||||
|
||||
# append log
|
||||
with open(os.path.join(self.save_dir, "log.txt"), "a") as f:
|
||||
f.write(
|
||||
f"{generation},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n"
|
||||
)
|
||||
f.write(f"{generation},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n")
|
||||
|
||||
print(
|
||||
f"Generation: {generation}, Cost time: {cost_time * 1000:.2f}ms\n",
|
||||
@@ -182,6 +192,15 @@ class Pipeline(StatefulBaseClass):
|
||||
|
||||
self.algorithm.show_details(state, fitnesses)
|
||||
|
||||
if self.show_problem_details:
|
||||
pop_transformed = self.compiled_pop_transform_func(
|
||||
state, self.algorithm.ask(state)
|
||||
)
|
||||
self.problem.show_details(
|
||||
state, state.randkey, self.algorithm.forward, pop_transformed
|
||||
)
|
||||
# show details for problem
|
||||
|
||||
def show(self, state, best, *args, **kwargs):
|
||||
transformed = self.algorithm.transform(state, best)
|
||||
return self.problem.show(
|
||||
|
||||
@@ -33,3 +33,10 @@ class BaseProblem(StatefulBaseClass):
|
||||
show how a genome perform in this problem
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def show_details(self, state: State, randkey, act_func: Callable, pop_params, *args, **kwargs):
|
||||
"""
|
||||
show the running details of the problem
|
||||
this function will be automaticly call in pipeline.auto_run()
|
||||
"""
|
||||
pass
|
||||
@@ -1,4 +1,5 @@
|
||||
###this code will throw a ValueError
|
||||
import numpy as np
|
||||
from tensorneat import algorithm, genome, common
|
||||
from tensorneat.pipeline import Pipeline
|
||||
from tensorneat.genome.gene.node import DefaultNode
|
||||
@@ -7,9 +8,11 @@ from tensorneat.genome.operations import mutation
|
||||
import jax, jax.numpy as jnp
|
||||
from tensorneat.problem import BaseProblem
|
||||
|
||||
|
||||
def binary_cross_entropy(prediction, target):
|
||||
return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction))
|
||||
|
||||
|
||||
# Define the custom Problem
|
||||
class CustomProblem(BaseProblem):
|
||||
|
||||
@@ -17,7 +20,9 @@ class CustomProblem(BaseProblem):
|
||||
|
||||
def __init__(self, inputs, labels, threshold):
|
||||
self.inputs = jnp.array(inputs) # nb! already has shape (n, 768)
|
||||
self.labels = jnp.array(labels).reshape((-1,1)) #nb! has shape (n), must be transformed to have shape (n, 1)
|
||||
self.labels = jnp.array(labels).reshape(
|
||||
(-1, 1)
|
||||
) # nb! has shape (n), must be transformed to have shape (n, 1)
|
||||
self.threshold = threshold
|
||||
|
||||
# move the calculation related to pairwise_labels to problem initialization
|
||||
@@ -28,6 +33,10 @@ class CustomProblem(BaseProblem):
|
||||
pairwise_labels = jnp.where(self.pairs_to_keep, pairwise_labels, jnp.nan)
|
||||
self.pairwise_labels = jnp.where(pairwise_labels > 0, True, False)
|
||||
|
||||
# # jit batch calculate accuracy in advance
|
||||
# self.batch_cal_accuracy = jax.jit(
|
||||
# jax.vmap(self.calculate_accuracy, in_axes=(None, None, 0))
|
||||
# )
|
||||
|
||||
def evaluate(self, state, randkey, act_func, params):
|
||||
# do batch forward for all inputs (using jax.vamp).
|
||||
@@ -38,11 +47,15 @@ class CustomProblem(BaseProblem):
|
||||
# calculating pairwise labels and predictions
|
||||
pairwise_predictions = predict - predict.T # shape (len(inputs), len(inputs))
|
||||
|
||||
pairwise_predictions = jnp.where(self.pairs_to_keep, pairwise_predictions, jnp.nan)
|
||||
pairwise_predictions = jnp.where(
|
||||
self.pairs_to_keep, pairwise_predictions, jnp.nan
|
||||
)
|
||||
pairwise_predictions = jax.nn.sigmoid(pairwise_predictions)
|
||||
|
||||
# calculate loss
|
||||
loss = binary_cross_entropy(pairwise_predictions, self.pairwise_labels) # shape (len(labels), len(labels))
|
||||
loss = binary_cross_entropy(
|
||||
pairwise_predictions, self.pairwise_labels
|
||||
) # shape (len(labels), len(labels))
|
||||
# jax.debug.print("loss={}", loss)
|
||||
# reduce loss to a scalar
|
||||
# we need to ignore nan value here
|
||||
@@ -61,9 +74,64 @@ class CustomProblem(BaseProblem):
|
||||
# the output shape that the act_func returns
|
||||
return (1,)
|
||||
|
||||
def calculate_accuracy(self, state, act_func, params):
|
||||
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
|
||||
state, params, self.inputs
|
||||
) # should be shape (len(labels), 1)
|
||||
|
||||
# calculating pairwise labels and predictions
|
||||
pairwise_predictions = predict - predict.T # shape (len(inputs), len(inputs))
|
||||
pairwise_predictions = jnp.where(
|
||||
self.pairs_to_keep, pairwise_predictions, jnp.nan
|
||||
)
|
||||
|
||||
pairwise_predictions = jnp.where(pairwise_predictions > 0, True, False)
|
||||
accuracy = jnp.mean(
|
||||
pairwise_predictions == self.pairwise_labels,
|
||||
where=~jnp.isnan(pairwise_predictions),
|
||||
)
|
||||
return accuracy
|
||||
|
||||
def show_details(self, state, randkey, act_func, pop_params, *args, **kwargs):
|
||||
# compile jax function when first call
|
||||
if not hasattr(self, "batch_accuracy"):
|
||||
def single_accuracy(state_, params_):
|
||||
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
|
||||
state_, params_, self.inputs
|
||||
) # should be shape (len(labels), 1)
|
||||
pairwise_predictions = predict - predict.T # shape (len(inputs), len(inputs))
|
||||
pairwise_predictions = jnp.where(
|
||||
self.pairs_to_keep, pairwise_predictions, jnp.nan
|
||||
)
|
||||
|
||||
pairwise_predictions = jnp.where(pairwise_predictions > 0, True, False)
|
||||
accuracy = jnp.mean(
|
||||
pairwise_predictions == self.pairwise_labels,
|
||||
where=~jnp.isnan(pairwise_predictions),
|
||||
)
|
||||
return accuracy
|
||||
self.batch_accuracy = jax.jit(
|
||||
jax.vmap(single_accuracy, in_axes=(None, 0))
|
||||
)
|
||||
|
||||
# calculate accuracy for the population
|
||||
accuracys = self.batch_accuracy(state, pop_params)
|
||||
accuracys = jax.device_get(accuracys) # move accuracys from gpu to cpu
|
||||
max_a, min_a, mean_a, std_a = (
|
||||
max(accuracys),
|
||||
min(accuracys),
|
||||
np.mean(accuracys),
|
||||
np.std(accuracys),
|
||||
)
|
||||
print(
|
||||
f"\tProblem Accuracy: max: {max_a:.4f}, min: {min_a:.4f}, mean: {mean_a:.4f}, std: {std_a:.4f}\n",
|
||||
)
|
||||
|
||||
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||
# showcase the performance of one individual
|
||||
predict = jax.vmap(act_func, in_axes=(None, None, 0))(state, params, self.inputs)
|
||||
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
|
||||
state, params, self.inputs
|
||||
)
|
||||
|
||||
loss = jnp.mean(jnp.square(predict - self.labels))
|
||||
|
||||
@@ -73,10 +141,13 @@ class CustomProblem(BaseProblem):
|
||||
|
||||
msg = f"Looking at {n_elements} first elements of input\n"
|
||||
for i in range(n_elements):
|
||||
msg += f"for input i: {i}, target: {self.labels[i]}, predict: {predict[i]}\n"
|
||||
msg += (
|
||||
f"for input i: {i}, target: {self.labels[i]}, predict: {predict[i]}\n"
|
||||
)
|
||||
msg += f"total loss: {loss}\n"
|
||||
print(msg)
|
||||
|
||||
|
||||
algorithm = algorithm.NEAT(
|
||||
pop_size=10,
|
||||
survival_threshold=0.2,
|
||||
@@ -113,9 +184,10 @@ print("-----------------------------------------------------------------------")
|
||||
pipeline = Pipeline(
|
||||
algorithm,
|
||||
problem,
|
||||
generation_limit=1,
|
||||
generation_limit=5,
|
||||
fitness_target=1,
|
||||
seed=42,
|
||||
show_problem_details=True,
|
||||
)
|
||||
|
||||
state = pipeline.setup()
|
||||
|
||||
Reference in New Issue
Block a user