add some test

This commit is contained in:
wls2002
2025-02-12 21:37:56 +08:00
parent 51028346fd
commit de2d906656
4 changed files with 417 additions and 0 deletions

140
test/ranknet.py Normal file
View File

@@ -0,0 +1,140 @@
# import RankNet
from tensorneat import algorithm, genome, common
from tensorneat.pipeline import Pipeline
from tensorneat.genome import BiasNode
from tensorneat.genome.operations import mutation
from tensorneat.common import ACT, AGG
import jax, jax.numpy as jnp
from tensorneat.problem import BaseProblem
data_num = 100
input_size = 768 # Each network (genome) should have input size 768
# The problem is to optimize a RankNet utilizing NEAT
def binary_cross_entropy(prediction, target):
return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction))
# Create dataset (100 samples of vectors with 768 features)
INPUTS = jax.random.uniform(
jax.random.PRNGKey(0), (data_num, input_size)
) # the input data x
LABELS = jax.random.uniform(jax.random.PRNGKey(0), (data_num, 1)) # the annotated labels y
# True (1): >=; False (0): <
pairwise_labels = jnp.where((LABELS - LABELS.T) >= 0, True, False)
print(f"{INPUTS.shape=}, {LABELS.shape=}")
# Define the custom Problem
class CustomProblem(BaseProblem):
jitable = True # necessary
def evaluate(self, state, randkey, act_func, params):
# Use ``act_func(state, params, inputs)`` to do network forward
# print("state: ", state)
# print("params: ",params)
# print("act_func: ",act_func)
ans_to_question = True
# Question: This is the same as doing a forward pass for the generated network?
# Meaning the network does 100 passes for all the elements of 768 features?
if ans_to_question:
# do batch forward for all inputs (using jax.vamp).
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, INPUTS
) # should be shape (100, 1)
else:
# I misunderstood, so I have to create a RankNet myself to predict the output
# Setting up with the values present in the genome
current_node = state.species.idx2species
current_node_weights = state.pop_conns[current_node]
net = RankNet.RankNet(input_size, current_node_weights)
predict = net.forward(INPUTS)
pairwise_predictions = predict - predict.T # shape (100, 100)
p = jax.nn.sigmoid(pairwise_predictions) # shape (100, 100)
# calculate loss
loss = binary_cross_entropy(p, pairwise_labels) # shape (100, 100)
# loss with shape (100, 100), we need to reduce it to a scalar
loss = jnp.mean(loss)
# return negative loss as fitness
# TensorNEAT maximizes fitness, equivalent to minimizing loss
return -loss
@property
def input_shape(self):
# the input shape that the act_func expects
return (input_size,)
@property
def output_shape(self):
# the output shape that the act_func returns
return (1,)
def show(self, state, randkey, act_func, params, *args, **kwargs):
# showcase the performance of one individual
predict = jax.vmap(act_func, in_axes=(None, None, 0))(state, params, INPUTS)
loss = jnp.mean(jnp.square(predict - LABELS))
msg = ""
for i in range(INPUTS.shape[0]):
msg += f"input: {INPUTS[i]}, target: {LABELS[i]}, predict: {predict[i]}\n"
msg += f"loss: {loss}\n"
print(msg)
algorithm1 = algorithm.NEAT(
# setting values to be the same as default in python NEAT package to get same as paper authors
# tried as best I could to follow this https://neat-python.readthedocs.io/en/latest/config_file.html
pop_size=100,
survival_threshold=0.2,
min_species_size=2,
species_number_calculate_by="fitness", # either this or rank, but 'fitness' should be more in line with original paper on NEAT
# species_size=10, #nothing specified for species_size, it remains default
# modifying the values the authors explicitly mention
compatibility_threshold=3.0, # maybe need to consider this one in the future if weird results, default is 2.0
species_elitism=2, # is 2 per default
genome=genome.DefaultGenome(
num_inputs=768,
num_outputs=1,
max_nodes=769, # must at least be same as inputs and outputs
max_conns=768, # must be 768 connections for the network to be fully connected
# 0 hidden layers per default
output_transform=common.ACT.sigmoid,
mutation=mutation.DefaultMutation(
# no allowing adding or deleting nodes
node_add=0.0,
node_delete=0.0,
# set mutation rates for edges to 0.5
conn_add=0.5,
conn_delete=0.5,
),
node_gene=BiasNode(),
),
)
problem = CustomProblem()
pipeline = Pipeline(
algorithm1,
problem,
generation_limit=150,
fitness_target=1,
seed=42,
)
state = pipeline.setup()
# run until termination
state, best = pipeline.auto_run(state)
# show results
# pipeline.show(state, best)
network = algorithm1.genome.network_dict(state, *best)

125
test/ranknet_neat.py Normal file
View File

@@ -0,0 +1,125 @@
###this code will throw a ValueError
from tensorneat import algorithm, genome, common
from tensorneat.pipeline import Pipeline
from tensorneat.genome.gene.node import DefaultNode
from tensorneat.genome.gene.conn import DefaultConn
from tensorneat.genome.operations import mutation
import jax, jax.numpy as jnp
from tensorneat.problem import BaseProblem
def binary_cross_entropy(prediction, target):
return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction))
# Define the custom Problem
class CustomProblem(BaseProblem):
jitable = True # necessary
def __init__(self, inputs, labels, threshold):
self.inputs = jnp.array(inputs) #nb! already has shape (n, 768)
self.labels = jnp.array(labels).reshape((-1,1)) #nb! has shape (n), must be transformed to have shape (n, 1)
self.threshold = threshold
# move the calculation related to pairwise_labels to problem initialization
pairwise_labels = self.labels - self.labels.T
self.pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold
# using nan istead of -inf
# as any mathmatical operation with nan will result in nan
pairwise_labels = jnp.where(self.pairs_to_keep, pairwise_labels, jnp.nan)
self.pairwise_labels = jnp.where(pairwise_labels > 0, True, False)
def evaluate(self, state, randkey, act_func, params):
# do batch forward for all inputs (using jax.vamp).
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
) # should be shape (len(labels), 1)
#calculating pairwise labels and predictions
pairwise_predictions = predict - predict.T # shape (len(inputs), len(inputs))
pairwise_predictions = jnp.where(self.pairs_to_keep, pairwise_predictions, jnp.nan)
pairwise_predictions = jax.nn.sigmoid(pairwise_predictions)
# calculate loss
loss = binary_cross_entropy(pairwise_predictions, self.pairwise_labels) # shape (len(labels), len(labels))
# jax.debug.print("loss={}", loss)
# reduce loss to a scalar
# we need to ignore nan value here
loss = jnp.mean(loss, where=~jnp.isnan(loss))
# return negative loss as fitness
# TensorNEAT maximizes fitness, equivalent to minimizing loss
return -loss
@property
def input_shape(self):
# the input shape that the act_func expects
return (self.inputs.shape[1],)
@property
def output_shape(self):
# the output shape that the act_func returns
return (1,)
def show(self, state, randkey, act_func, params, *args, **kwargs):
# showcase the performance of one individual
predict = jax.vmap(act_func, in_axes=(None, None, 0))(state, params, self.inputs)
loss = jnp.mean(jnp.square(predict - self.labels))
n_elements = 5
if n_elements > len(self.inputs):
n_elements = len(self.inputs)
msg = f"Looking at {n_elements} first elements of input\n"
for i in range(n_elements):
msg += f"for input i: {i}, target: {self.labels[i]}, predict: {predict[i]}\n"
msg += f"total loss: {loss}\n"
print(msg)
algorithm = algorithm.NEAT(
pop_size=10,
survival_threshold=0.2,
min_species_size=2,
compatibility_threshold=3.0,
species_elitism=2,
genome=genome.DefaultGenome(
num_inputs=768,
num_outputs=1,
max_nodes=769, # must at least be same as inputs and outputs
max_conns=768, # must be 768 connections for the network to be fully connected
output_transform=common.ACT.sigmoid,
mutation=mutation.DefaultMutation(
# no allowing adding or deleting nodes
node_add=0.0,
node_delete=0.0,
# set mutation rates for edges to 0.5
conn_add=0.5,
conn_delete=0.5,
),
node_gene=DefaultNode(),
conn_gene=DefaultConn(),
),
)
INPUTS = jax.random.uniform(jax.random.PRNGKey(0), (100, 768)) #the input data x
LABELS = jax.random.uniform(jax.random.PRNGKey(0), (100, )) #the annotated labels y
problem = CustomProblem(INPUTS, LABELS, 0.25)
print("Setting up pipeline and running it")
print("-----------------------------------------------------------------------")
pipeline = Pipeline(
algorithm,
problem,
generation_limit=1,
fitness_target=1,
seed=42,
)
state = pipeline.setup()
# run until termination
state, best = pipeline.auto_run(state)
# show results
pipeline.show(state, best)

118
test/test.ipynb Normal file
View File

@@ -0,0 +1,118 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import jax, jax.numpy as jnp"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"LABELS = jax.random.uniform(jax.random.PRNGKey(0), (5, 1)) # the annotated labels y"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"pairwise_labels = LABELS - LABELS.T"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Array([[0.57450044],\n",
" [0.09968603],\n",
" [0.39316022],\n",
" [0.8941783 ],\n",
" [0.59656656]], dtype=float32),\n",
" Array([[ 0. , 0.47481441, 0.18134022, -0.31967783, -0.02206612],\n",
" [-0.47481441, 0. , -0.2934742 , -0.79449224, -0.49688053],\n",
" [-0.18134022, 0.2934742 , 0. , -0.50101805, -0.20340633],\n",
" [ 0.31967783, 0.79449224, 0.50101805, 0. , 0.2976117 ],\n",
" [ 0.02206612, 0.49688053, 0.20340633, -0.2976117 , 0. ]], dtype=float32))"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"LABELS, pairwise_labels"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def binary_cross_entropy(prediction, target):\n",
" return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Array(0.6931472, dtype=float32, weak_type=True)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"binary_cross_entropy(0.5, 1)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "jax_env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

34
test/test.py Normal file
View File

@@ -0,0 +1,34 @@
###shows the difference in loss between using jnp.where() and boolean indexing
import jax
import jax.numpy as jnp
def binary_cross_entropy(prediction, target):
return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction))
preds = jax.random.uniform(jax.random.PRNGKey(0), (100, )).reshape((-1,1)) #predictions
LABELS = jax.random.uniform(jax.random.PRNGKey(0), (100, )).reshape((-1,1)) #the annotated labels y
pair_lab = LABELS - LABELS.T
pair_pred = preds - preds.T
ptk = jnp.abs(pair_lab) > 0.25
pair_labw = jnp.where(ptk, pair_lab, -jnp.nan)
pair_labm = pair_lab[ptk]
pair_labw = jnp.where(pair_labw > 0, True, False)
pair_labm = jnp.where(pair_labm > 0, True, False)
pair_predw = jnp.where(ptk, pair_pred, -jnp.nan)
pair_predm = pair_pred[ptk]
pair_predw = jax.nn.sigmoid(pair_predw)
pair_predm = jax.nn.sigmoid(pair_predm)
lossw = binary_cross_entropy(pair_predw, pair_labw)
lossm = binary_cross_entropy(pair_predm, pair_labm)
print("loss using jnp.where()", jnp.mean(lossw, where=~jnp.isnan(lossw)))
print("loss using boolean indexing", jnp.mean(lossm, where=~jnp.isnan(lossm)))