From fa341d004187fdff7d9299c3d72f44098d4094ba Mon Sep 17 00:00:00 2001 From: Vladislav Date: Thu, 4 Jun 2026 18:58:00 +0300 Subject: [PATCH 1/2] feat: add DISTINCT ON support for PostgreSQL. --- tests/test_distinct.py | 204 ++++++++++++++++++++++++++++++ tortoise/contrib/test/__init__.py | 64 ++++++++++ tortoise/queryset.py | 140 ++++++++++++++++---- 3 files changed, 384 insertions(+), 24 deletions(-) create mode 100644 tests/test_distinct.py diff --git a/tests/test_distinct.py b/tests/test_distinct.py new file mode 100644 index 000000000..e2ffeb983 --- /dev/null +++ b/tests/test_distinct.py @@ -0,0 +1,204 @@ +import pytest + +from tests.testmodels import Tournament +from tortoise.contrib import test +from tortoise.exceptions import OperationalError + +# --------------------------------------------------------------------------- +# Basic DISTINCT (all databases) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_distinct_no_args(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + tournaments = await Tournament.all().distinct() + assert len(tournaments) == 2 + + +# --------------------------------------------------------------------------- +# DISTINCT ON (PostgreSQL only) +# --------------------------------------------------------------------------- + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_single_field(db): + tournament_1 = await Tournament.create(name="1", desc="1") + await Tournament.create(name="1", desc="2") + await Tournament.create(name="1", desc="3") + + tournaments = await Tournament.all().distinct("name") + assert tournaments == [tournament_1] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_single_field_with_order_by(db): + await Tournament.create(name="1", desc="1") + await Tournament.create(name="1", desc="2") + tournament_3 = await Tournament.create(name="1", desc="3") + + tournaments = await Tournament.all().distinct("name").order_by("name", "-desc") + assert tournaments == [tournament_3] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_multiple_fields(db): + tournament_1 = await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="a") + tournament_3 = await Tournament.create(name="2", desc="b") + + tournaments = await Tournament.all().distinct("name", "desc").order_by("name", "desc") + assert tournaments == [tournament_1, tournament_3] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_values_list_single_field(db): + """values_list selects one field, same as DISTINCT ON field.""" + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").values_list("name", flat=True) + assert tournaments == ["1", "2"] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_values_list_multiple_fields(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").values_list("name", "desc") + assert tournaments == [("1", "a"), ("2", "c")] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_values_list_extra_fields(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").values_list("desc", flat=True) + assert tournaments == ["a", "c"] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_values_list_extra_field_respects_order_by(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = ( + await Tournament.all() + .distinct("name") + .order_by("name", "-desc") + .values_list("desc", flat=True) + ) + assert tournaments == ["b", "c"] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_values_single_field(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").values("name") + assert tournaments == [{"name": "1"}, {"name": "2"}] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_values_multiple_fields(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").values("name", "desc") + assert tournaments == [{"name": "1", "desc": "a"}, {"name": "2", "desc": "c"}] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_values_extra_fields(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").values("desc") + assert tournaments == [{"desc": "a"}, {"desc": "c"}] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_values_extra_field_respects_order_by(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").order_by("name", "-desc").values("desc") + assert tournaments == [{"desc": "b"}, {"desc": "c"}] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_only_same_field(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").only("name") + assert [t.name for t in tournaments] == ["1", "2"] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_only_extra_field(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").only("name", "desc") + assert [(t.name, t.desc) for t in tournaments] == [("1", "a"), ("2", "c")] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_only_with_order_by(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = ( + await Tournament.all().distinct("name").order_by("name", "-desc").only("name", "desc") + ) + assert [(t.name, t.desc) for t in tournaments] == [("1", "b"), ("2", "c")] + + +# --------------------------------------------------------------------------- +# DISTINCT ON validation errors +# --------------------------------------------------------------------------- + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_invalid_order_by(db): + await Tournament.create(name="1") + with pytest.raises(OperationalError): + await Tournament.all().distinct("name").order_by("desc") + + +@test.skipCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_not_supported_outside_postgres(db): + with pytest.raises(OperationalError): + Tournament.all().distinct("name") diff --git a/tortoise/contrib/test/__init__.py b/tortoise/contrib/test/__init__.py index b55864a45..beaa068b0 100644 --- a/tortoise/contrib/test/__init__.py +++ b/tortoise/contrib/test/__init__.py @@ -50,6 +50,7 @@ async def test_sqlite_only(db): "TortoiseContext", "tortoise_test_context", "requireCapability", + "skipCapability", "truncate_all_models", "init_memory_sqlite", "SkipTest", @@ -235,6 +236,69 @@ def skip_wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: return decorator +def skipCapability( + connection_name: str = "models", **conditions: typing.Any +) -> Callable[[_FT], _FT]: + """ + Skip a test if the specified capabilities are matched. + + This is the inverse of :func:`requireCapability`. + + Usage: + + .. code-block:: python3 + + @skipCapability(dialect='postgres') + @pytest.mark.asyncio + async def test_skip_on_postgres(db): + ... + + :param connection_name: name of the connection to retrieve capabilities from. + :param conditions: capability tests — if all match, the test is skipped. + """ + + def decorator(test_item: _FT) -> _FT: + if not isinstance(test_item, type): + + def check_capabilities() -> None: + db = get_connection(connection_name) + if all(getattr(db.capabilities, key) == val for key, val in conditions.items()): + raise SkipTest(f"Skipped because capabilities match: {conditions}") + + if inspect.iscoroutinefunction(test_item): + + @wraps(test_item) + async def skip_wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: + check_capabilities() + return await test_item(*args, **kwargs) + + else: + + @wraps(test_item) + def skip_wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: + check_capabilities() + return test_item(*args, **kwargs) + + return cast(_FT, skip_wrapper) + + # Assume a class is decorated + funcs = { + var: f + for var in dir(test_item) + if var.startswith("test_") and callable(f := getattr(test_item, var)) + } + for name, func in funcs.items(): + setattr( + test_item, + name, + skipCapability(connection_name=connection_name, **conditions)(func), + ) + + return test_item + + return decorator + + @typing.overload def init_memory_sqlite(models: ModulesConfigType | None = None) -> AsyncFuncDeco: ... diff --git a/tortoise/queryset.py b/tortoise/queryset.py index aecffc42d..99f53661e 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -8,6 +8,7 @@ from pypika_tortoise import JoinType, Order, Table from pypika_tortoise.analytics import Count +from pypika_tortoise.dialects import PostgreSQLQueryBuilder from pypika_tortoise.functions import Cast from pypika_tortoise.queries import QueryBuilder, _SetOperation from pypika_tortoise.terms import Case, Field, Star, Term, ValueWrapper @@ -42,6 +43,7 @@ # Empty placeholder - Should never be edited. QUERY: QueryBuilder = QueryBuilder() +POSTGRES_QUERY: PostgreSQLQueryBuilder = PostgreSQLQueryBuilder() if TYPE_CHECKING: # pragma: nocoverage from tortoise.models import Model @@ -84,6 +86,7 @@ def values( class _ChooseDBMixin(Generic[MODEL]): _db: BaseDBAsyncClient | None model: type[MODEL] + query: QueryBuilder | PostgreSQLQueryBuilder def _choose_db(self, for_write: bool = False) -> BaseDBAsyncClient: """ @@ -99,9 +102,23 @@ def _choose_db(self, for_write: bool = False) -> BaseDBAsyncClient: db = router.db_for_read(self.model) return db or self.model._meta.db + def _apply_db(self, db: BaseDBAsyncClient) -> None: + """ + Set the database connection for this query and update the query builder dialect. + + Assigns ``db`` to ``_db`` and, when the connection targets PostgreSQL, + replaces the default ``query`` placeholder with ``POSTGRES_QUERY`` so + that subsequent query-building calls produce PostgreSQL-specific SQL. + + :param db: The database connection to use for this query. + """ + self._db = db + if db is not None and hasattr(self, "query") and db.capabilities.dialect == "postgres": + self.query = POSTGRES_QUERY + def _choose_db_if_not_chosen(self, for_write: bool = False) -> None: if self._db is None: - self._db = self._choose_db(for_write) + self._apply_db(self._choose_db(for_write)) class AwaitableQuery(_ChooseDBMixin[MODEL], Generic[MODEL]): @@ -119,7 +136,7 @@ class AwaitableQuery(_ChooseDBMixin[MODEL], Generic[MODEL]): def __init__(self, model: type[MODEL]) -> None: self._joined_tables: list[Table] = [] self.model: type[MODEL] = model - self.query: QueryBuilder = QUERY + self.query: QueryBuilder | PostgreSQLQueryBuilder = QUERY self._db: BaseDBAsyncClient = None # type: ignore self._capabilities: Capabilities | None = None self._annotations: dict[str, Expression | Term] = {} @@ -362,6 +379,7 @@ def __init__(self, model: type[MODEL]) -> None: self._filter_kwargs: dict[str, Any] = {} self._orderings: list[tuple[str, Any]] = [] self._distinct: bool = False + self._distinct_on: list[str] = [] self._having: dict[str, Any] = {} self._fields_for_select: tuple[str, ...] = () self._group_bys: tuple[str, ...] = () @@ -396,6 +414,7 @@ def _clone(self) -> QuerySet[MODEL]: queryset._joined_tables = copy(self._joined_tables) queryset._q_objects = copy(self._q_objects) queryset._distinct = self._distinct + queryset._distinct_on = copy(self._distinct_on) queryset._annotations = copy(self._annotations) queryset._having = copy(self._having) queryset._custom_filters = copy(self._custom_filters) @@ -580,15 +599,44 @@ def __getitem__(self, key: slice) -> QuerySet[MODEL]: queryset = queryset.limit(key.stop - start) return queryset - def distinct(self) -> QuerySet[MODEL]: + def distinct(self, *args: str) -> QuerySet[MODEL]: """ - Make QuerySet distinct. + Make QuerySet return distinct results. + + Without arguments, adds a plain ``DISTINCT`` to the query, which works on all databases + and is most useful with ``.values()`` or ``.values_list()``. + + With arguments (PostgreSQL only), generates ``DISTINCT ON (fields)`` which keeps one row + per unique combination of the given fields. ``ORDER BY`` is optional, but if specified + it must begin with the same fields in the same order as ``DISTINCT ON`` — otherwise an + :exc:`~tortoise.exceptions.OperationalError` is raised. - Only makes sense in combination with a ``.values()`` or ``.values_list()`` as it - precedes all the fetched fields with a distinct. + Can be combined with ``.only()``, ``.values()``, and ``.values_list()`` — fields not + present in ``DISTINCT ON`` are taken from the row selected by the ordering. + + .. code-block:: python3 + + # Plain DISTINCT — all databases + await Tournament.all().distinct().values("name") + + # DISTINCT ON without ORDER BY — PostgreSQL only + await Tournament.all().distinct("name") + + # DISTINCT ON with ORDER BY — ORDER BY must start with DISTINCT ON fields + await Tournament.all().distinct("name").order_by("name", "-desc") + + :param args: Field names for ``DISTINCT ON`` (PostgreSQL only). Omit for plain + ``DISTINCT``. + :raises OperationalError: If field arguments are given on a non-PostgreSQL database, + or if ``ORDER BY`` is specified but does not start with the ``DISTINCT ON`` fields. """ queryset = self._clone() queryset._distinct = True + if args: + if isinstance(self.query, PostgreSQLQueryBuilder): + queryset._distinct_on = list(args) + else: + raise OperationalError("DISTINCT ON is only supported by PostgreSQL") return queryset def union(self, *other_qs: QuerySet[Model], all: bool = False) -> UnionQuery[MODEL]: @@ -698,6 +746,7 @@ def values_list(self, *fields_: str, flat: bool = False) -> ValuesListQuery[Lite group_bys=self._group_bys, force_indexes=self._force_indexes, use_indexes=self._use_indexes, + distinct_on=self._distinct_on, ) def values(self, *args: str, **kwargs: str) -> ValuesQuery[Literal[False]]: @@ -753,6 +802,7 @@ def values(self, *args: str, **kwargs: str) -> ValuesQuery[Literal[False]]: group_bys=self._group_bys, force_indexes=self._force_indexes, use_indexes=self._use_indexes, + distinct_on=self._distinct_on, ) def delete(self) -> DeleteQuery: @@ -1115,7 +1165,7 @@ def using_db(self, _db: BaseDBAsyncClient | None) -> QuerySet[MODEL]: Useful for transactions workaround. """ queryset = self._clone() - queryset._db = _db if _db else queryset._db + queryset._apply_db(_db if _db else queryset._db) return queryset def _join_select_related(self, lookup_expression: str) -> tuple[type[Model], Table]: @@ -1270,6 +1320,19 @@ def _make_query(self) -> None: self.query._offset = self.query._wrapper_cls(self._offset) if self._distinct: self.query._distinct = True + if isinstance(self.query, PostgreSQLQueryBuilder) and self._distinct_on: + ordering_fields = [ordering[0] for ordering in self._orderings] + len_ordering_fields = len(ordering_fields) + for i, field in enumerate(self._distinct_on): + if ordering_fields and ( + i >= len_ordering_fields or ordering_fields[i] != field + ): + raise OperationalError( + f"DISTINCT ON fields must match the leading ORDER BY fields. " + f"Expected ORDER BY to start with {self._distinct_on!r}." + ) + self.query._distinct_on = [] + self.query.distinct_on(*self._distinct_on) if self._select_for_update: self.query = self.query.for_update( self._select_for_update_nowait, @@ -1288,8 +1351,7 @@ def _make_query(self) -> None: self.query = self.query.use_index(*self._use_indexes) def __await__(self) -> Generator[Any, None, list[MODEL]]: - if self._db is None: - self._db = self._choose_db(self._select_for_update) # type: ignore + self._choose_db_if_not_chosen(self._select_for_update) self._make_query() return self._execute().__await__() @@ -1343,7 +1405,7 @@ def __init__( self._q_objects = q_objects self._annotations = annotations self._custom_filters = custom_filters - self._db = db + self._apply_db(db) self._limit = limit self._orderings = orderings @@ -1422,7 +1484,7 @@ def __init__( self._q_objects = q_objects self._annotations = annotations self._custom_filters = custom_filters - self._db = db + self._apply_db(db) self._limit = limit self._orderings = orderings @@ -1467,7 +1529,7 @@ def __init__( ) -> None: super().__init__(model) self._q_objects = q_objects - self._db = db + self._apply_db(db) self._annotations = annotations self._custom_filters = custom_filters self._force_indexes = force_indexes @@ -1540,7 +1602,7 @@ def __init__( self._custom_filters = custom_filters self._limit = limit self._offset = offset or 0 - self._db = db + self._apply_db(db) self._force_indexes = force_indexes self._use_indexes = use_indexes @@ -1595,8 +1657,7 @@ def _join_table_with_forwarded_fields( if field in self.model._meta.fetch_fields and not forwarded_fields: raise ValueError( - f'Selecting relation "{field}" is not possible, select concrete ' - "field on related model" + f'Selecting relation "{field}" is not possible, select concrete field on related model' ) field_object = cast(RelationalField, model._meta.fields_map.get(field)) @@ -1627,8 +1688,7 @@ def add_field_to_select_query(self, field: str, return_as: str) -> None: if field in self.model._meta.fetch_fields: raise ValueError( - f'Selecting relation "{field}" is not possible, select ' - "concrete field on related model" + f'Selecting relation "{field}" is not possible, select concrete field on related model' ) field_, __, forwarded_fields = field.partition("__") @@ -1704,6 +1764,7 @@ class ValuesListQuery(FieldSelectQuery, Generic[SINGLE]): "_force_indexes", "_use_indexes", "_fields_to_select_sql", + "_distinct_on", ) def __init__( @@ -1724,6 +1785,7 @@ def __init__( group_bys: tuple[str, ...], force_indexes: set[str], use_indexes: set[str], + distinct_on: list[str], ) -> None: super().__init__(model, annotations) if flat and (len(fields_for_select_list) != 1): @@ -1741,10 +1803,11 @@ def __init__( self._raise_does_not_exist = raise_does_not_exist self._fields_for_select_list = fields_for_select_list self._flat = flat - self._db = db + self._apply_db(db) self._group_bys = group_bys self._force_indexes = force_indexes self._use_indexes = use_indexes + self._distinct_on = distinct_on self._fields_to_select_sql = { *self._fields_for_select_list, *(key for key, value in self.fields.items() if value in self._fields_for_select_list), @@ -1771,6 +1834,19 @@ def _make_query(self) -> None: self.query._offset = self.query._wrapper_cls(self._offset) if self._distinct: self.query._distinct = True + if isinstance(self.query, PostgreSQLQueryBuilder) and self._distinct_on: + ordering_fields = [ordering[0] for ordering in self._orderings] + len_ordering_fields = len(ordering_fields) + for i, field in enumerate(self._distinct_on): + if ordering_fields and ( + i >= len_ordering_fields or ordering_fields[i] != field + ): + raise OperationalError( + f"DISTINCT ON fields must match the leading ORDER BY fields. " + f"Expected ORDER BY to start with {self._distinct_on!r}." + ) + self.query._distinct_on = [] + self.query.distinct_on(*self._distinct_on) if self._group_bys: self.query._groupbys = self._resolve_group_bys(*self._group_bys) @@ -1837,6 +1913,7 @@ class ValuesQuery(FieldSelectQuery, Generic[SINGLE]): "_group_bys", "_force_indexes", "_use_indexes", + "_distinct_on", ) def __init__( @@ -1856,6 +1933,7 @@ def __init__( group_bys: tuple[str, ...], force_indexes: set[str], use_indexes: set[str], + distinct_on: list[str], ) -> None: super().__init__(model, annotations) self._fields_for_select = fields_for_select @@ -1867,10 +1945,11 @@ def __init__( self._q_objects = q_objects self._single = single self._raise_does_not_exist = raise_does_not_exist - self._db = db + self._apply_db(db) self._group_bys = group_bys self._force_indexes = force_indexes self._use_indexes = use_indexes + self._distinct_on = distinct_on def _make_query(self) -> None: self._joined_tables = [] @@ -1899,6 +1978,19 @@ def _make_query(self) -> None: self.query._offset = self.query._wrapper_cls(self._offset) if self._distinct: self.query._distinct = True + if isinstance(self.query, PostgreSQLQueryBuilder) and self._distinct_on: + ordering_fields = [ordering[0] for ordering in self._orderings] + len_ordering_fields = len(ordering_fields) + for i, field in enumerate(self._distinct_on): + if ordering_fields and ( + i >= len_ordering_fields or ordering_fields[i] != field + ): + raise OperationalError( + f"DISTINCT ON fields must match the leading ORDER BY fields. " + f"Expected ORDER BY to start with {self._distinct_on!r}." + ) + self.query._distinct_on = [] + self.query.distinct_on(*self._distinct_on) if self._group_bys: self.query._groupbys = self._resolve_group_bys(*self._group_bys) @@ -1963,7 +2055,7 @@ class RawSQLQuery(AwaitableQuery): def __init__(self, model: type[MODEL], db: BaseDBAsyncClient, sql: str) -> None: super().__init__(model) self._sql = sql - self._db = db + self._apply_db(db) async def _execute(self) -> Any: instance_list = await self._db.executor_class( @@ -2094,7 +2186,7 @@ def __init__( self._objects = objects self._ignore_conflicts = ignore_conflicts self._batch_size = batch_size - self._db = db + self._apply_db(db) self._update_fields = update_fields self._on_conflict = on_conflict @@ -2298,7 +2390,7 @@ def __init__( ) -> None: super().__init__(model) self._union_query = union_query - self._db = db + self._apply_db(db) def _make_query(self) -> None: self._union_query._make_query() @@ -2344,11 +2436,11 @@ def __init__( all: bool = False, ): self.model = model - self.query = QUERY + self.query: QueryBuilder | PostgreSQLQueryBuilder = QUERY self._models: set[type[Model]] = {model, *(qs.model for qs in querysets)} self._union_query: QueryBuilder | _SetOperation | None = None self._selects: list[str] = [] - self._db = db + self._apply_db(db) self._qs = querysets self._all = all self._orderings: list[tuple[str, Order]] | None = None From 0b7bb860327e4257690f21d49f78e9ea7add5224 Mon Sep 17 00:00:00 2001 From: Vladislav Date: Sat, 6 Jun 2026 12:27:16 +0300 Subject: [PATCH 2/2] Removed the skipCapability function. --- tests/test_distinct.py | 3 +- tortoise/contrib/test/__init__.py | 64 ------------------------------- 2 files changed, 2 insertions(+), 65 deletions(-) diff --git a/tests/test_distinct.py b/tests/test_distinct.py index e2ffeb983..5fcb44d2d 100644 --- a/tests/test_distinct.py +++ b/tests/test_distinct.py @@ -2,6 +2,7 @@ from tests.testmodels import Tournament from tortoise.contrib import test +from tortoise.contrib.test.condition import NotIn from tortoise.exceptions import OperationalError # --------------------------------------------------------------------------- @@ -197,7 +198,7 @@ async def test_distinct_on_invalid_order_by(db): await Tournament.all().distinct("name").order_by("desc") -@test.skipCapability(dialect="postgres") +@test.requireCapability(dialect=NotIn("postgres")) @pytest.mark.asyncio async def test_distinct_on_not_supported_outside_postgres(db): with pytest.raises(OperationalError): diff --git a/tortoise/contrib/test/__init__.py b/tortoise/contrib/test/__init__.py index beaa068b0..b55864a45 100644 --- a/tortoise/contrib/test/__init__.py +++ b/tortoise/contrib/test/__init__.py @@ -50,7 +50,6 @@ async def test_sqlite_only(db): "TortoiseContext", "tortoise_test_context", "requireCapability", - "skipCapability", "truncate_all_models", "init_memory_sqlite", "SkipTest", @@ -236,69 +235,6 @@ def skip_wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: return decorator -def skipCapability( - connection_name: str = "models", **conditions: typing.Any -) -> Callable[[_FT], _FT]: - """ - Skip a test if the specified capabilities are matched. - - This is the inverse of :func:`requireCapability`. - - Usage: - - .. code-block:: python3 - - @skipCapability(dialect='postgres') - @pytest.mark.asyncio - async def test_skip_on_postgres(db): - ... - - :param connection_name: name of the connection to retrieve capabilities from. - :param conditions: capability tests — if all match, the test is skipped. - """ - - def decorator(test_item: _FT) -> _FT: - if not isinstance(test_item, type): - - def check_capabilities() -> None: - db = get_connection(connection_name) - if all(getattr(db.capabilities, key) == val for key, val in conditions.items()): - raise SkipTest(f"Skipped because capabilities match: {conditions}") - - if inspect.iscoroutinefunction(test_item): - - @wraps(test_item) - async def skip_wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: - check_capabilities() - return await test_item(*args, **kwargs) - - else: - - @wraps(test_item) - def skip_wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: - check_capabilities() - return test_item(*args, **kwargs) - - return cast(_FT, skip_wrapper) - - # Assume a class is decorated - funcs = { - var: f - for var in dir(test_item) - if var.startswith("test_") and callable(f := getattr(test_item, var)) - } - for name, func in funcs.items(): - setattr( - test_item, - name, - skipCapability(connection_name=connection_name, **conditions)(func), - ) - - return test_item - - return decorator - - @typing.overload def init_memory_sqlite(models: ModulesConfigType | None = None) -> AsyncFuncDeco: ...