diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelCommandDynamicPruning.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelCommandDynamicPruning.scala index f8acef9fe355..6766ad338b9f 100644 --- a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelCommandDynamicPruning.scala +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelCommandDynamicPruning.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.catalyst.plans.logical.UpdateIcebergTable import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION import org.apache.spark.sql.catalyst.trees.TreePattern.SORT -import org.apache.spark.sql.connector.read.SupportsRuntimeFiltering +import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.execution.datasources.v2.ExtendedDataSourceV2Implicits @@ -67,7 +67,7 @@ case class RowLevelCommandDynamicPruning(spark: SparkSession) // apply special dynamic filtering only for plans that don't support deltas case RewrittenRowLevelCommand( command: RowLevelCommand, - DataSourceV2ScanRelation(_, scan: SupportsRuntimeFiltering, _, _, _), + DataSourceV2ScanRelation(_, scan: SupportsRuntimeV2Filtering, _, _, _), rewritePlan: ReplaceIcebergData) if conf.dynamicPartitionPruningEnabled && isCandidate(command) => diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java index dbf5d455b948..9674a7333fa8 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java @@ -39,11 +39,11 @@ import org.apache.iceberg.util.SnapshotUtil; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.Literal; import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.filter.Predicate; import org.apache.spark.sql.connector.read.Statistics; -import org.apache.spark.sql.connector.read.SupportsRuntimeFiltering; -import org.apache.spark.sql.sources.Filter; -import org.apache.spark.sql.sources.In; +import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.MetadataBuilder; import org.apache.spark.sql.types.StructField; @@ -52,7 +52,7 @@ import org.slf4j.LoggerFactory; class SparkCopyOnWriteScan extends SparkPartitioningAwareScan - implements SupportsRuntimeFiltering { + implements SupportsRuntimeV2Filtering { private static final Logger LOG = LoggerFactory.getLogger(SparkCopyOnWriteScan.class); @@ -118,7 +118,7 @@ public NamedReference[] filterAttributes() { } @Override - public void filter(Filter[] filters) { + public void filter(Predicate[] predicates) { Preconditions.checkState( Objects.equals(snapshotId(), currentSnapshotId()), "Runtime file filtering is not possible: the table has been concurrently modified. " @@ -128,16 +128,10 @@ public void filter(Filter[] filters) { snapshotId(), currentSnapshotId()); - for (Filter filter : filters) { - // Spark can only pass In filters at the moment - if (filter instanceof In - && ((In) filter).attribute().equalsIgnoreCase(MetadataColumns.FILE_PATH.name())) { - In in = (In) filter; - - Set fileLocations = Sets.newHashSet(); - for (Object value : in.values()) { - fileLocations.add((String) value); - } + for (Predicate predicate : predicates) { + // Spark can only pass IN predicates at the moment + if (isFilePathInPredicate(predicate)) { + Set fileLocations = extractStringLiterals(predicate); // Spark may call this multiple times for UPDATEs with subqueries // as such cases are rewritten using UNION and the same scan on both sides @@ -159,7 +153,7 @@ public void filter(Filter[] filters) { resetTasks(filteredTasks); } } else { - LOG.warn("Unsupported runtime filter {}", filter); + LOG.warn("Unsupported runtime filter {}", predicate); } } } @@ -228,4 +222,32 @@ private boolean isRowLineageField(StructField field) { || field.name().equals(MetadataColumns.LAST_UPDATED_SEQUENCE_NUMBER.name()); return hasLineageFieldName && field.metadata().contains("__metadata_col"); } + + private static boolean isFilePathInPredicate(Predicate predicate) { + if (!"IN".equals(predicate.name()) || predicate.children().length < 1) { + return false; + } + + if (!(predicate.children()[0] instanceof NamedReference)) { + return false; + } + + String[] fieldNames = ((NamedReference) predicate.children()[0]).fieldNames(); + + return fieldNames.length == 1 + && fieldNames[0].equalsIgnoreCase(MetadataColumns.FILE_PATH.name()); + } + + private static Set extractStringLiterals(Predicate predicate) { + Set values = Sets.newHashSet(); + for (int i = 1; i < predicate.children().length; i++) { + if (predicate.children()[i] instanceof Literal) { + Object value = ((Literal) predicate.children()[i]).value(); + // V2 string literals come through as UTF8String; toString() materializes the Java String + values.add(value.toString()); + } + } + + return values; + } } diff --git a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java index dbf5d455b948..9674a7333fa8 100644 --- a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java +++ b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java @@ -39,11 +39,11 @@ import org.apache.iceberg.util.SnapshotUtil; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.Literal; import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.filter.Predicate; import org.apache.spark.sql.connector.read.Statistics; -import org.apache.spark.sql.connector.read.SupportsRuntimeFiltering; -import org.apache.spark.sql.sources.Filter; -import org.apache.spark.sql.sources.In; +import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.MetadataBuilder; import org.apache.spark.sql.types.StructField; @@ -52,7 +52,7 @@ import org.slf4j.LoggerFactory; class SparkCopyOnWriteScan extends SparkPartitioningAwareScan - implements SupportsRuntimeFiltering { + implements SupportsRuntimeV2Filtering { private static final Logger LOG = LoggerFactory.getLogger(SparkCopyOnWriteScan.class); @@ -118,7 +118,7 @@ public NamedReference[] filterAttributes() { } @Override - public void filter(Filter[] filters) { + public void filter(Predicate[] predicates) { Preconditions.checkState( Objects.equals(snapshotId(), currentSnapshotId()), "Runtime file filtering is not possible: the table has been concurrently modified. " @@ -128,16 +128,10 @@ public void filter(Filter[] filters) { snapshotId(), currentSnapshotId()); - for (Filter filter : filters) { - // Spark can only pass In filters at the moment - if (filter instanceof In - && ((In) filter).attribute().equalsIgnoreCase(MetadataColumns.FILE_PATH.name())) { - In in = (In) filter; - - Set fileLocations = Sets.newHashSet(); - for (Object value : in.values()) { - fileLocations.add((String) value); - } + for (Predicate predicate : predicates) { + // Spark can only pass IN predicates at the moment + if (isFilePathInPredicate(predicate)) { + Set fileLocations = extractStringLiterals(predicate); // Spark may call this multiple times for UPDATEs with subqueries // as such cases are rewritten using UNION and the same scan on both sides @@ -159,7 +153,7 @@ public void filter(Filter[] filters) { resetTasks(filteredTasks); } } else { - LOG.warn("Unsupported runtime filter {}", filter); + LOG.warn("Unsupported runtime filter {}", predicate); } } } @@ -228,4 +222,32 @@ private boolean isRowLineageField(StructField field) { || field.name().equals(MetadataColumns.LAST_UPDATED_SEQUENCE_NUMBER.name()); return hasLineageFieldName && field.metadata().contains("__metadata_col"); } + + private static boolean isFilePathInPredicate(Predicate predicate) { + if (!"IN".equals(predicate.name()) || predicate.children().length < 1) { + return false; + } + + if (!(predicate.children()[0] instanceof NamedReference)) { + return false; + } + + String[] fieldNames = ((NamedReference) predicate.children()[0]).fieldNames(); + + return fieldNames.length == 1 + && fieldNames[0].equalsIgnoreCase(MetadataColumns.FILE_PATH.name()); + } + + private static Set extractStringLiterals(Predicate predicate) { + Set values = Sets.newHashSet(); + for (int i = 1; i < predicate.children().length; i++) { + if (predicate.children()[i] instanceof Literal) { + Object value = ((Literal) predicate.children()[i]).value(); + // V2 string literals come through as UTF8String; toString() materializes the Java String + values.add(value.toString()); + } + } + + return values; + } } diff --git a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java index ee4be2461894..f957b97d60f5 100644 --- a/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java +++ b/spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java @@ -38,16 +38,16 @@ import org.apache.iceberg.util.SnapshotUtil; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.Literal; import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.filter.Predicate; import org.apache.spark.sql.connector.read.Statistics; -import org.apache.spark.sql.connector.read.SupportsRuntimeFiltering; -import org.apache.spark.sql.sources.Filter; -import org.apache.spark.sql.sources.In; +import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering; import org.slf4j.Logger; import org.slf4j.LoggerFactory; class SparkCopyOnWriteScan extends SparkPartitioningAwareScan - implements SupportsRuntimeFiltering { + implements SupportsRuntimeV2Filtering { private static final Logger LOG = LoggerFactory.getLogger(SparkCopyOnWriteScan.class); @@ -103,7 +103,7 @@ public NamedReference[] filterAttributes() { } @Override - public void filter(Filter[] filters) { + public void filter(Predicate[] predicates) { Preconditions.checkState( Objects.equals(snapshotId(), currentSnapshotId()), "Runtime file filtering is not possible: the table has been concurrently modified. " @@ -113,16 +113,10 @@ public void filter(Filter[] filters) { snapshotId(), currentSnapshotId()); - for (Filter filter : filters) { - // Spark can only pass In filters at the moment - if (filter instanceof In - && ((In) filter).attribute().equalsIgnoreCase(MetadataColumns.FILE_PATH.name())) { - In in = (In) filter; - - Set fileLocations = Sets.newHashSet(); - for (Object value : in.values()) { - fileLocations.add((String) value); - } + for (Predicate predicate : predicates) { + // Spark can only pass IN predicates at the moment + if (isFilePathInPredicate(predicate)) { + Set fileLocations = extractStringLiterals(predicate); // Spark may call this multiple times for UPDATEs with subqueries // as such cases are rewritten using UNION and the same scan on both sides @@ -144,7 +138,7 @@ public void filter(Filter[] filters) { resetTasks(filteredTasks); } } else { - LOG.warn("Unsupported runtime filter {}", filter); + LOG.warn("Unsupported runtime filter {}", predicate); } } } @@ -188,4 +182,32 @@ private Long currentSnapshotId() { Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table(), branch()); return currentSnapshot != null ? currentSnapshot.snapshotId() : null; } + + private static boolean isFilePathInPredicate(Predicate predicate) { + if (!"IN".equals(predicate.name()) || predicate.children().length < 1) { + return false; + } + + if (!(predicate.children()[0] instanceof NamedReference)) { + return false; + } + + String[] fieldNames = ((NamedReference) predicate.children()[0]).fieldNames(); + + return fieldNames.length == 1 + && fieldNames[0].equalsIgnoreCase(MetadataColumns.FILE_PATH.name()); + } + + private static Set extractStringLiterals(Predicate predicate) { + Set values = Sets.newHashSet(); + for (int i = 1; i < predicate.children().length; i++) { + if (predicate.children()[i] instanceof Literal) { + Object value = ((Literal) predicate.children()[i]).value(); + // V2 string literals come through as UTF8String; toString() materializes the Java String + values.add(value.toString()); + } + } + + return values; + } }