Skip to content
Open
Show file tree
Hide file tree
Changes from 79 commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
e09d04b
Center crop operator
mdabek-nvidia Mar 19, 2026
1b043c0
Review fixes
mdabek-nvidia Mar 19, 2026
f30e282
Review fixes
mdabek-nvidia Mar 24, 2026
3337aa9
Apply suggestion from @stiepan
mdabek-nvidia Mar 24, 2026
ebbf863
Torchvision ColorJitter and Grayscale implementations
mdabek-nvidia Mar 19, 2026
d8c5cd4
Review fixes
mdabek-nvidia Mar 19, 2026
4dc37b5
Review fixes
mdabek-nvidia Mar 23, 2026
e146f9f
Review fixes
mdabek-nvidia Mar 24, 2026
6a0038d
Review fixes
mdabek-nvidia Mar 24, 2026
f53d4ec
Review fixes
mdabek-nvidia Mar 24, 2026
dd7ff14
Review fixes
mdabek-nvidia Mar 30, 2026
0eb7bd7
Review fixes and validation renaming
mdabek-nvidia Mar 31, 2026
7fd89d9
Torchvision API - center crop operator (#6266)
mdabek-nvidia Mar 24, 2026
453d18b
Torchvision ColorJitter and Grayscale implementations
mdabek-nvidia Mar 19, 2026
2bf954c
Review fixes
mdabek-nvidia Mar 19, 2026
842dbea
Gaussian blur operator
mdabek-nvidia Mar 22, 2026
b940dc6
Review fixes
mdabek-nvidia Mar 25, 2026
7c19646
Fix call stack depth handling for error tracebacks in dynamic mode (#…
rostan-t Mar 24, 2026
21c5424
Add uniform_sample option to VideoReaderDecoder (#6258)
jantonguirao Mar 30, 2026
8f48e49
Defer DLTensor deletion when CUDA graph capture is active. (#6259)
JanuszL Mar 31, 2026
5652054
Torchvision API - ColorJitter and Grayscale operators (#6272)
mdabek-nvidia Apr 2, 2026
feb2508
Gaussian blur - review fixes
mdabek-nvidia Apr 2, 2026
52c9006
Improve type hinting in functional API
mdabek-nvidia Apr 2, 2026
bc192ab
Review fixes
mdabek-nvidia Mar 25, 2026
e5fb444
Review fixes
mdabek-nvidia Mar 19, 2026
68cfc26
Gaussian blur operator
mdabek-nvidia Mar 22, 2026
0cd9a05
Torchvision ColorJitter and Grayscale implementations
mdabek-nvidia Mar 19, 2026
6413943
Review fixes
mdabek-nvidia Mar 19, 2026
65bda1b
Gaussian blur operator
mdabek-nvidia Mar 22, 2026
68f2b3f
Torchvision Pad operators
mdabek-nvidia Mar 23, 2026
22daf70
Review fixes
mdabek-nvidia Mar 25, 2026
39895d8
Fixing annotations in functional API
mdabek-nvidia Apr 2, 2026
2f452ae
Review fixes
mdabek-nvidia Apr 2, 2026
fa9360d
Merge branch 'main' into torchvision_pad
mdabek-nvidia Apr 2, 2026
ce6f14b
Review fixes
mdabek-nvidia Apr 3, 2026
e17dcc9
Review fixes
mdabek-nvidia Mar 25, 2026
66208e4
Torchvision ColorJitter and Grayscale implementations
mdabek-nvidia Mar 19, 2026
baa817f
Review fixes
mdabek-nvidia Mar 19, 2026
8946174
Gaussian blur operator
mdabek-nvidia Mar 22, 2026
996c38a
Torchvision ColorJitter and Grayscale implementations
mdabek-nvidia Mar 19, 2026
cd76b76
Review fixes
mdabek-nvidia Mar 19, 2026
96aa07d
Gaussian blur operator
mdabek-nvidia Mar 22, 2026
89037e2
Torchvision Pad operators
mdabek-nvidia Mar 23, 2026
066f8f0
Torchvision normalize operators implementation
mdabek-nvidia Mar 23, 2026
159af28
Review fixes
mdabek-nvidia Mar 25, 2026
e86e551
Review fixes
mdabek-nvidia Apr 3, 2026
ca8bf9e
Improving type hints
mdabek-nvidia Apr 3, 2026
d08ca9b
Correct std and mean for functional API
mdabek-nvidia Apr 7, 2026
cb5848f
Merge branch 'main' into torchvision_normalize
mdabek-nvidia Apr 7, 2026
ed8d250
Typo fix
mdabek-nvidia Apr 8, 2026
45e4b5c
Review fixes
mdabek-nvidia Mar 25, 2026
f56cf44
Torchvision ColorJitter and Grayscale implementations
mdabek-nvidia Mar 19, 2026
f32200e
Review fixes
mdabek-nvidia Mar 19, 2026
70e66eb
Torchvision ColorJitter and Grayscale implementations
mdabek-nvidia Mar 19, 2026
724b2f9
Review fixes
mdabek-nvidia Mar 19, 2026
48e6ab6
Gaussian blur operator
mdabek-nvidia Mar 22, 2026
1dfa682
Torchvision ColorJitter and Grayscale implementations
mdabek-nvidia Mar 19, 2026
ea1dd01
Review fixes
mdabek-nvidia Mar 19, 2026
a32e31c
Gaussian blur operator
mdabek-nvidia Mar 22, 2026
fa48ebb
Torchvision normalize operators implementation
mdabek-nvidia Mar 23, 2026
122c676
Torchvision - user documentation
mdabek-nvidia Mar 23, 2026
3050c07
Post rebase fixes
mdabek-nvidia Mar 31, 2026
27b515b
Post rebase update
mdabek-nvidia Mar 31, 2026
593489c
Torchvision API documentation in Getting started
mdabek-nvidia Apr 7, 2026
0bbac6a
Rebase fixes
mdabek-nvidia Apr 8, 2026
6f21c9d
Review fixes
mdabek-nvidia Apr 8, 2026
db22042
Removed WAR for ndd.Batch creation
mdabek-nvidia Apr 8, 2026
4fb48b4
Moving Torchvision API test to L1
mdabek-nvidia Apr 8, 2026
f076eb9
DLPack capsule fix for PyTorch 2.7.1
mdabek-nvidia Apr 9, 2026
b9371c7
Merge branch 'main' into torchvision_documentation
mdabek-nvidia Apr 9, 2026
cc44d31
Revert "DLPack capsule fix for PyTorch 2.7.1"
mdabek-nvidia Apr 9, 2026
6566746
Review fixes
mdabek-nvidia Apr 9, 2026
bc7ef81
Torchvision ColorJitter and Grayscale implementations
mdabek-nvidia Mar 19, 2026
66613ab
Torchvision implementation of tensor and PIL conversions
mdabek-nvidia Apr 2, 2026
34f32de
Review fixes
mdabek-nvidia Apr 8, 2026
9790262
Review fixes
mdabek-nvidia Apr 10, 2026
c0c260f
Merge branch 'main' into torchvision_totensor
mdabek-nvidia Apr 12, 2026
cdc9d3e
Refactor of PipelineWithLayout to use PIL conversion functions
mdabek-nvidia Apr 12, 2026
a807873
Generalized tensor conversion
mdabek-nvidia Apr 13, 2026
4315106
Review fixes
mdabek-nvidia Apr 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions dali/python/nvidia/dali/experimental/torchvision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .v2.normalize import Normalize
from .v2.pad import Pad
from .v2.resize import Resize
from .v2.totensor import ToPureTensor, PILToTensor, ToPILImage

__all__ = [
"CenterCrop",
Expand All @@ -29,7 +30,10 @@
"Grayscale",
"Normalize",
"Pad",
"PILToTensor",
"RandomHorizontalFlip",
"RandomVerticalFlip",
"Resize",
"ToPILImage",
"ToPureTensor",
]
109 changes: 54 additions & 55 deletions dali/python/nvidia/dali/experimental/torchvision/v2/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from nvidia.dali.backend import TensorListCPU, TensorListGPU

from .operator import _ValidateTensorOrImage
from .totensor import ToPureTensor, PILToTensor, ToPILImage
from .functional.totensor import to_pil_image

import numpy as np
import multiprocessing
Expand Down Expand Up @@ -119,6 +121,17 @@ def _cuda_run(self, data_input):
def _cpu_run(self, data_input):
return self.pipe.run(input_data=data_input)

@staticmethod
def _get_output_format(op_list: List[Callable[..., Sequence[_DataNode] | _DataNode]]) -> str:
output_type = "default"
for op in op_list:
if isinstance(op, (ToPureTensor, PILToTensor)):
output_type = "tensor"
elif isinstance(op, ToPILImage):
output_type = "pil"

return output_type

def __init__(
self,
op_list: List[Callable[..., Sequence[_DataNode] | _DataNode]],
Expand All @@ -127,16 +140,7 @@ def __init__(
num_threads: int = DEFAULT_NUM_THREADS,
**dali_pipeline_kwargs,
):
# TODO:
# convert_to_tensor is currently not supported and requires an user's effort
# to convert to tensor
# ToTensor is deprecated and according to:
# https://docs.pytorch.org/vision/stable/_modules/torchvision/transforms/v2/_deprecated.html#ToTensor
# should be replaced with:
# v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
#
# self.convert_to_tensor = True if isinstance(op_list[-1], ToTensor) else False
self.convert_to_tensor = False
self.output_type = PipelineWithLayout._get_output_format(op_list)
self.device = op_list[0].device if len(op_list) > 0 else "cpu"
self.torch_device_type = "cuda" if self.device == "gpu" else "cpu"

Expand Down Expand Up @@ -167,14 +171,29 @@ def run(self, data_input):
if output is None:
return output

output = to_torch_tensor(output)
# ToTensor
if self.convert_to_tensor:
if output.shape[-4] > 1:
raise NotImplementedError("ToTensor does not currently work for batches")
return to_torch_tensor(output)

def output_to_tensor(self, output: torch.Tensor) -> torch.Tensor:
"""Return the pipeline output as a CHW or NCHW ``torch.Tensor``.

For HWC pipelines the axes are permuted; for CHW pipelines the tensor is returned
as-is (batch dimension handling is left to the caller).
"""
if self.get_layout() == "HWC":
# output shape: (N, H, W, C)
if output.shape[0] == 1:
return output.squeeze(0).permute(2, 0, 1) # → (C, H, W)
return output.permute(0, 3, 1, 2) # → (N, C, H, W)
# CHW — already the right layout; batch handling belongs to the subclass
return output

def output_to_pil(self, output: torch.Tensor) -> Image.Image:
"""Convert a single-sample pipeline output tensor to a ``PIL.Image``.

For CHW pipelines the tensor is permuted to HWC before conversion.
"""
return to_pil_image(output)

@abstractmethod
def get_layout(self) -> str: ...

Expand All @@ -185,7 +204,7 @@ def get_channel_reverse_idx(self) -> int: ...
def verify_layout(self, data) -> None: ...

def is_conversion_to_tensor(self) -> bool:
return self.convert_to_tensor
return self.output_type == "tensor"


class PipelineHWC(PipelineWithLayout):
Expand Down Expand Up @@ -221,30 +240,6 @@ def __init__(
**dali_pipeline_kwargs,
)

def _convert_tensor_to_image(self, in_tensor: torch.Tensor):

channels = self.get_channel_reverse_idx()

# TODO: consider when to convert to PIL.Image - e.g. if it make sense for channels < 3
# There is no certain method to determine if the tensor is HW, HWC, or NHWC.
# The method below checks if tensor's shape is HW or ...HWC with a single channel
if len(in_tensor.shape) == 2 or (
len(in_tensor.shape) >= 3 and in_tensor.shape[channels] == 1
):
mode = "L"
if len(in_tensor.shape) != 2:
in_tensor = in_tensor.squeeze(-1)
elif in_tensor.shape[channels] == 3:
mode = "RGB"
elif in_tensor.shape[channels] == 4:
mode = "RGBA"
else:
raise ValueError(
f"Unsupported number of channels: {in_tensor.shape[channels]}. Should be 1, 3 or 4."
)
# We need to convert tensor to CPU, PIL does not support CUDA tensors
return Image.fromarray(in_tensor.cpu().numpy(), mode=mode)

def run(self, data_input):
if isinstance(data_input, Image.Image):
_input = torch.as_tensor(np.array(data_input, copy=True)).unsqueeze(0)
Expand All @@ -254,24 +249,20 @@ def run(self, data_input):
raise ValueError("HWC layout is currently supported for PIL Images only.\
Please check if samples have the same format.")

output = super().run(_input)
output = super().run(_input) # (N, H, W, C)

if self.is_conversion_to_tensor():
return output
if self.output_type == "tensor":
return self.output_to_tensor(output)

if isinstance(output, tuple):
output = self._convert_tensor_to_image(output[0])
else:
# batches
if output.shape[0] > 1:
output_list = []
for i in range(output.shape[0]):
output_list.append(self._convert_tensor_to_image(output[i]))
output = output_list
else:
output = self._convert_tensor_to_image(output[0])
# default: PIL output
if output.shape[0] > 1:
return [self.output_to_pil(output[i]) for i in range(output.shape[0])]
return self.output_to_pil(output[0])

return output
def output_to_pil(self, output: torch.Tensor) -> Image.Image:
"""Convert a single-sample HWC (H, W, C) tensor to a ``PIL.Image``."""
# Pipeline stores data as HWC; to_pil_image expects CHW, so permute first.
return to_pil_image(output.permute(2, 0, 1))

def get_layout(self) -> str:
return "HWC"
Expand Down Expand Up @@ -331,6 +322,14 @@ def run(self, data_input):
if data_input.ndim == 3:
# Remove the batch dimension we added above
output = output.squeeze(0)

if self.output_type == "pil":
if output.ndim == 4:
if output.shape[0] > 1:
return [self.output_to_pil(output[i]) for i in range(output.shape[0])]
return self.output_to_pil(output[0])
return self.output_to_pil(output)

return output

def get_layout(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,19 @@
from .normalize import normalize
from .pad import pad
from .resize import resize
from .totensor import pil_to_tensor, to_tensor, to_pil_image

__all__ = [
"center_crop",
"gaussian_blur",
"horizontal_flip",
"normalize",
"pad",
"pil_to_tensor",
"resize",
"rgb_to_grayscale",
"to_grayscale",
"to_pil_image",
"to_tensor",
"vertical_flip",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from PIL import Image
import torch


def _any_to_tensor(inpt: Image.Image | np.ndarray) -> torch.Tensor:
"""
Convert a ``PIL.Image`` or ``np.array`` to torch.Tensor

Values are in [0, 255] and the dtype is ``torch.uint8``. No scaling is applied.
Mirrors ``torchvision.transforms.v2.functional.pil_to_tensor``.

Parameters
----------
inpt : PIL.Image
Input image. Modes ``L``, ``RGB``, and ``RGBA`` are supported.

Returns
-------
torch.Tensor
CHW tensor of dtype ``torch.uint8``.
"""

if isinstance(inpt, Image.Image):
arr = np.array(inpt, copy=True) # (H, W) for L, (H, W, C) for RGB/RGBA
else:
arr = inpt

if arr.ndim == 2:
arr = np.expand_dims(arr, axis=-1) # (H, W) → (H, W, 1)

return torch.from_numpy(arr).permute(2, 0, 1) # (H, W, C) → (C, H, W)


def pil_to_tensor(inpt: Image.Image) -> torch.Tensor:
"""
Convert a ``PIL.Image`` to a uint8 CHW ``torch.Tensor``.

Values are in [0, 255] and the dtype is ``torch.uint8``. No scaling is applied.
Mirrors ``torchvision.transforms.v2.functional.pil_to_tensor``.

Parameters
----------
inpt : PIL.Image
Input image. Modes ``L``, ``RGB``, and ``RGBA`` are supported.

Returns
-------
torch.Tensor
CHW tensor of dtype ``torch.uint8``.
"""
if not isinstance(inpt, Image.Image):
raise TypeError(f"Expected PIL.Image, got {type(inpt)}")

return _any_to_tensor(inpt)


def to_tensor(inpt: Image.Image) -> torch.Tensor:
"""
Convert a ``PIL.Image`` to a float32 CHW ``torch.Tensor`` with values in [0, 1].

Mirrors ``torchvision.transforms.v2.functional.to_tensor`` (deprecated in TV v2,
but kept here for compatibility).

Parameters
----------
inpt : PIL.Image
Input image. Modes ``L``, ``RGB``, and ``RGBA`` are supported.

Returns
-------
torch.Tensor
CHW tensor of dtype ``torch.float32`` with values in [0.0, 1.0].
"""
return _any_to_tensor(inpt).float() / 255.0


def to_pil_image(inpt: torch.Tensor, mode: str | None = None) -> Image.Image:
"""
Convert a CHW ``torch.Tensor`` to a ``PIL.Image``.

Mirrors ``torchvision.transforms.v2.functional.to_pil_image``.

Parameters
----------
inpt : torch.Tensor
CHW tensor. Supported channel counts: 1 (``L``), 3 (``RGB``), 4 (``RGBA``).
mode : str or None, optional
PIL image mode. If ``None`` the mode is inferred from the channel count.

Returns
-------
PIL.Image
"""
if not isinstance(inpt, torch.Tensor):
raise TypeError(f"Expected torch.Tensor, got {type(inpt)}")
if inpt.ndim != 3:
raise ValueError(f"Expected 3-D CHW tensor, got shape {tuple(inpt.shape)}")

hwc = inpt.permute(1, 2, 0).cpu() # (C, H, W) → (H, W, C)
channels = hwc.shape[-1]

if mode is None:
if channels == 1:
mode = "L"
elif channels == 3:
mode = "RGB"
elif channels == 4:
mode = "RGBA"
else:
raise ValueError(
f"Cannot infer PIL mode from {channels} channels. " "Pass mode explicitly."
)

arr = hwc.numpy()
if np.issubdtype(arr.dtype, np.floating) and mode != "F":
arr = (arr * 255).astype(np.uint8)

if mode == "L":
arr = arr.squeeze(-1)

return Image.fromarray(arr, mode=mode)
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def transform_input(inpt, device: Literal["cpu", "gpu"] = "cpu") -> ndd.Tensor |
if inpt.ndim == 3:
_input = ndd.Tensor(inpt, layout="CHW")
elif inpt.ndim > 3:
# Creating baches of NCHW
# Creating batches of NCHW
_input = ndd.as_batch(inpt, layout="CHW")
else:
raise TypeError(f"Tensor has < 3 dimensions: {inpt.ndim}, shape: {inpt.shape}")
Expand Down
Loading
Loading