add show_details in problem;

releated to https://github.com/EMI-Group/tensorneat/issues/15
This commit is contained in:
wls2002
2025-02-12 22:42:05 +08:00
parent de2d906656
commit e4f855b4f6
3 changed files with 126 additions and 28 deletions

View File

@@ -1,4 +1,5 @@
###this code will throw a ValueError
import numpy as np
from tensorneat import algorithm, genome, common
from tensorneat.pipeline import Pipeline
from tensorneat.genome.gene.node import DefaultNode
@@ -7,17 +8,21 @@ from tensorneat.genome.operations import mutation
import jax, jax.numpy as jnp
from tensorneat.problem import BaseProblem
def binary_cross_entropy(prediction, target):
return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction))
# Define the custom Problem
class CustomProblem(BaseProblem):
jitable = True # necessary
def __init__(self, inputs, labels, threshold):
self.inputs = jnp.array(inputs) #nb! already has shape (n, 768)
self.labels = jnp.array(labels).reshape((-1,1)) #nb! has shape (n), must be transformed to have shape (n, 1)
self.inputs = jnp.array(inputs) # nb! already has shape (n, 768)
self.labels = jnp.array(labels).reshape(
(-1, 1)
) # nb! has shape (n), must be transformed to have shape (n, 1)
self.threshold = threshold
# move the calculation related to pairwise_labels to problem initialization
@@ -28,21 +33,29 @@ class CustomProblem(BaseProblem):
pairwise_labels = jnp.where(self.pairs_to_keep, pairwise_labels, jnp.nan)
self.pairwise_labels = jnp.where(pairwise_labels > 0, True, False)
# # jit batch calculate accuracy in advance
# self.batch_cal_accuracy = jax.jit(
# jax.vmap(self.calculate_accuracy, in_axes=(None, None, 0))
# )
def evaluate(self, state, randkey, act_func, params):
# do batch forward for all inputs (using jax.vamp).
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
) # should be shape (len(labels), 1)
#calculating pairwise labels and predictions
# calculating pairwise labels and predictions
pairwise_predictions = predict - predict.T # shape (len(inputs), len(inputs))
pairwise_predictions = jnp.where(self.pairs_to_keep, pairwise_predictions, jnp.nan)
pairwise_predictions = jnp.where(
self.pairs_to_keep, pairwise_predictions, jnp.nan
)
pairwise_predictions = jax.nn.sigmoid(pairwise_predictions)
# calculate loss
loss = binary_cross_entropy(pairwise_predictions, self.pairwise_labels) # shape (len(labels), len(labels))
loss = binary_cross_entropy(
pairwise_predictions, self.pairwise_labels
) # shape (len(labels), len(labels))
# jax.debug.print("loss={}", loss)
# reduce loss to a scalar
# we need to ignore nan value here
@@ -61,9 +74,64 @@ class CustomProblem(BaseProblem):
# the output shape that the act_func returns
return (1,)
def calculate_accuracy(self, state, act_func, params):
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
) # should be shape (len(labels), 1)
# calculating pairwise labels and predictions
pairwise_predictions = predict - predict.T # shape (len(inputs), len(inputs))
pairwise_predictions = jnp.where(
self.pairs_to_keep, pairwise_predictions, jnp.nan
)
pairwise_predictions = jnp.where(pairwise_predictions > 0, True, False)
accuracy = jnp.mean(
pairwise_predictions == self.pairwise_labels,
where=~jnp.isnan(pairwise_predictions),
)
return accuracy
def show_details(self, state, randkey, act_func, pop_params, *args, **kwargs):
# compile jax function when first call
if not hasattr(self, "batch_accuracy"):
def single_accuracy(state_, params_):
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state_, params_, self.inputs
) # should be shape (len(labels), 1)
pairwise_predictions = predict - predict.T # shape (len(inputs), len(inputs))
pairwise_predictions = jnp.where(
self.pairs_to_keep, pairwise_predictions, jnp.nan
)
pairwise_predictions = jnp.where(pairwise_predictions > 0, True, False)
accuracy = jnp.mean(
pairwise_predictions == self.pairwise_labels,
where=~jnp.isnan(pairwise_predictions),
)
return accuracy
self.batch_accuracy = jax.jit(
jax.vmap(single_accuracy, in_axes=(None, 0))
)
# calculate accuracy for the population
accuracys = self.batch_accuracy(state, pop_params)
accuracys = jax.device_get(accuracys) # move accuracys from gpu to cpu
max_a, min_a, mean_a, std_a = (
max(accuracys),
min(accuracys),
np.mean(accuracys),
np.std(accuracys),
)
print(
f"\tProblem Accuracy: max: {max_a:.4f}, min: {min_a:.4f}, mean: {mean_a:.4f}, std: {std_a:.4f}\n",
)
def show(self, state, randkey, act_func, params, *args, **kwargs):
# showcase the performance of one individual
predict = jax.vmap(act_func, in_axes=(None, None, 0))(state, params, self.inputs)
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
)
loss = jnp.mean(jnp.square(predict - self.labels))
@@ -73,16 +141,19 @@ class CustomProblem(BaseProblem):
msg = f"Looking at {n_elements} first elements of input\n"
for i in range(n_elements):
msg += f"for input i: {i}, target: {self.labels[i]}, predict: {predict[i]}\n"
msg += (
f"for input i: {i}, target: {self.labels[i]}, predict: {predict[i]}\n"
)
msg += f"total loss: {loss}\n"
print(msg)
algorithm = algorithm.NEAT(
pop_size=10,
survival_threshold=0.2,
min_species_size=2,
compatibility_threshold=3.0,
species_elitism=2,
compatibility_threshold=3.0,
species_elitism=2,
genome=genome.DefaultGenome(
num_inputs=768,
num_outputs=1,
@@ -103,8 +174,8 @@ algorithm = algorithm.NEAT(
)
INPUTS = jax.random.uniform(jax.random.PRNGKey(0), (100, 768)) #the input data x
LABELS = jax.random.uniform(jax.random.PRNGKey(0), (100, )) #the annotated labels y
INPUTS = jax.random.uniform(jax.random.PRNGKey(0), (100, 768)) # the input data x
LABELS = jax.random.uniform(jax.random.PRNGKey(0), (100,)) # the annotated labels y
problem = CustomProblem(INPUTS, LABELS, 0.25)
@@ -113,13 +184,14 @@ print("-----------------------------------------------------------------------")
pipeline = Pipeline(
algorithm,
problem,
generation_limit=1,
generation_limit=5,
fitness_target=1,
seed=42,
show_problem_details=True,
)
state = pipeline.setup()
# run until termination
state, best = pipeline.auto_run(state)
# show results
pipeline.show(state, best)
pipeline.show(state, best)