Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.in
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
async-asgi-testclient
black
flake8
mypy>=0.941
Expand Down
10 changes: 8 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#
# This file is autogenerated by pip-compile with Python 3.11
# This file is autogenerated by pip-compile with Python 3.14
# by the following command:
#
# pip-compile
#
alabaster==1.0.0
# via sphinx
async-asgi-testclient==1.4.11
# via -r requirements.in
babel==2.18.0
# via sphinx
black==26.3.1
Expand Down Expand Up @@ -49,6 +51,8 @@ markupsafe==3.0.3
# via jinja2
mccabe==0.7.0
# via flake8
multidict==6.7.0
# via async-asgi-testclient
mypy==1.19.1
# via -r requirements.in
mypy-extensions==1.1.0
Expand Down Expand Up @@ -94,7 +98,9 @@ python-discovery==1.2.0
pytokens==0.4.1
# via black
requests==2.32.5
# via sphinx
# via
# async-asgi-testclient
# sphinx
roman-numerals==4.1.0
# via sphinx
snowballstemmer==3.0.1
Expand Down
223 changes: 223 additions & 0 deletions tornado/asgi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import inspect
from asyncio import create_task, Future, Task, wait
from collections.abc import AsyncGenerator, Awaitable, Callable
from dataclasses import dataclass
from typing import Optional, Union

from tornado.httputil import (
HTTPConnection,
HTTPHeaders,
RequestStartLine,
ResponseStartLine,
)
from tornado.web import Application

ReceiveCallable = Callable[[], Awaitable[dict]]
SendCallable = Callable[[dict], Awaitable[None]]
ApplicationGen = Callable[[], AsyncGenerator]


@dataclass
class ASGIHTTPRequestContext:
"""To convey connection details to the HTTPServerRequest object"""

protocol: str
address: Optional[tuple] = None
remote_ip: str = "0.0.0.0"


class ASGIHTTPConnection(HTTPConnection):
"""Represents the connection for 1 request/response pair

This provides the API for sending the response.
"""

def __init__(self, send_cb: SendCallable, context: ASGIHTTPRequestContext):
self.send_cb = send_cb
self.context = context
self.task_holder: set[Task] = set()
self._close_callback: Callable[[], None] | None = None
self._request_finished: Future[None] = Future()

# Various tornado APIs (e.g. RequestHandler.flush()) return a Future which
# application code does not need to await. The operations these represent
# are expected to complete even if the Future is not awaited. We hold onto
# these on the connection so they are not destroyed, and so that we can
# wait for them at the end of the ASGI connections scope.
def _bg_task(self, coro) -> Future: # type: ignore
task = create_task(coro)
self.task_holder.add(task)
task.add_done_callback(self.task_holder.discard)
return task

async def _write_headers(
self,
start_line: Union["RequestStartLine", "ResponseStartLine"],
headers: HTTPHeaders,
chunk: Optional[bytes] = None,
) -> None:
assert isinstance(start_line, ResponseStartLine)
await self.send_cb(
{
"type": "http.response.start",
"status": start_line.code,
"headers": [
[k.lower().encode("latin1"), v.encode("latin1")]
for k, v in headers.get_all()
],
}
)
if chunk is not None:
await self._write(chunk)

def write_headers(
self,
start_line: Union["RequestStartLine", "ResponseStartLine"],
headers: HTTPHeaders,
chunk: Optional[bytes] = None,
) -> "Future[None]":
return self._bg_task(self._write_headers(start_line, headers, chunk))

async def _write(self, chunk: bytes) -> None:
await self.send_cb(
{"type": "http.response.body", "body": chunk, "more_body": True}
)

def write(self, chunk: bytes) -> "Future[None]":
return self._bg_task(self._write(chunk))

def finish(self) -> None:
self._bg_task(
self.send_cb(
{
"type": "http.response.body",
"body": b"",
"more_body": False,
}
)
)
self._request_finished.set_result(None)

def set_close_callback(self, callback: Optional[Callable[[], None]]) -> None:
self._close_callback = callback

def _on_connection_close(self) -> None:
if self._close_callback is not None:
callback = self._close_callback
self._close_callback = None
callback()
self._request_finished.set_result(None)

async def wait_finish(self) -> None:
"""For the ASGI interface: wait for all input & output to finish"""
await self._request_finished
await wait(self.task_holder)


class ASGIAdapter:
"""Wrap a tornado application object to use with an ASGI server"""

application: Optional[Application]

def __init__(self, application: Application | ApplicationGen):
if isinstance(application, Application):
self.application_gen = None
self.application = application
elif inspect.isasyncgenfunction(application):
self.application_gen = application()
self.application = None
else:
raise TypeError(f"ASGIAdapter does not recognise {application!r}")

async def __call__(
self, scope: dict, receive: ReceiveCallable, send: SendCallable
) -> None:
if scope["type"] == "lifespan":
return await self.lifespan_scope(scope, receive, send)
if scope["type"] == "http":
return await self.http_scope(scope, receive, send)
raise KeyError(scope["type"])

async def _initialise_application(self) -> None:
# Ideally triggered by a lifespan startup message, but if the server
# doesn't support that, we'll do the setup on the first request.
if self.application is None:
assert self.application_gen is not None
self.application = await anext(self.application_gen)

async def lifespan_scope(
self, scope: dict, receive: ReceiveCallable, send: SendCallable
) -> None:
while True:
event = await receive()
if event["type"] == "lifespan.startup":
try:
await self._initialise_application()
except Exception as e:
await send({"type": "lifespan.startup.failed", "message": str(e)})
else:
await send({"type": "lifespan.startup.complete"})

elif event["type"] == "lifespan.shutdown":
try:
if self.application_gen is not None:
await anext(self.application_gen)
except StopAsyncIteration:
await send({"type": "lifespan.shutdown.complete"})
except Exception as e:
await send({"type": "lifespan.shutdown.failed", "message": str(e)})
else:
await send(
{
"type": "lifespan.shutdown.failed",
"message": "Async generator did not exit as expected",
}
)

async def http_scope(
self, scope: dict, receive: ReceiveCallable, send: SendCallable
) -> None:
"""Handles one HTTP request"""
await self._initialise_application()
assert self.application is not None

ctx = ASGIHTTPRequestContext(scope["scheme"])
if client_addr := scope.get("client", None):
ctx.address = tuple(client_addr)
ctx.remote_ip = client_addr[0]

conn = ASGIHTTPConnection(send, ctx)
msg_delegate = self.application.start_request(None, conn)
start_line, req_headers = self._http_convert_req(scope)
if (fut := msg_delegate.headers_received(start_line, req_headers)) is not None:
await fut

while True:
event = await receive()
if event["type"] == "http.request":
if chunk := event.get("body", b""):
if (fut := msg_delegate.data_received(chunk)) is not None:
await fut
if not event.get("more_body", False):
msg_delegate.finish()
break
elif event["type"] == "http.disconnect":
msg_delegate.on_connection_close()
conn._on_connection_close()
break

await conn.wait_finish()

@staticmethod
def _http_convert_req(scope: dict) -> tuple[RequestStartLine, HTTPHeaders]:
req_target = scope["path"]
if qs := scope["query_string"]:
req_target += "?" + qs.decode("latin1")
req_start_line = RequestStartLine(
scope["method"], req_target, scope["http_version"]
)
req_headers = HTTPHeaders()
for k, v in scope["headers"]:
req_headers.add(k.decode("latin1"), v.decode("latin1"))

return req_start_line, req_headers
20 changes: 17 additions & 3 deletions tornado/httputil.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@

import typing
from collections.abc import Awaitable, Generator, Iterable, Iterator, Mapping
from typing import (
AnyStr,
)
from typing import AnyStr, Optional

if typing.TYPE_CHECKING:
# These are relatively heavy imports and aren't needed in this file
Expand Down Expand Up @@ -762,6 +760,22 @@ def finish(self) -> None:
"""Indicates that the last body data has been written."""
raise NotImplementedError()

def set_close_callback(
self, callback: Optional[collections.abc.Callable[[], None]]
) -> None:
"""Sets a callback that will be run when the connection is closed.

Note that this callback is slightly different from
`.HTTPMessageDelegate.on_connection_close`: The
`.HTTPMessageDelegate` method is called when the connection is
closed while receiving a message. This callback is used when
there is not an active delegate (for example, on the server
side this callback is used if the client closes the connection
after sending its request but before receiving all the
response.
"""
raise NotImplementedError()


def url_concat(
url: str,
Expand Down
74 changes: 74 additions & 0 deletions tornado/test/asgi_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import unittest

try:
import async_asgi_testclient # type: ignore
except ImportError:
async_asgi_testclient = None

from tornado.asgi import ASGIAdapter
from tornado.web import Application, RequestHandler
from tornado.testing import AsyncTestCase, gen_test


class BasicHandler(RequestHandler):
def get(self):
name = self.get_argument("name", "world")
self.write(f"Hello, {name}")


class InspectHandler(RequestHandler):
def make_response(self, path_var):
# Send the response as JSON
self.finish(
{
"method": self.request.method,
"path": self.request.path,
"path_var": path_var,
"query_params": {
k: self.get_query_arguments(k) for k in self.request.query_arguments
},
"body": self.request.body.decode("latin1"),
}
)

def get(self, path_var):
return self.make_response(path_var)

def post(self, path_var):
return self.make_response(path_var)


@unittest.skipIf(
async_asgi_testclient is None, "async_asgi_testclient module not present"
)
class AsyncASGITestCase(AsyncTestCase):
def setUp(self) -> None:
super().setUp()
self.asgi_app = ASGIAdapter(
Application([(r"/", BasicHandler), (r"/inspect(/.*)", InspectHandler)])
)
self.client = async_asgi_testclient.TestClient(self.asgi_app)

@gen_test
async def test_basic_request(self):
resp = await self.client.get("/?name=foo")
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.text, "Hello, foo")

@gen_test
async def test_get_request_details(self):
resp = await self.client.get("/inspect/foo/?bar=baz")
d = resp.json()
self.assertEqual(d["method"], "GET")
self.assertEqual(d["path"], "/inspect/foo/")
self.assertEqual(d["query_params"], {"bar": ["baz"]})
self.assertEqual(d["body"], "")

@gen_test
async def test_post_request_details(self):
resp = await self.client.post("/inspect/foo/?bar=baz", data=b"123")
d = resp.json()
self.assertEqual(d["method"], "POST")
self.assertEqual(d["path"], "/inspect/foo/")
self.assertEqual(d["query_params"], {"bar": ["baz"]})
self.assertEqual(d["body"], "123")
1 change: 1 addition & 0 deletions tornado/test/runtests.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"tornado.httputil.doctests",
"tornado.iostream.doctests",
"tornado.util.doctests",
"tornado.test.asgi_test",
"tornado.test.asyncio_test",
"tornado.test.auth_test",
"tornado.test.autoreload_test",
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ deps =
# And since CaresResolver is deprecated, I do not expect to fix it, so just
# pin the previous version. (This should really be in requirements.{in,txt} instead)
full: pycares<5
full: async-asgi-testclient
docs: -r{toxinidir}/requirements.txt
lint: -r{toxinidir}/requirements.txt

Expand Down
Loading