diff --git a/requirements.in b/requirements.in index 417afeb5d..00920aed4 100644 --- a/requirements.in +++ b/requirements.in @@ -1,3 +1,4 @@ +async-asgi-testclient black flake8 mypy>=0.941 diff --git a/requirements.txt b/requirements.txt index 9af3bd586..e5efa3cc0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,13 @@ # -# This file is autogenerated by pip-compile with Python 3.11 +# This file is autogenerated by pip-compile with Python 3.14 # by the following command: # # pip-compile # alabaster==1.0.0 # via sphinx +async-asgi-testclient==1.4.11 + # via -r requirements.in babel==2.18.0 # via sphinx black==26.3.1 @@ -49,6 +51,8 @@ markupsafe==3.0.3 # via jinja2 mccabe==0.7.0 # via flake8 +multidict==6.7.0 + # via async-asgi-testclient mypy==1.19.1 # via -r requirements.in mypy-extensions==1.1.0 @@ -94,7 +98,9 @@ python-discovery==1.2.0 pytokens==0.4.1 # via black requests==2.32.5 - # via sphinx + # via + # async-asgi-testclient + # sphinx roman-numerals==4.1.0 # via sphinx snowballstemmer==3.0.1 diff --git a/tornado/asgi.py b/tornado/asgi.py new file mode 100644 index 000000000..069b71f9a --- /dev/null +++ b/tornado/asgi.py @@ -0,0 +1,223 @@ +import inspect +from asyncio import create_task, Future, Task, wait +from collections.abc import AsyncGenerator, Awaitable, Callable +from dataclasses import dataclass +from typing import Optional, Union + +from tornado.httputil import ( + HTTPConnection, + HTTPHeaders, + RequestStartLine, + ResponseStartLine, +) +from tornado.web import Application + +ReceiveCallable = Callable[[], Awaitable[dict]] +SendCallable = Callable[[dict], Awaitable[None]] +ApplicationGen = Callable[[], AsyncGenerator] + + +@dataclass +class ASGIHTTPRequestContext: + """To convey connection details to the HTTPServerRequest object""" + + protocol: str + address: Optional[tuple] = None + remote_ip: str = "0.0.0.0" + + +class ASGIHTTPConnection(HTTPConnection): + """Represents the connection for 1 request/response pair + + This provides the API for sending the response. + """ + + def __init__(self, send_cb: SendCallable, context: ASGIHTTPRequestContext): + self.send_cb = send_cb + self.context = context + self.task_holder: set[Task] = set() + self._close_callback: Callable[[], None] | None = None + self._request_finished: Future[None] = Future() + + # Various tornado APIs (e.g. RequestHandler.flush()) return a Future which + # application code does not need to await. The operations these represent + # are expected to complete even if the Future is not awaited. We hold onto + # these on the connection so they are not destroyed, and so that we can + # wait for them at the end of the ASGI connections scope. + def _bg_task(self, coro) -> Future: # type: ignore + task = create_task(coro) + self.task_holder.add(task) + task.add_done_callback(self.task_holder.discard) + return task + + async def _write_headers( + self, + start_line: Union["RequestStartLine", "ResponseStartLine"], + headers: HTTPHeaders, + chunk: Optional[bytes] = None, + ) -> None: + assert isinstance(start_line, ResponseStartLine) + await self.send_cb( + { + "type": "http.response.start", + "status": start_line.code, + "headers": [ + [k.lower().encode("latin1"), v.encode("latin1")] + for k, v in headers.get_all() + ], + } + ) + if chunk is not None: + await self._write(chunk) + + def write_headers( + self, + start_line: Union["RequestStartLine", "ResponseStartLine"], + headers: HTTPHeaders, + chunk: Optional[bytes] = None, + ) -> "Future[None]": + return self._bg_task(self._write_headers(start_line, headers, chunk)) + + async def _write(self, chunk: bytes) -> None: + await self.send_cb( + {"type": "http.response.body", "body": chunk, "more_body": True} + ) + + def write(self, chunk: bytes) -> "Future[None]": + return self._bg_task(self._write(chunk)) + + def finish(self) -> None: + self._bg_task( + self.send_cb( + { + "type": "http.response.body", + "body": b"", + "more_body": False, + } + ) + ) + self._request_finished.set_result(None) + + def set_close_callback(self, callback: Optional[Callable[[], None]]) -> None: + self._close_callback = callback + + def _on_connection_close(self) -> None: + if self._close_callback is not None: + callback = self._close_callback + self._close_callback = None + callback() + self._request_finished.set_result(None) + + async def wait_finish(self) -> None: + """For the ASGI interface: wait for all input & output to finish""" + await self._request_finished + await wait(self.task_holder) + + +class ASGIAdapter: + """Wrap a tornado application object to use with an ASGI server""" + + application: Optional[Application] + + def __init__(self, application: Application | ApplicationGen): + if isinstance(application, Application): + self.application_gen = None + self.application = application + elif inspect.isasyncgenfunction(application): + self.application_gen = application() + self.application = None + else: + raise TypeError(f"ASGIAdapter does not recognise {application!r}") + + async def __call__( + self, scope: dict, receive: ReceiveCallable, send: SendCallable + ) -> None: + if scope["type"] == "lifespan": + return await self.lifespan_scope(scope, receive, send) + if scope["type"] == "http": + return await self.http_scope(scope, receive, send) + raise KeyError(scope["type"]) + + async def _initialise_application(self) -> None: + # Ideally triggered by a lifespan startup message, but if the server + # doesn't support that, we'll do the setup on the first request. + if self.application is None: + assert self.application_gen is not None + self.application = await anext(self.application_gen) + + async def lifespan_scope( + self, scope: dict, receive: ReceiveCallable, send: SendCallable + ) -> None: + while True: + event = await receive() + if event["type"] == "lifespan.startup": + try: + await self._initialise_application() + except Exception as e: + await send({"type": "lifespan.startup.failed", "message": str(e)}) + else: + await send({"type": "lifespan.startup.complete"}) + + elif event["type"] == "lifespan.shutdown": + try: + if self.application_gen is not None: + await anext(self.application_gen) + except StopAsyncIteration: + await send({"type": "lifespan.shutdown.complete"}) + except Exception as e: + await send({"type": "lifespan.shutdown.failed", "message": str(e)}) + else: + await send( + { + "type": "lifespan.shutdown.failed", + "message": "Async generator did not exit as expected", + } + ) + + async def http_scope( + self, scope: dict, receive: ReceiveCallable, send: SendCallable + ) -> None: + """Handles one HTTP request""" + await self._initialise_application() + assert self.application is not None + + ctx = ASGIHTTPRequestContext(scope["scheme"]) + if client_addr := scope.get("client", None): + ctx.address = tuple(client_addr) + ctx.remote_ip = client_addr[0] + + conn = ASGIHTTPConnection(send, ctx) + msg_delegate = self.application.start_request(None, conn) + start_line, req_headers = self._http_convert_req(scope) + if (fut := msg_delegate.headers_received(start_line, req_headers)) is not None: + await fut + + while True: + event = await receive() + if event["type"] == "http.request": + if chunk := event.get("body", b""): + if (fut := msg_delegate.data_received(chunk)) is not None: + await fut + if not event.get("more_body", False): + msg_delegate.finish() + break + elif event["type"] == "http.disconnect": + msg_delegate.on_connection_close() + conn._on_connection_close() + break + + await conn.wait_finish() + + @staticmethod + def _http_convert_req(scope: dict) -> tuple[RequestStartLine, HTTPHeaders]: + req_target = scope["path"] + if qs := scope["query_string"]: + req_target += "?" + qs.decode("latin1") + req_start_line = RequestStartLine( + scope["method"], req_target, scope["http_version"] + ) + req_headers = HTTPHeaders() + for k, v in scope["headers"]: + req_headers.add(k.decode("latin1"), v.decode("latin1")) + + return req_start_line, req_headers diff --git a/tornado/httputil.py b/tornado/httputil.py index 2936e61b7..51d8937c0 100644 --- a/tornado/httputil.py +++ b/tornado/httputil.py @@ -45,9 +45,7 @@ import typing from collections.abc import Awaitable, Generator, Iterable, Iterator, Mapping -from typing import ( - AnyStr, -) +from typing import AnyStr, Optional if typing.TYPE_CHECKING: # These are relatively heavy imports and aren't needed in this file @@ -762,6 +760,22 @@ def finish(self) -> None: """Indicates that the last body data has been written.""" raise NotImplementedError() + def set_close_callback( + self, callback: Optional[collections.abc.Callable[[], None]] + ) -> None: + """Sets a callback that will be run when the connection is closed. + + Note that this callback is slightly different from + `.HTTPMessageDelegate.on_connection_close`: The + `.HTTPMessageDelegate` method is called when the connection is + closed while receiving a message. This callback is used when + there is not an active delegate (for example, on the server + side this callback is used if the client closes the connection + after sending its request but before receiving all the + response. + """ + raise NotImplementedError() + def url_concat( url: str, diff --git a/tornado/test/asgi_test.py b/tornado/test/asgi_test.py new file mode 100644 index 000000000..d8217e06a --- /dev/null +++ b/tornado/test/asgi_test.py @@ -0,0 +1,74 @@ +import unittest + +try: + import async_asgi_testclient # type: ignore +except ImportError: + async_asgi_testclient = None + +from tornado.asgi import ASGIAdapter +from tornado.web import Application, RequestHandler +from tornado.testing import AsyncTestCase, gen_test + + +class BasicHandler(RequestHandler): + def get(self): + name = self.get_argument("name", "world") + self.write(f"Hello, {name}") + + +class InspectHandler(RequestHandler): + def make_response(self, path_var): + # Send the response as JSON + self.finish( + { + "method": self.request.method, + "path": self.request.path, + "path_var": path_var, + "query_params": { + k: self.get_query_arguments(k) for k in self.request.query_arguments + }, + "body": self.request.body.decode("latin1"), + } + ) + + def get(self, path_var): + return self.make_response(path_var) + + def post(self, path_var): + return self.make_response(path_var) + + +@unittest.skipIf( + async_asgi_testclient is None, "async_asgi_testclient module not present" +) +class AsyncASGITestCase(AsyncTestCase): + def setUp(self) -> None: + super().setUp() + self.asgi_app = ASGIAdapter( + Application([(r"/", BasicHandler), (r"/inspect(/.*)", InspectHandler)]) + ) + self.client = async_asgi_testclient.TestClient(self.asgi_app) + + @gen_test + async def test_basic_request(self): + resp = await self.client.get("/?name=foo") + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.text, "Hello, foo") + + @gen_test + async def test_get_request_details(self): + resp = await self.client.get("/inspect/foo/?bar=baz") + d = resp.json() + self.assertEqual(d["method"], "GET") + self.assertEqual(d["path"], "/inspect/foo/") + self.assertEqual(d["query_params"], {"bar": ["baz"]}) + self.assertEqual(d["body"], "") + + @gen_test + async def test_post_request_details(self): + resp = await self.client.post("/inspect/foo/?bar=baz", data=b"123") + d = resp.json() + self.assertEqual(d["method"], "POST") + self.assertEqual(d["path"], "/inspect/foo/") + self.assertEqual(d["query_params"], {"bar": ["baz"]}) + self.assertEqual(d["body"], "123") diff --git a/tornado/test/runtests.py b/tornado/test/runtests.py index 60d11519c..edc0006f4 100644 --- a/tornado/test/runtests.py +++ b/tornado/test/runtests.py @@ -19,6 +19,7 @@ "tornado.httputil.doctests", "tornado.iostream.doctests", "tornado.util.doctests", + "tornado.test.asgi_test", "tornado.test.asyncio_test", "tornado.test.auth_test", "tornado.test.autoreload_test", diff --git a/tox.ini b/tox.ini index db0a4b604..e7496947f 100644 --- a/tox.ini +++ b/tox.ini @@ -50,6 +50,7 @@ deps = # And since CaresResolver is deprecated, I do not expect to fix it, so just # pin the previous version. (This should really be in requirements.{in,txt} instead) full: pycares<5 + full: async-asgi-testclient docs: -r{toxinidir}/requirements.txt lint: -r{toxinidir}/requirements.txt