Files
tensorneat-mend/test/ranknet.py
2025-02-12 21:37:56 +08:00

140 lines
4.9 KiB
Python

# import RankNet
from tensorneat import algorithm, genome, common
from tensorneat.pipeline import Pipeline
from tensorneat.genome import BiasNode
from tensorneat.genome.operations import mutation
from tensorneat.common import ACT, AGG
import jax, jax.numpy as jnp
from tensorneat.problem import BaseProblem
data_num = 100
input_size = 768 # Each network (genome) should have input size 768
# The problem is to optimize a RankNet utilizing NEAT
def binary_cross_entropy(prediction, target):
return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction))
# Create dataset (100 samples of vectors with 768 features)
INPUTS = jax.random.uniform(
jax.random.PRNGKey(0), (data_num, input_size)
) # the input data x
LABELS = jax.random.uniform(jax.random.PRNGKey(0), (data_num, 1)) # the annotated labels y
# True (1): >=; False (0): <
pairwise_labels = jnp.where((LABELS - LABELS.T) >= 0, True, False)
print(f"{INPUTS.shape=}, {LABELS.shape=}")
# Define the custom Problem
class CustomProblem(BaseProblem):
jitable = True # necessary
def evaluate(self, state, randkey, act_func, params):
# Use ``act_func(state, params, inputs)`` to do network forward
# print("state: ", state)
# print("params: ",params)
# print("act_func: ",act_func)
ans_to_question = True
# Question: This is the same as doing a forward pass for the generated network?
# Meaning the network does 100 passes for all the elements of 768 features?
if ans_to_question:
# do batch forward for all inputs (using jax.vamp).
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, INPUTS
) # should be shape (100, 1)
else:
# I misunderstood, so I have to create a RankNet myself to predict the output
# Setting up with the values present in the genome
current_node = state.species.idx2species
current_node_weights = state.pop_conns[current_node]
net = RankNet.RankNet(input_size, current_node_weights)
predict = net.forward(INPUTS)
pairwise_predictions = predict - predict.T # shape (100, 100)
p = jax.nn.sigmoid(pairwise_predictions) # shape (100, 100)
# calculate loss
loss = binary_cross_entropy(p, pairwise_labels) # shape (100, 100)
# loss with shape (100, 100), we need to reduce it to a scalar
loss = jnp.mean(loss)
# return negative loss as fitness
# TensorNEAT maximizes fitness, equivalent to minimizing loss
return -loss
@property
def input_shape(self):
# the input shape that the act_func expects
return (input_size,)
@property
def output_shape(self):
# the output shape that the act_func returns
return (1,)
def show(self, state, randkey, act_func, params, *args, **kwargs):
# showcase the performance of one individual
predict = jax.vmap(act_func, in_axes=(None, None, 0))(state, params, INPUTS)
loss = jnp.mean(jnp.square(predict - LABELS))
msg = ""
for i in range(INPUTS.shape[0]):
msg += f"input: {INPUTS[i]}, target: {LABELS[i]}, predict: {predict[i]}\n"
msg += f"loss: {loss}\n"
print(msg)
algorithm1 = algorithm.NEAT(
# setting values to be the same as default in python NEAT package to get same as paper authors
# tried as best I could to follow this https://neat-python.readthedocs.io/en/latest/config_file.html
pop_size=100,
survival_threshold=0.2,
min_species_size=2,
species_number_calculate_by="fitness", # either this or rank, but 'fitness' should be more in line with original paper on NEAT
# species_size=10, #nothing specified for species_size, it remains default
# modifying the values the authors explicitly mention
compatibility_threshold=3.0, # maybe need to consider this one in the future if weird results, default is 2.0
species_elitism=2, # is 2 per default
genome=genome.DefaultGenome(
num_inputs=768,
num_outputs=1,
max_nodes=769, # must at least be same as inputs and outputs
max_conns=768, # must be 768 connections for the network to be fully connected
# 0 hidden layers per default
output_transform=common.ACT.sigmoid,
mutation=mutation.DefaultMutation(
# no allowing adding or deleting nodes
node_add=0.0,
node_delete=0.0,
# set mutation rates for edges to 0.5
conn_add=0.5,
conn_delete=0.5,
),
node_gene=BiasNode(),
),
)
problem = CustomProblem()
pipeline = Pipeline(
algorithm1,
problem,
generation_limit=150,
fitness_target=1,
seed=42,
)
state = pipeline.setup()
# run until termination
state, best = pipeline.auto_run(state)
# show results
# pipeline.show(state, best)
network = algorithm1.genome.network_dict(state, *best)