Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/development_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def start_development_server(port=None, trino_version=TRINO_VERSION):
network = Network().create()
supports_spooling_protocol = TRINO_VERSION == "latest" or int(TRINO_VERSION) >= 466
if supports_spooling_protocol:
localstack = LocalStackContainer(image="localstack/localstack:latest", region_name="us-east-1") \
localstack = LocalStackContainer(image="localstack/localstack:4.14.0", region_name="us-east-1") \
Comment thread
wendigo marked this conversation as resolved.
.with_name("localstack") \
.with_network(network) \
.with_bind_ports(4566, 4566) \
Expand Down
73 changes: 73 additions & 0 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1900,6 +1900,79 @@ def test_segments_cursor(trino_connection):
assert total == 300875, f"Expected total rows 300875, got {total}"


@pytest.mark.skipif(
trino_version() <= 466,
reason="spooling protocol was introduced in version 466"
)
def test_spooled_segments_lazy_fetchone(trino_connection):
"""Verify that spooled results can be consumed row-by-row via fetchone()
without materializing the entire result set in memory."""
if trino_connection._client_session.encoding is None:
pytest.skip("spooling requires an encoding")

cur = trino_connection.cursor()
cur.execute("""SELECT l.*
FROM tpch.tiny.lineitem l, TABLE(sequence(
start => 1,
stop => 5,
step => 1)) n""")

# The underlying result rows should be an iterator, not a list
result_rows = cur._query._result._rows
assert not isinstance(result_rows, list), (
f"Expected lazy iterator for spooled results, got {type(result_rows)}"
)

# Consume rows one by one and count them
count = 0
while cur.fetchone() is not None:
count += 1
assert count == 300875, f"Expected 300875 rows, got {count}"


@pytest.mark.skipif(
trino_version() <= 466,
reason="spooling protocol was introduced in version 466"
)
def test_spooled_segments_fetchmany(trino_connection):
"""Verify that fetchmany() works correctly with lazily loaded spooled segments."""
if trino_connection._client_session.encoding is None:
pytest.skip("spooling requires an encoding")

cur = trino_connection.cursor()
cur.execute("SELECT * FROM tpch.tiny.lineitem")

batch = cur.fetchmany(100)
assert len(batch) == 100

total = len(batch)
while True:
batch = cur.fetchmany(1000)
if not batch:
break
total += len(batch)
assert total == 60175, f"Expected 60175 rows, got {total}"


@pytest.mark.skipif(
trino_version() <= 466,
reason="spooling protocol was introduced in version 466"
)
def test_spooled_segments_iterator_protocol(trino_connection):
"""Verify that cursor iteration works correctly with spooled segments."""
if trino_connection._client_session.encoding is None:
pytest.skip("spooling requires an encoding")

cur = trino_connection.cursor()
cur.execute("SELECT * FROM tpch.tiny.lineitem")

count = 0
for row in cur:
count += 1
assert isinstance(row, list)
assert count == 60175, f"Expected 60175 rows, got {count}"


def get_cursor(legacy_prepared_statements, run_trino):
host, port = run_trino

Expand Down
29 changes: 24 additions & 5 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import base64
import copy
import functools
import itertools
import os
import random
import re
Expand Down Expand Up @@ -904,9 +905,26 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows
self._result = TrinoResult(self, rows)

# Execute should block until at least one row is received or query is finished or cancelled
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
self._result.rows += self.fetch()
# Block until rows are available, the query finishes, or it is canceled.
# Rows start as an empty list. Early responses often contain only stats,
# so we keep fetching until actual data arrives.
#
# Two protocols produce rows differently:
# - Direct: fetch() returns a list - accumulate into the existing list.
# - Spooling: fetch() returns a lazy iterator - replace rows and stop,
# because we cannot cheaply check iterator length.
Comment thread
wendigo marked this conversation as resolved.
while not self.finished and not self.cancelled and self._result.rows == []:
new_rows = self.fetch()
if isinstance(new_rows, list):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we also need this check in columns?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? I don't think we do

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._result.rows += new_rows
else:
try:
first_row = next(new_rows)
self._result.rows = itertools.chain([first_row], new_rows)
break
except StopIteration:
self._result.rows = []

return self._result

def _update_state(self, status):
Expand All @@ -920,7 +938,7 @@ def _update_state(self, status):
if status.columns:
self._columns = status.columns

def fetch(self) -> List[Union[List[Any]], Any]:
def fetch(self) -> Union[List[Union[List[Any], Any]], Iterator[List[Any]]]:
"""Continue fetching data for the current query_id"""
try:
response = self._request.get(self._request.next_uri)
Expand All @@ -941,7 +959,8 @@ def fetch(self) -> List[Union[List[Any]], Any]:
spooled = self._to_segments(rows)
if self._fetch_mode == "segments":
return spooled
return list(SegmentIterator(spooled, self._row_mapper))
# Return iterator directly, do NOT materialize with list()
return SegmentIterator(spooled, self._row_mapper)
elif isinstance(status.rows, list):
return self._row_mapper.map(rows)
else:
Expand Down
Loading