diff --git a/docs/examples/frameworks/jax/jax-basic_example.ipynb b/docs/examples/frameworks/jax/jax-basic_example.ipynb index dd33cb32cc2..39219c021e9 100644 --- a/docs/examples/frameworks/jax/jax-basic_example.ipynb +++ b/docs/examples/frameworks/jax/jax-basic_example.ipynb @@ -15,14 +15,7 @@ { "cell_type": "code", "execution_count": 1, - "metadata": { - "execution": { - "iopub.execute_input": "2023-07-28T07:43:41.850101Z", - "iopub.status.busy": "2023-07-28T07:43:41.849672Z", - "iopub.status.idle": "2023-07-28T07:43:41.853520Z", - "shell.execute_reply": "2023-07-28T07:43:41.852990Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "import os\n", @@ -50,14 +43,7 @@ { "cell_type": "code", "execution_count": 2, - "metadata": { - "execution": { - "iopub.execute_input": "2023-07-28T07:43:41.855441Z", - "iopub.status.busy": "2023-07-28T07:43:41.855301Z", - "iopub.status.idle": "2023-07-28T07:43:41.986406Z", - "shell.execute_reply": "2023-07-28T07:43:41.985500Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "from nvidia.dali.plugin.jax import data_iterator\n", @@ -98,22 +84,15 @@ { "cell_type": "code", "execution_count": 3, - "metadata": { - "execution": { - "iopub.execute_input": "2023-07-28T07:43:41.989183Z", - "iopub.status.busy": "2023-07-28T07:43:41.988964Z", - "iopub.status.idle": "2023-07-28T07:43:42.104446Z", - "shell.execute_reply": "2023-07-28T07:43:42.103668Z" - } - }, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Creating iterators\n", - "\n", - "\n" + "\n", + "\n" ] } ], @@ -145,19 +124,50 @@ { "cell_type": "code", "execution_count": 4, - "metadata": { - "execution": { - "iopub.execute_input": "2023-07-28T07:43:43.559575Z", - "iopub.status.busy": "2023-07-28T07:43:43.559420Z", - "iopub.status.idle": "2023-07-28T07:43:43.618221Z", - "shell.execute_reply": "2023-07-28T07:43:43.617532Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "from model import init_model, update, accuracy" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`jax.jit` traces, compiles, and caches functions lazily on first invocation for a given input signature. During this process, XLA may capture CUDA graphs, which forbids some CUDA calls that DALI's background thread uses internally. Since subsequent calls to the JAX function with inputs of the same shape and dtype don't trigger compilation again, we can work around this by warming up with dummy inputs before starting any DALI workload:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import jax.numpy as jnp\n", + "\n", + "model = init_model()\n", + "dummy_images = jnp.empty(\n", + " (batch_size, image_size * image_size), dtype=jnp.float32\n", + ")\n", + "dummy_labels = jnp.empty((batch_size, num_classes), dtype=jnp.float32)\n", + "_ = update(model, {\"images\": dummy_images, \"labels\": dummy_labels})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + " Warning
\n", + " \n", + " If you skip this step, CUDA graph capture will happen on the first call to `update` and may overlap with DALI's execution, causing CUDA errors in JAX.\n", + " \n", + " Alternatively, you can disable XLA command buffers entirely by setting `XLA_FLAGS=\"--xla_gpu_enable_command_buffer=\"`, at the cost of some performance.\n", + " \n", + "
" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -168,15 +178,8 @@ }, { "cell_type": "code", - "execution_count": 5, - "metadata": { - "execution": { - "iopub.execute_input": "2023-07-28T07:43:43.622376Z", - "iopub.status.busy": "2023-07-28T07:43:43.621205Z", - "iopub.status.idle": "2023-07-28T07:43:58.016073Z", - "shell.execute_reply": "2023-07-28T07:43:58.015333Z" - } - }, + "execution_count": 6, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -184,22 +187,21 @@ "text": [ "Starting training\n", "Epoch 0 sec\n", - "Test set accuracy 0.67330002784729\n", + "Test set accuracy 0.674500048160553\n", "Epoch 1 sec\n", - "Test set accuracy 0.7855000495910645\n", + "Test set accuracy 0.7854000329971313\n", "Epoch 2 sec\n", - "Test set accuracy 0.8251000642776489\n", + "Test set accuracy 0.8252000212669373\n", "Epoch 3 sec\n", - "Test set accuracy 0.8469000458717346\n", + "Test set accuracy 0.847100019454956\n", "Epoch 4 sec\n", - "Test set accuracy 0.8616000413894653\n" + "Test set accuracy 0.8618000149726868\n" ] } ], "source": [ "print(\"Starting training\")\n", "\n", - "model = init_model()\n", "num_epochs = 5\n", "\n", "for epoch in range(num_epochs):\n", @@ -215,7 +217,7 @@ "metadata": { "celltoolbar": "Raw Cell Format", "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -229,9 +231,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.20" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 }