From de2d906656d642563c781b88c71539fadb3a6166 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Wed, 12 Feb 2025 21:37:56 +0800 Subject: [PATCH] add some test --- test/ranknet.py | 140 +++++++++++++++++++++++++++++++++++++++++++ test/ranknet_neat.py | 125 ++++++++++++++++++++++++++++++++++++++ test/test.ipynb | 118 ++++++++++++++++++++++++++++++++++++ test/test.py | 34 +++++++++++ 4 files changed, 417 insertions(+) create mode 100644 test/ranknet.py create mode 100644 test/ranknet_neat.py create mode 100644 test/test.ipynb create mode 100644 test/test.py diff --git a/test/ranknet.py b/test/ranknet.py new file mode 100644 index 0000000..1d232f0 --- /dev/null +++ b/test/ranknet.py @@ -0,0 +1,140 @@ +# import RankNet +from tensorneat import algorithm, genome, common +from tensorneat.pipeline import Pipeline +from tensorneat.genome import BiasNode +from tensorneat.genome.operations import mutation +from tensorneat.common import ACT, AGG +import jax, jax.numpy as jnp +from tensorneat.problem import BaseProblem + +data_num = 100 +input_size = 768 # Each network (genome) should have input size 768 + +# The problem is to optimize a RankNet utilizing NEAT + + +def binary_cross_entropy(prediction, target): + return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction)) + + +# Create dataset (100 samples of vectors with 768 features) +INPUTS = jax.random.uniform( + jax.random.PRNGKey(0), (data_num, input_size) +) # the input data x +LABELS = jax.random.uniform(jax.random.PRNGKey(0), (data_num, 1)) # the annotated labels y +# True (1): >=; False (0): < +pairwise_labels = jnp.where((LABELS - LABELS.T) >= 0, True, False) + +print(f"{INPUTS.shape=}, {LABELS.shape=}") + + +# Define the custom Problem +class CustomProblem(BaseProblem): + + jitable = True # necessary + + def evaluate(self, state, randkey, act_func, params): + # Use ``act_func(state, params, inputs)`` to do network forward + + # print("state: ", state) + # print("params: ",params) + # print("act_func: ",act_func) + + ans_to_question = True + + # Question: This is the same as doing a forward pass for the generated network? + # Meaning the network does 100 passes for all the elements of 768 features? + if ans_to_question: + # do batch forward for all inputs (using jax.vamp). + predict = jax.vmap(act_func, in_axes=(None, None, 0))( + state, params, INPUTS + ) # should be shape (100, 1) + else: + # I misunderstood, so I have to create a RankNet myself to predict the output + # Setting up with the values present in the genome + current_node = state.species.idx2species + current_node_weights = state.pop_conns[current_node] + net = RankNet.RankNet(input_size, current_node_weights) + predict = net.forward(INPUTS) + + pairwise_predictions = predict - predict.T # shape (100, 100) + p = jax.nn.sigmoid(pairwise_predictions) # shape (100, 100) + + # calculate loss + loss = binary_cross_entropy(p, pairwise_labels) # shape (100, 100) + # loss with shape (100, 100), we need to reduce it to a scalar + loss = jnp.mean(loss) + + # return negative loss as fitness + # TensorNEAT maximizes fitness, equivalent to minimizing loss + return -loss + + @property + def input_shape(self): + # the input shape that the act_func expects + return (input_size,) + + @property + def output_shape(self): + # the output shape that the act_func returns + return (1,) + + def show(self, state, randkey, act_func, params, *args, **kwargs): + # showcase the performance of one individual + predict = jax.vmap(act_func, in_axes=(None, None, 0))(state, params, INPUTS) + + loss = jnp.mean(jnp.square(predict - LABELS)) + + msg = "" + for i in range(INPUTS.shape[0]): + msg += f"input: {INPUTS[i]}, target: {LABELS[i]}, predict: {predict[i]}\n" + msg += f"loss: {loss}\n" + print(msg) + + +algorithm1 = algorithm.NEAT( + # setting values to be the same as default in python NEAT package to get same as paper authors + # tried as best I could to follow this https://neat-python.readthedocs.io/en/latest/config_file.html + pop_size=100, + survival_threshold=0.2, + min_species_size=2, + species_number_calculate_by="fitness", # either this or rank, but 'fitness' should be more in line with original paper on NEAT + # species_size=10, #nothing specified for species_size, it remains default + # modifying the values the authors explicitly mention + compatibility_threshold=3.0, # maybe need to consider this one in the future if weird results, default is 2.0 + species_elitism=2, # is 2 per default + genome=genome.DefaultGenome( + num_inputs=768, + num_outputs=1, + max_nodes=769, # must at least be same as inputs and outputs + max_conns=768, # must be 768 connections for the network to be fully connected + # 0 hidden layers per default + output_transform=common.ACT.sigmoid, + mutation=mutation.DefaultMutation( + # no allowing adding or deleting nodes + node_add=0.0, + node_delete=0.0, + # set mutation rates for edges to 0.5 + conn_add=0.5, + conn_delete=0.5, + ), + node_gene=BiasNode(), + ), +) + +problem = CustomProblem() + +pipeline = Pipeline( + algorithm1, + problem, + generation_limit=150, + fitness_target=1, + seed=42, +) +state = pipeline.setup() +# run until termination +state, best = pipeline.auto_run(state) +# show results +# pipeline.show(state, best) + +network = algorithm1.genome.network_dict(state, *best) \ No newline at end of file diff --git a/test/ranknet_neat.py b/test/ranknet_neat.py new file mode 100644 index 0000000..48fa552 --- /dev/null +++ b/test/ranknet_neat.py @@ -0,0 +1,125 @@ +###this code will throw a ValueError +from tensorneat import algorithm, genome, common +from tensorneat.pipeline import Pipeline +from tensorneat.genome.gene.node import DefaultNode +from tensorneat.genome.gene.conn import DefaultConn +from tensorneat.genome.operations import mutation +import jax, jax.numpy as jnp +from tensorneat.problem import BaseProblem + +def binary_cross_entropy(prediction, target): + return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction)) + +# Define the custom Problem +class CustomProblem(BaseProblem): + + jitable = True # necessary + + def __init__(self, inputs, labels, threshold): + self.inputs = jnp.array(inputs) #nb! already has shape (n, 768) + self.labels = jnp.array(labels).reshape((-1,1)) #nb! has shape (n), must be transformed to have shape (n, 1) + self.threshold = threshold + + # move the calculation related to pairwise_labels to problem initialization + pairwise_labels = self.labels - self.labels.T + self.pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold + # using nan istead of -inf + # as any mathmatical operation with nan will result in nan + pairwise_labels = jnp.where(self.pairs_to_keep, pairwise_labels, jnp.nan) + self.pairwise_labels = jnp.where(pairwise_labels > 0, True, False) + + + def evaluate(self, state, randkey, act_func, params): + # do batch forward for all inputs (using jax.vamp). + predict = jax.vmap(act_func, in_axes=(None, None, 0))( + state, params, self.inputs + ) # should be shape (len(labels), 1) + + #calculating pairwise labels and predictions + pairwise_predictions = predict - predict.T # shape (len(inputs), len(inputs)) + + pairwise_predictions = jnp.where(self.pairs_to_keep, pairwise_predictions, jnp.nan) + pairwise_predictions = jax.nn.sigmoid(pairwise_predictions) + + # calculate loss + loss = binary_cross_entropy(pairwise_predictions, self.pairwise_labels) # shape (len(labels), len(labels)) + # jax.debug.print("loss={}", loss) + # reduce loss to a scalar + # we need to ignore nan value here + loss = jnp.mean(loss, where=~jnp.isnan(loss)) + # return negative loss as fitness + # TensorNEAT maximizes fitness, equivalent to minimizing loss + return -loss + + @property + def input_shape(self): + # the input shape that the act_func expects + return (self.inputs.shape[1],) + + @property + def output_shape(self): + # the output shape that the act_func returns + return (1,) + + def show(self, state, randkey, act_func, params, *args, **kwargs): + # showcase the performance of one individual + predict = jax.vmap(act_func, in_axes=(None, None, 0))(state, params, self.inputs) + + loss = jnp.mean(jnp.square(predict - self.labels)) + + n_elements = 5 + if n_elements > len(self.inputs): + n_elements = len(self.inputs) + + msg = f"Looking at {n_elements} first elements of input\n" + for i in range(n_elements): + msg += f"for input i: {i}, target: {self.labels[i]}, predict: {predict[i]}\n" + msg += f"total loss: {loss}\n" + print(msg) + +algorithm = algorithm.NEAT( + pop_size=10, + survival_threshold=0.2, + min_species_size=2, + compatibility_threshold=3.0, + species_elitism=2, + genome=genome.DefaultGenome( + num_inputs=768, + num_outputs=1, + max_nodes=769, # must at least be same as inputs and outputs + max_conns=768, # must be 768 connections for the network to be fully connected + output_transform=common.ACT.sigmoid, + mutation=mutation.DefaultMutation( + # no allowing adding or deleting nodes + node_add=0.0, + node_delete=0.0, + # set mutation rates for edges to 0.5 + conn_add=0.5, + conn_delete=0.5, + ), + node_gene=DefaultNode(), + conn_gene=DefaultConn(), + ), +) + + +INPUTS = jax.random.uniform(jax.random.PRNGKey(0), (100, 768)) #the input data x +LABELS = jax.random.uniform(jax.random.PRNGKey(0), (100, )) #the annotated labels y + +problem = CustomProblem(INPUTS, LABELS, 0.25) + +print("Setting up pipeline and running it") +print("-----------------------------------------------------------------------") +pipeline = Pipeline( + algorithm, + problem, + generation_limit=1, + fitness_target=1, + seed=42, +) + +state = pipeline.setup() +# run until termination +state, best = pipeline.auto_run(state) +# show results +pipeline.show(state, best) \ No newline at end of file diff --git a/test/test.ipynb b/test/test.ipynb new file mode 100644 index 0000000..0cbeaba --- /dev/null +++ b/test/test.ipynb @@ -0,0 +1,118 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import jax, jax.numpy as jnp" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "LABELS = jax.random.uniform(jax.random.PRNGKey(0), (5, 1)) # the annotated labels y" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "pairwise_labels = LABELS - LABELS.T" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(Array([[0.57450044],\n", + " [0.09968603],\n", + " [0.39316022],\n", + " [0.8941783 ],\n", + " [0.59656656]], dtype=float32),\n", + " Array([[ 0. , 0.47481441, 0.18134022, -0.31967783, -0.02206612],\n", + " [-0.47481441, 0. , -0.2934742 , -0.79449224, -0.49688053],\n", + " [-0.18134022, 0.2934742 , 0. , -0.50101805, -0.20340633],\n", + " [ 0.31967783, 0.79449224, 0.50101805, 0. , 0.2976117 ],\n", + " [ 0.02206612, 0.49688053, 0.20340633, -0.2976117 , 0. ]], dtype=float32))" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "LABELS, pairwise_labels" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "def binary_cross_entropy(prediction, target):\n", + " return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array(0.6931472, dtype=float32, weak_type=True)" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "binary_cross_entropy(0.5, 1)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "jax_env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/test/test.py b/test/test.py new file mode 100644 index 0000000..19755e2 --- /dev/null +++ b/test/test.py @@ -0,0 +1,34 @@ +###shows the difference in loss between using jnp.where() and boolean indexing + +import jax +import jax.numpy as jnp + +def binary_cross_entropy(prediction, target): + return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction)) + + +preds = jax.random.uniform(jax.random.PRNGKey(0), (100, )).reshape((-1,1)) #predictions +LABELS = jax.random.uniform(jax.random.PRNGKey(0), (100, )).reshape((-1,1)) #the annotated labels y + +pair_lab = LABELS - LABELS.T +pair_pred = preds - preds.T + +ptk = jnp.abs(pair_lab) > 0.25 + +pair_labw = jnp.where(ptk, pair_lab, -jnp.nan) +pair_labm = pair_lab[ptk] + +pair_labw = jnp.where(pair_labw > 0, True, False) +pair_labm = jnp.where(pair_labm > 0, True, False) + +pair_predw = jnp.where(ptk, pair_pred, -jnp.nan) +pair_predm = pair_pred[ptk] + +pair_predw = jax.nn.sigmoid(pair_predw) +pair_predm = jax.nn.sigmoid(pair_predm) + +lossw = binary_cross_entropy(pair_predw, pair_labw) +lossm = binary_cross_entropy(pair_predm, pair_labm) + +print("loss using jnp.where()", jnp.mean(lossw, where=~jnp.isnan(lossw))) +print("loss using boolean indexing", jnp.mean(lossm, where=~jnp.isnan(lossm))) \ No newline at end of file