diff --git a/tests/test_distinct.py b/tests/test_distinct.py new file mode 100644 index 000000000..5fcb44d2d --- /dev/null +++ b/tests/test_distinct.py @@ -0,0 +1,205 @@ +import pytest + +from tests.testmodels import Tournament +from tortoise.contrib import test +from tortoise.contrib.test.condition import NotIn +from tortoise.exceptions import OperationalError + +# --------------------------------------------------------------------------- +# Basic DISTINCT (all databases) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_distinct_no_args(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + tournaments = await Tournament.all().distinct() + assert len(tournaments) == 2 + + +# --------------------------------------------------------------------------- +# DISTINCT ON (PostgreSQL only) +# --------------------------------------------------------------------------- + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_single_field(db): + tournament_1 = await Tournament.create(name="1", desc="1") + await Tournament.create(name="1", desc="2") + await Tournament.create(name="1", desc="3") + + tournaments = await Tournament.all().distinct("name") + assert tournaments == [tournament_1] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_single_field_with_order_by(db): + await Tournament.create(name="1", desc="1") + await Tournament.create(name="1", desc="2") + tournament_3 = await Tournament.create(name="1", desc="3") + + tournaments = await Tournament.all().distinct("name").order_by("name", "-desc") + assert tournaments == [tournament_3] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_multiple_fields(db): + tournament_1 = await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="a") + tournament_3 = await Tournament.create(name="2", desc="b") + + tournaments = await Tournament.all().distinct("name", "desc").order_by("name", "desc") + assert tournaments == [tournament_1, tournament_3] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_values_list_single_field(db): + """values_list selects one field, same as DISTINCT ON field.""" + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").values_list("name", flat=True) + assert tournaments == ["1", "2"] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_values_list_multiple_fields(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").values_list("name", "desc") + assert tournaments == [("1", "a"), ("2", "c")] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_values_list_extra_fields(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").values_list("desc", flat=True) + assert tournaments == ["a", "c"] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_values_list_extra_field_respects_order_by(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = ( + await Tournament.all() + .distinct("name") + .order_by("name", "-desc") + .values_list("desc", flat=True) + ) + assert tournaments == ["b", "c"] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_values_single_field(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").values("name") + assert tournaments == [{"name": "1"}, {"name": "2"}] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_values_multiple_fields(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").values("name", "desc") + assert tournaments == [{"name": "1", "desc": "a"}, {"name": "2", "desc": "c"}] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_values_extra_fields(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").values("desc") + assert tournaments == [{"desc": "a"}, {"desc": "c"}] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_values_extra_field_respects_order_by(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").order_by("name", "-desc").values("desc") + assert tournaments == [{"desc": "b"}, {"desc": "c"}] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_only_same_field(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").only("name") + assert [t.name for t in tournaments] == ["1", "2"] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_only_extra_field(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = await Tournament.all().distinct("name").only("name", "desc") + assert [(t.name, t.desc) for t in tournaments] == [("1", "a"), ("2", "c")] + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_only_with_order_by(db): + await Tournament.create(name="1", desc="a") + await Tournament.create(name="1", desc="b") + await Tournament.create(name="2", desc="c") + + tournaments = ( + await Tournament.all().distinct("name").order_by("name", "-desc").only("name", "desc") + ) + assert [(t.name, t.desc) for t in tournaments] == [("1", "b"), ("2", "c")] + + +# --------------------------------------------------------------------------- +# DISTINCT ON validation errors +# --------------------------------------------------------------------------- + + +@test.requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_distinct_on_invalid_order_by(db): + await Tournament.create(name="1") + with pytest.raises(OperationalError): + await Tournament.all().distinct("name").order_by("desc") + + +@test.requireCapability(dialect=NotIn("postgres")) +@pytest.mark.asyncio +async def test_distinct_on_not_supported_outside_postgres(db): + with pytest.raises(OperationalError): + Tournament.all().distinct("name") diff --git a/tortoise/queryset.py b/tortoise/queryset.py index aecffc42d..99f53661e 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -8,6 +8,7 @@ from pypika_tortoise import JoinType, Order, Table from pypika_tortoise.analytics import Count +from pypika_tortoise.dialects import PostgreSQLQueryBuilder from pypika_tortoise.functions import Cast from pypika_tortoise.queries import QueryBuilder, _SetOperation from pypika_tortoise.terms import Case, Field, Star, Term, ValueWrapper @@ -42,6 +43,7 @@ # Empty placeholder - Should never be edited. QUERY: QueryBuilder = QueryBuilder() +POSTGRES_QUERY: PostgreSQLQueryBuilder = PostgreSQLQueryBuilder() if TYPE_CHECKING: # pragma: nocoverage from tortoise.models import Model @@ -84,6 +86,7 @@ def values( class _ChooseDBMixin(Generic[MODEL]): _db: BaseDBAsyncClient | None model: type[MODEL] + query: QueryBuilder | PostgreSQLQueryBuilder def _choose_db(self, for_write: bool = False) -> BaseDBAsyncClient: """ @@ -99,9 +102,23 @@ def _choose_db(self, for_write: bool = False) -> BaseDBAsyncClient: db = router.db_for_read(self.model) return db or self.model._meta.db + def _apply_db(self, db: BaseDBAsyncClient) -> None: + """ + Set the database connection for this query and update the query builder dialect. + + Assigns ``db`` to ``_db`` and, when the connection targets PostgreSQL, + replaces the default ``query`` placeholder with ``POSTGRES_QUERY`` so + that subsequent query-building calls produce PostgreSQL-specific SQL. + + :param db: The database connection to use for this query. + """ + self._db = db + if db is not None and hasattr(self, "query") and db.capabilities.dialect == "postgres": + self.query = POSTGRES_QUERY + def _choose_db_if_not_chosen(self, for_write: bool = False) -> None: if self._db is None: - self._db = self._choose_db(for_write) + self._apply_db(self._choose_db(for_write)) class AwaitableQuery(_ChooseDBMixin[MODEL], Generic[MODEL]): @@ -119,7 +136,7 @@ class AwaitableQuery(_ChooseDBMixin[MODEL], Generic[MODEL]): def __init__(self, model: type[MODEL]) -> None: self._joined_tables: list[Table] = [] self.model: type[MODEL] = model - self.query: QueryBuilder = QUERY + self.query: QueryBuilder | PostgreSQLQueryBuilder = QUERY self._db: BaseDBAsyncClient = None # type: ignore self._capabilities: Capabilities | None = None self._annotations: dict[str, Expression | Term] = {} @@ -362,6 +379,7 @@ def __init__(self, model: type[MODEL]) -> None: self._filter_kwargs: dict[str, Any] = {} self._orderings: list[tuple[str, Any]] = [] self._distinct: bool = False + self._distinct_on: list[str] = [] self._having: dict[str, Any] = {} self._fields_for_select: tuple[str, ...] = () self._group_bys: tuple[str, ...] = () @@ -396,6 +414,7 @@ def _clone(self) -> QuerySet[MODEL]: queryset._joined_tables = copy(self._joined_tables) queryset._q_objects = copy(self._q_objects) queryset._distinct = self._distinct + queryset._distinct_on = copy(self._distinct_on) queryset._annotations = copy(self._annotations) queryset._having = copy(self._having) queryset._custom_filters = copy(self._custom_filters) @@ -580,15 +599,44 @@ def __getitem__(self, key: slice) -> QuerySet[MODEL]: queryset = queryset.limit(key.stop - start) return queryset - def distinct(self) -> QuerySet[MODEL]: + def distinct(self, *args: str) -> QuerySet[MODEL]: """ - Make QuerySet distinct. + Make QuerySet return distinct results. + + Without arguments, adds a plain ``DISTINCT`` to the query, which works on all databases + and is most useful with ``.values()`` or ``.values_list()``. + + With arguments (PostgreSQL only), generates ``DISTINCT ON (fields)`` which keeps one row + per unique combination of the given fields. ``ORDER BY`` is optional, but if specified + it must begin with the same fields in the same order as ``DISTINCT ON`` — otherwise an + :exc:`~tortoise.exceptions.OperationalError` is raised. - Only makes sense in combination with a ``.values()`` or ``.values_list()`` as it - precedes all the fetched fields with a distinct. + Can be combined with ``.only()``, ``.values()``, and ``.values_list()`` — fields not + present in ``DISTINCT ON`` are taken from the row selected by the ordering. + + .. code-block:: python3 + + # Plain DISTINCT — all databases + await Tournament.all().distinct().values("name") + + # DISTINCT ON without ORDER BY — PostgreSQL only + await Tournament.all().distinct("name") + + # DISTINCT ON with ORDER BY — ORDER BY must start with DISTINCT ON fields + await Tournament.all().distinct("name").order_by("name", "-desc") + + :param args: Field names for ``DISTINCT ON`` (PostgreSQL only). Omit for plain + ``DISTINCT``. + :raises OperationalError: If field arguments are given on a non-PostgreSQL database, + or if ``ORDER BY`` is specified but does not start with the ``DISTINCT ON`` fields. """ queryset = self._clone() queryset._distinct = True + if args: + if isinstance(self.query, PostgreSQLQueryBuilder): + queryset._distinct_on = list(args) + else: + raise OperationalError("DISTINCT ON is only supported by PostgreSQL") return queryset def union(self, *other_qs: QuerySet[Model], all: bool = False) -> UnionQuery[MODEL]: @@ -698,6 +746,7 @@ def values_list(self, *fields_: str, flat: bool = False) -> ValuesListQuery[Lite group_bys=self._group_bys, force_indexes=self._force_indexes, use_indexes=self._use_indexes, + distinct_on=self._distinct_on, ) def values(self, *args: str, **kwargs: str) -> ValuesQuery[Literal[False]]: @@ -753,6 +802,7 @@ def values(self, *args: str, **kwargs: str) -> ValuesQuery[Literal[False]]: group_bys=self._group_bys, force_indexes=self._force_indexes, use_indexes=self._use_indexes, + distinct_on=self._distinct_on, ) def delete(self) -> DeleteQuery: @@ -1115,7 +1165,7 @@ def using_db(self, _db: BaseDBAsyncClient | None) -> QuerySet[MODEL]: Useful for transactions workaround. """ queryset = self._clone() - queryset._db = _db if _db else queryset._db + queryset._apply_db(_db if _db else queryset._db) return queryset def _join_select_related(self, lookup_expression: str) -> tuple[type[Model], Table]: @@ -1270,6 +1320,19 @@ def _make_query(self) -> None: self.query._offset = self.query._wrapper_cls(self._offset) if self._distinct: self.query._distinct = True + if isinstance(self.query, PostgreSQLQueryBuilder) and self._distinct_on: + ordering_fields = [ordering[0] for ordering in self._orderings] + len_ordering_fields = len(ordering_fields) + for i, field in enumerate(self._distinct_on): + if ordering_fields and ( + i >= len_ordering_fields or ordering_fields[i] != field + ): + raise OperationalError( + f"DISTINCT ON fields must match the leading ORDER BY fields. " + f"Expected ORDER BY to start with {self._distinct_on!r}." + ) + self.query._distinct_on = [] + self.query.distinct_on(*self._distinct_on) if self._select_for_update: self.query = self.query.for_update( self._select_for_update_nowait, @@ -1288,8 +1351,7 @@ def _make_query(self) -> None: self.query = self.query.use_index(*self._use_indexes) def __await__(self) -> Generator[Any, None, list[MODEL]]: - if self._db is None: - self._db = self._choose_db(self._select_for_update) # type: ignore + self._choose_db_if_not_chosen(self._select_for_update) self._make_query() return self._execute().__await__() @@ -1343,7 +1405,7 @@ def __init__( self._q_objects = q_objects self._annotations = annotations self._custom_filters = custom_filters - self._db = db + self._apply_db(db) self._limit = limit self._orderings = orderings @@ -1422,7 +1484,7 @@ def __init__( self._q_objects = q_objects self._annotations = annotations self._custom_filters = custom_filters - self._db = db + self._apply_db(db) self._limit = limit self._orderings = orderings @@ -1467,7 +1529,7 @@ def __init__( ) -> None: super().__init__(model) self._q_objects = q_objects - self._db = db + self._apply_db(db) self._annotations = annotations self._custom_filters = custom_filters self._force_indexes = force_indexes @@ -1540,7 +1602,7 @@ def __init__( self._custom_filters = custom_filters self._limit = limit self._offset = offset or 0 - self._db = db + self._apply_db(db) self._force_indexes = force_indexes self._use_indexes = use_indexes @@ -1595,8 +1657,7 @@ def _join_table_with_forwarded_fields( if field in self.model._meta.fetch_fields and not forwarded_fields: raise ValueError( - f'Selecting relation "{field}" is not possible, select concrete ' - "field on related model" + f'Selecting relation "{field}" is not possible, select concrete field on related model' ) field_object = cast(RelationalField, model._meta.fields_map.get(field)) @@ -1627,8 +1688,7 @@ def add_field_to_select_query(self, field: str, return_as: str) -> None: if field in self.model._meta.fetch_fields: raise ValueError( - f'Selecting relation "{field}" is not possible, select ' - "concrete field on related model" + f'Selecting relation "{field}" is not possible, select concrete field on related model' ) field_, __, forwarded_fields = field.partition("__") @@ -1704,6 +1764,7 @@ class ValuesListQuery(FieldSelectQuery, Generic[SINGLE]): "_force_indexes", "_use_indexes", "_fields_to_select_sql", + "_distinct_on", ) def __init__( @@ -1724,6 +1785,7 @@ def __init__( group_bys: tuple[str, ...], force_indexes: set[str], use_indexes: set[str], + distinct_on: list[str], ) -> None: super().__init__(model, annotations) if flat and (len(fields_for_select_list) != 1): @@ -1741,10 +1803,11 @@ def __init__( self._raise_does_not_exist = raise_does_not_exist self._fields_for_select_list = fields_for_select_list self._flat = flat - self._db = db + self._apply_db(db) self._group_bys = group_bys self._force_indexes = force_indexes self._use_indexes = use_indexes + self._distinct_on = distinct_on self._fields_to_select_sql = { *self._fields_for_select_list, *(key for key, value in self.fields.items() if value in self._fields_for_select_list), @@ -1771,6 +1834,19 @@ def _make_query(self) -> None: self.query._offset = self.query._wrapper_cls(self._offset) if self._distinct: self.query._distinct = True + if isinstance(self.query, PostgreSQLQueryBuilder) and self._distinct_on: + ordering_fields = [ordering[0] for ordering in self._orderings] + len_ordering_fields = len(ordering_fields) + for i, field in enumerate(self._distinct_on): + if ordering_fields and ( + i >= len_ordering_fields or ordering_fields[i] != field + ): + raise OperationalError( + f"DISTINCT ON fields must match the leading ORDER BY fields. " + f"Expected ORDER BY to start with {self._distinct_on!r}." + ) + self.query._distinct_on = [] + self.query.distinct_on(*self._distinct_on) if self._group_bys: self.query._groupbys = self._resolve_group_bys(*self._group_bys) @@ -1837,6 +1913,7 @@ class ValuesQuery(FieldSelectQuery, Generic[SINGLE]): "_group_bys", "_force_indexes", "_use_indexes", + "_distinct_on", ) def __init__( @@ -1856,6 +1933,7 @@ def __init__( group_bys: tuple[str, ...], force_indexes: set[str], use_indexes: set[str], + distinct_on: list[str], ) -> None: super().__init__(model, annotations) self._fields_for_select = fields_for_select @@ -1867,10 +1945,11 @@ def __init__( self._q_objects = q_objects self._single = single self._raise_does_not_exist = raise_does_not_exist - self._db = db + self._apply_db(db) self._group_bys = group_bys self._force_indexes = force_indexes self._use_indexes = use_indexes + self._distinct_on = distinct_on def _make_query(self) -> None: self._joined_tables = [] @@ -1899,6 +1978,19 @@ def _make_query(self) -> None: self.query._offset = self.query._wrapper_cls(self._offset) if self._distinct: self.query._distinct = True + if isinstance(self.query, PostgreSQLQueryBuilder) and self._distinct_on: + ordering_fields = [ordering[0] for ordering in self._orderings] + len_ordering_fields = len(ordering_fields) + for i, field in enumerate(self._distinct_on): + if ordering_fields and ( + i >= len_ordering_fields or ordering_fields[i] != field + ): + raise OperationalError( + f"DISTINCT ON fields must match the leading ORDER BY fields. " + f"Expected ORDER BY to start with {self._distinct_on!r}." + ) + self.query._distinct_on = [] + self.query.distinct_on(*self._distinct_on) if self._group_bys: self.query._groupbys = self._resolve_group_bys(*self._group_bys) @@ -1963,7 +2055,7 @@ class RawSQLQuery(AwaitableQuery): def __init__(self, model: type[MODEL], db: BaseDBAsyncClient, sql: str) -> None: super().__init__(model) self._sql = sql - self._db = db + self._apply_db(db) async def _execute(self) -> Any: instance_list = await self._db.executor_class( @@ -2094,7 +2186,7 @@ def __init__( self._objects = objects self._ignore_conflicts = ignore_conflicts self._batch_size = batch_size - self._db = db + self._apply_db(db) self._update_fields = update_fields self._on_conflict = on_conflict @@ -2298,7 +2390,7 @@ def __init__( ) -> None: super().__init__(model) self._union_query = union_query - self._db = db + self._apply_db(db) def _make_query(self) -> None: self._union_query._make_query() @@ -2344,11 +2436,11 @@ def __init__( all: bool = False, ): self.model = model - self.query = QUERY + self.query: QueryBuilder | PostgreSQLQueryBuilder = QUERY self._models: set[type[Model]] = {model, *(qs.model for qs in querysets)} self._union_query: QueryBuilder | _SetOperation | None = None self._selects: list[str] = [] - self._db = db + self._apply_db(db) self._qs = querysets self._all = all self._orderings: list[tuple[str, Order]] | None = None