@@ -199,77 +199,6 @@ enum class RoundingMode
199199#endif
200200};
201201
202- /* * Rounding functions for decimal values
203- */
204-
205- template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode, typename OutputType>
206- struct DecimalRoundingComputation
207- {
208- static_assert (IsDecimal<T>);
209- static const size_t data_count = 1 ;
210- static size_t prepare (size_t scale) { return scale; }
211- // compute need decimal_scale to interpret decimals
212- static inline void compute (
213- const T * __restrict in,
214- size_t scale,
215- OutputType * __restrict out,
216- ScaleType decimal_scale)
217- {
218- static_assert (std::is_same_v<T, OutputType> || std::is_same_v<OutputType, Int64>);
219- Float64 val = in->template toFloat <Float64>(decimal_scale);
220-
221- if constexpr (scale_mode == ScaleMode::Positive)
222- {
223- val = val * scale;
224- }
225- else if constexpr (scale_mode == ScaleMode::Negative)
226- {
227- val = val / scale;
228- }
229-
230- if constexpr (rounding_mode == RoundingMode::Round)
231- {
232- val = round (val);
233- }
234- else if constexpr (rounding_mode == RoundingMode::Floor)
235- {
236- val = floor (val);
237- }
238- else if constexpr (rounding_mode == RoundingMode::Ceil)
239- {
240- val = ceil (val);
241- }
242- else if constexpr (rounding_mode == RoundingMode::Trunc)
243- {
244- val = trunc (val);
245- }
246-
247-
248- if constexpr (scale_mode == ScaleMode::Positive)
249- {
250- val = val / scale;
251- }
252- else if constexpr (scale_mode == ScaleMode::Negative)
253- {
254- val = val * scale;
255- }
256-
257- if constexpr (std::is_same_v<T, OutputType>)
258- {
259- *out = ToDecimal<Float64, T>(val, decimal_scale);
260- }
261- else if constexpr (std::is_same_v<OutputType, Int64>)
262- {
263- *out = static_cast <Int64>(val);
264- }
265- else
266- {
267- ; // never arrived here
268- }
269- }
270- };
271-
272-
273202/* * Rounding functions for integer values.
274203 */
275204template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
@@ -327,12 +256,90 @@ struct IntegerRoundingComputation
327256 }
328257 }
329258
330- static ALWAYS_INLINE void compute (const T * __restrict in, size_t scale, T * __restrict out)
259+ static ALWAYS_INLINE void compute (const T * __restrict in, T scale, T * __restrict out)
331260 {
332261 *out = compute (*in, scale);
333262 }
334263};
335264
265+ /* * Rounding functions for decimal values
266+ */
267+
268+ template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode, typename OutputType>
269+ struct DecimalRoundingComputation
270+ {
271+ static_assert (IsDecimal<T>);
272+ static const size_t data_count = 1 ;
273+ static size_t prepare (size_t scale) { return scale; }
274+ // compute need decimal_scale to interpret decimals
275+ static inline void compute (
276+ const T * __restrict in,
277+ size_t scale,
278+ OutputType * __restrict out,
279+ ScaleType decimal_scale)
280+ {
281+ static_assert (std::is_same_v<T, OutputType> || std::is_same_v<OutputType, Int64>);
282+ using NativeType = T::NativeType;
283+ // Currently, we only use DecimalRoundingComputation for floor/ceil.
284+ // As for round/truncate, we always use tidbRoundWithFrac/tidbTruncateWithFrac.
285+ // So, we only handle ScaleMode::Zero here.
286+ if constexpr (scale_mode == ScaleMode::Zero)
287+ {
288+ using Op = IntegerRoundingComputation<NativeType, rounding_mode, ScaleMode::Negative>;
289+ auto scale_factor = intExp10OfSize<NativeType>(decimal_scale);
290+
291+ if constexpr (std::is_same_v<T, OutputType>)
292+ {
293+ Op::compute (&in->value , scale_factor, &out->value );
294+ }
295+ else if constexpr (std::is_same_v<OutputType, Int64>)
296+ {
297+ try
298+ {
299+ if constexpr (rounding_mode == RoundingMode::Floor)
300+ {
301+ auto x = in->value ;
302+ if (x < 0 )
303+ x -= scale_factor - 1 ;
304+ *out = static_cast <Int64>(x / scale_factor);
305+ }
306+ else if constexpr (rounding_mode == RoundingMode::Ceil)
307+ {
308+ auto x = in->value ;
309+ if (x >= 0 )
310+ x += scale_factor - 1 ;
311+ *out = static_cast <Int64>(x / scale_factor);
312+ }
313+ else
314+ {
315+ throw Exception (
316+ " Logical error: unexpected 'rounding_mode' of DecimalRoundingComputation" ,
317+ ErrorCodes::LOGICAL_ERROR);
318+ }
319+ }
320+ catch (const std::overflow_error & e)
321+ {
322+ throw Exception (
323+ " Logical error: unexpected Type of DecimalRoundingComputation for INT result" ,
324+ ErrorCodes::LOGICAL_ERROR);
325+ }
326+ }
327+ else
328+ {
329+ throw Exception (
330+ " Logical error: unexpected OutputType of DecimalRoundingComputation" ,
331+ ErrorCodes::LOGICAL_ERROR);
332+ }
333+ }
334+ else
335+ {
336+ throw Exception (
337+ " Logical error: unexpected 'scale_mode' of DecimalRoundingComputation and unexpected scale: "
338+ + toString (scale),
339+ ErrorCodes::LOGICAL_ERROR);
340+ }
341+ }
342+ };
336343
337344#if __SSE4_1__
338345
@@ -540,7 +547,7 @@ struct IntegerRoundingImpl
540547
541548 while (p_in < end_in)
542549 {
543- Op::compute (p_in, scale, p_out);
550+ Op::compute (p_in, static_cast <T>( scale) , p_out);
544551 ++p_in;
545552 ++p_out;
546553 }
0 commit comments