{ "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import jax, jax.numpy as jnp" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "LABELS = jax.random.uniform(jax.random.PRNGKey(0), (5, 1)) # the annotated labels y" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "pairwise_labels = LABELS - LABELS.T" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Array([[0.57450044],\n", " [0.09968603],\n", " [0.39316022],\n", " [0.8941783 ],\n", " [0.59656656]], dtype=float32),\n", " Array([[ 0. , 0.47481441, 0.18134022, -0.31967783, -0.02206612],\n", " [-0.47481441, 0. , -0.2934742 , -0.79449224, -0.49688053],\n", " [-0.18134022, 0.2934742 , 0. , -0.50101805, -0.20340633],\n", " [ 0.31967783, 0.79449224, 0.50101805, 0. , 0.2976117 ],\n", " [ 0.02206612, 0.49688053, 0.20340633, -0.2976117 , 0. ]], dtype=float32))" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "LABELS, pairwise_labels" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def binary_cross_entropy(prediction, target):\n", " return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Array(0.6931472, dtype=float32, weak_type=True)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "binary_cross_entropy(0.5, 1)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "jax_env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }