diff --git a/tests/development_server.py b/tests/development_server.py index fd15c201..78cb08cc 100644 --- a/tests/development_server.py +++ b/tests/development_server.py @@ -51,7 +51,7 @@ def start_development_server(port=None, trino_version=TRINO_VERSION): network = Network().create() supports_spooling_protocol = TRINO_VERSION == "latest" or int(TRINO_VERSION) >= 466 if supports_spooling_protocol: - localstack = LocalStackContainer(image="localstack/localstack:latest", region_name="us-east-1") \ + localstack = LocalStackContainer(image="localstack/localstack:4.14.0", region_name="us-east-1") \ .with_name("localstack") \ .with_network(network) \ .with_bind_ports(4566, 4566) \ diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index 71b6c663..625a282b 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -1900,6 +1900,79 @@ def test_segments_cursor(trino_connection): assert total == 300875, f"Expected total rows 300875, got {total}" +@pytest.mark.skipif( + trino_version() <= 466, + reason="spooling protocol was introduced in version 466" +) +def test_spooled_segments_lazy_fetchone(trino_connection): + """Verify that spooled results can be consumed row-by-row via fetchone() + without materializing the entire result set in memory.""" + if trino_connection._client_session.encoding is None: + pytest.skip("spooling requires an encoding") + + cur = trino_connection.cursor() + cur.execute("""SELECT l.* + FROM tpch.tiny.lineitem l, TABLE(sequence( + start => 1, + stop => 5, + step => 1)) n""") + + # The underlying result rows should be an iterator, not a list + result_rows = cur._query._result._rows + assert not isinstance(result_rows, list), ( + f"Expected lazy iterator for spooled results, got {type(result_rows)}" + ) + + # Consume rows one by one and count them + count = 0 + while cur.fetchone() is not None: + count += 1 + assert count == 300875, f"Expected 300875 rows, got {count}" + + +@pytest.mark.skipif( + trino_version() <= 466, + reason="spooling protocol was introduced in version 466" +) +def test_spooled_segments_fetchmany(trino_connection): + """Verify that fetchmany() works correctly with lazily loaded spooled segments.""" + if trino_connection._client_session.encoding is None: + pytest.skip("spooling requires an encoding") + + cur = trino_connection.cursor() + cur.execute("SELECT * FROM tpch.tiny.lineitem") + + batch = cur.fetchmany(100) + assert len(batch) == 100 + + total = len(batch) + while True: + batch = cur.fetchmany(1000) + if not batch: + break + total += len(batch) + assert total == 60175, f"Expected 60175 rows, got {total}" + + +@pytest.mark.skipif( + trino_version() <= 466, + reason="spooling protocol was introduced in version 466" +) +def test_spooled_segments_iterator_protocol(trino_connection): + """Verify that cursor iteration works correctly with spooled segments.""" + if trino_connection._client_session.encoding is None: + pytest.skip("spooling requires an encoding") + + cur = trino_connection.cursor() + cur.execute("SELECT * FROM tpch.tiny.lineitem") + + count = 0 + for row in cur: + count += 1 + assert isinstance(row, list) + assert count == 60175, f"Expected 60175 rows, got {count}" + + def get_cursor(legacy_prepared_statements, run_trino): host, port = run_trino diff --git a/trino/client.py b/trino/client.py index 3ab27e33..85199036 100644 --- a/trino/client.py +++ b/trino/client.py @@ -39,6 +39,7 @@ import base64 import copy import functools +import itertools import os import random import re @@ -904,9 +905,26 @@ def execute(self, additional_http_headers=None) -> TrinoResult: rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows self._result = TrinoResult(self, rows) - # Execute should block until at least one row is received or query is finished or cancelled - while not self.finished and not self.cancelled and len(self._result.rows) == 0: - self._result.rows += self.fetch() + # Block until rows are available, the query finishes, or it is canceled. + # Rows start as an empty list. Early responses often contain only stats, + # so we keep fetching until actual data arrives. + # + # Two protocols produce rows differently: + # - Direct: fetch() returns a list - accumulate into the existing list. + # - Spooling: fetch() returns a lazy iterator - replace rows and stop, + # because we cannot cheaply check iterator length. + while not self.finished and not self.cancelled and self._result.rows == []: + new_rows = self.fetch() + if isinstance(new_rows, list): + self._result.rows += new_rows + else: + try: + first_row = next(new_rows) + self._result.rows = itertools.chain([first_row], new_rows) + break + except StopIteration: + self._result.rows = [] + return self._result def _update_state(self, status): @@ -920,7 +938,7 @@ def _update_state(self, status): if status.columns: self._columns = status.columns - def fetch(self) -> List[Union[List[Any]], Any]: + def fetch(self) -> Union[List[Union[List[Any], Any]], Iterator[List[Any]]]: """Continue fetching data for the current query_id""" try: response = self._request.get(self._request.next_uri) @@ -941,7 +959,8 @@ def fetch(self) -> List[Union[List[Any]], Any]: spooled = self._to_segments(rows) if self._fetch_mode == "segments": return spooled - return list(SegmentIterator(spooled, self._row_mapper)) + # Return iterator directly, do NOT materialize with list() + return SegmentIterator(spooled, self._row_mapper) elif isinstance(status.rows, list): return self._row_mapper.map(rows) else: