Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 24 additions & 19 deletions src/snowflake/snowpark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,42 +57,45 @@
"ai_agg",
"ai_summarize_agg",
"any_value",
"approximate_count_distinct",
"approximate_jaccard_index",
"approximate_similarity",
"approx_count_distinct",
"approx_percentile",
"approx_percentile_accumulate",
"approx_percentile_combine",
"approx_top_k",
"approx_top_k_accumulate",
"approx_top_k_combine",
"arrayagg",
"approximate_count_distinct",
"approximate_jaccard_index",
"approximate_similarity",
"array_agg",
"array_union_agg",
"array_unique_agg",
"arrayagg",
"avg",
"bitandagg",
"bit_and_agg",
"bit_andagg",
"bit_or_agg",
"bit_oragg",
"bit_xor_agg",
"bit_xoragg",
"bitand_agg",
"bitandagg",
"bitmap_and_agg",
"bitmap_construct_agg",
"bitmap_or_agg",
"bitoragg",
"bitor_agg",
"bitxoragg",
"bitoragg",
"bitxor_agg",
"bit_andagg",
"bit_and_agg",
"bit_oragg",
"bit_or_agg",
"bit_xoragg",
"bit_xor_agg",
"bitxoragg",
"booland_agg",
"boolor_agg",
"boolxor_agg",
"corr",
"count",
"count(*)",
"count_if",
"count_internal",
"count_internal(*)",
"covar_pop",
"covar_samp",
"datasketches_hll",
Expand All @@ -110,12 +113,12 @@
"max_by",
"median",
"min",
"min_by",
"minhash",
"minhash_combine",
"min_by",
"mode",
"objectagg",
"object_agg",
"objectagg",
"percentile_cont",
"percentile_disc",
"regr_avgx",
Expand All @@ -128,27 +131,29 @@
"regr_sxy",
"regr_syy",
"skew",
"st_intersection_agg_geography_internal",
"st_union_agg_geography_internal",
"stddev",
"stddev_pop",
"stddev_samp",
"st_intersection_agg_geography_internal",
"st_union_agg_geography_internal",
"sum",
"sum_internal",
"sum_internal_real",
"sum_real",
"summarize_agg",
"var_pop",
"var_samp",
"variance",
"variance_pop",
"variance_samp",
"var_pop",
"var_samp",
"vector_avg",
"vector_max",
"vector_min",
"vector_sum",
]
)


_cte_error_threshold = 3 # 0 to disable auto-cte-disable, otherwise the number of times CTE optimization can fail before it is automatically disabled for the remainder of the session.

# Following are internal-only global flags, used to enable development features.
Expand Down
161 changes: 132 additions & 29 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from collections import defaultdict
from functools import reduce
from logging import getLogger
from threading import RLock
from threading import Event, Lock, RLock
from types import ModuleType
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -856,8 +856,15 @@ def __init__(
self._dataframe_profiler = DataframeProfiler(session=self)
self._catalog = None
self._client_telemetry = EventTableTelemetry(session=self)
self._agg_function_prefetch_job: Optional[AsyncJob] = None
# Guards the one-time atomic claim of _agg_function_prefetch_job.
self._agg_function_prefetch_lock = Lock()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's a way for the same thread to attempt to acquire this lock multiple times, but I think we should make this an RLock instead (which is already imported in this file) to be safe.

# Set by the thread that claimed the async job once it finishes (success or failure),
# so other threads can wait instead of issuing redundant sync queries.
self._agg_function_fetch_event: Optional[Event] = None

self._ast_batch = AstBatch(self)
self._start_async_aggregation_prefetch_if_needed()

_logger.info("Snowpark Session information: %s", self._session_info)

Expand Down Expand Up @@ -5055,43 +5062,139 @@ def _retrieve_aggregation_function_list(self) -> None:
return

retrieved_set = set()
system_fetch_succeeded = False

# Atomically claim the async job. The claiming thread creates an Event so concurrent
# threads can wait on it rather than issuing redundant sync queries.
# AsyncJob.result() is not thread-safe — the underlying connector cursor mutates
# shared state (_result, _rownumber, _prefetch_hook) during result fetching, so only
# one thread may call it. The lock is held only for the pointer swap and event setup
# (nanoseconds), not the network call itself.
with self._agg_function_prefetch_lock:
job, self._agg_function_prefetch_job = self._agg_function_prefetch_job, None
if job is not None:
fetch_event = Event()
self._agg_function_fetch_event = fetch_event
wait_event = None
elif self._agg_function_fetch_event is not None:
fetch_event = None
wait_event = self._agg_function_fetch_event
else:
fetch_event = None
wait_event = None

if wait_event is not None:
# The query typically finishes in ~5s; 20s gives ample headroom while
# bounding the hang in the rare case the winner thread dies before its
# finally block runs (e.g. os._exit, interpreter shutdown).
wait_event.wait(timeout=20)
if context._aggregation_function_set:
Comment thread
sfc-gh-bkogan marked this conversation as resolved.
Outdated
return
# Winner failed or timed out; fall through to sync query.

# User-defined aggregation functions
try:
retrieved_set.update(
{
r[0].lower()
for r in self.sql(
"""select function_name from information_schema.functions where is_aggregate = 'YES'"""
).collect()
}
)
except Exception as e:
_logger.debug(
"Unable to get user-defined aggregation functions: %s",
e,
)
if job is not None:
try:
retrieved_set.update(
{r[0].lower() for r in job.result()}
)
system_fetch_succeeded = True
except Exception as e:
_logger.debug(
"Unable to use async aggregation function prefetch: %s",
e,
)
else:
_logger.debug(
"Async aggregation function prefetch job is unavailable; using sync fallback."
)
try:
retrieved_set.update(
{
r[0].lower()
for r in self._conn.run_query(
"""show functions ->> select "name" from $1 where "is_aggregate" = 'Y'
union
select function_name from information_schema.functions where is_aggregate = 'YES'""",
_is_internal=True,
)["data"]
}
)
system_fetch_succeeded = True
except Exception as e:
_logger.debug(
"Unable to get aggregation functions via sync union query: %s",
e,
)

# Sync fallback query.
if not system_fetch_succeeded:
try:
retrieved_set.update(
{
r[0].lower()
for r in self._conn.run_query(
"""show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""",
_is_internal=True,
)["data"]
}
)
system_fetch_succeeded = True
except Exception as e:
_logger.debug(
"Unable to get aggregation functions via sync fallback query: %s",
e,
)

# Fallback to the local hardcoded list only when metadata retrieval fails.
if not system_fetch_succeeded:
retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS)
Comment thread
sfc-gh-bkogan marked this conversation as resolved.

with context._aggregation_function_set_lock:
context._aggregation_function_set.update(retrieved_set)
finally:
# Signal after _aggregation_function_set is published so waiters see
# the populated set immediately upon waking. Also fires on BaseException
# (e.g. KeyboardInterrupt) so waiters are never left blocking until timeout.
if fetch_event is not None:
fetch_event.set()

def _start_async_aggregation_prefetch_if_needed(self) -> None:
"""Start aggregation metadata prefetch only when not already in progress."""
if not (
context._is_snowpark_connect_compatible_mode
and context._snowpark_connect_flatten_select_after_sort
):
return
if context._aggregation_function_set:
return
if self._agg_function_prefetch_job is not None:
return

# System built-in aggregation functions
try:
retrieved_set.update(
{
r[0].lower()
for r in self.sql(
"""show functions ->> select "name" from $1 where "is_aggregate" = 'Y'"""
Comment thread
sfc-gh-joshi marked this conversation as resolved.
).collect()
}
self._agg_function_prefetch_job = self._submit_internal_async_prefetch_query(
Comment thread
sfc-gh-joshi marked this conversation as resolved.
Outdated
"""show functions ->> select "name" from $1 where "is_aggregate" = 'Y'
union
select function_name from information_schema.functions where is_aggregate = 'YES'"""
)
except Exception as e:
except Exception as e: # pragma: no cover
_logger.debug(
"Unable to get system aggregation functions, "
"falling back to hardcoded list: %s",
"Unable to start async aggregation metadata prefetch: %s",
e,
)
retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS)
self._agg_function_prefetch_job = None

with context._aggregation_function_set_lock:
context._aggregation_function_set.update(retrieved_set)
def _submit_internal_async_prefetch_query(self, query: str) -> Optional[AsyncJob]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can we inline this method since it's only called once, and pretty short?

"""Submit a prefetch query as internal async and return an AsyncJob handle."""
try:
result = self._conn.execute_async_and_notify_query_listener(
query,
_is_internal=True,
)
return self.create_async_job(result["queryId"])
except Exception as e: # pragma: no cover
_logger.debug("Unable to submit internal async prefetch query: %s", e)
return None

def directory(self, stage_name: str, _emit_ast: bool = True) -> DataFrame:
"""
Expand Down
Loading
Loading