Skip to content

Commit 59bb758

Browse files
Fix decimal floor/ceil (#10365) (#10364) (#10412)
close #3496, close #10365, ref pingcap/tidb#63086 Co-authored-by: ChangRui-Ryan <changrui82@gmail.com>
1 parent ec6b716 commit 59bb758

File tree

4 files changed

+697
-76
lines changed

4 files changed

+697
-76
lines changed

dbms/src/Functions/FunctionsRound.h

Lines changed: 94 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -199,77 +199,6 @@ enum class RoundingMode
199199
#endif
200200
};
201201

202-
/** Rounding functions for decimal values
203-
*/
204-
205-
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode, typename OutputType>
206-
struct DecimalRoundingComputation
207-
{
208-
static_assert(IsDecimal<T>);
209-
static const size_t data_count = 1;
210-
static size_t prepare(size_t scale) { return scale; }
211-
// compute need decimal_scale to interpret decimals
212-
static inline void compute(
213-
const T * __restrict in,
214-
size_t scale,
215-
OutputType * __restrict out,
216-
ScaleType decimal_scale)
217-
{
218-
static_assert(std::is_same_v<T, OutputType> || std::is_same_v<OutputType, Int64>);
219-
Float64 val = in->template toFloat<Float64>(decimal_scale);
220-
221-
if constexpr (scale_mode == ScaleMode::Positive)
222-
{
223-
val = val * scale;
224-
}
225-
else if constexpr (scale_mode == ScaleMode::Negative)
226-
{
227-
val = val / scale;
228-
}
229-
230-
if constexpr (rounding_mode == RoundingMode::Round)
231-
{
232-
val = round(val);
233-
}
234-
else if constexpr (rounding_mode == RoundingMode::Floor)
235-
{
236-
val = floor(val);
237-
}
238-
else if constexpr (rounding_mode == RoundingMode::Ceil)
239-
{
240-
val = ceil(val);
241-
}
242-
else if constexpr (rounding_mode == RoundingMode::Trunc)
243-
{
244-
val = trunc(val);
245-
}
246-
247-
248-
if constexpr (scale_mode == ScaleMode::Positive)
249-
{
250-
val = val / scale;
251-
}
252-
else if constexpr (scale_mode == ScaleMode::Negative)
253-
{
254-
val = val * scale;
255-
}
256-
257-
if constexpr (std::is_same_v<T, OutputType>)
258-
{
259-
*out = ToDecimal<Float64, T>(val, decimal_scale);
260-
}
261-
else if constexpr (std::is_same_v<OutputType, Int64>)
262-
{
263-
*out = static_cast<Int64>(val);
264-
}
265-
else
266-
{
267-
; // never arrived here
268-
}
269-
}
270-
};
271-
272-
273202
/** Rounding functions for integer values.
274203
*/
275204
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
@@ -327,12 +256,74 @@ struct IntegerRoundingComputation
327256
}
328257
}
329258

330-
static ALWAYS_INLINE void compute(const T * __restrict in, size_t scale, T * __restrict out)
259+
static ALWAYS_INLINE void compute(const T * __restrict in, T scale, T * __restrict out)
331260
{
332261
*out = compute(*in, scale);
333262
}
334263
};
335264

265+
/** Rounding functions for decimal values
266+
*/
267+
268+
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode, typename OutputType>
269+
struct DecimalRoundingComputation
270+
{
271+
static_assert(IsDecimal<T>);
272+
using NativeType = T::NativeType;
273+
static const size_t data_count = 1;
274+
static size_t prepare(size_t scale) { return scale; }
275+
// compute need decimal_scale to interpret decimals
276+
static inline void compute(
277+
const T * __restrict in,
278+
size_t scale,
279+
OutputType * __restrict out,
280+
NativeType decimal_scale)
281+
{
282+
static_assert(std::is_same_v<T, OutputType> || std::is_same_v<OutputType, Int64>);
283+
// Currently, we only use DecimalRoundingComputation for floor/ceil.
284+
// As for round/truncate, we always use tidbRoundWithFrac/tidbTruncateWithFrac.
285+
// So, we only handle ScaleMode::Zero here.
286+
if constexpr (scale_mode == ScaleMode::Zero)
287+
{
288+
try
289+
{
290+
if constexpr (rounding_mode == RoundingMode::Floor)
291+
{
292+
auto x = in->value;
293+
if (x < 0)
294+
x -= decimal_scale - 1;
295+
*out = static_cast<OutputType>(x / decimal_scale);
296+
}
297+
else if constexpr (rounding_mode == RoundingMode::Ceil)
298+
{
299+
auto x = in->value;
300+
if (x >= 0)
301+
x += decimal_scale - 1;
302+
*out = static_cast<OutputType>(x / decimal_scale);
303+
}
304+
else
305+
{
306+
throw Exception(
307+
"Logical error: unexpected 'rounding_mode' of DecimalRoundingComputation",
308+
ErrorCodes::LOGICAL_ERROR);
309+
}
310+
}
311+
catch (const std::overflow_error & e)
312+
{
313+
throw Exception(
314+
"Logical error: unexpected overflow in DecimalRoundingComputation",
315+
ErrorCodes::LOGICAL_ERROR);
316+
}
317+
}
318+
else
319+
{
320+
throw Exception(
321+
"Logical error: unexpected 'scale_mode' of DecimalRoundingComputation and unexpected scale: "
322+
+ toString(scale),
323+
ErrorCodes::LOGICAL_ERROR);
324+
}
325+
}
326+
};
336327

337328
#if __SSE4_1__
338329

@@ -540,7 +531,7 @@ struct IntegerRoundingImpl
540531

541532
while (p_in < end_in)
542533
{
543-
Op::compute(p_in, scale, p_out);
534+
Op::compute(p_in, static_cast<T>(scale), p_out);
544535
++p_in;
545536
++p_out;
546537
}
@@ -606,6 +597,9 @@ struct DecimalRoundingImpl;
606597
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
607598
struct DecimalRoundingImpl<T, rounding_mode, scale_mode, Int64>
608599
{
600+
static_assert(IsDecimal<T>);
601+
using NativeType = typename T::NativeType;
602+
609603
private:
610604
using Op = DecimalRoundingComputation<T, rounding_mode, scale_mode, Int64>;
611605
using Data = T;
@@ -616,7 +610,8 @@ struct DecimalRoundingImpl<T, rounding_mode, scale_mode, Int64>
616610
size_t scale,
617611
typename ColumnVector<Int64>::Container & out)
618612
{
619-
ScaleType decimal_scale = in.getScale();
613+
ScaleType in_scale = in.getScale();
614+
auto decimal_scale = intExp10OfSize<NativeType>(in_scale);
620615
const T * end_in = in.data() + in.size();
621616

622617
const T * __restrict p_in = in.data();
@@ -634,6 +629,9 @@ struct DecimalRoundingImpl<T, rounding_mode, scale_mode, Int64>
634629
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
635630
struct DecimalRoundingImpl<T, rounding_mode, scale_mode, T>
636631
{
632+
static_assert(IsDecimal<T>);
633+
using NativeType = typename T::NativeType;
634+
637635
private:
638636
using Op = DecimalRoundingComputation<T, rounding_mode, scale_mode, T>;
639637
using Data = T;
@@ -644,7 +642,8 @@ struct DecimalRoundingImpl<T, rounding_mode, scale_mode, T>
644642
size_t scale,
645643
typename ColumnDecimal<T>::Container & out)
646644
{
647-
ScaleType decimal_scale = in.getScale();
645+
ScaleType in_scale = in.getScale();
646+
auto decimal_scale = intExp10OfSize<NativeType>(in_scale);
648647
const T * end_in = in.data() + in.size();
649648

650649
const T * __restrict p_in = in.data();
@@ -698,7 +697,12 @@ struct Dispatcher
698697

699698
if constexpr (IsDecimal<OutputType>)
700699
{
701-
auto col_res = ColumnDecimal<OutputType>::create(col->getData().size(), col->getData().getScale());
700+
UInt32 res_scale = 0;
701+
if constexpr (rounding_mode == RoundingMode::Round || rounding_mode == RoundingMode::Trunc)
702+
{
703+
res_scale = col->getData().getScale();
704+
}
705+
auto col_res = ColumnDecimal<OutputType>::create(col->getData().size(), res_scale);
702706
typename ColumnDecimal<OutputType>::Container & vec_res = col_res->getData();
703707
applyInternal(col, vec_res, col_res, block, scale_arg, result);
704708
}
@@ -813,6 +817,20 @@ class FunctionRounding : public IFunction
813817
fmt::format("Illegal type {} of argument of function {}", arguments[0]->getName(), getName()),
814818
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
815819

820+
if constexpr (rounding_mode == RoundingMode::Ceil || rounding_mode == RoundingMode::Floor)
821+
{
822+
if (arguments[0]->isDecimal())
823+
{
824+
if (const auto * decimal_type32 = checkAndGetDataType<DataTypeDecimal32>(arguments[0].get()))
825+
return std::make_shared<DataTypeDecimal32>(decimal_type32->getPrec(), 0);
826+
else if (const auto * decimal_type64 = checkAndGetDataType<DataTypeDecimal64>(arguments[0].get()))
827+
return std::make_shared<DataTypeDecimal64>(decimal_type64->getPrec(), 0);
828+
else if (const auto * decimal_type128 = checkAndGetDataType<DataTypeDecimal128>(arguments[0].get()))
829+
return std::make_shared<DataTypeDecimal128>(decimal_type128->getPrec(), 0);
830+
else if (const auto * decimal_type256 = checkAndGetDataType<DataTypeDecimal256>(arguments[0].get()))
831+
return std::make_shared<DataTypeDecimal256>(decimal_type256->getPrec(), 0);
832+
}
833+
}
816834
return arguments[0];
817835
}
818836

0 commit comments

Comments
 (0)