[SPARK-44462][SS][CONNECT] Fix the session passed to foreachBatch in Spark Connect#55410
[SPARK-44462][SS][CONNECT] Fix the session passed to foreachBatch in Spark Connect#55410LuciferYang wants to merge 7 commits intoapache:masterfrom
Conversation
Add a `dataFrameQueryIndex` map to `SessionHolder` that tracks the active cached DataFrame ID per streaming query. Before caching a new batch DataFrame in `dataFrameCachingWrapper`, we check for any stale entry from a previous batch and remove it, logging a warning. This addresses TODO 1 in StreamingForeachBatchHelper. Changes: - `SessionHolder`: Add `dataFrameQueryIndex` ConcurrentMap - `StreamingForeachBatchHelper`: Accept `queryIdRef` in `dataFrameCachingWrapper`, perform sanity check, clean up mapping in finally block. Update `scalaForeachBatchWrapper` and `pythonForeachBatchWrapper` return types to include `AtomicReference[String]`. - `SparkConnectPlanner`: Destructure new return types, set query id on the AtomicReference after query starts. - `SparkConnectSessionHolderSuite`: Update test to destructure 3-element tuple from `pythonForeachBatchWrapper`.
…nHolder for foreachBatch When a streaming query starts, `StreamExecution` clones the SparkSession. The DataFrame passed to foreachBatch has this cloned session, but the `SessionHolder` still references the original session. This causes: 1. Session mismatch: batch DataFrames operate on a different session 2. Session lifetime leak: Python worker keeps original session alive 3. CachedRemoteRelation resolves against the wrong session This commit fixes all three issues by: - Adding `registerExistingSession` to `SparkConnectSessionManager` to wrap an existing SparkSession (the stream clone) under a new session ID without creating another clone. - Adding `ForeachBatchSessionManager` that lazily creates a stream-level `SessionHolder` on the first batch invocation and reuses it for subsequent batches. - Adding `ForeachBatchCleaner` that closes both the Python runner and the stream `SessionHolder` on query termination. - Sending the stream session ID per-batch to the Python worker so it resolves `CachedRemoteRelation` against the correct session. - Updating `foreach_batch_worker.py` to read the stream session ID and create/switch the SparkSession accordingly. The stream `SessionHolder` uses `customInactiveTimeoutMs = -1` (never expire by inactivity); its lifecycle is managed by query termination via `CleanerCache`. Addresses TODO 2 and TODO 3 in StreamingForeachBatchHelper, removing all SPARK-44462 TODO comments.
…chHelper Run scalafmt to fix formatting violations detected by CI linter.
| .asInstanceOf[org.apache.spark.sql.classic.SparkSession] | ||
| val streamSessionId = UUID.randomUUID().toString | ||
| _streamSessionHolder = SparkConnectService.sessionManager | ||
| .registerExistingSession(parentSessionHolder.userId, streamSessionId, streamSession) |
There was a problem hiding this comment.
The registerExistingSession call uses parentSessionHolder.userId and a new random UUID, but nothing ties this stream session back to the parent session's artifacts or classloader. If the user's foreachBatch function relies on session-level artifacts (added jars, UDFs, etc.) from the parent session, would those be missing in the stream SessionHolder? The underlying SparkSession is the cloned one from StreamExecution, so Spark-level config should carry over, but Connect-level SessionHolder state (artifact manager, plan cache, etc.) would be fresh/empty.
There was a problem hiding this comment.
Checked this. StreamExecution clones via sparkSession.cloneSession() (StreamExecution.scala:164), and SparkSession.cloneSession() force-copies ArtifactManager and its resources (SparkSession.scala:295). Our cloned SessionHolder wraps that cloned classic SparkSession, and SessionHolder.artifactManager just delegates to session.artifactManager, so jars and the UDF classloader added to the root session before the stream starts are inherited. Artifacts added after the stream is running would not auto-sync, but that mirrors classic mode.
| # The per-batch stream session id will be received with each batch and used to create | ||
| # or switch to the correct stream-level session. | ||
| connect_url_init = connect_url + ";session_id=" + session_id | ||
| spark_connect_session = SparkSession.builder.remote(connect_url_init).getOrCreate() |
There was a problem hiding this comment.
The initial spark_connect_session is created but is never assigned to spark and never used afterward, t only serves as a validation? Is this intended to
remain open for the lifetime of the worker?
There was a problem hiding this comment.
I guess we need a session to read the FEB func, but cloned session is not yet available so we have to use the original session first
There was a problem hiding this comment.
+1 to @bogao007. We need a session to unpickle the foreachBatch function before the first batch arrives, at which point the cloned session id is not yet known. Once batches start flowing, spark is reassigned to the cloned session via .create(), so the initial one is bootstrap-only.
| log_name = "Streaming ForeachBatch worker" | ||
|
|
||
| def process(df_id, batch_id): # type: ignore[no-untyped-def] | ||
| def process(df_id, batch_id, stream_session_id): # type: ignore[no-untyped-def] |
There was a problem hiding this comment.
Can we use clone_session_id to indicate that this is a cloned session? To differentiate, maybe we could name the previous session to be root_session_id.
There was a problem hiding this comment.
Renamed to cloned_session_id / root_session_id (matches the Scala side: clonedSession* / rootSessionHolder).
| # The per-batch stream session id will be received with each batch and used to create | ||
| # or switch to the correct stream-level session. | ||
| connect_url_init = connect_url + ";session_id=" + session_id | ||
| spark_connect_session = SparkSession.builder.remote(connect_url_init).getOrCreate() |
There was a problem hiding this comment.
I guess we need a session to read the FEB func, but cloned session is not yet available so we have to use the original session first
| def curried_function(df): | ||
| def inner(batch_df, batch_id): | ||
| df.createOrReplaceTempView("updates") | ||
| batch_df.createOrReplaceTempView("batch_updates") |
There was a problem hiding this comment.
Can we still keep the createOrReplaceTempView() to verify that root session cannot access the temp view created by the cloned session? I think this is an important test case and we should make sure the behavior aligns with the classic ForeachBatch
There was a problem hiding this comment.
Added test_temp_view_is_isolated_from_root_session. It creates a temp view inside FEB, writes a sentinel table so the test can prove collect_batch actually ran, then asserts the root session hits TABLE_OR_VIEW_NOT_FOUND when reading that view. Kept test_nested_dataframes on saveAsTable — that one is about cross-session persistent tables.
| log"[session: ${MDC(SESSION_ID, effectiveHolder.sessionId)}] " + | ||
| log"Caching DataFrame with id ${MDC(DATAFRAME_ID, dfId)}") | ||
|
|
||
| // Sanity check: remove any stale DataFrame left over from a previous batch for this query. |
There was a problem hiding this comment.
Why do we need this check? We are already removing the DF from the cache after the batch completes
There was a problem hiding this comment.
It is defensive. Happy path is handled by the finally, but if a previous batch skipped its cleanup (async interruption / unchecked error path), a stale dfId could still be sitting in dataFrameQueryIndex. This also lines up with the original TODO(SPARK-44462) that the sanity-check commit replaced. I tightened the comment to spell this out. Fine to drop it if you prefer; it is strictly belt-and-suspenders.
There was a problem hiding this comment.
I see, yeah I'm okay with it since it's fixing an existing TODO.
| // Sanity check: remove any stale DataFrame left over from a previous batch for this query. | ||
| val queryId = queryIdRef.get() | ||
|
|
||
| effectiveHolder.cacheDataFrameById(dfId, df) |
There was a problem hiding this comment.
For a case where a table created outside FEB func() but used inside FEB() to perform a Spark operation, it would probably fail because it's created by the root session and would use root session id to find the cached dataframe. How do we handle that case?
There was a problem hiding this comment.
Depends on what kind of table:
- persistent / managed tables (Hive metastore etc.): fine, visible across sessions, no
CachedRemoteRelationlookup involved. - temp views from the root session: not visible in the cloned session — intentional isolation, same as classic.
- DataFrames cached via
df.persist()in the root session and captured in the closure: this one would fail because theCachedRemoteRelationid is only registered in the rootSessionHolder, not the cloned one. That is a real gap but outside the scope of this PR; I will file a follow-up JIRA. WDYT?
There was a problem hiding this comment.
Sure, I'm okay with a follow up Jira, thanks!
| extends Logging { | ||
| @volatile private var _streamSessionHolder: SessionHolder = null | ||
|
|
||
| def getOrCreateStreamSessionHolder(batchDf: DataFrame): SessionHolder = { |
There was a problem hiding this comment.
Same as above: can we use clonedSession as naming to indicate that this is a cloned session?
There was a problem hiding this comment.
Done — renamed streamSession* → clonedSession* and the ctor param parentSessionHolder → rootSessionHolder so Scala and Python stay symmetric.
- Rename streamSession/parentSessionHolder -> clonedSession/rootSessionHolder across Scala and Python for consistency with StreamExecution's semantics. - Tighten close() to clear the reference before calling closeSession so a failing close cannot leave a half-closed holder; add success/failure logs with MDC. - Eagerly close the cleaner in registerCleanerForQuery when the query is already inactive, so the cloned SessionHolder (which never expires by inactivity) cannot leak if the query terminates before cleaner registration. - Add a parity test test_temp_view_is_isolated_from_root_session that writes a sentinel table from inside foreachBatch to prove the batch ran, then asserts the root session hits TABLE_OR_VIEW_NOT_FOUND for a temp view created only in the cloned session.
- Stub `query.isActive` in StreamingForeachBatchHelperSuite's mockQuery so the eager-cleanup branch added in the previous commit doesn't immediately purge the registered cleaners. - Rewrap the ForeachBatchSessionManager docstring so no line exceeds scalafmt's maxColumn=98.
| * The returned SessionHolder has `customInactiveTimeoutMs = Some(-1)` (never expire by | ||
| * inactivity); its lifecycle is managed by the streaming query termination. | ||
| */ | ||
| private[connect] def registerExistingSession( |
There was a problem hiding this comment.
@hvanhovell could you help review the changes here? Thanks!
bogao007
left a comment
There was a problem hiding this comment.
LGTM, thanks for making the fix! I've asked @hvanhovell to help review changes as well.
|
Thank you @bogao007 ~ |
What changes were proposed in this pull request?
Two fixes for foreachBatch session handling in Spark Connect:
DataFrame cache sanity check: Track which DataFrame is cached per streaming query in
SessionHolder. When caching a new batch DataFrame, detect and remove any stale one left over from a previous batch. Pass the query ID into the foreachBatch wrappers so it can be set after the query starts.Dedicated stream SessionHolder: On the first batch, lazily create a new
SessionHolderwrappingStreamExecution's cloned SparkSession (instead of the original). Register it withSparkConnectSessionManagerwith no inactivity timeout since the streaming query manages its lifecycle. Clean up on query termination viaForeachBatchCleaner. On the Python side, each batch now receives(dfId, batchId, streamSessionId)from the server and the worker creates a session bound to the stream session ID.Why are the changes needed?
SessionHolderreferenced the original SparkSession, not theStreamExecutionclone (sparkSessionForStream) that batch DataFrames actually use:This caused: (1) batch DataFrames ran against the cloned session but
SessionHolderpointed at the original, so session-level state was invisible to Connect; (2) the Python foreachBatch worker operated under the original session ID, keeping the parent session active and delaying cleanup after client disconnect; (3) nothing cleaned up stale cached DataFrames from a previous batch that did not exit cleanly.Does this PR introduce any user-facing change?
No.
How was this patch tested?
SparkConnectSessionHolderSuiteto destructure the new 3-tuple return frompythonForeachBatchWrapper.test_nested_dataframesinStreamingForeachBatchParityTests-- exercises both closured and batch DataFrames inside foreachBatch, usingsaveAsTablefor cross-session visibility.Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude Code