Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 28 additions & 20 deletions src/snowflake/snowpark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""Context module for Snowpark."""
import logging
import sys
from typing import Callable, Optional
from typing import Any, Callable, Optional

import snowflake.snowpark
import threading
Expand Down Expand Up @@ -45,6 +45,11 @@
set()
) # lower cased names of aggregation functions, used in sql simplification
_aggregation_function_set_lock = threading.RLock()
_aggregation_function_prefetch_state: dict[str, Any] = {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be a dict? Can we just use 3 variables or a singleton class instead?

Copy link
Copy Markdown
Collaborator Author

@sfc-gh-yuwang sfc-gh-yuwang May 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes we definitely can, I thought using a dictionary make it more clear that this is for the agg function prefetch

"lock": threading.RLock(),
"event": None,
"job": None,
}

# Hardcoded fallback for system built-in aggregation functions.
# Used when the dynamic query fails to retrieve the list from the database.
Expand All @@ -62,35 +67,36 @@
"ai_agg",
"ai_summarize_agg",
"any_value",
"approximate_count_distinct",
"approximate_jaccard_index",
"approximate_similarity",
"approx_count_distinct",
"approx_percentile",
"approx_percentile_accumulate",
"approx_percentile_combine",
"approx_top_k",
"approx_top_k_accumulate",
"approx_top_k_combine",
"arrayagg",
"approximate_count_distinct",
"approximate_jaccard_index",
"approximate_similarity",
"array_agg",
"array_union_agg",
"array_unique_agg",
"arrayagg",
"avg",
"bitandagg",
"bit_and_agg",
"bit_andagg",
"bit_or_agg",
"bit_oragg",
"bit_xor_agg",
"bit_xoragg",
"bitand_agg",
"bitandagg",
"bitmap_and_agg",
"bitmap_construct_agg",
"bitmap_or_agg",
"bitoragg",
"bitor_agg",
"bitxoragg",
"bitoragg",
"bitxor_agg",
"bit_andagg",
"bit_and_agg",
"bit_oragg",
"bit_or_agg",
"bit_xoragg",
"bit_xor_agg",
"bitxoragg",
"booland_agg",
"boolor_agg",
"boolxor_agg",
Expand All @@ -115,12 +121,12 @@
"max_by",
"median",
"min",
"min_by",
"minhash",
"minhash_combine",
"min_by",
"mode",
"objectagg",
"object_agg",
"objectagg",
"percentile_cont",
"percentile_disc",
"regr_avgx",
Expand All @@ -133,27 +139,29 @@
"regr_sxy",
"regr_syy",
"skew",
"st_intersection_agg_geography_internal",
"st_union_agg_geography_internal",
"stddev",
"stddev_pop",
"stddev_samp",
"st_intersection_agg_geography_internal",
"st_union_agg_geography_internal",
"sum",
"sum_internal",
"sum_internal_real",
"sum_real",
"summarize_agg",
"var_pop",
"var_samp",
"variance",
"variance_pop",
"variance_samp",
"var_pop",
"var_samp",
"vector_avg",
"vector_max",
"vector_min",
"vector_sum",
]
)


_cte_error_threshold = 3 # 0 to disable auto-cte-disable, otherwise the number of times CTE optimization can fail before it is automatically disabled for the remainder of the session.

# Following are internal-only global flags, used to enable development features.
Expand Down
142 changes: 108 additions & 34 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from collections import defaultdict
from functools import reduce
from logging import getLogger
from threading import RLock
from threading import Event, RLock
from types import ModuleType
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -858,6 +858,7 @@ def __init__(
self._client_telemetry = EventTableTelemetry(session=self)

self._ast_batch = AstBatch(self)
self._start_async_aggregation_prefetch_if_needed()

_logger.info("Snowpark Session information: %s", self._session_info)

Expand Down Expand Up @@ -5055,43 +5056,116 @@ def _retrieve_aggregation_function_list(self) -> None:
return

retrieved_set = set()
system_fetch_succeeded = False
prefetch_state = context._aggregation_function_prefetch_state

# Atomically claim the async job. The claiming thread creates an Event so concurrent
# threads can wait on it rather than issuing redundant sync queries.
# AsyncJob.result() is not thread-safe — the underlying connector cursor mutates
# shared state (_result, _rownumber, _prefetch_hook) during result fetching, so only
# one thread may call it. The lock is held only for the pointer swap and event setup
# (nanoseconds), not the network call itself.
with prefetch_state["lock"]:
job, prefetch_state["job"] = prefetch_state["job"], None
if job is not None:
fetch_event = Event()
prefetch_state["event"] = fetch_event
wait_event = None
elif prefetch_state["event"] is not None:
fetch_event = None
wait_event = prefetch_state["event"]
else:
fetch_event = None
wait_event = None

if wait_event is not None:
# The query typically finishes in ~5s; 20s gives ample headroom while
# bounding the hang in the rare case the winner thread dies before its
# finally block runs (e.g. os._exit, interpreter shutdown).
wait_event.wait(timeout=20)
if context._aggregation_function_set:
Comment thread
sfc-gh-bkogan marked this conversation as resolved.
Outdated
return
# Winner failed or timed out; fall through to sync query.

# User-defined aggregation functions
try:
retrieved_set.update(
{
r[0].lower()
for r in self.sql(
"""select function_name from information_schema.functions where is_aggregate = 'YES'"""
).collect()
}
)
except Exception as e:
_logger.debug(
"Unable to get user-defined aggregation functions: %s",
e,
)
if job is not None:
try:
retrieved_set.update({r[0].lower() for r in job.result()})
system_fetch_succeeded = True
except Exception as e:
_logger.debug(
"Unable to use async aggregation function prefetch: %s",
e,
)
else:
_logger.debug(
"Async aggregation function prefetch job is unavailable; using sync fallback."
)

# System built-in aggregation functions
try:
retrieved_set.update(
{
r[0].lower()
for r in self.sql(
"""show functions ->> select "name" from $1 where "is_aggregate" = 'Y'"""
Comment thread
sfc-gh-joshi marked this conversation as resolved.
).collect()
}
)
except Exception as e:
_logger.debug(
"Unable to get system aggregation functions, "
"falling back to hardcoded list: %s",
e,
)
retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS)
# Sync fallback query.
if not system_fetch_succeeded:
try:
retrieved_set.update(
{
r[0].lower()
for r in self._conn.run_query(
"""show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""",
_is_internal=True,
)["data"]
}
)
system_fetch_succeeded = True
except Exception as e:
_logger.debug(
"Unable to get aggregation functions via sync fallback query: %s",
e,
)

with context._aggregation_function_set_lock:
context._aggregation_function_set.update(retrieved_set)
# Fallback to the local hardcoded list only when metadata retrieval fails.
if not system_fetch_succeeded:
retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS)
Comment thread
sfc-gh-bkogan marked this conversation as resolved.

with context._aggregation_function_set_lock:
context._aggregation_function_set.update(retrieved_set)
finally:
# Signal after _aggregation_function_set is published so waiters see
# the populated set immediately upon waking. Also fires on BaseException
# (e.g. KeyboardInterrupt) so waiters are never left blocking until timeout.
if fetch_event is not None:
fetch_event.set()

def _start_async_aggregation_prefetch_if_needed(self) -> None:
"""Start aggregation metadata prefetch only when not already in progress."""
if not (
context._is_snowpark_connect_compatible_mode
and context._snowpark_connect_flatten_select_after_sort
):
return
prefetch_state = context._aggregation_function_prefetch_state
with prefetch_state["lock"]:
if context._aggregation_function_set:
return
if prefetch_state["job"] is not None:
return
# A winner thread has already claimed the async job and is still publishing results.
# Do not start a new async query while that in-flight fetch is unfinished.
if (
prefetch_state["event"] is not None
and not prefetch_state["event"].is_set()
):
return
try:
result = self._conn.execute_async_and_notify_query_listener(
"""show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""",
_is_internal=True,
)
prefetch_state["job"] = self.create_async_job(result["queryId"])
except Exception as e: # pragma: no cover
_logger.debug(
"Unable to start async aggregation metadata prefetch: %s",
e,
)
prefetch_state["job"] = None

def directory(self, stage_name: str, _emit_ast: bool = True) -> DataFrame:
"""
Expand Down
Loading
Loading