Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions tests/fields/test_m2m_uuid.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,15 @@ async def test__add_uninstantiated(db, m2m_uuid_models):
two = await UUIDM2MRelatedModel.create()
with pytest.raises(OperationalError, match=r"You should first call .save\(\) on"):
await two.models.add(one)


@pytest.mark.asyncio
async def test_prefetch_related(db, m2m_uuid_models):
UUIDPkModel, UUIDM2MRelatedModel = m2m_uuid_models
one = await UUIDPkModel.create()
two = await UUIDM2MRelatedModel.create()
await one.peers.add(two)

fetched = await UUIDPkModel.get(pk=one.pk).prefetch_related("peers")

assert list(fetched.peers) == [two]
55 changes: 55 additions & 0 deletions tests/test_prefetching.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,61 @@ async def test_prefetch_m2m_to_attr(db):
assert list(event.to_attr_participants_2) == [team_second]


@pytest.mark.asyncio
async def test_prefetch_m2m_annotate(db):
tournament = await Tournament.create(name="tournament")
team = await Team.create(name="1")
event = await Event.create(name="First", tournament=tournament)
await event.participants.add(team)
event = await Event.first().prefetch_related(
Prefetch("participants", Team.annotate(count_events=Count("events")))
)
for team in event.participants:
assert team.count_events == 1


@pytest.mark.asyncio
async def test_prefetch_m2m_select_related(db):
tournament = await Tournament.create(name="tournament")
team = await Team.create(name="1")
event = await Event.create(name="First", tournament=tournament)
await team.events.add(event)
team = await Team.first().prefetch_related(
Prefetch("events", Event.all().select_related("tournament"))
)
for event in team.events:
assert event.tournament == tournament


@pytest.mark.asyncio
async def test_prefetch_m2m_order_by(db):
tournament = await Tournament.create(name="tournament")
team_1 = await Team.create(name="1")
team_2 = await Team.create(name="2")
event = await Event.create(name="First", tournament=tournament)
await event.participants.add(team_1, team_2)
event_1 = await Event.first().prefetch_related(
Prefetch("participants", Team.all().order_by("name"))
)
event_2 = await Event.first().prefetch_related(
Prefetch("participants", Team.all().order_by("-name"))
)
assert [team.name for team in event_1.participants] == ["1", "2"]
assert [team.name for team in event_2.participants] == ["2", "1"]


@pytest.mark.asyncio
async def test_prefetch_m2m_only(db):
tournament = await Tournament.create(name="tournament")
team = await Team.create(name="1")
event = await Event.create(name="First", tournament=tournament)
await team.events.add(event)
team = await Team.first().prefetch_related(Prefetch("events", Event.all().only("name")))
assert len(team.events) == 1
for event in team.events:
assert bool(event.pk)


@pytest.mark.asyncio
async def test_prefetch_o2o_to_attr(db):
tournament = await Tournament.create(name="tournament")
Expand Down
111 changes: 39 additions & 72 deletions tortoise/backends/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from copy import copy
from typing import TYPE_CHECKING, Any, cast

from pypika_tortoise import JoinType, Parameter, Table
from pypika_tortoise.queries import QueryBuilder
from pypika_tortoise.queries import QueryBuilder, Table
from pypika_tortoise.terms import Parameter

from tortoise.exceptions import OperationalError, UnSupportedError
from tortoise.expressions import Expression, ResolveContext
Expand All @@ -19,7 +19,6 @@
ManyToManyFieldInstance,
RelationalField,
)
from tortoise.query_utils import QueryModifier

if TYPE_CHECKING: # pragma: nocoverage
from tortoise.backends.base.client import BaseDBAsyncClient
Expand Down Expand Up @@ -602,85 +601,53 @@ async def _prefetch_m2m_relation(
field: str,
related_query: tuple[str | None, QuerySet],
) -> Iterable[Model]:
to_attr, related_query = related_query
instance_id_set: set = {
instance._meta.pk.to_db_value(instance.pk, instance) for instance in instance_list
}
to_attr, queryset = related_query

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VladislavYar Why do you change it from related_query to queryset?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the tuple[str | None, QuerySet] (related_query) type is assigned a value with the QuerySet type. Therefore, it is better to create a separate variable and the name of the queryset is clearer.
I will rename it to related_queryset, because the _prefetch_direct_relation method has the same name.


field_object: ManyToManyFieldInstance = self.model._meta.fields_map[field] # type: ignore

through_table = Table(field_object.through, schema=field_object.through_schema)
model_pk = self.model._meta.pk

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model._meta.pk is a field object, not a model class or instance; therefore, the variable name can be pk_field.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected it

instance_pks = [model_pk.to_db_value(instance.pk, instance) for instance in instance_list]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to keep instance_id_set instead of changing it to a list. Keeping the change minimal makes it easier to review.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected it


subquery = (
self.db.query_class.from_(through_table)
.select(
through_table[field_object.backward_key].as_("_backward_relation_key"),
through_table[field_object.forward_key].as_("_forward_relation_key"),
)
.where(through_table[field_object.backward_key].isin(instance_id_set))
)
related_model_pk = queryset.model._meta.pk
model_field_name_pk = related_model_pk.model_field_name
fields_for_select = queryset._fields_for_select
if fields_for_select and model_field_name_pk not in fields_for_select:
queryset = queryset.only(*queryset._fields_for_select, model_field_name_pk)

related_query_table = related_query.model._meta.basetable
related_pk_field = related_query.model._meta.db_pk_column
related_query.resolve_ordering(related_query.model, related_query_table, [], {})
query = (
related_query.query.join(subquery)
.on(subquery._forward_relation_key == related_query_table[related_pk_field])
.select(
subquery._backward_relation_key.as_("_backward_relation_key"),
*[related_query_table[field].as_(field) for field in related_query.fields],
relation_map: dict = {}
related_objects_by_pks = {
related_model_pk.to_db_value(obj.pk, obj): obj
for obj in await queryset.filter(**{f"{field_object.related_name}__in": instance_pks})
}
if related_objects_by_pks:
through_table = Table(field_object.through, schema=field_object.through_schema)
backward_field = through_table[field_object.backward_key]
forward_field = through_table[field_object.forward_key]

_, through_rows = await self.db.execute_query(
*(
self.db.query_class.from_(through_table)
.select(backward_field, forward_field)
.where(backward_field.isin(instance_pks))
.where(forward_field.isin(tuple(related_objects_by_pks)))
.get_parameterized_sql()
)
)
)

if related_query._q_objects:
joined_tables: list[Table] = []
modifier = QueryModifier()
for node in related_query._q_objects:
modifier &= node.resolve(
ResolveContext(
model=related_query.model,
table=related_query_table,
annotations=related_query._annotations,
custom_filters=related_query._custom_filters,
)
)
reverse_map: dict = {}
for row in through_rows:
forward_key_value = related_model_pk.to_python_value(row[field_object.forward_key])
backward_key_value = model_pk.to_python_value(row[field_object.backward_key])
reverse_map.setdefault(forward_key_value, []).append(backward_key_value)

for join in modifier.joins:
if join[0] not in joined_tables:
query = query.join(join[0], how=JoinType.left_outer).on(join[1])
joined_tables.append(join[0])

if modifier.where_criterion:
query = query.where(modifier.where_criterion)

if modifier.having_criterion:
query = query.having(modifier.having_criterion)

_, raw_results = await self.db.execute_query(*query.get_parameterized_sql())
relations: list[tuple[Any, Any]] = []
related_object_list: list[Model] = []
model_pk, related_pk = self.model._meta.pk, field_object.related_model._meta.pk
for e in raw_results:
pk_values: tuple[Any, Any] = (
model_pk.to_python_value(e["_backward_relation_key"]),
related_pk.to_python_value(e[related_pk_field]),
)
relations.append(pk_values)
related_object_list.append(related_query.model._init_from_db(**e))
await self.__class__(
model=related_query.model, db=self.db, prefetch_map=related_query._prefetch_map
)._execute_prefetch_queries(related_object_list)
related_object_map = {e.pk: e for e in related_object_list}
relation_map: dict[str, list] = {}

for object_id, related_object_id in relations:
if object_id not in relation_map:
relation_map[object_id] = []
relation_map[object_id].append(related_object_map[related_object_id])
for related_object in related_objects_by_pks.values():
for instance_pk in reverse_map.get(related_object.pk, []):
relation_map.setdefault(instance_pk, []).append(related_object)

for instance in instance_list:
relation_container = getattr(instance, field)
relation_container._set_result_for_query(relation_map.get(instance.pk, []), to_attr)
getattr(instance, field)._set_result_for_query(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One-line style reduces readability; it would be better to roll it back.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected it

relation_map.get(instance.pk, []), to_attr
)
return instance_list

async def _prefetch_direct_relation(
Expand Down
Loading