Skip to content

Commit 5e10607

Browse files
committed
Start adding output metadata inference to operator schemas
Signed-off-by: Rostan Tabet <rtabet@nvidia.com>
1 parent fb77745 commit 5e10607

File tree

13 files changed

+223
-24
lines changed

13 files changed

+223
-24
lines changed

dali/operators/bbox/bb_flip.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2017-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -37,7 +37,10 @@ system, that is 0.0-1.0)code")
3737
1, true)
3838
.AddOptionalArg("vertical",
3939
R"code(Flip vertical dimension.)code",
40-
0, true);
40+
0, true)
41+
.OutputDType(0, [](const OpSpec &, span<const DALIDataType> in) { return in[0]; })
42+
.OutputNdim(0, [](const OpSpec &, span<const int> in) { return in[0]; })
43+
.OutputLayout(0, [](const OpSpec &, span<const TensorLayout> in) { return in[0]; });
4144

4245
void BbFlipCPU::RunImpl(Workspace &ws) {
4346
const auto &input = ws.Input<CPUBackend>(0);

dali/operators/decoder/image_decoder.cc

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2019-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -191,7 +191,10 @@ Please note that GPU acceleration for JPEG 2000 decoding is only available for C
191191
.NumInput(1)
192192
.NumOutput(1)
193193
.AddParent("ImageDecoderAttr")
194-
.AddParent("CachedDecoderAttr");
194+
.AddParent("CachedDecoderAttr")
195+
.OutputDType(0, [](const OpSpec &, span<const DALIDataType>) { return DALI_UINT8; })
196+
.OutputNdim(0, [](const OpSpec &, span<const int>) { return 3; })
197+
.OutputLayout(0, [](const OpSpec &, span<const TensorLayout>) { return "HWC"; });
195198

196199
// Fused
197200

@@ -309,7 +312,10 @@ of the slice (s0, s1, s2, …).
309312
310313
Integer coordinates are interpreted as absolute coordinates, while float coordinates can be
311314
interpreted as absolute or relative coordinates, depending on the value of
312-
`normalized_shape`.)code");
315+
`normalized_shape`.)code")
316+
.OutputDType(0, [](const OpSpec &, span<const DALIDataType>) { return DALI_UINT8; })
317+
.OutputNdim(0, [](const OpSpec &, span<const int>) { return 3; })
318+
.OutputLayout(0, [](const OpSpec &, span<const TensorLayout>) { return "HWC"; });
313319

314320

315321
// Deprecated aliases

dali/operators/generic/cast.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2017-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -71,7 +71,12 @@ DALI_SCHEMA(Cast)
7171
.NumOutput(1)
7272
.AllowSequences()
7373
.SupportVolumetric()
74-
.AddTypeArg("dtype", R"code(Output data type.)code");
74+
.AddTypeArg("dtype", R"code(Output data type.)code")
75+
.OutputDType(0, [](const OpSpec &spec, span<const DALIDataType>) {
76+
return spec.GetArgument<DALIDataType>("dtype");
77+
})
78+
.OutputNdim(0, [](const OpSpec &, span<const int> in) { return in[0]; })
79+
.OutputLayout(0, [](const OpSpec &, span<const TensorLayout> in) { return in[0]; });
7580

7681
DALI_SCHEMA(CastLike)
7782
.DocStr("Cast the first tensor to the type of the second tensor.")

dali/operators/generic/flip.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2019-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -34,7 +34,10 @@ and depthwise).)code")
3434
.AddOptionalArg("depthwise", R"code(Flip the depthwise dimension.)code", 0, true)
3535
.InputLayout({"FDHWC", "FHWC", "DHWC", "HWC", "FCDHW", "FCHW", "CDHW", "CHW"})
3636
.AllowSequences()
37-
.SupportVolumetric();
37+
.SupportVolumetric()
38+
.OutputDType(0, [](const OpSpec &, span<const DALIDataType> in) { return in[0]; })
39+
.OutputNdim(0, [](const OpSpec &, span<const int> in) { return in[0]; })
40+
.OutputLayout(0, [](const OpSpec &, span<const TensorLayout> in) { return in[0]; });
3841

3942

4043
template <>

dali/operators/image/color/brightness_contrast.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2019-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -89,7 +89,15 @@ This operator can also change the type of data.)code")
8989
.NumOutput(1)
9090
.AllowSequences()
9191
.SupportVolumetric()
92-
.InputLayout({"FHWC", "DHWC", "HWC"});
92+
.InputLayout({"FHWC", "DHWC", "HWC"})
93+
.OutputDType(0, [](const OpSpec &spec, span<const DALIDataType> in) {
94+
DALIDataType dtype;
95+
if (spec.TryGetArgument(dtype, "dtype"))
96+
return dtype;
97+
return in[0];
98+
})
99+
.OutputNdim(0, [](const OpSpec &, span<const int> in) { return in[0]; })
100+
.OutputLayout(0, [](const OpSpec &, span<const TensorLayout> in) { return in[0]; });
93101

94102
DALI_REGISTER_OPERATOR(BrightnessContrast, BrightnessContrastCpu, CPU)
95103
DALI_REGISTER_OPERATOR(Brightness, BrightnessContrastCpu, CPU);

dali/operators/image/color/color_twist.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2020-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -50,7 +50,15 @@ they would in case of rotation.)code",
5050
If a value is not set, the input type is used.)code",
5151
DALI_UINT8)
5252
.InputLayout(0, {"HWC", "FHWC", "DHWC"})
53-
.AllowSequences();
53+
.AllowSequences()
54+
.OutputDType(0, [](const OpSpec &spec, span<const DALIDataType> in) {
55+
DALIDataType dtype;
56+
if (spec.TryGetArgument(dtype, "dtype"))
57+
return dtype;
58+
return in[0];
59+
})
60+
.OutputNdim(0, [](const OpSpec &, span<const int> in) { return in[0]; })
61+
.OutputLayout(0, [](const OpSpec &, span<const TensorLayout> in) { return in[0]; });
5462

5563
DALI_SCHEMA(ColorTransformBase)
5664
.DocStr(R"code(Base Schema for color transformations operators.)code")

dali/operators/image/crop/bbox_crop.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,16 @@ if the fraction of their area within the ROI is greater than or equal to the thr
364364
For example, when `bbox_prune_threshold=0.2` bboxes that have at least 20% of their original area within
365365
the ROI are kept, bboxes less than or equal to are pruned. If `bbox_prune_threshold=0.0`, all boxes that
366366
have some presence in the ROI are kept.)code",
367-
nullptr);
367+
nullptr)
368+
.OutputDType(0, [](const OpSpec &, span<const DALIDataType>) { return DALI_FLOAT; })
369+
.OutputDType(1, [](const OpSpec &, span<const DALIDataType>) { return DALI_FLOAT; })
370+
.OutputDType(2, [](const OpSpec &, span<const DALIDataType>) { return DALI_FLOAT; })
371+
.OutputNdim(0, [](const OpSpec &, span<const int>) { return 1; })
372+
.OutputNdim(1, [](const OpSpec &, span<const int>) { return 1; })
373+
.OutputNdim(2, [](const OpSpec &, span<const int>) { return 2; })
374+
.OutputLayout(0, [](const OpSpec &, span<const TensorLayout>) { return TensorLayout{}; })
375+
.OutputLayout(1, [](const OpSpec &, span<const TensorLayout>) { return TensorLayout{}; })
376+
.OutputLayout(2, [](const OpSpec &, span<const TensorLayout>) { return TensorLayout{}; });
368377

369378
template <int ndim>
370379
class RandomBBoxCropImpl : public OpImplBase<CPUBackend> {

dali/operators/image/crop/crop.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2019-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -36,7 +36,15 @@ DALI_SCHEMA(Crop)
3636
.DeprecateArg("image_type", "0.24")
3737
.AddParent("CropAttr")
3838
.AddParent("OutOfBoundsAttr")
39-
.AddParent("SliceBase");
39+
.AddParent("SliceBase")
40+
.OutputDType(0, [](const OpSpec &spec, span<const DALIDataType> in) {
41+
DALIDataType dtype;
42+
if (spec.TryGetArgument(dtype, "dtype"))
43+
return dtype;
44+
return in[0];
45+
})
46+
.OutputNdim(0, [](const OpSpec &, span<const int> in) { return in[0]; })
47+
.OutputLayout(0, [](const OpSpec &, span<const TensorLayout> in) { return in[0]; });
4048

4149
// Register operator
4250
DALI_REGISTER_OPERATOR(Crop, Crop<CPUBackend>, CPU);

dali/operators/image/crop/crop_mirror_normalize.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2019-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -71,7 +71,14 @@ This argument is useful when using integer outputs to improve dynamic range util
7171
This argument is useful when using unsigned integer outputs to improve dynamic range utilization.)",
7272
0.0f)
7373
.AddParent("CropAttr")
74-
.AddParent("OutOfBoundsAttr");
74+
.AddParent("OutOfBoundsAttr")
75+
.OutputDType(0, [](const OpSpec &spec, span<const DALIDataType>) {
76+
return spec.GetArgument<DALIDataType>("dtype");
77+
})
78+
.OutputNdim(0, [](const OpSpec &, span<const int> in) { return in[0]; })
79+
.OutputLayout(0, [](const OpSpec &spec, span<const TensorLayout>) {
80+
return spec.GetArgument<TensorLayout>("output_layout");
81+
});
7582

7683
DALI_REGISTER_OPERATOR(CropMirrorNormalize, CropMirrorNormalize<CPUBackend>, CPU);
7784

dali/operators/image/resize/resize.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2017-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2017-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -34,7 +34,15 @@ DALI_SCHEMA(Resize)
3434
.SupportVolumetric()
3535
.AllowSequences()
3636
.AddParent("ResizeAttr")
37-
.AddParent("ResamplingFilterAttr");
37+
.AddParent("ResamplingFilterAttr")
38+
.OutputDType(0, [](const OpSpec &spec, span<const DALIDataType> in) {
39+
DALIDataType dtype;
40+
if (spec.TryGetArgument(dtype, "dtype"))
41+
return dtype;
42+
return in[0];
43+
})
44+
.OutputNdim(0, [](const OpSpec &, span<const int> in) { return in[0]; })
45+
.OutputLayout(0, [](const OpSpec &, span<const TensorLayout> in) { return in[0]; });
3846

3947
template<typename Backend>
4048
Resize<Backend>::Resize(const OpSpec &spec)

0 commit comments

Comments
 (0)