diff --git a/core/trino-main/src/main/java/io/trino/sql/ir/IrExpressions.java b/core/trino-main/src/main/java/io/trino/sql/ir/IrExpressions.java index 23259b5899cb..ac76a8f26634 100644 --- a/core/trino-main/src/main/java/io/trino/sql/ir/IrExpressions.java +++ b/core/trino-main/src/main/java/io/trino/sql/ir/IrExpressions.java @@ -16,13 +16,27 @@ import com.google.common.collect.ImmutableList; import io.trino.metadata.Metadata; import io.trino.metadata.ResolvedFunction; +import io.trino.spi.type.BigintType; +import io.trino.spi.type.DecimalType; +import io.trino.spi.type.DoubleType; +import io.trino.spi.type.Int128; +import io.trino.spi.type.IntegerType; +import io.trino.spi.type.NumberType; +import io.trino.spi.type.RealType; import io.trino.spi.type.RowType; +import io.trino.spi.type.SmallintType; +import io.trino.spi.type.TinyintType; +import io.trino.spi.type.TrinoNumber; import io.trino.sql.PlannerContext; import io.trino.type.TypeCoercion; +import java.math.BigDecimal; import java.util.List; +import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; import static io.trino.spi.block.RowValueBuilder.buildRowValue; +import static io.trino.spi.function.OperatorType.DIVIDE; +import static io.trino.spi.function.OperatorType.MODULUS; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.TypeUtils.writeNativeValue; import static io.trino.spi.type.VarcharType.VARCHAR; @@ -72,7 +86,7 @@ public static boolean mayFail(PlannerContext plannerContext, Expression expressi // These expressions need to verify their operands case Array e -> e.elements().stream().anyMatch(element -> mayFail(plannerContext, element)); case Between e -> mayFail(plannerContext, e.value()) || mayFail(plannerContext, e.min()) || mayFail(plannerContext, e.max()); - case Call e -> mayFail(e.function()) || e.arguments().stream().anyMatch(argument -> mayFail(plannerContext, argument)); + case Call e -> mayFail(e) || e.arguments().stream().anyMatch(argument -> mayFail(plannerContext, argument)); case Case e -> e.whenClauses().stream().anyMatch(clause -> mayFail(plannerContext, clause.getOperand()) || mayFail(plannerContext, clause.getResult())) || mayFail(plannerContext, e.defaultValue()); case Cast e -> mayFail(plannerContext, e); @@ -103,9 +117,45 @@ private static boolean mayFail(PlannerContext plannerContext, Cast cast) return !cast.type().equals(VARCHAR); } - private static boolean mayFail(ResolvedFunction function) + private static boolean mayFail(Call call) { - return !function.neverFails() && !isDynamicFilterFunction(function.name()); + ResolvedFunction function = call.function(); + if (function.neverFails() || isDynamicFilterFunction(function.name())) { + return false; + } + List arguments = call.arguments(); + if (isModulsOrDivide(function) && arguments.get(1) instanceof Constant divisor && !canCauseDivisionByZeroError(divisor)) { + return false; + } + return true; + } + + private static boolean isModulsOrDivide(ResolvedFunction function) + { + return (function.name().equals(builtinFunctionName(MODULUS)) || function.name().equals(builtinFunctionName(DIVIDE))) && function.signature().getArity() == 2; + } + + private static boolean canCauseDivisionByZeroError(Constant divisor) + { + Object value = divisor.value(); + if (value == null) { + return false; // dividing by null is null + } + return switch (divisor.type()) { + case TinyintType _, SmallintType _, IntegerType _, BigintType _ -> (long) value == 0; + case DecimalType decimalType -> { + if (decimalType.isShort()) { + yield (long) value == 0; + } + yield ((Int128) value).isZero(); + } + case NumberType _ -> switch (((TrinoNumber) value).toBigDecimal()) { + case TrinoNumber.BigDecimalValue(BigDecimal bigdecimal) -> bigdecimal.signum() == 0; + case TrinoNumber.Infinity _, TrinoNumber.NotANumber _ -> false; + }; + case RealType _, DoubleType _ -> false; // will return NaN or ±Inf on division by 0 + default -> true; + }; } public static Expression not(Metadata metadata, Expression expression) diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHivePlans.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHivePlans.java index e177d5415f39..0826199a986d 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHivePlans.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/optimizer/TestHivePlans.java @@ -60,7 +60,6 @@ import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.ir.Comparison.Operator.NOT_EQUAL; import static io.trino.sql.ir.Logical.Operator.AND; -import static io.trino.sql.planner.assertions.PlanMatchPattern.any; import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.join; @@ -277,16 +276,16 @@ public void testSubsumePartitionFilterNotConvertibleToTupleDomain() output( join(INNER, builder -> builder .equiCriteria("L_INT_PART", "R_INT_COL") - .filter(new Comparison(EQUAL, new Call(MODULUS_INTEGER, ImmutableList.of(new Reference(INTEGER, "R_INT_COL"), new Constant(INTEGER, 2L))), new Constant(INTEGER, 0L))) .left( exchange(REMOTE, REPARTITION, - any( + filter( + new Comparison(EQUAL, new Call(MODULUS_INTEGER, ImmutableList.of(new Reference(INTEGER, "L_INT_PART"), new Constant(INTEGER, 2L))), new Constant(INTEGER, 0L)), tableScan("table_int_partitioned", Map.of("L_INT_PART", "int_part", "L_STR_COL", "str_col"))))) .right( exchange(LOCAL, exchange(REMOTE, REPARTITION, filter( - new Between(new Reference(INTEGER, "R_INT_COL"), new Constant(INTEGER, 2L), new Constant(INTEGER, 4L)), + new Logical(AND, ImmutableList.of(new In(new Reference(INTEGER, "R_INT_COL"), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L))), new Comparison(EQUAL, new Call(MODULUS_INTEGER, ImmutableList.of(new Reference(INTEGER, "R_INT_COL"), new Constant(INTEGER, 2L))), new Constant(INTEGER, 0L)))), tableScan("table_unpartitioned", Map.of("R_STR_COL", "str_col", "R_INT_COL", "int_col"))))))))); }