add some test
This commit is contained in:
34
test/test.py
Normal file
34
test/test.py
Normal file
@@ -0,0 +1,34 @@
|
||||
###shows the difference in loss between using jnp.where() and boolean indexing
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
def binary_cross_entropy(prediction, target):
|
||||
return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction))
|
||||
|
||||
|
||||
preds = jax.random.uniform(jax.random.PRNGKey(0), (100, )).reshape((-1,1)) #predictions
|
||||
LABELS = jax.random.uniform(jax.random.PRNGKey(0), (100, )).reshape((-1,1)) #the annotated labels y
|
||||
|
||||
pair_lab = LABELS - LABELS.T
|
||||
pair_pred = preds - preds.T
|
||||
|
||||
ptk = jnp.abs(pair_lab) > 0.25
|
||||
|
||||
pair_labw = jnp.where(ptk, pair_lab, -jnp.nan)
|
||||
pair_labm = pair_lab[ptk]
|
||||
|
||||
pair_labw = jnp.where(pair_labw > 0, True, False)
|
||||
pair_labm = jnp.where(pair_labm > 0, True, False)
|
||||
|
||||
pair_predw = jnp.where(ptk, pair_pred, -jnp.nan)
|
||||
pair_predm = pair_pred[ptk]
|
||||
|
||||
pair_predw = jax.nn.sigmoid(pair_predw)
|
||||
pair_predm = jax.nn.sigmoid(pair_predm)
|
||||
|
||||
lossw = binary_cross_entropy(pair_predw, pair_labw)
|
||||
lossm = binary_cross_entropy(pair_predm, pair_labm)
|
||||
|
||||
print("loss using jnp.where()", jnp.mean(lossw, where=~jnp.isnan(lossw)))
|
||||
print("loss using boolean indexing", jnp.mean(lossm, where=~jnp.isnan(lossm)))
|
||||
Reference in New Issue
Block a user