Skip to content

Commit 1f9dbc5

Browse files
committed
Compile the function ahead of time in the JAX example
Signed-off-by: Rostan Tabet <rtabet@nvidia.com>
1 parent 8e15229 commit 1f9dbc5

File tree

1 file changed

+52
-52
lines changed

1 file changed

+52
-52
lines changed

docs/examples/frameworks/jax/jax-basic_example.ipynb

Lines changed: 52 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,7 @@
1515
{
1616
"cell_type": "code",
1717
"execution_count": 1,
18-
"metadata": {
19-
"execution": {
20-
"iopub.execute_input": "2023-07-28T07:43:41.850101Z",
21-
"iopub.status.busy": "2023-07-28T07:43:41.849672Z",
22-
"iopub.status.idle": "2023-07-28T07:43:41.853520Z",
23-
"shell.execute_reply": "2023-07-28T07:43:41.852990Z"
24-
}
25-
},
18+
"metadata": {},
2619
"outputs": [],
2720
"source": [
2821
"import os\n",
@@ -50,14 +43,7 @@
5043
{
5144
"cell_type": "code",
5245
"execution_count": 2,
53-
"metadata": {
54-
"execution": {
55-
"iopub.execute_input": "2023-07-28T07:43:41.855441Z",
56-
"iopub.status.busy": "2023-07-28T07:43:41.855301Z",
57-
"iopub.status.idle": "2023-07-28T07:43:41.986406Z",
58-
"shell.execute_reply": "2023-07-28T07:43:41.985500Z"
59-
}
60-
},
46+
"metadata": {},
6147
"outputs": [],
6248
"source": [
6349
"from nvidia.dali.plugin.jax import data_iterator\n",
@@ -98,22 +84,15 @@
9884
{
9985
"cell_type": "code",
10086
"execution_count": 3,
101-
"metadata": {
102-
"execution": {
103-
"iopub.execute_input": "2023-07-28T07:43:41.989183Z",
104-
"iopub.status.busy": "2023-07-28T07:43:41.988964Z",
105-
"iopub.status.idle": "2023-07-28T07:43:42.104446Z",
106-
"shell.execute_reply": "2023-07-28T07:43:42.103668Z"
107-
}
108-
},
87+
"metadata": {},
10988
"outputs": [
11089
{
11190
"name": "stdout",
11291
"output_type": "stream",
11392
"text": [
11493
"Creating iterators\n",
115-
"<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7f2894462ef0>\n",
116-
"<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7f28944634c0>\n"
94+
"<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7d7397b4b790>\n",
95+
"<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7d739800e530>\n"
11796
]
11897
}
11998
],
@@ -145,19 +124,48 @@
145124
{
146125
"cell_type": "code",
147126
"execution_count": 4,
148-
"metadata": {
149-
"execution": {
150-
"iopub.execute_input": "2023-07-28T07:43:43.559575Z",
151-
"iopub.status.busy": "2023-07-28T07:43:43.559420Z",
152-
"iopub.status.idle": "2023-07-28T07:43:43.618221Z",
153-
"shell.execute_reply": "2023-07-28T07:43:43.617532Z"
154-
}
155-
},
127+
"metadata": {},
156128
"outputs": [],
157129
"source": [
158130
"from model import init_model, update, accuracy"
159131
]
160132
},
133+
{
134+
"cell_type": "markdown",
135+
"metadata": {},
136+
"source": [
137+
"`jax.jit` traces, compiles, and caches functions lazily on first invocation for a given input signature. During this process, XLA may capture CUDA graphs, which forbids some CUDA calls that DALI's background thread uses internally. Since subsequent calls to the JAX function with inputs of the same shape and dtype don't trigger compilation again, we can work around this by warming up with dummy inputs before starting any DALI workload:"
138+
]
139+
},
140+
{
141+
"cell_type": "code",
142+
"execution_count": 5,
143+
"metadata": {},
144+
"outputs": [],
145+
"source": [
146+
"import jax.numpy as jnp\n",
147+
"\n",
148+
"model = init_model()\n",
149+
"dummy_images = jnp.empty((batch_size, image_size * image_size), dtype=jnp.float32)\n",
150+
"dummy_labels = jnp.empty((batch_size, num_classes), dtype=jnp.float32)\n",
151+
"_ = update(model, {\"images\": dummy_images, \"labels\": dummy_labels})"
152+
]
153+
},
154+
{
155+
"cell_type": "markdown",
156+
"metadata": {},
157+
"source": [
158+
"<div class=\"alert alert-warning\">\n",
159+
"\n",
160+
" Warning<br>\n",
161+
" \n",
162+
" If you skip this step, CUDA graph capture will happen on the first call to `update` and may overlap with DALI's execution, causing CUDA errors in JAX.\n",
163+
" \n",
164+
" Alternatively, you can disable XLA command buffers entirely by setting `XLA_FLAGS=\"--xla_gpu_enable_command_buffer=\"`, at the cost of some performance.\n",
165+
" \n",
166+
"</div>"
167+
]
168+
},
161169
{
162170
"attachments": {},
163171
"cell_type": "markdown",
@@ -168,38 +176,30 @@
168176
},
169177
{
170178
"cell_type": "code",
171-
"execution_count": 5,
172-
"metadata": {
173-
"execution": {
174-
"iopub.execute_input": "2023-07-28T07:43:43.622376Z",
175-
"iopub.status.busy": "2023-07-28T07:43:43.621205Z",
176-
"iopub.status.idle": "2023-07-28T07:43:58.016073Z",
177-
"shell.execute_reply": "2023-07-28T07:43:58.015333Z"
178-
}
179-
},
179+
"execution_count": 6,
180+
"metadata": {},
180181
"outputs": [
181182
{
182183
"name": "stdout",
183184
"output_type": "stream",
184185
"text": [
185186
"Starting training\n",
186187
"Epoch 0 sec\n",
187-
"Test set accuracy 0.67330002784729\n",
188+
"Test set accuracy 0.674500048160553\n",
188189
"Epoch 1 sec\n",
189-
"Test set accuracy 0.7855000495910645\n",
190+
"Test set accuracy 0.7854000329971313\n",
190191
"Epoch 2 sec\n",
191-
"Test set accuracy 0.8251000642776489\n",
192+
"Test set accuracy 0.8252000212669373\n",
192193
"Epoch 3 sec\n",
193-
"Test set accuracy 0.8469000458717346\n",
194+
"Test set accuracy 0.847100019454956\n",
194195
"Epoch 4 sec\n",
195-
"Test set accuracy 0.8616000413894653\n"
196+
"Test set accuracy 0.8618000149726868\n"
196197
]
197198
}
198199
],
199200
"source": [
200201
"print(\"Starting training\")\n",
201202
"\n",
202-
"model = init_model()\n",
203203
"num_epochs = 5\n",
204204
"\n",
205205
"for epoch in range(num_epochs):\n",
@@ -215,7 +215,7 @@
215215
"metadata": {
216216
"celltoolbar": "Raw Cell Format",
217217
"kernelspec": {
218-
"display_name": "Python 3",
218+
"display_name": "Python 3 (ipykernel)",
219219
"language": "python",
220220
"name": "python3"
221221
},
@@ -229,9 +229,9 @@
229229
"name": "python",
230230
"nbconvert_exporter": "python",
231231
"pygments_lexer": "ipython3",
232-
"version": "3.10.12"
232+
"version": "3.10.20"
233233
}
234234
},
235235
"nbformat": 4,
236-
"nbformat_minor": 2
236+
"nbformat_minor": 4
237237
}

0 commit comments

Comments
 (0)