Files
tensorneat-mend/test/ranknet_neat.py

198 lines
7.0 KiB
Python

###this code will throw a ValueError
import numpy as np
from tensorneat import algorithm, genome, common
from tensorneat.pipeline import Pipeline
from tensorneat.genome.gene.node import DefaultNode
from tensorneat.genome.gene.conn import DefaultConn
from tensorneat.genome.operations import mutation
import jax, jax.numpy as jnp
from tensorneat.problem import BaseProblem
def binary_cross_entropy(prediction, target):
return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction))
# Define the custom Problem
class CustomProblem(BaseProblem):
jitable = True # necessary
def __init__(self, inputs, labels, threshold):
self.inputs = jnp.array(inputs) # nb! already has shape (n, 768)
self.labels = jnp.array(labels).reshape(
(-1, 1)
) # nb! has shape (n), must be transformed to have shape (n, 1)
self.threshold = threshold
# move the calculation related to pairwise_labels to problem initialization
pairwise_labels = self.labels - self.labels.T
self.pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold
# using nan istead of -inf
# as any mathmatical operation with nan will result in nan
pairwise_labels = jnp.where(self.pairs_to_keep, pairwise_labels, jnp.nan)
self.pairwise_labels = jnp.where(pairwise_labels > 0, True, False)
# # jit batch calculate accuracy in advance
# self.batch_cal_accuracy = jax.jit(
# jax.vmap(self.calculate_accuracy, in_axes=(None, None, 0))
# )
def evaluate(self, state, randkey, act_func, params):
# do batch forward for all inputs (using jax.vamp).
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
) # should be shape (len(labels), 1)
# calculating pairwise labels and predictions
pairwise_predictions = predict - predict.T # shape (len(inputs), len(inputs))
pairwise_predictions = jnp.where(
self.pairs_to_keep, pairwise_predictions, jnp.nan
)
pairwise_predictions = jax.nn.sigmoid(pairwise_predictions)
# calculate loss
loss = binary_cross_entropy(
pairwise_predictions, self.pairwise_labels
) # shape (len(labels), len(labels))
# jax.debug.print("loss={}", loss)
# reduce loss to a scalar
# we need to ignore nan value here
loss = jnp.mean(loss, where=~jnp.isnan(loss))
# return negative loss as fitness
# TensorNEAT maximizes fitness, equivalent to minimizing loss
return -loss
@property
def input_shape(self):
# the input shape that the act_func expects
return (self.inputs.shape[1],)
@property
def output_shape(self):
# the output shape that the act_func returns
return (1,)
def calculate_accuracy(self, state, act_func, params):
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
) # should be shape (len(labels), 1)
# calculating pairwise labels and predictions
pairwise_predictions = predict - predict.T # shape (len(inputs), len(inputs))
pairwise_predictions = jnp.where(
self.pairs_to_keep, pairwise_predictions, jnp.nan
)
pairwise_predictions = jnp.where(pairwise_predictions > 0, True, False)
accuracy = jnp.mean(
pairwise_predictions == self.pairwise_labels,
where=~jnp.isnan(pairwise_predictions),
)
return accuracy
def show_details(self, state, randkey, act_func, pop_params, *args, **kwargs):
# compile jax function when first call
if not hasattr(self, "batch_accuracy"):
def single_accuracy(state_, params_):
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state_, params_, self.inputs
) # should be shape (len(labels), 1)
pairwise_predictions = predict - predict.T # shape (len(inputs), len(inputs))
pairwise_predictions = jnp.where(
self.pairs_to_keep, pairwise_predictions, jnp.nan
)
pairwise_predictions = jnp.where(pairwise_predictions > 0, True, False)
accuracy = jnp.mean(
pairwise_predictions == self.pairwise_labels,
where=~jnp.isnan(pairwise_predictions),
)
return accuracy
self.batch_accuracy = jax.jit(
jax.vmap(single_accuracy, in_axes=(None, 0))
)
# calculate accuracy for the population
accuracys = self.batch_accuracy(state, pop_params)
accuracys = jax.device_get(accuracys) # move accuracys from gpu to cpu
max_a, min_a, mean_a, std_a = (
max(accuracys),
min(accuracys),
np.mean(accuracys),
np.std(accuracys),
)
print(
f"\tProblem Accuracy: max: {max_a:.4f}, min: {min_a:.4f}, mean: {mean_a:.4f}, std: {std_a:.4f}\n",
)
def show(self, state, randkey, act_func, params, *args, **kwargs):
# showcase the performance of one individual
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
)
loss = jnp.mean(jnp.square(predict - self.labels))
n_elements = 5
if n_elements > len(self.inputs):
n_elements = len(self.inputs)
msg = f"Looking at {n_elements} first elements of input\n"
for i in range(n_elements):
msg += (
f"for input i: {i}, target: {self.labels[i]}, predict: {predict[i]}\n"
)
msg += f"total loss: {loss}\n"
print(msg)
algorithm = algorithm.NEAT(
pop_size=10,
survival_threshold=0.2,
min_species_size=2,
compatibility_threshold=3.0,
species_elitism=2,
genome=genome.DefaultGenome(
num_inputs=768,
num_outputs=1,
max_nodes=769, # must at least be same as inputs and outputs
max_conns=768, # must be 768 connections for the network to be fully connected
output_transform=common.ACT.sigmoid,
mutation=mutation.DefaultMutation(
# no allowing adding or deleting nodes
node_add=0.0,
node_delete=0.0,
# set mutation rates for edges to 0.5
conn_add=0.5,
conn_delete=0.5,
),
node_gene=DefaultNode(),
conn_gene=DefaultConn(),
),
)
INPUTS = jax.random.uniform(jax.random.PRNGKey(0), (100, 768)) # the input data x
LABELS = jax.random.uniform(jax.random.PRNGKey(0), (100,)) # the annotated labels y
problem = CustomProblem(INPUTS, LABELS, 0.25)
print("Setting up pipeline and running it")
print("-----------------------------------------------------------------------")
pipeline = Pipeline(
algorithm,
problem,
generation_limit=5,
fitness_target=1,
seed=42,
show_problem_details=True,
)
state = pipeline.setup()
# run until termination
state, best = pipeline.auto_run(state)
# show results
pipeline.show(state, best)