diff --git a/README.md b/README.md index d87cfabd85d..4e4dc1a2da5 100644 --- a/README.md +++ b/README.md @@ -109,6 +109,10 @@ and others would not be possible without your help. Ready? [Getting Started](https://kyuubi.readthedocs.io/en/master/quick_start/) with Kyuubi. +## Security & Guard + +- [Dangerous Join Watchdog](./docs/watchdog/dangerous-join.md) + ## [Contributing](./CONTRIBUTING.md) ## Project & Community Status diff --git a/docs/configuration/settings.md b/docs/configuration/settings.md index 591860243ad..5299d4efa6b 100644 --- a/docs/configuration/settings.md +++ b/docs/configuration/settings.md @@ -580,6 +580,18 @@ jdbc:hive2://localhost:10009/default;#spark.sql.shuffle.partitions=2;spark.execu Please refer to the Spark official online documentation for [SET Command](https://spark.apache.org/docs/latest/sql-ref-syntax-aux-conf-mgmt-set.html) +### Dangerous Join Watchdog + +You can enable dangerous join detection for Spark SQL extension with: + +| Name | Default | Description | +|------------------------------------------------|---------|------------------------------------------------------------------------------------| +| `kyuubi.watchdog.dangerousJoin.enabled` | `true` | Enable dangerous join detection | +| `kyuubi.watchdog.dangerousJoin.broadcastRatio` | `0.8` | Ratio against Spark broadcast threshold to identify oversized broadcast fallback | +| `kyuubi.watchdog.dangerousJoin.action` | `WARN` | `WARN` logs warning diagnostics, `REJECT` throws exception with error code `41101` | + +Please see [Dangerous Join Watchdog](../watchdog/dangerous-join.md) for rules and examples. + ## Flink Configurations ### Via flink-conf.yaml diff --git a/docs/deployment/index.rst b/docs/deployment/index.rst index 1b6bf876678..b9b6c0ded8a 100644 --- a/docs/deployment/index.rst +++ b/docs/deployment/index.rst @@ -27,6 +27,7 @@ Basics :glob: kyuubi_on_kubernetes + settings hive_metastore high_availability_guide migration-guide @@ -42,4 +43,4 @@ Engines engine_on_kubernetes engine_share_level engine_lifecycle - spark/index \ No newline at end of file + spark/index diff --git a/docs/deployment/settings.md b/docs/deployment/settings.md new file mode 100644 index 00000000000..921df077f4b --- /dev/null +++ b/docs/deployment/settings.md @@ -0,0 +1,34 @@ + + +# Deployment Settings for Dangerous Join Watchdog + +## Spark SQL Extensions + +```properties +spark.sql.extensions=org.apache.kyuubi.sql.KyuubiSparkSQLExtension,org.apache.kyuubi.sql.watchdog.KyuubiDangerousJoinExtension +``` + +## Dangerous Join Configurations + +| Name | Default | Description | +|------------------------------------------------|---------|-----------------------------------------------------------| +| `kyuubi.watchdog.dangerousJoin.enabled` | `true` | Enable dangerous join watchdog | +| `kyuubi.watchdog.dangerousJoin.broadcastRatio` | `0.8` | Broadcast threshold coefficient | +| `kyuubi.watchdog.dangerousJoin.action` | `WARN` | `WARN` only logs diagnostics, `REJECT` throws error 41101 | + +For detailed rules and examples, see [Dangerous Join Watchdog](../watchdog/dangerous-join.md). diff --git a/docs/extensions/engines/spark/rules.md b/docs/extensions/engines/spark/rules.md index bb46174c7ca..cad3cdac280 100644 --- a/docs/extensions/engines/spark/rules.md +++ b/docs/extensions/engines/spark/rules.md @@ -65,31 +65,34 @@ Now, you can enjoy the Kyuubi SQL Extension. Kyuubi provides some configs to make these feature easy to use. -| Name | Default Value | Description | Since | -|---------------------------------------------------------------------|----------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------| -| spark.sql.optimizer.insertRepartitionBeforeWrite.enabled | true | Add repartition node at the top of query plan. An approach of merging small files. | 1.2.0 | -| spark.sql.optimizer.forceShuffleBeforeJoin.enabled | false | Ensure shuffle node exists before shuffled join (shj and smj) to make AQE `OptimizeSkewedJoin` works (complex scenario join, multi table join). | 1.2.0 | -| spark.sql.optimizer.finalStageConfigIsolation.enabled | false | If true, the final stage support use different config with previous stage. The prefix of final stage config key should be `spark.sql.finalStage.`. For example, the raw spark config: `spark.sql.adaptive.advisoryPartitionSizeInBytes`, then the final stage config should be: `spark.sql.finalStage.adaptive.advisoryPartitionSizeInBytes`. | 1.2.0 | -| spark.sql.optimizer.insertZorderBeforeWriting.enabled | true | When true, we will follow target table properties to insert zorder or not. The key properties are: 1) `kyuubi.zorder.enabled`: if this property is true, we will insert zorder before writing data. 2) `kyuubi.zorder.cols`: string split by comma, we will zorder by these cols. | 1.4.0 | -| spark.sql.optimizer.zorderGlobalSort.enabled | true | When true, we do a global sort using zorder. Note that, it can cause data skew issue if the zorder columns have less cardinality. When false, we only do local sort using zorder. | 1.4.0 | -| spark.sql.watchdog.maxPartitions | none | Set the max partition number when spark scans a data source. Enable maxPartition Strategy by specifying this configuration. Add maxPartitions Strategy to avoid scan excessive partitions on partitioned table, it's optional that works with defined | 1.4.0 | -| spark.sql.watchdog.maxFileSize | none | Set the maximum size in bytes of files when spark scans a data source. Enable maxFileSize Strategy by specifying this configuration. Add maxFileSize Strategy to avoid scan excessive size of files, it's optional that works with defined | 1.8.0 | -| spark.sql.optimizer.dropIgnoreNonExistent | false | When true, do not report an error if DROP DATABASE/TABLE/VIEW/FUNCTION/PARTITION specifies a non-existent database/table/view/function/partition | 1.5.0 | -| spark.sql.optimizer.rebalanceBeforeZorder.enabled | false | When true, we do a rebalance before zorder in case data skew. Note that, if the insertion is dynamic partition we will use the partition columns to rebalance. | 1.6.0 | -| spark.sql.optimizer.rebalanceZorderColumns.enabled | false | When true and `spark.sql.optimizer.rebalanceBeforeZorder.enabled` is true, we do rebalance before Z-Order. If it's dynamic partition insert, the rebalance expression will include both partition columns and Z-Order columns. | 1.6.0 | -| spark.sql.optimizer.twoPhaseRebalanceBeforeZorder.enabled | false | When true and `spark.sql.optimizer.rebalanceBeforeZorder.enabled` is true, we do two phase rebalance before Z-Order for the dynamic partition write. The first phase rebalance using dynamic partition column; The second phase rebalance using dynamic partition column Z-Order columns. | 1.6.0 | -| spark.sql.optimizer.zorderUsingOriginalOrdering.enabled | false | When true and `spark.sql.optimizer.rebalanceBeforeZorder.enabled` is true, we do sort by the original ordering i.e. lexicographical order. | 1.6.0 | -| spark.sql.optimizer.inferRebalanceAndSortOrders.enabled | false | When ture, infer columns for rebalance and sort orders from original query, e.g. the join keys from join. It can avoid compression ratio regression. | 1.7.0 | -| spark.sql.optimizer.inferRebalanceAndSortOrdersMaxColumns | 3 | The max columns of inferred columns. | 1.7.0 | -| spark.sql.optimizer.insertRepartitionBeforeWriteIfNoShuffle.enabled | false | When true, add repartition even if the original plan does not have shuffle. | 1.7.0 | -| spark.sql.optimizer.finalStageConfigIsolationWriteOnly.enabled | true | When true, only enable final stage isolation for writing. | 1.7.0 | -| spark.sql.finalWriteStage.eagerlyKillExecutors.enabled | false | When true, eagerly kill redundant executors before running final write stage. | 1.8.0 | -| spark.sql.finalWriteStage.skipKillingExecutorsForTableCache | true | When true, skip killing executors if the plan has table caches. | 1.8.0 | -| spark.sql.finalWriteStage.retainExecutorsFactor | 1.2 | If the target executors * factor < active executors, and target executors * factor > min executors, then inject kill executors or inject custom resource profile. | 1.8.0 | -| spark.sql.finalWriteStage.resourceIsolation.enabled | false | When true, make final write stage resource isolation using custom RDD resource profile. | 1.8.0 | -| spark.sql.finalWriteStageExecutorCores | fallback spark.executor.cores | Specify the executor core request for final write stage. It would be passed to the RDD resource profile. | 1.8.0 | -| spark.sql.finalWriteStageExecutorMemory | fallback spark.executor.memory | Specify the executor on heap memory request for final write stage. It would be passed to the RDD resource profile. | 1.8.0 | -| spark.sql.finalWriteStageExecutorMemoryOverhead | fallback spark.executor.memoryOverhead | Specify the executor memory overhead request for final write stage. It would be passed to the RDD resource profile. | 1.8.0 | -| spark.sql.finalWriteStageExecutorOffHeapMemory | NONE | Specify the executor off heap memory request for final write stage. It would be passed to the RDD resource profile. | 1.8.0 | -| spark.sql.execution.scriptTransformation.enabled | true | When false, script transformation is not allowed. | 1.9.0 | +| Name | Default Value | Description | Since | +|---------------------------------------------------------------------|----------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------| +| spark.sql.optimizer.insertRepartitionBeforeWrite.enabled | true | Add repartition node at the top of query plan. An approach of merging small files. | 1.2.0 | +| spark.sql.optimizer.forceShuffleBeforeJoin.enabled | false | Ensure shuffle node exists before shuffled join (shj and smj) to make AQE `OptimizeSkewedJoin` works (complex scenario join, multi table join). | 1.2.0 | +| spark.sql.optimizer.finalStageConfigIsolation.enabled | false | If true, the final stage support use different config with previous stage. The prefix of final stage config key should be `spark.sql.finalStage.`. For example, the raw spark config: `spark.sql.adaptive.advisoryPartitionSizeInBytes`, then the final stage config should be: `spark.sql.finalStage.adaptive.advisoryPartitionSizeInBytes`. | 1.2.0 | +| spark.sql.optimizer.insertZorderBeforeWriting.enabled | true | When true, we will follow target table properties to insert zorder or not. The key properties are: 1) `kyuubi.zorder.enabled`: if this property is true, we will insert zorder before writing data. 2) `kyuubi.zorder.cols`: string split by comma, we will zorder by these cols. | 1.4.0 | +| spark.sql.optimizer.zorderGlobalSort.enabled | true | When true, we do a global sort using zorder. Note that, it can cause data skew issue if the zorder columns have less cardinality. When false, we only do local sort using zorder. | 1.4.0 | +| spark.sql.watchdog.maxPartitions | none | Set the max partition number when spark scans a data source. Enable maxPartition Strategy by specifying this configuration. Add maxPartitions Strategy to avoid scan excessive partitions on partitioned table, it's optional that works with defined | 1.4.0 | +| spark.sql.watchdog.maxFileSize | none | Set the maximum size in bytes of files when spark scans a data source. Enable maxFileSize Strategy by specifying this configuration. Add maxFileSize Strategy to avoid scan excessive size of files, it's optional that works with defined | 1.8.0 | +| kyuubi.watchdog.dangerousJoin.enabled | false | Enable dangerous join condition detection in planner stage. | 1.10.0 | +| kyuubi.watchdog.dangerousJoin.broadcastRatio | 0.8 | Broadcast threshold coefficient used to identify oversized broadcast fallback. | 1.10.0 | +| kyuubi.watchdog.dangerousJoin.action | WARN | Action when dangerous join is detected, one of `WARN` and `REJECT`. | 1.10.0 | +| spark.sql.optimizer.dropIgnoreNonExistent | false | When true, do not report an error if DROP DATABASE/TABLE/VIEW/FUNCTION/PARTITION specifies a non-existent database/table/view/function/partition | 1.5.0 | +| spark.sql.optimizer.rebalanceBeforeZorder.enabled | false | When true, we do a rebalance before zorder in case data skew. Note that, if the insertion is dynamic partition we will use the partition columns to rebalance. | 1.6.0 | +| spark.sql.optimizer.rebalanceZorderColumns.enabled | false | When true and `spark.sql.optimizer.rebalanceBeforeZorder.enabled` is true, we do rebalance before Z-Order. If it's dynamic partition insert, the rebalance expression will include both partition columns and Z-Order columns. | 1.6.0 | +| spark.sql.optimizer.twoPhaseRebalanceBeforeZorder.enabled | false | When true and `spark.sql.optimizer.rebalanceBeforeZorder.enabled` is true, we do two phase rebalance before Z-Order for the dynamic partition write. The first phase rebalance using dynamic partition column; The second phase rebalance using dynamic partition column Z-Order columns. | 1.6.0 | +| spark.sql.optimizer.zorderUsingOriginalOrdering.enabled | false | When true and `spark.sql.optimizer.rebalanceBeforeZorder.enabled` is true, we do sort by the original ordering i.e. lexicographical order. | 1.6.0 | +| spark.sql.optimizer.inferRebalanceAndSortOrders.enabled | false | When ture, infer columns for rebalance and sort orders from original query, e.g. the join keys from join. It can avoid compression ratio regression. | 1.7.0 | +| spark.sql.optimizer.inferRebalanceAndSortOrdersMaxColumns | 3 | The max columns of inferred columns. | 1.7.0 | +| spark.sql.optimizer.insertRepartitionBeforeWriteIfNoShuffle.enabled | false | When true, add repartition even if the original plan does not have shuffle. | 1.7.0 | +| spark.sql.optimizer.finalStageConfigIsolationWriteOnly.enabled | true | When true, only enable final stage isolation for writing. | 1.7.0 | +| spark.sql.finalWriteStage.eagerlyKillExecutors.enabled | false | When true, eagerly kill redundant executors before running final write stage. | 1.8.0 | +| spark.sql.finalWriteStage.skipKillingExecutorsForTableCache | true | When true, skip killing executors if the plan has table caches. | 1.8.0 | +| spark.sql.finalWriteStage.retainExecutorsFactor | 1.2 | If the target executors * factor < active executors, and target executors * factor > min executors, then inject kill executors or inject custom resource profile. | 1.8.0 | +| spark.sql.finalWriteStage.resourceIsolation.enabled | false | When true, make final write stage resource isolation using custom RDD resource profile. | 1.8.0 | +| spark.sql.finalWriteStageExecutorCores | fallback spark.executor.cores | Specify the executor core request for final write stage. It would be passed to the RDD resource profile. | 1.8.0 | +| spark.sql.finalWriteStageExecutorMemory | fallback spark.executor.memory | Specify the executor on heap memory request for final write stage. It would be passed to the RDD resource profile. | 1.8.0 | +| spark.sql.finalWriteStageExecutorMemoryOverhead | fallback spark.executor.memoryOverhead | Specify the executor memory overhead request for final write stage. It would be passed to the RDD resource profile. | 1.8.0 | +| spark.sql.finalWriteStageExecutorOffHeapMemory | NONE | Specify the executor off heap memory request for final write stage. It would be passed to the RDD resource profile. | 1.8.0 | +| spark.sql.execution.scriptTransformation.enabled | true | When false, script transformation is not allowed. | 1.9.0 | diff --git a/docs/index.rst b/docs/index.rst index 8966ea444a5..4386eb47d41 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -181,6 +181,7 @@ What's Next quick_start/index configuration/settings + watchdog/dangerous-join deployment/index Security monitor/index diff --git a/docs/watchdog/dangerous-join.md b/docs/watchdog/dangerous-join.md new file mode 100644 index 00000000000..d81d9977273 --- /dev/null +++ b/docs/watchdog/dangerous-join.md @@ -0,0 +1,105 @@ + + +# Dangerous Join Watchdog + +Kyuubi Dangerous Join Watchdog detects risky join planning patterns before query execution. +It helps reduce accidental Cartesian products, oversized broadcast attempts, and long-running nested loop joins. + +## Background + +In shared SQL gateway environments, a single risky join can consume excessive driver memory or create very slow jobs. +The Dangerous Join Watchdog adds planning-time checks for these high-risk patterns. + +## Risk Rules + +### Equi-Join + +- Rule 1: Equi-join is marked dangerous when it degrades to a Cartesian pattern. +- Rule 2: Equi-join is marked dangerous when the estimated build side exceeds the configured broadcast ratio threshold. + +### Non-Equi Join + +- Rule 1: Non-equi join is marked dangerous when both sides exceed broadcast threshold and effectively become Cartesian risk. +- Rule 2: Non-equi join is marked dangerous when build side is not selectable and the plan falls back to a second BNLJ pattern. + +## Configurations + +| Name | Default | Meaning | +|------------------------------------------------|---------|---------------------------------------------------------------------------| +| `kyuubi.watchdog.dangerousJoin.enabled` | `true` | Enable or disable dangerous join detection | +| `kyuubi.watchdog.dangerousJoin.broadcastRatio` | `0.8` | Ratio against Spark broadcast threshold for warning/reject decision | +| `kyuubi.watchdog.dangerousJoin.action` | `WARN` | `WARN` logs diagnostics; `REJECT` throws exception and rejects submission | + +## Usage + +1. Put Kyuubi Spark extension jar into Spark classpath. +2. Configure SQL extensions: + +```properties +spark.sql.extensions=org.apache.kyuubi.sql.KyuubiSparkSQLExtension,org.apache.kyuubi.sql.watchdog.KyuubiDangerousJoinExtension +``` + +3. Configure action: + +```properties +kyuubi.watchdog.dangerousJoin.action=WARN +``` + +or + +```properties +kyuubi.watchdog.dangerousJoin.action=REJECT +``` + +## Sample WARN Log + +When action is `WARN`, Kyuubi writes a structured JSON payload: + +```text +KYUUBI_LOG_KEY={"sql":"SELECT ...","joinType":"INNER","reason":"Cartesian","leftSize":10485760,"rightSize":15728640,"broadcastThreshold":10485760,"broadcastRatio":0.8} +``` + +## Sample REJECT Error + +When action is `REJECT`, query submission fails with: + +```text +errorCode=41101 +Query rejected due to dangerous join strategy: {...details...} +``` + +## Disable or Tune + +- Disable watchdog: + +```properties +kyuubi.watchdog.dangerousJoin.enabled=false +``` + +- Increase tolerance: + +```properties +kyuubi.watchdog.dangerousJoin.broadcastRatio=0.95 +``` + +## FAQ + +### What if `spark.sql.adaptive.enabled=true`? + +Dangerous Join Watchdog runs in planner strategy phase and evaluates pre-execution plan statistics. +AQE may still optimize runtime plans, but watchdog decisions are made before query execution starts. diff --git a/extensions/spark/kyuubi-extension-spark-3-3/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala b/extensions/spark/kyuubi-extension-spark-3-3/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala index e72a6c07354..6c38fb0041f 100644 --- a/extensions/spark/kyuubi-extension-spark-3-3/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala +++ b/extensions/spark/kyuubi-extension-spark-3-3/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala @@ -123,6 +123,30 @@ object KyuubiSQLConf { .bytesConf(ByteUnit.BYTE) .createOptional + val DANGEROUS_JOIN_ENABLED = + buildConf("kyuubi.watchdog.dangerousJoin.enabled") + .doc("Enable dangerous join condition detection.") + .version("1.11.0") + .booleanConf + .createWithDefault(false) + + val DANGEROUS_JOIN_BROADCAST_RATIO = + buildConf("kyuubi.watchdog.dangerousJoin.broadcastRatio") + .doc("The threshold ratio to mark oversized broadcast fallback.") + .version("1.11.0") + .doubleConf + .checkValue(v => v > 0 && v <= 1, "must be in (0, 1]") + .createWithDefault(0.8) + + val DANGEROUS_JOIN_ACTION = + buildConf("kyuubi.watchdog.dangerousJoin.action") + .doc("Action when dangerous join is detected, one of WARN and REJECT.") + .version("1.11.0") + .stringConf + .transform(_.toUpperCase(java.util.Locale.ROOT)) + .checkValues(Set("WARN", "REJECT")) + .createWithDefault("WARN") + val DROP_IGNORE_NONEXISTENT = buildConf("spark.sql.optimizer.dropIgnoreNonExistent") .doc("Do not report an error if DROP DATABASE/TABLE/VIEW/FUNCTION/PARTITION specifies " + diff --git a/extensions/spark/kyuubi-extension-spark-3-3/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala b/extensions/spark/kyuubi-extension-spark-3-3/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala index 33ff3e3177a..637df43d826 100644 --- a/extensions/spark/kyuubi-extension-spark-3-3/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala +++ b/extensions/spark/kyuubi-extension-spark-3-3/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala @@ -19,7 +19,7 @@ package org.apache.kyuubi.sql import org.apache.spark.sql.{FinalStageResourceManager, InjectCustomResourceProfile, SparkSessionExtensions} -import org.apache.kyuubi.sql.watchdog.{KyuubiUnsupportedOperationsCheck, MaxScanStrategy} +import org.apache.kyuubi.sql.watchdog.{DangerousJoinInterceptor, KyuubiUnsupportedOperationsCheck, MaxScanStrategy} // scalastyle:off line.size.limit /** @@ -39,6 +39,7 @@ class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) { // watchdog extension extensions.injectCheckRule(_ => KyuubiUnsupportedOperationsCheck) extensions.injectPlannerStrategy(MaxScanStrategy) + extensions.injectPlannerStrategy(DangerousJoinInterceptor(_)) extensions.injectQueryStagePrepRule(FinalStageResourceManager(_)) extensions.injectQueryStagePrepRule(InjectCustomResourceProfile) diff --git a/extensions/spark/kyuubi-extension-spark-3-3/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinCounter.scala b/extensions/spark/kyuubi-extension-spark-3-3/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinCounter.scala new file mode 100644 index 00000000000..d5ff84c5e3f --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-3/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinCounter.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +import scala.collection.mutable.ArrayBuffer + +object DangerousJoinCounter { + case class Entry( + sqlText: String, + joinType: String, + reason: String, + leftSize: BigInt, + rightSize: BigInt, + broadcastThreshold: Long, + broadcastRatio: Double) { + def toJson: String = { + val pairs = Seq( + "sql" -> escape(sqlText), + "joinType" -> escape(joinType), + "reason" -> escape(reason), + "leftSize" -> leftSize.toString, + "rightSize" -> rightSize.toString, + "broadcastThreshold" -> broadcastThreshold.toString, + "broadcastRatio" -> broadcastRatio.toString) + pairs.map { case (k, v) => + if (k == "leftSize" || k == "rightSize" || k == "broadcastThreshold" || k == "broadcastRatio") { + s""""$k":$v""" + } else { + s""""$k":"$v"""" + } + }.mkString("{", ",", "}") + } + } + + private val entries = ArrayBuffer.empty[Entry] + + def add(entry: Entry): Unit = synchronized { + entries += entry + } + + def count: Int = synchronized { + entries.size + } + + def latest: Option[Entry] = synchronized { + entries.lastOption + } + + def snapshot: Seq[Entry] = synchronized { + entries.toSeq + } + + def reset(): Unit = synchronized { + entries.clear() + } + + private def escape(raw: String): String = { + raw + .replace("\\", "\\\\") + .replace("\"", "\\\"") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\t", "\\t") + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-3/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinInterceptor.scala b/extensions/spark/kyuubi-extension-spark-3-3/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinInterceptor.scala new file mode 100644 index 00000000000..4ff7a420e22 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-3/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinInterceptor.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan} +import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy} +import org.slf4j.LoggerFactory + +import org.apache.kyuubi.sql.KyuubiSQLConf + +case class DangerousJoinInterceptor(session: SparkSession) extends SparkStrategy { + import DangerousJoinInterceptor._ + + private val logger = LoggerFactory.getLogger(classOf[DangerousJoinInterceptor]) + + override def apply(plan: LogicalPlan): Seq[SparkPlan] = { + val conf = session.sessionState.conf + if (!conf.getConf(KyuubiSQLConf.DANGEROUS_JOIN_ENABLED)) { + return Nil + } + val ratio = conf.getConf(KyuubiSQLConf.DANGEROUS_JOIN_BROADCAST_RATIO) + val threshold = conf.autoBroadcastJoinThreshold + plan.foreach { + case join: Join => + detect(join, threshold, ratio).foreach { reason => + val entry = DangerousJoinCounter.Entry( + sqlText = plan.toString(), + joinType = join.joinType.sql, + reason = reason, + leftSize = join.left.stats.sizeInBytes, + rightSize = join.right.stats.sizeInBytes, + broadcastThreshold = threshold, + broadcastRatio = ratio) + DangerousJoinCounter.add(entry) + logger.warn(s"$KYUUBI_LOG_KEY=${entry.toJson}") + if (conf.getConf(KyuubiSQLConf.DANGEROUS_JOIN_ACTION) == REJECT) { + throw new KyuubiDangerousJoinException(entry.toJson) + } + } + case _ => + } + Nil + } + + private def detect(join: Join, threshold: Long, ratio: Double): Option[String] = { + if (threshold <= 0) { + return None + } + val leftSize = join.left.stats.sizeInBytes + val rightSize = join.right.stats.sizeInBytes + val hasEquiJoin = isEquiJoin(join) + if (hasEquiJoin) { + if (isCartesianCondition(join.condition)) { + Some("Cartesian") + } else if (minSize(leftSize, rightSize) > BigInt((threshold * ratio).toLong)) { + Some("OversizedBroadcastFallback") + } else { + None + } + } else { + if (leftSize > threshold && rightSize > threshold) { + Some("Cartesian") + } else if (cannotSelectBuildSide(leftSize, rightSize, threshold)) { + Some("SecondBNLJ") + } else { + None + } + } + } + + private def isEquiJoin(join: Join): Boolean = { + join match { + case ExtractEquiJoinKeys(_, _, _, _, _, _, _, _) => true + case _ => false + } + } + + private def isCartesianCondition(condition: Option[Expression]): Boolean = { + condition.forall(!containsJoinKey(_)) + } + + private def containsJoinKey(expr: Expression): Boolean = { + expr match { + case EqualTo(l: AttributeReference, r: AttributeReference) => + l.qualifier.nonEmpty && r.qualifier.nonEmpty && l.qualifier != r.qualifier + case And(l, r) => containsJoinKey(l) || containsJoinKey(r) + case _ => false + } + } + + private def minSize(leftSize: BigInt, rightSize: BigInt): BigInt = { + if (leftSize <= rightSize) leftSize else rightSize + } + + private def cannotSelectBuildSide( + leftSize: BigInt, + rightSize: BigInt, + threshold: Long): Boolean = { + leftSize > threshold && rightSize > threshold + } +} + +object DangerousJoinInterceptor { + val WARN = "WARN" + val REJECT = "REJECT" + val KYUUBI_LOG_KEY = "KYUUBI_LOG_KEY" +} diff --git a/extensions/spark/kyuubi-extension-spark-3-3/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinException.scala b/extensions/spark/kyuubi-extension-spark-3-3/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinException.scala new file mode 100644 index 00000000000..dc2a1336cc8 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-3/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinException.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +class KyuubiDangerousJoinException(details: String) + extends java.sql.SQLException( + s"Query rejected due to dangerous join strategy: $details", + null, + 41101) diff --git a/extensions/spark/kyuubi-extension-spark-3-3/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinExtension.scala b/extensions/spark/kyuubi-extension-spark-3-3/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinExtension.scala new file mode 100644 index 00000000000..32d315bdd0b --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-3/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinExtension.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +import org.apache.spark.sql.SparkSessionExtensions + +class KyuubiDangerousJoinExtension extends (SparkSessionExtensions => Unit) { + override def apply(extensions: SparkSessionExtensions): Unit = { + extensions.injectPlannerStrategy(DangerousJoinInterceptor(_)) + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-3/src/test/scala/org/apache/spark/sql/DangerousJoinInterceptorSuite.scala b/extensions/spark/kyuubi-extension-spark-3-3/src/test/scala/org/apache/spark/sql/DangerousJoinInterceptorSuite.scala new file mode 100644 index 00000000000..1f511972b77 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-3/src/test/scala/org/apache/spark/sql/DangerousJoinInterceptorSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.kyuubi.sql.KyuubiSQLConf +import org.apache.kyuubi.sql.watchdog.{DangerousJoinCounter, KyuubiDangerousJoinException} + +class DangerousJoinInterceptorSuite extends KyuubiSparkSQLExtensionTest { + override protected def beforeAll(): Unit = { + super.beforeAll() + setupData() + } + + test("equi join oversized broadcast fallback should be counted") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "true", + KyuubiSQLConf.DANGEROUS_JOIN_BROADCAST_RATIO.key -> "0.8", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "WARN", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + sql("SELECT * FROM t1 a JOIN t2 b ON a.c1 = b.c1").queryExecution.sparkPlan + assert(DangerousJoinCounter.count >= 1) + } + } + + test("non equi join cartesian should include Cartesian marker") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "true", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "WARN", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + sql("SELECT * FROM t1 a JOIN t2 b ON a.c1 > b.c1").queryExecution.sparkPlan + assert(DangerousJoinCounter.count >= 1) + assert(DangerousJoinCounter.latest.exists(_.toJson.contains("Cartesian"))) + } + } + + test("reject action should throw dangerous join exception with 41101") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "true", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "REJECT", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + val e = intercept[KyuubiDangerousJoinException] { + sql("SELECT * FROM t1 a JOIN t2 b ON a.c1 = b.c1").queryExecution.sparkPlan + } + assert(e.getErrorCode == 41101) + } + } + + test("disabled dangerous join should not count") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "false", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "WARN", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + sql("SELECT * FROM t1 a JOIN t2 b ON a.c1 > b.c1").queryExecution.sparkPlan + assert(DangerousJoinCounter.count == 0) + } + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-3/src/test/scala/org/apache/spark/sql/KyuubiDangerousJoinIT.scala b/extensions/spark/kyuubi-extension-spark-3-3/src/test/scala/org/apache/spark/sql/KyuubiDangerousJoinIT.scala new file mode 100644 index 00000000000..91ae27605ed --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-3/src/test/scala/org/apache/spark/sql/KyuubiDangerousJoinIT.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.kyuubi.sql.KyuubiSQLConf +import org.apache.kyuubi.sql.watchdog.{DangerousJoinCounter, KyuubiDangerousJoinException} + +class KyuubiDangerousJoinIT extends KyuubiSparkSQLExtensionTest { + override protected def beforeAll(): Unit = { + super.beforeAll() + setupData() + } + + test("warn action should keep query successful and emit warning diagnostics") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "true", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "WARN", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + val rows = sql("SELECT count(*) FROM t1 a JOIN t2 b ON a.c1 = b.c1").collect() + assert(rows.nonEmpty) + assert(DangerousJoinCounter.count >= 1) + assert(DangerousJoinCounter.latest.exists(_.toJson.contains("joinType"))) + } + } + + test("reject action should return detailed dangerous join exception") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "true", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "REJECT", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + val e = intercept[KyuubiDangerousJoinException] { + sql("SELECT * FROM t1 a JOIN t2 b ON a.c1 > b.c1").collect() + } + assert(e.getErrorCode == 41101) + assert(e.getMessage.contains("leftSize")) + assert(e.getMessage.contains("rightSize")) + } + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-4/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala b/extensions/spark/kyuubi-extension-spark-3-4/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala index e72a6c07354..6c38fb0041f 100644 --- a/extensions/spark/kyuubi-extension-spark-3-4/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala +++ b/extensions/spark/kyuubi-extension-spark-3-4/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala @@ -123,6 +123,30 @@ object KyuubiSQLConf { .bytesConf(ByteUnit.BYTE) .createOptional + val DANGEROUS_JOIN_ENABLED = + buildConf("kyuubi.watchdog.dangerousJoin.enabled") + .doc("Enable dangerous join condition detection.") + .version("1.11.0") + .booleanConf + .createWithDefault(false) + + val DANGEROUS_JOIN_BROADCAST_RATIO = + buildConf("kyuubi.watchdog.dangerousJoin.broadcastRatio") + .doc("The threshold ratio to mark oversized broadcast fallback.") + .version("1.11.0") + .doubleConf + .checkValue(v => v > 0 && v <= 1, "must be in (0, 1]") + .createWithDefault(0.8) + + val DANGEROUS_JOIN_ACTION = + buildConf("kyuubi.watchdog.dangerousJoin.action") + .doc("Action when dangerous join is detected, one of WARN and REJECT.") + .version("1.11.0") + .stringConf + .transform(_.toUpperCase(java.util.Locale.ROOT)) + .checkValues(Set("WARN", "REJECT")) + .createWithDefault("WARN") + val DROP_IGNORE_NONEXISTENT = buildConf("spark.sql.optimizer.dropIgnoreNonExistent") .doc("Do not report an error if DROP DATABASE/TABLE/VIEW/FUNCTION/PARTITION specifies " + diff --git a/extensions/spark/kyuubi-extension-spark-3-4/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala b/extensions/spark/kyuubi-extension-spark-3-4/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala index 33ff3e3177a..637df43d826 100644 --- a/extensions/spark/kyuubi-extension-spark-3-4/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala +++ b/extensions/spark/kyuubi-extension-spark-3-4/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala @@ -19,7 +19,7 @@ package org.apache.kyuubi.sql import org.apache.spark.sql.{FinalStageResourceManager, InjectCustomResourceProfile, SparkSessionExtensions} -import org.apache.kyuubi.sql.watchdog.{KyuubiUnsupportedOperationsCheck, MaxScanStrategy} +import org.apache.kyuubi.sql.watchdog.{DangerousJoinInterceptor, KyuubiUnsupportedOperationsCheck, MaxScanStrategy} // scalastyle:off line.size.limit /** @@ -39,6 +39,7 @@ class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) { // watchdog extension extensions.injectCheckRule(_ => KyuubiUnsupportedOperationsCheck) extensions.injectPlannerStrategy(MaxScanStrategy) + extensions.injectPlannerStrategy(DangerousJoinInterceptor(_)) extensions.injectQueryStagePrepRule(FinalStageResourceManager(_)) extensions.injectQueryStagePrepRule(InjectCustomResourceProfile) diff --git a/extensions/spark/kyuubi-extension-spark-3-4/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinCounter.scala b/extensions/spark/kyuubi-extension-spark-3-4/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinCounter.scala new file mode 100644 index 00000000000..d5ff84c5e3f --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-4/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinCounter.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +import scala.collection.mutable.ArrayBuffer + +object DangerousJoinCounter { + case class Entry( + sqlText: String, + joinType: String, + reason: String, + leftSize: BigInt, + rightSize: BigInt, + broadcastThreshold: Long, + broadcastRatio: Double) { + def toJson: String = { + val pairs = Seq( + "sql" -> escape(sqlText), + "joinType" -> escape(joinType), + "reason" -> escape(reason), + "leftSize" -> leftSize.toString, + "rightSize" -> rightSize.toString, + "broadcastThreshold" -> broadcastThreshold.toString, + "broadcastRatio" -> broadcastRatio.toString) + pairs.map { case (k, v) => + if (k == "leftSize" || k == "rightSize" || k == "broadcastThreshold" || k == "broadcastRatio") { + s""""$k":$v""" + } else { + s""""$k":"$v"""" + } + }.mkString("{", ",", "}") + } + } + + private val entries = ArrayBuffer.empty[Entry] + + def add(entry: Entry): Unit = synchronized { + entries += entry + } + + def count: Int = synchronized { + entries.size + } + + def latest: Option[Entry] = synchronized { + entries.lastOption + } + + def snapshot: Seq[Entry] = synchronized { + entries.toSeq + } + + def reset(): Unit = synchronized { + entries.clear() + } + + private def escape(raw: String): String = { + raw + .replace("\\", "\\\\") + .replace("\"", "\\\"") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\t", "\\t") + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-4/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinInterceptor.scala b/extensions/spark/kyuubi-extension-spark-3-4/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinInterceptor.scala new file mode 100644 index 00000000000..4ff7a420e22 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-4/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinInterceptor.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan} +import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy} +import org.slf4j.LoggerFactory + +import org.apache.kyuubi.sql.KyuubiSQLConf + +case class DangerousJoinInterceptor(session: SparkSession) extends SparkStrategy { + import DangerousJoinInterceptor._ + + private val logger = LoggerFactory.getLogger(classOf[DangerousJoinInterceptor]) + + override def apply(plan: LogicalPlan): Seq[SparkPlan] = { + val conf = session.sessionState.conf + if (!conf.getConf(KyuubiSQLConf.DANGEROUS_JOIN_ENABLED)) { + return Nil + } + val ratio = conf.getConf(KyuubiSQLConf.DANGEROUS_JOIN_BROADCAST_RATIO) + val threshold = conf.autoBroadcastJoinThreshold + plan.foreach { + case join: Join => + detect(join, threshold, ratio).foreach { reason => + val entry = DangerousJoinCounter.Entry( + sqlText = plan.toString(), + joinType = join.joinType.sql, + reason = reason, + leftSize = join.left.stats.sizeInBytes, + rightSize = join.right.stats.sizeInBytes, + broadcastThreshold = threshold, + broadcastRatio = ratio) + DangerousJoinCounter.add(entry) + logger.warn(s"$KYUUBI_LOG_KEY=${entry.toJson}") + if (conf.getConf(KyuubiSQLConf.DANGEROUS_JOIN_ACTION) == REJECT) { + throw new KyuubiDangerousJoinException(entry.toJson) + } + } + case _ => + } + Nil + } + + private def detect(join: Join, threshold: Long, ratio: Double): Option[String] = { + if (threshold <= 0) { + return None + } + val leftSize = join.left.stats.sizeInBytes + val rightSize = join.right.stats.sizeInBytes + val hasEquiJoin = isEquiJoin(join) + if (hasEquiJoin) { + if (isCartesianCondition(join.condition)) { + Some("Cartesian") + } else if (minSize(leftSize, rightSize) > BigInt((threshold * ratio).toLong)) { + Some("OversizedBroadcastFallback") + } else { + None + } + } else { + if (leftSize > threshold && rightSize > threshold) { + Some("Cartesian") + } else if (cannotSelectBuildSide(leftSize, rightSize, threshold)) { + Some("SecondBNLJ") + } else { + None + } + } + } + + private def isEquiJoin(join: Join): Boolean = { + join match { + case ExtractEquiJoinKeys(_, _, _, _, _, _, _, _) => true + case _ => false + } + } + + private def isCartesianCondition(condition: Option[Expression]): Boolean = { + condition.forall(!containsJoinKey(_)) + } + + private def containsJoinKey(expr: Expression): Boolean = { + expr match { + case EqualTo(l: AttributeReference, r: AttributeReference) => + l.qualifier.nonEmpty && r.qualifier.nonEmpty && l.qualifier != r.qualifier + case And(l, r) => containsJoinKey(l) || containsJoinKey(r) + case _ => false + } + } + + private def minSize(leftSize: BigInt, rightSize: BigInt): BigInt = { + if (leftSize <= rightSize) leftSize else rightSize + } + + private def cannotSelectBuildSide( + leftSize: BigInt, + rightSize: BigInt, + threshold: Long): Boolean = { + leftSize > threshold && rightSize > threshold + } +} + +object DangerousJoinInterceptor { + val WARN = "WARN" + val REJECT = "REJECT" + val KYUUBI_LOG_KEY = "KYUUBI_LOG_KEY" +} diff --git a/extensions/spark/kyuubi-extension-spark-3-4/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinException.scala b/extensions/spark/kyuubi-extension-spark-3-4/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinException.scala new file mode 100644 index 00000000000..dc2a1336cc8 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-4/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinException.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +class KyuubiDangerousJoinException(details: String) + extends java.sql.SQLException( + s"Query rejected due to dangerous join strategy: $details", + null, + 41101) diff --git a/extensions/spark/kyuubi-extension-spark-3-4/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinExtension.scala b/extensions/spark/kyuubi-extension-spark-3-4/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinExtension.scala new file mode 100644 index 00000000000..32d315bdd0b --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-4/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinExtension.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +import org.apache.spark.sql.SparkSessionExtensions + +class KyuubiDangerousJoinExtension extends (SparkSessionExtensions => Unit) { + override def apply(extensions: SparkSessionExtensions): Unit = { + extensions.injectPlannerStrategy(DangerousJoinInterceptor(_)) + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-4/src/test/scala/org/apache/spark/sql/DangerousJoinInterceptorSuite.scala b/extensions/spark/kyuubi-extension-spark-3-4/src/test/scala/org/apache/spark/sql/DangerousJoinInterceptorSuite.scala new file mode 100644 index 00000000000..1f511972b77 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-4/src/test/scala/org/apache/spark/sql/DangerousJoinInterceptorSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.kyuubi.sql.KyuubiSQLConf +import org.apache.kyuubi.sql.watchdog.{DangerousJoinCounter, KyuubiDangerousJoinException} + +class DangerousJoinInterceptorSuite extends KyuubiSparkSQLExtensionTest { + override protected def beforeAll(): Unit = { + super.beforeAll() + setupData() + } + + test("equi join oversized broadcast fallback should be counted") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "true", + KyuubiSQLConf.DANGEROUS_JOIN_BROADCAST_RATIO.key -> "0.8", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "WARN", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + sql("SELECT * FROM t1 a JOIN t2 b ON a.c1 = b.c1").queryExecution.sparkPlan + assert(DangerousJoinCounter.count >= 1) + } + } + + test("non equi join cartesian should include Cartesian marker") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "true", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "WARN", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + sql("SELECT * FROM t1 a JOIN t2 b ON a.c1 > b.c1").queryExecution.sparkPlan + assert(DangerousJoinCounter.count >= 1) + assert(DangerousJoinCounter.latest.exists(_.toJson.contains("Cartesian"))) + } + } + + test("reject action should throw dangerous join exception with 41101") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "true", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "REJECT", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + val e = intercept[KyuubiDangerousJoinException] { + sql("SELECT * FROM t1 a JOIN t2 b ON a.c1 = b.c1").queryExecution.sparkPlan + } + assert(e.getErrorCode == 41101) + } + } + + test("disabled dangerous join should not count") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "false", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "WARN", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + sql("SELECT * FROM t1 a JOIN t2 b ON a.c1 > b.c1").queryExecution.sparkPlan + assert(DangerousJoinCounter.count == 0) + } + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-4/src/test/scala/org/apache/spark/sql/KyuubiDangerousJoinIT.scala b/extensions/spark/kyuubi-extension-spark-3-4/src/test/scala/org/apache/spark/sql/KyuubiDangerousJoinIT.scala new file mode 100644 index 00000000000..91ae27605ed --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-4/src/test/scala/org/apache/spark/sql/KyuubiDangerousJoinIT.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.kyuubi.sql.KyuubiSQLConf +import org.apache.kyuubi.sql.watchdog.{DangerousJoinCounter, KyuubiDangerousJoinException} + +class KyuubiDangerousJoinIT extends KyuubiSparkSQLExtensionTest { + override protected def beforeAll(): Unit = { + super.beforeAll() + setupData() + } + + test("warn action should keep query successful and emit warning diagnostics") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "true", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "WARN", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + val rows = sql("SELECT count(*) FROM t1 a JOIN t2 b ON a.c1 = b.c1").collect() + assert(rows.nonEmpty) + assert(DangerousJoinCounter.count >= 1) + assert(DangerousJoinCounter.latest.exists(_.toJson.contains("joinType"))) + } + } + + test("reject action should return detailed dangerous join exception") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "true", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "REJECT", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + val e = intercept[KyuubiDangerousJoinException] { + sql("SELECT * FROM t1 a JOIN t2 b ON a.c1 > b.c1").collect() + } + assert(e.getErrorCode == 41101) + assert(e.getMessage.contains("leftSize")) + assert(e.getMessage.contains("rightSize")) + } + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala index 9644537a06d..b307bdec875 100644 --- a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala @@ -123,6 +123,30 @@ object KyuubiSQLConf { .bytesConf(ByteUnit.BYTE) .createOptional + val DANGEROUS_JOIN_ENABLED = + buildConf("kyuubi.watchdog.dangerousJoin.enabled") + .doc("Enable dangerous join condition detection.") + .version("1.11.0") + .booleanConf + .createWithDefault(false) + + val DANGEROUS_JOIN_BROADCAST_RATIO = + buildConf("kyuubi.watchdog.dangerousJoin.broadcastRatio") + .doc("The threshold ratio to mark oversized broadcast fallback.") + .version("1.11.0") + .doubleConf + .checkValue(v => v > 0 && v <= 1, "must be in (0, 1]") + .createWithDefault(0.8) + + val DANGEROUS_JOIN_ACTION = + buildConf("kyuubi.watchdog.dangerousJoin.action") + .doc("Action when dangerous join is detected, one of WARN and REJECT.") + .version("1.11.0") + .stringConf + .transform(_.toUpperCase(java.util.Locale.ROOT)) + .checkValues(Set("WARN", "REJECT")) + .createWithDefault("WARN") + val DROP_IGNORE_NONEXISTENT = buildConf("spark.sql.optimizer.dropIgnoreNonExistent") .doc("Do not report an error if DROP DATABASE/TABLE/VIEW/FUNCTION/PARTITION specifies " + diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala index db7f4b6ea30..89e4490eed4 100644 --- a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala @@ -19,7 +19,7 @@ package org.apache.kyuubi.sql import org.apache.spark.sql.{FinalStageResourceManager, InjectCustomResourceProfile, SparkSessionExtensions} -import org.apache.kyuubi.sql.watchdog.{KyuubiUnsupportedOperationsCheck, MaxScanStrategy} +import org.apache.kyuubi.sql.watchdog.{DangerousJoinInterceptor, KyuubiUnsupportedOperationsCheck, MaxScanStrategy} import org.apache.kyuubi.sql.zorder.{InsertZorderBeforeWritingDatasource, InsertZorderBeforeWritingHive, ResolveZorder} // scalastyle:off line.size.limit @@ -48,6 +48,7 @@ class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) { // watchdog extension extensions.injectCheckRule(_ => KyuubiUnsupportedOperationsCheck) extensions.injectPlannerStrategy(MaxScanStrategy) + extensions.injectPlannerStrategy(DangerousJoinInterceptor(_)) extensions.injectQueryStagePrepRule(_ => InsertShuffleNodeBeforeJoin) extensions.injectQueryStagePrepRule(DynamicShufflePartitions) diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinCounter.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinCounter.scala new file mode 100644 index 00000000000..d5ff84c5e3f --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinCounter.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +import scala.collection.mutable.ArrayBuffer + +object DangerousJoinCounter { + case class Entry( + sqlText: String, + joinType: String, + reason: String, + leftSize: BigInt, + rightSize: BigInt, + broadcastThreshold: Long, + broadcastRatio: Double) { + def toJson: String = { + val pairs = Seq( + "sql" -> escape(sqlText), + "joinType" -> escape(joinType), + "reason" -> escape(reason), + "leftSize" -> leftSize.toString, + "rightSize" -> rightSize.toString, + "broadcastThreshold" -> broadcastThreshold.toString, + "broadcastRatio" -> broadcastRatio.toString) + pairs.map { case (k, v) => + if (k == "leftSize" || k == "rightSize" || k == "broadcastThreshold" || k == "broadcastRatio") { + s""""$k":$v""" + } else { + s""""$k":"$v"""" + } + }.mkString("{", ",", "}") + } + } + + private val entries = ArrayBuffer.empty[Entry] + + def add(entry: Entry): Unit = synchronized { + entries += entry + } + + def count: Int = synchronized { + entries.size + } + + def latest: Option[Entry] = synchronized { + entries.lastOption + } + + def snapshot: Seq[Entry] = synchronized { + entries.toSeq + } + + def reset(): Unit = synchronized { + entries.clear() + } + + private def escape(raw: String): String = { + raw + .replace("\\", "\\\\") + .replace("\"", "\\\"") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\t", "\\t") + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinInterceptor.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinInterceptor.scala new file mode 100644 index 00000000000..4ff7a420e22 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinInterceptor.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan} +import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy} +import org.slf4j.LoggerFactory + +import org.apache.kyuubi.sql.KyuubiSQLConf + +case class DangerousJoinInterceptor(session: SparkSession) extends SparkStrategy { + import DangerousJoinInterceptor._ + + private val logger = LoggerFactory.getLogger(classOf[DangerousJoinInterceptor]) + + override def apply(plan: LogicalPlan): Seq[SparkPlan] = { + val conf = session.sessionState.conf + if (!conf.getConf(KyuubiSQLConf.DANGEROUS_JOIN_ENABLED)) { + return Nil + } + val ratio = conf.getConf(KyuubiSQLConf.DANGEROUS_JOIN_BROADCAST_RATIO) + val threshold = conf.autoBroadcastJoinThreshold + plan.foreach { + case join: Join => + detect(join, threshold, ratio).foreach { reason => + val entry = DangerousJoinCounter.Entry( + sqlText = plan.toString(), + joinType = join.joinType.sql, + reason = reason, + leftSize = join.left.stats.sizeInBytes, + rightSize = join.right.stats.sizeInBytes, + broadcastThreshold = threshold, + broadcastRatio = ratio) + DangerousJoinCounter.add(entry) + logger.warn(s"$KYUUBI_LOG_KEY=${entry.toJson}") + if (conf.getConf(KyuubiSQLConf.DANGEROUS_JOIN_ACTION) == REJECT) { + throw new KyuubiDangerousJoinException(entry.toJson) + } + } + case _ => + } + Nil + } + + private def detect(join: Join, threshold: Long, ratio: Double): Option[String] = { + if (threshold <= 0) { + return None + } + val leftSize = join.left.stats.sizeInBytes + val rightSize = join.right.stats.sizeInBytes + val hasEquiJoin = isEquiJoin(join) + if (hasEquiJoin) { + if (isCartesianCondition(join.condition)) { + Some("Cartesian") + } else if (minSize(leftSize, rightSize) > BigInt((threshold * ratio).toLong)) { + Some("OversizedBroadcastFallback") + } else { + None + } + } else { + if (leftSize > threshold && rightSize > threshold) { + Some("Cartesian") + } else if (cannotSelectBuildSide(leftSize, rightSize, threshold)) { + Some("SecondBNLJ") + } else { + None + } + } + } + + private def isEquiJoin(join: Join): Boolean = { + join match { + case ExtractEquiJoinKeys(_, _, _, _, _, _, _, _) => true + case _ => false + } + } + + private def isCartesianCondition(condition: Option[Expression]): Boolean = { + condition.forall(!containsJoinKey(_)) + } + + private def containsJoinKey(expr: Expression): Boolean = { + expr match { + case EqualTo(l: AttributeReference, r: AttributeReference) => + l.qualifier.nonEmpty && r.qualifier.nonEmpty && l.qualifier != r.qualifier + case And(l, r) => containsJoinKey(l) || containsJoinKey(r) + case _ => false + } + } + + private def minSize(leftSize: BigInt, rightSize: BigInt): BigInt = { + if (leftSize <= rightSize) leftSize else rightSize + } + + private def cannotSelectBuildSide( + leftSize: BigInt, + rightSize: BigInt, + threshold: Long): Boolean = { + leftSize > threshold && rightSize > threshold + } +} + +object DangerousJoinInterceptor { + val WARN = "WARN" + val REJECT = "REJECT" + val KYUUBI_LOG_KEY = "KYUUBI_LOG_KEY" +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinException.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinException.scala new file mode 100644 index 00000000000..dc2a1336cc8 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinException.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +class KyuubiDangerousJoinException(details: String) + extends java.sql.SQLException( + s"Query rejected due to dangerous join strategy: $details", + null, + 41101) diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinExtension.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinExtension.scala new file mode 100644 index 00000000000..32d315bdd0b --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinExtension.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +import org.apache.spark.sql.SparkSessionExtensions + +class KyuubiDangerousJoinExtension extends (SparkSessionExtensions => Unit) { + override def apply(extensions: SparkSessionExtensions): Unit = { + extensions.injectPlannerStrategy(DangerousJoinInterceptor(_)) + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/DangerousJoinInterceptorSuite.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/DangerousJoinInterceptorSuite.scala new file mode 100644 index 00000000000..1f511972b77 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/DangerousJoinInterceptorSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.kyuubi.sql.KyuubiSQLConf +import org.apache.kyuubi.sql.watchdog.{DangerousJoinCounter, KyuubiDangerousJoinException} + +class DangerousJoinInterceptorSuite extends KyuubiSparkSQLExtensionTest { + override protected def beforeAll(): Unit = { + super.beforeAll() + setupData() + } + + test("equi join oversized broadcast fallback should be counted") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "true", + KyuubiSQLConf.DANGEROUS_JOIN_BROADCAST_RATIO.key -> "0.8", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "WARN", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + sql("SELECT * FROM t1 a JOIN t2 b ON a.c1 = b.c1").queryExecution.sparkPlan + assert(DangerousJoinCounter.count >= 1) + } + } + + test("non equi join cartesian should include Cartesian marker") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "true", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "WARN", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + sql("SELECT * FROM t1 a JOIN t2 b ON a.c1 > b.c1").queryExecution.sparkPlan + assert(DangerousJoinCounter.count >= 1) + assert(DangerousJoinCounter.latest.exists(_.toJson.contains("Cartesian"))) + } + } + + test("reject action should throw dangerous join exception with 41101") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "true", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "REJECT", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + val e = intercept[KyuubiDangerousJoinException] { + sql("SELECT * FROM t1 a JOIN t2 b ON a.c1 = b.c1").queryExecution.sparkPlan + } + assert(e.getErrorCode == 41101) + } + } + + test("disabled dangerous join should not count") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "false", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "WARN", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + sql("SELECT * FROM t1 a JOIN t2 b ON a.c1 > b.c1").queryExecution.sparkPlan + assert(DangerousJoinCounter.count == 0) + } + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/KyuubiDangerousJoinIT.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/KyuubiDangerousJoinIT.scala new file mode 100644 index 00000000000..91ae27605ed --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/KyuubiDangerousJoinIT.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.kyuubi.sql.KyuubiSQLConf +import org.apache.kyuubi.sql.watchdog.{DangerousJoinCounter, KyuubiDangerousJoinException} + +class KyuubiDangerousJoinIT extends KyuubiSparkSQLExtensionTest { + override protected def beforeAll(): Unit = { + super.beforeAll() + setupData() + } + + test("warn action should keep query successful and emit warning diagnostics") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "true", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "WARN", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + val rows = sql("SELECT count(*) FROM t1 a JOIN t2 b ON a.c1 = b.c1").collect() + assert(rows.nonEmpty) + assert(DangerousJoinCounter.count >= 1) + assert(DangerousJoinCounter.latest.exists(_.toJson.contains("joinType"))) + } + } + + test("reject action should return detailed dangerous join exception") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "true", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "REJECT", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + val e = intercept[KyuubiDangerousJoinException] { + sql("SELECT * FROM t1 a JOIN t2 b ON a.c1 > b.c1").collect() + } + assert(e.getErrorCode == 41101) + assert(e.getMessage.contains("leftSize")) + assert(e.getMessage.contains("rightSize")) + } + } +} diff --git a/extensions/spark/kyuubi-extension-spark-4-0/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala b/extensions/spark/kyuubi-extension-spark-4-0/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala index 9644537a06d..b307bdec875 100644 --- a/extensions/spark/kyuubi-extension-spark-4-0/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala +++ b/extensions/spark/kyuubi-extension-spark-4-0/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala @@ -123,6 +123,30 @@ object KyuubiSQLConf { .bytesConf(ByteUnit.BYTE) .createOptional + val DANGEROUS_JOIN_ENABLED = + buildConf("kyuubi.watchdog.dangerousJoin.enabled") + .doc("Enable dangerous join condition detection.") + .version("1.11.0") + .booleanConf + .createWithDefault(false) + + val DANGEROUS_JOIN_BROADCAST_RATIO = + buildConf("kyuubi.watchdog.dangerousJoin.broadcastRatio") + .doc("The threshold ratio to mark oversized broadcast fallback.") + .version("1.11.0") + .doubleConf + .checkValue(v => v > 0 && v <= 1, "must be in (0, 1]") + .createWithDefault(0.8) + + val DANGEROUS_JOIN_ACTION = + buildConf("kyuubi.watchdog.dangerousJoin.action") + .doc("Action when dangerous join is detected, one of WARN and REJECT.") + .version("1.11.0") + .stringConf + .transform(_.toUpperCase(java.util.Locale.ROOT)) + .checkValues(Set("WARN", "REJECT")) + .createWithDefault("WARN") + val DROP_IGNORE_NONEXISTENT = buildConf("spark.sql.optimizer.dropIgnoreNonExistent") .doc("Do not report an error if DROP DATABASE/TABLE/VIEW/FUNCTION/PARTITION specifies " + diff --git a/extensions/spark/kyuubi-extension-spark-4-0/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala b/extensions/spark/kyuubi-extension-spark-4-0/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala index db7f4b6ea30..89e4490eed4 100644 --- a/extensions/spark/kyuubi-extension-spark-4-0/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala +++ b/extensions/spark/kyuubi-extension-spark-4-0/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala @@ -19,7 +19,7 @@ package org.apache.kyuubi.sql import org.apache.spark.sql.{FinalStageResourceManager, InjectCustomResourceProfile, SparkSessionExtensions} -import org.apache.kyuubi.sql.watchdog.{KyuubiUnsupportedOperationsCheck, MaxScanStrategy} +import org.apache.kyuubi.sql.watchdog.{DangerousJoinInterceptor, KyuubiUnsupportedOperationsCheck, MaxScanStrategy} import org.apache.kyuubi.sql.zorder.{InsertZorderBeforeWritingDatasource, InsertZorderBeforeWritingHive, ResolveZorder} // scalastyle:off line.size.limit @@ -48,6 +48,7 @@ class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) { // watchdog extension extensions.injectCheckRule(_ => KyuubiUnsupportedOperationsCheck) extensions.injectPlannerStrategy(MaxScanStrategy) + extensions.injectPlannerStrategy(DangerousJoinInterceptor(_)) extensions.injectQueryStagePrepRule(_ => InsertShuffleNodeBeforeJoin) extensions.injectQueryStagePrepRule(DynamicShufflePartitions) diff --git a/extensions/spark/kyuubi-extension-spark-4-0/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinCounter.scala b/extensions/spark/kyuubi-extension-spark-4-0/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinCounter.scala new file mode 100644 index 00000000000..d5ff84c5e3f --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-4-0/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinCounter.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +import scala.collection.mutable.ArrayBuffer + +object DangerousJoinCounter { + case class Entry( + sqlText: String, + joinType: String, + reason: String, + leftSize: BigInt, + rightSize: BigInt, + broadcastThreshold: Long, + broadcastRatio: Double) { + def toJson: String = { + val pairs = Seq( + "sql" -> escape(sqlText), + "joinType" -> escape(joinType), + "reason" -> escape(reason), + "leftSize" -> leftSize.toString, + "rightSize" -> rightSize.toString, + "broadcastThreshold" -> broadcastThreshold.toString, + "broadcastRatio" -> broadcastRatio.toString) + pairs.map { case (k, v) => + if (k == "leftSize" || k == "rightSize" || k == "broadcastThreshold" || k == "broadcastRatio") { + s""""$k":$v""" + } else { + s""""$k":"$v"""" + } + }.mkString("{", ",", "}") + } + } + + private val entries = ArrayBuffer.empty[Entry] + + def add(entry: Entry): Unit = synchronized { + entries += entry + } + + def count: Int = synchronized { + entries.size + } + + def latest: Option[Entry] = synchronized { + entries.lastOption + } + + def snapshot: Seq[Entry] = synchronized { + entries.toSeq + } + + def reset(): Unit = synchronized { + entries.clear() + } + + private def escape(raw: String): String = { + raw + .replace("\\", "\\\\") + .replace("\"", "\\\"") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\t", "\\t") + } +} diff --git a/extensions/spark/kyuubi-extension-spark-4-0/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinInterceptor.scala b/extensions/spark/kyuubi-extension-spark-4-0/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinInterceptor.scala new file mode 100644 index 00000000000..4ff7a420e22 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-4-0/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinInterceptor.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan} +import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy} +import org.slf4j.LoggerFactory + +import org.apache.kyuubi.sql.KyuubiSQLConf + +case class DangerousJoinInterceptor(session: SparkSession) extends SparkStrategy { + import DangerousJoinInterceptor._ + + private val logger = LoggerFactory.getLogger(classOf[DangerousJoinInterceptor]) + + override def apply(plan: LogicalPlan): Seq[SparkPlan] = { + val conf = session.sessionState.conf + if (!conf.getConf(KyuubiSQLConf.DANGEROUS_JOIN_ENABLED)) { + return Nil + } + val ratio = conf.getConf(KyuubiSQLConf.DANGEROUS_JOIN_BROADCAST_RATIO) + val threshold = conf.autoBroadcastJoinThreshold + plan.foreach { + case join: Join => + detect(join, threshold, ratio).foreach { reason => + val entry = DangerousJoinCounter.Entry( + sqlText = plan.toString(), + joinType = join.joinType.sql, + reason = reason, + leftSize = join.left.stats.sizeInBytes, + rightSize = join.right.stats.sizeInBytes, + broadcastThreshold = threshold, + broadcastRatio = ratio) + DangerousJoinCounter.add(entry) + logger.warn(s"$KYUUBI_LOG_KEY=${entry.toJson}") + if (conf.getConf(KyuubiSQLConf.DANGEROUS_JOIN_ACTION) == REJECT) { + throw new KyuubiDangerousJoinException(entry.toJson) + } + } + case _ => + } + Nil + } + + private def detect(join: Join, threshold: Long, ratio: Double): Option[String] = { + if (threshold <= 0) { + return None + } + val leftSize = join.left.stats.sizeInBytes + val rightSize = join.right.stats.sizeInBytes + val hasEquiJoin = isEquiJoin(join) + if (hasEquiJoin) { + if (isCartesianCondition(join.condition)) { + Some("Cartesian") + } else if (minSize(leftSize, rightSize) > BigInt((threshold * ratio).toLong)) { + Some("OversizedBroadcastFallback") + } else { + None + } + } else { + if (leftSize > threshold && rightSize > threshold) { + Some("Cartesian") + } else if (cannotSelectBuildSide(leftSize, rightSize, threshold)) { + Some("SecondBNLJ") + } else { + None + } + } + } + + private def isEquiJoin(join: Join): Boolean = { + join match { + case ExtractEquiJoinKeys(_, _, _, _, _, _, _, _) => true + case _ => false + } + } + + private def isCartesianCondition(condition: Option[Expression]): Boolean = { + condition.forall(!containsJoinKey(_)) + } + + private def containsJoinKey(expr: Expression): Boolean = { + expr match { + case EqualTo(l: AttributeReference, r: AttributeReference) => + l.qualifier.nonEmpty && r.qualifier.nonEmpty && l.qualifier != r.qualifier + case And(l, r) => containsJoinKey(l) || containsJoinKey(r) + case _ => false + } + } + + private def minSize(leftSize: BigInt, rightSize: BigInt): BigInt = { + if (leftSize <= rightSize) leftSize else rightSize + } + + private def cannotSelectBuildSide( + leftSize: BigInt, + rightSize: BigInt, + threshold: Long): Boolean = { + leftSize > threshold && rightSize > threshold + } +} + +object DangerousJoinInterceptor { + val WARN = "WARN" + val REJECT = "REJECT" + val KYUUBI_LOG_KEY = "KYUUBI_LOG_KEY" +} diff --git a/extensions/spark/kyuubi-extension-spark-4-0/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinException.scala b/extensions/spark/kyuubi-extension-spark-4-0/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinException.scala new file mode 100644 index 00000000000..dc2a1336cc8 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-4-0/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinException.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +class KyuubiDangerousJoinException(details: String) + extends java.sql.SQLException( + s"Query rejected due to dangerous join strategy: $details", + null, + 41101) diff --git a/extensions/spark/kyuubi-extension-spark-4-0/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinExtension.scala b/extensions/spark/kyuubi-extension-spark-4-0/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinExtension.scala new file mode 100644 index 00000000000..32d315bdd0b --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-4-0/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinExtension.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +import org.apache.spark.sql.SparkSessionExtensions + +class KyuubiDangerousJoinExtension extends (SparkSessionExtensions => Unit) { + override def apply(extensions: SparkSessionExtensions): Unit = { + extensions.injectPlannerStrategy(DangerousJoinInterceptor(_)) + } +} diff --git a/extensions/spark/kyuubi-extension-spark-4-0/src/test/scala/org/apache/spark/sql/DangerousJoinInterceptorSuite.scala b/extensions/spark/kyuubi-extension-spark-4-0/src/test/scala/org/apache/spark/sql/DangerousJoinInterceptorSuite.scala new file mode 100644 index 00000000000..1f511972b77 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-4-0/src/test/scala/org/apache/spark/sql/DangerousJoinInterceptorSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.kyuubi.sql.KyuubiSQLConf +import org.apache.kyuubi.sql.watchdog.{DangerousJoinCounter, KyuubiDangerousJoinException} + +class DangerousJoinInterceptorSuite extends KyuubiSparkSQLExtensionTest { + override protected def beforeAll(): Unit = { + super.beforeAll() + setupData() + } + + test("equi join oversized broadcast fallback should be counted") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "true", + KyuubiSQLConf.DANGEROUS_JOIN_BROADCAST_RATIO.key -> "0.8", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "WARN", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + sql("SELECT * FROM t1 a JOIN t2 b ON a.c1 = b.c1").queryExecution.sparkPlan + assert(DangerousJoinCounter.count >= 1) + } + } + + test("non equi join cartesian should include Cartesian marker") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "true", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "WARN", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + sql("SELECT * FROM t1 a JOIN t2 b ON a.c1 > b.c1").queryExecution.sparkPlan + assert(DangerousJoinCounter.count >= 1) + assert(DangerousJoinCounter.latest.exists(_.toJson.contains("Cartesian"))) + } + } + + test("reject action should throw dangerous join exception with 41101") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "true", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "REJECT", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + val e = intercept[KyuubiDangerousJoinException] { + sql("SELECT * FROM t1 a JOIN t2 b ON a.c1 = b.c1").queryExecution.sparkPlan + } + assert(e.getErrorCode == 41101) + } + } + + test("disabled dangerous join should not count") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "false", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "WARN", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + sql("SELECT * FROM t1 a JOIN t2 b ON a.c1 > b.c1").queryExecution.sparkPlan + assert(DangerousJoinCounter.count == 0) + } + } +} diff --git a/extensions/spark/kyuubi-extension-spark-4-0/src/test/scala/org/apache/spark/sql/KyuubiDangerousJoinIT.scala b/extensions/spark/kyuubi-extension-spark-4-0/src/test/scala/org/apache/spark/sql/KyuubiDangerousJoinIT.scala new file mode 100644 index 00000000000..91ae27605ed --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-4-0/src/test/scala/org/apache/spark/sql/KyuubiDangerousJoinIT.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.kyuubi.sql.KyuubiSQLConf +import org.apache.kyuubi.sql.watchdog.{DangerousJoinCounter, KyuubiDangerousJoinException} + +class KyuubiDangerousJoinIT extends KyuubiSparkSQLExtensionTest { + override protected def beforeAll(): Unit = { + super.beforeAll() + setupData() + } + + test("warn action should keep query successful and emit warning diagnostics") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "true", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "WARN", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + val rows = sql("SELECT count(*) FROM t1 a JOIN t2 b ON a.c1 = b.c1").collect() + assert(rows.nonEmpty) + assert(DangerousJoinCounter.count >= 1) + assert(DangerousJoinCounter.latest.exists(_.toJson.contains("joinType"))) + } + } + + test("reject action should return detailed dangerous join exception") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "true", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "REJECT", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + val e = intercept[KyuubiDangerousJoinException] { + sql("SELECT * FROM t1 a JOIN t2 b ON a.c1 > b.c1").collect() + } + assert(e.getErrorCode == 41101) + assert(e.getMessage.contains("leftSize")) + assert(e.getMessage.contains("rightSize")) + } + } +} diff --git a/extensions/spark/kyuubi-extension-spark-4-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala b/extensions/spark/kyuubi-extension-spark-4-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala index 9644537a06d..465c12d9615 100644 --- a/extensions/spark/kyuubi-extension-spark-4-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala +++ b/extensions/spark/kyuubi-extension-spark-4-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala @@ -123,6 +123,30 @@ object KyuubiSQLConf { .bytesConf(ByteUnit.BYTE) .createOptional + val DANGEROUS_JOIN_ENABLED = + buildConf("kyuubi.watchdog.dangerousJoin.enabled") + .doc("Enable dangerous join condition detection.") + .version("1.10.0") + .booleanConf + .createWithDefault(true) + + val DANGEROUS_JOIN_BROADCAST_RATIO = + buildConf("kyuubi.watchdog.dangerousJoin.broadcastRatio") + .doc("The threshold ratio to mark oversized broadcast fallback.") + .version("1.10.0") + .doubleConf + .checkValue(v => v > 0 && v <= 1, "must be in (0, 1]") + .createWithDefault(0.8) + + val DANGEROUS_JOIN_ACTION = + buildConf("kyuubi.watchdog.dangerousJoin.action") + .doc("Action when dangerous join is detected, one of WARN and REJECT.") + .version("1.10.0") + .stringConf + .transform(_.toUpperCase(java.util.Locale.ROOT)) + .checkValues(Set("WARN", "REJECT")) + .createWithDefault("WARN") + val DROP_IGNORE_NONEXISTENT = buildConf("spark.sql.optimizer.dropIgnoreNonExistent") .doc("Do not report an error if DROP DATABASE/TABLE/VIEW/FUNCTION/PARTITION specifies " + diff --git a/extensions/spark/kyuubi-extension-spark-4-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala b/extensions/spark/kyuubi-extension-spark-4-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala index db7f4b6ea30..89e4490eed4 100644 --- a/extensions/spark/kyuubi-extension-spark-4-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala +++ b/extensions/spark/kyuubi-extension-spark-4-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala @@ -19,7 +19,7 @@ package org.apache.kyuubi.sql import org.apache.spark.sql.{FinalStageResourceManager, InjectCustomResourceProfile, SparkSessionExtensions} -import org.apache.kyuubi.sql.watchdog.{KyuubiUnsupportedOperationsCheck, MaxScanStrategy} +import org.apache.kyuubi.sql.watchdog.{DangerousJoinInterceptor, KyuubiUnsupportedOperationsCheck, MaxScanStrategy} import org.apache.kyuubi.sql.zorder.{InsertZorderBeforeWritingDatasource, InsertZorderBeforeWritingHive, ResolveZorder} // scalastyle:off line.size.limit @@ -48,6 +48,7 @@ class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) { // watchdog extension extensions.injectCheckRule(_ => KyuubiUnsupportedOperationsCheck) extensions.injectPlannerStrategy(MaxScanStrategy) + extensions.injectPlannerStrategy(DangerousJoinInterceptor(_)) extensions.injectQueryStagePrepRule(_ => InsertShuffleNodeBeforeJoin) extensions.injectQueryStagePrepRule(DynamicShufflePartitions) diff --git a/extensions/spark/kyuubi-extension-spark-4-1/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinCounter.scala b/extensions/spark/kyuubi-extension-spark-4-1/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinCounter.scala new file mode 100644 index 00000000000..d5ff84c5e3f --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-4-1/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinCounter.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +import scala.collection.mutable.ArrayBuffer + +object DangerousJoinCounter { + case class Entry( + sqlText: String, + joinType: String, + reason: String, + leftSize: BigInt, + rightSize: BigInt, + broadcastThreshold: Long, + broadcastRatio: Double) { + def toJson: String = { + val pairs = Seq( + "sql" -> escape(sqlText), + "joinType" -> escape(joinType), + "reason" -> escape(reason), + "leftSize" -> leftSize.toString, + "rightSize" -> rightSize.toString, + "broadcastThreshold" -> broadcastThreshold.toString, + "broadcastRatio" -> broadcastRatio.toString) + pairs.map { case (k, v) => + if (k == "leftSize" || k == "rightSize" || k == "broadcastThreshold" || k == "broadcastRatio") { + s""""$k":$v""" + } else { + s""""$k":"$v"""" + } + }.mkString("{", ",", "}") + } + } + + private val entries = ArrayBuffer.empty[Entry] + + def add(entry: Entry): Unit = synchronized { + entries += entry + } + + def count: Int = synchronized { + entries.size + } + + def latest: Option[Entry] = synchronized { + entries.lastOption + } + + def snapshot: Seq[Entry] = synchronized { + entries.toSeq + } + + def reset(): Unit = synchronized { + entries.clear() + } + + private def escape(raw: String): String = { + raw + .replace("\\", "\\\\") + .replace("\"", "\\\"") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\t", "\\t") + } +} diff --git a/extensions/spark/kyuubi-extension-spark-4-1/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinInterceptor.scala b/extensions/spark/kyuubi-extension-spark-4-1/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinInterceptor.scala new file mode 100644 index 00000000000..4ff7a420e22 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-4-1/src/main/scala/org/apache/kyuubi/sql/watchdog/DangerousJoinInterceptor.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan} +import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy} +import org.slf4j.LoggerFactory + +import org.apache.kyuubi.sql.KyuubiSQLConf + +case class DangerousJoinInterceptor(session: SparkSession) extends SparkStrategy { + import DangerousJoinInterceptor._ + + private val logger = LoggerFactory.getLogger(classOf[DangerousJoinInterceptor]) + + override def apply(plan: LogicalPlan): Seq[SparkPlan] = { + val conf = session.sessionState.conf + if (!conf.getConf(KyuubiSQLConf.DANGEROUS_JOIN_ENABLED)) { + return Nil + } + val ratio = conf.getConf(KyuubiSQLConf.DANGEROUS_JOIN_BROADCAST_RATIO) + val threshold = conf.autoBroadcastJoinThreshold + plan.foreach { + case join: Join => + detect(join, threshold, ratio).foreach { reason => + val entry = DangerousJoinCounter.Entry( + sqlText = plan.toString(), + joinType = join.joinType.sql, + reason = reason, + leftSize = join.left.stats.sizeInBytes, + rightSize = join.right.stats.sizeInBytes, + broadcastThreshold = threshold, + broadcastRatio = ratio) + DangerousJoinCounter.add(entry) + logger.warn(s"$KYUUBI_LOG_KEY=${entry.toJson}") + if (conf.getConf(KyuubiSQLConf.DANGEROUS_JOIN_ACTION) == REJECT) { + throw new KyuubiDangerousJoinException(entry.toJson) + } + } + case _ => + } + Nil + } + + private def detect(join: Join, threshold: Long, ratio: Double): Option[String] = { + if (threshold <= 0) { + return None + } + val leftSize = join.left.stats.sizeInBytes + val rightSize = join.right.stats.sizeInBytes + val hasEquiJoin = isEquiJoin(join) + if (hasEquiJoin) { + if (isCartesianCondition(join.condition)) { + Some("Cartesian") + } else if (minSize(leftSize, rightSize) > BigInt((threshold * ratio).toLong)) { + Some("OversizedBroadcastFallback") + } else { + None + } + } else { + if (leftSize > threshold && rightSize > threshold) { + Some("Cartesian") + } else if (cannotSelectBuildSide(leftSize, rightSize, threshold)) { + Some("SecondBNLJ") + } else { + None + } + } + } + + private def isEquiJoin(join: Join): Boolean = { + join match { + case ExtractEquiJoinKeys(_, _, _, _, _, _, _, _) => true + case _ => false + } + } + + private def isCartesianCondition(condition: Option[Expression]): Boolean = { + condition.forall(!containsJoinKey(_)) + } + + private def containsJoinKey(expr: Expression): Boolean = { + expr match { + case EqualTo(l: AttributeReference, r: AttributeReference) => + l.qualifier.nonEmpty && r.qualifier.nonEmpty && l.qualifier != r.qualifier + case And(l, r) => containsJoinKey(l) || containsJoinKey(r) + case _ => false + } + } + + private def minSize(leftSize: BigInt, rightSize: BigInt): BigInt = { + if (leftSize <= rightSize) leftSize else rightSize + } + + private def cannotSelectBuildSide( + leftSize: BigInt, + rightSize: BigInt, + threshold: Long): Boolean = { + leftSize > threshold && rightSize > threshold + } +} + +object DangerousJoinInterceptor { + val WARN = "WARN" + val REJECT = "REJECT" + val KYUUBI_LOG_KEY = "KYUUBI_LOG_KEY" +} diff --git a/extensions/spark/kyuubi-extension-spark-4-1/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinException.scala b/extensions/spark/kyuubi-extension-spark-4-1/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinException.scala new file mode 100644 index 00000000000..dc2a1336cc8 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-4-1/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinException.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +class KyuubiDangerousJoinException(details: String) + extends java.sql.SQLException( + s"Query rejected due to dangerous join strategy: $details", + null, + 41101) diff --git a/extensions/spark/kyuubi-extension-spark-4-1/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinExtension.scala b/extensions/spark/kyuubi-extension-spark-4-1/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinExtension.scala new file mode 100644 index 00000000000..32d315bdd0b --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-4-1/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiDangerousJoinExtension.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +import org.apache.spark.sql.SparkSessionExtensions + +class KyuubiDangerousJoinExtension extends (SparkSessionExtensions => Unit) { + override def apply(extensions: SparkSessionExtensions): Unit = { + extensions.injectPlannerStrategy(DangerousJoinInterceptor(_)) + } +} diff --git a/extensions/spark/kyuubi-extension-spark-4-1/src/test/scala/org/apache/spark/sql/DangerousJoinInterceptorSuite.scala b/extensions/spark/kyuubi-extension-spark-4-1/src/test/scala/org/apache/spark/sql/DangerousJoinInterceptorSuite.scala new file mode 100644 index 00000000000..1f511972b77 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-4-1/src/test/scala/org/apache/spark/sql/DangerousJoinInterceptorSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.kyuubi.sql.KyuubiSQLConf +import org.apache.kyuubi.sql.watchdog.{DangerousJoinCounter, KyuubiDangerousJoinException} + +class DangerousJoinInterceptorSuite extends KyuubiSparkSQLExtensionTest { + override protected def beforeAll(): Unit = { + super.beforeAll() + setupData() + } + + test("equi join oversized broadcast fallback should be counted") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "true", + KyuubiSQLConf.DANGEROUS_JOIN_BROADCAST_RATIO.key -> "0.8", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "WARN", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + sql("SELECT * FROM t1 a JOIN t2 b ON a.c1 = b.c1").queryExecution.sparkPlan + assert(DangerousJoinCounter.count >= 1) + } + } + + test("non equi join cartesian should include Cartesian marker") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "true", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "WARN", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + sql("SELECT * FROM t1 a JOIN t2 b ON a.c1 > b.c1").queryExecution.sparkPlan + assert(DangerousJoinCounter.count >= 1) + assert(DangerousJoinCounter.latest.exists(_.toJson.contains("Cartesian"))) + } + } + + test("reject action should throw dangerous join exception with 41101") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "true", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "REJECT", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + val e = intercept[KyuubiDangerousJoinException] { + sql("SELECT * FROM t1 a JOIN t2 b ON a.c1 = b.c1").queryExecution.sparkPlan + } + assert(e.getErrorCode == 41101) + } + } + + test("disabled dangerous join should not count") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "false", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "WARN", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + sql("SELECT * FROM t1 a JOIN t2 b ON a.c1 > b.c1").queryExecution.sparkPlan + assert(DangerousJoinCounter.count == 0) + } + } +} diff --git a/extensions/spark/kyuubi-extension-spark-4-1/src/test/scala/org/apache/spark/sql/KyuubiDangerousJoinIT.scala b/extensions/spark/kyuubi-extension-spark-4-1/src/test/scala/org/apache/spark/sql/KyuubiDangerousJoinIT.scala new file mode 100644 index 00000000000..91ae27605ed --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-4-1/src/test/scala/org/apache/spark/sql/KyuubiDangerousJoinIT.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.kyuubi.sql.KyuubiSQLConf +import org.apache.kyuubi.sql.watchdog.{DangerousJoinCounter, KyuubiDangerousJoinException} + +class KyuubiDangerousJoinIT extends KyuubiSparkSQLExtensionTest { + override protected def beforeAll(): Unit = { + super.beforeAll() + setupData() + } + + test("warn action should keep query successful and emit warning diagnostics") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "true", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "WARN", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + val rows = sql("SELECT count(*) FROM t1 a JOIN t2 b ON a.c1 = b.c1").collect() + assert(rows.nonEmpty) + assert(DangerousJoinCounter.count >= 1) + assert(DangerousJoinCounter.latest.exists(_.toJson.contains("joinType"))) + } + } + + test("reject action should return detailed dangerous join exception") { + DangerousJoinCounter.reset() + withSQLConf( + KyuubiSQLConf.DANGEROUS_JOIN_ENABLED.key -> "true", + KyuubiSQLConf.DANGEROUS_JOIN_ACTION.key -> "REJECT", + "spark.sql.autoBroadcastJoinThreshold" -> "1") { + val e = intercept[KyuubiDangerousJoinException] { + sql("SELECT * FROM t1 a JOIN t2 b ON a.c1 > b.c1").collect() + } + assert(e.getErrorCode == 41101) + assert(e.getMessage.contains("leftSize")) + assert(e.getMessage.contains("rightSize")) + } + } +}