Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion tornado/simple_httpclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def __init__(
IOLoop.current().add_future(
gen.convert_yielded(self.run()), lambda f: f.result()
)
self._connect_phase: str | None = None

async def run(self) -> None:
try:
Expand Down Expand Up @@ -337,6 +338,7 @@ async def run(self) -> None:
ssl_options=ssl_options,
max_buffer_size=self.max_buffer_size,
source_ip=source_ip,
on_phase_change=self._set_connect_phase,
)

if self.final_callback is None:
Expand Down Expand Up @@ -478,7 +480,8 @@ def _on_timeout(self, info: str | None = None) -> None:
:info string key: More detailed timeout information.
"""
self._timeout = None
error_message = f"Timeout {info}" if info else "Timeout"
error_message = f"Timeout during {self._connect_phase}" if self._connect_phase else "Timeout"
error_message = f"{error_message} {info}" if info else error_message
if self.final_callback is not None:
self._handle_exception(
HTTPTimeoutError, HTTPTimeoutError(error_message), None
Expand Down Expand Up @@ -690,6 +693,12 @@ def data_received(self, chunk: bytes) -> Awaitable[None] | None:
self.chunks.append(chunk)
return None

def _set_connect_phase(self, phase: str) -> None:
if phase == "dns":
self._connect_phase = "DNS resolution"
elif phase == "tcp_connect":
self._connect_phase = "TCP connection"


if __name__ == "__main__":
AsyncHTTPClient.configure(SimpleAsyncHTTPClient)
Expand Down
11 changes: 10 additions & 1 deletion tornado/tcpclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import socket
import ssl
from collections.abc import Callable, Iterator
from typing import Any, Tuple
from typing import Any, Tuple, Optional

from tornado import gen
from tornado.concurrent import Future, future_add_done_callback
Expand Down Expand Up @@ -219,6 +219,7 @@ async def connect(
source_ip: str | None = None,
source_port: int | None = None,
timeout: float | datetime.timedelta | None = None,
on_phase_change: Optional[Callable[[str], None]] = None,
) -> IOStream:
"""Connect to the given host and port.

Expand Down Expand Up @@ -252,6 +253,10 @@ async def connect(
timeout = IOLoop.current().time() + timeout.total_seconds()
else:
raise TypeError("Unsupported timeout %r" % timeout)

if on_phase_change:
on_phase_change("dns")

if timeout is not None:
addrinfo = await gen.with_timeout(
timeout, self.resolver.resolve(host, port, af)
Expand All @@ -267,6 +272,10 @@ async def connect(
source_port=source_port,
),
)

if on_phase_change:
on_phase_change("tcp_connect")

af, addr, stream = await connector.start(connect_timeout=timeout)
# TODO: For better performance we could cache the (af, addr)
# information here and re-use it on subsequent connections to
Expand Down
59 changes: 57 additions & 2 deletions tornado/test/simple_httpclient_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
HTTPStreamClosedError,
HTTPTimeoutError,
SimpleAsyncHTTPClient,
TCPClient,
)
from tornado.test import httpclient_test
from tornado.test.httpclient_test import (
Expand Down Expand Up @@ -299,13 +300,66 @@ async def resolve(self, *args, **kwargs):
return [(socket.AF_INET, ("127.0.0.1", test.get_http_port()))]

with closing(self.create_client(resolver=TimeoutResolver())) as client:
with self.assertRaises(HTTPTimeoutError):
with self.assertRaises(HTTPTimeoutError) as cm:
yield client.fetch(
self.get_url("/hello"),
connect_timeout=timeout,
request_timeout=3600,
raise_error=True,
)
self.assertEqual(str(cm.exception), "Timeout during DNS resolution while connecting")

# Let the hanging coroutine clean up after itself. We need to
# wait more than a single IOLoop iteration for the SSL case,
# which logs errors on unexpected EOF.
cleanup_event.set()
yield gen.sleep(0.2)

@gen_test
def test_connect_timeout_tcp_conn(self):
timeout = 0.1

cleanup_event = Event()
test = self

class TimeoutTCPClient(TCPClient):
async def connect(
self,
host,
port,
af=socket.AF_UNSPEC,
ssl_options=None,
max_buffer_size=None,
source_ip=None,
source_port=None,
timeout=None,
on_phase_change=None,
):
if on_phase_change is not None:
on_phase_change("tcp_connect")
await cleanup_event.wait()
# Return something valid so the test doesn't raise during shutdown.
return await super().connect(
host,
port,
af=af,
ssl_options=ssl_options,
max_buffer_size=max_buffer_size,
source_ip=source_ip,
source_port=source_port,
timeout=timeout,
)

with closing(self.create_client()) as client:
client.tcp_client = TimeoutTCPClient(resolver=client.resolver)
with self.assertRaises(HTTPTimeoutError) as cm:
yield client.fetch(
self.get_url("/hello"),
connect_timeout=timeout,
request_timeout=3600,
raise_error=True,
)
self.assertEqual(str(cm.exception),"Timeout during TCP connection while connecting")

# Let the hanging coroutine clean up after itself. We need to
# wait more than a single IOLoop iteration for the SSL case,
Expand Down Expand Up @@ -772,8 +826,9 @@ def get_app(self):
return Application([url("/hello", HelloWorldHandler)])

def test_resolve_timeout(self):
with self.assertRaises(HTTPTimeoutError):
with self.assertRaises(HTTPTimeoutError) as cm:
self.fetch("/hello", connect_timeout=0.1, raise_error=True)
self.assertEqual(str(cm.exception), "Timeout during DNS resolution while connecting")

# Let the hanging coroutine clean up after itself
self.cleanup_event.set()
Expand Down