add some test
This commit is contained in:
140
test/ranknet.py
Normal file
140
test/ranknet.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# import RankNet
|
||||
from tensorneat import algorithm, genome, common
|
||||
from tensorneat.pipeline import Pipeline
|
||||
from tensorneat.genome import BiasNode
|
||||
from tensorneat.genome.operations import mutation
|
||||
from tensorneat.common import ACT, AGG
|
||||
import jax, jax.numpy as jnp
|
||||
from tensorneat.problem import BaseProblem
|
||||
|
||||
data_num = 100
|
||||
input_size = 768 # Each network (genome) should have input size 768
|
||||
|
||||
# The problem is to optimize a RankNet utilizing NEAT
|
||||
|
||||
|
||||
def binary_cross_entropy(prediction, target):
|
||||
return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction))
|
||||
|
||||
|
||||
# Create dataset (100 samples of vectors with 768 features)
|
||||
INPUTS = jax.random.uniform(
|
||||
jax.random.PRNGKey(0), (data_num, input_size)
|
||||
) # the input data x
|
||||
LABELS = jax.random.uniform(jax.random.PRNGKey(0), (data_num, 1)) # the annotated labels y
|
||||
# True (1): >=; False (0): <
|
||||
pairwise_labels = jnp.where((LABELS - LABELS.T) >= 0, True, False)
|
||||
|
||||
print(f"{INPUTS.shape=}, {LABELS.shape=}")
|
||||
|
||||
|
||||
# Define the custom Problem
|
||||
class CustomProblem(BaseProblem):
|
||||
|
||||
jitable = True # necessary
|
||||
|
||||
def evaluate(self, state, randkey, act_func, params):
|
||||
# Use ``act_func(state, params, inputs)`` to do network forward
|
||||
|
||||
# print("state: ", state)
|
||||
# print("params: ",params)
|
||||
# print("act_func: ",act_func)
|
||||
|
||||
ans_to_question = True
|
||||
|
||||
# Question: This is the same as doing a forward pass for the generated network?
|
||||
# Meaning the network does 100 passes for all the elements of 768 features?
|
||||
if ans_to_question:
|
||||
# do batch forward for all inputs (using jax.vamp).
|
||||
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
|
||||
state, params, INPUTS
|
||||
) # should be shape (100, 1)
|
||||
else:
|
||||
# I misunderstood, so I have to create a RankNet myself to predict the output
|
||||
# Setting up with the values present in the genome
|
||||
current_node = state.species.idx2species
|
||||
current_node_weights = state.pop_conns[current_node]
|
||||
net = RankNet.RankNet(input_size, current_node_weights)
|
||||
predict = net.forward(INPUTS)
|
||||
|
||||
pairwise_predictions = predict - predict.T # shape (100, 100)
|
||||
p = jax.nn.sigmoid(pairwise_predictions) # shape (100, 100)
|
||||
|
||||
# calculate loss
|
||||
loss = binary_cross_entropy(p, pairwise_labels) # shape (100, 100)
|
||||
# loss with shape (100, 100), we need to reduce it to a scalar
|
||||
loss = jnp.mean(loss)
|
||||
|
||||
# return negative loss as fitness
|
||||
# TensorNEAT maximizes fitness, equivalent to minimizing loss
|
||||
return -loss
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
# the input shape that the act_func expects
|
||||
return (input_size,)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
# the output shape that the act_func returns
|
||||
return (1,)
|
||||
|
||||
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||
# showcase the performance of one individual
|
||||
predict = jax.vmap(act_func, in_axes=(None, None, 0))(state, params, INPUTS)
|
||||
|
||||
loss = jnp.mean(jnp.square(predict - LABELS))
|
||||
|
||||
msg = ""
|
||||
for i in range(INPUTS.shape[0]):
|
||||
msg += f"input: {INPUTS[i]}, target: {LABELS[i]}, predict: {predict[i]}\n"
|
||||
msg += f"loss: {loss}\n"
|
||||
print(msg)
|
||||
|
||||
|
||||
algorithm1 = algorithm.NEAT(
|
||||
# setting values to be the same as default in python NEAT package to get same as paper authors
|
||||
# tried as best I could to follow this https://neat-python.readthedocs.io/en/latest/config_file.html
|
||||
pop_size=100,
|
||||
survival_threshold=0.2,
|
||||
min_species_size=2,
|
||||
species_number_calculate_by="fitness", # either this or rank, but 'fitness' should be more in line with original paper on NEAT
|
||||
# species_size=10, #nothing specified for species_size, it remains default
|
||||
# modifying the values the authors explicitly mention
|
||||
compatibility_threshold=3.0, # maybe need to consider this one in the future if weird results, default is 2.0
|
||||
species_elitism=2, # is 2 per default
|
||||
genome=genome.DefaultGenome(
|
||||
num_inputs=768,
|
||||
num_outputs=1,
|
||||
max_nodes=769, # must at least be same as inputs and outputs
|
||||
max_conns=768, # must be 768 connections for the network to be fully connected
|
||||
# 0 hidden layers per default
|
||||
output_transform=common.ACT.sigmoid,
|
||||
mutation=mutation.DefaultMutation(
|
||||
# no allowing adding or deleting nodes
|
||||
node_add=0.0,
|
||||
node_delete=0.0,
|
||||
# set mutation rates for edges to 0.5
|
||||
conn_add=0.5,
|
||||
conn_delete=0.5,
|
||||
),
|
||||
node_gene=BiasNode(),
|
||||
),
|
||||
)
|
||||
|
||||
problem = CustomProblem()
|
||||
|
||||
pipeline = Pipeline(
|
||||
algorithm1,
|
||||
problem,
|
||||
generation_limit=150,
|
||||
fitness_target=1,
|
||||
seed=42,
|
||||
)
|
||||
state = pipeline.setup()
|
||||
# run until termination
|
||||
state, best = pipeline.auto_run(state)
|
||||
# show results
|
||||
# pipeline.show(state, best)
|
||||
|
||||
network = algorithm1.genome.network_dict(state, *best)
|
||||
Reference in New Issue
Block a user