Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ import org.apache.spark.sql.catalyst.plans.logical.UpdateIcebergTable
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION
import org.apache.spark.sql.catalyst.trees.TreePattern.SORT
import org.apache.spark.sql.connector.read.SupportsRuntimeFiltering
import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.execution.datasources.v2.ExtendedDataSourceV2Implicits
Expand All @@ -67,7 +67,7 @@ case class RowLevelCommandDynamicPruning(spark: SparkSession)
// apply special dynamic filtering only for plans that don't support deltas
case RewrittenRowLevelCommand(
command: RowLevelCommand,
DataSourceV2ScanRelation(_, scan: SupportsRuntimeFiltering, _, _, _),
DataSourceV2ScanRelation(_, scan: SupportsRuntimeV2Filtering, _, _, _),
rewritePlan: ReplaceIcebergData)
if conf.dynamicPartitionPruningEnabled && isCandidate(command) =>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@
import org.apache.iceberg.util.SnapshotUtil;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.connector.expressions.Expressions;
import org.apache.spark.sql.connector.expressions.Literal;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.expressions.filter.Predicate;
import org.apache.spark.sql.connector.read.Statistics;
import org.apache.spark.sql.connector.read.SupportsRuntimeFiltering;
import org.apache.spark.sql.sources.Filter;
import org.apache.spark.sql.sources.In;
import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.MetadataBuilder;
import org.apache.spark.sql.types.StructField;
Expand All @@ -52,7 +52,7 @@
import org.slf4j.LoggerFactory;

class SparkCopyOnWriteScan extends SparkPartitioningAwareScan<FileScanTask>
implements SupportsRuntimeFiltering {
implements SupportsRuntimeV2Filtering {

private static final Logger LOG = LoggerFactory.getLogger(SparkCopyOnWriteScan.class);

Expand Down Expand Up @@ -118,7 +118,7 @@ public NamedReference[] filterAttributes() {
}

@Override
public void filter(Filter[] filters) {
public void filter(Predicate[] predicates) {
Preconditions.checkState(
Objects.equals(snapshotId(), currentSnapshotId()),
"Runtime file filtering is not possible: the table has been concurrently modified. "
Expand All @@ -128,16 +128,10 @@ public void filter(Filter[] filters) {
snapshotId(),
currentSnapshotId());

for (Filter filter : filters) {
// Spark can only pass In filters at the moment
if (filter instanceof In
&& ((In) filter).attribute().equalsIgnoreCase(MetadataColumns.FILE_PATH.name())) {
In in = (In) filter;

Set<String> fileLocations = Sets.newHashSet();
for (Object value : in.values()) {
fileLocations.add((String) value);
}
for (Predicate predicate : predicates) {
Comment thread
kevinjqliu marked this conversation as resolved.
// Spark can only pass IN predicates at the moment
if (isFilePathInPredicate(predicate)) {
Set<String> fileLocations = extractStringLiterals(predicate);

// Spark may call this multiple times for UPDATEs with subqueries
// as such cases are rewritten using UNION and the same scan on both sides
Expand All @@ -159,7 +153,7 @@ public void filter(Filter[] filters) {
resetTasks(filteredTasks);
}
} else {
LOG.warn("Unsupported runtime filter {}", filter);
LOG.warn("Unsupported runtime filter {}", predicate);
}
}
}
Expand Down Expand Up @@ -228,4 +222,32 @@ private boolean isRowLineageField(StructField field) {
|| field.name().equals(MetadataColumns.LAST_UPDATED_SEQUENCE_NUMBER.name());
return hasLineageFieldName && field.metadata().contains("__metadata_col");
}

private static boolean isFilePathInPredicate(Predicate predicate) {
if (!"IN".equals(predicate.name()) || predicate.children().length < 1) {
return false;
}

if (!(predicate.children()[0] instanceof NamedReference)) {
return false;
}

String[] fieldNames = ((NamedReference) predicate.children()[0]).fieldNames();

return fieldNames.length == 1
&& fieldNames[0].equalsIgnoreCase(MetadataColumns.FILE_PATH.name());
}

private static Set<String> extractStringLiterals(Predicate predicate) {
Set<String> values = Sets.newHashSet();
for (int i = 1; i < predicate.children().length; i++) {
if (predicate.children()[i] instanceof Literal) {
Object value = ((Literal<?>) predicate.children()[i]).value();
// V2 string literals come through as UTF8String; toString() materializes the Java String
values.add(value.toString());
}
}

return values;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@
import org.apache.iceberg.util.SnapshotUtil;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.connector.expressions.Expressions;
import org.apache.spark.sql.connector.expressions.Literal;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.expressions.filter.Predicate;
import org.apache.spark.sql.connector.read.Statistics;
import org.apache.spark.sql.connector.read.SupportsRuntimeFiltering;
import org.apache.spark.sql.sources.Filter;
import org.apache.spark.sql.sources.In;
import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.MetadataBuilder;
import org.apache.spark.sql.types.StructField;
Expand All @@ -52,7 +52,7 @@
import org.slf4j.LoggerFactory;

class SparkCopyOnWriteScan extends SparkPartitioningAwareScan<FileScanTask>
implements SupportsRuntimeFiltering {
implements SupportsRuntimeV2Filtering {

private static final Logger LOG = LoggerFactory.getLogger(SparkCopyOnWriteScan.class);

Expand Down Expand Up @@ -118,7 +118,7 @@ public NamedReference[] filterAttributes() {
}

@Override
public void filter(Filter[] filters) {
public void filter(Predicate[] predicates) {
Preconditions.checkState(
Objects.equals(snapshotId(), currentSnapshotId()),
"Runtime file filtering is not possible: the table has been concurrently modified. "
Expand All @@ -128,16 +128,10 @@ public void filter(Filter[] filters) {
snapshotId(),
currentSnapshotId());

for (Filter filter : filters) {
// Spark can only pass In filters at the moment
if (filter instanceof In
&& ((In) filter).attribute().equalsIgnoreCase(MetadataColumns.FILE_PATH.name())) {
In in = (In) filter;

Set<String> fileLocations = Sets.newHashSet();
for (Object value : in.values()) {
fileLocations.add((String) value);
}
for (Predicate predicate : predicates) {
Comment thread
kevinjqliu marked this conversation as resolved.
// Spark can only pass IN predicates at the moment
if (isFilePathInPredicate(predicate)) {
Set<String> fileLocations = extractStringLiterals(predicate);

// Spark may call this multiple times for UPDATEs with subqueries
// as such cases are rewritten using UNION and the same scan on both sides
Expand All @@ -159,7 +153,7 @@ public void filter(Filter[] filters) {
resetTasks(filteredTasks);
}
} else {
LOG.warn("Unsupported runtime filter {}", filter);
LOG.warn("Unsupported runtime filter {}", predicate);
}
}
}
Expand Down Expand Up @@ -228,4 +222,32 @@ private boolean isRowLineageField(StructField field) {
|| field.name().equals(MetadataColumns.LAST_UPDATED_SEQUENCE_NUMBER.name());
return hasLineageFieldName && field.metadata().contains("__metadata_col");
}

private static boolean isFilePathInPredicate(Predicate predicate) {
if (!"IN".equals(predicate.name()) || predicate.children().length < 1) {
return false;
}

if (!(predicate.children()[0] instanceof NamedReference)) {
return false;
}

String[] fieldNames = ((NamedReference) predicate.children()[0]).fieldNames();

return fieldNames.length == 1
&& fieldNames[0].equalsIgnoreCase(MetadataColumns.FILE_PATH.name());
}

private static Set<String> extractStringLiterals(Predicate predicate) {
Set<String> values = Sets.newHashSet();
for (int i = 1; i < predicate.children().length; i++) {
if (predicate.children()[i] instanceof Literal) {
Object value = ((Literal<?>) predicate.children()[i]).value();
// V2 string literals come through as UTF8String; toString() materializes the Java String
values.add(value.toString());
}
}

return values;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@
import org.apache.iceberg.util.SnapshotUtil;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.connector.expressions.Expressions;
import org.apache.spark.sql.connector.expressions.Literal;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.expressions.filter.Predicate;
import org.apache.spark.sql.connector.read.Statistics;
import org.apache.spark.sql.connector.read.SupportsRuntimeFiltering;
import org.apache.spark.sql.sources.Filter;
import org.apache.spark.sql.sources.In;
import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class SparkCopyOnWriteScan extends SparkPartitioningAwareScan<FileScanTask>
implements SupportsRuntimeFiltering {
implements SupportsRuntimeV2Filtering {

private static final Logger LOG = LoggerFactory.getLogger(SparkCopyOnWriteScan.class);

Expand Down Expand Up @@ -103,7 +103,7 @@ public NamedReference[] filterAttributes() {
}

@Override
public void filter(Filter[] filters) {
public void filter(Predicate[] predicates) {
Preconditions.checkState(
Objects.equals(snapshotId(), currentSnapshotId()),
"Runtime file filtering is not possible: the table has been concurrently modified. "
Expand All @@ -113,16 +113,10 @@ public void filter(Filter[] filters) {
snapshotId(),
currentSnapshotId());

for (Filter filter : filters) {
// Spark can only pass In filters at the moment
if (filter instanceof In
&& ((In) filter).attribute().equalsIgnoreCase(MetadataColumns.FILE_PATH.name())) {
In in = (In) filter;

Set<String> fileLocations = Sets.newHashSet();
for (Object value : in.values()) {
fileLocations.add((String) value);
}
for (Predicate predicate : predicates) {
Comment thread
kevinjqliu marked this conversation as resolved.
// Spark can only pass IN predicates at the moment
if (isFilePathInPredicate(predicate)) {
Set<String> fileLocations = extractStringLiterals(predicate);

// Spark may call this multiple times for UPDATEs with subqueries
// as such cases are rewritten using UNION and the same scan on both sides
Expand All @@ -144,7 +138,7 @@ public void filter(Filter[] filters) {
resetTasks(filteredTasks);
}
} else {
LOG.warn("Unsupported runtime filter {}", filter);
LOG.warn("Unsupported runtime filter {}", predicate);
}
}
}
Expand Down Expand Up @@ -188,4 +182,32 @@ private Long currentSnapshotId() {
Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table(), branch());
return currentSnapshot != null ? currentSnapshot.snapshotId() : null;
}

private static boolean isFilePathInPredicate(Predicate predicate) {
if (!"IN".equals(predicate.name()) || predicate.children().length < 1) {
return false;
}

if (!(predicate.children()[0] instanceof NamedReference)) {
return false;
}

String[] fieldNames = ((NamedReference) predicate.children()[0]).fieldNames();

return fieldNames.length == 1
&& fieldNames[0].equalsIgnoreCase(MetadataColumns.FILE_PATH.name());
}

private static Set<String> extractStringLiterals(Predicate predicate) {
Set<String> values = Sets.newHashSet();
for (int i = 1; i < predicate.children().length; i++) {
if (predicate.children()[i] instanceof Literal) {
Object value = ((Literal<?>) predicate.children()[i]).value();
// V2 string literals come through as UTF8String; toString() materializes the Java String
values.add(value.toString());
}
}

return values;
}
}