|
15 | 15 | { |
16 | 16 | "cell_type": "code", |
17 | 17 | "execution_count": 1, |
18 | | - "metadata": { |
19 | | - "execution": { |
20 | | - "iopub.execute_input": "2023-07-28T07:43:41.850101Z", |
21 | | - "iopub.status.busy": "2023-07-28T07:43:41.849672Z", |
22 | | - "iopub.status.idle": "2023-07-28T07:43:41.853520Z", |
23 | | - "shell.execute_reply": "2023-07-28T07:43:41.852990Z" |
24 | | - } |
25 | | - }, |
| 18 | + "metadata": {}, |
26 | 19 | "outputs": [], |
27 | 20 | "source": [ |
28 | 21 | "import os\n", |
|
50 | 43 | { |
51 | 44 | "cell_type": "code", |
52 | 45 | "execution_count": 2, |
53 | | - "metadata": { |
54 | | - "execution": { |
55 | | - "iopub.execute_input": "2023-07-28T07:43:41.855441Z", |
56 | | - "iopub.status.busy": "2023-07-28T07:43:41.855301Z", |
57 | | - "iopub.status.idle": "2023-07-28T07:43:41.986406Z", |
58 | | - "shell.execute_reply": "2023-07-28T07:43:41.985500Z" |
59 | | - } |
60 | | - }, |
| 46 | + "metadata": {}, |
61 | 47 | "outputs": [], |
62 | 48 | "source": [ |
63 | 49 | "from nvidia.dali.plugin.jax import data_iterator\n", |
|
98 | 84 | { |
99 | 85 | "cell_type": "code", |
100 | 86 | "execution_count": 3, |
101 | | - "metadata": { |
102 | | - "execution": { |
103 | | - "iopub.execute_input": "2023-07-28T07:43:41.989183Z", |
104 | | - "iopub.status.busy": "2023-07-28T07:43:41.988964Z", |
105 | | - "iopub.status.idle": "2023-07-28T07:43:42.104446Z", |
106 | | - "shell.execute_reply": "2023-07-28T07:43:42.103668Z" |
107 | | - } |
108 | | - }, |
| 87 | + "metadata": {}, |
109 | 88 | "outputs": [ |
110 | 89 | { |
111 | 90 | "name": "stdout", |
112 | 91 | "output_type": "stream", |
113 | 92 | "text": [ |
114 | 93 | "Creating iterators\n", |
115 | | - "<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7f2894462ef0>\n", |
116 | | - "<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7f28944634c0>\n" |
| 94 | + "<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7d7397b4b790>\n", |
| 95 | + "<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7d739800e530>\n" |
117 | 96 | ] |
118 | 97 | } |
119 | 98 | ], |
|
145 | 124 | { |
146 | 125 | "cell_type": "code", |
147 | 126 | "execution_count": 4, |
148 | | - "metadata": { |
149 | | - "execution": { |
150 | | - "iopub.execute_input": "2023-07-28T07:43:43.559575Z", |
151 | | - "iopub.status.busy": "2023-07-28T07:43:43.559420Z", |
152 | | - "iopub.status.idle": "2023-07-28T07:43:43.618221Z", |
153 | | - "shell.execute_reply": "2023-07-28T07:43:43.617532Z" |
154 | | - } |
155 | | - }, |
| 127 | + "metadata": {}, |
156 | 128 | "outputs": [], |
157 | 129 | "source": [ |
158 | 130 | "from model import init_model, update, accuracy" |
159 | 131 | ] |
160 | 132 | }, |
| 133 | + { |
| 134 | + "cell_type": "markdown", |
| 135 | + "metadata": {}, |
| 136 | + "source": [ |
| 137 | + "`jax.jit` traces, compiles, and caches functions lazily on first invocation for a given input signature. During this process, XLA may capture CUDA graphs, which forbids some CUDA calls that DALI's background thread uses internally. Since subsequent calls to the JAX function with inputs of the same shape and dtype don't trigger compilation again, we can work around this by warming up with dummy inputs before starting any DALI workload:" |
| 138 | + ] |
| 139 | + }, |
| 140 | + { |
| 141 | + "cell_type": "code", |
| 142 | + "execution_count": 5, |
| 143 | + "metadata": {}, |
| 144 | + "outputs": [], |
| 145 | + "source": [ |
| 146 | + "import jax.numpy as jnp\n", |
| 147 | + "\n", |
| 148 | + "model = init_model()\n", |
| 149 | + "dummy_images = jnp.empty((batch_size, image_size * image_size), dtype=jnp.float32)\n", |
| 150 | + "dummy_labels = jnp.empty((batch_size, num_classes), dtype=jnp.float32)\n", |
| 151 | + "_ = update(model, {\"images\": dummy_images, \"labels\": dummy_labels})" |
| 152 | + ] |
| 153 | + }, |
| 154 | + { |
| 155 | + "cell_type": "markdown", |
| 156 | + "metadata": {}, |
| 157 | + "source": [ |
| 158 | + "<div class=\"alert alert-warning\">\n", |
| 159 | + "\n", |
| 160 | + " Warning<br>\n", |
| 161 | + " \n", |
| 162 | + " If you skip this step, CUDA graph capture will happen on the first call to `update` and may overlap with DALI's execution, causing CUDA errors in JAX.\n", |
| 163 | + " \n", |
| 164 | + " Alternatively, you can disable XLA command buffers entirely by setting `XLA_FLAGS=\"--xla_gpu_enable_command_buffer=\"`, at the cost of some performance.\n", |
| 165 | + " \n", |
| 166 | + "</div>" |
| 167 | + ] |
| 168 | + }, |
161 | 169 | { |
162 | 170 | "attachments": {}, |
163 | 171 | "cell_type": "markdown", |
|
168 | 176 | }, |
169 | 177 | { |
170 | 178 | "cell_type": "code", |
171 | | - "execution_count": 5, |
172 | | - "metadata": { |
173 | | - "execution": { |
174 | | - "iopub.execute_input": "2023-07-28T07:43:43.622376Z", |
175 | | - "iopub.status.busy": "2023-07-28T07:43:43.621205Z", |
176 | | - "iopub.status.idle": "2023-07-28T07:43:58.016073Z", |
177 | | - "shell.execute_reply": "2023-07-28T07:43:58.015333Z" |
178 | | - } |
179 | | - }, |
| 179 | + "execution_count": 6, |
| 180 | + "metadata": {}, |
180 | 181 | "outputs": [ |
181 | 182 | { |
182 | 183 | "name": "stdout", |
183 | 184 | "output_type": "stream", |
184 | 185 | "text": [ |
185 | 186 | "Starting training\n", |
186 | 187 | "Epoch 0 sec\n", |
187 | | - "Test set accuracy 0.67330002784729\n", |
| 188 | + "Test set accuracy 0.674500048160553\n", |
188 | 189 | "Epoch 1 sec\n", |
189 | | - "Test set accuracy 0.7855000495910645\n", |
| 190 | + "Test set accuracy 0.7854000329971313\n", |
190 | 191 | "Epoch 2 sec\n", |
191 | | - "Test set accuracy 0.8251000642776489\n", |
| 192 | + "Test set accuracy 0.8252000212669373\n", |
192 | 193 | "Epoch 3 sec\n", |
193 | | - "Test set accuracy 0.8469000458717346\n", |
| 194 | + "Test set accuracy 0.847100019454956\n", |
194 | 195 | "Epoch 4 sec\n", |
195 | | - "Test set accuracy 0.8616000413894653\n" |
| 196 | + "Test set accuracy 0.8618000149726868\n" |
196 | 197 | ] |
197 | 198 | } |
198 | 199 | ], |
199 | 200 | "source": [ |
200 | 201 | "print(\"Starting training\")\n", |
201 | 202 | "\n", |
202 | | - "model = init_model()\n", |
203 | 203 | "num_epochs = 5\n", |
204 | 204 | "\n", |
205 | 205 | "for epoch in range(num_epochs):\n", |
|
215 | 215 | "metadata": { |
216 | 216 | "celltoolbar": "Raw Cell Format", |
217 | 217 | "kernelspec": { |
218 | | - "display_name": "Python 3", |
| 218 | + "display_name": "Python 3 (ipykernel)", |
219 | 219 | "language": "python", |
220 | 220 | "name": "python3" |
221 | 221 | }, |
|
229 | 229 | "name": "python", |
230 | 230 | "nbconvert_exporter": "python", |
231 | 231 | "pygments_lexer": "ipython3", |
232 | | - "version": "3.10.12" |
| 232 | + "version": "3.10.20" |
233 | 233 | } |
234 | 234 | }, |
235 | 235 | "nbformat": 4, |
236 | | - "nbformat_minor": 2 |
| 236 | + "nbformat_minor": 4 |
237 | 237 | } |
0 commit comments