Skip to content

[SPARK-44462][SS][CONNECT] Fix the session passed to foreachBatch in Spark Connect#55410

Open
LuciferYang wants to merge 7 commits intoapache:masterfrom
LuciferYang:SPARK-44462-foreachBatch-session-fix
Open

[SPARK-44462][SS][CONNECT] Fix the session passed to foreachBatch in Spark Connect#55410
LuciferYang wants to merge 7 commits intoapache:masterfrom
LuciferYang:SPARK-44462-foreachBatch-session-fix

Conversation

@LuciferYang
Copy link
Copy Markdown
Contributor

What changes were proposed in this pull request?

Two fixes for foreachBatch session handling in Spark Connect:

  1. DataFrame cache sanity check: Track which DataFrame is cached per streaming query in SessionHolder. When caching a new batch DataFrame, detect and remove any stale one left over from a previous batch. Pass the query ID into the foreachBatch wrappers so it can be set after the query starts.

  2. Dedicated stream SessionHolder: On the first batch, lazily create a new SessionHolder wrapping StreamExecution's cloned SparkSession (instead of the original). Register it with SparkConnectSessionManager with no inactivity timeout since the streaming query manages its lifecycle. Clean up on query termination via ForeachBatchCleaner. On the Python side, each batch now receives (dfId, batchId, streamSessionId) from the server and the worker creates a session bound to the stream session ID.

Why are the changes needed?

SessionHolder referenced the original SparkSession, not the StreamExecution clone (sparkSessionForStream) that batch DataFrames actually use:

// StreamExecution.scala
protected[sql] val sparkSessionForStream: SparkSession = sparkSession.cloneSession()

This caused: (1) batch DataFrames ran against the cloned session but SessionHolder pointed at the original, so session-level state was invisible to Connect; (2) the Python foreachBatch worker operated under the original session ID, keeping the parent session active and delaying cleanup after client disconnect; (3) nothing cleaned up stale cached DataFrames from a previous batch that did not exit cleanly.

Does this PR introduce any user-facing change?

No.

How was this patch tested?

  • Updated SparkConnectSessionHolderSuite to destructure the new 3-tuple return from pythonForeachBatchWrapper.
  • Added test_nested_dataframes in StreamingForeachBatchParityTests -- exercises both closured and batch DataFrames inside foreachBatch, using saveAsTable for cross-session visibility.
  • Pass Github Actions

Was this patch authored or co-authored using generative AI tooling?

Generated-by: Claude Code

Add a `dataFrameQueryIndex` map to `SessionHolder` that tracks the
active cached DataFrame ID per streaming query. Before caching a new
batch DataFrame in `dataFrameCachingWrapper`, we check for any stale
entry from a previous batch and remove it, logging a warning.

This addresses TODO 1 in StreamingForeachBatchHelper.

Changes:
- `SessionHolder`: Add `dataFrameQueryIndex` ConcurrentMap
- `StreamingForeachBatchHelper`: Accept `queryIdRef` in
  `dataFrameCachingWrapper`, perform sanity check, clean up mapping
  in finally block. Update `scalaForeachBatchWrapper` and
  `pythonForeachBatchWrapper` return types to include
  `AtomicReference[String]`.
- `SparkConnectPlanner`: Destructure new return types, set query id
  on the AtomicReference after query starts.
- `SparkConnectSessionHolderSuite`: Update test to destructure
  3-element tuple from `pythonForeachBatchWrapper`.
…nHolder for foreachBatch

When a streaming query starts, `StreamExecution` clones the SparkSession.
The DataFrame passed to foreachBatch has this cloned session, but the
`SessionHolder` still references the original session. This causes:

1. Session mismatch: batch DataFrames operate on a different session
2. Session lifetime leak: Python worker keeps original session alive
3. CachedRemoteRelation resolves against the wrong session

This commit fixes all three issues by:

- Adding `registerExistingSession` to `SparkConnectSessionManager` to
  wrap an existing SparkSession (the stream clone) under a new session
  ID without creating another clone.
- Adding `ForeachBatchSessionManager` that lazily creates a stream-level
  `SessionHolder` on the first batch invocation and reuses it for
  subsequent batches.
- Adding `ForeachBatchCleaner` that closes both the Python runner and
  the stream `SessionHolder` on query termination.
- Sending the stream session ID per-batch to the Python worker so it
  resolves `CachedRemoteRelation` against the correct session.
- Updating `foreach_batch_worker.py` to read the stream session ID and
  create/switch the SparkSession accordingly.

The stream `SessionHolder` uses `customInactiveTimeoutMs = -1` (never
expire by inactivity); its lifecycle is managed by query termination
via `CleanerCache`.

Addresses TODO 2 and TODO 3 in StreamingForeachBatchHelper, removing
all SPARK-44462 TODO comments.
@HyukjinKwon
Copy link
Copy Markdown
Member

@heyihong @HeartSaVioR FYI

@HyukjinKwon HyukjinKwon changed the title [SPARK-44462][CONNECT] Fix the session passed to foreachBatch in Spark Connect [SPARK-44462][SS][CONNECT] Fix the session passed to foreachBatch in Spark Connect Apr 19, 2026
.asInstanceOf[org.apache.spark.sql.classic.SparkSession]
val streamSessionId = UUID.randomUUID().toString
_streamSessionHolder = SparkConnectService.sessionManager
.registerExistingSession(parentSessionHolder.userId, streamSessionId, streamSession)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The registerExistingSession call uses parentSessionHolder.userId and a new random UUID, but nothing ties this stream session back to the parent session's artifacts or classloader. If the user's foreachBatch function relies on session-level artifacts (added jars, UDFs, etc.) from the parent session, would those be missing in the stream SessionHolder? The underlying SparkSession is the cloned one from StreamExecution, so Spark-level config should carry over, but Connect-level SessionHolder state (artifact manager, plan cache, etc.) would be fresh/empty.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checked this. StreamExecution clones via sparkSession.cloneSession() (StreamExecution.scala:164), and SparkSession.cloneSession() force-copies ArtifactManager and its resources (SparkSession.scala:295). Our cloned SessionHolder wraps that cloned classic SparkSession, and SessionHolder.artifactManager just delegates to session.artifactManager, so jars and the UDF classloader added to the root session before the stream starts are inherited. Artifacts added after the stream is running would not auto-sync, but that mirrors classic mode.

# The per-batch stream session id will be received with each batch and used to create
# or switch to the correct stream-level session.
connect_url_init = connect_url + ";session_id=" + session_id
spark_connect_session = SparkSession.builder.remote(connect_url_init).getOrCreate()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The initial spark_connect_session is created but is never assigned to spark and never used afterward, t only serves as a validation? Is this intended to
remain open for the lifetime of the worker?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we need a session to read the FEB func, but cloned session is not yet available so we have to use the original session first

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to @bogao007. We need a session to unpickle the foreachBatch function before the first batch arrives, at which point the cloned session id is not yet known. Once batches start flowing, spark is reassigned to the cloned session via .create(), so the initial one is bootstrap-only.

log_name = "Streaming ForeachBatch worker"

def process(df_id, batch_id): # type: ignore[no-untyped-def]
def process(df_id, batch_id, stream_session_id): # type: ignore[no-untyped-def]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use clone_session_id to indicate that this is a cloned session? To differentiate, maybe we could name the previous session to be root_session_id.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to cloned_session_id / root_session_id (matches the Scala side: clonedSession* / rootSessionHolder).

# The per-batch stream session id will be received with each batch and used to create
# or switch to the correct stream-level session.
connect_url_init = connect_url + ";session_id=" + session_id
spark_connect_session = SparkSession.builder.remote(connect_url_init).getOrCreate()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we need a session to read the FEB func, but cloned session is not yet available so we have to use the original session first

def curried_function(df):
def inner(batch_df, batch_id):
df.createOrReplaceTempView("updates")
batch_df.createOrReplaceTempView("batch_updates")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we still keep the createOrReplaceTempView() to verify that root session cannot access the temp view created by the cloned session? I think this is an important test case and we should make sure the behavior aligns with the classic ForeachBatch

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added test_temp_view_is_isolated_from_root_session. It creates a temp view inside FEB, writes a sentinel table so the test can prove collect_batch actually ran, then asserts the root session hits TABLE_OR_VIEW_NOT_FOUND when reading that view. Kept test_nested_dataframes on saveAsTable — that one is about cross-session persistent tables.

log"[session: ${MDC(SESSION_ID, effectiveHolder.sessionId)}] " +
log"Caching DataFrame with id ${MDC(DATAFRAME_ID, dfId)}")

// Sanity check: remove any stale DataFrame left over from a previous batch for this query.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this check? We are already removing the DF from the cache after the batch completes

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is defensive. Happy path is handled by the finally, but if a previous batch skipped its cleanup (async interruption / unchecked error path), a stale dfId could still be sitting in dataFrameQueryIndex. This also lines up with the original TODO(SPARK-44462) that the sanity-check commit replaced. I tightened the comment to spell this out. Fine to drop it if you prefer; it is strictly belt-and-suspenders.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, yeah I'm okay with it since it's fixing an existing TODO.

// Sanity check: remove any stale DataFrame left over from a previous batch for this query.
val queryId = queryIdRef.get()

effectiveHolder.cacheDataFrameById(dfId, df)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a case where a table created outside FEB func() but used inside FEB() to perform a Spark operation, it would probably fail because it's created by the root session and would use root session id to find the cached dataframe. How do we handle that case?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depends on what kind of table:

  • persistent / managed tables (Hive metastore etc.): fine, visible across sessions, no CachedRemoteRelation lookup involved.
  • temp views from the root session: not visible in the cloned session — intentional isolation, same as classic.
  • DataFrames cached via df.persist() in the root session and captured in the closure: this one would fail because the CachedRemoteRelation id is only registered in the root SessionHolder, not the cloned one. That is a real gap but outside the scope of this PR; I will file a follow-up JIRA. WDYT?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'm okay with a follow up Jira, thanks!

extends Logging {
@volatile private var _streamSessionHolder: SessionHolder = null

def getOrCreateStreamSessionHolder(batchDf: DataFrame): SessionHolder = {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above: can we use clonedSession as naming to indicate that this is a cloned session?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — renamed streamSession*clonedSession* and the ctor param parentSessionHolderrootSessionHolder so Scala and Python stay symmetric.

- Rename streamSession/parentSessionHolder -> clonedSession/rootSessionHolder
  across Scala and Python for consistency with StreamExecution's semantics.
- Tighten close() to clear the reference before calling closeSession so a
  failing close cannot leave a half-closed holder; add success/failure logs
  with MDC.
- Eagerly close the cleaner in registerCleanerForQuery when the query is
  already inactive, so the cloned SessionHolder (which never expires by
  inactivity) cannot leak if the query terminates before cleaner registration.
- Add a parity test test_temp_view_is_isolated_from_root_session that writes
  a sentinel table from inside foreachBatch to prove the batch ran, then
  asserts the root session hits TABLE_OR_VIEW_NOT_FOUND for a temp view
  created only in the cloned session.
- Stub `query.isActive` in StreamingForeachBatchHelperSuite's mockQuery
  so the eager-cleanup branch added in the previous commit doesn't
  immediately purge the registered cleaners.
- Rewrap the ForeachBatchSessionManager docstring so no line exceeds
  scalafmt's maxColumn=98.
* The returned SessionHolder has `customInactiveTimeoutMs = Some(-1)` (never expire by
* inactivity); its lifecycle is managed by the streaming query termination.
*/
private[connect] def registerExistingSession(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hvanhovell could you help review the changes here? Thanks!

Copy link
Copy Markdown
Contributor

@bogao007 bogao007 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for making the fix! I've asked @hvanhovell to help review changes as well.

@LuciferYang
Copy link
Copy Markdown
Contributor Author

Thank you @bogao007 ~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants