From 45b4a928ce03e706900ff135cd95d9a113f799e4 Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Thu, 30 Apr 2026 15:18:47 +0100 Subject: [PATCH 1/4] #15385 Spark: Support variant_get predicate pushdown for file skipping copilot's solution to why pushdown wasn't working, independent of qlong's #15385 I plan to take qlong's and pull what is extra from this one. --- .../iceberg/expressions/ExpressionUtil.java | 18 +++++++ .../org/apache/iceberg/spark/Spark3Util.java | 10 ++++ .../apache/iceberg/spark/SparkV2Filters.java | 50 ++++++++++++++++++- .../iceberg/spark/sql/TestFilterPushDown.java | 4 +- 4 files changed, 78 insertions(+), 4 deletions(-) diff --git a/api/src/main/java/org/apache/iceberg/expressions/ExpressionUtil.java b/api/src/main/java/org/apache/iceberg/expressions/ExpressionUtil.java index af24ce40cac8..202821a8b968 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/ExpressionUtil.java +++ b/api/src/main/java/org/apache/iceberg/expressions/ExpressionUtil.java @@ -239,6 +239,24 @@ public static String describe(Term term) { + "(" + describe(((BoundTransform) term).ref()) + ")"; + } else if (term instanceof UnboundExtract) { + UnboundExtract extract = (UnboundExtract) term; + return "variant_get(" + + extract.ref().name() + + ", '" + + extract.path() + + "', '" + + extract.type().toString() + + "')"; + } else if (term instanceof BoundExtract) { + BoundExtract extract = (BoundExtract) term; + return "variant_get(" + + extract.ref().name() + + ", '" + + extract.path() + + "', '" + + extract.type().toString() + + "')"; } else if (term instanceof NamedReference) { return ((NamedReference) term).name(); } else if (term instanceof BoundReference) { diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/Spark3Util.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/Spark3Util.java index 064e4f7d6dc7..eb0fb07a0370 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/Spark3Util.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/Spark3Util.java @@ -44,6 +44,7 @@ import org.apache.iceberg.expressions.BoundPredicate; import org.apache.iceberg.expressions.ExpressionVisitors; import org.apache.iceberg.expressions.Term; +import org.apache.iceberg.expressions.UnboundExtract; import org.apache.iceberg.expressions.UnboundPredicate; import org.apache.iceberg.expressions.UnboundTerm; import org.apache.iceberg.expressions.UnboundTransform; @@ -714,6 +715,15 @@ private static String sqlString(UnboundTerm term) { } else if (term instanceof UnboundTransform) { UnboundTransform transform = (UnboundTransform) term; return transform.transform().toString() + "(" + transform.ref().name() + ")"; + } else if (term instanceof UnboundExtract) { + UnboundExtract extract = (UnboundExtract) term; + return "variant_get(" + + extract.ref().name() + + ", '" + + extract.path() + + "', '" + + extract.type().toString() + + "')"; } else { throw new UnsupportedOperationException("Cannot convert term to SQL: " + term); } diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java index 57b9d61e38bd..2f9e845d613a 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java @@ -354,7 +354,7 @@ private static T childAtIndex(Predicate predicate, int index) { private static boolean canConvertToTerm( org.apache.spark.sql.connector.expressions.Expression expr) { - return isRef(expr) || isSystemFunc(expr); + return isRef(expr) || isSystemFunc(expr) || isVariantGetFunc(expr); } private static boolean isRef(org.apache.spark.sql.connector.expressions.Expression expr) { @@ -440,12 +440,58 @@ private static UnboundTerm toTerm(T input) { if (input instanceof NamedReference) { return Expressions.ref(SparkUtil.toColumnName((NamedReference) input)); } else if (input instanceof UserDefinedScalarFunc) { - return udfToTerm((UserDefinedScalarFunc) input); + UserDefinedScalarFunc udf = (UserDefinedScalarFunc) input; + if (isVariantGetFunc(udf)) { + return variantGetToTerm(udf); + } + return udfToTerm(udf); } else { return null; } } + private static boolean isVariantGetFunc( + org.apache.spark.sql.connector.expressions.Expression expr) { + if (!(expr instanceof UserDefinedScalarFunc)) { + return false; + } + UserDefinedScalarFunc udf = (UserDefinedScalarFunc) expr; + String name = udf.name().toLowerCase(Locale.ROOT); + if (!name.equals("variant_get") && !name.equals("try_variant_get")) { + return false; + } + org.apache.spark.sql.connector.expressions.Expression[] children = udf.children(); + return children.length == 3 + && isRef(children[0]) + && isLiteral(children[1]) + && isLiteral(children[2]); + } + + private static UnboundTerm variantGetToTerm(UserDefinedScalarFunc udf) { + org.apache.spark.sql.connector.expressions.Expression[] children = udf.children(); + String colName = SparkUtil.toColumnName((NamedReference) children[0]); + String path = convertLiteral((Literal) children[1]).toString(); + String sparkTypeName = convertLiteral((Literal) children[2]).toString(); + String icebergTypeName = sparkTypeNameToIceberg(sparkTypeName); + try { + return Expressions.extract(colName, path, icebergTypeName); + } catch (IllegalArgumentException e) { + return null; + } + } + + private static String sparkTypeNameToIceberg(String sparkTypeName) { + switch (sparkTypeName.toLowerCase(Locale.ROOT)) { + case "bigint": + return "long"; + case "tinyint": + case "smallint": + return "int"; + default: + return sparkTypeName; + } + } + @SuppressWarnings("checkstyle:CyclomaticComplexity") private static UnboundTerm udfToTerm(UserDefinedScalarFunc udf) { org.apache.spark.sql.connector.expressions.Expression[] children = udf.children(); diff --git a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/sql/TestFilterPushDown.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/sql/TestFilterPushDown.java index e5a9d63b68d6..141f6b3f949f 100644 --- a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/sql/TestFilterPushDown.java +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/sql/TestFilterPushDown.java @@ -616,7 +616,7 @@ public void testVariantExtractFiltering() { checkFilters( "try_variant_get(data, '$.num', 'int') IS NULL", "isnull(try_variant_get(data, $.num, IntegerType, false, Some(UTC)))", - "", + "variant_get(data, '$.num', 'int') IS NULL", ImmutableList.of(row(4L, null))); checkFilters( @@ -634,7 +634,7 @@ public void testVariantExtractFiltering() { checkFilters( "try_variant_get(data, '$.num', 'int') IN (25, 35)", "try_variant_get(data, $.num, IntegerType, false, Some(UTC)) IN (25,35)", - "", + "variant_get(data, '$.num', 'int') IN (25, 35)", ImmutableList.of( row(1L, toSparkVariantRow("foo", 25)), row(3L, toSparkVariantRow("baz", 35)))); From 4635ef5147b0f154f2f2f9e0703ee27f1228669e Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Fri, 24 Apr 2026 21:40:17 +0100 Subject: [PATCH 2/4] #15510 Support row group skipping for shredded variant columns ParquetMetricsRowGroupFilter.compareVariant() implements comparisons for variants, including NaN, null, uuid, set membership. * New ParquetVariantUtil splitter regexp can't cope with columns called ]. That's OK as normalisation forbids that and empty paths * Copilot wrote the tests so it's over-verbose, but thorough. * including NaN behaviour and string truncation on max values. * lazy build of variant mapping info, cached for entire file. There's no concurrency handling here in the build up of that lazy structure, once happy with the design it'll need to be locked down better. TestSparkVariantFilterPushDown to test variant pushdown through spark This needs a modified spark, currently. A package private counter is exported from ParquetMetricsRowGroupFilter to assist in testing as it can assess #of lookups during planning and execution. --- .../InclusiveMetricsEvaluator.java | 13 +- .../parquet/ParquetMetricsRowGroupFilter.java | 402 ++++++++- .../iceberg/parquet/ParquetVariantUtil.java | 45 + .../parquet/TestParquetVariantUtil.java | 53 ++ .../TestShreddedVariantRowGroupFilter.java | 796 ++++++++++++++++++ .../apache/iceberg/spark/SparkV2Filters.java | 30 +- .../TestSparkVariantFilterPushDown.java | 456 ++++++++++ .../apache/iceberg/spark/TestSpark3Util.java | 16 + .../iceberg/spark/TestSparkV2Filters.java | 60 ++ 9 files changed, 1848 insertions(+), 23 deletions(-) create mode 100644 parquet/src/test/java/org/apache/iceberg/parquet/TestParquetVariantUtil.java create mode 100644 parquet/src/test/java/org/apache/iceberg/parquet/TestShreddedVariantRowGroupFilter.java create mode 100644 spark/v4.1/spark/src/test/java/org/apache/iceberg/parquet/TestSparkVariantFilterPushDown.java diff --git a/api/src/main/java/org/apache/iceberg/expressions/InclusiveMetricsEvaluator.java b/api/src/main/java/org/apache/iceberg/expressions/InclusiveMetricsEvaluator.java index 81cbbe785519..2437f5fcd09a 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/InclusiveMetricsEvaluator.java +++ b/api/src/main/java/org/apache/iceberg/expressions/InclusiveMetricsEvaluator.java @@ -21,6 +21,7 @@ import static org.apache.iceberg.expressions.Expressions.rewriteNot; import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.Collection; import java.util.Comparator; import java.util.Map; @@ -631,7 +632,17 @@ private boolean isNonNullPreserving(Bound term) { } } + /** + * Build a variant from the buffer, regardless of the ordering of the incoming buffer. + * + * @param buffer source data + * @return variant instance + */ private static VariantObject parseBounds(ByteBuffer buffer) { - return Variant.from(buffer).value().asObject(); + final ByteBuffer littleEnded = + buffer.order() == ByteOrder.LITTLE_ENDIAN + ? buffer + : buffer.duplicate().order(ByteOrder.LITTLE_ENDIAN); + return Variant.from(littleEnded).value().asObject(); } } diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetMetricsRowGroupFilter.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetMetricsRowGroupFilter.java index cae9447513c0..f5aefbe56726 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetMetricsRowGroupFilter.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetMetricsRowGroupFilter.java @@ -21,19 +21,26 @@ import java.nio.ByteBuffer; import java.util.Collection; import java.util.Comparator; +import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.atomic.AtomicLong; import java.util.function.Function; import java.util.stream.Collectors; import org.apache.iceberg.Schema; import org.apache.iceberg.expressions.Binder; import org.apache.iceberg.expressions.Bound; +import org.apache.iceberg.expressions.BoundExtract; +import org.apache.iceberg.expressions.BoundPredicate; import org.apache.iceberg.expressions.BoundReference; import org.apache.iceberg.expressions.Expression; import org.apache.iceberg.expressions.ExpressionVisitors; import org.apache.iceberg.expressions.ExpressionVisitors.BoundExpressionVisitor; import org.apache.iceberg.expressions.Expressions; import org.apache.iceberg.expressions.Literal; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.types.Comparators; import org.apache.iceberg.types.Type; @@ -42,16 +49,42 @@ import org.apache.parquet.column.statistics.Statistics; import org.apache.parquet.hadoop.metadata.BlockMetaData; import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData; +import org.apache.parquet.hadoop.metadata.ColumnPath; import org.apache.parquet.io.api.Binary; import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.PrimitiveType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class ParquetMetricsRowGroupFilter { - private static final int IN_PREDICATE_LIMIT = 200; + public static final Logger LOG = LoggerFactory.getLogger(ParquetMetricsRowGroupFilter.class); + private static final int IN_PREDICATE_LIMIT = 200; private final Schema schema; private final Expression expr; + /** Is special handling of variants needed? */ + private final boolean hasVariantPredicates; + + /** + * variantColumnNames is file-schema-scoped (same for all row groups in a file), so cache it. * + */ + private MessageType cachedVariantFileSchema = null; + + /** Map of schema ID to variant column names. */ + private Map variantColumnNames = null; + + /** Set of variant top-level column names, derived from variantColumnNames values. */ + private Set variantTopLevelNames = null; + + /** + * For testing, especially while there are requirements for predecessor patches to be applied. + * + *

This permits assertions to be made that variant predicate pushdown reached this far and + * processed shredded columns. + */ + private static final AtomicLong VARIANT_PREDICATES_SHREDDED_METRICS = new AtomicLong(); + public ParquetMetricsRowGroupFilter(Schema schema, Expression unbound) { this(schema, unbound, true); } @@ -60,27 +93,44 @@ public ParquetMetricsRowGroupFilter(Schema schema, Expression unbound, boolean c this.schema = schema; StructType struct = schema.asStruct(); this.expr = Binder.bind(struct, Expressions.rewriteNot(unbound), caseSensitive); + this.hasVariantPredicates = hasVariantPredicates(expr); } /** - * Test whether the file may contain records that match the expression. + * Test whether the row group may contain records that match the expression. * * @param fileSchema schema for the Parquet file * @param rowGroup metadata for a row group * @return false if the file cannot contain rows that match the expression, true otherwise. */ + @SuppressWarnings("ReferenceEquality") public boolean shouldRead(MessageType fileSchema, BlockMetaData rowGroup) { + // identity check: same object means same file schema, no need to recompute variant column names + // REVIST: would the cached schema ever change within a file? + if (hasVariantPredicates && fileSchema != cachedVariantFileSchema) { + cachedVariantFileSchema = fileSchema; + variantColumnNames = buildVariantColumnNames(fileSchema); + variantTopLevelNames = ImmutableSet.copyOf(variantColumnNames.values()); + } return new MetricsEvalVisitor().eval(fileSchema, rowGroup); } private static final boolean ROWS_MIGHT_MATCH = true; private static final boolean ROWS_CANNOT_MATCH = false; + private record VariantColumnInfo(PrimitiveType type, int id, ColumnChunkMetaData chunkMetaData) {} + + /** Evaluate metrics, including variants. */ private class MetricsEvalVisitor extends BoundExpressionVisitor { private Map> stats = null; private Map valueCounts = null; private Map> conversions = null; + // ID-less columns collected during the main column scan for lazy variantInfoByColumnPath build + private List shreddedVariantColumns = null; + // Built lazily on the first compareVariant() call; null means not yet built + private Map variantInfoByColumnPath = null; + private boolean eval(MessageType fileSchema, BlockMetaData rowGroup) { if (rowGroup.getRowCount() <= 0) { return ROWS_CANNOT_MATCH; @@ -89,6 +139,8 @@ private boolean eval(MessageType fileSchema, BlockMetaData rowGroup) { this.stats = Maps.newHashMap(); this.valueCounts = Maps.newHashMap(); this.conversions = Maps.newHashMap(); + this.shreddedVariantColumns = hasVariantPredicates ? Lists.newArrayList() : null; + this.variantInfoByColumnPath = null; for (ColumnChunkMetaData col : rowGroup.getColumns()) { PrimitiveType colType = fileSchema.getType(col.getPath().toArray()).asPrimitiveType(); if (colType.getId() != null) { @@ -97,12 +149,28 @@ private boolean eval(MessageType fileSchema, BlockMetaData rowGroup) { stats.put(id, col.getStatistics()); valueCounts.put(id, col.getValueCount()); conversions.put(id, ParquetConversions.converterFromParquet(colType, icebergType)); + } else if (shreddedVariantColumns != null) { + // Shredded variant typed_value columns have no Iceberg field ID, just parquet ones. + // Pre-filter to only those under known variant top-level column names; buildVariantInfo() + // then just copies the list into a map on the first compareVariant() call. + ColumnPath columnPath = col.getPath(); + String[] pathParts = columnPath.toArray(); + if (pathParts.length > 0 && variantTopLevelNames.contains(pathParts[0])) { + shreddedVariantColumns.add(new VariantColumnInfo(colType, -1, col)); + } } } - return ExpressionVisitors.visitEvaluator(expr, this); } + /** Build variantInfoByColumnPath lazily on the first compareVariant() call. */ + private void buildVariantInfo() { + variantInfoByColumnPath = Maps.newHashMap(); + for (VariantColumnInfo colInfo : shreddedVariantColumns) { + variantInfoByColumnPath.put(colInfo.chunkMetaData().getPath(), colInfo); + } + } + @Override public Boolean alwaysTrue() { return ROWS_MIGHT_MATCH; // all rows match @@ -560,6 +628,239 @@ private T max(Statistics statistics, int id) { return (T) conversions.get(id).apply(statistics.genericGetMax()); } + @Override + public Boolean predicate(BoundPredicate pred) { + if (pred.term() instanceof BoundExtract term) { + // it's a variant predicate: process accordingly. + return compareVariant(pred, term); + } else { + return super.predicate(pred); + } + } + + /** + * Compare a variant with the predicate. For floats and doubles, expects the parquet writer to + * have normalized -0.0 to +0.0. + * + * @param pred predicate + * @param extract extracted variant reference + * @param type of predicate + * @return true if the file rows should be read (i.e. false iff we are confident they can be + * skipped) + */ + private boolean compareVariant(BoundPredicate pred, BoundExtract extract) { + if (variantInfoByColumnPath == null) { + // TODO: concurrency ? + buildVariantInfo(); + } + int fieldId = extract.ref().fieldId(); + LOG.info("comparing variant {}", extract); + String colName = variantColumnNames.get(fieldId); + if (colName == null) { + // not in the variant columns + return ROWS_MIGHT_MATCH; + } + // Build the column path of which a shredded field would have + ColumnPath columnPath = ParquetVariantUtil.shreddedColumnPath(colName, extract.path()); + final VariantColumnInfo columnInfo = variantInfoByColumnPath.get(columnPath); + if (columnInfo == null) { + // the column isn't shredded in this file, so no statistics are available. + return ROWS_MIGHT_MATCH; + } + // increment shredded metrics counter. + VARIANT_PREDICATES_SHREDDED_METRICS.incrementAndGet(); + + // now do the evaluation. + LOG.info("Evaluating column {} with info {}", columnPath, columnInfo); + PrimitiveType parquetType = columnInfo.type(); + final ColumnChunkMetaData col = columnInfo.chunkMetaData; + Statistics colStats = col.getStatistics(); + long valueCount = col.getValueCount(); + if (parquetType == null || colStats == null) { + // no type info or column stats + return ROWS_MIGHT_MATCH; + } + + // everything here onwards expects colStats != null + if (pred.isUnaryPredicate()) { + return evalUnaryPredicate(pred, colStats, valueCount); + } + final Boolean outcome = evalShreddedColumnStats(colStats, valueCount); + if (outcome != null) { + return outcome; + } + if (pred.isSetPredicate()) { + // set check + return evalMembershipPredicateOnShreddedVariant(pred, extract, parquetType, colStats); + } + if (!pred.isLiteralPredicate()) { + return ROWS_MIGHT_MATCH; + } + // get this far: it's a shredded variant with column statistics + return evalBinaryPredicateOnShreddedVariant(pred, extract, parquetType, colStats); + } + + /** + * Evaluate the statistics, return an Boolean value if there was enough information to make a + * decision. + * + *

This is a bit contrived, but it keeps the complexity of {@code compareVariant()} below the + * checkstyle limit. + * + * @param colStats column statistics + * @param valueCount number of values. + * @return an boolean which is null if no conclusion is reached. + */ + private Boolean evalShreddedColumnStats(Statistics colStats, long valueCount) { + if (colStats.isEmpty()) { + return ROWS_MIGHT_MATCH; + } + if (allNulls(colStats, valueCount)) { + // there's no non-null columns, therefore all comparators will be false + return ROWS_CANNOT_MATCH; + } + if (minMaxUndefined(colStats)) { + // min or max undefined, so ranged comparisons not possible + return ROWS_MIGHT_MATCH; + } + return null; + } + + /** + * Evaluate a predicate against two shredded values: should the rowgroup be read? + * + * @param pred predicate + * @param extract bounded extrat + * @param parquetType the parquet type of the column + * @param colStats column statistics. + * @param type of the arguments + * @return true if the rowgroup should be read. + */ + @SuppressWarnings("unchecked") + private boolean evalBinaryPredicateOnShreddedVariant( + BoundPredicate pred, + BoundExtract extract, + PrimitiveType parquetType, + Statistics colStats) { + // it's a binary predicate with valid results from comparisons with + // the stats. So get their min and max, compare them with the literal value. + Literal lit = pred.asLiteralPredicate().literal(); + // get the type converter for the evaluation + Function converter = + ParquetConversions.converterFromParquet(parquetType, extract.type()); + T min = (T) converter.apply(colStats.genericGetMin()); + T max = (T) converter.apply(colStats.genericGetMax()); + + final int minVsLiteral = lit.comparator().compare(min, lit.value()); + final int literalVsMax = lit.comparator().compare(lit.value(), max); + + // is the min-max range within that of the predicate? + boolean inRange = + switch (pred.op()) { + // min value is less than the literal + case LT -> minVsLiteral < 0; + // min value is less than or equal to the literal + case LT_EQ -> minVsLiteral <= 0; + // max value is > the literal + case GT -> literalVsMax < 0; + // max value is > or == to the literal + case GT_EQ -> literalVsMax <= 0; + // min <= lit and max >= min so one element + // may match lit. + case EQ -> minVsLiteral <= 0 && literalVsMax <= 0; + // anything else + default -> true; + }; + + return inRange; + } + + /** + * Evaluate an IN predicate against a shredded variant column's min/max statistics. + * + *

A row group can be skipped if no value in the IN set falls within [min, max]. + * + * @param pred IN predicate + * @param extract the bound extract term + * @param parquetType the Parquet type of the shredded column + * @param colStats column statistics + * @param type of the predicate values + * @return true if the row group might match, false if it cannot match + */ + @SuppressWarnings("unchecked") + private boolean evalMembershipPredicateOnShreddedVariant( + BoundPredicate pred, + BoundExtract extract, + PrimitiveType parquetType, + Statistics colStats) { + + if (pred.op() != Expression.Operation.IN) { + // not looking at other set member operations. + return ROWS_MIGHT_MATCH; + } + Set literalSet = pred.asSetPredicate().literalSet(); + + if (literalSet.size() > IN_PREDICATE_LIMIT) { + return ROWS_MIGHT_MATCH; + } + LOG.info("Set membership evaluation"); + Function converter = + ParquetConversions.converterFromParquet(parquetType, extract.type()); + T min = (T) converter.apply(colStats.genericGetMin()); + T max = (T) converter.apply(colStats.genericGetMax()); + + // keep only values that are >= min + Collection candidates = + literalSet.stream().filter(v -> pred.term().comparator().compare(min, v) <= 0).toList(); + + // nothing is above the minimum + if (candidates.isEmpty()) { + return ROWS_CANNOT_MATCH; + } + + // second filter: keep only values that are <= max + candidates = + candidates.stream().filter(v -> pred.term().comparator().compare(max, v) >= 0).toList(); + + final boolean match = candidates.isEmpty() ? ROWS_CANNOT_MATCH : ROWS_MIGHT_MATCH; + LOG.info("Outcome match={}", match); + return match; + } + + /** + * Evaluate a Unary Predicate. Pulled out of the main compareVariant() call due to checkstyle + * rejecting the complexity of that method. + * + * @param pred predicate being evaluated. + * @param colStats column statistics + * @param valueCount number of elements in the rowgroup + * @param type of predicate + * @return true if the rowgroup should be read. + */ + private boolean evalUnaryPredicate( + BoundPredicate pred, Statistics colStats, long valueCount) { + LOG.info("Evaluating unary predicate: {}", pred.op()); + switch (pred.op()) { + case IS_NULL -> { + // If every row has a non-null typed value, no row can match IS_NULL + if (!colStats.isEmpty() && colStats.isNumNullsSet() && colStats.getNumNulls() == 0) { + return ROWS_CANNOT_MATCH; + } + return ROWS_MIGHT_MATCH; + } + case NOT_NULL -> { + // If every row has a null typed value, no row can match NOT_NULL + if (!colStats.isEmpty() && allNulls(colStats, valueCount)) { + return ROWS_CANNOT_MATCH; + } + return ROWS_MIGHT_MATCH; + } + default -> { + return ROWS_MIGHT_MATCH; + } + } + } + @Override public Boolean handleNonReference(Bound term) { return ROWS_MIGHT_MATCH; @@ -584,6 +885,8 @@ static boolean nullMinMax(Statistics statistics) { /** * The internal logic of Parquet-MR says that if numNulls is set but hasNonNull value is false, * then the min/max of the column are undefined. + * + *

Parquet also sets this for a float/double if there is a NaN in the row group. */ static boolean minMaxUndefined(Statistics statistics) { return (statistics.isNumNullsSet() && !statistics.hasNonNullValue()) || nullMinMax(statistics); @@ -596,4 +899,97 @@ static boolean allNulls(Statistics statistics, long valueCount) { private static boolean mayContainNull(Statistics statistics) { return !statistics.isNumNullsSet() || statistics.getNumNulls() > 0; } + + /** + * Scan the schema for variant types and build a map of variant columns. + * + * @param fileSchema file schema. + * @return the map of variant column names, may be empty. + */ + private Map buildVariantColumnNames(MessageType fileSchema) { + LOG.info("Building variant column names..."); + Map names = Maps.newHashMap(); + for (org.apache.parquet.schema.Type field : fileSchema.getFields()) { + if (field.getId() != null) { + int id = field.getId().intValue(); + Type icebergType = schema.findType(id); + if (icebergType != null && icebergType.isVariantType()) { + names.put(id, field.getName()); + } + } + } + LOG.info("Found {} names", names.size()); + return names; + } + + /** + * Does an expression have variant predicates? + * + * @param expr expression to evaluate. + * @return true if any part of the expression refers to a variant. + */ + private static boolean hasVariantPredicates(Expression expr) { + return ExpressionVisitors.visit(expr, HasVariantPredicateVisitor.INSTANCE); + } + + /** + * Visitor for scanning an expression for variants. + * + *

This isn't trying to evaluate the expression, so and/or/not aggregate the results, rather + * than apply boolean predicates on child nodes. + */ + private static final class HasVariantPredicateVisitor extends BoundExpressionVisitor { + static final HasVariantPredicateVisitor INSTANCE = new HasVariantPredicateVisitor(); + + @Override + public Boolean alwaysTrue() { + return false; + } + + @Override + public Boolean alwaysFalse() { + return false; + } + + @Override + public Boolean not(Boolean result) { + return result; + } + + @Override + public Boolean and(Boolean leftResult, Boolean rightResult) { + return leftResult || rightResult; + } + + @Override + public Boolean or(Boolean leftResult, Boolean rightResult) { + return leftResult || rightResult; + } + + @Override + public Boolean predicate(BoundPredicate pred) { + return pred.term() instanceof BoundExtract; + } + + @Override + public Boolean handleNonReference(Bound term) { + return false; + } + } + + /** + * The number of times shredded metric predicates have been evaluated. + * + * @return zero or a positive integer + */ + @VisibleForTesting + static long variantPredicatesShreddedMetrics() { + return VARIANT_PREDICATES_SHREDDED_METRICS.get(); + } + + /** Reset the shredded metrics counter. */ + @VisibleForTesting + static void resetShreddedMetricsCounter() { + VARIANT_PREDICATES_SHREDDED_METRICS.set(0); + } } diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantUtil.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantUtil.java index ac418a1127bd..a0a4667617be 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantUtil.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantUtil.java @@ -27,6 +27,8 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import org.apache.iceberg.expressions.PathUtil; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Lists; @@ -42,6 +44,7 @@ import org.apache.iceberg.variants.VariantValue; import org.apache.iceberg.variants.VariantVisitor; import org.apache.iceberg.variants.Variants; +import org.apache.parquet.hadoop.metadata.ColumnPath; import org.apache.parquet.io.api.Binary; import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.LogicalTypeAnnotation; @@ -548,4 +551,46 @@ private static Type shreddedPrimitive( return Types.optional(primitive).as(annotation).length(length).named("typed_value"); } } + + /** Matches bracket-notation path segments like {@code ['fieldName']} in normalized paths. */ + private static final Pattern PATH_SEGMENT_PATTERN = Pattern.compile("\\['([^']+)'\\]"); + + /** + * Build the Parquet column path for a shredded variant field. + * + *

For a variant column named {@code v} and path {@code $['price']}, the physical Parquet + * column path is {@code ["v", "typed_value", "price", "typed_value"]}. For nested paths like + * {@code $['user']['name']}, the path is {@code ["v", "typed_value", "user", "typed_value", + * "name", "typed_value"]}. For the root path {@code $}, it is {@code ["v", "typed_value"]}. + */ + static ColumnPath shreddedColumnPath(String colName, String normalizedPath) { + List segments = parsePathSegments(normalizedPath); + String[] pathArray = new String[(1 + segments.size()) * 2]; + pathArray[0] = colName; + pathArray[1] = "typed_value"; + for (int i = 0, offset = 2; i < segments.size(); i++, offset += 2) { + pathArray[offset] = segments.get(i); + pathArray[offset + 1] = "typed_value"; + } + return ColumnPath.get(pathArray); + } + + /** + * Parse a normalized RFC9535 bracket-notation path into field name segments. The root path {@code + * $} returns an empty list. Paths are produced by {@link + * org.apache.iceberg.expressions.PathUtil#toNormalizedPath} and always use single-quoted bracket + * notation with identifiers that contain only letters, digits, underscores, and high-Unicode + * characters (no escape sequences). + * + * @param normalizedPath path to parse + * @return the path segments. + */ + static List parsePathSegments(String normalizedPath) { + List segments = Lists.newArrayList(); + Matcher matcher = PATH_SEGMENT_PATTERN.matcher(normalizedPath); + while (matcher.find()) { + segments.add(matcher.group(1)); + } + return segments; + } } diff --git a/parquet/src/test/java/org/apache/iceberg/parquet/TestParquetVariantUtil.java b/parquet/src/test/java/org/apache/iceberg/parquet/TestParquetVariantUtil.java new file mode 100644 index 000000000000..9234b925b3d2 --- /dev/null +++ b/parquet/src/test/java/org/apache/iceberg/parquet/TestParquetVariantUtil.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.parquet; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.FieldSource; + +/** Tests for {@code ParquetVariantUti]} */ +public class TestParquetVariantUtil { + + @ParameterizedTest + @FieldSource("PATH_MAPPINGS") + public void testParsePathSegments(Parsing args) { + assertThat(ParquetVariantUtil.parsePathSegments(args.input)) + .describedAs("Parsing of path segments %s", args.input) + .containsExactly(args.output); + } + + /** + * A variant path parses to a seqence of column names. + * + * @param input input string + * @param output varargs list of result + */ + private record Parsing(String input, String... output) {} + + private static final List PATH_MAPPINGS = + List.of( + new Parsing("$"), + new Parsing("$['a']['b']", "a", "b"), + new Parsing("$['0']['1']", "0", "1"), + new Parsing("$['user']['firstName']", "user", "firstName"), + new Parsing("$['_under_score']", "_under_score")); +} diff --git a/parquet/src/test/java/org/apache/iceberg/parquet/TestShreddedVariantRowGroupFilter.java b/parquet/src/test/java/org/apache/iceberg/parquet/TestShreddedVariantRowGroupFilter.java new file mode 100644 index 000000000000..367025c0b1b6 --- /dev/null +++ b/parquet/src/test/java/org/apache/iceberg/parquet/TestShreddedVariantRowGroupFilter.java @@ -0,0 +1,796 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.parquet; + +import static org.apache.iceberg.expressions.Expressions.and; +import static org.apache.iceberg.expressions.Expressions.equal; +import static org.apache.iceberg.expressions.Expressions.extract; +import static org.apache.iceberg.expressions.Expressions.greaterThan; +import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual; +import static org.apache.iceberg.expressions.Expressions.in; +import static org.apache.iceberg.expressions.Expressions.isNull; +import static org.apache.iceberg.expressions.Expressions.lessThan; +import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual; +import static org.apache.iceberg.expressions.Expressions.not; +import static org.apache.iceberg.expressions.Expressions.notNull; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.util.List; +import java.util.Set; +import java.util.UUID; +import org.apache.iceberg.Schema; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.data.parquet.InternalWriter; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.expressions.UnboundTerm; +import org.apache.iceberg.inmemory.InMemoryOutputFile; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.variants.ShreddedObject; +import org.apache.iceberg.variants.Variant; +import org.apache.iceberg.variants.VariantMetadata; +import org.apache.iceberg.variants.VariantTestUtil; +import org.apache.iceberg.variants.Variants; +import org.apache.parquet.hadoop.ParquetFileReader; +import org.apache.parquet.hadoop.metadata.BlockMetaData; +import org.apache.parquet.schema.MessageType; +import org.assertj.core.api.AbstractBooleanAssert; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Tests for shredded variant row-group skipping in {@link ParquetMetricsRowGroupFilter}. + * + *

Verifies that when a variant column has shredded fields (e.g. {@code $.price}), the filter can + * use column chunk statistics to skip row groups that cannot match a predicate. + */ +class TestShreddedVariantRowGroupFilter { + + private static final Logger LOG = + LoggerFactory.getLogger(TestShreddedVariantRowGroupFilter.class); + + private static final Schema SCHEMA = + new Schema( + Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "var", Types.VariantType.get())); + + private static final VariantMetadata METADATA = + VariantMetadata.from(VariantTestUtil.createMetadata(Set.of("price", "name", "user"), true)); + + private static final VariantMetadata UUID_METADATA = + VariantMetadata.from(VariantTestUtil.createMetadata(Set.of("deviceid"), true)); + + /** Price field evaluated as an integer. */ + private static final UnboundTerm PRICE = extract("var", "$.price", "int"); + + /** Name field evaluated as a string. */ + private static final UnboundTerm NAME = extract("var", "$.name", "string"); + + /** Device ID as a UUID. */ + private static final UnboundTerm DEVICEID_UUID = extract("var", "$.deviceid", "uuid"); + + /** Zero UUID. */ + private static final UUID UUID_ZERO = UUID.fromString("00000000-0000-0000-0000-000000000000"); + + private static final UUID UUID_LOW = UUID.fromString("00000000-0000-0000-0000-000000000001"); + private static final UUID UUID_MID = UUID.fromString("00000000-0000-0000-0000-000000000005"); + private static final UUID UUID_HIGH = UUID.fromString("00000000-0000-0000-0000-00000000000a"); + + /** IUnknown is abvoe the high marker. */ + private static final UUID UUID_ABOVE_HIGH = + UUID.fromString("00000000-0000-0000-C000-000000000046"); + + @BeforeEach + void before() { + ParquetMetricsRowGroupFilter.resetShreddedMetricsCounter(); + } + + @Test + void testShreddedIntLessThan() throws IOException { + // Row group has prices 10..14; predicate price < 10 should skip + List variants = intPriceVariants(10, 11, 12, 13, 14); + assertThat(shouldRead(lessThan(PRICE, 10), variants)) + .as("Should skip: all prices >= 10, predicate requires price < 10") + .isFalse(); + } + + @Test + void testShreddedIntLessThanOverlapping() throws IOException { + // Row group has prices 10..14; predicate price < 11 should read + List variants = intPriceVariants(10, 11, 12, 13, 14); + + assertThat(shouldRead(lessThan(PRICE, 11), variants)) + .as("Should read: price range [10,14] overlaps price < 11") + .isTrue(); + } + + @Test + void testShreddedIntLessThanOrEqual() throws IOException { + List variants = intPriceVariants(10, 11, 12, 13, 14); + assertThat(shouldRead(lessThanOrEqual(PRICE, 9), variants)) + .as("Should skip: all prices >= 10, predicate requires price <= 9") + .isFalse(); + + assertThat(shouldRead(lessThanOrEqual(PRICE, 10), variants)) + .as("Should read: min price == 10 matches price <= 10") + .isTrue(); + } + + @Test + void testShreddedIntGreaterThan() throws IOException { + List variants = intPriceVariants(10, 11, 12, 13, 14); + assertThat(shouldRead(greaterThan(PRICE, 14), variants)) + .as("Should skip: all prices <= 14, predicate requires price > 14") + .isFalse(); + + assertThat(shouldRead(greaterThan(PRICE, 13), variants)) + .as("Should read: price range [10,14] overlaps price > 13") + .isTrue(); + } + + @Test + void testShreddedIntGreaterThanOrEqual() throws IOException { + List variants = intPriceVariants(10, 11, 12, 13, 14); + assertThat(shouldRead(greaterThanOrEqual(PRICE, 15), variants)) + .as("Should skip: all prices <= 14, predicate requires price >= 15") + .isFalse(); + + assertThat(shouldRead(greaterThanOrEqual(PRICE, 14), variants)) + .as("Should read: max price == 14 matches price >= 14") + .isTrue(); + } + + @Test + void testShreddedIntEqual() throws IOException { + List variants = intPriceVariants(10, 11, 12, 13, 14); + assertThat(shouldRead(equal(PRICE, 9), variants)).as("Should skip: 9 < min price 10").isFalse(); + assertThat(shouldRead(equal(PRICE, 15), variants)) + .as("Should skip: 15 > max price 14") + .isFalse(); + assertThat(shouldRead(equal(PRICE, 12), variants)) + .as("Should read: 12 is within price range [10,14]") + .isTrue(); + } + + @Test + void testShreddedIsNullNoNulls() throws IOException { + // All rows have the typed price value (no nulls in typed_value) + List variants = intPriceVariants(10, 11, 12); + assertThat(shouldRead(isNull(PRICE), variants)) + .as("Should skip: all rows have typed price, IS_NULL cannot match") + .isFalse(); + } + + @Test + void testShreddedNotNullAllNulls() throws IOException { + // All rows have $.price as variant null; typed_value will be null for all rows + ImmutableList.Builder builder = ImmutableList.builder(); + for (int i = 0; i < 3; i++) { + ShreddedObject obj = Variants.object(METADATA); + obj.put("price", Variants.ofNull()); // explicit variant null → typed_value is null + builder.add(Variant.of(METADATA, obj)); + } + List variants = builder.build(); + + // Shredding function says price is an int; but values are variant nulls → typed_value all null + ShreddedObject example = Variants.object(METADATA); + example.put("price", Variants.of(0)); + VariantShreddingFunction shreddingFunc = + (id, name) -> ParquetVariantUtil.toParquetSchema(example); + + assertThat(shouldRead(notNull(PRICE), variants, shreddingFunc)) + .as("Should skip: no rows have typed price, NOT_NULL cannot match") + .isFalse(); + } + + @Test + void testUnshreddedPathMightMatch() throws IOException { + // $.name is not shredded; filter should return MIGHT_MATCH (true) + List variants = intPriceVariants(10, 11, 12); + // Use the price shredding function — $.name is not shredded so falls back + assertThat(shouldRead(equal(NAME, "foo"), variants)) + .as("Should read: unshredded path must fall back to MIGHT_MATCH") + .isTrue(); + } + + @Test + void testNestedShreddedPath() throws IOException { + // Shred $.user.name as a string + ShreddedObject user1 = Variants.object(METADATA); + user1.put("name", Variants.of("alice")); + ShreddedObject row1 = Variants.object(METADATA); + row1.put("user", user1); + + ShreddedObject user2 = Variants.object(METADATA); + user2.put("name", Variants.of("bob")); + ShreddedObject row2 = Variants.object(METADATA); + row2.put("user", user2); + + List variants = + ImmutableList.of(Variant.of(METADATA, row1), Variant.of(METADATA, row2)); + + // Shred $.user as an object with a shredded .name inside + VariantShreddingFunction shreddingFunc = (id, name) -> ParquetVariantUtil.toParquetSchema(row1); + + // The path $.user is shredded but $.user.name is nested deeper — + // the filter should handle multi-level typed_value paths correctly + final UnboundTerm username = extract("var", "$.user.name", "string"); + assertThat(shouldRead(equal(username, "charlie"), variants, shreddingFunc)) + .as("Should skip: 'charlie' is outside the [alice, bob] range") + .isFalse(); + + assertThat(shouldRead(equal(username, "alice"), variants, shreddingFunc)) + .as("Should read: 'alice' is within the range") + .isTrue(); + } + + @Test + void testShreddedLongPredicates() throws IOException { + List variants = longPriceVariants(100L, 101L, 102L, 103L, 104L); + + final UnboundTerm term = extract("var", "$.price", "long"); + assertNotRead(lessThan(term, 100L), variants, "< 100"); + assertIsRead(lessThanOrEqual(term, 100L), variants, "<= 100"); + assertIsRead(lessThan(term, 105L), variants, "< 105"); + assertIsRead(lessThan(term, 1119L), variants, "< 119L"); + + assertNotRead(greaterThan(term, 104L), variants, "> 104"); + assertIsRead(greaterThanOrEqual(term, 104L), variants, " >= 104"); + assertIsRead(greaterThan(term, 103L), variants, "> 103"); + assertIsRead(greaterThan(term, 99L), variants, "> 99"); + + assertNotRead(equal(term, 99L), variants, "= 99"); + assertNotRead(equal(term, 105L), variants, "= 105"); + assertIsRead(equal(term, 102L), variants, "= 102"); + assertNotRead(equal(term, 99L), variants, "= 105"); + // some not terms to see how they come out + assertIsRead(not(equal(term, 108L)), variants, "!(= 108)"); + assertIsRead(not((greaterThan(term, 104L))), variants, "!(> 104)"); + } + + private AbstractBooleanAssert assertExpression( + Expression expr, List variants, String text) throws IOException { + return assertThat(shouldRead(expr, variants)) + .as( + "Predicate '%s' on `range [%s, %s]", + text, variants.get(0).value(), variants.get(variants.size() - 1)); + } + + private void assertIsRead(Expression expr, List variants, String text) + throws IOException { + assertExpression(expr, variants, text).isTrue(); + LOG.info("Predicate '{}' succeeded", text); + } + + private void assertNotRead(Expression expr, List variants, String text) + throws IOException { + assertExpression(expr, variants, text).isFalse(); + LOG.info("Predicate '{}' succeessfully rejected", text); + } + + @Test + void testShreddedFloatPredicates() throws IOException { + List variants = floatPriceVariants(1.0F, 2.0F, 3.0F); + + final UnboundTerm term = extract("var", "$.price", "float"); + assertNotRead(lessThan(term, 1.0F), variants, "< 1.0F"); + assertNotRead(greaterThan(term, 3.0F), variants, "> 3.0F"); + // float equality is always dubious + assertIsRead(equal(term, 2.0F), variants, "= 2.0F"); + } + + @Test + void testShreddedFloatNaNDropsMinMax() throws IOException { + // When any value is NaN, Parquet drops min/max → filter must return MIGHT_MATCH + List variants = floatPriceVariants(1.0F, Float.NaN, 3.0F); + + // Even though 99.0 is nowhere near [1.0, NaN, 3.0], stats are dropped so must read + final UnboundTerm term = extract("var", "$.price", "float"); + assertIsRead(equal(term, 99.0F), variants, "= 99.0F with a NaN in the row group"); + assertIsRead( + greaterThan(term, 1000.0F), + variants, + "> 1000; NaN in row group causes Parquet to drop min/max → MIGHT_MATCH"); + } + + @Test + void testShreddedDoublePredicates() throws IOException { + List variants = doublePriceVariants(1.0D, 2.0D, 3.0D); + + final UnboundTerm term = extract("var", "$.price", "double"); + assertNotRead(lessThan(term, 1.0D), variants, "< 1.0D"); + assertNotRead(greaterThan(term, 3.0D), variants, "> 3.0D"); + assertIsRead(equal(term, 2.0D), variants, "= 2.0D"); + } + + @Test + void testShreddedDoubleNaNDropsMinMax() throws IOException { + // When any value is NaN, Parquet drops min/max → filter MUST return MIGHT_MATCH + List variants = doublePriceVariants(1.0D, Double.NaN, 3.0D); + VariantShreddingFunction shreddingFunc = doublePriceShreddingFunc(); + + final UnboundTerm term = extract("var", "$.price", "double"); + assertThat(shouldRead(equal(term, 99.0D), variants, shreddingFunc)) + .as("Should read: NaN in row group causes Parquet to drop min/max → MIGHT_MATCH") + .isTrue(); + assertThat(shouldRead(greaterThan(term, 1000.0D), variants, shreddingFunc)) + .as("Should read: NaN in row group causes Parquet to drop min/max → MIGHT_MATCH") + .isTrue(); + } + + @Test + void testShreddedIntNotEqual() throws IOException { + // NOT_EQ always returns MIGHT_MATCH + final UnboundTerm term = PRICE; + + assertIsRead(not(equal(term, 10)), intPriceVariants(10, 10, 10), "!(= 10) all values are 10"); + assertIsRead( + not(equal(term, 10)), intPriceVariants(10, 11, 12, 13, 14), "!(= 10) range [10,14]"); + assertIsRead( + not(equal(term, 99)), intPriceVariants(10, 11, 12, 13, 14), "!(= 99) range [10,14]"); + } + + @Test + void testShreddedStringPredicates() throws IOException { + // $.name shredded as a string field; names span "alice".."charlie" + List variants = nameStringVariants("alice", "bob", "charlie"); + VariantShreddingFunction shreddingFunc = nameStringShreddingFunc(); + + final UnboundTerm term = NAME; + + // Nothing is lexicographically less than the minimum "alice" + assertThat(shouldRead(lessThan(term, "alice"), variants, shreddingFunc)) + .as("Should skip: 'alice' is min, nothing is < 'alice'") + .isFalse(); + + // Nothing is greater than the maximum "charlie" + assertThat(shouldRead(greaterThan(term, "charlie"), variants, shreddingFunc)) + .as("Should skip: 'charlie' is max, nothing is > 'charlie'") + .isFalse(); + + // "david" sorts after max "charlie" → EQ cannot match + assertThat(shouldRead(equal(term, "david"), variants, shreddingFunc)) + .as("Should skip: 'david' > max 'charlie'") + .isFalse(); + + // "aardvark" sorts before min "alice" → EQ cannot match + assertThat(shouldRead(equal(term, "aardvark"), variants, shreddingFunc)) + .as("Should skip: 'aardvark' < min 'alice'") + .isFalse(); + + // "bob" is within [alice, charlie] → EQ might match + assertThat(shouldRead(equal(term, "bob"), variants, shreddingFunc)) + .as("Should read: 'bob' is within [alice, charlie]") + .isTrue(); + + // GT "alice" → might match (max "charlie" > "alice") + assertThat(shouldRead(greaterThan(term, "alice"), variants, shreddingFunc)) + .as("Should read: range [alice, charlie] overlaps > 'alice'") + .isTrue(); + } + + @Test + void testShreddedLiteralAllNulls() throws IOException { + // All typed_value entries are null: any literal predicate must return CANNOT_MATCH + ImmutableList.Builder builder = ImmutableList.builder(); + for (int i = 0; i < 3; i++) { + ShreddedObject obj = Variants.object(METADATA); + obj.put("price", Variants.ofNull()); // variant null → typed_value written as null + builder.add(Variant.of(METADATA, obj)); + } + List variants = builder.build(); + + // Shredding function defines price as int so the typed_value column exists in the schema + ShreddedObject example = Variants.object(METADATA); + example.put("price", Variants.of(0)); + VariantShreddingFunction shreddingFunc = + (id, name) -> ParquetVariantUtil.toParquetSchema(example); + + final UnboundTerm term = PRICE; + assertThat(shouldRead(equal(term, 5), variants, shreddingFunc)) + .as("Should skip: all typed_value are null, EQ cannot match") + .isFalse(); + assertThat(shouldRead(greaterThan(term, 0), variants, shreddingFunc)) + .as("Should skip: all typed_value are null, GT cannot match") + .isFalse(); + assertThat(shouldRead(lessThan(term, 100), variants, shreddingFunc)) + .as("Should skip: all typed_value are null, LT cannot match") + .isFalse(); + } + + @Test + void testShreddedIsNullSomeNulls() throws IOException { + // Mix of typed_value (non-null) and null typed_value — IS_NULL should read + ShreddedObject example = Variants.object(METADATA); + example.put("price", Variants.of(0)); + VariantShreddingFunction shreddingFunc = + (id, name) -> ParquetVariantUtil.toParquetSchema(example); + + ImmutableList.Builder builder = ImmutableList.builder(); + for (int price : new int[] {10, 11, 12}) { + ShreddedObject obj = Variants.object(METADATA); + obj.put("price", Variants.of(price)); + builder.add(Variant.of(METADATA, obj)); + } + for (int i = 0; i < 2; i++) { + ShreddedObject obj = Variants.object(METADATA); + obj.put("price", Variants.ofNull()); // null typed_value + builder.add(Variant.of(METADATA, obj)); + } + + assertThat(shouldRead(isNull(PRICE), builder.build(), shreddingFunc)) + .as("Should read: some rows have null typed_value, IS_NULL might match") + .isTrue(); + } + + @Test + void testShreddedNotNullSomeValues() throws IOException { + // Mix of typed_value (non-null) and null typed_value — NOT_NULL should read + ShreddedObject example = Variants.object(METADATA); + example.put("price", Variants.of(0)); + VariantShreddingFunction shreddingFunc = + (id, name) -> ParquetVariantUtil.toParquetSchema(example); + + ImmutableList.Builder builder = ImmutableList.builder(); + for (int price : new int[] {10, 11, 12}) { + ShreddedObject obj = Variants.object(METADATA); + obj.put("price", Variants.of(price)); + builder.add(Variant.of(METADATA, obj)); + } + for (int i = 0; i < 2; i++) { + ShreddedObject obj = Variants.object(METADATA); + obj.put("price", Variants.ofNull()); // null typed_value + builder.add(Variant.of(METADATA, obj)); + } + + assertThat(shouldRead(notNull(PRICE), builder.build(), shreddingFunc)) + .as("Should read: some rows have non-null typed_value, NOT_NULL might match") + .isTrue(); + } + + @Test + void testShreddedStringMaxIsAllMaxCodepoints() throws IOException { + // U+10FFFF is the highest Unicode code point; its UTF-8 encoding ends in bytes 0xBF 0xBF, + // which approach the 0xFF boundary. A string of 17 such characters is longer than the + // Iceberg-metrics truncation limit of 16 code points, so UnicodeUtil.truncateStringMax + // returns null (no valid upper bound can be computed by incrementing any code point because + // every code point is already at the maximum). This tests that the Parquet column chunk + // statistics — which store the exact bytes, not a truncated approximation — still allow the + // filter to make correct skip decisions. + final String maxCodepoint = "\uDBFF\uDFFF"; // U+10FFFF as a Java surrogate pair + final String allMax = maxCodepoint.repeat(17); // > 16 code points: truncation returns null + List variants = nameStringVariants("alpha", allMax); + VariantShreddingFunction shreddingFunc = nameStringShreddingFunc(); + + final UnboundTerm term = NAME; + + // Exact Parquet stats: max is known, so GT(max) can skip + assertThat(shouldRead(greaterThan(term, allMax), variants, shreddingFunc)) + .as("Should skip: nothing is greater than the all-U+10FFFF max") + .isFalse(); + + // EQ at the exact max: value is present, cannot skip + assertThat(shouldRead(equal(term, allMax), variants, shreddingFunc)) + .as("Should read: all-U+10FFFF string is the max value") + .isTrue(); + + // One extra U+10FFFF pushes the literal beyond the stored max + assertThat(shouldRead(equal(term, allMax + maxCodepoint), variants, shreddingFunc)) + .as("Should skip: literal is strictly greater than the max") + .isFalse(); + + // LT at min: nothing below "alpha" + assertThat(shouldRead(lessThan(term, "alpha"), variants, shreddingFunc)) + .as("Should skip: nothing is less than min 'alpha'") + .isFalse(); + + // Mid-range value: within [alpha, allMax] + assertThat(shouldRead(equal(term, "beta"), variants, shreddingFunc)) + .as("Should read: 'beta' is within [alpha, all-U+10FFFF]") + .isTrue(); + } + + @Test + void testShreddedSingleValueAtExactBoundary() throws IOException { + // When a row group contains only one distinct value (min == max), predicates at that exact + // boundary must be evaluated correctly: LT and GT should skip (the value cannot satisfy them), + // while LT_EQ, GT_EQ, and EQ should read (the value might satisfy them). + List variants = intPriceVariants(10, 10, 10); // all rows have price = 10 + final UnboundTerm term = PRICE; + + // min=10, LT(10): minVsLiteral=0, 0 < 0 is false → CANNOT_MATCH + assertNotRead(lessThan(term, 10), variants, "< 10 when all values are 10"); + // max=10, GT(10): literalVsMax=0, 0 < 0 is false → CANNOT_MATCH + assertNotRead(greaterThan(term, 10), variants, "> 10 when all values are 10"); + // min=10, LT_EQ(10): minVsLiteral=0, 0 <= 0 is true → MIGHT_MATCH + assertIsRead(lessThanOrEqual(term, 10), variants, "<= 10 when all values are 10"); + // max=10, GT_EQ(10): literalVsMax=0, 0 <= 0 is true → MIGHT_MATCH + assertIsRead(greaterThanOrEqual(term, 10), variants, ">= 10 when all values are 10"); + // EQ(10): both comparisons are 0 → MIGHT_MATCH + assertIsRead(equal(term, 10), variants, "= 10 when all values are 10"); + } + + @Test + void testShreddedAndCompoundPredicate() throws IOException { + // AND over two variant extract predicates: the row group can be skipped if either arm cannot + // match, and must be read only when both arms might match. + List variants = intPriceVariants(10, 11, 12); // min=10, max=12 + final UnboundTerm term = PRICE; + + // Both arms overlap the range [10,12] → MIGHT_MATCH + assertIsRead( + and(greaterThan(term, 5), lessThan(term, 15)), variants, "5 < price < 15 overlaps [10,12]"); + // two predicates will be evaluated here, min and max. + // this highlights at an in-range query will be expensive on unshredded numbers + // as they will need to be read twice. + assertShreddedMetricsProcessed(2); + + // GT(20) cannot match (max=12 < 20) → AND short-circuits to CANNOT_MATCH + assertNotRead( + and(greaterThan(term, 20), lessThan(term, 25)), + variants, + "20 < price < 25 is above the range [10,12]"); + + // LT(8) cannot match (min=10 > 8) → AND short-circuits to CANNOT_MATCH + assertNotRead( + and(greaterThan(term, 5), lessThan(term, 8)), + variants, + "5 < price < 8 is below the range [10,12]"); + } + + @Test + void testShreddedSetMembership() throws IOException { + List variants = intPriceVariants(10, 11, 12); // min=10, max=12 + final UnboundTerm term = PRICE; + + // All set values are below the range — row group can be skipped + assertNotRead(in(term, 1, 2, 3), variants, "IN {1,2,3} with range [10..12]"); + assertShreddedMetricsProcessed(1); + + // All set values are above the range — row group can be skipped + assertNotRead(in(term, 20, 30), variants, "IN {20,30} with range [10..12]"); + // Single value that is below the range — row group can be skipped + assertNotRead(in(term, 5), variants, "IN {5} with range [10..12]"); + assertNotRead(in(term, -100, 400), variants, "IN {-100, 400} with range [10..12]"); + // At least one set value is within the range — row group must be read + assertIsRead(in(term, 9, 10), variants, "IN {9,10} with range [10..12]"); + assertIsRead(in(term, 12, 13), variants, "IN {12,13} with range [10..12]"); + assertIsRead(in(term, 10, 11, 12), variants, "IN {10,11,12} with range [10..12]"); + // Exact boundary matches — row group must be read + assertIsRead(in(term, 10), variants, "IN {10} with range [10..12]"); + assertIsRead(in(term, 12), variants, "IN {12} with range [10..12]"); + } + + @Test + void testShreddedUUIDEqual() throws IOException { + // Row group has deviceid range [UUID_LOW, UUID_HIGH] + List variants = uuidDeviceIdVariants(UUID_LOW, UUID_MID, UUID_HIGH); + final UnboundTerm term = DEVICEID_UUID; + + // EQ with a value strictly below the range — row group can be skipped + assertNotRead(equal(term, UUID_ZERO), variants, "= nil UUID, range [LOW, HIGH]"); + + // EQ with a value strictly above the range — row group can be skipped + assertNotRead(equal(term, UUID_ABOVE_HIGH), variants, "= UUID 100, range [LOW, HIGH]"); + + // EQ with a value within the range — row group must be read + assertIsRead(equal(term, UUID_MID), variants, "= UUID_MID within [LOW, HIGH]"); + + // EQ with exact boundary values — row group must be read + assertIsRead(equal(term, UUID_LOW), variants, "= UUID_LOW, lower boundary"); + assertIsRead(equal(term, UUID_HIGH), variants, "= UUID_HIGH, upper boundary"); + } + + @Test + void testShreddedUUIDNotEqual() throws IOException { + List variants = uuidDeviceIdVariants(UUID_LOW, UUID_MID, UUID_HIGH); + final UnboundTerm term = DEVICEID_UUID; + + // NOT_EQ never uses min/max to skip — always MIGHT_MATCH + assertIsRead(not(equal(term, UUID_MID)), variants, "!= UUID_MID with range [LOW, HIGH]"); + assertIsRead( + not(equal(term, UUID_MID)), + uuidDeviceIdVariants(UUID_MID, UUID_MID, UUID_MID), + "!= UUID_MID, all values equal UUID_MID"); + } + + @Test + void testShreddedUUIDIn() throws IOException { + ParquetMetricsRowGroupFilter.resetShreddedMetricsCounter(); + + // Row group has deviceid range [UUID_LOW, UUID_HIGH] + List variants = uuidDeviceIdVariants(UUID_LOW, UUID_MID, UUID_HIGH); + final UnboundTerm term = DEVICEID_UUID; + + // All set values are below the range — row group can be skipped + assertNotRead(in(term, UUID_ZERO), variants, "IN {nil UUID}, range [LOW, HIGH]"); + int expected = 1; + assertShreddedMetricsProcessed(expected++); + + // All set values are above the range — row group can be skipped + assertNotRead(in(term, UUID_ABOVE_HIGH), variants, "IN {UUID 100}, range [LOW, HIGH]"); + assertShreddedMetricsProcessed(expected++); + + assertNotRead( + in(term, UUID_ABOVE_HIGH, UUID.fromString("00000000-0000-0000-0000-0000000000c8")), + variants, + "IN {UUID 100, UUID 200}, range [LOW, HIGH]"); + assertShreddedMetricsProcessed(expected++); + + // At least one set value is within the range — row group must be read + assertIsRead(in(term, UUID_MID), variants, "IN {UUID_MID} within [LOW, HIGH]"); + assertShreddedMetricsProcessed(expected++); + assertIsRead( + in(term, UUID_ZERO, UUID_MID), variants, "IN {below, UUID_MID} straddles [LOW, HIGH]"); + assertShreddedMetricsProcessed(expected++); + assertIsRead( + in(term, UUID_LOW, UUID_MID, UUID_HIGH), + variants, + "IN {LOW, MID, HIGH} covers [LOW, HIGH]"); + assertShreddedMetricsProcessed(expected++); + + // Exact boundary values — row group must be read + assertIsRead(in(term, UUID_LOW), variants, "IN {UUID_LOW}, lower boundary"); + assertShreddedMetricsProcessed(expected++); + assertIsRead(in(term, UUID_HIGH), variants, "IN {UUID_HIGH}, upper boundary"); + assertShreddedMetricsProcessed(expected++); + } + + private static void assertShreddedMetricsProcessed(final int expected) { + assertThat(ParquetMetricsRowGroupFilter.variantPredicatesShreddedMetrics()) + .describedAs("Count of shredded metrics filtered on in predicates") + .isEqualTo(expected); + } + + // --- helpers --- + + private List intPriceVariants(int... prices) { + ImmutableList.Builder builder = ImmutableList.builder(); + for (int price : prices) { + ShreddedObject obj = Variants.object(METADATA); + obj.put("price", Variants.of(price)); + builder.add(Variant.of(METADATA, obj)); + } + return builder.build(); + } + + private List longPriceVariants(long... prices) { + ImmutableList.Builder builder = ImmutableList.builder(); + for (long price : prices) { + ShreddedObject obj = Variants.object(METADATA); + obj.put("price", Variants.of(price)); + builder.add(Variant.of(METADATA, obj)); + } + return builder.build(); + } + + private List floatPriceVariants(float... prices) { + ImmutableList.Builder builder = ImmutableList.builder(); + for (float price : prices) { + ShreddedObject obj = Variants.object(METADATA); + obj.put("price", Variants.of(price)); + builder.add(Variant.of(METADATA, obj)); + } + return builder.build(); + } + + private List doublePriceVariants(double... prices) { + ImmutableList.Builder builder = ImmutableList.builder(); + for (double price : prices) { + ShreddedObject obj = Variants.object(METADATA); + obj.put("price", Variants.of(price)); + builder.add(Variant.of(METADATA, obj)); + } + return builder.build(); + } + + private VariantShreddingFunction floatPriceShreddingFunc() { + ShreddedObject example = Variants.object(METADATA); + example.put("price", Variants.of(0.0F)); + return (id, name) -> ParquetVariantUtil.toParquetSchema(example); + } + + private VariantShreddingFunction doublePriceShreddingFunc() { + ShreddedObject example = Variants.object(METADATA); + example.put("price", Variants.of(0.0D)); + return (id, name) -> ParquetVariantUtil.toParquetSchema(example); + } + + private List nameStringVariants(String... names) { + ImmutableList.Builder builder = ImmutableList.builder(); + for (String name : names) { + ShreddedObject obj = Variants.object(METADATA); + obj.put("name", Variants.of(name)); + builder.add(Variant.of(METADATA, obj)); + } + return builder.build(); + } + + private VariantShreddingFunction nameStringShreddingFunc() { + ShreddedObject example = Variants.object(METADATA); + example.put("name", Variants.of("x")); + return (id, name) -> ParquetVariantUtil.toParquetSchema(example); + } + + private List uuidDeviceIdVariants(UUID... uuids) { + ImmutableList.Builder builder = ImmutableList.builder(); + for (UUID uuid : uuids) { + ShreddedObject obj = Variants.object(UUID_METADATA); + obj.put("deviceid", Variants.ofUUID(uuid)); + builder.add(Variant.of(UUID_METADATA, obj)); + } + return builder.build(); + } + + /** + * Should the expression, when evaluated againt the fully shredded set of veriants, require the + * RowGroup to be read? + * + * @param expr expression + * @param variants list of variants + */ + private boolean shouldRead(Expression expr, List variants) throws IOException { + // Derive the shredding schema from the first variant's structure + VariantShreddingFunction shreddingFunc = + (id, name) -> ParquetVariantUtil.toParquetSchema(variants.get(0).value()); + return shouldRead(expr, variants, shreddingFunc); + } + + /** + * Should a set of variants, shred with the supplied shredding function, be read? + * + * @param expr expression + * @param variants list of variants + * @param shreddingFunc shredding function + * @return true if a file containing only these variants should be read. + */ + private boolean shouldRead( + Expression expr, List variants, VariantShreddingFunction shreddingFunc) + throws IOException { + OutputFile out = new InMemoryOutputFile(); + GenericRecord record = GenericRecord.create(SCHEMA); + + FileAppender writer = + Parquet.write(out) + .schema(SCHEMA) + .variantShreddingFunc(shreddingFunc) + .createWriterFunc(fileSchema -> InternalWriter.create(SCHEMA.asStruct(), fileSchema)) + .build(); + + try (writer) { + for (int i = 0; i < variants.size(); i++) { + record.setField("id", (long) i); + record.setField("var", variants.get(i)); + writer.add(record); + } + } + + try (ParquetFileReader reader = ParquetFileReader.open(ParquetIO.file(out.toInputFile()))) { + BlockMetaData rowGroup = reader.getRowGroups().get(0); + MessageType fileSchema = reader.getFileMetaData().getSchema(); + return new ParquetMetricsRowGroupFilter(SCHEMA, expr, true).shouldRead(fileSchema, rowGroup); + } + } +} diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java index 2f9e845d613a..9ca5d6f61b7a 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java @@ -452,19 +452,15 @@ private static UnboundTerm toTerm(T input) { private static boolean isVariantGetFunc( org.apache.spark.sql.connector.expressions.Expression expr) { - if (!(expr instanceof UserDefinedScalarFunc)) { + if (!(expr instanceof UserDefinedScalarFunc udf)) { return false; } - UserDefinedScalarFunc udf = (UserDefinedScalarFunc) expr; String name = udf.name().toLowerCase(Locale.ROOT); - if (!name.equals("variant_get") && !name.equals("try_variant_get")) { - return false; - } - org.apache.spark.sql.connector.expressions.Expression[] children = udf.children(); - return children.length == 3 - && isRef(children[0]) - && isLiteral(children[1]) - && isLiteral(children[2]); + return ("variant_get".equals(name) || "try_variant_get".equals(name)) + && udf.children().length == 3 + && isRef(udf.children()[0]) + && isLiteral(udf.children()[1]) + && isLiteral(udf.children()[2]); } private static UnboundTerm variantGetToTerm(UserDefinedScalarFunc udf) { @@ -481,15 +477,11 @@ private static UnboundTerm variantGetToTerm(UserDefinedScalarFunc udf) { } private static String sparkTypeNameToIceberg(String sparkTypeName) { - switch (sparkTypeName.toLowerCase(Locale.ROOT)) { - case "bigint": - return "long"; - case "tinyint": - case "smallint": - return "int"; - default: - return sparkTypeName; - } + return switch (sparkTypeName.toLowerCase(Locale.ROOT)) { + case "bigint" -> "long"; + case "tinyint", "smallint" -> "int"; + default -> sparkTypeName; + }; } @SuppressWarnings("checkstyle:CyclomaticComplexity") diff --git a/spark/v4.1/spark/src/test/java/org/apache/iceberg/parquet/TestSparkVariantFilterPushDown.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/parquet/TestSparkVariantFilterPushDown.java new file mode 100644 index 000000000000..79eb62a3ef94 --- /dev/null +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/parquet/TestSparkVariantFilterPushDown.java @@ -0,0 +1,456 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.parquet; + +import static org.apache.iceberg.PlanningMode.DISTRIBUTED; +import static org.apache.iceberg.PlanningMode.LOCAL; +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.LongStream; +import org.apache.iceberg.Parameter; +import org.apache.iceberg.ParameterizedTestExtension; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.PlanningMode; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.spark.TestBaseWithCatalog; +import org.apache.spark.sql.execution.SparkPlan; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Integration tests for Spark SQL variant predicate pushdown via the Iceberg data source. After + * creating a table with variants (parameterization includes shredded/unshredded), queries are + * issued against the table. + * + *

Each test verifies both the rows returned (correctness) and what pushed Iceberg scan filters + * were present in the physical plan. + * + *

They also make assertions on how many times {@code + * ParquetMetricsRowGroupFilter.compareVariant()} has been invoked on shredded data, which is why it + * needs to be in the same package as that class. + */ +@ExtendWith(ParameterizedTestExtension.class) +public class TestSparkVariantFilterPushDown extends TestBaseWithCatalog { + + public static final String SPARK_VARIANT_GET = + "variant_get(nested, $.varcategory, IntegerType, true, Some(UTC))"; + public static final String ICEBERG_VARCAT = "variant_get(nested, '$.varcategory', 'int')"; + public static final String SPARK_ISNOTNULL_NESTED = "isnotnull(nested)"; + public static final String ICEBERG_NESTED_ISNOTNULL = "nested IS NOT NULL"; + + @Parameters( + name = + "catalogName = {0}, implementation = {1}, config = {2}, planningMode = {3} shredded= {4}") + public static Object[][] parameters() { + return new Object[][] { + // unshredded: the reference + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties(), + LOCAL, + false + }, + // local planning and shredded + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties(), + LOCAL, + true + }, + // distributed planning and shredded + { + SparkCatalogConfig.HADOOP.catalogName(), + SparkCatalogConfig.HADOOP.implementation(), + SparkCatalogConfig.HADOOP.properties(), + DISTRIBUTED, + true + }, + }; + } + + public static final Logger LOG = LoggerFactory.getLogger(TestSparkVariantFilterPushDown.class); + + @Parameter(index = 3) + private PlanningMode planningMode; + + /** Should the variant be shredded? */ + @Parameter(index = 4) + private boolean shredded; + + /** Number of categories; each row has {@code category = id}. */ + private static final int NUM_CATEGORIES = 20; + + public TestSparkVariantFilterPushDown() {} + + @BeforeEach + public void createTable() { + LOG.info("Creating Spark Table with shredding {}", shredded); + if (shredded) { + spark.conf().set(SparkSQLProperties.SHRED_VARIANTS, "true"); + } + sql( + """ + CREATE TABLE %s (id BIGINT, category INT, nested VARIANT, arr VARIANT) + USING iceberg + TBLPROPERTIES ('format-version'='3', + 'read.parquet.vectorization.enabled'='false', + 'write.parquet.shred-variants'='%s')""", + selectTarget(), shredded); + configurePlanningMode(planningMode); + buildDataset(); + } + + @AfterEach + public void dropTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + /** + * Build a dataset in sql, using parse_json to create the variant data. Records are {@code (id, + * category, parse_json(nested_json), parse_json(arr_json))} + */ + private void buildDataset() { + + String values = + IntStream.range(0, NUM_CATEGORIES) + .mapToObj( + n -> + String.format( + "(%d, %d, parse_json('{\"varid\": %d, \"varcategory\": %d}'), parse_json('[%d]'))", + (long) n, n, n, n, n)) + .collect(Collectors.joining(", ")); + + sql("INSERT INTO %s VALUES %s", selectTarget(), values); + } + + /** + * Baseline: filter on a plain INT column, project the id. Iceberg pushes the predicate fully to + * the scan; Spark still evaluates the predicate post-scan as a safety check. + */ + @TestTemplate + public void filterCategoryProjectId() { + withDefaultTimeZone( + "UTC", + () -> + checkFilters( + "id", + "category = 5", + "isnotnull(category) AND (category = 5)", + "category IS NOT NULL, category = 5", + 0, 0, + ImmutableList.of(row(5L)))); + } + + /** + * Filter using variant field extraction (equality). Iceberg pushes both a null check and the + * equality predicate on the variant column; Spark also evaluates the filter post-scan. + */ + @TestTemplate + public void filterVariantCategoryProjectId() { + withDefaultTimeZone( + "UTC", + () -> + checkFilters( + "id", + ICEBERG_VARCAT + " = 5", + SPARK_ISNOTNULL_NESTED + " AND (" + SPARK_VARIANT_GET + " = 5)", + ICEBERG_NESTED_ISNOTNULL + ", " + ICEBERG_VARCAT + " = 5", + 2, 2, + ImmutableList.of(row(5L)))); + } + + /** Use the greater than and less than predicates in a query. Doubles the number of scans. */ + @TestTemplate + public void filterVariantCategoryInRange() { + withDefaultTimeZone( + "UTC", + () -> + checkFilters( + "id", + ICEBERG_VARCAT + " > 4 AND " + ICEBERG_VARCAT + " < 7", + "(" + + SPARK_ISNOTNULL_NESTED + + " AND (" + + SPARK_VARIANT_GET + + " > 4))" + + " AND (" + + SPARK_VARIANT_GET + + " < 7)", + ICEBERG_NESTED_ISNOTNULL + + ", " + + ICEBERG_VARCAT + + " > 4, " + + ICEBERG_VARCAT + + " < 7", + 4, 4, + ImmutableList.of(row(5L), row(6L)))); + } + + /** Use the greater than and less than predicates in a query. Doubles the number of scans. */ + @TestTemplate + public void filterVariantCategoryGreateThanEquals() { + withDefaultTimeZone( + "UTC", + () -> + checkFilters( + "id", + ICEBERG_VARCAT + " >= 4 AND " + ICEBERG_VARCAT + " <= 7", + "(" + + SPARK_ISNOTNULL_NESTED + + " AND (" + + SPARK_VARIANT_GET + + " >= 4))" + + " AND (" + + SPARK_VARIANT_GET + + " <= 7)", + ICEBERG_NESTED_ISNOTNULL + + ", " + + ICEBERG_VARCAT + + " >= 4, " + + ICEBERG_VARCAT + + " <= 7", + 4, 4, + ImmutableList.of(row(4L), row(5L), row(6L), row(7L)))); + } + + /** + * Project a variant field and filter on a different variant field. Iceberg pushes both a null + * check and the equality predicate; the projection is evaluated post-scan. + */ + @TestTemplate + public void filterVariantCategoryProjectVariantId() { + withDefaultTimeZone( + "UTC", + () -> + checkFilters( + "variant_get(nested, '$.varid', 'int')", + ICEBERG_VARCAT + " = 5", + SPARK_ISNOTNULL_NESTED + " AND (" + SPARK_VARIANT_GET + " = 5)", + ICEBERG_NESTED_ISNOTNULL + ", " + ICEBERG_VARCAT + " = 5", + 2, 2, + ImmutableList.of(row(5)))); + } + + /** IN predicate on a variant field using {@code variant_get}. */ + @TestTemplate + public void filterVariantCategorySetMembership() { + withDefaultTimeZone( + "UTC", + () -> + checkFilters( + "id", + ICEBERG_VARCAT + " IN (5, 10)", + SPARK_VARIANT_GET + " IN (5,10)", + ICEBERG_VARCAT + " IN (5, 10)", // no null check + 4, 4, + ImmutableList.of(row(5L), row(10L)))); + } + + /** + * Set members are all above or below the categories. The filter string this produces is slightly + * different from that of {@link #filterVariantCategorySetMembership()}. + */ + @TestTemplate + public void filterVariantCategorySetMembership2() { + withDefaultTimeZone( + "UTC", + () -> + checkFilters( + "id", + ICEBERG_VARCAT + " IN (100, 400)", + SPARK_VARIANT_GET + " IN (100,400)", + "", + 0, 0, + ImmutableList.of())); + } + + /** + * Set membership of a single element is remapped to equality, filtering takes place in + * planning. + */ + @TestTemplate + public void filterVariantCategorySetMembership3() { + withDefaultTimeZone( + "UTC", + () -> + checkFilters( + "id", + ICEBERG_VARCAT + " IN (100)", + SPARK_ISNOTNULL_NESTED + " AND (" + SPARK_VARIANT_GET + " = 100)", + "", + 0, 0, + ImmutableList.of())); + } + + /** + * Set membership of a single element is remapped to equality, and if that element is in range, a + * pushed down predicate is evaluated. + */ + @TestTemplate + public void filterVariantCategorySetMembership4() { + withDefaultTimeZone( + "UTC", + () -> + checkFilters( + "id", + ICEBERG_VARCAT + " IN (4)", + SPARK_ISNOTNULL_NESTED + " AND (" + SPARK_VARIANT_GET + " = 4)", + ICEBERG_NESTED_ISNOTNULL + ", " + ICEBERG_VARCAT + " = 4", + 2, 2, + ImmutableList.of(row(4L)))); + } + + /** + * Evaluation of the IS NULL predicate. + */ + @TestTemplate + public void filterVariantCategoryIsNull() { + withDefaultTimeZone( + "UTC", + () -> + checkFilters( + "id", + ICEBERG_VARCAT + " IS NULL", + "isnull(" + SPARK_VARIANT_GET + ")", + "", + 4, 4, + ImmutableList.of())); + } + + /** + * Evaluation of the IS NOT NULL predicate; this finds everything. + */ + @TestTemplate + public void filterVariantCategoryIsNotNull() { + List rows = LongStream.rangeClosed(0, 19) + .mapToObj(this::row) + .toList(); + withDefaultTimeZone( + "UTC", + () -> + checkFilters( + "id", + ICEBERG_VARCAT + " IS NOT NULL", + SPARK_ISNOTNULL_NESTED + " AND isnotnull(" + SPARK_VARIANT_GET + ")", + "nested IS NOT NULL, variant_get(nested, '$.varcategory', 'int') IS NOT NULL", + 4, 4, + rows)); + } + + /** + * Filter on element 0 of an array variant. Iceberg pushes a null check on the array column; the + * actual element comparison is done post-scan. + */ + @TestTemplate + public void filterArrayElementProjectId() { + withDefaultTimeZone( + "UTC", + () -> + checkFilters( + "id", + "variant_get(arr, '$[0]', 'int') = 5", + "isnotnull(arr) AND (variant_get(arr, $[0], IntegerType, true, Some(UTC)) = 5)", + "arr IS NOT NULL", + 0, 0, + ImmutableList.of(row(5L)))); + } + + // --------------------------------------------------------------------------- + // Helpers + // --------------------------------------------------------------------------- + + /** + * Run {@code SELECT FROM WHERE ORDER BY id}, assert the returned + * rows, and verify that the physical plan contains the expected Spark post-scan filter and + * Iceberg scan filters. + * + * @param projection column expression(s) to select + * @param predicate SQL WHERE clause (no "WHERE" keyword) + * @param sparkFilter expected post-scan Spark Filter node text + * @param icebergFilters expected {@code filters=...} value from the Iceberg scan node; empty + * string means no Iceberg pushdown shall take places. + * @param expectedPlanningEvaluations expected number of evaluations during planning + * @param expectedExecutionEvaluations number of evaluations of a rowgroup filter predicate + * on shredded column. + * @param expectedRows expected result rows in id order + */ + private void checkFilters( + String projection, + String predicate, + String sparkFilter, + String icebergFilters, + int expectedPlanningEvaluations, + int expectedExecutionEvaluations, + List expectedRows) { + + ParquetMetricsRowGroupFilter.resetShreddedMetricsCounter(); + String query = + String.format("SELECT %s FROM %s WHERE %s ORDER BY id", projection, tableName, predicate); + + SparkPlan plan = executeAndKeepPlan(query); + long planShredCount = ParquetMetricsRowGroupFilter.variantPredicatesShreddedMetrics(); + String planString = plan.toString().replaceAll("#\\d+L?", ""); + String summary = String.format("%s with plan shred count %d", query, planShredCount); + + assertThat(planString) + .as("Post-scan Spark filter of %s", summary) + .containsAnyOf("Filter (" + sparkFilter + ")", "Filter " + sparkFilter); + + if (!icebergFilters.isEmpty()) { + assertThat(planString).as("No iceberg scan generated from %s", summary).contains("IcebergScan"); + assertThat(planString) + .as("Iceberg pushed filters of must match from %s", summary) + .contains(", filters=" + icebergFilters + ","); + } else { + assertThat(planString) + .as("No iceberg scan generated from %s", summary) + .doesNotContain("IcebergScan"); + } + + if (shredded) { + assertThat(ParquetMetricsRowGroupFilter.variantPredicatesShreddedMetrics()) + .describedAs("Count of shredded metrics filtered during planning of of %s to plan %s", + summary, planString) + .isEqualTo(expectedPlanningEvaluations); + } + ParquetMetricsRowGroupFilter.resetShreddedMetricsCounter(); + final List rows = sql("SELECT %s FROM %s WHERE %s ORDER BY id", projection, selectTarget(), predicate); + assertEquals( + "Execution of " + summary + " to plan " + planString, + expectedRows, + rows); + if (shredded) { + assertThat(ParquetMetricsRowGroupFilter.variantPredicatesShreddedMetrics()) + .describedAs("Count of shredded metrics filtered during execution of of %s to plan %s", + summary, planString) + .isEqualTo(expectedExecutionEvaluations); + } + } +} diff --git a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSpark3Util.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSpark3Util.java index e4e66abfefa0..596375809a23 100644 --- a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSpark3Util.java +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSpark3Util.java @@ -24,6 +24,7 @@ import static org.apache.iceberg.expressions.Expressions.bucket; import static org.apache.iceberg.expressions.Expressions.day; import static org.apache.iceberg.expressions.Expressions.equal; +import static org.apache.iceberg.expressions.Expressions.extract; import static org.apache.iceberg.expressions.Expressions.greaterThan; import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual; import static org.apache.iceberg.expressions.Expressions.hour; @@ -165,6 +166,21 @@ public void testDescribeExpression() { assertThat(Spark3Util.describe(andExpression)).isEqualTo("(id = 1 AND year(ts) > 10)"); } + @Test + public void testDescribeExtractExpression() { + Expression extractGt = greaterThan(extract("v", "$.city", "string"), "Boston"); + assertThat(Spark3Util.describe(extractGt)) + .isEqualTo("variant_get(v, '$.city', 'string') > 'Boston'"); + + Expression extractEq = equal(extract("v", "$.event.id", "long"), 42L); + assertThat(Spark3Util.describe(extractEq)) + .isEqualTo("variant_get(v, '$.event.id', 'long') = 42"); + + Expression extractIn = in(extract("v", "$.city", "string"), "NYC", "LA"); + assertThat(Spark3Util.describe(extractIn)) + .isEqualTo("variant_get(v, '$.city', 'string') IN ('NYC', 'LA')"); + } + private SortOrder buildSortOrder(String transform, Schema schema, int sourceId) { String jsonString = "{\n" diff --git a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSparkV2Filters.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSparkV2Filters.java index e0b590e5a6e8..0538b8130771 100644 --- a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSparkV2Filters.java +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/TestSparkV2Filters.java @@ -652,6 +652,66 @@ public void testTruncate() { testUDF(udf, Expressions.truncate("strCol", 6), "prefix", DataTypes.StringType); } + @Test + public void testVariantGet() { + UserDefinedScalarFunc udf = + new UserDefinedScalarFunc( + "variant_get", + "variant_get", + expressions( + FieldReference.apply("v"), + LiteralValue.apply(UTF8String.fromString("$.city"), DataTypes.StringType), + LiteralValue.apply(UTF8String.fromString("string"), DataTypes.StringType))); + testUDF(udf, Expressions.extract("v", "$.city", "string"), "NYC", DataTypes.StringType); + } + + @Test + public void testTryVariantGet() { + UserDefinedScalarFunc udf = + new UserDefinedScalarFunc( + "try_variant_get", + "try_variant_get", + expressions( + FieldReference.apply("v"), + LiteralValue.apply(UTF8String.fromString("$.city"), DataTypes.StringType), + LiteralValue.apply(UTF8String.fromString("string"), DataTypes.StringType))); + testUDF(udf, Expressions.extract("v", "$.city", "string"), "NYC", DataTypes.StringType); + } + + @Test + public void testVariantGetNestedPath() { + UserDefinedScalarFunc udf = + new UserDefinedScalarFunc( + "variant_get", + "variant_get", + expressions( + FieldReference.apply("v"), + LiteralValue.apply(UTF8String.fromString("$.event.id"), DataTypes.StringType), + LiteralValue.apply(UTF8String.fromString("long"), DataTypes.StringType))); + testUDF(udf, Expressions.extract("v", "$.event.id", "long"), 42L, DataTypes.LongType); + } + + @Test + public void testVariantGetInPredicate() { + UserDefinedScalarFunc udf = + new UserDefinedScalarFunc( + "variant_get", + "variant_get", + expressions( + FieldReference.apply("v"), + LiteralValue.apply(UTF8String.fromString("$.city"), DataTypes.StringType), + LiteralValue.apply(UTF8String.fromString("string"), DataTypes.StringType))); + org.apache.spark.sql.connector.expressions.Expression[] attrAndValues = + expressions( + udf, + LiteralValue.apply(UTF8String.fromString("NYC"), DataTypes.StringType), + LiteralValue.apply(UTF8String.fromString("LA"), DataTypes.StringType)); + Predicate in = new Predicate("IN", attrAndValues); + Expression actual = SparkV2Filters.convert(in); + Expression expected = Expressions.in(Expressions.extract("v", "$.city", "string"), "NYC", "LA"); + assertEquals(expected, actual); + } + @Test public void testUnsupportedUDFConvert() { ScalarFunction icebergVersionFunc = From b06b4e9a274222a313547914f4786c72dbb1171c Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Wed, 20 May 2026 20:55:18 +0100 Subject: [PATCH 3/4] ParquetMetricsRowGroupFilter enhancement: now counting RGs skipped. Allows for assertions in tests and in benchmarks that rowgroup skipping is taking place. Needed as there's not much tangible speedup, yet Change-Id: I8c03eb33d2d3d8a2139c347e6a72a7284e627f62 --- .../parquet/ParquetMetricsRowGroupFilter.java | 43 +++++++-- .../TestShreddedVariantRowGroupFilter.java | 89 ++++++++++++++++++- .../TestSparkVariantFilterPushDown.java | 85 ++++++++++-------- 3 files changed, 166 insertions(+), 51 deletions(-) diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetMetricsRowGroupFilter.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetMetricsRowGroupFilter.java index f5aefbe56726..51606608b59f 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetMetricsRowGroupFilter.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetMetricsRowGroupFilter.java @@ -83,7 +83,12 @@ public class ParquetMetricsRowGroupFilter { *

This permits assertions to be made that variant predicate pushdown reached this far and * processed shredded columns. */ - private static final AtomicLong VARIANT_PREDICATES_SHREDDED_METRICS = new AtomicLong(); + private static final AtomicLong VARIANT_PREDICATES_SHREDDED_METRICS_EVALUATED = new AtomicLong(); + + /** + * Counter for row groups proven skippable by a shredded variant predicate. + */ + private static final AtomicLong VARIANT_PREDICATES_SHREDDED_SKIPPED = new AtomicLong(); public ParquetMetricsRowGroupFilter(Schema schema, Expression unbound) { this(schema, unbound, true); @@ -632,7 +637,7 @@ private T max(Statistics statistics, int id) { public Boolean predicate(BoundPredicate pred) { if (pred.term() instanceof BoundExtract term) { // it's a variant predicate: process accordingly. - return compareVariant(pred, term); + return recordOutcome(compareVariant(pred, term)); } else { return super.predicate(pred); } @@ -668,7 +673,7 @@ private boolean compareVariant(BoundPredicate pred, BoundExtract extra return ROWS_MIGHT_MATCH; } // increment shredded metrics counter. - VARIANT_PREDICATES_SHREDDED_METRICS.incrementAndGet(); + VARIANT_PREDICATES_SHREDDED_METRICS_EVALUATED.incrementAndGet(); // now do the evaluation. LOG.info("Evaluating column {} with info {}", columnPath, columnInfo); @@ -700,6 +705,16 @@ private boolean compareVariant(BoundPredicate pred, BoundExtract extra return evalBinaryPredicateOnShreddedVariant(pred, extract, parquetType, colStats); } + /** + * Increment the skipped counter on {@code ROWS_CANNOT_MATCH} and return the input unchanged. + */ + private boolean recordOutcome(boolean shouldRead) { + if (!shouldRead) { + VARIANT_PREDICATES_SHREDDED_SKIPPED.incrementAndGet(); + } + return shouldRead; + } + /** * Evaluate the statistics, return an Boolean value if there was enough information to make a * decision. @@ -983,13 +998,25 @@ public Boolean handleNonReference(Bound term) { * @return zero or a positive integer */ @VisibleForTesting - static long variantPredicatesShreddedMetrics() { - return VARIANT_PREDICATES_SHREDDED_METRICS.get(); + static long variantPredicatesShreddedMetricsEvaluated() { + return VARIANT_PREDICATES_SHREDDED_METRICS_EVALUATED.get(); + } + + /** + * The number of row groups proven skippable by a shredded variant predicate. Will always be equal to or less than + * the value of {@link #variantPredicatesShreddedMetricsEvaluated()}. + * + * @return zero or a positive integer + */ + @VisibleForTesting + static long variantPredicatesShreddedSkipped() { + return VARIANT_PREDICATES_SHREDDED_SKIPPED.get(); } - /** Reset the shredded metrics counter. */ + /** Reset both shredded metrics counters (examined and skipped). */ @VisibleForTesting - static void resetShreddedMetricsCounter() { - VARIANT_PREDICATES_SHREDDED_METRICS.set(0); + static void resetShreddedMetricsCounters() { + VARIANT_PREDICATES_SHREDDED_METRICS_EVALUATED.set(0); + VARIANT_PREDICATES_SHREDDED_SKIPPED.set(0); } } diff --git a/parquet/src/test/java/org/apache/iceberg/parquet/TestShreddedVariantRowGroupFilter.java b/parquet/src/test/java/org/apache/iceberg/parquet/TestShreddedVariantRowGroupFilter.java index 367025c0b1b6..281962cc6998 100644 --- a/parquet/src/test/java/org/apache/iceberg/parquet/TestShreddedVariantRowGroupFilter.java +++ b/parquet/src/test/java/org/apache/iceberg/parquet/TestShreddedVariantRowGroupFilter.java @@ -104,7 +104,7 @@ class TestShreddedVariantRowGroupFilter { @BeforeEach void before() { - ParquetMetricsRowGroupFilter.resetShreddedMetricsCounter(); + ParquetMetricsRowGroupFilter.resetShreddedMetricsCounters(); } @Test @@ -618,7 +618,7 @@ void testShreddedUUIDNotEqual() throws IOException { @Test void testShreddedUUIDIn() throws IOException { - ParquetMetricsRowGroupFilter.resetShreddedMetricsCounter(); + ParquetMetricsRowGroupFilter.resetShreddedMetricsCounters(); // Row group has deviceid range [UUID_LOW, UUID_HIGH] List variants = uuidDeviceIdVariants(UUID_LOW, UUID_MID, UUID_HIGH); @@ -658,13 +658,94 @@ void testShreddedUUIDIn() throws IOException { assertShreddedMetricsProcessed(expected++); } + // --------------------------------------------------------------------------- + // Skip-counter tests: prove the filter is actually skipping row groups + // --------------------------------------------------------------------------- + + @Test + void testSkipCounterLessThanSkips() throws IOException { + List variants = intPriceVariants(10, 11, 12, 13, 14); + assertThat(shouldRead(lessThan(PRICE, 10), variants)).isFalse(); + assertShreddedSkipped(1); + } + + @Test + void testSkipCounterGreaterThanSkips() throws IOException { + List variants = intPriceVariants(10, 11, 12, 13, 14); + assertThat(shouldRead(greaterThan(PRICE, 14), variants)).isFalse(); + assertShreddedSkipped(1); + } + + @Test + void testSkipCounterEqualBelowRangeSkips() throws IOException { + List variants = intPriceVariants(10, 11, 12, 13, 14); + assertThat(shouldRead(equal(PRICE, 5), variants)).isFalse(); + assertShreddedSkipped(1); + } + + @Test + void testSkipCounterEqualAboveRangeSkips() throws IOException { + List variants = intPriceVariants(10, 11, 12, 13, 14); + assertThat(shouldRead(equal(PRICE, 99), variants)).isFalse(); + assertShreddedSkipped(1); + } + + @Test + void testSkipCounterEqualInRangeDoesNotSkip() throws IOException { + List variants = intPriceVariants(10, 11, 12, 13, 14); + assertThat(shouldRead(equal(PRICE, 12), variants)).isTrue(); + // examined > 0, but no skip + assertShreddedSkipped(0); + assertShreddedMetricsProcessed(1); + } + + @Test + void testSkipCounterIsNullWithNoNullsSkips() throws IOException { + List variants = intPriceVariants(10, 11, 12); + assertThat(shouldRead(isNull(PRICE), variants)).isFalse(); + assertShreddedSkipped(1); + } + + @Test + void testSkipCounterNotNullAllNullsSkips() throws IOException { + ImmutableList.Builder builder = ImmutableList.builder(); + for (int i = 0; i < 3; i++) { + ShreddedObject obj = Variants.object(METADATA); + obj.put("price", Variants.ofNull()); + builder.add(Variant.of(METADATA, obj)); + } + List variants = builder.build(); + ShreddedObject example = Variants.object(METADATA); + example.put("price", Variants.of(0)); + VariantShreddingFunction shreddingFunc = + (id, name) -> ParquetVariantUtil.toParquetSchema(example); + + assertThat(shouldRead(notNull(PRICE), variants, shreddingFunc)).isFalse(); + assertShreddedSkipped(1); + } + + @Test + void testSkipCounterUnshreddedPathDoesNotSkip() throws IOException { + // $.name isn't in the shredded schema → falls back to MIGHT_MATCH before consulting stats → + // the skipped counter must NOT advance (we don't count fallbacks as skips). + List variants = intPriceVariants(10, 11, 12); + assertThat(shouldRead(equal(NAME, "foo"), variants)).isTrue(); + assertShreddedSkipped(0); + } + + // --- helpers --- + private static void assertShreddedMetricsProcessed(final int expected) { - assertThat(ParquetMetricsRowGroupFilter.variantPredicatesShreddedMetrics()) + assertThat(ParquetMetricsRowGroupFilter.variantPredicatesShreddedMetricsEvaluated()) .describedAs("Count of shredded metrics filtered on in predicates") .isEqualTo(expected); } - // --- helpers --- + private static void assertShreddedSkipped(final long expected) { + assertThat(ParquetMetricsRowGroupFilter.variantPredicatesShreddedSkipped()) + .describedAs("Count of row groups skipped by shredded variant predicates") + .isEqualTo(expected); + } private List intPriceVariants(int... prices) { ImmutableList.Builder builder = ImmutableList.builder(); diff --git a/spark/v4.1/spark/src/test/java/org/apache/iceberg/parquet/TestSparkVariantFilterPushDown.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/parquet/TestSparkVariantFilterPushDown.java index 79eb62a3ef94..00f44b6c9fca 100644 --- a/spark/v4.1/spark/src/test/java/org/apache/iceberg/parquet/TestSparkVariantFilterPushDown.java +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/parquet/TestSparkVariantFilterPushDown.java @@ -164,7 +164,8 @@ public void filterCategoryProjectId() { "category = 5", "isnotnull(category) AND (category = 5)", "category IS NOT NULL, category = 5", - 0, 0, + 0, + 0, ImmutableList.of(row(5L)))); } @@ -182,7 +183,8 @@ public void filterVariantCategoryProjectId() { ICEBERG_VARCAT + " = 5", SPARK_ISNOTNULL_NESTED + " AND (" + SPARK_VARIANT_GET + " = 5)", ICEBERG_NESTED_ISNOTNULL + ", " + ICEBERG_VARCAT + " = 5", - 2, 2, + 2, + 2, ImmutableList.of(row(5L)))); } @@ -209,7 +211,8 @@ public void filterVariantCategoryInRange() { + " > 4, " + ICEBERG_VARCAT + " < 7", - 4, 4, + 4, + 4, ImmutableList.of(row(5L), row(6L)))); } @@ -236,7 +239,8 @@ public void filterVariantCategoryGreateThanEquals() { + " >= 4, " + ICEBERG_VARCAT + " <= 7", - 4, 4, + 4, + 4, ImmutableList.of(row(4L), row(5L), row(6L), row(7L)))); } @@ -254,7 +258,8 @@ public void filterVariantCategoryProjectVariantId() { ICEBERG_VARCAT + " = 5", SPARK_ISNOTNULL_NESTED + " AND (" + SPARK_VARIANT_GET + " = 5)", ICEBERG_NESTED_ISNOTNULL + ", " + ICEBERG_VARCAT + " = 5", - 2, 2, + 2, + 2, ImmutableList.of(row(5)))); } @@ -269,7 +274,8 @@ public void filterVariantCategorySetMembership() { ICEBERG_VARCAT + " IN (5, 10)", SPARK_VARIANT_GET + " IN (5,10)", ICEBERG_VARCAT + " IN (5, 10)", // no null check - 4, 4, + 4, + 4, ImmutableList.of(row(5L), row(10L)))); } @@ -287,13 +293,13 @@ public void filterVariantCategorySetMembership2() { ICEBERG_VARCAT + " IN (100, 400)", SPARK_VARIANT_GET + " IN (100,400)", "", - 0, 0, + 0, + 0, ImmutableList.of())); } /** - * Set membership of a single element is remapped to equality, filtering takes place in - * planning. + * Set membership of a single element is remapped to equality, filtering takes place in planning. */ @TestTemplate public void filterVariantCategorySetMembership3() { @@ -305,7 +311,8 @@ public void filterVariantCategorySetMembership3() { ICEBERG_VARCAT + " IN (100)", SPARK_ISNOTNULL_NESTED + " AND (" + SPARK_VARIANT_GET + " = 100)", "", - 0, 0, + 0, + 0, ImmutableList.of())); } @@ -323,13 +330,12 @@ public void filterVariantCategorySetMembership4() { ICEBERG_VARCAT + " IN (4)", SPARK_ISNOTNULL_NESTED + " AND (" + SPARK_VARIANT_GET + " = 4)", ICEBERG_NESTED_ISNOTNULL + ", " + ICEBERG_VARCAT + " = 4", - 2, 2, + 2, + 2, ImmutableList.of(row(4L)))); } - /** - * Evaluation of the IS NULL predicate. - */ + /** Evaluation of the IS NULL predicate. */ @TestTemplate public void filterVariantCategoryIsNull() { withDefaultTimeZone( @@ -340,18 +346,15 @@ public void filterVariantCategoryIsNull() { ICEBERG_VARCAT + " IS NULL", "isnull(" + SPARK_VARIANT_GET + ")", "", - 4, 4, + 4, + 4, ImmutableList.of())); } - /** - * Evaluation of the IS NOT NULL predicate; this finds everything. - */ + /** Evaluation of the IS NOT NULL predicate; this finds everything. */ @TestTemplate public void filterVariantCategoryIsNotNull() { - List rows = LongStream.rangeClosed(0, 19) - .mapToObj(this::row) - .toList(); + List rows = LongStream.rangeClosed(0, 19).mapToObj(this::row).toList(); withDefaultTimeZone( "UTC", () -> @@ -360,7 +363,8 @@ public void filterVariantCategoryIsNotNull() { ICEBERG_VARCAT + " IS NOT NULL", SPARK_ISNOTNULL_NESTED + " AND isnotnull(" + SPARK_VARIANT_GET + ")", "nested IS NOT NULL, variant_get(nested, '$.varcategory', 'int') IS NOT NULL", - 4, 4, + 4, + 4, rows)); } @@ -378,7 +382,8 @@ public void filterArrayElementProjectId() { "variant_get(arr, '$[0]', 'int') = 5", "isnotnull(arr) AND (variant_get(arr, $[0], IntegerType, true, Some(UTC)) = 5)", "arr IS NOT NULL", - 0, 0, + 0, + 0, ImmutableList.of(row(5L)))); } @@ -395,10 +400,10 @@ public void filterArrayElementProjectId() { * @param predicate SQL WHERE clause (no "WHERE" keyword) * @param sparkFilter expected post-scan Spark Filter node text * @param icebergFilters expected {@code filters=...} value from the Iceberg scan node; empty - * string means no Iceberg pushdown shall take places. + * string means no Iceberg pushdown shall take places. * @param expectedPlanningEvaluations expected number of evaluations during planning - * @param expectedExecutionEvaluations number of evaluations of a rowgroup filter predicate - * on shredded column. + * @param expectedExecutionEvaluations number of evaluations of a rowgroup filter predicate on + * shredded column. * @param expectedRows expected result rows in id order */ private void checkFilters( @@ -410,12 +415,12 @@ private void checkFilters( int expectedExecutionEvaluations, List expectedRows) { - ParquetMetricsRowGroupFilter.resetShreddedMetricsCounter(); + ParquetMetricsRowGroupFilter.resetShreddedMetricsCounters(); String query = String.format("SELECT %s FROM %s WHERE %s ORDER BY id", projection, tableName, predicate); SparkPlan plan = executeAndKeepPlan(query); - long planShredCount = ParquetMetricsRowGroupFilter.variantPredicatesShreddedMetrics(); + long planShredCount = ParquetMetricsRowGroupFilter.variantPredicatesShreddedMetricsEvaluated(); String planString = plan.toString().replaceAll("#\\d+L?", ""); String summary = String.format("%s with plan shred count %d", query, planShredCount); @@ -424,7 +429,9 @@ private void checkFilters( .containsAnyOf("Filter (" + sparkFilter + ")", "Filter " + sparkFilter); if (!icebergFilters.isEmpty()) { - assertThat(planString).as("No iceberg scan generated from %s", summary).contains("IcebergScan"); + assertThat(planString) + .as("No iceberg scan generated from %s", summary) + .contains("IcebergScan"); assertThat(planString) .as("Iceberg pushed filters of must match from %s", summary) .contains(", filters=" + icebergFilters + ","); @@ -435,20 +442,20 @@ private void checkFilters( } if (shredded) { - assertThat(ParquetMetricsRowGroupFilter.variantPredicatesShreddedMetrics()) - .describedAs("Count of shredded metrics filtered during planning of of %s to plan %s", + assertThat(ParquetMetricsRowGroupFilter.variantPredicatesShreddedMetricsEvaluated()) + .describedAs( + "Count of shredded metrics filtered during planning of of %s to plan %s", summary, planString) .isEqualTo(expectedPlanningEvaluations); } - ParquetMetricsRowGroupFilter.resetShreddedMetricsCounter(); - final List rows = sql("SELECT %s FROM %s WHERE %s ORDER BY id", projection, selectTarget(), predicate); - assertEquals( - "Execution of " + summary + " to plan " + planString, - expectedRows, - rows); + ParquetMetricsRowGroupFilter.resetShreddedMetricsCounters(); + final List rows = + sql("SELECT %s FROM %s WHERE %s ORDER BY id", projection, selectTarget(), predicate); + assertEquals("Execution of " + summary + " to plan " + planString, expectedRows, rows); if (shredded) { - assertThat(ParquetMetricsRowGroupFilter.variantPredicatesShreddedMetrics()) - .describedAs("Count of shredded metrics filtered during execution of of %s to plan %s", + assertThat(ParquetMetricsRowGroupFilter.variantPredicatesShreddedMetricsEvaluated()) + .describedAs( + "Count of shredded metrics filtered during execution of of %s to plan %s", summary, planString) .isEqualTo(expectedExecutionEvaluations); } From 6ad21e1eb269de537cabccc44b24c91551b29c6a Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Thu, 21 May 2026 19:16:46 +0100 Subject: [PATCH 4/4] ParquetMetricsRowGroupFilter: logging downgraded to info Filtering is working well enough to not need logging on normal execution. Change-Id: Ief88bbe7e1df28a93b1fd988f6d9f224fbb846e0 --- .../parquet/ParquetMetricsRowGroupFilter.java | 33 ++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetMetricsRowGroupFilter.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetMetricsRowGroupFilter.java index 51606608b59f..155ddeec589a 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetMetricsRowGroupFilter.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetMetricsRowGroupFilter.java @@ -85,9 +85,7 @@ public class ParquetMetricsRowGroupFilter { */ private static final AtomicLong VARIANT_PREDICATES_SHREDDED_METRICS_EVALUATED = new AtomicLong(); - /** - * Counter for row groups proven skippable by a shredded variant predicate. - */ + /** Counter for row groups proven skippable by a shredded variant predicate. */ private static final AtomicLong VARIANT_PREDICATES_SHREDDED_SKIPPED = new AtomicLong(); public ParquetMetricsRowGroupFilter(Schema schema, Expression unbound) { @@ -131,9 +129,15 @@ private class MetricsEvalVisitor extends BoundExpressionVisitor { private Map valueCounts = null; private Map> conversions = null; - // ID-less columns collected during the main column scan for lazy variantInfoByColumnPath build + /** + * ID-less columns collected during the main column scan for lazy variantInfoByColumnPath build. + */ private List shreddedVariantColumns = null; - // Built lazily on the first compareVariant() call; null means not yet built + + /** + * Built lazily on the first compareVariant() call; null means not yet built. TODO: should + * construction be synchronized? + */ private Map variantInfoByColumnPath = null; private boolean eval(MessageType fileSchema, BlockMetaData rowGroup) { @@ -655,11 +659,10 @@ public Boolean predicate(BoundPredicate pred) { */ private boolean compareVariant(BoundPredicate pred, BoundExtract extract) { if (variantInfoByColumnPath == null) { - // TODO: concurrency ? buildVariantInfo(); } int fieldId = extract.ref().fieldId(); - LOG.info("comparing variant {}", extract); + LOG.debug("comparing variant {}", extract); String colName = variantColumnNames.get(fieldId); if (colName == null) { // not in the variant columns @@ -676,7 +679,7 @@ private boolean compareVariant(BoundPredicate pred, BoundExtract extra VARIANT_PREDICATES_SHREDDED_METRICS_EVALUATED.incrementAndGet(); // now do the evaluation. - LOG.info("Evaluating column {} with info {}", columnPath, columnInfo); + LOG.debug("Evaluating column {} with info {}", columnPath, columnInfo); PrimitiveType parquetType = columnInfo.type(); final ColumnChunkMetaData col = columnInfo.chunkMetaData; Statistics colStats = col.getStatistics(); @@ -818,7 +821,7 @@ private boolean evalMembershipPredicateOnShreddedVariant( if (literalSet.size() > IN_PREDICATE_LIMIT) { return ROWS_MIGHT_MATCH; } - LOG.info("Set membership evaluation"); + LOG.debug("Set membership evaluation"); Function converter = ParquetConversions.converterFromParquet(parquetType, extract.type()); T min = (T) converter.apply(colStats.genericGetMin()); @@ -838,7 +841,7 @@ private boolean evalMembershipPredicateOnShreddedVariant( candidates.stream().filter(v -> pred.term().comparator().compare(max, v) >= 0).toList(); final boolean match = candidates.isEmpty() ? ROWS_CANNOT_MATCH : ROWS_MIGHT_MATCH; - LOG.info("Outcome match={}", match); + LOG.debug("Outcome match={}", match); return match; } @@ -854,7 +857,7 @@ private boolean evalMembershipPredicateOnShreddedVariant( */ private boolean evalUnaryPredicate( BoundPredicate pred, Statistics colStats, long valueCount) { - LOG.info("Evaluating unary predicate: {}", pred.op()); + LOG.debug("Evaluating unary predicate: {}", pred.op()); switch (pred.op()) { case IS_NULL -> { // If every row has a non-null typed value, no row can match IS_NULL @@ -922,7 +925,7 @@ private static boolean mayContainNull(Statistics statistics) { * @return the map of variant column names, may be empty. */ private Map buildVariantColumnNames(MessageType fileSchema) { - LOG.info("Building variant column names..."); + LOG.debug("Building variant column names..."); Map names = Maps.newHashMap(); for (org.apache.parquet.schema.Type field : fileSchema.getFields()) { if (field.getId() != null) { @@ -933,7 +936,7 @@ private Map buildVariantColumnNames(MessageType fileSchema) { } } } - LOG.info("Found {} names", names.size()); + LOG.debug("Found {} names", names.size()); return names; } @@ -1003,8 +1006,8 @@ static long variantPredicatesShreddedMetricsEvaluated() { } /** - * The number of row groups proven skippable by a shredded variant predicate. Will always be equal to or less than - * the value of {@link #variantPredicatesShreddedMetricsEvaluated()}. + * The number of row groups proven skippable by a shredded variant predicate. Will always be equal + * to or less than the value of {@link #variantPredicatesShreddedMetricsEvaluated()}. * * @return zero or a positive integer */