diff --git a/extensions/spark/kyuubi-extension-spark-4-1/src/test/scala/org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala b/extensions/spark/kyuubi-extension-spark-4-1/src/test/scala/org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala index d1c439d2231..d3db9eecc2a 100644 --- a/extensions/spark/kyuubi-extension-spark-4-1/src/test/scala/org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala +++ b/extensions/spark/kyuubi-extension-spark-4-1/src/test/scala/org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql +import java.util.concurrent.atomic.AtomicBoolean + import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.SparkConf import org.apache.spark.sql.classic.SparkSession @@ -100,13 +102,17 @@ trait KyuubiSparkSQLExtensionTest extends QueryTest withListener(sql(sqlString))(callback) } - def withListener(df: => DataFrame)(callback: DataWritingCommand => Unit): Unit = { + def withListener(df: => DataFrame, mustBeCalled: Boolean = true)( + callback: DataWritingCommand => Unit): Unit = { + val called = new AtomicBoolean(false) val listener = new QueryExecutionListener { override def onFailure(f: String, qe: QueryExecution, e: Exception): Unit = {} override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { - qe.executedPlan match { - case write: DataWritingCommandExec => callback(write.cmd) + collect(qe.executedPlan) { + case write: DataWritingCommandExec => + called.set(true) + callback(write.cmd) case _ => } } @@ -115,6 +121,7 @@ trait KyuubiSparkSQLExtensionTest extends QueryTest try { df.collect() sparkContext.listenerBus.waitUntilEmpty() + assert(!mustBeCalled || called.get(), "callback function should be executed.") } finally { spark.listenerManager.unregister(listener) } diff --git a/extensions/spark/kyuubi-extension-spark-4-1/src/test/scala/org/apache/spark/sql/RebalanceBeforeWritingSuite.scala b/extensions/spark/kyuubi-extension-spark-4-1/src/test/scala/org/apache/spark/sql/RebalanceBeforeWritingSuite.scala index d63e79996b2..b0b25efab30 100644 --- a/extensions/spark/kyuubi-extension-spark-4-1/src/test/scala/org/apache/spark/sql/RebalanceBeforeWritingSuite.scala +++ b/extensions/spark/kyuubi-extension-spark-4-1/src/test/scala/org/apache/spark/sql/RebalanceBeforeWritingSuite.scala @@ -109,7 +109,7 @@ class RebalanceBeforeWritingSuite extends KyuubiSparkSQLExtensionTest { test("check rebalance does not exists") { def check(df: DataFrame): Unit = { - withListener(df) { write => + withListener(df, false) { write => assert(write.collect { case r: RebalancePartitions => r }.isEmpty)