-
Notifications
You must be signed in to change notification settings - Fork 661
Expand file tree
/
Copy pathresize.cc
More file actions
executable file
·112 lines (94 loc) · 3.52 KB
/
resize.cc
File metadata and controls
executable file
·112 lines (94 loc) · 3.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
// Copyright (c) 2017-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dali/operators/image/resize/resize.h"
#include <cassert>
#include "dali/pipeline/data/views.h"
namespace dali {
DALI_SCHEMA(Resize)
.DocStr(R"code(Resize images.)code")
.NumInput(1)
.NumOutput(1)
.AdditionalOutputsFn([](const OpSpec& spec) {
return static_cast<int>(spec.GetArgument<bool>("save_attrs"));
})
.InputLayout(0, {"HWC", "FHWC", "CHW", "FCHW", "CFHW" ,
"DHWC", "FDHWC", "CDHW", "FCDHW", "CFDHW" })
.AddOptionalArg("save_attrs",
R"code(Save reshape attributes for testing.)code", false)
.AddOptionalArg<DALIImageType>("image_type", "Image type", nullptr)
.DeprecateArg("image_type", "0.25")
.SupportVolumetric()
.AllowSequences()
.AddParent("ResizeAttr")
.AddParent("ResamplingFilterAttr")
.OutputDType(1, DALI_INT32)
.OutputNDim(1, 1)
.OutputLayout(1, "");
template<typename Backend>
Resize<Backend>::Resize(const OpSpec &spec)
: StatelessOperator<Backend>(spec)
, ResizeBase<Backend>(spec) {
save_attrs_ = this->spec_.HasArgument("save_attrs");
InitializeBackend();
}
template <>
void Resize<CPUBackend>::InitializeBackend() {
InitializeCPU(num_threads_);
}
template <>
void Resize<CPUBackend>::RunImpl(Workspace &ws) {
const auto &input = ws.Input<CPUBackend>(0);
auto &output = ws.Output<CPUBackend>(0);
RunResize(ws, output, input);
output.SetLayout(input.GetLayout());
if (save_attrs_) {
const auto &input_shape = input.shape();
auto &attr_out = ws.Output<CPUBackend>(1);
const auto &attr_shape = attr_out.shape();
assert(attr_shape.num_samples() == input_shape.num_samples() &&
attr_shape.sample_dim() == 1 &&
is_uniform(attr_shape) && attr_shape[0][0] == NumSpatialDims());
auto attr_view = view<int, 1>(attr_out);
SaveAttrs(attr_view, input.shape());
}
}
DALI_REGISTER_OPERATOR(Resize, Resize<CPUBackend>, CPU);
template <>
void Resize<GPUBackend>::InitializeBackend() {
InitializeGPU(spec_.GetArgument<int>("minibatch_size"),
spec_.GetArgument<int64_t>("temp_buffer_hint"));
}
template<>
void Resize<GPUBackend>::RunImpl(Workspace &ws) {
const auto &input = ws.Input<GPUBackend>(0);
auto &output = ws.Output<GPUBackend>(0);
RunResize(ws, output, input);
output.SetLayout(input.GetLayout());
if (save_attrs_) {
auto &attr_out = ws.Output<GPUBackend>(1);
const auto &attr_shape = attr_out.shape();
assert(attr_shape.num_samples() == input.shape().num_samples() &&
attr_shape.sample_dim() == 1 &&
is_uniform(attr_shape) &&
attr_shape[0][0] == NumSpatialDims());
if (!attr_staging_.has_data())
attr_staging_.set_pinned(true);
attr_staging_.Resize(attr_out.shape(), DALI_INT32);
auto attr_view = view<int, 1>(attr_staging_);
SaveAttrs(attr_view, input.shape());
attr_out.Copy(attr_staging_, ws.stream());
}
}
DALI_REGISTER_OPERATOR(Resize, Resize<GPUBackend>, GPU);
} // namespace dali