@@ -61,83 +61,97 @@ case class BatchScanExec(
6161
6262 // Visible for testing
6363 @ transient private [sql] lazy val filteredPartitions : Seq [Option [InputPartition ]] = {
64- val dataSourceFilters = runtimeFilters.flatMap {
65- case DynamicPruningExpression (e) => DataSourceV2Strategy .translateRuntimeFilterV2(e)
66- case f => DataSourceV2Strategy .translateScalarSubqueryFilterV2(f)
67- }
68-
6964 val originalPartitioning = outputPartitioning
70- // the cast is safe as runtime filters are only assigned if the scan can be filtered
71- val filterableScan = scan.asInstanceOf [SupportsRuntimeV2Filtering ]
72- var filtered = false
73-
74- if (dataSourceFilters.nonEmpty) {
75- filterableScan.filter(dataSourceFilters.toArray)
76- filtered = true
77- }
65+ if (runtimeFilters.nonEmpty) {
66+ // the cast is safe as runtime filters are only assigned if the scan can be filtered
67+ val filterableScan = scan.asInstanceOf [SupportsRuntimeV2Filtering ]
68+
69+ // push down translatable runtime filters
70+ val dataSourceFilters = runtimeFilters.flatMap {
71+ case DynamicPruningExpression (e) => DataSourceV2Strategy .translateRuntimeFilterV2(e)
72+ case f => DataSourceV2Strategy .translateScalarSubqueryFilterV2(f)
73+ }
74+ if (dataSourceFilters.nonEmpty) {
75+ filterableScan.filter(dataSourceFilters.toArray)
76+ }
7877
79- // If the scan supports iterative filtering, derive PartitionPredicates from the
80- // runtime filters and push them in a second pass. (See SPARK-55596)
81- if (filterableScan.supportsIterativeFiltering()) {
82- PushDownUtils .getPartitionPredicateSchema(table, output).foreach { partitionFields =>
83- val partPredicates =
78+ // If the scan supports iterative filtering, derive PartitionPredicates from the
79+ // runtime filters and push them in a second pass. (See SPARK-55596)
80+ val partPredicates = if (filterableScan.supportsIterativeFiltering()) {
81+ PushDownUtils .getPartitionPredicateSchema(table, output).map { partitionFields =>
8482 PushDownUtils .createRuntimePartitionPredicates(runtimeFilters, partitionFields)
85- if (partPredicates.nonEmpty) {
86- filterableScan.filter(partPredicates.toArray)
87- filtered = true
88- }
83+ }.getOrElse(Seq .empty)
84+ } else {
85+ Seq .empty
86+ }
87+ if (partPredicates.nonEmpty) {
88+ filterableScan.filter(partPredicates.toArray)
8989 }
90- }
9190
92- if (filtered) {
93- // call toBatch again to get filtered partitions
94- val newPartitions = scan.toBatch.planInputPartitions()
91+ if (dataSourceFilters.nonEmpty || partPredicates.nonEmpty) {
92+ // call toBatch again to get filtered partitions
93+ val newPartitions = scan.toBatch.planInputPartitions()
94+
95+ originalPartitioning match {
96+ case k : KeyedPartitioning =>
97+ if (newPartitions.exists(! _.isInstanceOf [HasPartitionKey ])) {
98+ throw new SparkException (
99+ " Data source must have preserved the original partitioning " +
100+ " during runtime filtering: not all partitions implement " +
101+ " HasPartitionKey after filtering" )
102+ }
95103
96- originalPartitioning match {
97- case k : KeyedPartitioning =>
98- if (newPartitions.exists(! _.isInstanceOf [HasPartitionKey ])) {
99- throw new SparkException (" Data source must have preserved the original partitioning " +
100- " during runtime filtering: not all partitions implement HasPartitionKey after " +
101- " filtering" )
102- }
103-
104- val inputMap = k.partitionKeys.groupBy(identity).view.mapValues(_.size)
105- val comparableKeyWrapperFactory = InternalRowComparableWrapper
106- .getInternalRowComparableWrapperFactory(k.expressionDataTypes)
107- val filteredMap = newPartitions.groupBy(
108- p => comparableKeyWrapperFactory(p.asInstanceOf [HasPartitionKey ].partitionKey())
109- )
110-
111- if (! filteredMap.keySet.subsetOf(inputMap.keySet)) {
112- throw new SparkException (" During runtime filtering, data source must not report new " +
113- " partition keys that are not present in the original partitioning." )
114- }
115-
116- inputMap.toSeq
117- .sortBy(_._1)(k.keyOrdering)
118- .flatMap { case (key, size) =>
119- // We require the new number of partitions to be equal or less than the old number of
120- // partitions for a given key. In the case of less than, empty partitions are added.
121- val fps = filteredMap.getOrElse(key, Array .empty)
122-
123- if (fps.size > size) {
124- throw new SparkException (" During runtime filtering, data source must not report " +
125- s " new partitions for a given key. Before: $size partitions. " +
126- s " After: ${fps.size} partitions " )
104+ val inputMap = k.partitionKeys.groupBy(identity).view.mapValues(_.size)
105+ val comparableKeyWrapperFactory = InternalRowComparableWrapper
106+ .getInternalRowComparableWrapperFactory(k.expressionDataTypes)
107+ val filteredMap = newPartitions.groupBy(
108+ p => comparableKeyWrapperFactory(
109+ p.asInstanceOf [HasPartitionKey ].partitionKey()))
110+
111+ if (! filteredMap.keySet.subsetOf(inputMap.keySet)) {
112+ throw new SparkException (
113+ " During runtime filtering, data source must not report new " +
114+ " partition keys that are not present in the original partitioning." )
115+ }
116+
117+ inputMap.toSeq
118+ .sortBy(_._1)(k.keyOrdering)
119+ .flatMap { case (key, size) =>
120+ // We require the new number of partitions to be equal or less than
121+ // the old number of partitions for a given key. In the case of less
122+ // than, empty partitions are added.
123+ val fps = filteredMap.getOrElse(key, Array .empty)
124+
125+ if (fps.size > size) {
126+ throw new SparkException (
127+ " During runtime filtering, data source must not report " +
128+ s " new partitions for a given key. Before: $size partitions. " +
129+ s " After: ${fps.size} partitions " )
130+ }
131+
132+ fps.map(Some ).padTo(size, None )
127133 }
128134
129- fps.map(Some ).padTo(size, None )
130- }
135+ case _ =>
136+ // no validation is needed as the data source did not report any specific
137+ // partitioning
138+ newPartitions.toSeq.map(Some )
139+ }
131140
132- case _ =>
133- // no validation is needed as the data source did not report any specific partitioning
134- newPartitions.toSeq.map(Some )
135- }
141+ } else {
142+ (originalPartitioning match {
143+ case k : KeyedPartitioning =>
144+ inputPartitions.sortBy(
145+ _.asInstanceOf [HasPartitionKey ].partitionKey())(k.keyRowOrdering)
136146
147+ case _ => inputPartitions
148+ }).map(Some )
149+ }
137150 } else {
138151 (originalPartitioning match {
139152 case k : KeyedPartitioning =>
140- inputPartitions.sortBy(_.asInstanceOf [HasPartitionKey ].partitionKey())(k.keyRowOrdering)
153+ inputPartitions.sortBy(
154+ _.asInstanceOf [HasPartitionKey ].partitionKey())(k.keyRowOrdering)
141155
142156 case _ => inputPartitions
143157 }).map(Some )
0 commit comments