diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 23ef5eb..5782625 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,10 +16,10 @@ jobs: runs-on: "ubuntu-latest" strategy: + fail-fast: false matrix: python-version: ["3.11", "3.12", "3.13", "3.14"] redis-version: [6] - fail-fast: false steps: - uses: "actions/checkout@v4" diff --git a/CHANGELOG.md b/CHANGELOG.md index bb83fcc..f7f5ca5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ The **third number** is for emergencies when we need to start branches for older ([#58](https://github.com/Tinche/uapi/pull/58)) - Dictionaries are now supported in the OpenAPI schema, rendering to object schemas with `additionalProperties`. ([#58](https://github.com/Tinche/uapi/pull/58)) +- Multiple query parameters can now be received by annotating a parameter with `list` or `Sequence`. + ([#68](https://github.com/Tinche/uapi/pull/68)) - {meth}`uapi.flask.FlaskApp.run`, {meth}`uapi.quart.QuartApp.run` and {meth}`uapi.starlette.StarletteApp.run` now expose `host` parameters. ([#59](https://github.com/Tinche/uapi/pull/59)) - _uapi_ is now tested against Python 3.13 and 3.14. diff --git a/docs/handlers.md b/docs/handlers.md index e284db1..b6e2ff9 100644 --- a/docs/handlers.md +++ b/docs/handlers.md @@ -49,6 +49,11 @@ To receive query parameters, annotate a handler parameter with any type that has The {class}`App `'s dependency injection system is configured to fulfill handler parameters from query parameters by default; directly when annotated as strings or Any or through the App's converter if any other type. Query parameters may have default values. +```{note} +Technically, HTTP requests may contain multiple query parameters with the same name. +Unless the parameter is annotated as a list or sequence, all underlying frameworks return the *first* value encountered, except Django; it returns the last. +``` + Query params will be present in the [OpenAPI schema](openapi.md); parameters with defaults will be rendered as `required=False`. ```python @@ -58,6 +63,32 @@ async def query_handler(string_query: str, int_query: int = 0) -> None: return ``` +When a required query parameter is not provided, the result depends on the underlying framework used: + +* Starlette, aiohttp and Django return a `500 Internal Server Error`. +* Quart and Flask return a `400 Bad Request` error. + +#### Multiple Query Parameters + +To receive multiple query parameters, annotate a handler parameter with `list[T]` or [`Sequence[T]`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence). +When `list[str]` is used, the underlying framework's result will be directly returned; +otherwise the result will be structured into the parameter type by the App converter. +Because the underlying frameworks generally support only basic parsing of query parameters, this is usually only useful with simple types, like `list[int]` or `Sequence[int]`. + +```python +@app.get("/query_handler") +async def query_handler(string_query: list[str]) -> None: + # `string_query` can be provided multiple times. + return +``` + +```{note} +A multiple query parameter without a default value will be marked as `required` in the OpenAPI schema +even though technically it is not. +This is done mostly for consistency. +Assign a default value to make it non-required. +``` + ### Path Parameters One of the simplest ways of getting data into a handler is by using _path parameters_. diff --git a/pyproject.toml b/pyproject.toml index 3af771e..746ca58 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ classifiers = [ "Typing :: Typed", ] dependencies = [ - "cattrs >= 23.2.2", + "cattrs>=25.3.0", "incant >= 23.2.0", "itsdangerous", "attrs >= 23.1.0", diff --git a/src/uapi/_openapi.py b/src/uapi/_openapi.py index 3bb84b4..ba06890 100644 --- a/src/uapi/_openapi.py +++ b/src/uapi/_openapi.py @@ -163,6 +163,7 @@ def build_operation( builder.get_schema_for_type(form_type) ) else: + # Query params if is_union_type(arg_type): refs: list[Reference | Schema | IntegerSchema] = [] for union_member in arg_type.__args__: @@ -170,7 +171,9 @@ def build_operation( refs.append(Schema(Schema.Type.NULL)) elif union_member in builder.PYTHON_PRIMITIVES_TO_OPENAPI: refs.append(builder.PYTHON_PRIMITIVES_TO_OPENAPI[union_member]) - param_schema: OneOfSchema | Schema | IntegerSchema = OneOfSchema(refs) + param_schema: AnySchema | Reference = OneOfSchema(refs) + elif getattr(arg_type, "__origin__", None) in (list, Sequence): + param_schema = builder.get_schema_for_type(arg_type) else: param_schema = builder.PYTHON_PRIMITIVES_TO_OPENAPI.get( arg_param.annotation, builder.PYTHON_PRIMITIVES_TO_OPENAPI[str] diff --git a/src/uapi/aiohttp.py b/src/uapi/aiohttp.py index 368d2d3..9fd8881 100644 --- a/src/uapi/aiohttp.py +++ b/src/uapi/aiohttp.py @@ -1,5 +1,5 @@ from asyncio import sleep -from collections.abc import Callable, Coroutine +from collections.abc import Callable, Coroutine, Sequence from functools import partial from inspect import Parameter, Signature, signature from logging import Logger @@ -248,6 +248,48 @@ def read_query(_request: FrameworkRequest): res.register_hook_factory( lambda p: p.annotation in (Signature.empty, str), string_query_factory ) + + def nonstring_list_query_factory( + p: Parameter, + ) -> Callable[[FrameworkRequest], list]: + def read_query(_request: FrameworkRequest): + return ( + converter.structure(_request.query.getall(p.name), p.annotation) + if p.default is Signature.empty + else ( + converter.structure(_request.query.getall(p.name), p.annotation) + if p.name in _request.query + else p.default + ) + ) + + return read_query + + res.register_hook_factory( + lambda p: getattr(p.annotation, "__origin__", None) in (list, Sequence), + nonstring_list_query_factory, + ) + + def string_list_query_factory( + p: Parameter, + ) -> Callable[[FrameworkRequest], list[str]]: + def read_query(_request: FrameworkRequest): + return ( + _request.query.getall(p.name) + if p.default is Signature.empty + else ( + _request.query.getall(p.name) + if p.name in _request.query + else p.default + ) + ) + + return read_query + + res.register_hook_factory( + lambda p: p.annotation == list[str], string_list_query_factory + ) + res.register_hook_factory( is_header, lambda p: _make_header_dependency( diff --git a/src/uapi/django.py b/src/uapi/django.py index 9bf0621..b6eee9b 100644 --- a/src/uapi/django.py +++ b/src/uapi/django.py @@ -1,4 +1,4 @@ -from collections.abc import Callable +from collections.abc import Callable, Sequence from functools import partial from inspect import Parameter, Signature, signature from typing import Any, ClassVar, Generic, TypeAlias, TypeVar @@ -255,6 +255,43 @@ def read_query(_request: FrameworkRequest) -> str: lambda p: p.annotation in (Signature.empty, str), string_query_factory ) + def nonstring_list_query_factory( + p: Parameter, + ) -> Callable[[FrameworkRequest], Sequence]: + def read_query(_request: FrameworkRequest): + return ( + converter.structure(_request.GET.getlist(p.name), p.annotation) + if p.default is Signature.empty + else ( + converter.structure(_request.GET.getlist(p.name), p.annotation) + if p.name in _request.GET + else p.default + ) + ) + + return read_query + + res.register_hook_factory( + lambda p: getattr(p.annotation, "__origin__", None) in (list, Sequence), + nonstring_list_query_factory, + ) + + def string_list_query_factory( + p: Parameter, + ) -> Callable[[FrameworkRequest], list[str]]: + def read_query(_request: FrameworkRequest) -> list[str]: + return ( + _request.GET.getlist(p.name) + if p.default is Signature.empty + else _request.GET.getlist(p.name, p.default) + ) + + return read_query + + res.register_hook_factory( + lambda p: p.annotation == list[str], string_list_query_factory + ) + res.register_hook_factory( is_header, lambda p: _make_header_dependency( diff --git a/src/uapi/flask.py b/src/uapi/flask.py index de9f7f3..03ab5c5 100644 --- a/src/uapi/flask.py +++ b/src/uapi/flask.py @@ -1,6 +1,6 @@ -from collections.abc import Callable +from collections.abc import Callable, Sequence from functools import partial -from inspect import Signature, signature +from inspect import Parameter, Signature, signature from typing import Any, ClassVar, Generic, TypeAlias, TypeVar from attrs import Factory, define @@ -178,6 +178,44 @@ def _make_flask_incanter(converter: Converter) -> Incanter: else request.args.get(p.name, p.default) ), ) + + def nonstring_list_query_factory(p: Parameter) -> Callable[[], Sequence]: + def read_query(): + return ( + converter.structure(request.args.getlist(p.name), p.annotation) + if p.default is Signature.empty + else ( + converter.structure(request.args.getlist(p.name), p.annotation) + if p.name in request.args + else p.default + ) + ) + + return read_query + + res.register_hook_factory( + lambda p: getattr(p.annotation, "__origin__", None) in (list, Sequence), + nonstring_list_query_factory, + ) + + def string_list_query_factory(p: Parameter) -> Callable[[], list[str]]: + def read_query() -> list[str]: + return ( + request.args.getlist(p.name) + if p.default is Signature.empty + else ( + request.args.getlist(p.name) + if p.name in request.args + else p.default + ) + ) + + return read_query + + res.register_hook_factory( + lambda p: p.annotation == list[str], string_list_query_factory + ) + res.register_hook_factory( is_header, lambda p: _make_header_dependency( diff --git a/src/uapi/openapi.py b/src/uapi/openapi.py index fe12855..073f685 100644 --- a/src/uapi/openapi.py +++ b/src/uapi/openapi.py @@ -2,7 +2,6 @@ from __future__ import annotations from collections.abc import Callable, Mapping, Sequence -from contextlib import suppress from datetime import date, datetime from enum import Enum, unique from typing import Any, ClassVar, Literal, TypeAlias @@ -218,10 +217,7 @@ def get_schema_for_type( return ArraySchema(inner) raise Exception("Nested arrays are unsupported") - mapping = False - # TODO: remove this when cattrs 24.1 releases - with suppress(TypeError): - mapping = is_mapping(type) + mapping = is_mapping(type) if mapping: # Dicts also get created inline. args = get_args(type) diff --git a/src/uapi/quart.py b/src/uapi/quart.py index f5f0ef4..652f602 100644 --- a/src/uapi/quart.py +++ b/src/uapi/quart.py @@ -1,8 +1,8 @@ from asyncio import create_task, sleep -from collections.abc import Callable, Coroutine, Generator +from collections.abc import Callable, Coroutine, Generator, Sequence from contextlib import contextmanager, suppress from functools import partial -from inspect import Signature, signature +from inspect import Parameter, Signature, signature from typing import Any, ClassVar, Generic, TypeAlias, TypeVar from attrs import Factory, define @@ -237,6 +237,44 @@ def _make_quart_incanter(converter: Converter) -> Incanter: else request.args.get(p.name, p.default) ), ) + + def nonstring_list_query_factory(p: Parameter) -> Callable[[], Sequence]: + def read_query(): + return ( + converter.structure(request.args.getlist(p.name), p.annotation) + if p.default is Signature.empty + else ( + converter.structure(request.args.getlist(p.name), p.annotation) + if p.name in request.args + else p.default + ) + ) + + return read_query + + res.register_hook_factory( + lambda p: getattr(p.annotation, "__origin__", None) in (list, Sequence), + nonstring_list_query_factory, + ) + + def string_list_query_factory(p: Parameter) -> Callable[[], list[str]]: + def read_query() -> list[str]: + return ( + request.args.getlist(p.name) + if p.default is Signature.empty + else ( + request.args.getlist(p.name) + if p.name in request.args + else p.default + ) + ) + + return read_query + + res.register_hook_factory( + lambda p: p.annotation == list[str], string_list_query_factory + ) + res.register_hook_factory( is_header, lambda p: _make_header_dependency( diff --git a/src/uapi/starlette.py b/src/uapi/starlette.py index 59652bf..d34b7ca 100644 --- a/src/uapi/starlette.py +++ b/src/uapi/starlette.py @@ -1,5 +1,5 @@ from asyncio import create_task, sleep -from collections.abc import Callable, Coroutine, Generator +from collections.abc import Callable, Coroutine, Generator, Sequence from contextlib import contextmanager, suppress from functools import partial from inspect import Parameter, Signature, signature @@ -261,6 +261,50 @@ def read_query(_request: FrameworkRequest) -> Any: res.register_hook_factory( lambda p: p.annotation in (Signature.empty, str), string_query_factory ) + + def nonstring_list_query_factory( + p: Parameter, + ) -> Callable[[FrameworkRequest], list]: + def read_query(_request: FrameworkRequest): + return ( + converter.structure(_request.query_params.getlist(p.name), p.annotation) + if p.default is Signature.empty + else ( + converter.structure( + _request.query_params.getlist(p.name), p.annotation + ) + if p.name in _request.query_params + else p.default + ) + ) + + return read_query + + res.register_hook_factory( + lambda p: getattr(p.annotation, "__origin__", None) in (list, Sequence), + nonstring_list_query_factory, + ) + + def string_list_query_factory( + p: Parameter, + ) -> Callable[[FrameworkRequest], list[str]]: + def read_query(_request: FrameworkRequest): + return ( + _request.query_params.getlist(p.name) + if p.default is Signature.empty + else ( + _request.query_params.getlist(p.name) + if p.name in _request.query_params + else p.default + ) + ) + + return read_query + + res.register_hook_factory( + lambda p: p.annotation == list[str], string_list_query_factory + ) + res.register_hook_factory( is_header, lambda p: _make_header_dependency( diff --git a/tests/apps.py b/tests/apps.py index 5cb4df7..32117a3 100644 --- a/tests/apps.py +++ b/tests/apps.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from datetime import datetime from typing import Annotated, TypeAlias, TypeVar @@ -50,6 +51,26 @@ async def hello() -> str: async def query_post(page: int) -> str: return str(page + 1) + @app.get("/query-list") + async def query_list(param: list[str]) -> str: + return str(sum(int(q) for q in param)) + + @app.get("/query-list-def") + async def query_list_def(param: list[str] = ["1", "2"]) -> str: + """Query lists with defaults.""" + return str(sum(int(q) for q in param)) + + @app.get("/query-list-nonstring") + async def query_list_nonstring(param: list[int]) -> str: + """Query lists with non-strings work.""" + return str(sum(param)) + + @app.get("/query-seq") + async def query_list_seq(param: Sequence[int]) -> str: + """Query sequences work.""" + assert isinstance(param, tuple) + return str(sum(param)) + @app.get("/response-bytes", tags=["query"]) async def response_bytes() -> bytes: return b"2" @@ -285,6 +306,26 @@ def hello() -> str: def query_post(page: int) -> str: return str(page + 1) + @app.get("/query-list") + def query_list(param: list[str]) -> str: + return str(sum(int(q) for q in param)) + + @app.get("/query-list-def") + def query_list_def(param: list[str] = ["1", "2"]) -> str: + """Query lists with defaults.""" + return str(sum(int(q) for q in param)) + + @app.get("/query-list-nonstring") + def query_list_nonstring(param: list[int]) -> str: + """Query lists with non-strings work.""" + return str(sum(param)) + + @app.get("/query-seq") + def query_list_seq(param: Sequence[int]) -> str: + """Query sequences work.""" + assert isinstance(param, tuple) + return str(sum(param)) + @app.get("/response-bytes", tags=["query"]) def response_bytes() -> bytes: return b"2" diff --git a/tests/openapi/test_openapi.py b/tests/openapi/test_openapi.py index 7441231..e5acc79 100644 --- a/tests/openapi/test_openapi.py +++ b/tests/openapi/test_openapi.py @@ -56,78 +56,6 @@ def test_get_path_param(app: App) -> None: assert op.get.description is None -def test_get_query_int(app: App) -> None: - spec: OpenAPI = app.make_openapi_spec() - - op = spec.paths["/query"] - assert op is not None - assert op.get is not None - assert op.get.parameters == [ - Parameter( - name="page", - kind=Parameter.Kind.QUERY, - required=True, - schema=IntegerSchema(), - ) - ] - assert len(op.get.responses) == 1 - assert op.get.responses["200"] - - -def test_get_query_default(app: App) -> None: - spec: OpenAPI = app.make_openapi_spec() - - op = spec.paths["/query-default"] - assert op is not None - assert op.get - assert op.get.parameters == [ - Parameter( - name="page", - kind=Parameter.Kind.QUERY, - required=False, - schema=IntegerSchema(), - ) - ] - assert len(op.get.responses) == 1 - assert op.get.responses["200"] - - -def test_get_query_unannotated(app: App) -> None: - spec: OpenAPI = app.make_openapi_spec() - - op = spec.paths["/query/unannotated"] - assert op is not None - assert op.get - assert op.get.parameters == [ - Parameter( - name="query", - kind=Parameter.Kind.QUERY, - required=True, - schema=Schema(Schema.Type.STRING), - ) - ] - assert len(op.get.responses) == 1 - assert op.get.responses["200"] - - -def test_get_query_string(app: App) -> None: - spec: OpenAPI = app.make_openapi_spec() - - op = spec.paths["/query/string"] - assert op is not None - assert op.get is not None - assert op.get.parameters == [ - Parameter( - name="query", - kind=Parameter.Kind.QUERY, - required=True, - schema=Schema(Schema.Type.STRING), - ) - ] - assert len(op.get.responses) == 1 - assert op.get.responses["200"] - - def test_get_bytes(app: App) -> None: spec: OpenAPI = app.make_openapi_spec() diff --git a/tests/openapi/test_openapi_query.py b/tests/openapi/test_openapi_query.py new file mode 100644 index 0000000..9d26571 --- /dev/null +++ b/tests/openapi/test_openapi_query.py @@ -0,0 +1,130 @@ +"""Test query parameters.""" + +from uapi.base import App +from uapi.openapi import ArraySchema, IntegerSchema, OpenAPI, Parameter, Schema + + +def test_get_query_int(app: App) -> None: + spec: OpenAPI = app.make_openapi_spec() + + op = spec.paths["/query"] + assert op is not None + assert op.get is not None + assert op.get.parameters == [ + Parameter( + name="page", + kind=Parameter.Kind.QUERY, + required=True, + schema=IntegerSchema(), + ) + ] + assert len(op.get.responses) == 1 + assert op.get.responses["200"] + + +def test_get_query_default(app: App) -> None: + spec: OpenAPI = app.make_openapi_spec() + + op = spec.paths["/query-default"] + assert op is not None + assert op.get + assert op.get.parameters == [ + Parameter( + name="page", + kind=Parameter.Kind.QUERY, + required=False, + schema=IntegerSchema(), + ) + ] + assert len(op.get.responses) == 1 + assert op.get.responses["200"] + + +def test_get_query_unannotated(app: App) -> None: + spec: OpenAPI = app.make_openapi_spec() + + op = spec.paths["/query/unannotated"] + assert op is not None + assert op.get + assert op.get.parameters == [ + Parameter( + name="query", + kind=Parameter.Kind.QUERY, + required=True, + schema=Schema(Schema.Type.STRING), + ) + ] + assert len(op.get.responses) == 1 + assert op.get.responses["200"] + + +def test_get_query_string(app: App) -> None: + spec: OpenAPI = app.make_openapi_spec() + + op = spec.paths["/query/string"] + assert op is not None + assert op.get is not None + assert op.get.parameters == [ + Parameter( + name="query", + kind=Parameter.Kind.QUERY, + required=True, + schema=Schema(Schema.Type.STRING), + ) + ] + assert len(op.get.responses) == 1 + assert op.get.responses["200"] + + +def test_get_query_list(app: App) -> None: + spec: OpenAPI = app.make_openapi_spec() + + op = spec.paths["/query-list"] + assert op is not None + assert op.get + assert op.get.parameters == [ + Parameter( + name="param", + kind=Parameter.Kind.QUERY, + required=True, + schema=ArraySchema(Schema(Schema.Type.STRING)), + ) + ] + assert len(op.get.responses) == 1 + assert op.get.responses["200"] + + +def test_get_query_list_nonstring(app: App) -> None: + spec: OpenAPI = app.make_openapi_spec() + + op = spec.paths["/query-list-nonstring"] + assert op is not None + assert op.get + assert op.get.parameters == [ + Parameter( + name="param", + kind=Parameter.Kind.QUERY, + required=True, + schema=ArraySchema(IntegerSchema()), + ) + ] + assert len(op.get.responses) == 1 + assert op.get.responses["200"] + + +def test_get_query_seq(app: App) -> None: + spec: OpenAPI = app.make_openapi_spec() + + op = spec.paths["/query-seq"] + assert op is not None + assert op.get + assert op.get.parameters == [ + Parameter( + name="param", + kind=Parameter.Kind.QUERY, + required=True, + schema=ArraySchema(IntegerSchema()), + ) + ] + assert len(op.get.responses) == 1 + assert op.get.responses["200"] diff --git a/tests/test_query.py b/tests/test_query.py index 6f36ca2..08a6e81 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -3,7 +3,7 @@ @pytest.mark.asyncio(loop_scope="session") -async def test_query_post(server): +async def test_query_post(server: int): """Test query params in posts.""" async with AsyncClient() as client: resp = await client.post( @@ -11,3 +11,48 @@ async def test_query_post(server): ) assert resp.status_code == 200 assert resp.read() == b"3" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_query_string_list(server: int): + """Multiple query params can be gathered into lists.""" + async with AsyncClient() as client: + resp = await client.get( + f"http://localhost:{server}/query-list", params={"param": ["1", "2", "3"]} + ) + assert resp.status_code == 200 + assert resp.read() == b"6" + + resp = await client.get( + f"http://localhost:{server}/query-list-def", + params={"param": ["1", "2", "3"]}, + ) + assert resp.status_code == 200 + assert resp.read() == b"6" + + resp = await client.get(f"http://localhost:{server}/query-list-def") + assert resp.status_code == 200 + assert resp.read() == b"3" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_query_nonstring_list(server: int): + """Multiple non-string query params can be gathered into lists.""" + async with AsyncClient() as client: + resp = await client.get( + f"http://localhost:{server}/query-list-nonstring", + params={"param": ["1", "2", "3"]}, + ) + assert resp.status_code == 200 + assert resp.read() == b"6" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_query_seq(server: int): + """Multiple query params can be gathered into sequences.""" + async with AsyncClient() as client: + resp = await client.get( + f"http://localhost:{server}/query-seq", params={"param": ["1", "2", "3"]} + ) + assert resp.status_code == 200 + assert resp.read() == b"6" diff --git a/uv.lock b/uv.lock index 3482ca9..990c67b 100644 --- a/uv.lock +++ b/uv.lock @@ -2438,7 +2438,7 @@ test = [ [package.metadata] requires-dist = [ { name = "attrs", specifier = ">=23.1.0" }, - { name = "cattrs", specifier = ">=23.2.2" }, + { name = "cattrs", specifier = ">=25.3.0" }, { name = "incant", specifier = ">=23.2.0" }, { name = "itsdangerous" }, { name = "orjson", specifier = ">=3.11.3" },