add show_details in problem;

releated to https://github.com/EMI-Group/tensorneat/issues/15
This commit is contained in:
wls2002
2025-02-12 22:42:05 +08:00
parent de2d906656
commit e4f855b4f6
3 changed files with 126 additions and 28 deletions

View File

@@ -20,6 +20,7 @@ class Pipeline(StatefulBaseClass):
generation_limit: int = 1000, generation_limit: int = 1000,
is_save: bool = False, is_save: bool = False,
save_dir=None, save_dir=None,
show_problem_details: bool = False,
): ):
assert problem.jitable, "Currently, problem must be jitable" assert problem.jitable, "Currently, problem must be jitable"
@@ -54,6 +55,8 @@ class Pipeline(StatefulBaseClass):
if not os.path.exists(self.genome_dir): if not os.path.exists(self.genome_dir):
os.makedirs(self.genome_dir) os.makedirs(self.genome_dir)
self.show_problem_details = show_problem_details
def setup(self, state=State()): def setup(self, state=State()):
print("initializing") print("initializing")
state = state.register(randkey=jax.random.PRNGKey(self.seed)) state = state.register(randkey=jax.random.PRNGKey(self.seed))
@@ -99,6 +102,14 @@ class Pipeline(StatefulBaseClass):
print("start compile") print("start compile")
tic = time.time() tic = time.time()
compiled_step = jax.jit(self.step).lower(state).compile() compiled_step = jax.jit(self.step).lower(state).compile()
if self.show_problem_details:
self.compiled_pop_transform_func = (
jax.jit(jax.vmap(self.algorithm.transform, in_axes=(None, 0)))
.lower(self.algorithm.ask(state))
.compile()
)
# compiled_step = self.step # compiled_step = self.step
print( print(
f"compile finished, cost time: {time.time() - tic:.6f}s", f"compile finished, cost time: {time.time() - tic:.6f}s",
@@ -134,17 +145,20 @@ class Pipeline(StatefulBaseClass):
return state, self.best_genome return state, self.best_genome
def analysis(self, state, pop, fitnesses): def analysis(self, state, pop, fitnesses):
generation = int(state.generation) generation = int(state.generation)
valid_fitnesses = fitnesses[~np.isinf(fitnesses)] valid_fitnesses = fitnesses[~np.isinf(fitnesses)]
# avoid there is no valid fitness in the whole population
max_f, min_f, mean_f, std_f = ( if len(valid_fitnesses) == 0:
max(valid_fitnesses), max_f, min_f, mean_f, std_f = ["NaN"] * 4
min(valid_fitnesses), else:
np.mean(valid_fitnesses), max_f, min_f, mean_f, std_f = (
np.std(valid_fitnesses), max(valid_fitnesses),
) min(valid_fitnesses),
np.mean(valid_fitnesses),
np.std(valid_fitnesses),
)
new_timestamp = time.time() new_timestamp = time.time()
@@ -158,9 +172,7 @@ class Pipeline(StatefulBaseClass):
if self.is_save: if self.is_save:
# save best # save best
best_genome = jax.device_get((pop[0][max_idx], pop[1][max_idx])) best_genome = jax.device_get((pop[0][max_idx], pop[1][max_idx]))
file_name = os.path.join( file_name = os.path.join(self.genome_dir, f"{generation}.npz")
self.genome_dir, f"{generation}.npz"
)
with open(file_name, "wb") as f: with open(file_name, "wb") as f:
np.savez( np.savez(
f, f,
@@ -171,9 +183,7 @@ class Pipeline(StatefulBaseClass):
# append log # append log
with open(os.path.join(self.save_dir, "log.txt"), "a") as f: with open(os.path.join(self.save_dir, "log.txt"), "a") as f:
f.write( f.write(f"{generation},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n")
f"{generation},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n"
)
print( print(
f"Generation: {generation}, Cost time: {cost_time * 1000:.2f}ms\n", f"Generation: {generation}, Cost time: {cost_time * 1000:.2f}ms\n",
@@ -182,6 +192,15 @@ class Pipeline(StatefulBaseClass):
self.algorithm.show_details(state, fitnesses) self.algorithm.show_details(state, fitnesses)
if self.show_problem_details:
pop_transformed = self.compiled_pop_transform_func(
state, self.algorithm.ask(state)
)
self.problem.show_details(
state, state.randkey, self.algorithm.forward, pop_transformed
)
# show details for problem
def show(self, state, best, *args, **kwargs): def show(self, state, best, *args, **kwargs):
transformed = self.algorithm.transform(state, best) transformed = self.algorithm.transform(state, best)
return self.problem.show( return self.problem.show(

View File

@@ -33,3 +33,10 @@ class BaseProblem(StatefulBaseClass):
show how a genome perform in this problem show how a genome perform in this problem
""" """
raise NotImplementedError raise NotImplementedError
def show_details(self, state: State, randkey, act_func: Callable, pop_params, *args, **kwargs):
"""
show the running details of the problem
this function will be automaticly call in pipeline.auto_run()
"""
pass

View File

@@ -1,4 +1,5 @@
###this code will throw a ValueError ###this code will throw a ValueError
import numpy as np
from tensorneat import algorithm, genome, common from tensorneat import algorithm, genome, common
from tensorneat.pipeline import Pipeline from tensorneat.pipeline import Pipeline
from tensorneat.genome.gene.node import DefaultNode from tensorneat.genome.gene.node import DefaultNode
@@ -7,17 +8,21 @@ from tensorneat.genome.operations import mutation
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from tensorneat.problem import BaseProblem from tensorneat.problem import BaseProblem
def binary_cross_entropy(prediction, target): def binary_cross_entropy(prediction, target):
return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction)) return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction))
# Define the custom Problem # Define the custom Problem
class CustomProblem(BaseProblem): class CustomProblem(BaseProblem):
jitable = True # necessary jitable = True # necessary
def __init__(self, inputs, labels, threshold): def __init__(self, inputs, labels, threshold):
self.inputs = jnp.array(inputs) #nb! already has shape (n, 768) self.inputs = jnp.array(inputs) # nb! already has shape (n, 768)
self.labels = jnp.array(labels).reshape((-1,1)) #nb! has shape (n), must be transformed to have shape (n, 1) self.labels = jnp.array(labels).reshape(
(-1, 1)
) # nb! has shape (n), must be transformed to have shape (n, 1)
self.threshold = threshold self.threshold = threshold
# move the calculation related to pairwise_labels to problem initialization # move the calculation related to pairwise_labels to problem initialization
@@ -28,21 +33,29 @@ class CustomProblem(BaseProblem):
pairwise_labels = jnp.where(self.pairs_to_keep, pairwise_labels, jnp.nan) pairwise_labels = jnp.where(self.pairs_to_keep, pairwise_labels, jnp.nan)
self.pairwise_labels = jnp.where(pairwise_labels > 0, True, False) self.pairwise_labels = jnp.where(pairwise_labels > 0, True, False)
# # jit batch calculate accuracy in advance
# self.batch_cal_accuracy = jax.jit(
# jax.vmap(self.calculate_accuracy, in_axes=(None, None, 0))
# )
def evaluate(self, state, randkey, act_func, params): def evaluate(self, state, randkey, act_func, params):
# do batch forward for all inputs (using jax.vamp). # do batch forward for all inputs (using jax.vamp).
predict = jax.vmap(act_func, in_axes=(None, None, 0))( predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs state, params, self.inputs
) # should be shape (len(labels), 1) ) # should be shape (len(labels), 1)
#calculating pairwise labels and predictions # calculating pairwise labels and predictions
pairwise_predictions = predict - predict.T # shape (len(inputs), len(inputs)) pairwise_predictions = predict - predict.T # shape (len(inputs), len(inputs))
pairwise_predictions = jnp.where(self.pairs_to_keep, pairwise_predictions, jnp.nan) pairwise_predictions = jnp.where(
self.pairs_to_keep, pairwise_predictions, jnp.nan
)
pairwise_predictions = jax.nn.sigmoid(pairwise_predictions) pairwise_predictions = jax.nn.sigmoid(pairwise_predictions)
# calculate loss # calculate loss
loss = binary_cross_entropy(pairwise_predictions, self.pairwise_labels) # shape (len(labels), len(labels)) loss = binary_cross_entropy(
pairwise_predictions, self.pairwise_labels
) # shape (len(labels), len(labels))
# jax.debug.print("loss={}", loss) # jax.debug.print("loss={}", loss)
# reduce loss to a scalar # reduce loss to a scalar
# we need to ignore nan value here # we need to ignore nan value here
@@ -61,9 +74,64 @@ class CustomProblem(BaseProblem):
# the output shape that the act_func returns # the output shape that the act_func returns
return (1,) return (1,)
def calculate_accuracy(self, state, act_func, params):
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
) # should be shape (len(labels), 1)
# calculating pairwise labels and predictions
pairwise_predictions = predict - predict.T # shape (len(inputs), len(inputs))
pairwise_predictions = jnp.where(
self.pairs_to_keep, pairwise_predictions, jnp.nan
)
pairwise_predictions = jnp.where(pairwise_predictions > 0, True, False)
accuracy = jnp.mean(
pairwise_predictions == self.pairwise_labels,
where=~jnp.isnan(pairwise_predictions),
)
return accuracy
def show_details(self, state, randkey, act_func, pop_params, *args, **kwargs):
# compile jax function when first call
if not hasattr(self, "batch_accuracy"):
def single_accuracy(state_, params_):
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state_, params_, self.inputs
) # should be shape (len(labels), 1)
pairwise_predictions = predict - predict.T # shape (len(inputs), len(inputs))
pairwise_predictions = jnp.where(
self.pairs_to_keep, pairwise_predictions, jnp.nan
)
pairwise_predictions = jnp.where(pairwise_predictions > 0, True, False)
accuracy = jnp.mean(
pairwise_predictions == self.pairwise_labels,
where=~jnp.isnan(pairwise_predictions),
)
return accuracy
self.batch_accuracy = jax.jit(
jax.vmap(single_accuracy, in_axes=(None, 0))
)
# calculate accuracy for the population
accuracys = self.batch_accuracy(state, pop_params)
accuracys = jax.device_get(accuracys) # move accuracys from gpu to cpu
max_a, min_a, mean_a, std_a = (
max(accuracys),
min(accuracys),
np.mean(accuracys),
np.std(accuracys),
)
print(
f"\tProblem Accuracy: max: {max_a:.4f}, min: {min_a:.4f}, mean: {mean_a:.4f}, std: {std_a:.4f}\n",
)
def show(self, state, randkey, act_func, params, *args, **kwargs): def show(self, state, randkey, act_func, params, *args, **kwargs):
# showcase the performance of one individual # showcase the performance of one individual
predict = jax.vmap(act_func, in_axes=(None, None, 0))(state, params, self.inputs) predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
)
loss = jnp.mean(jnp.square(predict - self.labels)) loss = jnp.mean(jnp.square(predict - self.labels))
@@ -73,16 +141,19 @@ class CustomProblem(BaseProblem):
msg = f"Looking at {n_elements} first elements of input\n" msg = f"Looking at {n_elements} first elements of input\n"
for i in range(n_elements): for i in range(n_elements):
msg += f"for input i: {i}, target: {self.labels[i]}, predict: {predict[i]}\n" msg += (
f"for input i: {i}, target: {self.labels[i]}, predict: {predict[i]}\n"
)
msg += f"total loss: {loss}\n" msg += f"total loss: {loss}\n"
print(msg) print(msg)
algorithm = algorithm.NEAT( algorithm = algorithm.NEAT(
pop_size=10, pop_size=10,
survival_threshold=0.2, survival_threshold=0.2,
min_species_size=2, min_species_size=2,
compatibility_threshold=3.0, compatibility_threshold=3.0,
species_elitism=2, species_elitism=2,
genome=genome.DefaultGenome( genome=genome.DefaultGenome(
num_inputs=768, num_inputs=768,
num_outputs=1, num_outputs=1,
@@ -103,8 +174,8 @@ algorithm = algorithm.NEAT(
) )
INPUTS = jax.random.uniform(jax.random.PRNGKey(0), (100, 768)) #the input data x INPUTS = jax.random.uniform(jax.random.PRNGKey(0), (100, 768)) # the input data x
LABELS = jax.random.uniform(jax.random.PRNGKey(0), (100, )) #the annotated labels y LABELS = jax.random.uniform(jax.random.PRNGKey(0), (100,)) # the annotated labels y
problem = CustomProblem(INPUTS, LABELS, 0.25) problem = CustomProblem(INPUTS, LABELS, 0.25)
@@ -113,13 +184,14 @@ print("-----------------------------------------------------------------------")
pipeline = Pipeline( pipeline = Pipeline(
algorithm, algorithm,
problem, problem,
generation_limit=1, generation_limit=5,
fitness_target=1, fitness_target=1,
seed=42, seed=42,
show_problem_details=True,
) )
state = pipeline.setup() state = pipeline.setup()
# run until termination # run until termination
state, best = pipeline.auto_run(state) state, best = pipeline.auto_run(state)
# show results # show results
pipeline.show(state, best) pipeline.show(state, best)