add some test
This commit is contained in:
140
test/ranknet.py
Normal file
140
test/ranknet.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
# import RankNet
|
||||||
|
from tensorneat import algorithm, genome, common
|
||||||
|
from tensorneat.pipeline import Pipeline
|
||||||
|
from tensorneat.genome import BiasNode
|
||||||
|
from tensorneat.genome.operations import mutation
|
||||||
|
from tensorneat.common import ACT, AGG
|
||||||
|
import jax, jax.numpy as jnp
|
||||||
|
from tensorneat.problem import BaseProblem
|
||||||
|
|
||||||
|
data_num = 100
|
||||||
|
input_size = 768 # Each network (genome) should have input size 768
|
||||||
|
|
||||||
|
# The problem is to optimize a RankNet utilizing NEAT
|
||||||
|
|
||||||
|
|
||||||
|
def binary_cross_entropy(prediction, target):
|
||||||
|
return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction))
|
||||||
|
|
||||||
|
|
||||||
|
# Create dataset (100 samples of vectors with 768 features)
|
||||||
|
INPUTS = jax.random.uniform(
|
||||||
|
jax.random.PRNGKey(0), (data_num, input_size)
|
||||||
|
) # the input data x
|
||||||
|
LABELS = jax.random.uniform(jax.random.PRNGKey(0), (data_num, 1)) # the annotated labels y
|
||||||
|
# True (1): >=; False (0): <
|
||||||
|
pairwise_labels = jnp.where((LABELS - LABELS.T) >= 0, True, False)
|
||||||
|
|
||||||
|
print(f"{INPUTS.shape=}, {LABELS.shape=}")
|
||||||
|
|
||||||
|
|
||||||
|
# Define the custom Problem
|
||||||
|
class CustomProblem(BaseProblem):
|
||||||
|
|
||||||
|
jitable = True # necessary
|
||||||
|
|
||||||
|
def evaluate(self, state, randkey, act_func, params):
|
||||||
|
# Use ``act_func(state, params, inputs)`` to do network forward
|
||||||
|
|
||||||
|
# print("state: ", state)
|
||||||
|
# print("params: ",params)
|
||||||
|
# print("act_func: ",act_func)
|
||||||
|
|
||||||
|
ans_to_question = True
|
||||||
|
|
||||||
|
# Question: This is the same as doing a forward pass for the generated network?
|
||||||
|
# Meaning the network does 100 passes for all the elements of 768 features?
|
||||||
|
if ans_to_question:
|
||||||
|
# do batch forward for all inputs (using jax.vamp).
|
||||||
|
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
|
||||||
|
state, params, INPUTS
|
||||||
|
) # should be shape (100, 1)
|
||||||
|
else:
|
||||||
|
# I misunderstood, so I have to create a RankNet myself to predict the output
|
||||||
|
# Setting up with the values present in the genome
|
||||||
|
current_node = state.species.idx2species
|
||||||
|
current_node_weights = state.pop_conns[current_node]
|
||||||
|
net = RankNet.RankNet(input_size, current_node_weights)
|
||||||
|
predict = net.forward(INPUTS)
|
||||||
|
|
||||||
|
pairwise_predictions = predict - predict.T # shape (100, 100)
|
||||||
|
p = jax.nn.sigmoid(pairwise_predictions) # shape (100, 100)
|
||||||
|
|
||||||
|
# calculate loss
|
||||||
|
loss = binary_cross_entropy(p, pairwise_labels) # shape (100, 100)
|
||||||
|
# loss with shape (100, 100), we need to reduce it to a scalar
|
||||||
|
loss = jnp.mean(loss)
|
||||||
|
|
||||||
|
# return negative loss as fitness
|
||||||
|
# TensorNEAT maximizes fitness, equivalent to minimizing loss
|
||||||
|
return -loss
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_shape(self):
|
||||||
|
# the input shape that the act_func expects
|
||||||
|
return (input_size,)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_shape(self):
|
||||||
|
# the output shape that the act_func returns
|
||||||
|
return (1,)
|
||||||
|
|
||||||
|
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||||
|
# showcase the performance of one individual
|
||||||
|
predict = jax.vmap(act_func, in_axes=(None, None, 0))(state, params, INPUTS)
|
||||||
|
|
||||||
|
loss = jnp.mean(jnp.square(predict - LABELS))
|
||||||
|
|
||||||
|
msg = ""
|
||||||
|
for i in range(INPUTS.shape[0]):
|
||||||
|
msg += f"input: {INPUTS[i]}, target: {LABELS[i]}, predict: {predict[i]}\n"
|
||||||
|
msg += f"loss: {loss}\n"
|
||||||
|
print(msg)
|
||||||
|
|
||||||
|
|
||||||
|
algorithm1 = algorithm.NEAT(
|
||||||
|
# setting values to be the same as default in python NEAT package to get same as paper authors
|
||||||
|
# tried as best I could to follow this https://neat-python.readthedocs.io/en/latest/config_file.html
|
||||||
|
pop_size=100,
|
||||||
|
survival_threshold=0.2,
|
||||||
|
min_species_size=2,
|
||||||
|
species_number_calculate_by="fitness", # either this or rank, but 'fitness' should be more in line with original paper on NEAT
|
||||||
|
# species_size=10, #nothing specified for species_size, it remains default
|
||||||
|
# modifying the values the authors explicitly mention
|
||||||
|
compatibility_threshold=3.0, # maybe need to consider this one in the future if weird results, default is 2.0
|
||||||
|
species_elitism=2, # is 2 per default
|
||||||
|
genome=genome.DefaultGenome(
|
||||||
|
num_inputs=768,
|
||||||
|
num_outputs=1,
|
||||||
|
max_nodes=769, # must at least be same as inputs and outputs
|
||||||
|
max_conns=768, # must be 768 connections for the network to be fully connected
|
||||||
|
# 0 hidden layers per default
|
||||||
|
output_transform=common.ACT.sigmoid,
|
||||||
|
mutation=mutation.DefaultMutation(
|
||||||
|
# no allowing adding or deleting nodes
|
||||||
|
node_add=0.0,
|
||||||
|
node_delete=0.0,
|
||||||
|
# set mutation rates for edges to 0.5
|
||||||
|
conn_add=0.5,
|
||||||
|
conn_delete=0.5,
|
||||||
|
),
|
||||||
|
node_gene=BiasNode(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
problem = CustomProblem()
|
||||||
|
|
||||||
|
pipeline = Pipeline(
|
||||||
|
algorithm1,
|
||||||
|
problem,
|
||||||
|
generation_limit=150,
|
||||||
|
fitness_target=1,
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
state = pipeline.setup()
|
||||||
|
# run until termination
|
||||||
|
state, best = pipeline.auto_run(state)
|
||||||
|
# show results
|
||||||
|
# pipeline.show(state, best)
|
||||||
|
|
||||||
|
network = algorithm1.genome.network_dict(state, *best)
|
||||||
125
test/ranknet_neat.py
Normal file
125
test/ranknet_neat.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
###this code will throw a ValueError
|
||||||
|
from tensorneat import algorithm, genome, common
|
||||||
|
from tensorneat.pipeline import Pipeline
|
||||||
|
from tensorneat.genome.gene.node import DefaultNode
|
||||||
|
from tensorneat.genome.gene.conn import DefaultConn
|
||||||
|
from tensorneat.genome.operations import mutation
|
||||||
|
import jax, jax.numpy as jnp
|
||||||
|
from tensorneat.problem import BaseProblem
|
||||||
|
|
||||||
|
def binary_cross_entropy(prediction, target):
|
||||||
|
return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction))
|
||||||
|
|
||||||
|
# Define the custom Problem
|
||||||
|
class CustomProblem(BaseProblem):
|
||||||
|
|
||||||
|
jitable = True # necessary
|
||||||
|
|
||||||
|
def __init__(self, inputs, labels, threshold):
|
||||||
|
self.inputs = jnp.array(inputs) #nb! already has shape (n, 768)
|
||||||
|
self.labels = jnp.array(labels).reshape((-1,1)) #nb! has shape (n), must be transformed to have shape (n, 1)
|
||||||
|
self.threshold = threshold
|
||||||
|
|
||||||
|
# move the calculation related to pairwise_labels to problem initialization
|
||||||
|
pairwise_labels = self.labels - self.labels.T
|
||||||
|
self.pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold
|
||||||
|
# using nan istead of -inf
|
||||||
|
# as any mathmatical operation with nan will result in nan
|
||||||
|
pairwise_labels = jnp.where(self.pairs_to_keep, pairwise_labels, jnp.nan)
|
||||||
|
self.pairwise_labels = jnp.where(pairwise_labels > 0, True, False)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(self, state, randkey, act_func, params):
|
||||||
|
# do batch forward for all inputs (using jax.vamp).
|
||||||
|
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
|
||||||
|
state, params, self.inputs
|
||||||
|
) # should be shape (len(labels), 1)
|
||||||
|
|
||||||
|
#calculating pairwise labels and predictions
|
||||||
|
pairwise_predictions = predict - predict.T # shape (len(inputs), len(inputs))
|
||||||
|
|
||||||
|
pairwise_predictions = jnp.where(self.pairs_to_keep, pairwise_predictions, jnp.nan)
|
||||||
|
pairwise_predictions = jax.nn.sigmoid(pairwise_predictions)
|
||||||
|
|
||||||
|
# calculate loss
|
||||||
|
loss = binary_cross_entropy(pairwise_predictions, self.pairwise_labels) # shape (len(labels), len(labels))
|
||||||
|
# jax.debug.print("loss={}", loss)
|
||||||
|
# reduce loss to a scalar
|
||||||
|
# we need to ignore nan value here
|
||||||
|
loss = jnp.mean(loss, where=~jnp.isnan(loss))
|
||||||
|
# return negative loss as fitness
|
||||||
|
# TensorNEAT maximizes fitness, equivalent to minimizing loss
|
||||||
|
return -loss
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_shape(self):
|
||||||
|
# the input shape that the act_func expects
|
||||||
|
return (self.inputs.shape[1],)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_shape(self):
|
||||||
|
# the output shape that the act_func returns
|
||||||
|
return (1,)
|
||||||
|
|
||||||
|
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||||
|
# showcase the performance of one individual
|
||||||
|
predict = jax.vmap(act_func, in_axes=(None, None, 0))(state, params, self.inputs)
|
||||||
|
|
||||||
|
loss = jnp.mean(jnp.square(predict - self.labels))
|
||||||
|
|
||||||
|
n_elements = 5
|
||||||
|
if n_elements > len(self.inputs):
|
||||||
|
n_elements = len(self.inputs)
|
||||||
|
|
||||||
|
msg = f"Looking at {n_elements} first elements of input\n"
|
||||||
|
for i in range(n_elements):
|
||||||
|
msg += f"for input i: {i}, target: {self.labels[i]}, predict: {predict[i]}\n"
|
||||||
|
msg += f"total loss: {loss}\n"
|
||||||
|
print(msg)
|
||||||
|
|
||||||
|
algorithm = algorithm.NEAT(
|
||||||
|
pop_size=10,
|
||||||
|
survival_threshold=0.2,
|
||||||
|
min_species_size=2,
|
||||||
|
compatibility_threshold=3.0,
|
||||||
|
species_elitism=2,
|
||||||
|
genome=genome.DefaultGenome(
|
||||||
|
num_inputs=768,
|
||||||
|
num_outputs=1,
|
||||||
|
max_nodes=769, # must at least be same as inputs and outputs
|
||||||
|
max_conns=768, # must be 768 connections for the network to be fully connected
|
||||||
|
output_transform=common.ACT.sigmoid,
|
||||||
|
mutation=mutation.DefaultMutation(
|
||||||
|
# no allowing adding or deleting nodes
|
||||||
|
node_add=0.0,
|
||||||
|
node_delete=0.0,
|
||||||
|
# set mutation rates for edges to 0.5
|
||||||
|
conn_add=0.5,
|
||||||
|
conn_delete=0.5,
|
||||||
|
),
|
||||||
|
node_gene=DefaultNode(),
|
||||||
|
conn_gene=DefaultConn(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
INPUTS = jax.random.uniform(jax.random.PRNGKey(0), (100, 768)) #the input data x
|
||||||
|
LABELS = jax.random.uniform(jax.random.PRNGKey(0), (100, )) #the annotated labels y
|
||||||
|
|
||||||
|
problem = CustomProblem(INPUTS, LABELS, 0.25)
|
||||||
|
|
||||||
|
print("Setting up pipeline and running it")
|
||||||
|
print("-----------------------------------------------------------------------")
|
||||||
|
pipeline = Pipeline(
|
||||||
|
algorithm,
|
||||||
|
problem,
|
||||||
|
generation_limit=1,
|
||||||
|
fitness_target=1,
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
state = pipeline.setup()
|
||||||
|
# run until termination
|
||||||
|
state, best = pipeline.auto_run(state)
|
||||||
|
# show results
|
||||||
|
pipeline.show(state, best)
|
||||||
118
test/test.ipynb
Normal file
118
test/test.ipynb
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import jax, jax.numpy as jnp"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"LABELS = jax.random.uniform(jax.random.PRNGKey(0), (5, 1)) # the annotated labels y"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"pairwise_labels = LABELS - LABELS.T"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"(Array([[0.57450044],\n",
|
||||||
|
" [0.09968603],\n",
|
||||||
|
" [0.39316022],\n",
|
||||||
|
" [0.8941783 ],\n",
|
||||||
|
" [0.59656656]], dtype=float32),\n",
|
||||||
|
" Array([[ 0. , 0.47481441, 0.18134022, -0.31967783, -0.02206612],\n",
|
||||||
|
" [-0.47481441, 0. , -0.2934742 , -0.79449224, -0.49688053],\n",
|
||||||
|
" [-0.18134022, 0.2934742 , 0. , -0.50101805, -0.20340633],\n",
|
||||||
|
" [ 0.31967783, 0.79449224, 0.50101805, 0. , 0.2976117 ],\n",
|
||||||
|
" [ 0.02206612, 0.49688053, 0.20340633, -0.2976117 , 0. ]], dtype=float32))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"LABELS, pairwise_labels"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def binary_cross_entropy(prediction, target):\n",
|
||||||
|
" return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"Array(0.6931472, dtype=float32, weak_type=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 11,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"binary_cross_entropy(0.5, 1)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "jax_env",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.14"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
||||||
34
test/test.py
Normal file
34
test/test.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
###shows the difference in loss between using jnp.where() and boolean indexing
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
def binary_cross_entropy(prediction, target):
|
||||||
|
return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction))
|
||||||
|
|
||||||
|
|
||||||
|
preds = jax.random.uniform(jax.random.PRNGKey(0), (100, )).reshape((-1,1)) #predictions
|
||||||
|
LABELS = jax.random.uniform(jax.random.PRNGKey(0), (100, )).reshape((-1,1)) #the annotated labels y
|
||||||
|
|
||||||
|
pair_lab = LABELS - LABELS.T
|
||||||
|
pair_pred = preds - preds.T
|
||||||
|
|
||||||
|
ptk = jnp.abs(pair_lab) > 0.25
|
||||||
|
|
||||||
|
pair_labw = jnp.where(ptk, pair_lab, -jnp.nan)
|
||||||
|
pair_labm = pair_lab[ptk]
|
||||||
|
|
||||||
|
pair_labw = jnp.where(pair_labw > 0, True, False)
|
||||||
|
pair_labm = jnp.where(pair_labm > 0, True, False)
|
||||||
|
|
||||||
|
pair_predw = jnp.where(ptk, pair_pred, -jnp.nan)
|
||||||
|
pair_predm = pair_pred[ptk]
|
||||||
|
|
||||||
|
pair_predw = jax.nn.sigmoid(pair_predw)
|
||||||
|
pair_predm = jax.nn.sigmoid(pair_predm)
|
||||||
|
|
||||||
|
lossw = binary_cross_entropy(pair_predw, pair_labw)
|
||||||
|
lossm = binary_cross_entropy(pair_predm, pair_labm)
|
||||||
|
|
||||||
|
print("loss using jnp.where()", jnp.mean(lossw, where=~jnp.isnan(lossw)))
|
||||||
|
print("loss using boolean indexing", jnp.mean(lossm, where=~jnp.isnan(lossm)))
|
||||||
Reference in New Issue
Block a user