Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 53 additions & 3 deletions core/trino-main/src/main/java/io/trino/sql/ir/IrExpressions.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,27 @@
import com.google.common.collect.ImmutableList;
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.Int128;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.NumberType;
import io.trino.spi.type.RealType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.SmallintType;
import io.trino.spi.type.TinyintType;
import io.trino.spi.type.TrinoNumber;
import io.trino.sql.PlannerContext;
import io.trino.type.TypeCoercion;

import java.math.BigDecimal;
import java.util.List;

import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName;
import static io.trino.spi.block.RowValueBuilder.buildRowValue;
import static io.trino.spi.function.OperatorType.DIVIDE;
import static io.trino.spi.function.OperatorType.MODULUS;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.TypeUtils.writeNativeValue;
import static io.trino.spi.type.VarcharType.VARCHAR;
Expand Down Expand Up @@ -72,7 +86,7 @@ public static boolean mayFail(PlannerContext plannerContext, Expression expressi
// These expressions need to verify their operands
case Array e -> e.elements().stream().anyMatch(element -> mayFail(plannerContext, element));
case Between e -> mayFail(plannerContext, e.value()) || mayFail(plannerContext, e.min()) || mayFail(plannerContext, e.max());
case Call e -> mayFail(e.function()) || e.arguments().stream().anyMatch(argument -> mayFail(plannerContext, argument));
case Call e -> mayFail(e) || e.arguments().stream().anyMatch(argument -> mayFail(plannerContext, argument));
case Case e -> e.whenClauses().stream().anyMatch(clause -> mayFail(plannerContext, clause.getOperand()) || mayFail(plannerContext, clause.getResult())) ||
mayFail(plannerContext, e.defaultValue());
case Cast e -> mayFail(plannerContext, e);
Expand Down Expand Up @@ -103,9 +117,45 @@ private static boolean mayFail(PlannerContext plannerContext, Cast cast)
return !cast.type().equals(VARCHAR);
}

private static boolean mayFail(ResolvedFunction function)
private static boolean mayFail(Call call)
{
return !function.neverFails() && !isDynamicFilterFunction(function.name());
ResolvedFunction function = call.function();
if (function.neverFails() || isDynamicFilterFunction(function.name())) {
return false;
}
Comment thread
losipiuk marked this conversation as resolved.
List<Expression> arguments = call.arguments();
if (isModulsOrDivide(function) && arguments.get(1) instanceof Constant divisor && !canCauseDivisionByZeroError(divisor)) {
return false;
}
return true;
}

private static boolean isModulsOrDivide(ResolvedFunction function)
{
return (function.name().equals(builtinFunctionName(MODULUS)) || function.name().equals(builtinFunctionName(DIVIDE))) && function.signature().getArity() == 2;
}

private static boolean canCauseDivisionByZeroError(Constant divisor)
{
Object value = divisor.value();
if (value == null) {
return false; // dividing by null is null
}
return switch (divisor.type()) {
case TinyintType _, SmallintType _, IntegerType _, BigintType _ -> (long) value == 0;
case DecimalType decimalType -> {
if (decimalType.isShort()) {
yield (long) value == 0;
}
yield ((Int128) value).isZero();
}
case NumberType _ -> switch (((TrinoNumber) value).toBigDecimal()) {
case TrinoNumber.BigDecimalValue(BigDecimal bigdecimal) -> bigdecimal.signum() == 0;
case TrinoNumber.Infinity _, TrinoNumber.NotANumber _ -> false;
};
case RealType _, DoubleType _ -> false; // will return NaN or ±Inf on division by 0
default -> true;
};
}

public static Expression not(Metadata metadata, Expression expression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
import static io.trino.sql.ir.Comparison.Operator.EQUAL;
import static io.trino.sql.ir.Comparison.Operator.NOT_EQUAL;
import static io.trino.sql.ir.Logical.Operator.AND;
import static io.trino.sql.planner.assertions.PlanMatchPattern.any;
import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange;
import static io.trino.sql.planner.assertions.PlanMatchPattern.filter;
import static io.trino.sql.planner.assertions.PlanMatchPattern.join;
Expand Down Expand Up @@ -277,16 +276,16 @@ public void testSubsumePartitionFilterNotConvertibleToTupleDomain()
output(
join(INNER, builder -> builder
.equiCriteria("L_INT_PART", "R_INT_COL")
.filter(new Comparison(EQUAL, new Call(MODULUS_INTEGER, ImmutableList.of(new Reference(INTEGER, "R_INT_COL"), new Constant(INTEGER, 2L))), new Constant(INTEGER, 0L)))
.left(
exchange(REMOTE, REPARTITION,
any(
filter(
new Comparison(EQUAL, new Call(MODULUS_INTEGER, ImmutableList.of(new Reference(INTEGER, "L_INT_PART"), new Constant(INTEGER, 2L))), new Constant(INTEGER, 0L)),
tableScan("table_int_partitioned", Map.of("L_INT_PART", "int_part", "L_STR_COL", "str_col")))))
.right(
exchange(LOCAL,
exchange(REMOTE, REPARTITION,
filter(
new Between(new Reference(INTEGER, "R_INT_COL"), new Constant(INTEGER, 2L), new Constant(INTEGER, 4L)),
new Logical(AND, ImmutableList.of(new In(new Reference(INTEGER, "R_INT_COL"), ImmutableList.of(new Constant(INTEGER, 2L), new Constant(INTEGER, 4L))), new Comparison(EQUAL, new Call(MODULUS_INTEGER, ImmutableList.of(new Reference(INTEGER, "R_INT_COL"), new Constant(INTEGER, 2L))), new Constant(INTEGER, 0L)))),
tableScan("table_unpartitioned", Map.of("R_STR_COL", "str_col", "R_INT_COL", "int_col")))))))));
}

Expand Down
Loading