diff --git a/tests/fields/test_m2m_uuid.py b/tests/fields/test_m2m_uuid.py index 7648750b2..f1c41b891 100644 --- a/tests/fields/test_m2m_uuid.py +++ b/tests/fields/test_m2m_uuid.py @@ -163,3 +163,15 @@ async def test__add_uninstantiated(db, m2m_uuid_models): two = await UUIDM2MRelatedModel.create() with pytest.raises(OperationalError, match=r"You should first call .save\(\) on"): await two.models.add(one) + + +@pytest.mark.asyncio +async def test_prefetch_related(db, m2m_uuid_models): + UUIDPkModel, UUIDM2MRelatedModel = m2m_uuid_models + one = await UUIDPkModel.create() + two = await UUIDM2MRelatedModel.create() + await one.peers.add(two) + + fetched = await UUIDPkModel.get(pk=one.pk).prefetch_related("peers") + + assert list(fetched.peers) == [two] diff --git a/tests/test_prefetching.py b/tests/test_prefetching.py index face651cc..85f39413e 100644 --- a/tests/test_prefetching.py +++ b/tests/test_prefetching.py @@ -154,6 +154,61 @@ async def test_prefetch_m2m_to_attr(db): assert list(event.to_attr_participants_2) == [team_second] +@pytest.mark.asyncio +async def test_prefetch_m2m_annotate(db): + tournament = await Tournament.create(name="tournament") + team = await Team.create(name="1") + event = await Event.create(name="First", tournament=tournament) + await event.participants.add(team) + event = await Event.first().prefetch_related( + Prefetch("participants", Team.annotate(count_events=Count("events"))) + ) + for team in event.participants: + assert team.count_events == 1 + + +@pytest.mark.asyncio +async def test_prefetch_m2m_select_related(db): + tournament = await Tournament.create(name="tournament") + team = await Team.create(name="1") + event = await Event.create(name="First", tournament=tournament) + await team.events.add(event) + team = await Team.first().prefetch_related( + Prefetch("events", Event.all().select_related("tournament")) + ) + for event in team.events: + assert event.tournament == tournament + + +@pytest.mark.asyncio +async def test_prefetch_m2m_order_by(db): + tournament = await Tournament.create(name="tournament") + team_1 = await Team.create(name="1") + team_2 = await Team.create(name="2") + event = await Event.create(name="First", tournament=tournament) + await event.participants.add(team_1, team_2) + event_1 = await Event.first().prefetch_related( + Prefetch("participants", Team.all().order_by("name")) + ) + event_2 = await Event.first().prefetch_related( + Prefetch("participants", Team.all().order_by("-name")) + ) + assert [team.name for team in event_1.participants] == ["1", "2"] + assert [team.name for team in event_2.participants] == ["2", "1"] + + +@pytest.mark.asyncio +async def test_prefetch_m2m_only(db): + tournament = await Tournament.create(name="tournament") + team = await Team.create(name="1") + event = await Event.create(name="First", tournament=tournament) + await team.events.add(event) + team = await Team.first().prefetch_related(Prefetch("events", Event.all().only("name"))) + assert len(team.events) == 1 + for event in team.events: + assert bool(event.pk) + + @pytest.mark.asyncio async def test_prefetch_o2o_to_attr(db): tournament = await Tournament.create(name="tournament") diff --git a/tortoise/backends/base/executor.py b/tortoise/backends/base/executor.py index 9865b78fc..f84fd0e34 100644 --- a/tortoise/backends/base/executor.py +++ b/tortoise/backends/base/executor.py @@ -7,8 +7,8 @@ from copy import copy from typing import TYPE_CHECKING, Any, cast -from pypika_tortoise import JoinType, Parameter, Table -from pypika_tortoise.queries import QueryBuilder +from pypika_tortoise.queries import QueryBuilder, Table +from pypika_tortoise.terms import Parameter from tortoise.exceptions import OperationalError, UnSupportedError from tortoise.expressions import Expression, ResolveContext @@ -19,7 +19,6 @@ ManyToManyFieldInstance, RelationalField, ) -from tortoise.query_utils import QueryModifier if TYPE_CHECKING: # pragma: nocoverage from tortoise.backends.base.client import BaseDBAsyncClient @@ -602,81 +601,54 @@ async def _prefetch_m2m_relation( field: str, related_query: tuple[str | None, QuerySet], ) -> Iterable[Model]: - to_attr, related_query = related_query - instance_id_set: set = { - instance._meta.pk.to_db_value(instance.pk, instance) for instance in instance_list - } + to_attr, related_queryset = related_query field_object: ManyToManyFieldInstance = self.model._meta.fields_map[field] # type: ignore - through_table = Table(field_object.through, schema=field_object.through_schema) + pk_field = self.model._meta.pk + instance_id_set: set = { + pk_field.to_db_value(instance.pk, instance) for instance in instance_list + } - subquery = ( - self.db.query_class.from_(through_table) - .select( - through_table[field_object.backward_key].as_("_backward_relation_key"), - through_table[field_object.forward_key].as_("_forward_relation_key"), + related_pk_field = related_queryset.model._meta.pk + related_pk_field_name = related_pk_field.model_field_name + fields_for_select = related_queryset._fields_for_select + if fields_for_select and related_pk_field_name not in fields_for_select: + related_queryset = related_queryset.only( + *related_queryset._fields_for_select, related_pk_field_name ) - .where(through_table[field_object.backward_key].isin(instance_id_set)) - ) - related_query_table = related_query.model._meta.basetable - related_pk_field = related_query.model._meta.db_pk_column - related_query.resolve_ordering(related_query.model, related_query_table, [], {}) - query = ( - related_query.query.join(subquery) - .on(subquery._forward_relation_key == related_query_table[related_pk_field]) - .select( - subquery._backward_relation_key.as_("_backward_relation_key"), - *[related_query_table[field].as_(field) for field in related_query.fields], + relation_map: dict = {} + related_objects_by_pks = { + related_pk_field.to_db_value(obj.pk, obj): obj + for obj in await related_queryset.filter( + **{f"{field_object.related_name}__in": instance_id_set} ) - ) - - if related_query._q_objects: - joined_tables: list[Table] = [] - modifier = QueryModifier() - for node in related_query._q_objects: - modifier &= node.resolve( - ResolveContext( - model=related_query.model, - table=related_query_table, - annotations=related_query._annotations, - custom_filters=related_query._custom_filters, - ) + } + if related_objects_by_pks: + through_table = Table(field_object.through, schema=field_object.through_schema) + backward_field = through_table[field_object.backward_key] + forward_field = through_table[field_object.forward_key] + + _, through_rows = await self.db.execute_query( + *( + self.db.query_class.from_(through_table) + .select(backward_field, forward_field) + .where(backward_field.isin(instance_id_set)) + .where(forward_field.isin(tuple(related_objects_by_pks))) + .get_parameterized_sql() ) - - for join in modifier.joins: - if join[0] not in joined_tables: - query = query.join(join[0], how=JoinType.left_outer).on(join[1]) - joined_tables.append(join[0]) - - if modifier.where_criterion: - query = query.where(modifier.where_criterion) - - if modifier.having_criterion: - query = query.having(modifier.having_criterion) - - _, raw_results = await self.db.execute_query(*query.get_parameterized_sql()) - relations: list[tuple[Any, Any]] = [] - related_object_list: list[Model] = [] - model_pk, related_pk = self.model._meta.pk, field_object.related_model._meta.pk - for e in raw_results: - pk_values: tuple[Any, Any] = ( - model_pk.to_python_value(e["_backward_relation_key"]), - related_pk.to_python_value(e[related_pk_field]), ) - relations.append(pk_values) - related_object_list.append(related_query.model._init_from_db(**e)) - await self.__class__( - model=related_query.model, db=self.db, prefetch_map=related_query._prefetch_map - )._execute_prefetch_queries(related_object_list) - related_object_map = {e.pk: e for e in related_object_list} - relation_map: dict[str, list] = {} - - for object_id, related_object_id in relations: - if object_id not in relation_map: - relation_map[object_id] = [] - relation_map[object_id].append(related_object_map[related_object_id]) + + reverse_map: dict = {} + for row in through_rows: + forward_key_value = related_pk_field.to_python_value(row[field_object.forward_key]) + backward_key_value = pk_field.to_python_value(row[field_object.backward_key]) + reverse_map.setdefault(forward_key_value, []).append(backward_key_value) + + for related_object in related_objects_by_pks.values(): + for instance_pk in reverse_map.get(related_object.pk, []): + relation_map.setdefault(instance_pk, []).append(related_object) for instance in instance_list: relation_container = getattr(instance, field)