diff --git a/src/snowflake/snowpark/context.py b/src/snowflake/snowpark/context.py index c402871889..812f29d502 100644 --- a/src/snowflake/snowpark/context.py +++ b/src/snowflake/snowpark/context.py @@ -6,7 +6,7 @@ """Context module for Snowpark.""" import logging import sys -from typing import Callable, Optional +from typing import Any, Callable, Optional import snowflake.snowpark import threading @@ -45,6 +45,11 @@ set() ) # lower cased names of aggregation functions, used in sql simplification _aggregation_function_set_lock = threading.RLock() +_aggregation_function_prefetch_state: dict[str, Any] = { + "lock": threading.RLock(), + "event": None, + "job": None, +} # Hardcoded fallback for system built-in aggregation functions. # Used when the dynamic query fails to retrieve the list from the database. @@ -62,9 +67,6 @@ "ai_agg", "ai_summarize_agg", "any_value", - "approximate_count_distinct", - "approximate_jaccard_index", - "approximate_similarity", "approx_count_distinct", "approx_percentile", "approx_percentile_accumulate", @@ -72,25 +74,29 @@ "approx_top_k", "approx_top_k_accumulate", "approx_top_k_combine", - "arrayagg", + "approximate_count_distinct", + "approximate_jaccard_index", + "approximate_similarity", "array_agg", "array_union_agg", "array_unique_agg", + "arrayagg", "avg", - "bitandagg", + "bit_and_agg", + "bit_andagg", + "bit_or_agg", + "bit_oragg", + "bit_xor_agg", + "bit_xoragg", "bitand_agg", + "bitandagg", + "bitmap_and_agg", "bitmap_construct_agg", "bitmap_or_agg", - "bitoragg", "bitor_agg", - "bitxoragg", + "bitoragg", "bitxor_agg", - "bit_andagg", - "bit_and_agg", - "bit_oragg", - "bit_or_agg", - "bit_xoragg", - "bit_xor_agg", + "bitxoragg", "booland_agg", "boolor_agg", "boolxor_agg", @@ -115,12 +121,12 @@ "max_by", "median", "min", + "min_by", "minhash", "minhash_combine", - "min_by", "mode", - "objectagg", "object_agg", + "objectagg", "percentile_cont", "percentile_disc", "regr_avgx", @@ -133,20 +139,21 @@ "regr_sxy", "regr_syy", "skew", + "st_intersection_agg_geography_internal", + "st_union_agg_geography_internal", "stddev", "stddev_pop", "stddev_samp", - "st_intersection_agg_geography_internal", - "st_union_agg_geography_internal", "sum", "sum_internal", "sum_internal_real", "sum_real", + "summarize_agg", + "var_pop", + "var_samp", "variance", "variance_pop", "variance_samp", - "var_pop", - "var_samp", "vector_avg", "vector_max", "vector_min", @@ -154,6 +161,7 @@ ] ) + _cte_error_threshold = 3 # 0 to disable auto-cte-disable, otherwise the number of times CTE optimization can fail before it is automatically disabled for the remainder of the session. # Following are internal-only global flags, used to enable development features. diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 457f28f95b..c8bdc0ca5a 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -17,7 +17,7 @@ from collections import defaultdict from functools import reduce from logging import getLogger -from threading import RLock +from threading import Event, RLock from types import ModuleType from typing import ( TYPE_CHECKING, @@ -858,6 +858,7 @@ def __init__( self._client_telemetry = EventTableTelemetry(session=self) self._ast_batch = AstBatch(self) + self._start_async_aggregation_prefetch_if_needed() _logger.info("Snowpark Session information: %s", self._session_info) @@ -5045,53 +5046,129 @@ def _execute_sproc_internal( def _retrieve_aggregation_function_list(self) -> None: """Retrieve the list of aggregation functions which will later be used in sql simplifier.""" - if ( - not ( - context._is_snowpark_connect_compatible_mode - and context._snowpark_connect_flatten_select_after_sort - ) - or context._aggregation_function_set + if not ( + context._is_snowpark_connect_compatible_mode + and context._snowpark_connect_flatten_select_after_sort ): return + with context._aggregation_function_set_lock: + if context._aggregation_function_set: + return + retrieved_set = set() + system_fetch_succeeded = False + prefetch_state = context._aggregation_function_prefetch_state + + # Atomically claim the async job. The claiming thread creates an Event so concurrent + # threads can wait on it rather than issuing redundant sync queries. + # AsyncJob.result() is not thread-safe — the underlying connector cursor mutates + # shared state (_result, _rownumber, _prefetch_hook) during result fetching, so only + # one thread may call it. The lock is held only for the pointer swap and event setup + # (nanoseconds), not the network call itself. + with prefetch_state["lock"]: + job, prefetch_state["job"] = prefetch_state["job"], None + if job is not None: + fetch_event = Event() + prefetch_state["event"] = fetch_event + wait_event = None + elif prefetch_state["event"] is not None: + fetch_event = None + wait_event = prefetch_state["event"] + else: + fetch_event = None + wait_event = None + + if wait_event is not None: + # The query typically finishes in ~5s; 20s gives ample headroom while + # bounding the hang in the rare case the winner thread dies before its + # finally block runs (e.g. os._exit, interpreter shutdown). + wait_event.wait(timeout=20) + with context._aggregation_function_set_lock: + if context._aggregation_function_set: + return + # Winner failed or timed out; fall through to sync query. - # User-defined aggregation functions try: - retrieved_set.update( - { - r[0].lower() - for r in self.sql( - """select function_name from information_schema.functions where is_aggregate = 'YES'""" - ).collect() - } - ) - except Exception as e: - _logger.debug( - "Unable to get user-defined aggregation functions: %s", - e, - ) + if job is not None: + try: + retrieved_set.update({r[0].lower() for r in job.result()}) + system_fetch_succeeded = True + except Exception as e: + _logger.debug( + "Unable to use async aggregation function prefetch: %s", + e, + ) + else: + _logger.debug( + "Async aggregation function prefetch job is unavailable; using sync fallback." + ) - # System built-in aggregation functions - try: - retrieved_set.update( - { - r[0].lower() - for r in self.sql( - """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""" - ).collect() - } - ) - except Exception as e: - _logger.debug( - "Unable to get system aggregation functions, " - "falling back to hardcoded list: %s", - e, - ) - retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS) + # Sync fallback query. + if not system_fetch_succeeded: + try: + retrieved_set.update( + { + r[0].lower() + for r in self._conn.run_query( + """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""", + _is_internal=True, + )["data"] + } + ) + system_fetch_succeeded = True + except Exception as e: + _logger.debug( + "Unable to get aggregation functions via sync fallback query: %s", + e, + ) - with context._aggregation_function_set_lock: - context._aggregation_function_set.update(retrieved_set) + # Fallback to the local hardcoded list only when metadata retrieval fails. + if not system_fetch_succeeded: + retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS) + + with context._aggregation_function_set_lock: + context._aggregation_function_set.update(retrieved_set) + finally: + # Signal after _aggregation_function_set is published so waiters see + # the populated set immediately upon waking. Also fires on BaseException + # (e.g. KeyboardInterrupt) so waiters are never left blocking until timeout. + if fetch_event is not None: + fetch_event.set() + + def _start_async_aggregation_prefetch_if_needed(self) -> None: + """Start aggregation metadata prefetch only when not already in progress.""" + if not ( + context._is_snowpark_connect_compatible_mode + and context._snowpark_connect_flatten_select_after_sort + ): + return + prefetch_state = context._aggregation_function_prefetch_state + with prefetch_state["lock"]: + with context._aggregation_function_set_lock: + if context._aggregation_function_set: + return + if prefetch_state["job"] is not None: + return + # A winner thread has already claimed the async job and is still publishing results. + # Do not start a new async query while that in-flight fetch is unfinished. + if ( + prefetch_state["event"] is not None + and not prefetch_state["event"].is_set() + ): + return + try: + result = self._conn.execute_async_and_notify_query_listener( + """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""", + _is_internal=True, + ) + prefetch_state["job"] = self.create_async_job(result["queryId"]) + except Exception as e: # pragma: no cover + _logger.debug( + "Unable to start async aggregation metadata prefetch: %s", + e, + ) + prefetch_state["job"] = None def directory(self, stage_name: str, _emit_ast: bool = True) -> DataFrame: """ diff --git a/tests/integ/test_simplifier_suite.py b/tests/integ/test_simplifier_suite.py index b446347d51..616d1d928a 100644 --- a/tests/integ/test_simplifier_suite.py +++ b/tests/integ/test_simplifier_suite.py @@ -2519,3 +2519,197 @@ def test_retrieving_aggregation_funcs(session, monkeypatch): assert not context._aggregation_function_set session._retrieve_aggregation_function_list() assert not context._aggregation_function_set + + +def test_internal_async_aggregation_prefetch_submission(session, monkeypatch): + from threading import Event + + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) + monkeypatch.setattr(context, "_aggregation_function_set", set()) + context._aggregation_function_prefetch_state["job"] = None + context._aggregation_function_prefetch_state["event"] = None + + calls = [] + + def _fake_execute_async(query, **kwargs): + calls.append((query, kwargs)) + return {"queryId": "qid_combined"} + + monkeypatch.setattr( + session._conn, "execute_async_and_notify_query_listener", _fake_execute_async + ) + session._start_async_aggregation_prefetch_if_needed() + + assert len(calls) == 1 + assert calls[0][1].get("_is_internal") is True + assert "show functions" in calls[0][0] + assert "information_schema.functions" not in calls[0][0] + assert ( + context._aggregation_function_prefetch_state["job"].query_id == "qid_combined" + ) + + # Another session start during in-flight fetch should not submit another async query. + context._aggregation_function_prefetch_state["job"] = None + context._aggregation_function_prefetch_state["event"] = Event() + session._start_async_aggregation_prefetch_if_needed() + assert len(calls) == 1 + + +def test_aggregation_fallback_not_used_when_combined_async_succeeds( + session, monkeypatch +): + import snowflake.snowpark.context as context + + class _FakeAsyncJob: + def __init__(self, rows=None, error=None) -> None: + self._rows = rows + self._error = error + + def result(self): + if self._error is not None: + raise self._error + return self._rows + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) + monkeypatch.setattr(context, "_aggregation_function_set", set()) + context._aggregation_function_prefetch_state["job"] = _FakeAsyncJob(rows=[("SUM",)]) + + session._retrieve_aggregation_function_list() + + assert "sum" in context._aggregation_function_set + assert "sum_internal" not in context._aggregation_function_set + + +def test_internal_sync_aggregation_fallback_submission(session, monkeypatch): + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) + monkeypatch.setattr(context, "_aggregation_function_set", set()) + context._aggregation_function_prefetch_state["job"] = None + context._aggregation_function_prefetch_state["event"] = None + + calls = [] + + def _fake_run_query(query, **kwargs): + calls.append((query, kwargs)) + return {"data": [("AVG",)]} + + monkeypatch.setattr(session._conn, "run_query", _fake_run_query) + session._retrieve_aggregation_function_list() + + assert len(calls) == 1 + assert calls[0][1].get("_is_internal") is True + assert "show functions" in calls[0][0] + assert "information_schema.functions" not in calls[0][0] + assert "avg" in context._aggregation_function_set + + +def test_concurrent_retrieve_agg_waiters_no_sync_query(session, monkeypatch): + """Concurrent calls to _retrieve_aggregation_function_list must result in zero + sync queries from waiters — they reuse the winner's async result via the Event.""" + import threading + import time + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) + monkeypatch.setattr(context, "_aggregation_function_set", set()) + context._aggregation_function_prefetch_state["event"] = None + + job_may_proceed = threading.Event() + waiter_count = [0] + waiter_count_lock = threading.Lock() + + class SlowFakeAsyncJob: + def result(self): + job_may_proceed.wait() + return [("SUM",), ("AVG",)] + + context._aggregation_function_prefetch_state["job"] = SlowFakeAsyncJob() + + sync_query_calls = [] + original_run_query = session._conn.run_query + + def counting_run_query(query, **kwargs): + if kwargs.get("_is_internal") and "show functions" in query: + sync_query_calls.append(query) + return original_run_query(query, **kwargs) + + monkeypatch.setattr(session._conn, "run_query", counting_run_query) + + errors = [] + + def run_winner(): + try: + session._retrieve_aggregation_function_list() + except Exception as e: + errors.append(e) + + def run_waiter(): + try: + with waiter_count_lock: + waiter_count[0] += 1 + if waiter_count[0] == 2: + job_may_proceed.set() + session._retrieve_aggregation_function_list() + except Exception as e: + errors.append(e) + + winner = threading.Thread(target=run_winner) + waiters = [threading.Thread(target=run_waiter) for _ in range(2)] + + winner.start() + time.sleep(0.05) # give winner time to claim job and set fetch_event + for w in waiters: + w.start() + winner.join(timeout=15) + for w in waiters: + w.join(timeout=15) + + assert not errors + assert "sum" in context._aggregation_function_set + assert "avg" in context._aggregation_function_set + assert ( + len(sync_query_calls) == 0 + ), f"Expected 0 sync queries from waiters, got {len(sync_query_calls)}" + + +def test_concurrent_retrieve_agg_event_set_after_context_published( + session, monkeypatch +): + """The fetch_event must be set only after _aggregation_function_set is published — + waiters must see a non-empty set the moment they wake up""" + import snowflake.snowpark.context as context + from threading import Event as _Event + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + monkeypatch.setattr(context, "_snowpark_connect_flatten_select_after_sort", True) + monkeypatch.setattr(context, "_aggregation_function_set", set()) + context._aggregation_function_prefetch_state["event"] = None + + class _FakeAsyncJob: + def result(self): + return [("SUM",)] + + context._aggregation_function_prefetch_state["job"] = _FakeAsyncJob() + + snapshot_at_set = [] + original_event_set = _Event.set + + def patched_set(self_event): + snapshot_at_set.append(frozenset(context._aggregation_function_set)) + original_event_set(self_event) + + monkeypatch.setattr(_Event, "set", patched_set) + + session._retrieve_aggregation_function_list() + + assert snapshot_at_set, "fetch_event.set() was never called" + assert snapshot_at_set[ + 0 + ], "fetch_event fired before _aggregation_function_set was published" diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 0349618659..235f2ff130 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -808,9 +808,8 @@ def test_infer_is_return_table_uses_internal_describe(): assert mocked_run_query.call_count == 1 -def test_retrieve_aggregation_function_list_handles_user_defined_error(): - """When querying user-defined aggregation functions fails, the error is - swallowed and the method continues to query system functions.""" +def test_retrieve_aggregation_function_list_handles_async_error(): + """When async metadata prefetch fails, sync internal fallback is used.""" import snowflake.snowpark.context as ctx fake_server_connection = mock.create_autospec(ServerConnection) @@ -818,34 +817,38 @@ def test_retrieve_aggregation_function_list_handles_user_defined_error(): session = Session(fake_server_connection) original_compat = ctx._is_snowpark_connect_compatible_mode + original_flatten = ctx._snowpark_connect_flatten_select_after_sort original_agg_set = ctx._aggregation_function_set try: ctx._is_snowpark_connect_compatible_mode = True + ctx._snowpark_connect_flatten_select_after_sort = True ctx._aggregation_function_set = set() + ctx._aggregation_function_prefetch_state["event"] = None - mock_df = MagicMock() - call_count = [0] + fake_async_job = MagicMock() + fake_async_job.result.side_effect = RuntimeError("async query failed") + ctx._aggregation_function_prefetch_state["job"] = fake_async_job - def sql_side_effect(query, **kwargs): - call_count[0] += 1 - if call_count[0] == 1: - raise RuntimeError("user-defined query failed") - mock_df.collect.return_value = [["SUM"], ["AVG"]] - return mock_df + def run_query_side_effect(query, **kwargs): + assert kwargs.get("_is_internal") is True + assert "show functions" in query + return {"data": [["SUM"], ["AVG"]]} - with mock.patch.object(session, "sql", side_effect=sql_side_effect): + with mock.patch.object( + fake_server_connection, "run_query", side_effect=run_query_side_effect + ): session._retrieve_aggregation_function_list() assert "sum" in ctx._aggregation_function_set assert "avg" in ctx._aggregation_function_set finally: ctx._is_snowpark_connect_compatible_mode = original_compat + ctx._snowpark_connect_flatten_select_after_sort = original_flatten ctx._aggregation_function_set = original_agg_set -def test_retrieve_aggregation_function_list_handles_system_error(): - """When querying system aggregation functions fails, the method falls back - to the hardcoded _KNOWN_AGGREGATION_FUNCTIONS set.""" +def test_retrieve_aggregation_function_list_handles_sync_error(): + """When sync metadata query fails, hardcoded fallback applies.""" import snowflake.snowpark.context as ctx fake_server_connection = mock.create_autospec(ServerConnection) @@ -853,26 +856,32 @@ def test_retrieve_aggregation_function_list_handles_system_error(): session = Session(fake_server_connection) original_compat = ctx._is_snowpark_connect_compatible_mode + original_flatten = ctx._snowpark_connect_flatten_select_after_sort original_agg_set = ctx._aggregation_function_set try: ctx._is_snowpark_connect_compatible_mode = True + ctx._snowpark_connect_flatten_select_after_sort = True ctx._aggregation_function_set = set() - mock_df = MagicMock() - mock_df.collect.side_effect = RuntimeError("system query failed") + def run_query_side_effect(query, **kwargs): + assert kwargs.get("_is_internal") is True + assert "show functions" in query + raise RuntimeError("sync query failed") - with mock.patch.object(session, "sql", return_value=mock_df): + with mock.patch.object( + fake_server_connection, "run_query", side_effect=run_query_side_effect + ): session._retrieve_aggregation_function_list() assert ctx._KNOWN_AGGREGATION_FUNCTIONS.issubset(ctx._aggregation_function_set) finally: ctx._is_snowpark_connect_compatible_mode = original_compat + ctx._snowpark_connect_flatten_select_after_sort = original_flatten ctx._aggregation_function_set = original_agg_set -def test_retrieve_aggregation_function_list_handles_both_errors(): - """When both aggregation function queries fail, the hardcoded fallback - set is still populated.""" +def test_retrieve_aggregation_function_list_uses_single_internal_sync_query(): + """Sync fallback executes exactly one internal metadata query.""" import snowflake.snowpark.context as ctx fake_server_connection = mock.create_autospec(ServerConnection) @@ -880,17 +889,177 @@ def test_retrieve_aggregation_function_list_handles_both_errors(): session = Session(fake_server_connection) original_compat = ctx._is_snowpark_connect_compatible_mode + original_flatten = ctx._snowpark_connect_flatten_select_after_sort original_agg_set = ctx._aggregation_function_set try: ctx._is_snowpark_connect_compatible_mode = True + ctx._snowpark_connect_flatten_select_after_sort = True ctx._aggregation_function_set = set() + ctx._aggregation_function_prefetch_state["event"] = None + ctx._aggregation_function_prefetch_state["job"] = None + + called_queries = [] + + def run_query_side_effect(query, **kwargs): + called_queries.append(query) + assert kwargs.get("_is_internal") is True + return {"data": [["SUM"]]} with mock.patch.object( - session, "sql", side_effect=RuntimeError("query failed") + fake_server_connection, + "run_query", + side_effect=run_query_side_effect, ): session._retrieve_aggregation_function_list() - assert ctx._KNOWN_AGGREGATION_FUNCTIONS.issubset(ctx._aggregation_function_set) + assert len(called_queries) == 1 + assert "show functions" in called_queries[0] + assert "information_schema.functions" not in called_queries[0] + assert "sum" in ctx._aggregation_function_set finally: ctx._is_snowpark_connect_compatible_mode = original_compat + ctx._snowpark_connect_flatten_select_after_sort = original_flatten ctx._aggregation_function_set = original_agg_set + + +def test_retrieve_agg_event_set_after_context_published(monkeypatch): + """fetch_event.set() must be called only after _aggregation_function_set is + populated — the original bug was setting the event before publishing.""" + import snowflake.snowpark.context as ctx + from threading import Event as _Event + + fake_conn = mock.create_autospec(ServerConnection) + fake_conn._thread_safe_session_enabled = True + session = Session(fake_conn) + ctx._aggregation_function_prefetch_state["event"] = None + + monkeypatch.setattr(ctx, "_is_snowpark_connect_compatible_mode", True) + monkeypatch.setattr(ctx, "_snowpark_connect_flatten_select_after_sort", True) + monkeypatch.setattr(ctx, "_aggregation_function_set", set()) + + class TrackingAsyncJob: + def result(self): + return [("SUM",)] + + ctx._aggregation_function_prefetch_state["job"] = TrackingAsyncJob() + + publish_order = [] + original_set = _Event.set + + def patched_set(self_event): + publish_order.append(("event_set", frozenset(ctx._aggregation_function_set))) + original_set(self_event) + + with mock.patch.object(_Event, "set", patched_set): + session._retrieve_aggregation_function_list() + + assert publish_order, "fetch_event.set() was never called" + # At the moment event fires the set must already contain the result. + _, snapshot = publish_order[0] + assert ( + "sum" in snapshot + ), f"fetch_event fired before context was populated; snapshot={snapshot}" + + +def test_retrieve_agg_waiters_fall_through_on_winner_failure(monkeypatch): + """When the winner's async job fails, waiters fall through to sync query + rather than hanging or returning an empty set.""" + import threading + import time + import snowflake.snowpark.context as ctx + + fake_conn = mock.create_autospec(ServerConnection) + fake_conn._thread_safe_session_enabled = True + session = Session(fake_conn) + ctx._aggregation_function_prefetch_state["event"] = None + + monkeypatch.setattr(ctx, "_is_snowpark_connect_compatible_mode", True) + monkeypatch.setattr(ctx, "_snowpark_connect_flatten_select_after_sort", True) + monkeypatch.setattr(ctx, "_aggregation_function_set", set()) + + job_may_proceed = threading.Event() + waiter_registered = threading.Event() + + class FailingAsyncJob: + def result(self): + job_may_proceed.wait() + raise RuntimeError("async job failed") + + ctx._aggregation_function_prefetch_state["job"] = FailingAsyncJob() + + sync_query_calls = [] + + def run_query_side_effect(query, **kwargs): + sync_query_calls.append(query) + return {"data": [("COUNT",)]} + + fake_conn.run_query.side_effect = run_query_side_effect + + errors = [] + + def run_winner(): + try: + session._retrieve_aggregation_function_list() + except Exception as e: + errors.append(e) + + def run_waiter(): + try: + waiter_registered.set() + session._retrieve_aggregation_function_list() + except Exception as e: + errors.append(e) + + winner = threading.Thread(target=run_winner) + waiter = threading.Thread(target=run_waiter) + + winner.start() + time.sleep(0.05) # give winner time to claim the job and set fetch_event + waiter.start() + waiter_registered.wait(timeout=5) + time.sleep(0.05) # give waiter time to reach wait_event.wait() + job_may_proceed.set() # let the winner fail + + winner.join(timeout=10) + waiter.join(timeout=10) + + assert not errors + # Winner failed → waiter fell through to sync query → count in set. + assert "count" in ctx._aggregation_function_set + + +def test_retrieve_agg_event_always_set_on_base_exception(monkeypatch): + """fetch_event.set() fires even when a BaseException escapes the async job, + so waiters are never left blocking until timeout""" + import snowflake.snowpark.context as ctx + from threading import Event as _Event + + fake_conn = mock.create_autospec(ServerConnection) + fake_conn._thread_safe_session_enabled = True + session = Session(fake_conn) + ctx._aggregation_function_prefetch_state["event"] = None + + monkeypatch.setattr(ctx, "_is_snowpark_connect_compatible_mode", True) + monkeypatch.setattr(ctx, "_snowpark_connect_flatten_select_after_sort", True) + monkeypatch.setattr(ctx, "_aggregation_function_set", set()) + + class KeyboardInterruptJob: + def result(self): + raise KeyboardInterrupt() + + ctx._aggregation_function_prefetch_state["job"] = KeyboardInterruptJob() + + event_was_set = [] + original_set = _Event.set + + def patched_set(self_event): + event_was_set.append(True) + original_set(self_event) + + with mock.patch.object(_Event, "set", patched_set): + try: + session._retrieve_aggregation_function_list() + except KeyboardInterrupt: + pass + + assert event_was_set, "fetch_event.set() was not called despite BaseException"