add some test

This commit is contained in:
wls2002
2025-02-12 21:37:56 +08:00
parent 51028346fd
commit de2d906656
4 changed files with 417 additions and 0 deletions

34
test/test.py Normal file
View File

@@ -0,0 +1,34 @@
###shows the difference in loss between using jnp.where() and boolean indexing
import jax
import jax.numpy as jnp
def binary_cross_entropy(prediction, target):
return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction))
preds = jax.random.uniform(jax.random.PRNGKey(0), (100, )).reshape((-1,1)) #predictions
LABELS = jax.random.uniform(jax.random.PRNGKey(0), (100, )).reshape((-1,1)) #the annotated labels y
pair_lab = LABELS - LABELS.T
pair_pred = preds - preds.T
ptk = jnp.abs(pair_lab) > 0.25
pair_labw = jnp.where(ptk, pair_lab, -jnp.nan)
pair_labm = pair_lab[ptk]
pair_labw = jnp.where(pair_labw > 0, True, False)
pair_labm = jnp.where(pair_labm > 0, True, False)
pair_predw = jnp.where(ptk, pair_pred, -jnp.nan)
pair_predm = pair_pred[ptk]
pair_predw = jax.nn.sigmoid(pair_predw)
pair_predm = jax.nn.sigmoid(pair_predm)
lossw = binary_cross_entropy(pair_predw, pair_labw)
lossm = binary_cross_entropy(pair_predm, pair_labm)
print("loss using jnp.where()", jnp.mean(lossw, where=~jnp.isnan(lossw)))
print("loss using boolean indexing", jnp.mean(lossm, where=~jnp.isnan(lossm)))