@@ -199,77 +199,6 @@ enum class RoundingMode
199199#endif
200200};
201201
202- /* * Rounding functions for decimal values
203- */
204-
205- template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode, typename OutputType>
206- struct DecimalRoundingComputation
207- {
208- static_assert (IsDecimal<T>);
209- static const size_t data_count = 1 ;
210- static size_t prepare (size_t scale) { return scale; }
211- // compute need decimal_scale to interpret decimals
212- static inline void compute (
213- const T * __restrict in,
214- size_t scale,
215- OutputType * __restrict out,
216- ScaleType decimal_scale)
217- {
218- static_assert (std::is_same_v<T, OutputType> || std::is_same_v<OutputType, Int64>);
219- Float64 val = in->template toFloat <Float64>(decimal_scale);
220-
221- if constexpr (scale_mode == ScaleMode::Positive)
222- {
223- val = val * scale;
224- }
225- else if constexpr (scale_mode == ScaleMode::Negative)
226- {
227- val = val / scale;
228- }
229-
230- if constexpr (rounding_mode == RoundingMode::Round)
231- {
232- val = round (val);
233- }
234- else if constexpr (rounding_mode == RoundingMode::Floor)
235- {
236- val = floor (val);
237- }
238- else if constexpr (rounding_mode == RoundingMode::Ceil)
239- {
240- val = ceil (val);
241- }
242- else if constexpr (rounding_mode == RoundingMode::Trunc)
243- {
244- val = trunc (val);
245- }
246-
247-
248- if constexpr (scale_mode == ScaleMode::Positive)
249- {
250- val = val / scale;
251- }
252- else if constexpr (scale_mode == ScaleMode::Negative)
253- {
254- val = val * scale;
255- }
256-
257- if constexpr (std::is_same_v<T, OutputType>)
258- {
259- *out = ToDecimal<Float64, T>(val, decimal_scale);
260- }
261- else if constexpr (std::is_same_v<OutputType, Int64>)
262- {
263- *out = static_cast <Int64>(val);
264- }
265- else
266- {
267- ; // never arrived here
268- }
269- }
270- };
271-
272-
273202/* * Rounding functions for integer values.
274203 */
275204template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
@@ -327,12 +256,74 @@ struct IntegerRoundingComputation
327256 }
328257 }
329258
330- static ALWAYS_INLINE void compute (const T * __restrict in, size_t scale, T * __restrict out)
259+ static ALWAYS_INLINE void compute (const T * __restrict in, T scale, T * __restrict out)
331260 {
332261 *out = compute (*in, scale);
333262 }
334263};
335264
265+ /* * Rounding functions for decimal values
266+ */
267+
268+ template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode, typename OutputType>
269+ struct DecimalRoundingComputation
270+ {
271+ static_assert (IsDecimal<T>);
272+ using NativeType = T::NativeType;
273+ static const size_t data_count = 1 ;
274+ static size_t prepare (size_t scale) { return scale; }
275+ // compute need decimal_scale to interpret decimals
276+ static inline void compute (
277+ const T * __restrict in,
278+ size_t scale,
279+ OutputType * __restrict out,
280+ NativeType decimal_scale)
281+ {
282+ static_assert (std::is_same_v<T, OutputType> || std::is_same_v<OutputType, Int64>);
283+ // Currently, we only use DecimalRoundingComputation for floor/ceil.
284+ // As for round/truncate, we always use tidbRoundWithFrac/tidbTruncateWithFrac.
285+ // So, we only handle ScaleMode::Zero here.
286+ if constexpr (scale_mode == ScaleMode::Zero)
287+ {
288+ try
289+ {
290+ if constexpr (rounding_mode == RoundingMode::Floor)
291+ {
292+ auto x = in->value ;
293+ if (x < 0 )
294+ x -= decimal_scale - 1 ;
295+ *out = static_cast <OutputType>(x / decimal_scale);
296+ }
297+ else if constexpr (rounding_mode == RoundingMode::Ceil)
298+ {
299+ auto x = in->value ;
300+ if (x >= 0 )
301+ x += decimal_scale - 1 ;
302+ *out = static_cast <OutputType>(x / decimal_scale);
303+ }
304+ else
305+ {
306+ throw Exception (
307+ " Logical error: unexpected 'rounding_mode' of DecimalRoundingComputation" ,
308+ ErrorCodes::LOGICAL_ERROR);
309+ }
310+ }
311+ catch (const std::overflow_error & e)
312+ {
313+ throw Exception (
314+ " Logical error: unexpected overflow in DecimalRoundingComputation" ,
315+ ErrorCodes::LOGICAL_ERROR);
316+ }
317+ }
318+ else
319+ {
320+ throw Exception (
321+ " Logical error: unexpected 'scale_mode' of DecimalRoundingComputation and unexpected scale: "
322+ + toString (scale),
323+ ErrorCodes::LOGICAL_ERROR);
324+ }
325+ }
326+ };
336327
337328#if __SSE4_1__
338329
@@ -540,7 +531,7 @@ struct IntegerRoundingImpl
540531
541532 while (p_in < end_in)
542533 {
543- Op::compute (p_in, scale, p_out);
534+ Op::compute (p_in, static_cast <T>( scale) , p_out);
544535 ++p_in;
545536 ++p_out;
546537 }
@@ -606,6 +597,9 @@ struct DecimalRoundingImpl;
606597template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
607598struct DecimalRoundingImpl <T, rounding_mode, scale_mode, Int64>
608599{
600+ static_assert (IsDecimal<T>);
601+ using NativeType = typename T::NativeType;
602+
609603private:
610604 using Op = DecimalRoundingComputation<T, rounding_mode, scale_mode, Int64>;
611605 using Data = T;
@@ -616,7 +610,8 @@ struct DecimalRoundingImpl<T, rounding_mode, scale_mode, Int64>
616610 size_t scale,
617611 typename ColumnVector<Int64>::Container & out)
618612 {
619- ScaleType decimal_scale = in.getScale ();
613+ ScaleType in_scale = in.getScale ();
614+ auto decimal_scale = intExp10OfSize<NativeType>(in_scale);
620615 const T * end_in = in.data () + in.size ();
621616
622617 const T * __restrict p_in = in.data ();
@@ -634,6 +629,9 @@ struct DecimalRoundingImpl<T, rounding_mode, scale_mode, Int64>
634629template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
635630struct DecimalRoundingImpl <T, rounding_mode, scale_mode, T>
636631{
632+ static_assert (IsDecimal<T>);
633+ using NativeType = typename T::NativeType;
634+
637635private:
638636 using Op = DecimalRoundingComputation<T, rounding_mode, scale_mode, T>;
639637 using Data = T;
@@ -644,7 +642,8 @@ struct DecimalRoundingImpl<T, rounding_mode, scale_mode, T>
644642 size_t scale,
645643 typename ColumnDecimal<T>::Container & out)
646644 {
647- ScaleType decimal_scale = in.getScale ();
645+ ScaleType in_scale = in.getScale ();
646+ auto decimal_scale = intExp10OfSize<NativeType>(in_scale);
648647 const T * end_in = in.data () + in.size ();
649648
650649 const T * __restrict p_in = in.data ();
@@ -698,7 +697,12 @@ struct Dispatcher
698697
699698 if constexpr (IsDecimal<OutputType>)
700699 {
701- auto col_res = ColumnDecimal<OutputType>::create (col->getData ().size (), col->getData ().getScale ());
700+ UInt32 res_scale = 0 ;
701+ if constexpr (rounding_mode == RoundingMode::Round || rounding_mode == RoundingMode::Trunc)
702+ {
703+ res_scale = col->getData ().getScale ();
704+ }
705+ auto col_res = ColumnDecimal<OutputType>::create (col->getData ().size (), res_scale);
702706 typename ColumnDecimal<OutputType>::Container & vec_res = col_res->getData ();
703707 applyInternal (col, vec_res, col_res, block, scale_arg, result);
704708 }
@@ -813,6 +817,20 @@ class FunctionRounding : public IFunction
813817 fmt::format (" Illegal type {} of argument of function {}" , arguments[0 ]->getName (), getName ()),
814818 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
815819
820+ if constexpr (rounding_mode == RoundingMode::Ceil || rounding_mode == RoundingMode::Floor)
821+ {
822+ if (arguments[0 ]->isDecimal ())
823+ {
824+ if (const auto * decimal_type32 = checkAndGetDataType<DataTypeDecimal32>(arguments[0 ].get ()))
825+ return std::make_shared<DataTypeDecimal32>(decimal_type32->getPrec (), 0 );
826+ else if (const auto * decimal_type64 = checkAndGetDataType<DataTypeDecimal64>(arguments[0 ].get ()))
827+ return std::make_shared<DataTypeDecimal64>(decimal_type64->getPrec (), 0 );
828+ else if (const auto * decimal_type128 = checkAndGetDataType<DataTypeDecimal128>(arguments[0 ].get ()))
829+ return std::make_shared<DataTypeDecimal128>(decimal_type128->getPrec (), 0 );
830+ else if (const auto * decimal_type256 = checkAndGetDataType<DataTypeDecimal256>(arguments[0 ].get ()))
831+ return std::make_shared<DataTypeDecimal256>(decimal_type256->getPrec (), 0 );
832+ }
833+ }
816834 return arguments[0 ];
817835 }
818836
0 commit comments