Skip to content

Commit 939e128

Browse files
committed
[SPARK-56521][SQL] Refactor BatchScanExec: guard cast with runtimeFilters.nonEmpty, simplify partPredicates
1 parent e085be4 commit 939e128

1 file changed

Lines changed: 78 additions & 64 deletions

File tree

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala

Lines changed: 78 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -61,83 +61,97 @@ case class BatchScanExec(
6161

6262
// Visible for testing
6363
@transient private[sql] lazy val filteredPartitions: Seq[Option[InputPartition]] = {
64-
val dataSourceFilters = runtimeFilters.flatMap {
65-
case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e)
66-
case f => DataSourceV2Strategy.translateScalarSubqueryFilterV2(f)
67-
}
68-
6964
val originalPartitioning = outputPartitioning
70-
// the cast is safe as runtime filters are only assigned if the scan can be filtered
71-
val filterableScan = scan.asInstanceOf[SupportsRuntimeV2Filtering]
72-
var filtered = false
73-
74-
if (dataSourceFilters.nonEmpty) {
75-
filterableScan.filter(dataSourceFilters.toArray)
76-
filtered = true
77-
}
65+
if (runtimeFilters.nonEmpty) {
66+
// the cast is safe as runtime filters are only assigned if the scan can be filtered
67+
val filterableScan = scan.asInstanceOf[SupportsRuntimeV2Filtering]
68+
69+
// push down translatable runtime filters
70+
val dataSourceFilters = runtimeFilters.flatMap {
71+
case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e)
72+
case f => DataSourceV2Strategy.translateScalarSubqueryFilterV2(f)
73+
}
74+
if (dataSourceFilters.nonEmpty) {
75+
filterableScan.filter(dataSourceFilters.toArray)
76+
}
7877

79-
// If the scan supports iterative filtering, derive PartitionPredicates from the
80-
// runtime filters and push them in a second pass. (See SPARK-55596)
81-
if (filterableScan.supportsIterativeFiltering()) {
82-
PushDownUtils.getPartitionPredicateSchema(table, output).foreach { partitionFields =>
83-
val partPredicates =
78+
// If the scan supports iterative filtering, derive PartitionPredicates from the
79+
// runtime filters and push them in a second pass. (See SPARK-55596)
80+
val partPredicates = if (filterableScan.supportsIterativeFiltering()) {
81+
PushDownUtils.getPartitionPredicateSchema(table, output).map { partitionFields =>
8482
PushDownUtils.createRuntimePartitionPredicates(runtimeFilters, partitionFields)
85-
if (partPredicates.nonEmpty) {
86-
filterableScan.filter(partPredicates.toArray)
87-
filtered = true
88-
}
83+
}.getOrElse(Seq.empty)
84+
} else {
85+
Seq.empty
86+
}
87+
if (partPredicates.nonEmpty) {
88+
filterableScan.filter(partPredicates.toArray)
8989
}
90-
}
9190

92-
if (filtered) {
93-
// call toBatch again to get filtered partitions
94-
val newPartitions = scan.toBatch.planInputPartitions()
91+
if (dataSourceFilters.nonEmpty || partPredicates.nonEmpty) {
92+
// call toBatch again to get filtered partitions
93+
val newPartitions = scan.toBatch.planInputPartitions()
94+
95+
originalPartitioning match {
96+
case k: KeyedPartitioning =>
97+
if (newPartitions.exists(!_.isInstanceOf[HasPartitionKey])) {
98+
throw new SparkException(
99+
"Data source must have preserved the original partitioning " +
100+
"during runtime filtering: not all partitions implement " +
101+
"HasPartitionKey after filtering")
102+
}
95103

96-
originalPartitioning match {
97-
case k: KeyedPartitioning =>
98-
if (newPartitions.exists(!_.isInstanceOf[HasPartitionKey])) {
99-
throw new SparkException("Data source must have preserved the original partitioning " +
100-
"during runtime filtering: not all partitions implement HasPartitionKey after " +
101-
"filtering")
102-
}
103-
104-
val inputMap = k.partitionKeys.groupBy(identity).view.mapValues(_.size)
105-
val comparableKeyWrapperFactory = InternalRowComparableWrapper
106-
.getInternalRowComparableWrapperFactory(k.expressionDataTypes)
107-
val filteredMap = newPartitions.groupBy(
108-
p => comparableKeyWrapperFactory(p.asInstanceOf[HasPartitionKey].partitionKey())
109-
)
110-
111-
if (!filteredMap.keySet.subsetOf(inputMap.keySet)) {
112-
throw new SparkException("During runtime filtering, data source must not report new " +
113-
"partition keys that are not present in the original partitioning.")
114-
}
115-
116-
inputMap.toSeq
117-
.sortBy(_._1)(k.keyOrdering)
118-
.flatMap { case (key, size) =>
119-
// We require the new number of partitions to be equal or less than the old number of
120-
// partitions for a given key. In the case of less than, empty partitions are added.
121-
val fps = filteredMap.getOrElse(key, Array.empty)
122-
123-
if (fps.size > size) {
124-
throw new SparkException("During runtime filtering, data source must not report " +
125-
s"new partitions for a given key. Before: $size partitions. " +
126-
s"After: ${fps.size} partitions")
104+
val inputMap = k.partitionKeys.groupBy(identity).view.mapValues(_.size)
105+
val comparableKeyWrapperFactory = InternalRowComparableWrapper
106+
.getInternalRowComparableWrapperFactory(k.expressionDataTypes)
107+
val filteredMap = newPartitions.groupBy(
108+
p => comparableKeyWrapperFactory(
109+
p.asInstanceOf[HasPartitionKey].partitionKey()))
110+
111+
if (!filteredMap.keySet.subsetOf(inputMap.keySet)) {
112+
throw new SparkException(
113+
"During runtime filtering, data source must not report new " +
114+
"partition keys that are not present in the original partitioning.")
115+
}
116+
117+
inputMap.toSeq
118+
.sortBy(_._1)(k.keyOrdering)
119+
.flatMap { case (key, size) =>
120+
// We require the new number of partitions to be equal or less than
121+
// the old number of partitions for a given key. In the case of less
122+
// than, empty partitions are added.
123+
val fps = filteredMap.getOrElse(key, Array.empty)
124+
125+
if (fps.size > size) {
126+
throw new SparkException(
127+
"During runtime filtering, data source must not report " +
128+
s"new partitions for a given key. Before: $size partitions. " +
129+
s"After: ${fps.size} partitions")
130+
}
131+
132+
fps.map(Some).padTo(size, None)
127133
}
128134

129-
fps.map(Some).padTo(size, None)
130-
}
135+
case _ =>
136+
// no validation is needed as the data source did not report any specific
137+
// partitioning
138+
newPartitions.toSeq.map(Some)
139+
}
131140

132-
case _ =>
133-
// no validation is needed as the data source did not report any specific partitioning
134-
newPartitions.toSeq.map(Some)
135-
}
141+
} else {
142+
(originalPartitioning match {
143+
case k: KeyedPartitioning =>
144+
inputPartitions.sortBy(
145+
_.asInstanceOf[HasPartitionKey].partitionKey())(k.keyRowOrdering)
136146

147+
case _ => inputPartitions
148+
}).map(Some)
149+
}
137150
} else {
138151
(originalPartitioning match {
139152
case k: KeyedPartitioning =>
140-
inputPartitions.sortBy(_.asInstanceOf[HasPartitionKey].partitionKey())(k.keyRowOrdering)
153+
inputPartitions.sortBy(
154+
_.asInstanceOf[HasPartitionKey].partitionKey())(k.keyRowOrdering)
141155

142156
case _ => inputPartitions
143157
}).map(Some)

0 commit comments

Comments
 (0)