Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions tests/fields/test_m2m_uuid.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,15 @@ async def test__add_uninstantiated(db, m2m_uuid_models):
two = await UUIDM2MRelatedModel.create()
with pytest.raises(OperationalError, match=r"You should first call .save\(\) on"):
await two.models.add(one)


@pytest.mark.asyncio
async def test_prefetch_related(db, m2m_uuid_models):
UUIDPkModel, UUIDM2MRelatedModel = m2m_uuid_models
one = await UUIDPkModel.create()
two = await UUIDM2MRelatedModel.create()
await one.peers.add(two)

fetched = await UUIDPkModel.get(pk=one.pk).prefetch_related("peers")

assert list(fetched.peers) == [two]
55 changes: 55 additions & 0 deletions tests/test_prefetching.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,61 @@ async def test_prefetch_m2m_to_attr(db):
assert list(event.to_attr_participants_2) == [team_second]


@pytest.mark.asyncio
async def test_prefetch_m2m_annotate(db):
tournament = await Tournament.create(name="tournament")
team = await Team.create(name="1")
event = await Event.create(name="First", tournament=tournament)
await event.participants.add(team)
event = await Event.first().prefetch_related(
Prefetch("participants", Team.annotate(count_events=Count("events")))
)
for team in event.participants:
assert team.count_events == 1


@pytest.mark.asyncio
async def test_prefetch_m2m_select_related(db):
tournament = await Tournament.create(name="tournament")
team = await Team.create(name="1")
event = await Event.create(name="First", tournament=tournament)
await team.events.add(event)
team = await Team.first().prefetch_related(
Prefetch("events", Event.all().select_related("tournament"))
)
for event in team.events:
assert event.tournament == tournament


@pytest.mark.asyncio
async def test_prefetch_m2m_order_by(db):
tournament = await Tournament.create(name="tournament")
team_1 = await Team.create(name="1")
team_2 = await Team.create(name="2")
event = await Event.create(name="First", tournament=tournament)
await event.participants.add(team_1, team_2)
event_1 = await Event.first().prefetch_related(
Prefetch("participants", Team.all().order_by("name"))
)
event_2 = await Event.first().prefetch_related(
Prefetch("participants", Team.all().order_by("-name"))
)
assert [team.name for team in event_1.participants] == ["1", "2"]
assert [team.name for team in event_2.participants] == ["2", "1"]


@pytest.mark.asyncio
async def test_prefetch_m2m_only(db):
tournament = await Tournament.create(name="tournament")
team = await Team.create(name="1")
event = await Event.create(name="First", tournament=tournament)
await team.events.add(event)
team = await Team.first().prefetch_related(Prefetch("events", Event.all().only("name")))
assert len(team.events) == 1
for event in team.events:
assert bool(event.pk)


@pytest.mark.asyncio
async def test_prefetch_o2o_to_attr(db):
tournament = await Tournament.create(name="tournament")
Expand Down
110 changes: 41 additions & 69 deletions tortoise/backends/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from copy import copy
from typing import TYPE_CHECKING, Any, cast

from pypika_tortoise import JoinType, Parameter, Table
from pypika_tortoise.queries import QueryBuilder
from pypika_tortoise.queries import QueryBuilder, Table
from pypika_tortoise.terms import Parameter

from tortoise.exceptions import OperationalError, UnSupportedError
from tortoise.expressions import Expression, ResolveContext
Expand All @@ -19,7 +19,6 @@
ManyToManyFieldInstance,
RelationalField,
)
from tortoise.query_utils import QueryModifier

if TYPE_CHECKING: # pragma: nocoverage
from tortoise.backends.base.client import BaseDBAsyncClient
Expand Down Expand Up @@ -602,81 +601,54 @@ async def _prefetch_m2m_relation(
field: str,
related_query: tuple[str | None, QuerySet],
) -> Iterable[Model]:
to_attr, related_query = related_query
instance_id_set: set = {
instance._meta.pk.to_db_value(instance.pk, instance) for instance in instance_list
}
to_attr, related_queryset = related_query

field_object: ManyToManyFieldInstance = self.model._meta.fields_map[field] # type: ignore

through_table = Table(field_object.through, schema=field_object.through_schema)
pk_field = self.model._meta.pk
instance_id_set: set = {
pk_field.to_db_value(instance.pk, instance) for instance in instance_list
}

subquery = (
self.db.query_class.from_(through_table)
.select(
through_table[field_object.backward_key].as_("_backward_relation_key"),
through_table[field_object.forward_key].as_("_forward_relation_key"),
related_pk_field = related_queryset.model._meta.pk
related_pk_field_name = related_pk_field.model_field_name
fields_for_select = related_queryset._fields_for_select
if fields_for_select and related_pk_field_name not in fields_for_select:
related_queryset = related_queryset.only(
*related_queryset._fields_for_select, related_pk_field_name
)
.where(through_table[field_object.backward_key].isin(instance_id_set))
)

related_query_table = related_query.model._meta.basetable
related_pk_field = related_query.model._meta.db_pk_column
related_query.resolve_ordering(related_query.model, related_query_table, [], {})
query = (
related_query.query.join(subquery)
.on(subquery._forward_relation_key == related_query_table[related_pk_field])
.select(
subquery._backward_relation_key.as_("_backward_relation_key"),
*[related_query_table[field].as_(field) for field in related_query.fields],
relation_map: dict = {}
related_objects_by_pks = {
related_pk_field.to_db_value(obj.pk, obj): obj
for obj in await related_queryset.filter(
**{f"{field_object.related_name}__in": instance_id_set}
)
)

if related_query._q_objects:
joined_tables: list[Table] = []
modifier = QueryModifier()
for node in related_query._q_objects:
modifier &= node.resolve(
ResolveContext(
model=related_query.model,
table=related_query_table,
annotations=related_query._annotations,
custom_filters=related_query._custom_filters,
)
}
if related_objects_by_pks:
through_table = Table(field_object.through, schema=field_object.through_schema)
backward_field = through_table[field_object.backward_key]
forward_field = through_table[field_object.forward_key]

_, through_rows = await self.db.execute_query(
*(
self.db.query_class.from_(through_table)
.select(backward_field, forward_field)
.where(backward_field.isin(instance_id_set))
.where(forward_field.isin(tuple(related_objects_by_pks)))
.get_parameterized_sql()
)

for join in modifier.joins:
if join[0] not in joined_tables:
query = query.join(join[0], how=JoinType.left_outer).on(join[1])
joined_tables.append(join[0])

if modifier.where_criterion:
query = query.where(modifier.where_criterion)

if modifier.having_criterion:
query = query.having(modifier.having_criterion)

_, raw_results = await self.db.execute_query(*query.get_parameterized_sql())
relations: list[tuple[Any, Any]] = []
related_object_list: list[Model] = []
model_pk, related_pk = self.model._meta.pk, field_object.related_model._meta.pk
for e in raw_results:
pk_values: tuple[Any, Any] = (
model_pk.to_python_value(e["_backward_relation_key"]),
related_pk.to_python_value(e[related_pk_field]),
)
relations.append(pk_values)
related_object_list.append(related_query.model._init_from_db(**e))
await self.__class__(
model=related_query.model, db=self.db, prefetch_map=related_query._prefetch_map
)._execute_prefetch_queries(related_object_list)
related_object_map = {e.pk: e for e in related_object_list}
relation_map: dict[str, list] = {}

for object_id, related_object_id in relations:
if object_id not in relation_map:
relation_map[object_id] = []
relation_map[object_id].append(related_object_map[related_object_id])

reverse_map: dict = {}
for row in through_rows:
forward_key_value = related_pk_field.to_python_value(row[field_object.forward_key])
backward_key_value = pk_field.to_python_value(row[field_object.backward_key])
reverse_map.setdefault(forward_key_value, []).append(backward_key_value)

for related_object in related_objects_by_pks.values():
for instance_pk in reverse_map.get(related_object.pk, []):
relation_map.setdefault(instance_pk, []).append(related_object)

for instance in instance_list:
relation_container = getattr(instance, field)
Expand Down
Loading