Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 24 additions & 19 deletions src/snowflake/snowpark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,42 +57,45 @@
"ai_agg",
"ai_summarize_agg",
"any_value",
"approximate_count_distinct",
"approximate_jaccard_index",
"approximate_similarity",
"approx_count_distinct",
"approx_percentile",
"approx_percentile_accumulate",
"approx_percentile_combine",
"approx_top_k",
"approx_top_k_accumulate",
"approx_top_k_combine",
"arrayagg",
"approximate_count_distinct",
"approximate_jaccard_index",
"approximate_similarity",
"array_agg",
"array_union_agg",
"array_unique_agg",
"arrayagg",
"avg",
"bitandagg",
"bit_and_agg",
"bit_andagg",
"bit_or_agg",
"bit_oragg",
"bit_xor_agg",
"bit_xoragg",
"bitand_agg",
"bitandagg",
"bitmap_and_agg",
"bitmap_construct_agg",
"bitmap_or_agg",
"bitoragg",
"bitor_agg",
"bitxoragg",
"bitoragg",
"bitxor_agg",
"bit_andagg",
"bit_and_agg",
"bit_oragg",
"bit_or_agg",
"bit_xoragg",
"bit_xor_agg",
"bitxoragg",
"booland_agg",
"boolor_agg",
"boolxor_agg",
"corr",
"count",
"count(*)",
"count_if",
"count_internal",
"count_internal(*)",
"covar_pop",
"covar_samp",
"datasketches_hll",
Expand All @@ -110,12 +113,12 @@
"max_by",
"median",
"min",
"min_by",
"minhash",
"minhash_combine",
"min_by",
"mode",
"objectagg",
"object_agg",
"objectagg",
"percentile_cont",
"percentile_disc",
"regr_avgx",
Expand All @@ -128,27 +131,29 @@
"regr_sxy",
"regr_syy",
"skew",
"st_intersection_agg_geography_internal",
"st_union_agg_geography_internal",
"stddev",
"stddev_pop",
"stddev_samp",
"st_intersection_agg_geography_internal",
"st_union_agg_geography_internal",
"sum",
"sum_internal",
"sum_internal_real",
"sum_real",
"summarize_agg",
"var_pop",
"var_samp",
"variance",
"variance_pop",
"variance_samp",
"var_pop",
"var_samp",
"vector_avg",
"vector_max",
"vector_min",
"vector_sum",
]
)


_cte_error_threshold = 3 # 0 to disable auto-cte-disable, otherwise the number of times CTE optimization can fail before it is automatically disabled for the remainder of the session.

# Following are internal-only global flags, used to enable development features.
Expand Down
102 changes: 75 additions & 27 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,8 +856,10 @@ def __init__(
self._dataframe_profiler = DataframeProfiler(session=self)
self._catalog = None
self._client_telemetry = EventTableTelemetry(session=self)
self._agg_function_prefetch_job: Optional[AsyncJob] = None

self._ast_batch = AstBatch(self)
self._start_async_aggregation_prefetch_if_needed()

_logger.info("Snowpark Session information: %s", self._session_info)

Expand Down Expand Up @@ -5055,43 +5057,89 @@ def _retrieve_aggregation_function_list(self) -> None:
return

retrieved_set = set()
system_fetch_succeeded = False

# User-defined aggregation functions
try:
retrieved_set.update(
{
r[0].lower()
for r in self.sql(
"""select function_name from information_schema.functions where is_aggregate = 'YES'"""
).collect()
}
)
except Exception as e:
# Try async result first if prefetch was already started.
if self._agg_function_prefetch_job is not None:
try:
retrieved_set.update(
{r[0].lower() for r in self._agg_function_prefetch_job.result()}
)
system_fetch_succeeded = True
except Exception as e:
_logger.debug(
"Unable to use async aggregation function prefetch: %s",
e,
)
finally:
self._agg_function_prefetch_job = None
else:
_logger.debug(
"Unable to get user-defined aggregation functions: %s",
e,
"Async aggregation function prefetch job is unavailable; using sync fallback."
Comment thread
sfc-gh-yuwang marked this conversation as resolved.
Outdated
)

# System built-in aggregation functions
# Sync fallback query.
if not system_fetch_succeeded:
try:
retrieved_set.update(
{
r[0].lower()
for r in self._conn.run_query(
"""show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""",
_is_internal=True,
Comment thread
sfc-gh-yuwang marked this conversation as resolved.
Outdated
)["data"]
}
)
system_fetch_succeeded = True
except Exception as e:
_logger.debug(
"Unable to get aggregation functions via sync fallback query: %s",
e,
)

# Fallback to the local hardcoded list only when metadata retrieval fails.
if not system_fetch_succeeded:
retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS)

with context._aggregation_function_set_lock:
context._aggregation_function_set.update(retrieved_set)

def _start_async_aggregation_prefetch_if_needed(self) -> None:
"""Start aggregation metadata prefetch only when not already in progress."""
if not (
context._is_snowpark_connect_compatible_mode
and context._snowpark_connect_flatten_select_after_sort
):
return
if context._aggregation_function_set:
return
if self._agg_function_prefetch_job is not None:
return

try:
retrieved_set.update(
{
r[0].lower()
for r in self.sql(
"""show functions ->> select "name" from $1 where "is_aggregate" = 'Y'"""
Comment thread
sfc-gh-joshi marked this conversation as resolved.
).collect()
}
self._agg_function_prefetch_job = self._submit_internal_async_prefetch_query(
Comment thread
sfc-gh-joshi marked this conversation as resolved.
Outdated
"""show functions ->> select "name" from $1 where "is_aggregate" = 'Y'
union
select function_name from information_schema.functions where is_aggregate = 'YES'"""
)
except Exception as e:
except Exception as e: # pragma: no cover
_logger.debug(
"Unable to get system aggregation functions, "
"falling back to hardcoded list: %s",
"Unable to start async aggregation metadata prefetch: %s",
e,
)
retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS)
self._agg_function_prefetch_job = None

with context._aggregation_function_set_lock:
context._aggregation_function_set.update(retrieved_set)
def _submit_internal_async_prefetch_query(self, query: str) -> Optional[AsyncJob]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can we inline this method since it's only called once, and pretty short?

"""Submit a prefetch query as internal async and return an AsyncJob handle."""
try:
result = self._conn.execute_async_and_notify_query_listener(
query,
_is_internal=True,
)
return self.create_async_job(result["queryId"])
except Exception as e: # pragma: no cover
_logger.debug("Unable to submit internal async prefetch query: %s", e)
return None

def directory(self, stage_name: str, _emit_ast: bool = True) -> DataFrame:
"""
Expand Down
76 changes: 76 additions & 0 deletions tests/integ/test_simplifier_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2519,3 +2519,79 @@ def test_retrieving_aggregation_funcs(session, monkeypatch):
assert not context._aggregation_function_set
session._retrieve_aggregation_function_list()
assert not context._aggregation_function_set


def test_internal_async_aggregation_prefetch_submission(session, monkeypatch):
import snowflake.snowpark.context as context

monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True)
monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True)
monkeypatch.setattr(context, "_aggregation_function_set", set())
session._agg_function_prefetch_job = None

calls = []

def _fake_execute_async(query, **kwargs):
calls.append((query, kwargs))
return {"queryId": "qid_combined"}

monkeypatch.setattr(
session._conn, "execute_async_and_notify_query_listener", _fake_execute_async
)
session._start_async_aggregation_prefetch_if_needed()

assert len(calls) == 1
assert calls[0][1].get("_is_internal") is True
assert "show functions" in calls[0][0]
assert "information_schema.functions" in calls[0][0]
assert session._agg_function_prefetch_job.query_id == "qid_combined"


def test_aggregation_fallback_not_used_when_combined_async_succeeds(
session, monkeypatch
):
import snowflake.snowpark.context as context

class _FakeAsyncJob:
def __init__(self, rows=None, error=None) -> None:
self._rows = rows
self._error = error

def result(self):
if self._error is not None:
raise self._error
return self._rows

monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True)
monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True)
monkeypatch.setattr(context, "_aggregation_function_set", set())
session._agg_function_prefetch_job = _FakeAsyncJob(rows=[("SUM",)])

session._retrieve_aggregation_function_list()

assert "sum" in context._aggregation_function_set
assert "sum_internal" not in context._aggregation_function_set


def test_internal_sync_aggregation_fallback_submission(session, monkeypatch):
import snowflake.snowpark.context as context

monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True)
monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True)
monkeypatch.setattr(context, "_aggregation_function_set", set())
session._agg_function_prefetch_job = None

calls = []

def _fake_run_query(query, **kwargs):
calls.append((query, kwargs))
return {"data": [("AVG",)]}

monkeypatch.setattr(session._conn, "run_query", _fake_run_query)
session._retrieve_aggregation_function_list()

assert len(calls) == 1
assert calls[0][1].get("_is_internal") is True
assert "show functions" in calls[0][0]
assert "information_schema.functions" not in calls[0][0]
assert "avg" in context._aggregation_function_set
Loading
Loading