Skip to content

Commit d7aaefc

Browse files
committed
Fix decimal floor/ceil (#10365)
1 parent c6dc5fc commit d7aaefc

File tree

4 files changed

+683
-73
lines changed

4 files changed

+683
-73
lines changed

dbms/src/Functions/FunctionsRound.h

Lines changed: 80 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -199,77 +199,6 @@ enum class RoundingMode
199199
#endif
200200
};
201201

202-
/** Rounding functions for decimal values
203-
*/
204-
205-
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode, typename OutputType>
206-
struct DecimalRoundingComputation
207-
{
208-
static_assert(IsDecimal<T>);
209-
static const size_t data_count = 1;
210-
static size_t prepare(size_t scale) { return scale; }
211-
// compute need decimal_scale to interpret decimals
212-
static inline void compute(
213-
const T * __restrict in,
214-
size_t scale,
215-
OutputType * __restrict out,
216-
ScaleType decimal_scale)
217-
{
218-
static_assert(std::is_same_v<T, OutputType> || std::is_same_v<OutputType, Int64>);
219-
Float64 val = in->template toFloat<Float64>(decimal_scale);
220-
221-
if constexpr (scale_mode == ScaleMode::Positive)
222-
{
223-
val = val * scale;
224-
}
225-
else if constexpr (scale_mode == ScaleMode::Negative)
226-
{
227-
val = val / scale;
228-
}
229-
230-
if constexpr (rounding_mode == RoundingMode::Round)
231-
{
232-
val = round(val);
233-
}
234-
else if constexpr (rounding_mode == RoundingMode::Floor)
235-
{
236-
val = floor(val);
237-
}
238-
else if constexpr (rounding_mode == RoundingMode::Ceil)
239-
{
240-
val = ceil(val);
241-
}
242-
else if constexpr (rounding_mode == RoundingMode::Trunc)
243-
{
244-
val = trunc(val);
245-
}
246-
247-
248-
if constexpr (scale_mode == ScaleMode::Positive)
249-
{
250-
val = val / scale;
251-
}
252-
else if constexpr (scale_mode == ScaleMode::Negative)
253-
{
254-
val = val * scale;
255-
}
256-
257-
if constexpr (std::is_same_v<T, OutputType>)
258-
{
259-
*out = ToDecimal<Float64, T>(val, decimal_scale);
260-
}
261-
else if constexpr (std::is_same_v<OutputType, Int64>)
262-
{
263-
*out = static_cast<Int64>(val);
264-
}
265-
else
266-
{
267-
; // never arrived here
268-
}
269-
}
270-
};
271-
272-
273202
/** Rounding functions for integer values.
274203
*/
275204
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
@@ -327,12 +256,90 @@ struct IntegerRoundingComputation
327256
}
328257
}
329258

330-
static ALWAYS_INLINE void compute(const T * __restrict in, size_t scale, T * __restrict out)
259+
static ALWAYS_INLINE void compute(const T * __restrict in, T scale, T * __restrict out)
331260
{
332261
*out = compute(*in, scale);
333262
}
334263
};
335264

265+
/** Rounding functions for decimal values
266+
*/
267+
268+
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode, typename OutputType>
269+
struct DecimalRoundingComputation
270+
{
271+
static_assert(IsDecimal<T>);
272+
static const size_t data_count = 1;
273+
static size_t prepare(size_t scale) { return scale; }
274+
// compute need decimal_scale to interpret decimals
275+
static inline void compute(
276+
const T * __restrict in,
277+
size_t scale,
278+
OutputType * __restrict out,
279+
ScaleType decimal_scale)
280+
{
281+
static_assert(std::is_same_v<T, OutputType> || std::is_same_v<OutputType, Int64>);
282+
using NativeType = T::NativeType;
283+
// Currently, we only use DecimalRoundingComputation for floor/ceil.
284+
// As for round/truncate, we always use tidbRoundWithFrac/tidbTruncateWithFrac.
285+
// So, we only handle ScaleMode::Zero here.
286+
if constexpr (scale_mode == ScaleMode::Zero)
287+
{
288+
using Op = IntegerRoundingComputation<NativeType, rounding_mode, ScaleMode::Negative>;
289+
auto scale_factor = intExp10OfSize<NativeType>(decimal_scale);
290+
291+
if constexpr (std::is_same_v<T, OutputType>)
292+
{
293+
Op::compute(&in->value, scale_factor, &out->value);
294+
}
295+
else if constexpr (std::is_same_v<OutputType, Int64>)
296+
{
297+
try
298+
{
299+
if constexpr (rounding_mode == RoundingMode::Floor)
300+
{
301+
auto x = in->value;
302+
if (x < 0)
303+
x -= scale_factor - 1;
304+
*out = static_cast<Int64>(x / scale_factor);
305+
}
306+
else if constexpr (rounding_mode == RoundingMode::Ceil)
307+
{
308+
auto x = in->value;
309+
if (x >= 0)
310+
x += scale_factor - 1;
311+
*out = static_cast<Int64>(x / scale_factor);
312+
}
313+
else
314+
{
315+
throw Exception(
316+
"Logical error: unexpected 'rounding_mode' of DecimalRoundingComputation",
317+
ErrorCodes::LOGICAL_ERROR);
318+
}
319+
}
320+
catch (const std::overflow_error & e)
321+
{
322+
throw Exception(
323+
"Logical error: unexpected Type of DecimalRoundingComputation for INT result",
324+
ErrorCodes::LOGICAL_ERROR);
325+
}
326+
}
327+
else
328+
{
329+
throw Exception(
330+
"Logical error: unexpected OutputType of DecimalRoundingComputation",
331+
ErrorCodes::LOGICAL_ERROR);
332+
}
333+
}
334+
else
335+
{
336+
throw Exception(
337+
"Logical error: unexpected 'scale_mode' of DecimalRoundingComputation and unexpected scale: "
338+
+ toString(scale),
339+
ErrorCodes::LOGICAL_ERROR);
340+
}
341+
}
342+
};
336343

337344
#if __SSE4_1__
338345

@@ -540,7 +547,7 @@ struct IntegerRoundingImpl
540547

541548
while (p_in < end_in)
542549
{
543-
Op::compute(p_in, scale, p_out);
550+
Op::compute(p_in, static_cast<T>(scale), p_out);
544551
++p_in;
545552
++p_out;
546553
}

0 commit comments

Comments
 (0)