diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index 625a282b..53f86504 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -1973,6 +1973,33 @@ def test_spooled_segments_iterator_protocol(trino_connection): assert count == 60175, f"Expected 60175 rows, got {count}" +@pytest.mark.skipif( + trino_version() <= 466, + reason="spooling protocol was introduced in version 466" +) +def test_spooled_segments_lazy_description(trino_connection): + """Verify that accessing cursor.description does not materialize the lazy spooled iterator.""" + if trino_connection._client_session.encoding is None: + pytest.skip("spooling requires an encoding") + + cur = trino_connection.cursor() + cur.execute("SELECT * FROM tpch.tiny.lineitem") + + assert not isinstance(cur._query._result._rows, list), ( + f"Expected lazy iterator for spooled results, got {type(cur._query._result._rows)}" + ) + + desc = cur.description + assert desc is not None + assert len(desc) > 0 + + assert not isinstance(cur._query._result._rows, list), ( + f"Expected lazy iterator after description access, got {type(cur._query._result._rows)}" + ) + + assert len(cur.fetchall()) == 60175 + + def get_cursor(legacy_prepared_statements, run_trino): host, port = run_trino diff --git a/trino/client.py b/trino/client.py index ac32d75f..97ef11ae 100644 --- a/trino/client.py +++ b/trino/client.py @@ -880,7 +880,21 @@ def columns(self): while not self._columns and not self.finished and not self.cancelled: # Columns are not returned immediately after query is submitted. # Continue fetching data until columns information is available and push fetched rows into buffer. - self._result.rows += self.fetch() + # + # Two protocols produce rows differently: + # - Direct: fetch() returns a list - accumulate into the existing list. + # - Spooling: fetch() returns a lazy iterator - replace rows and stop, + # because we cannot cheaply check iterator length. + new_rows = self.fetch() + if isinstance(new_rows, list): + self._result.rows += new_rows + else: + try: + first_row = next(new_rows) + self._result.rows = itertools.chain([first_row], new_rows) + break + except StopIteration: + self._result.rows = [] return self._columns @property