Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 204 additions & 0 deletions tests/test_distinct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import pytest

from tests.testmodels import Tournament
from tortoise.contrib import test
from tortoise.exceptions import OperationalError

# ---------------------------------------------------------------------------
# Basic DISTINCT (all databases)
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_distinct_no_args(db):
await Tournament.create(name="1", desc="a")
await Tournament.create(name="1", desc="b")
tournaments = await Tournament.all().distinct()
assert len(tournaments) == 2


# ---------------------------------------------------------------------------
# DISTINCT ON (PostgreSQL only)
# ---------------------------------------------------------------------------


@test.requireCapability(dialect="postgres")
@pytest.mark.asyncio
async def test_distinct_on_single_field(db):
tournament_1 = await Tournament.create(name="1", desc="1")
await Tournament.create(name="1", desc="2")
await Tournament.create(name="1", desc="3")

tournaments = await Tournament.all().distinct("name")
assert tournaments == [tournament_1]


@test.requireCapability(dialect="postgres")
@pytest.mark.asyncio
async def test_distinct_on_single_field_with_order_by(db):
await Tournament.create(name="1", desc="1")
await Tournament.create(name="1", desc="2")
tournament_3 = await Tournament.create(name="1", desc="3")

tournaments = await Tournament.all().distinct("name").order_by("name", "-desc")
assert tournaments == [tournament_3]


@test.requireCapability(dialect="postgres")
@pytest.mark.asyncio
async def test_distinct_on_multiple_fields(db):
tournament_1 = await Tournament.create(name="1", desc="a")
await Tournament.create(name="1", desc="a")
tournament_3 = await Tournament.create(name="2", desc="b")

tournaments = await Tournament.all().distinct("name", "desc").order_by("name", "desc")
assert tournaments == [tournament_1, tournament_3]


@test.requireCapability(dialect="postgres")
@pytest.mark.asyncio
async def test_distinct_on_values_list_single_field(db):
"""values_list selects one field, same as DISTINCT ON field."""
await Tournament.create(name="1", desc="a")
await Tournament.create(name="1", desc="b")
await Tournament.create(name="2", desc="c")

tournaments = await Tournament.all().distinct("name").values_list("name", flat=True)
assert tournaments == ["1", "2"]


@test.requireCapability(dialect="postgres")
@pytest.mark.asyncio
async def test_distinct_on_values_list_multiple_fields(db):
await Tournament.create(name="1", desc="a")
await Tournament.create(name="1", desc="b")
await Tournament.create(name="2", desc="c")

tournaments = await Tournament.all().distinct("name").values_list("name", "desc")
assert tournaments == [("1", "a"), ("2", "c")]


@test.requireCapability(dialect="postgres")
@pytest.mark.asyncio
async def test_distinct_on_values_list_extra_fields(db):
await Tournament.create(name="1", desc="a")
await Tournament.create(name="1", desc="b")
await Tournament.create(name="2", desc="c")

tournaments = await Tournament.all().distinct("name").values_list("desc", flat=True)
assert tournaments == ["a", "c"]


@test.requireCapability(dialect="postgres")
@pytest.mark.asyncio
async def test_distinct_on_values_list_extra_field_respects_order_by(db):
await Tournament.create(name="1", desc="a")
await Tournament.create(name="1", desc="b")
await Tournament.create(name="2", desc="c")

tournaments = (
await Tournament.all()
.distinct("name")
.order_by("name", "-desc")
.values_list("desc", flat=True)
)
assert tournaments == ["b", "c"]


@test.requireCapability(dialect="postgres")
@pytest.mark.asyncio
async def test_distinct_on_values_single_field(db):
await Tournament.create(name="1", desc="a")
await Tournament.create(name="1", desc="b")
await Tournament.create(name="2", desc="c")

tournaments = await Tournament.all().distinct("name").values("name")
assert tournaments == [{"name": "1"}, {"name": "2"}]


@test.requireCapability(dialect="postgres")
@pytest.mark.asyncio
async def test_distinct_on_values_multiple_fields(db):
await Tournament.create(name="1", desc="a")
await Tournament.create(name="1", desc="b")
await Tournament.create(name="2", desc="c")

tournaments = await Tournament.all().distinct("name").values("name", "desc")
assert tournaments == [{"name": "1", "desc": "a"}, {"name": "2", "desc": "c"}]


@test.requireCapability(dialect="postgres")
@pytest.mark.asyncio
async def test_distinct_on_values_extra_fields(db):
await Tournament.create(name="1", desc="a")
await Tournament.create(name="1", desc="b")
await Tournament.create(name="2", desc="c")

tournaments = await Tournament.all().distinct("name").values("desc")
assert tournaments == [{"desc": "a"}, {"desc": "c"}]


@test.requireCapability(dialect="postgres")
@pytest.mark.asyncio
async def test_distinct_on_values_extra_field_respects_order_by(db):
await Tournament.create(name="1", desc="a")
await Tournament.create(name="1", desc="b")
await Tournament.create(name="2", desc="c")

tournaments = await Tournament.all().distinct("name").order_by("name", "-desc").values("desc")
assert tournaments == [{"desc": "b"}, {"desc": "c"}]


@test.requireCapability(dialect="postgres")
@pytest.mark.asyncio
async def test_distinct_on_only_same_field(db):
await Tournament.create(name="1", desc="a")
await Tournament.create(name="1", desc="b")
await Tournament.create(name="2", desc="c")

tournaments = await Tournament.all().distinct("name").only("name")
assert [t.name for t in tournaments] == ["1", "2"]


@test.requireCapability(dialect="postgres")
@pytest.mark.asyncio
async def test_distinct_on_only_extra_field(db):
await Tournament.create(name="1", desc="a")
await Tournament.create(name="1", desc="b")
await Tournament.create(name="2", desc="c")

tournaments = await Tournament.all().distinct("name").only("name", "desc")
assert [(t.name, t.desc) for t in tournaments] == [("1", "a"), ("2", "c")]


@test.requireCapability(dialect="postgres")
@pytest.mark.asyncio
async def test_distinct_on_only_with_order_by(db):
await Tournament.create(name="1", desc="a")
await Tournament.create(name="1", desc="b")
await Tournament.create(name="2", desc="c")

tournaments = (
await Tournament.all().distinct("name").order_by("name", "-desc").only("name", "desc")
)
assert [(t.name, t.desc) for t in tournaments] == [("1", "b"), ("2", "c")]


# ---------------------------------------------------------------------------
# DISTINCT ON validation errors
# ---------------------------------------------------------------------------


@test.requireCapability(dialect="postgres")
@pytest.mark.asyncio
async def test_distinct_on_invalid_order_by(db):
await Tournament.create(name="1")
with pytest.raises(OperationalError):
await Tournament.all().distinct("name").order_by("desc")


@test.skipCapability(dialect="postgres")
Comment thread
waketzheng marked this conversation as resolved.
Outdated
@pytest.mark.asyncio
async def test_distinct_on_not_supported_outside_postgres(db):
with pytest.raises(OperationalError):
Tournament.all().distinct("name")
64 changes: 64 additions & 0 deletions tortoise/contrib/test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ async def test_sqlite_only(db):
"TortoiseContext",
"tortoise_test_context",
"requireCapability",
"skipCapability",
"truncate_all_models",
"init_memory_sqlite",
"SkipTest",
Expand Down Expand Up @@ -235,6 +236,69 @@ def skip_wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
return decorator


def skipCapability(
Comment thread
waketzheng marked this conversation as resolved.
Outdated
connection_name: str = "models", **conditions: typing.Any
) -> Callable[[_FT], _FT]:
"""
Skip a test if the specified capabilities are matched.

This is the inverse of :func:`requireCapability`.

Usage:

.. code-block:: python3

@skipCapability(dialect='postgres')
@pytest.mark.asyncio
async def test_skip_on_postgres(db):
...

:param connection_name: name of the connection to retrieve capabilities from.
:param conditions: capability tests — if all match, the test is skipped.
"""

def decorator(test_item: _FT) -> _FT:
if not isinstance(test_item, type):

def check_capabilities() -> None:
db = get_connection(connection_name)
if all(getattr(db.capabilities, key) == val for key, val in conditions.items()):
raise SkipTest(f"Skipped because capabilities match: {conditions}")

if inspect.iscoroutinefunction(test_item):

@wraps(test_item)
async def skip_wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
check_capabilities()
return await test_item(*args, **kwargs)

else:

@wraps(test_item)
def skip_wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
check_capabilities()
return test_item(*args, **kwargs)

return cast(_FT, skip_wrapper)

# Assume a class is decorated
funcs = {
var: f
for var in dir(test_item)
if var.startswith("test_") and callable(f := getattr(test_item, var))
}
for name, func in funcs.items():
setattr(
test_item,
name,
skipCapability(connection_name=connection_name, **conditions)(func),
)

return test_item

return decorator


@typing.overload
def init_memory_sqlite(models: ModulesConfigType | None = None) -> AsyncFuncDeco: ...

Expand Down
Loading
Loading