Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 54 additions & 52 deletions docs/examples/frameworks/jax/jax-basic_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,7 @@
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"execution": {
"iopub.execute_input": "2023-07-28T07:43:41.850101Z",
"iopub.status.busy": "2023-07-28T07:43:41.849672Z",
"iopub.status.idle": "2023-07-28T07:43:41.853520Z",
"shell.execute_reply": "2023-07-28T07:43:41.852990Z"
}
},
"metadata": {},
"outputs": [],
"source": [
"import os\n",
Expand Down Expand Up @@ -50,14 +43,7 @@
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"execution": {
"iopub.execute_input": "2023-07-28T07:43:41.855441Z",
"iopub.status.busy": "2023-07-28T07:43:41.855301Z",
"iopub.status.idle": "2023-07-28T07:43:41.986406Z",
"shell.execute_reply": "2023-07-28T07:43:41.985500Z"
}
},
"metadata": {},
"outputs": [],
"source": [
"from nvidia.dali.plugin.jax import data_iterator\n",
Expand Down Expand Up @@ -98,22 +84,15 @@
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2023-07-28T07:43:41.989183Z",
"iopub.status.busy": "2023-07-28T07:43:41.988964Z",
"iopub.status.idle": "2023-07-28T07:43:42.104446Z",
"shell.execute_reply": "2023-07-28T07:43:42.103668Z"
}
},
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Creating iterators\n",
"<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7f2894462ef0>\n",
"<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7f28944634c0>\n"
"<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7d7397b4b790>\n",
"<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7d739800e530>\n"
]
}
],
Expand Down Expand Up @@ -145,19 +124,50 @@
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2023-07-28T07:43:43.559575Z",
"iopub.status.busy": "2023-07-28T07:43:43.559420Z",
"iopub.status.idle": "2023-07-28T07:43:43.618221Z",
"shell.execute_reply": "2023-07-28T07:43:43.617532Z"
}
},
"metadata": {},
"outputs": [],
"source": [
"from model import init_model, update, accuracy"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`jax.jit` traces, compiles, and caches functions lazily on first invocation for a given input signature. During this process, XLA may capture CUDA graphs, which forbids some CUDA calls that DALI's background thread uses internally. Since subsequent calls to the JAX function with inputs of the same shape and dtype don't trigger compilation again, we can work around this by warming up with dummy inputs before starting any DALI workload:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import jax.numpy as jnp\n",
"\n",
"model = init_model()\n",
"dummy_images = jnp.empty(\n",
" (batch_size, image_size * image_size), dtype=jnp.float32\n",
")\n",
"dummy_labels = jnp.empty((batch_size, num_classes), dtype=jnp.float32)\n",
"_ = update(model, {\"images\": dummy_images, \"labels\": dummy_labels})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<div class=\"alert alert-warning\">\n",
"\n",
" Warning<br>\n",
" \n",
" If you skip this step, CUDA graph capture will happen on the first call to `update` and may overlap with DALI's execution, causing CUDA errors in JAX.\n",
" \n",
" Alternatively, you can disable XLA command buffers entirely by setting `XLA_FLAGS=\"--xla_gpu_enable_command_buffer=\"`, at the cost of some performance.\n",
" \n",
"</div>"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand All @@ -168,38 +178,30 @@
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"execution": {
"iopub.execute_input": "2023-07-28T07:43:43.622376Z",
"iopub.status.busy": "2023-07-28T07:43:43.621205Z",
"iopub.status.idle": "2023-07-28T07:43:58.016073Z",
"shell.execute_reply": "2023-07-28T07:43:58.015333Z"
}
},
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Starting training\n",
"Epoch 0 sec\n",
"Test set accuracy 0.67330002784729\n",
"Test set accuracy 0.674500048160553\n",
"Epoch 1 sec\n",
"Test set accuracy 0.7855000495910645\n",
"Test set accuracy 0.7854000329971313\n",
"Epoch 2 sec\n",
"Test set accuracy 0.8251000642776489\n",
"Test set accuracy 0.8252000212669373\n",
"Epoch 3 sec\n",
"Test set accuracy 0.8469000458717346\n",
"Test set accuracy 0.847100019454956\n",
"Epoch 4 sec\n",
"Test set accuracy 0.8616000413894653\n"
"Test set accuracy 0.8618000149726868\n"
]
}
],
"source": [
"print(\"Starting training\")\n",
"\n",
"model = init_model()\n",
"num_epochs = 5\n",
"\n",
"for epoch in range(num_epochs):\n",
Expand All @@ -215,7 +217,7 @@
"metadata": {
"celltoolbar": "Raw Cell Format",
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -229,9 +231,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.20"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
Loading