From 511df088a335b4c4fef2af83f1318adb71cc839e Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 8 May 2026 21:46:33 +0800 Subject: [PATCH 1/9] feat(triggerer): share one poll across sibling event triggers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When several AssetWatcher instances back triggers that read from the same upstream resource (one SQS queue, one Kafka topic, etc.), the triggerer spins up N independent poll loops today — one per trigger. Issue #66476 asks for one shared poller serving all of them. Add an opt-in path on BaseEventTrigger via three new hooks (`shared_stream_key`, `open_shared_stream`, `filter_shared_stream`) and a new SharedStreamManager that runs one poll task per distinct key and broadcasts events to per-subscriber queues. The key is read once when run_trigger starts and identifies the group for the trigger's lifetime. Per-trigger cleanup runs in run_trigger's finally; SharedStreamManager.stop_all() runs in the triggerer's shutdown path as a safety net. Triggers whose shared_stream_key() returns None (the default) keep their existing run() loop unchanged. The per-subscriber buffer size is exposed as [triggerer] shared_stream_subscriber_queue_size (default 1024) so deployments with a fast upstream can raise it without code changes. --- .../event-scheduling.rst | 122 ++++ airflow-core/newsfragments/66584.feature.rst | 1 + .../src/airflow/config_templates/config.yml | 11 + .../example_asset_with_watchers.py | 44 +- .../src/airflow/jobs/triggerer_job_runner.py | 53 +- airflow-core/src/airflow/triggers/base.py | 122 +++- .../src/airflow/triggers/shared_stream.py | 389 +++++++++++ .../tests/unit/jobs/test_triggerer_job.py | 64 ++ .../tests/unit/triggers/test_base_trigger.py | 96 ++- .../tests/unit/triggers/test_shared_stream.py | 635 ++++++++++++++++++ .../providers/standard/triggers/file.py | 106 ++- .../tests/unit/standard/triggers/test_file.py | 169 ++++- 12 files changed, 1800 insertions(+), 12 deletions(-) create mode 100644 airflow-core/newsfragments/66584.feature.rst create mode 100644 airflow-core/src/airflow/triggers/shared_stream.py create mode 100644 airflow-core/tests/unit/triggers/test_shared_stream.py diff --git a/airflow-core/docs/authoring-and-scheduling/event-scheduling.rst b/airflow-core/docs/authoring-and-scheduling/event-scheduling.rst index 4cd72edc63b75..b5d947bd6a027 100644 --- a/airflow-core/docs/authoring-and-scheduling/event-scheduling.rst +++ b/airflow-core/docs/authoring-and-scheduling/event-scheduling.rst @@ -64,6 +64,128 @@ event-driven scheduling, then a new trigger must be created. This new trigger must inherit ``BaseEventTrigger`` and ensure it properly works with event-driven scheduling. It might inherit from the existing trigger as well if both triggers share some common code. +Sharing one poll across sibling triggers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. versionadded:: 3.3 + +When several ``AssetWatcher`` instances on different assets back triggers that read from the **same upstream resource** +— a directory of flag files, a polling REST endpoint, a Kafka topic with auto-commit, and similar idempotent or +subscriber-side-effect sources — the triggerer would otherwise spin up one independent poll loop per trigger. For a +shared source with twenty subscribers that means twenty poll loops, twenty connections, twenty sets of API calls per +cadence. See "Suitable upstreams" below for the precise scope. + +``BaseEventTrigger`` supports an opt-in path so that sibling triggers share a single underlying poll, while each +trigger keeps its own DB row, its own ``run_trigger`` task, and its own per-instance filtering. To participate, a +subclass overrides three hooks: + +* :py:meth:`~airflow.triggers.base.BaseEventTrigger.shared_stream_key` — return a key identifying the shared + upstream (typically a tuple of strings). Triggers whose key compares equal will share one poll. Returning ``None`` + (the default) opts out — the trigger runs its own independent ``run()`` loop, exactly as before. The return value + is read **once** when the triggerer starts this trigger; changing it mid-lifetime has no effect on group + membership, so siblings that should share a poll must return the same key from the outset. + +* :py:meth:`~airflow.triggers.base.BaseEventTrigger.open_shared_stream` — a ``@classmethod`` coroutine the triggerer + drives **once per shared-stream group** to yield raw events from the upstream. Because the triggerer reuses one + trigger's kwargs to drive the shared poll, only rely on fields whose values participate in ``shared_stream_key``. + +* :py:meth:`~airflow.triggers.base.BaseEventTrigger.filter_shared_stream` — an instance method that consumes the + broadcast raw stream and yields the ``TriggerEvent`` instances this trigger should fire. Per-trigger filtering + (e.g. only events matching this instance's ``filename``) lives here. + +Example: a ``DirectoryFileDeleteTrigger`` that fires when a per-asset flag file appears in a shared inbox directory: + +.. code-block:: python + + from collections.abc import AsyncIterator, Hashable + from typing import Any + + from airflow.triggers.base import BaseEventTrigger, TriggerEvent + + + class DirectoryFileDeleteTrigger(BaseEventTrigger): + def __init__(self, *, directory, filename, poke_interval=5.0): + super().__init__() + self.directory = directory + self.filename = filename + self.poke_interval = poke_interval + + def shared_stream_key(self) -> Hashable | None: + # All triggers on the same directory + cadence share one scan. + return ("directory-scan", self.directory, self.poke_interval) + + @classmethod + async def open_shared_stream(cls, kwargs: dict[str, Any]) -> AsyncIterator[Any]: + # Drives one directory listing loop per group. + ... + + async def filter_shared_stream(self, shared_stream: AsyncIterator[Any]) -> AsyncIterator[TriggerEvent]: + # Each instance fires only for its own filename. + async for snapshot in shared_stream: + if self.filename in snapshot["names"]: + yield TriggerEvent(...) + return + +A complete example using this trigger ships in +``airflow.example_dags.example_asset_with_watchers``, where two sibling +``DirectoryFileDeleteTrigger`` watchers share one directory scan alongside +a standalone ``FileDeleteTrigger`` watcher in the same Dag. + +What is and isn't shared +^^^^^^^^^^^^^^^^^^^^^^^^ + +The sharing is narrower than the name might suggest: + +* **Shared** (one per ``shared_stream_key``): the ``open_shared_stream`` async generator and its upstream I/O — for + example, the actual ``iterdir`` calls on the directory or polling REST API calls. + +* **Not shared** (one per trigger): the ``Trigger`` DB row, the trigger instance, the ``run_trigger`` + asyncio task, and the ``filter_shared_stream`` async generator. Each ``AssetWatcher`` still appears as its own + trigger in the UI and in the metadata database. + +In other words, the savings is at the poll-loop and upstream-I/O layer, not at the persistence or scheduling layer. + +Suitable upstreams +^^^^^^^^^^^^^^^^^^ + +The shared-stream channel is **one-way** today: events flow from +``open_shared_stream`` out to each subscriber's ``filter_shared_stream``, +and there is no way for a subscriber to tell the producer "I accepted / +dropped / committed this event". That restricts the pattern to upstreams +whose consumption does **not** depend on a side effect on a handle that +only the producer holds. Good fits: + +* Idempotent / read-only reads — directory scans, polling REST APIs. +* Auto-commit Kafka consumers (``enable.auto.commit=true``). +* Subscriber-side-effect cleanup, where the trigger's per-event action + (``unlink``, local marking, …) goes through APIs the subscriber owns + independently of the shared producer handle. + +Currently **not** in scope: Kafka consumers with manual commit, SQS with +delete-on-process or visibility extension, and any source where progress +on the producer's handle is tied to the subscriber's accept / reject +decision. A producer-side ack channel to cover those cases is a planned +follow-up; it should be designed against a concrete Kafka or SQS consumer +rather than against an abstract API, so it is intentionally left out of +the first iteration. + +Verifying that sharing is active +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The triggerer logs the creation of each shared-stream group, and names the poll task after its key: + +.. code-block:: text + + Shared stream group started key=('directory-scan', '/tmp/region-flags', 5.0) + +.. code-block:: text + + asyncio task name: shared-stream-poll[('directory-scan', '/tmp/region-flags', 5.0)] + +If sharing is active you should see exactly one ``Shared stream group started`` line per distinct key, regardless of +how many subscribers join it. If you see one log line per subscriber instead, the keys probably do not compare equal +— verify that ``shared_stream_key`` returns identical values across the siblings. + Avoid infinite scheduling ~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/airflow-core/newsfragments/66584.feature.rst b/airflow-core/newsfragments/66584.feature.rst new file mode 100644 index 0000000000000..e8a547d10baf1 --- /dev/null +++ b/airflow-core/newsfragments/66584.feature.rst @@ -0,0 +1 @@ +Sibling ``BaseEventTrigger`` instances on different ``AssetWatcher`` s can now share a single underlying poll loop in the triggerer by overriding ``shared_stream_key``, ``open_shared_stream``, and ``filter_shared_stream``. Triggers that opt out (the default) keep their existing independent ``run()`` loop behavior. diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index f7f958bb84b6f..a6f3de25e2196 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -2773,6 +2773,17 @@ triggerer: type: boolean example: ~ default: "False" + shared_stream_subscriber_queue_size: + description: | + Per-subscriber buffer size for shared-stream triggers (triggers that opt into a shared poll loop + via ``BaseEventTrigger.shared_stream_key``). Each subscribing trigger keeps an in-memory queue of + raw events produced by the shared poll; if a slow subscriber fills its queue, only that subscriber + fails, sibling subscribers are unaffected. Increase if a slow subscriber must tolerate bursts from + a fast upstream. + version_added: 3.3.0 + type: integer + example: ~ + default: "1024" kerberos: description: ~ options: diff --git a/airflow-core/src/airflow/example_dags/example_asset_with_watchers.py b/airflow-core/src/airflow/example_dags/example_asset_with_watchers.py index 673be86a7964b..e28b93f482dd3 100644 --- a/airflow-core/src/airflow/example_dags/example_asset_with_watchers.py +++ b/airflow-core/src/airflow/example_dags/example_asset_with_watchers.py @@ -15,23 +15,55 @@ # specific language governing permissions and limitations # under the License. """ -Example DAG for demonstrating the usage of event driven scheduling using assets and triggers. +Example Dag for event-driven scheduling using Assets and AssetWatchers. + +Three watchers demonstrate the two trigger patterns in one place: + +* The first watcher uses ``FileDeleteTrigger`` for a single specific path — + one watcher, one independent poll loop in the triggerer. +* The other two use ``DirectoryFileDeleteTrigger`` with a matching + ``shared_stream_key`` of ``("directory-scan", directory, poke_interval)``; + the triggerer runs **one** directory listing loop for the pair and + broadcasts the result to both. Each still fires only for its own filename. + +The Dag runs when any of the three watchers' assets is updated. Touch +``/tmp/test``, ``/tmp/region-flags/us.flag``, or ``/tmp/region-flags/eu.flag`` +to trigger a run. """ from __future__ import annotations -from airflow.providers.standard.triggers.file import FileDeleteTrigger +from airflow.providers.standard.triggers.file import ( + DirectoryFileDeleteTrigger, + FileDeleteTrigger, +) from airflow.sdk import DAG, Asset, AssetWatcher, chain, task -file_path = "/tmp/test" +# Independent single-file watcher — has its own poll loop in the triggerer. +single_file_trigger = FileDeleteTrigger(filepath="/tmp/test") +single_file_asset = Asset( + "example_asset", + watchers=[AssetWatcher(name="test_asset_watcher", trigger=single_file_trigger)], +) -trigger = FileDeleteTrigger(filepath=file_path) -asset = Asset("example_asset", watchers=[AssetWatcher(name="test_asset_watcher", trigger=trigger)]) +# Shared-stream watchers — same directory + poke interval, so the triggerer +# runs one scan for both. Each watcher's ``filter_shared_stream`` matches on +# its own filename and ``unlink``s the flag file as a subscriber-side effect. +us_trigger = DirectoryFileDeleteTrigger(directory="/tmp/region-flags", filename="us.flag", poke_interval=5.0) +eu_trigger = DirectoryFileDeleteTrigger(directory="/tmp/region-flags", filename="eu.flag", poke_interval=5.0) +us_asset = Asset( + "region_us_flag", + watchers=[AssetWatcher(name="us_flag_watcher", trigger=us_trigger)], +) +eu_asset = Asset( + "region_eu_flag", + watchers=[AssetWatcher(name="eu_flag_watcher", trigger=eu_trigger)], +) with DAG( dag_id="example_asset_with_watchers", - schedule=[asset], + schedule=[single_file_asset, us_asset, eu_asset], catchup=False, ): diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 6f8f7baae84ab..cb8a8be4ef3e2 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -27,7 +27,7 @@ import threading import time from collections import deque -from collections.abc import Callable, Generator, Iterable, Iterator +from collections.abc import Callable, Generator, Hashable, Iterable, Iterator from contextlib import contextmanager, suppress from datetime import datetime from socket import socket @@ -109,6 +109,7 @@ from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance from airflow.serialization.serialized_objects import DagSerialization from airflow.triggers.base import BaseEventTrigger, BaseTrigger, DiscrimatedTriggerEvent, TriggerEvent +from airflow.triggers.shared_stream import SharedStreamManager from airflow.utils.helpers import log_filename_template_renderer from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import create_session, provide_session @@ -1069,6 +1070,10 @@ def __init__(self): self.failed_triggers = deque() self.job_id = None self._stop_event = None + self._shared_streams = SharedStreamManager( + log=self.log, + max_subscriber_queue=conf.getint("triggerer", "shared_stream_subscriber_queue_size"), + ) self.blocked_main_thread_warning_threshold = conf.getfloat( "triggerer", "blocked_main_thread_warning_threshold" ) @@ -1136,6 +1141,12 @@ async def arun(self): reader_task.cancel() with suppress(asyncio.CancelledError): await reader_task + # Safety net: cancel any shared-stream poll tasks whose group + # survived per-trigger cleanup. The normal eviction path is + # ``SharedStreamManager.unsubscribe`` in ``run_trigger``'s + # finally; this call only matters when that path was bypassed + # (e.g. the unsubscribe coroutine raised and was swallowed). + await self._shared_streams.stop_all() # Wait for supporting tasks to complete await watchdog @@ -1437,12 +1448,39 @@ async def run_trigger( name = self.triggers[trigger_id]["name"] self.log.info("trigger %s starting", name) + + # Triggers that opt into a shared underlying I/O stream + # (BaseEventTrigger.shared_stream_key returns non-None) consume a + # broadcast stream produced by SharedStreamManager and convert it + # via filter_shared_stream(). Everything else stays on the original + # standalone-run() path. + shared_key: Hashable | None = None + event_trigger: BaseEventTrigger | None = None + if isinstance(trigger, BaseEventTrigger): + event_trigger = trigger + try: + shared_key = event_trigger.shared_stream_key() + except Exception: + self.log.exception( + "shared_stream_key() raised; falling back to standalone run", + trigger_id=trigger_id, + ) + shared_key = None + with _make_trigger_span(ti=trigger.task_instance, trigger_id=trigger_id, name=name) as span: try: if context is not None: trigger.render_template_fields(context=context) - async for event in trigger.run(): + if shared_key is not None and event_trigger is not None: + shared_stream = self._shared_streams.subscribe( + trigger_id=trigger_id, trigger=event_trigger, key=shared_key + ) + event_stream = event_trigger.filter_shared_stream(shared_stream) + else: + event_stream = trigger.run() + + async for event in event_stream: await self.log.ainfo( "Trigger fired event", name=self.triggers[trigger_id]["name"], result=event ) @@ -1486,6 +1524,17 @@ async def run_trigger( # fine, the cleanup process will understand that, but we want to # allow triggers a chance to cleanup, either in that case or if # they exit cleanly. Exception from cleanup methods are ignored. + if shared_key is not None: + try: + await self._shared_streams.unsubscribe(trigger_id, shared_key) + except Exception: + # Best-effort cleanup, but log so we don't lose + # cancel-propagation or _handle_poll_terminate bugs. + self.log.exception( + "Failed to unsubscribe trigger from shared stream", + trigger_id=trigger_id, + key=shared_key, + ) with suppress(Exception): await trigger.cleanup() diff --git a/airflow-core/src/airflow/triggers/base.py b/airflow-core/src/airflow/triggers/base.py index f39b62facf7b2..e7c83bada61d6 100644 --- a/airflow-core/src/airflow/triggers/base.py +++ b/airflow-core/src/airflow/triggers/base.py @@ -18,7 +18,7 @@ import abc import json -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Hashable from dataclasses import dataclass from datetime import timedelta from typing import TYPE_CHECKING, Annotated, Any @@ -251,6 +251,50 @@ class BaseEventTrigger(BaseTrigger): ``BaseEventTrigger`` is a subclass of ``BaseTrigger`` designed to identify triggers compatible with event-driven scheduling. + + **Sharing an underlying I/O stream between triggers** + + A subclass that polls an upstream resource which can be safely consumed + by multiple sibling triggers (e.g. a directory scan, a polling REST API, + a Kafka topic read with ``enable.auto.commit=true``) may opt in to having + the triggerer run a single underlying poll loop and fan its raw events + out to every trigger in the group. To do so, override: + + * :meth:`shared_stream_key` — return a key identifying the + shared stream (a ``tuple`` of strings is a common choice). Triggers + whose key compares equal share one poll. + * :meth:`open_shared_stream` — open the shared stream and yield raw + events. Called once per group in the triggerer. + * :meth:`filter_shared_stream` — convert the shared raw stream into this + trigger's own ``TriggerEvent`` instances, applying any per-trigger + filtering or transformation. + + Triggers whose ``shared_stream_key`` returns ``None`` (the default) + keep the existing behavior: each trigger gets its own poll loop via + :meth:`run`. + + **Suitable upstreams** + + The shared-stream channel is **one-way** today: events flow from the + producer (``open_shared_stream``) to each subscriber's + ``filter_shared_stream``, with no path back to tell the producer that a + subscriber accepted, dropped, or finished processing an event. That + restricts the pattern to upstreams whose consumption does **not** depend + on a side effect on a handle that only the producer holds: + + * Idempotent / read-only reads (filesystem listings, polling REST APIs). + * Auto-commit consumers, e.g. Kafka with ``enable.auto.commit=true``. + * Subscriber-side-effect cleanup, where the trigger's per-event action + (``unlink``, local marking, …) operates through APIs the subscriber + already owns, independent of the shared producer handle. + + Upstreams that do **not** fit this scope today include Kafka consumers + with manual commit, SQS with delete-on-process or visibility extension, + and any source where producer-side commit / advance is tied to the + subscriber's accept / reject decision. Adding a producer-side ack + channel to support those cases is tracked as a follow-up — to be + designed against a concrete Kafka or SQS consumer rather than against + an abstract API. """ supports_triggerer_queue: bool = False @@ -269,6 +313,82 @@ def hash(classpath: str, kwargs: dict[str, Any]) -> int: normalized = encode_trigger({"classpath": classpath, "kwargs": kwargs})["kwargs"] return hash((classpath, json.dumps(BaseSerialization.serialize(normalized)).encode("utf-8"))) + def shared_stream_key(self) -> Hashable | None: + """ + Identify an underlying I/O stream that can be shared with sibling triggers. + + Two trigger instances whose ``shared_stream_key()`` return values + compare equal (and are not ``None``) will share a single underlying + poll loop in the triggerer. Each instance still receives the events + it cares about through its own :meth:`filter_shared_stream` call. + + Returning ``None`` (the default) opts out of sharing — the trigger + runs its own independent poll loop via :meth:`run`, exactly as today. + + The return value is read **once** when ``run_trigger`` first starts + this trigger; any change to the key afterwards has no effect on + group membership for this instance. To share one poll across a set + of sibling triggers, ensure every trigger in the set returns the + same key from the outset. + + .. warning:: + + This method is called **before** :meth:`render_template_fields`, + so any templated attribute (for example a ``directory`` derived + from a Jinja expression) is still in its unrendered form here. + Keying on such an attribute means two sibling triggers that + render to the same path will not share their poll. Either base + the key only on already-resolved attributes, or render the + relevant fields yourself before constructing the key. + """ + return None + + @classmethod + async def open_shared_stream(cls, kwargs: dict[str, Any]) -> AsyncIterator[Any]: + """ + Open the shared underlying stream and yield raw events. + + Called **once per shared-stream group** in the triggerer. ``kwargs`` + is taken from one trigger in the group; implementations should rely + only on fields whose values participate in :meth:`shared_stream_key`, + because other fields may differ between siblings in the group. + + Implementations are expected to run for the lifetime of the group — + the triggerer drives the iterator from a single task and cancels it + when the last subscriber leaves. Returning without raising (e.g. + because the upstream resource closed) is treated as an error and + propagated to every subscriber, so the contract is "yield forever, or + raise". If an upstream EOF is a meaningful end-of-life condition, + raise an exception that conveys it. + + Declared as a classmethod (not staticmethod) so subclasses can + compose via ``super().open_shared_stream(kwargs)`` and reach + ``cls`` for class-scoped state or diagnostics. + + Required only when :meth:`shared_stream_key` returns non-``None``. + """ + raise NotImplementedError( + f"{cls.__name__} declares a shared_stream_key but does not implement open_shared_stream" + ) + yield # pragma: no cover - convince mypy this is an async iterator + + async def filter_shared_stream(self, shared_stream: AsyncIterator[Any]) -> AsyncIterator[TriggerEvent]: + """ + Transform the shared raw event stream into this trigger's events. + + The triggerer calls this method (instead of :meth:`run`) when this + trigger participates in a shared-stream group. Iterate + ``shared_stream`` to receive raw events from the shared poll, and + ``yield`` a :class:`TriggerEvent` for each one that should fire this + trigger. + + Required only when :meth:`shared_stream_key` returns non-``None``. + """ + raise NotImplementedError( + f"{type(self).__name__} declares a shared_stream_key but does not implement filter_shared_stream" + ) + yield # pragma: no cover - convince mypy this is an async iterator + class TriggerEvent(BaseModel): """ diff --git a/airflow-core/src/airflow/triggers/shared_stream.py b/airflow-core/src/airflow/triggers/shared_stream.py new file mode 100644 index 0000000000000..388a832509f3d --- /dev/null +++ b/airflow-core/src/airflow/triggers/shared_stream.py @@ -0,0 +1,389 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Shared underlying I/O between :class:`BaseEventTrigger` instances in the triggerer. + +When multiple triggers declare the same non-``None`` +:meth:`~airflow.triggers.base.BaseEventTrigger.shared_stream_key`, the +triggerer routes them through :class:`SharedStreamManager` so that one +underlying poll loop produces raw events that are broadcast to every +participating trigger. Each trigger then runs +:meth:`~airflow.triggers.base.BaseEventTrigger.filter_shared_stream` to +convert the broadcast into its own :class:`~airflow.triggers.base.TriggerEvent` +instances. Triggers that opt out (the default) keep their independent +``run()``-based poll loops untouched. + +Scope and the missing ack channel +--------------------------------- + +The shared-stream channel is **one-way**: events flow from +``open_shared_stream`` out to each subscriber's ``filter_shared_stream``, +with no path back. Subscribers cannot tell the producer "I accepted this +event; please advance / commit / ack". The pattern is therefore only safe +for upstreams whose consumption does not need a producer-side side effect +tied to a subscriber's accept / reject decision: + +* Idempotent / read-only reads (filesystem listings, polling REST APIs). +* Auto-commit Kafka consumers (``enable.auto.commit=true``). +* Subscriber-side-effect cleanup (``unlink``, local marking, …) where the + per-event action goes through APIs the subscriber owns independently. + +Kafka manual-commit consumers, SQS delete-on-process / visibility +extension, and similar message-broker patterns where progress is per-message +and tied to the subscriber's decision are **not** in scope here today. A +producer-side ack channel to cover them is a follow-up that should be +designed against a concrete Kafka or SQS consumer rather than against an +abstract API. See :class:`~airflow.triggers.base.BaseEventTrigger` for the +matching subclass-facing notes. + +Lifecycle invariants +-------------------- + +The manager and groups cooperate to keep a single invariant true at every +``await``-point: + + A key is present in :attr:`SharedStreamManager._groups` only while its + group's poll task is alive and accepting new subscribers. + +This rules out the late-subscriber races that the naive design admits — a +new subscriber for a key whose poll has died or is in the middle of being +torn down always falls through to "create a fresh group" rather than +attaching to a dead one and hanging on an empty queue. The invariant is +maintained synchronously: + +* When ``_poll`` ends for any reason other than cancellation (the upstream + iterator raised, or returned), the group's ``finally`` block evicts the + key from ``_groups`` and broadcasts a terminal sentinel to current + subscribers — all without yielding, so no other coroutine can interleave. +* When the last subscriber leaves, :meth:`SharedStreamManager.unsubscribe` + evicts the key from ``_groups`` *before* awaiting ``group.stop()``, so a + new subscriber arriving while we wait for cancellation creates a fresh + group. +* :meth:`SharedStreamManager.stop_all` clears ``_groups`` in one synchronous + step before awaiting any stop, applying the same rule to shutdown. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncGenerator, AsyncIterator, Callable, Hashable +from contextlib import suppress +from typing import TYPE_CHECKING, Any + +import structlog + +if TYPE_CHECKING: + from structlog.stdlib import BoundLogger + + from airflow.triggers.base import BaseEventTrigger + +log = structlog.get_logger(__name__) + +DEFAULT_SUBSCRIBER_QUEUE_MAX = 1024 +"""Default per-subscriber queue size for shared streams. + +The :class:`SharedStreamManager` admits up to this many unconsumed raw events +per subscriber before treating the subscriber as too slow to keep up — at +which point the subscriber's trigger is failed with +:class:`_SubscriberOverflow` rather than the queue growing without bound. + +Used as the fallback when no value is passed to ``SharedStreamManager``; +in the triggerer this is overridden from the +``[triggerer] shared_stream_subscriber_queue_size`` config option. +""" + + +class _PollTerminated(Exception): + """ + Raised inside subscribers when ``open_shared_stream`` returns without yielding more events. + + Implementations are expected to run for the lifetime of the group; an + early return would otherwise leave subscribers waiting forever on an + empty queue. + """ + + +class _SubscriberOverflow(Exception): + """ + Raised in a subscriber whose queue exceeded its maxsize. + + Surfaces the slow subscriber loudly through the standard trigger-failure + path (rather than silently dropping events) so Airflow's retry / failure + semantics apply. Other subscribers in the same group are unaffected. + """ + + +class _PollFailure: + """Sentinel propagated through subscriber queues when the shared poll ends.""" + + __slots__ = ("exc",) + + def __init__(self, exc: BaseException) -> None: + self.exc = exc + + +async def _drain(queue: asyncio.Queue) -> AsyncGenerator[Any, None]: + """ + Yield items from ``queue`` until a poll termination sentinel arrives. + + Subscribers exit either by their consuming task being cancelled + (Airflow's standard idiom — :class:`CancelledError` propagates through + ``queue.get()``) or by the shared poll ending, in which case the + :class:`_PollFailure` sentinel re-raises here. + """ + while True: + item = await queue.get() + if isinstance(item, _PollFailure): + raise item.exc + yield item + + +class _SharedStreamGroup: + """One shared poll loop broadcasting raw events to N subscriber queues.""" + + def __init__( + self, + *, + key: Hashable, + trigger_class: type[BaseEventTrigger], + kwargs: dict[str, Any], + on_poll_terminate: Callable[[_SharedStreamGroup], None], + max_subscriber_queue: int, + log: BoundLogger, + ) -> None: + self.key = key + self.trigger_class = trigger_class + self.kwargs = kwargs + self.log = log + self._on_poll_terminate = on_poll_terminate + self._max_subscriber_queue = max_subscriber_queue + self._subscribers: dict[int, asyncio.Queue] = {} + self._overflowed: set[int] = set() + self._poll_task: asyncio.Task | None = None + + def start(self) -> None: + """Start the underlying poll loop. Call exactly once per group.""" + if self._poll_task is not None: + raise RuntimeError(f"Shared stream group {self.key!r} already started") + self._poll_task = asyncio.create_task( + self._poll(), + name=f"shared-stream-poll[{self.key!r}]", + ) + + async def _poll(self) -> None: + terminal_exc: BaseException | None = None + try: + async for raw_event in self.trigger_class.open_shared_stream(self.kwargs): + for trigger_id, queue in self._subscribers.items(): + if trigger_id in self._overflowed: + # Subscriber has been force-failed on a previous + # overflow; the failure sentinel is already in its + # queue and unsubscribe will drop it on next pass. + continue + try: + queue.put_nowait(raw_event) + except asyncio.QueueFull: + self._fail_overflowed_subscriber(trigger_id, queue) + terminal_exc = _PollTerminated( + f"open_shared_stream for {self.key!r} returned without raising; " + "shared streams are expected to run for the lifetime of the group" + ) + except asyncio.CancelledError: + # ``stop()`` initiated this; the manager has already evicted the + # group and is awaiting our exit. Do not run the terminate path. + raise + except Exception as exc: + terminal_exc = exc + self.log.exception("Shared stream poll failed; propagating to subscribers", key=self.key) + finally: + if terminal_exc is not None: + # Synchronous: evict from the manager and broadcast the + # sentinel before returning to the loop, so no coroutine can + # observe ``_groups[key]`` pointing at a dead poll. + self._on_poll_terminate(self) + failure = _PollFailure(terminal_exc) + for queue in self._subscribers.values(): + # A subscriber whose queue is already at capacity (slow + # consumer, or an unread ``_SubscriberOverflow`` sentinel) + # would raise ``QueueFull`` here and abort the broadcast, + # leaving later subscribers without the terminal sentinel. + # Drain whatever stale events are queued — they become + # irrelevant once the poll is terminating — and then put + # the failure so every subscriber wakes up. + while not queue.empty(): + try: + queue.get_nowait() + except asyncio.QueueEmpty: + break + queue.put_nowait(failure) + + def subscribe(self, trigger_id: int) -> AsyncIterator[Any]: + """Register ``trigger_id`` as a subscriber and return its raw event stream.""" + if trigger_id in self._subscribers: + raise RuntimeError(f"Trigger {trigger_id} already subscribed to shared stream {self.key!r}") + queue: asyncio.Queue = asyncio.Queue(maxsize=self._max_subscriber_queue) + self._subscribers[trigger_id] = queue + return _drain(queue) + + def unsubscribe(self, trigger_id: int) -> None: + # Active subscribers exit through their consuming task being cancelled + # (Airflow's standard idiom); dropping the queue is enough here. + self._subscribers.pop(trigger_id, None) + self._overflowed.discard(trigger_id) + + def _fail_overflowed_subscriber(self, trigger_id: int, queue: asyncio.Queue) -> None: + """ + Force a slow subscriber to fail with :class:`_SubscriberOverflow`. + + The broadcast hit ``QueueFull`` for this subscriber's queue, which + means the subscriber's :meth:`filter_shared_stream` is falling behind + the upstream cadence. Rather than dropping events silently — which + would invisibly violate Asset event-driven semantics — we drain + whatever stale events are pending and replace them with a + :class:`_PollFailure` so the subscriber's ``run_trigger`` sees the + error on its next ``__anext__``. Other subscribers in the same group + are unaffected. + """ + self.log.warning( + "Shared stream subscriber overflowed; failing this trigger", + key=self.key, + trigger_id=trigger_id, + queue_maxsize=queue.maxsize, + ) + # Discard the unread backlog so the subscriber jumps straight to the + # failure on its next get() instead of chewing through stale events. + while not queue.empty(): + try: + queue.get_nowait() + except asyncio.QueueEmpty: + break + queue.put_nowait( + _PollFailure( + _SubscriberOverflow( + f"shared stream {self.key!r} fell behind for trigger {trigger_id}: " + f"subscriber queue exceeded maxsize={queue.maxsize}" + ) + ) + ) + self._overflowed.add(trigger_id) + + def is_empty(self) -> bool: + return not self._subscribers + + async def stop(self) -> None: + """Cancel the poll task if it is still running and wait for it to exit.""" + if self._poll_task is None or self._poll_task.done(): + return + self._poll_task.cancel() + with suppress(asyncio.CancelledError): + await self._poll_task + + +class SharedStreamManager: + """ + Coordinate :class:`BaseEventTrigger` instances that share underlying I/O. + + The manager owns one :class:`_SharedStreamGroup` per distinct + ``shared_stream_key``. Each group runs a single async task that drives + ``open_shared_stream``; subscribers receive raw events through their own + asyncio queues and convert them to :class:`TriggerEvent` instances + independently. + + The manager is single-event-loop and not thread-safe. The triggerer's + ``TriggerRunner`` is its sole owner. + """ + + def __init__( + self, + *, + log: BoundLogger | None = None, + max_subscriber_queue: int = DEFAULT_SUBSCRIBER_QUEUE_MAX, + ) -> None: + self.log = log or structlog.get_logger(__name__) + self._max_subscriber_queue = max_subscriber_queue + self._groups: dict[Hashable, _SharedStreamGroup] = {} + + def subscribe( + self, + *, + trigger_id: int, + trigger: BaseEventTrigger, + key: Hashable, + ) -> AsyncIterator[Any]: + """ + Subscribe a trigger to the shared stream identified by ``key``. + + On first subscriber for a given key the group is created and the + underlying poll loop is started. Returns an async iterator of raw + events the trigger should feed into ``filter_shared_stream``. + """ + if key is None: + raise ValueError("shared stream key must not be None") + group = self._groups.get(key) + if group is None: + _, kwargs = trigger.serialize() + group = _SharedStreamGroup( + key=key, + trigger_class=type(trigger), + kwargs=kwargs, + on_poll_terminate=self._handle_poll_terminate, + max_subscriber_queue=self._max_subscriber_queue, + log=self.log, + ) + self._groups[key] = group + group.start() + self.log.debug("Shared stream group started", key=key) + return group.subscribe(trigger_id) + + async def unsubscribe(self, trigger_id: int, key: Hashable) -> None: + """ + Remove a subscriber. + + When the last subscriber for ``key`` leaves, the key is evicted from + ``_groups`` synchronously and the underlying poll task is cancelled. + Eviction happens *before* awaiting ``stop()`` so that a new subscriber + arriving while we wait for cancellation builds a fresh group rather + than attaching to the dying one. + """ + group = self._groups.get(key) + if group is None: + return + group.unsubscribe(trigger_id) + if group.is_empty(): + del self._groups[key] + await group.stop() + self.log.debug("Shared stream group stopped", key=key) + + async def stop_all(self) -> None: + """Cancel every active group; used during triggerer shutdown.""" + groups = list(self._groups.values()) + self._groups.clear() + for group in groups: + await group.stop() + + def _handle_poll_terminate(self, group: _SharedStreamGroup) -> None: + """ + Evict a group synchronously when its poll task ends on its own. + + Invoked from ``_SharedStreamGroup._poll``'s ``finally`` before any + ``await`` hands control to another coroutine, so the eviction races no + ``subscribe`` call. The ``is`` check is defensive — under normal flow + a group only enters this path while it is still the live entry for + its key. + """ + if self._groups.get(group.key) is group: + del self._groups[group.key] diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 0501783b992d2..73643ab3867af 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -776,6 +776,70 @@ def test_run_trigger_on_kill_exception_does_not_swallow_cancelled_error(self, se mock_trigger.on_kill.assert_awaited_once() mock_trigger.cleanup.assert_awaited_once() + def test_run_trigger_routes_shared_stream_trigger_through_manager(self, session) -> None: + """A BaseEventTrigger that opts into a shared stream consumes filter_shared_stream().""" + from airflow.triggers.base import BaseEventTrigger, TriggerEvent + + class _SharedTrigger(BaseEventTrigger): + def __init__(self, queue_url: str, region: str | None = None): + super().__init__() + self.queue_url = queue_url + self.region = region + + def serialize(self): + return ( + f"{type(self).__module__}.{type(self).__qualname__}", + {"queue_url": self.queue_url, "region": self.region}, + ) + + def shared_stream_key(self): + return ("queue", self.queue_url) + + @classmethod + async def open_shared_stream(cls, kwargs): + yield {"region": "us"} + yield {"region": "eu"} + # Stay alive so the manager can tear us down on unsubscribe. + await asyncio.Event().wait() + + async def filter_shared_stream(self, shared_stream): + async for raw in shared_stream: + if self.region is None or raw["region"] == self.region: + yield TriggerEvent(raw) + + async def run(self): # pragma: no cover - replaced by filter_shared_stream + yield TriggerEvent({}) + + trigger_runner = TriggerRunner() + trigger_runner.triggers = { + 1: {"task": MagicMock(spec=asyncio.Task), "is_watcher": True, "name": "us", "events": 0} + } + trigger = _SharedTrigger(queue_url="https://q", region="us") + trigger.task_instance = MagicMock() + trigger.task_instance.map_index = -1 + + async def _drive(): + run_task = asyncio.create_task(trigger_runner.run_trigger(1, trigger)) + # Wait until the "us" event has been pushed onto the outbound queue, + # then cancel the trigger so the test can exit deterministically. + for _ in range(100): + await asyncio.sleep(0.01) + if trigger_runner.events: + break + run_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await run_task + + asyncio.run(_drive()) + + events = list(trigger_runner.events) + assert len(events) == 1 + trigger_id, event = events[0] + assert trigger_id == 1 + assert event.payload == {"region": "us"} + # Group is torn down on unsubscribe. + assert trigger_runner._shared_streams._groups == {} + def test_run_trigger_on_kill_timeout_does_not_block_cleanup(self, session) -> None: """A hanging on_kill() is interrupted after the timeout and cleanup still runs.""" trigger_runner = TriggerRunner() diff --git a/airflow-core/tests/unit/triggers/test_base_trigger.py b/airflow-core/tests/unit/triggers/test_base_trigger.py index d9e38385a2052..8e5690a3ef991 100644 --- a/airflow-core/tests/unit/triggers/test_base_trigger.py +++ b/airflow-core/tests/unit/triggers/test_base_trigger.py @@ -20,7 +20,7 @@ import pytest from airflow.sdk.bases.operator import BaseOperator -from airflow.triggers.base import BaseTrigger, StartTriggerArgs +from airflow.triggers.base import BaseEventTrigger, BaseTrigger, StartTriggerArgs, TriggerEvent class DummyOperator(BaseOperator): @@ -138,3 +138,97 @@ def test_render_template_fields_empty_when_no_trigger_kwargs(create_task_instanc # Rendering with empty template_fields is a no-op trigger.render_template_fields(context={"name": "world"}) assert trigger.name == "Hello {{ name }}" + + +class _PlainEventTrigger(BaseEventTrigger): + """A BaseEventTrigger that does not opt into shared streams.""" + + def __init__(self, name: str = "plain"): + super().__init__() + self.name = name + + def serialize(self): + return (f"{type(self).__module__}.{type(self).__qualname__}", {"name": self.name}) + + async def run(self): + yield TriggerEvent({"name": self.name}) + + +class _SharedQueueTrigger(BaseEventTrigger): + """A BaseEventTrigger that opts into shared streams.""" + + def __init__(self, queue_url: str, region: str | None = None): + super().__init__() + self.queue_url = queue_url + self.region = region + + def serialize(self): + return ( + f"{type(self).__module__}.{type(self).__qualname__}", + {"queue_url": self.queue_url, "region": self.region}, + ) + + def shared_stream_key(self): + return ("shared-queue", self.queue_url) + + @classmethod + async def open_shared_stream(cls, kwargs): + for region in ("us", "eu", "us"): + yield {"queue_url": kwargs["queue_url"], "region": region} + + async def filter_shared_stream(self, shared_stream): + async for raw in shared_stream: + if self.region is None or raw["region"] == self.region: + yield TriggerEvent(raw) + + async def run(self): # pragma: no cover - replaced by filter_shared_stream + yield TriggerEvent({}) + + +def test_base_event_trigger_defaults_no_sharing(): + trigger = _PlainEventTrigger() + assert trigger.shared_stream_key() is None + + +async def _drain_async_iter(it): + async for _ in it: + pass + + +@pytest.mark.asyncio +async def test_base_event_trigger_default_open_shared_stream_raises(): + with pytest.raises(NotImplementedError, match="open_shared_stream"): + await _drain_async_iter(_PlainEventTrigger.open_shared_stream({})) + + +@pytest.mark.asyncio +async def test_base_event_trigger_default_filter_shared_stream_raises(): + trigger = _PlainEventTrigger() + + async def empty_stream(): + if False: + yield # pragma: no cover + + with pytest.raises(NotImplementedError, match="filter_shared_stream"): + await _drain_async_iter(trigger.filter_shared_stream(empty_stream())) + + +def test_subclass_can_declare_shared_stream_key(): + a = _SharedQueueTrigger(queue_url="https://q", region="us") + b = _SharedQueueTrigger(queue_url="https://q", region="eu") + c = _SharedQueueTrigger(queue_url="https://other", region="us") + + assert a.shared_stream_key() == b.shared_stream_key() + assert a.shared_stream_key() != c.shared_stream_key() + + +@pytest.mark.asyncio +async def test_subclass_filter_shared_stream_applies_per_instance_match(): + us = _SharedQueueTrigger(queue_url="https://q", region="us") + + async def stream(): + for region in ("us", "eu", "us"): + yield {"queue_url": "https://q", "region": region} + + payloads = [event.payload async for event in us.filter_shared_stream(stream())] + assert [p["region"] for p in payloads] == ["us", "us"] diff --git a/airflow-core/tests/unit/triggers/test_shared_stream.py b/airflow-core/tests/unit/triggers/test_shared_stream.py new file mode 100644 index 0000000000000..0158f7cfdb5a2 --- /dev/null +++ b/airflow-core/tests/unit/triggers/test_shared_stream.py @@ -0,0 +1,635 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from contextlib import suppress + +import pytest + +from airflow.triggers.base import BaseEventTrigger, TriggerEvent +from airflow.triggers.shared_stream import SharedStreamManager, _SubscriberOverflow + + +class _ProgrammableSharedStreamTrigger(BaseEventTrigger): + """ + Test helper trigger whose shared poll yields whatever the test class attr says. + + Subclass per test so each scenario gets its own ``open_shared_stream`` + behavior without leaking state between tests. + """ + + queue_url: str = "https://q" + + def __init__(self, queue_url: str = "https://q", region: str | None = None): + super().__init__() + self.queue_url = queue_url + self.region = region + + def serialize(self): + return ( + f"{type(self).__module__}.{type(self).__qualname__}", + {"queue_url": self.queue_url, "region": self.region}, + ) + + def shared_stream_key(self): + return ("queue", self.queue_url) + + async def filter_shared_stream(self, shared_stream): + async for raw in shared_stream: + if self.region is None or raw["region"] == self.region: + yield TriggerEvent(raw) + + async def run(self): # pragma: no cover - replaced by filter_shared_stream + yield TriggerEvent({}) + + +def _events_then_block(events: list[dict]): + async def _open_shared_stream(cls, kwargs): + for event in events: + yield event + # Stay alive forever so tests can observe broadcast then tear down. + await asyncio.Event().wait() + + return classmethod(_open_shared_stream) + + +def _make_trigger_class(open_shared_stream): + """Return a fresh subclass with the given open_shared_stream classmethod.""" + + class _Trigger(_ProgrammableSharedStreamTrigger): + pass + + _Trigger.open_shared_stream = open_shared_stream + return _Trigger + + +async def _collect(stream, *, n: int, timeout: float = 1.0) -> list: + """Pull ``n`` items off an async iterator with a per-item timeout.""" + out = [] + it = stream.__aiter__() + for _ in range(n): + out.append(await asyncio.wait_for(it.__anext__(), timeout=timeout)) + return out + + +@pytest.mark.asyncio +async def test_single_subscriber_receives_broadcast_events(): + cls = _make_trigger_class( + _events_then_block( + [ + {"region": "us"}, + {"region": "eu"}, + ] + ) + ) + trigger = cls(region="us") + manager = SharedStreamManager() + try: + stream = manager.subscribe(trigger_id=1, trigger=trigger, key=trigger.shared_stream_key()) + events = await _collect(trigger.filter_shared_stream(stream), n=1) + assert [e.payload["region"] for e in events] == ["us"] + finally: + await manager.unsubscribe(1, trigger.shared_stream_key()) + + +@pytest.mark.asyncio +async def test_two_subscribers_share_one_poll_and_filter_independently(): + cls = _make_trigger_class( + _events_then_block( + [ + {"region": "us"}, + {"region": "eu"}, + {"region": "us"}, + ] + ) + ) + us, eu = cls(region="us"), cls(region="eu") + key = us.shared_stream_key() + assert key == eu.shared_stream_key() + + manager = SharedStreamManager() + try: + us_stream = manager.subscribe(trigger_id=1, trigger=us, key=key) + eu_stream = manager.subscribe(trigger_id=2, trigger=eu, key=key) + + # The shared group is created exactly once. + assert len(manager._groups) == 1 + + us_events, eu_events = await asyncio.gather( + _collect(us.filter_shared_stream(us_stream), n=2), + _collect(eu.filter_shared_stream(eu_stream), n=1), + ) + assert [e.payload["region"] for e in us_events] == ["us", "us"] + assert [e.payload["region"] for e in eu_events] == ["eu"] + finally: + await manager.unsubscribe(1, key) + await manager.unsubscribe(2, key) + + +@pytest.mark.asyncio +async def test_group_is_torn_down_when_last_subscriber_leaves(): + cls = _make_trigger_class(_events_then_block([{"region": "us"}])) + trigger = cls(region="us") + manager = SharedStreamManager() + key = trigger.shared_stream_key() + + manager.subscribe(trigger_id=1, trigger=trigger, key=key) + assert key in manager._groups + + await manager.unsubscribe(1, key) + assert key not in manager._groups + + +@pytest.mark.asyncio +async def test_independent_keys_use_independent_groups(): + cls = _make_trigger_class(_events_then_block([{"region": "us"}])) + a = cls(queue_url="https://a") + b = cls(queue_url="https://b") + manager = SharedStreamManager() + + manager.subscribe(trigger_id=1, trigger=a, key=a.shared_stream_key()) + manager.subscribe(trigger_id=2, trigger=b, key=b.shared_stream_key()) + try: + assert set(manager._groups) == {a.shared_stream_key(), b.shared_stream_key()} + finally: + await manager.unsubscribe(1, a.shared_stream_key()) + await manager.unsubscribe(2, b.shared_stream_key()) + + +@pytest.mark.asyncio +async def test_poll_failure_propagates_to_subscribers_and_evicts_group(): + async def _open_shared_stream(cls, kwargs): + raise RuntimeError("boom") + yield # pragma: no cover + + cls = _make_trigger_class(classmethod(_open_shared_stream)) + trigger = cls() + manager = SharedStreamManager() + key = trigger.shared_stream_key() + try: + stream = manager.subscribe(trigger_id=1, trigger=trigger, key=key) + with pytest.raises(RuntimeError, match="boom"): + await asyncio.wait_for(_collect(trigger.filter_shared_stream(stream), n=1), timeout=1.0) + # The failing poll evicts its own group from the manager in _poll's + # finally, before any subscriber resumes — so by the time the + # subscriber observes "boom" the manager already has no group for + # this key. A late subscriber arriving here would create a fresh + # group rather than attaching to a dead one. + assert key not in manager._groups + finally: + await manager.unsubscribe(1, key) + + +@pytest.mark.asyncio +async def test_subscribe_rejects_none_key(): + cls = _make_trigger_class(_events_then_block([])) + trigger = cls() + manager = SharedStreamManager() + with pytest.raises(ValueError, match="must not be None"): + manager.subscribe(trigger_id=1, trigger=trigger, key=None) + + +@pytest.mark.asyncio +async def test_double_subscribe_same_id_is_rejected(): + cls = _make_trigger_class(_events_then_block([])) + trigger = cls() + manager = SharedStreamManager() + key = trigger.shared_stream_key() + try: + manager.subscribe(trigger_id=1, trigger=trigger, key=key) + with pytest.raises(RuntimeError, match="already subscribed"): + manager.subscribe(trigger_id=1, trigger=trigger, key=key) + finally: + await manager.unsubscribe(1, key) + + +@pytest.mark.asyncio +async def test_stop_all_clears_every_group(): + cls = _make_trigger_class(_events_then_block([])) + a = cls(queue_url="https://a") + b = cls(queue_url="https://b") + manager = SharedStreamManager() + + manager.subscribe(trigger_id=1, trigger=a, key=a.shared_stream_key()) + manager.subscribe(trigger_id=2, trigger=b, key=b.shared_stream_key()) + assert len(manager._groups) == 2 + + await manager.stop_all() + assert manager._groups == {} + + +@pytest.mark.asyncio +async def test_late_subscriber_after_poll_failure_gets_fresh_group(): + """The first call's open_shared_stream raises; a subsequent subscribe for the same key should + start a brand new poll rather than attach to the dead group. + """ + invocations: list[int] = [] + + async def _open_shared_stream(cls, kwargs): + n = len(invocations) + invocations.append(n) + if n == 0: + raise RuntimeError("first invocation fails") + yield {"region": "us"} + await asyncio.Event().wait() + + cls = _make_trigger_class(classmethod(_open_shared_stream)) + trigger = cls() + manager = SharedStreamManager() + key = trigger.shared_stream_key() + + stream1 = manager.subscribe(trigger_id=1, trigger=trigger, key=key) + with pytest.raises(RuntimeError, match="first invocation fails"): + await asyncio.wait_for( + _collect(trigger.filter_shared_stream(stream1), n=1), + timeout=1.0, + ) + await manager.unsubscribe(1, key) + + stream2 = manager.subscribe(trigger_id=2, trigger=trigger, key=key) + try: + events = await asyncio.wait_for( + _collect(trigger.filter_shared_stream(stream2), n=1), + timeout=1.0, + ) + assert [e.payload["region"] for e in events] == ["us"] + finally: + await manager.unsubscribe(2, key) + + assert invocations == [0, 1], "open_shared_stream should be called twice (failed, then fresh)" + + +@pytest.mark.asyncio +async def test_late_subscriber_during_poll_failure_window_does_not_attach_to_dead_group(): + """Reproduce the race the lifecycle rewrite closes: a new subscriber arriving after _poll has + raised but before the original subscriber has finished propagating the failure must see no + existing group and create a fresh one — otherwise it would attach to a queue nothing will ever + put events on. + """ + invocations: list[int] = [] + + async def _open_shared_stream(cls, kwargs): + n = len(invocations) + invocations.append(n) + if n == 0: + raise RuntimeError("boom") + yield {"region": "fresh"} + await asyncio.Event().wait() + + cls = _make_trigger_class(classmethod(_open_shared_stream)) + trigger = cls() + manager = SharedStreamManager() + key = trigger.shared_stream_key() + + stream1 = manager.subscribe(trigger_id=1, trigger=trigger, key=key) + + # Wait for the poll task to finish its lifecycle — including the synchronous self-eviction in + # its finally block — but do NOT consume the _PollFailure from stream1 yet. This simulates the + # "broadcast done, subscriber not yet unwound" window described in the bug report. + poll_task = manager._groups[key]._poll_task + assert poll_task is not None + with suppress(RuntimeError): + await poll_task + + assert key not in manager._groups, ( + "the failing poll must evict its group synchronously in _poll's finally, so this window " + "is closed before any other coroutine can subscribe" + ) + + stream2 = manager.subscribe(trigger_id=2, trigger=trigger, key=key) + try: + events = await asyncio.wait_for( + _collect(trigger.filter_shared_stream(stream2), n=1), + timeout=1.0, + ) + assert events[0].payload == {"region": "fresh"} + finally: + # Original subscriber still has _PollFailure waiting for it. + with pytest.raises(RuntimeError, match="boom"): + await asyncio.wait_for( + _collect(trigger.filter_shared_stream(stream1), n=1), + timeout=1.0, + ) + await manager.unsubscribe(1, key) + await manager.unsubscribe(2, key) + + assert invocations == [0, 1] + + +@pytest.mark.asyncio +async def test_resubscribe_during_last_unsubscribe_creates_fresh_group(): + """If the last subscriber leaves and the manager is mid-``await group.stop()``, a concurrent + subscribe for the same key must build a new group instead of attaching to the dying one. + """ + invocations: list[int] = [] + + async def _open_shared_stream(cls, kwargs): + n = len(invocations) + invocations.append(n) + yield {"n": n} + await asyncio.Event().wait() + + cls = _make_trigger_class(classmethod(_open_shared_stream)) + trigger = cls() + manager = SharedStreamManager() + key = trigger.shared_stream_key() + + stream1 = manager.subscribe(trigger_id=1, trigger=trigger, key=key) + await asyncio.wait_for( + _collect(trigger.filter_shared_stream(stream1), n=1), + timeout=1.0, + ) + + unsub_task = asyncio.create_task(manager.unsubscribe(1, key)) + # One tick: unsubscribe runs synchronously through the pop-from-_groups step, then yields at + # `await group.stop()`. After this yield returns to us, _groups is already cleared. + await asyncio.sleep(0) + assert key not in manager._groups, ( + "manager.unsubscribe must evict the group from _groups before awaiting stop(), so a " + "racing subscribe sees no group and creates a fresh one" + ) + + stream2 = manager.subscribe(trigger_id=2, trigger=trigger, key=key) + try: + events = await asyncio.wait_for( + _collect(trigger.filter_shared_stream(stream2), n=1), + timeout=1.0, + ) + # Second invocation (index 1) — proves stream2 is bound to a fresh poll, not the dying one. + assert events[0].payload == {"n": 1} + finally: + await unsub_task + await manager.unsubscribe(2, key) + + assert invocations == [0, 1] + + +@pytest.mark.asyncio +async def test_open_shared_stream_returning_naturally_propagates_as_failure(): + """A shared poll that exhausts its iterator instead of running indefinitely would otherwise + leave subscribers blocked on queue.get() forever; the manager surfaces it as an error. + """ + + async def _open_shared_stream(cls, kwargs): + yield {"region": "us"} + + cls = _make_trigger_class(classmethod(_open_shared_stream)) + trigger = cls() + manager = SharedStreamManager() + key = trigger.shared_stream_key() + + stream = manager.subscribe(trigger_id=1, trigger=trigger, key=key) + with pytest.raises(Exception, match="expected to run for the lifetime of the group"): + await asyncio.wait_for( + _collect(trigger.filter_shared_stream(stream), n=2), + timeout=1.0, + ) + + assert key not in manager._groups, "natural exhaustion should evict the group like a failure" + await manager.unsubscribe(1, key) + + +@pytest.mark.asyncio +async def test_slow_subscriber_overflow_fails_only_that_subscriber(): + """A subscriber whose ``filter_shared_stream`` lags behind the upstream cadence enough to + overflow its bounded queue must fail loudly with ``_SubscriberOverflow`` — silent drops are + unacceptable for Asset event-driven semantics. Sibling subscribers in the same group keep + receiving events. + """ + + async def _open_shared_stream(cls, kwargs): + for i in range(5): + yield {"i": i} + # Yield to the loop so the fast consumer gets a chance to drain; + # the slow consumer never runs while sleep(0) ticks pass, so its + # queue fills up. + await asyncio.sleep(0) + await asyncio.Event().wait() + + cls = _make_trigger_class(classmethod(_open_shared_stream)) + slow_trigger = cls() + fast_trigger = cls() + manager = SharedStreamManager(max_subscriber_queue=2) + key = slow_trigger.shared_stream_key() + + slow_stream = manager.subscribe(trigger_id=1, trigger=slow_trigger, key=key) + fast_stream = manager.subscribe(trigger_id=2, trigger=fast_trigger, key=key) + + async def drain_fast(): + out = [] + async for ev in fast_trigger.filter_shared_stream(fast_stream): + out.append(ev) + if len(out) >= 5: + break + return out + + # Start fast first so it drains its queue as the producer broadcasts. + fast_task = asyncio.create_task(drain_fast()) + + # Hand control back so the producer can broadcast all 5 events. The fast + # consumer keeps its queue around 1; the slow consumer has no task yet, + # so its queue fills past maxsize=2 and the overflow handler swaps the + # backlog for a failure sentinel. + fast_events = await asyncio.wait_for(fast_task, timeout=2.0) + + # Slow consumer starts after the overflow; first event should be the failure. + with pytest.raises(_SubscriberOverflow, match="exceeded maxsize"): + await asyncio.wait_for( + _collect(slow_trigger.filter_shared_stream(slow_stream), n=1), + timeout=2.0, + ) + + assert [e.payload["i"] for e in fast_events] == [0, 1, 2, 3, 4], ( + "fast subscriber must not be affected by the slow subscriber's overflow" + ) + # The group is still alive — only the slow subscriber was failed; fast is still subscribed. + assert key in manager._groups + assert 1 in manager._groups[key]._overflowed + + await manager.unsubscribe(1, key) + await manager.unsubscribe(2, key) + + +@pytest.mark.asyncio +async def test_concurrent_unsubscribes_tear_down_group_cleanly(): + """N subscribers leaving at once via concurrent ``unsubscribe`` must end with the group fully + torn down and the poll task cancelled — mirrors a triggerer cancelling many deferred tasks in + the same tick. + """ + cls = _make_trigger_class(_events_then_block([])) + n_subscribers = 8 + triggers = [cls() for _ in range(n_subscribers)] + key = triggers[0].shared_stream_key() + manager = SharedStreamManager() + + for trigger_id, trigger in enumerate(triggers): + manager.subscribe(trigger_id=trigger_id, trigger=trigger, key=key) + assert len(manager._groups[key]._subscribers) == n_subscribers + poll_task = manager._groups[key]._poll_task + assert poll_task is not None + + await asyncio.gather(*(manager.unsubscribe(i, key) for i in range(n_subscribers))) + + assert manager._groups == {}, "every subscriber gone means the group is gone" + assert poll_task.done(), "the poll task must exit when the last subscriber leaves" + assert poll_task.cancelled() + + +@pytest.mark.asyncio +async def test_stop_all_with_blocked_consumer_does_not_inject_failure_sentinel(): + """A consumer blocked on ``queue.get()`` when ``stop_all`` runs must not be woken with a + poison sentinel. The poll task's ``CancelledError`` path explicitly skips the terminate + broadcast, leaving the standard idiom — the trigger's consuming task is cancelled separately + — as the only exit. Verifies the asymmetry between cancel-driven and failure-driven teardown. + """ + cls = _make_trigger_class(_events_then_block([])) # never yields; consumer always blocks + trigger = cls() + key = trigger.shared_stream_key() + manager = SharedStreamManager() + + stream = manager.subscribe(trigger_id=1, trigger=trigger, key=key) + + async def consume(): + async for event in trigger.filter_shared_stream(stream): + return event + return None + + consumer = asyncio.create_task(consume()) + # Let the consumer reach ``await queue.get()``. + await asyncio.sleep(0) + assert not consumer.done() + + poll_task = manager._groups[key]._poll_task + assert poll_task is not None + + await manager.stop_all() + + assert manager._groups == {} + assert poll_task.done() + assert poll_task.cancelled() + # No sentinel was injected — the consumer is still parked on queue.get(). + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(asyncio.shield(consumer), timeout=0.05) + + consumer.cancel() + with suppress(asyncio.CancelledError): + await consumer + + +@pytest.mark.asyncio +async def test_sibling_non_key_kwargs_diverge_first_subscriber_wins(): + """Two siblings with the same ``shared_stream_key`` but divergent non-key kwargs share the + group built from the **first** subscriber's kwargs. The second subscriber's non-key kwargs are + silently ignored — this is the documented contract; the test locks the behavior so any future + change (e.g. adding a runtime warning) is a deliberate decision rather than a regression. + """ + captured_kwargs: list[dict] = [] + + async def _open_shared_stream(cls, kwargs): + captured_kwargs.append(kwargs) + yield {"region": kwargs.get("region")} + await asyncio.Event().wait() + + cls = _make_trigger_class(classmethod(_open_shared_stream)) + first = cls(region="us") + second = cls(region="eu") # same queue_url (key), different region (non-key) + key = first.shared_stream_key() + assert key == second.shared_stream_key() + + manager = SharedStreamManager() + try: + stream1 = manager.subscribe(trigger_id=1, trigger=first, key=key) + manager.subscribe(trigger_id=2, trigger=second, key=key) + + # First subscriber accepts (region="us"); second's filter rejects since the raw event + # carries the first subscriber's region. Verify by consuming from the first subscriber. + events = await _collect(first.filter_shared_stream(stream1), n=1) + assert [e.payload for e in events] == [{"region": "us"}] + + assert len(captured_kwargs) == 1, "open_shared_stream must be called exactly once per group" + assert captured_kwargs[0]["region"] == "us", ( + "first subscriber's non-key kwargs become the group's kwargs" + ) + finally: + await manager.unsubscribe(1, key) + await manager.unsubscribe(2, key) + + +@pytest.mark.asyncio +async def test_serialize_failure_in_subscribe_leaves_groups_clean(): + """If ``trigger.serialize()`` raises while a fresh group is being built, ``subscribe`` must + propagate the exception without leaving an orphan entry in ``_groups``. A subsequent subscribe + for the same key must build a clean group. + """ + cls = _make_trigger_class(_events_then_block([{"region": "us"}])) + + class _BrokenSerializeTrigger(cls): + def serialize(self): + raise RuntimeError("serialize boom") + + broken = _BrokenSerializeTrigger() + manager = SharedStreamManager() + key = broken.shared_stream_key() + + with pytest.raises(RuntimeError, match="serialize boom"): + manager.subscribe(trigger_id=1, trigger=broken, key=key) + + assert key not in manager._groups, "failed subscribe must not leave an orphan group entry" + + clean = cls() + stream = manager.subscribe(trigger_id=2, trigger=clean, key=key) + try: + events = await _collect(clean.filter_shared_stream(stream), n=1) + assert events[0].payload == {"region": "us"} + assert key in manager._groups + finally: + await manager.unsubscribe(2, key) + + +@pytest.mark.asyncio +async def test_terminal_failure_reaches_every_subscriber_even_with_full_queues(): + """When the shared poll raises right after a broadcast that filled every subscriber's queue, + the terminal :class:`_PollFailure` sentinel must still reach all of them. Without draining + each queue before the terminal ``put_nowait``, the first overflowed subscriber would raise + ``QueueFull``, abort the broadcast loop, and silently strand the remaining subscribers on + ``queue.get()`` forever. + """ + + async def _open_shared_stream(cls, kwargs): + yield {"region": "us"} + raise RuntimeError("upstream died") + + cls = _make_trigger_class(classmethod(_open_shared_stream)) + first = cls() + second = cls() + manager = SharedStreamManager(max_subscriber_queue=1) + key = first.shared_stream_key() + + first_stream = manager.subscribe(trigger_id=1, trigger=first, key=key) + second_stream = manager.subscribe(trigger_id=2, trigger=second, key=key) + + # Both queues sit at maxsize=1 with the broadcast event unread when the + # terminal _PollFailure goes out. The fix must drain each queue so the + # sentinel lands; both consumers should observe the same RuntimeError. + with pytest.raises(RuntimeError, match="upstream died"): + await asyncio.wait_for(_collect(first.filter_shared_stream(first_stream), n=2), timeout=2.0) + with pytest.raises(RuntimeError, match="upstream died"): + await asyncio.wait_for(_collect(second.filter_shared_stream(second_stream), n=2), timeout=2.0) + + await manager.unsubscribe(1, key) + await manager.unsubscribe(2, key) diff --git a/providers/standard/src/airflow/providers/standard/triggers/file.py b/providers/standard/src/airflow/providers/standard/triggers/file.py index 699be775ffa8b..aa4548d7dabca 100644 --- a/providers/standard/src/airflow/providers/standard/triggers/file.py +++ b/providers/standard/src/airflow/providers/standard/triggers/file.py @@ -18,8 +18,9 @@ import asyncio import datetime +import logging import os -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Hashable from glob import glob from typing import Any @@ -36,6 +37,8 @@ TriggerEvent, ) +log = logging.getLogger(__name__) + class FileTrigger(BaseTrigger): """ @@ -132,3 +135,104 @@ async def run(self) -> AsyncIterator[TriggerEvent]: yield TriggerEvent(True) return await asyncio.sleep(self.poke_interval) + + +class DirectoryFileDeleteTrigger(BaseEventTrigger): + """ + Fire once when ``filename`` appears in ``directory``, then delete it. + + Functionally equivalent to ``FileDeleteTrigger`` for a single file, but + sibling triggers that point at the same ``directory`` and ``poke_interval`` + share a single underlying directory scan in the triggerer; each instance + only fires for its own ``filename``. This is useful when many assets are + driven by per-flag-file events landing in a shared inbox directory. + + :param directory: Directory to scan. + :param filename: File name (without directory) whose appearance fires this + trigger. The matched file is deleted before the event is yielded. + :param poke_interval: Time to wait between scans. + """ + + def __init__(self, *, directory: str, filename: str, poke_interval: float = 5.0) -> None: + super().__init__() + self.directory = directory + self.filename = filename + self.poke_interval = poke_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize DirectoryFileDeleteTrigger arguments and classpath.""" + return ( + "airflow.providers.standard.triggers.file.DirectoryFileDeleteTrigger", + { + "directory": self.directory, + "filename": self.filename, + "poke_interval": self.poke_interval, + }, + ) + + def shared_stream_key(self) -> Hashable | None: + """All triggers on the same directory + cadence share one scan.""" + # Normalise so trivial path variants (``/tmp/flags`` vs ``/tmp/flags/``, + # or ``./flags`` vs an absolute equivalent in the same cwd) key to the + # same group instead of silently running N independent scans. + return ("directory-scan", os.path.normpath(self.directory), self.poke_interval) + + @classmethod + async def open_shared_stream(cls, kwargs: dict[str, Any]) -> AsyncIterator[Any]: + """ + Drive one directory-listing loop and broadcast each snapshot. + + Missing directories yield an empty snapshot so subscribers keep + polling for the file to appear. Other ``OSError`` cases (permission + denied, transient I/O errors) are logged and the snapshot is skipped + for this cadence — failing the shared poll outright would + cascade-fail every sibling watcher on the same directory for what + may be a transient blip. + """ + directory = anyio.Path(kwargs["directory"]) + poke_interval: float = kwargs["poke_interval"] + while True: + try: + names = {p.name async for p in directory.iterdir()} + except FileNotFoundError: + names = set() + except OSError: + log.warning( + "Failed to list %s; retrying after %ss", + directory, + poke_interval, + exc_info=True, + ) + await asyncio.sleep(poke_interval) + continue + yield {"directory": str(directory), "names": names} + await asyncio.sleep(poke_interval) + + async def filter_shared_stream(self, shared_stream: AsyncIterator[Any]) -> AsyncIterator[TriggerEvent]: + """Fire once for this instance's own filename and delete the file.""" + async for snapshot in shared_stream: + if self.filename not in snapshot["names"]: + continue + filepath = anyio.Path(snapshot["directory"]) / self.filename + try: + await filepath.unlink() + except FileNotFoundError: + # Lost a race with a sibling, or the file disappeared between + # snapshot and unlink. Wait for the next scan. + continue + self.log.info("File %s has been deleted", filepath) + yield TriggerEvent({"filepath": str(filepath)}) + return + + async def run(self) -> AsyncIterator[TriggerEvent]: + """ + Standalone fallback when the shared-stream manager is unavailable. + + Mirrors the shared path so the trigger remains usable in unit tests + and on Airflow versions without the manager wired in. It does not + deduplicate I/O — that requires the triggerer to drive the shared + stream. + """ + kwargs = self.serialize()[1] + async for event in self.filter_shared_stream(type(self).open_shared_stream(kwargs)): + yield event diff --git a/providers/standard/tests/unit/standard/triggers/test_file.py b/providers/standard/tests/unit/standard/triggers/test_file.py index 793f0aeb62861..17b97c8a5d710 100644 --- a/providers/standard/tests/unit/standard/triggers/test_file.py +++ b/providers/standard/tests/unit/standard/triggers/test_file.py @@ -21,7 +21,11 @@ import anyio import pytest -from airflow.providers.standard.triggers.file import FileDeleteTrigger, FileTrigger +from airflow.providers.standard.triggers.file import ( + DirectoryFileDeleteTrigger, + FileDeleteTrigger, + FileTrigger, +) from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS @@ -106,3 +110,166 @@ async def test_file_delete_trigger(self, tmp_path): # returns, so once the task is done, the file is guaranteed gone. await asyncio.wait_for(task, timeout=5.0) assert await anyio.Path(p).exists() is False + + +@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Skip on Airflow < 3.0") +class TestDirectoryFileDeleteTrigger: + DIRECTORY = "/data/flags" + + def test_serialization(self): + trigger = DirectoryFileDeleteTrigger( + directory=self.DIRECTORY, filename="orders_us.flag", poke_interval=5 + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.standard.triggers.file.DirectoryFileDeleteTrigger" + assert kwargs == { + "directory": self.DIRECTORY, + "filename": "orders_us.flag", + "poke_interval": 5, + } + + def test_shared_stream_key_groups_same_directory_and_cadence(self): + a = DirectoryFileDeleteTrigger(directory=self.DIRECTORY, filename="us.flag", poke_interval=1.0) + b = DirectoryFileDeleteTrigger(directory=self.DIRECTORY, filename="eu.flag", poke_interval=1.0) + c = DirectoryFileDeleteTrigger(directory=self.DIRECTORY, filename="us.flag", poke_interval=2.0) + d = DirectoryFileDeleteTrigger(directory="/other", filename="us.flag", poke_interval=1.0) + + assert a.shared_stream_key() == b.shared_stream_key() + assert a.shared_stream_key() != c.shared_stream_key() + assert a.shared_stream_key() != d.shared_stream_key() + + @pytest.mark.parametrize( + ("first", "second"), + [ + ("/data/flags", "/data/flags/"), + ("/data/flags", "/data//flags"), + ("/data/flags", "/data/./flags"), + ("/data/parent/../flags", "/data/flags"), + ], + ) + def test_shared_stream_key_normalises_trivial_path_variants(self, first, second): + a = DirectoryFileDeleteTrigger(directory=first, filename="us.flag", poke_interval=1.0) + b = DirectoryFileDeleteTrigger(directory=second, filename="us.flag", poke_interval=1.0) + assert a.shared_stream_key() == b.shared_stream_key() + + @pytest.mark.asyncio + async def test_filter_shared_stream_fires_only_for_own_filename(self, tmp_path): + directory = tmp_path / "flags" + await anyio.Path(directory).mkdir() + await (anyio.Path(directory) / "us.flag").touch() + + async def stream(): + yield {"directory": str(directory), "names": {"us.flag", "eu.flag"}} + + us = DirectoryFileDeleteTrigger(directory=str(directory), filename="us.flag", poke_interval=1.0) + events = [event async for event in us.filter_shared_stream(stream())] + + assert len(events) == 1 + assert events[0].payload == {"filepath": str(directory / "us.flag")} + assert await (anyio.Path(directory) / "us.flag").exists() is False + + @pytest.mark.asyncio + async def test_filter_shared_stream_skips_other_filenames(self, tmp_path): + directory = tmp_path / "flags" + await anyio.Path(directory).mkdir() + await (anyio.Path(directory) / "eu.flag").touch() + + async def stream(): + yield {"directory": str(directory), "names": {"eu.flag"}} + + us = DirectoryFileDeleteTrigger(directory=str(directory), filename="us.flag", poke_interval=1.0) + events = [event async for event in us.filter_shared_stream(stream())] + + # Did not fire, did not delete the unrelated file. + assert events == [] + assert await (anyio.Path(directory) / "eu.flag").exists() is True + + @pytest.mark.asyncio + async def test_filter_shared_stream_recovers_when_sibling_unlinks_first(self, tmp_path): + directory = tmp_path / "flags" + await anyio.Path(directory).mkdir() + + async def stream(): + # Snapshot says the file is there; in reality a sibling already + # consumed it, so unlink raises FileNotFoundError. We must keep + # iterating, not crash. After the snapshot drops the filename, + # we exit the iterator without firing. + yield {"directory": str(directory), "names": {"us.flag"}} + yield {"directory": str(directory), "names": set()} + + us = DirectoryFileDeleteTrigger(directory=str(directory), filename="us.flag", poke_interval=1.0) + events = [event async for event in us.filter_shared_stream(stream())] + + assert events == [] + + @pytest.mark.asyncio + async def test_open_shared_stream_handles_missing_directory(self, tmp_path): + missing = tmp_path / "does_not_exist" + snapshots = [] + + async def consume(): + it = DirectoryFileDeleteTrigger.open_shared_stream( + {"directory": str(missing), "poke_interval": 0.01} + ).__aiter__() + for _ in range(2): + snapshots.append(await it.__anext__()) + + await asyncio.wait_for(consume(), timeout=1.0) + + assert all(s["names"] == set() for s in snapshots) + + @pytest.mark.asyncio + async def test_open_shared_stream_logs_and_retries_on_permission_error(self, tmp_path, mocker): + """A transient ``PermissionError`` from ``iterdir`` must not cascade-fail every sibling + watcher. The shared poll logs at warning level, sleeps for one poke, and tries again on + the next cadence so a brief perms blip is recoverable. + """ + # Two failures, then succeed -- proves the poll keeps retrying instead + # of propagating to subscribers. + states: list[set[str]] = [set(), {"us.flag"}] + + async def _iterdir(self): + if not states: + if False: + yield # pragma: no cover - sentinel for async generator typing + return + state = states.pop(0) + if state == set(): + raise PermissionError("denied") + for name in state: + yield anyio.Path("/tmp") / name + + mocker.patch.object(anyio.Path, "iterdir", _iterdir) + warning = mocker.patch("airflow.providers.standard.triggers.file.log.warning") + + directory = tmp_path / "flags" + snapshots = [] + + async def consume(): + it = DirectoryFileDeleteTrigger.open_shared_stream( + {"directory": str(directory), "poke_interval": 0.01} + ).__aiter__() + snapshots.append(await it.__anext__()) + + await asyncio.wait_for(consume(), timeout=2.0) + + assert snapshots == [{"directory": str(directory), "names": {"us.flag"}}] + assert warning.called, "PermissionError must produce a warning, not be silently swallowed" + + @pytest.mark.asyncio + async def test_run_standalone_fallback_polls_until_filename_appears(self, tmp_path): + directory = tmp_path / "flags" + await anyio.Path(directory).mkdir() + target = anyio.Path(directory) / "us.flag" + + trigger = DirectoryFileDeleteTrigger(directory=str(directory), filename="us.flag", poke_interval=0.05) + task = asyncio.create_task(trigger.run().__anext__()) + + await asyncio.sleep(0.2) + assert task.done() is False + + await target.touch() + event = await asyncio.wait_for(task, timeout=1.0) + + assert event.payload == {"filepath": str(target)} + assert await target.exists() is False From dd43683b3eb20ce789a35bcaf4c1a45d54fe7106 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 21 May 2026 20:54:09 +0800 Subject: [PATCH 2/9] fixup! feat(triggerer): share one poll across sibling event triggers --- .../tests/unit/triggers/test_shared_stream.py | 52 ++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/airflow-core/tests/unit/triggers/test_shared_stream.py b/airflow-core/tests/unit/triggers/test_shared_stream.py index 0158f7cfdb5a2..28577ccbb528f 100644 --- a/airflow-core/tests/unit/triggers/test_shared_stream.py +++ b/airflow-core/tests/unit/triggers/test_shared_stream.py @@ -22,7 +22,12 @@ import pytest from airflow.triggers.base import BaseEventTrigger, TriggerEvent -from airflow.triggers.shared_stream import SharedStreamManager, _SubscriberOverflow +from airflow.triggers.shared_stream import ( + SharedStreamManager, + _PollFailure, + _SharedStreamGroup, + _SubscriberOverflow, +) class _ProgrammableSharedStreamTrigger(BaseEventTrigger): @@ -633,3 +638,48 @@ async def _open_shared_stream(cls, kwargs): await manager.unsubscribe(1, key) await manager.unsubscribe(2, key) + + +@pytest.mark.asyncio +async def test_fail_overflowed_subscriber_drains_full_queue_before_putting_sentinel(): + """``_fail_overflowed_subscriber`` must drain the backlog *before* placing the + failure sentinel, not after. + + White-box invariant: given a queue already at capacity, calling + ``_fail_overflowed_subscriber`` must leave exactly one item in the queue — + the :class:`_PollFailure` wrapping a :class:`_SubscriberOverflow` — regardless + of how many stale events were sitting there beforehand. + + If the drain loop were moved to *after* the ``put_nowait``, the put would + raise :exc:`asyncio.QueueFull` before any draining occurred and the + subscriber would never receive its failure sentinel. + """ + import structlog + + cap = 3 + queue: asyncio.Queue = asyncio.Queue(maxsize=cap) + # Pre-fill the queue to capacity with stale events. + for i in range(cap): + queue.put_nowait({"stale": i}) + + assert queue.full(), "pre-condition: queue must be full before the call" + + group = _SharedStreamGroup( + key="test-key", + trigger_class=_ProgrammableSharedStreamTrigger, + kwargs={}, + on_poll_terminate=lambda g: None, + max_subscriber_queue=cap, + log=structlog.get_logger("test"), + ) + trigger_id = 42 + group._subscribers[trigger_id] = queue + + group._fail_overflowed_subscriber(trigger_id, queue) + + # Post-conditions that pin the drain-before-put ordering: + assert queue.qsize() == 1, "exactly one item must remain: the failure sentinel" + sentinel = queue.get_nowait() + assert isinstance(sentinel, _PollFailure), "sentinel must be a _PollFailure" + assert isinstance(sentinel.exc, _SubscriberOverflow), "the wrapped exception must be _SubscriberOverflow" + assert trigger_id in group._overflowed, "trigger_id must be recorded in _overflowed" From 1d1cea102cd7c8faad6b90a5b97e9b89b956ea79 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 22 May 2026 10:01:34 +0800 Subject: [PATCH 3/9] fixup! fixup! feat(triggerer): share one poll across sibling event triggers --- airflow-core/src/airflow/triggers/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow-core/src/airflow/triggers/base.py b/airflow-core/src/airflow/triggers/base.py index e7c83bada61d6..523c9359532a9 100644 --- a/airflow-core/src/airflow/triggers/base.py +++ b/airflow-core/src/airflow/triggers/base.py @@ -335,7 +335,7 @@ def shared_stream_key(self) -> Hashable | None: This method is called **before** :meth:`render_template_fields`, so any templated attribute (for example a ``directory`` derived - from a Jinja expression) is still in its unrendered form here. + from a Jinja expression) is still its raw template string here. Keying on such an attribute means two sibling triggers that render to the same path will not share their poll. Either base the key only on already-resolved attributes, or render the From cfb3359222e856db744eb3ab8f353ed83c94eafa Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 22 May 2026 11:16:52 +0800 Subject: [PATCH 4/9] fixup! feat(triggerer): share one poll across sibling event triggers - triggerer: compute shared_stream_key after render_template_fields so templated attributes resolve before keying - shared_stream: extract _drain_and_offer_failure helper; reuse from terminal broadcast and overflow paths - DirectoryFileDeleteTrigger: normalise directory via realpath so relative/absolute/symlink/trailing-slash variants share one scan --- .../src/airflow/jobs/triggerer_job_runner.py | 24 ++++++----- airflow-core/src/airflow/triggers/base.py | 12 +++--- .../src/airflow/triggers/shared_stream.py | 43 +++++++++---------- .../providers/standard/triggers/file.py | 9 ++-- .../tests/unit/standard/triggers/test_file.py | 29 +++++++++++++ 5 files changed, 74 insertions(+), 43 deletions(-) diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index cb8a8be4ef3e2..66576da230456 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -1453,25 +1453,29 @@ async def run_trigger( # (BaseEventTrigger.shared_stream_key returns non-None) consume a # broadcast stream produced by SharedStreamManager and convert it # via filter_shared_stream(). Everything else stays on the original - # standalone-run() path. - shared_key: Hashable | None = None + # standalone-run() path. The key is computed after + # render_template_fields so any templated attributes are already + # resolved when the key is constructed. event_trigger: BaseEventTrigger | None = None if isinstance(trigger, BaseEventTrigger): event_trigger = trigger - try: - shared_key = event_trigger.shared_stream_key() - except Exception: - self.log.exception( - "shared_stream_key() raised; falling back to standalone run", - trigger_id=trigger_id, - ) - shared_key = None + shared_key: Hashable | None = None with _make_trigger_span(ti=trigger.task_instance, trigger_id=trigger_id, name=name) as span: try: if context is not None: trigger.render_template_fields(context=context) + if event_trigger is not None: + try: + shared_key = event_trigger.shared_stream_key() + except Exception: + self.log.exception( + "shared_stream_key() raised; falling back to standalone run", + trigger_id=trigger_id, + ) + shared_key = None + if shared_key is not None and event_trigger is not None: shared_stream = self._shared_streams.subscribe( trigger_id=trigger_id, trigger=event_trigger, key=shared_key diff --git a/airflow-core/src/airflow/triggers/base.py b/airflow-core/src/airflow/triggers/base.py index 523c9359532a9..ae9832d3b2328 100644 --- a/airflow-core/src/airflow/triggers/base.py +++ b/airflow-core/src/airflow/triggers/base.py @@ -331,15 +331,13 @@ def shared_stream_key(self) -> Hashable | None: of sibling triggers, ensure every trigger in the set returns the same key from the outset. - .. warning:: + .. note:: - This method is called **before** :meth:`render_template_fields`, + This method is called **after** :meth:`render_template_fields`, so any templated attribute (for example a ``directory`` derived - from a Jinja expression) is still its raw template string here. - Keying on such an attribute means two sibling triggers that - render to the same path will not share their poll. Either base - the key only on already-resolved attributes, or render the - relevant fields yourself before constructing the key. + from a Jinja expression) is already resolved when the key is + constructed. Two sibling triggers that render to the same path + will correctly share their poll. """ return None diff --git a/airflow-core/src/airflow/triggers/shared_stream.py b/airflow-core/src/airflow/triggers/shared_stream.py index 388a832509f3d..f1bd1ecdf6f4f 100644 --- a/airflow-core/src/airflow/triggers/shared_stream.py +++ b/airflow-core/src/airflow/triggers/shared_stream.py @@ -217,19 +217,9 @@ async def _poll(self) -> None: self._on_poll_terminate(self) failure = _PollFailure(terminal_exc) for queue in self._subscribers.values(): - # A subscriber whose queue is already at capacity (slow - # consumer, or an unread ``_SubscriberOverflow`` sentinel) - # would raise ``QueueFull`` here and abort the broadcast, - # leaving later subscribers without the terminal sentinel. - # Drain whatever stale events are queued — they become - # irrelevant once the poll is terminating — and then put - # the failure so every subscriber wakes up. - while not queue.empty(): - try: - queue.get_nowait() - except asyncio.QueueEmpty: - break - queue.put_nowait(failure) + # Drain stale events then put the failure sentinel so every + # subscriber wakes up even if its queue was at capacity. + self._drain_and_offer_failure(queue, failure) def subscribe(self, trigger_id: int) -> AsyncIterator[Any]: """Register ``trigger_id`` as a subscriber and return its raw event stream.""" @@ -264,23 +254,32 @@ def _fail_overflowed_subscriber(self, trigger_id: int, queue: asyncio.Queue) -> trigger_id=trigger_id, queue_maxsize=queue.maxsize, ) - # Discard the unread backlog so the subscriber jumps straight to the - # failure on its next get() instead of chewing through stale events. - while not queue.empty(): - try: - queue.get_nowait() - except asyncio.QueueEmpty: - break - queue.put_nowait( + self._drain_and_offer_failure( + queue, _PollFailure( _SubscriberOverflow( f"shared stream {self.key!r} fell behind for trigger {trigger_id}: " f"subscriber queue exceeded maxsize={queue.maxsize}" ) - ) + ), ) self._overflowed.add(trigger_id) + def _drain_and_offer_failure(self, queue: asyncio.Queue, failure: _PollFailure) -> None: + """ + Drain ``queue`` and put ``failure`` so the subscriber wakes on the failure. + + The drain releases capacity so the subsequent ``put_nowait`` cannot raise + ``QueueFull``; this is the single point that both the terminal-broadcast + and the per-subscriber overflow path go through. + """ + while not queue.empty(): + try: + queue.get_nowait() + except asyncio.QueueEmpty: + break + queue.put_nowait(failure) + def is_empty(self) -> bool: return not self._subscribers diff --git a/providers/standard/src/airflow/providers/standard/triggers/file.py b/providers/standard/src/airflow/providers/standard/triggers/file.py index aa4548d7dabca..b4b2285db3ed3 100644 --- a/providers/standard/src/airflow/providers/standard/triggers/file.py +++ b/providers/standard/src/airflow/providers/standard/triggers/file.py @@ -172,10 +172,11 @@ def serialize(self) -> tuple[str, dict[str, Any]]: def shared_stream_key(self) -> Hashable | None: """All triggers on the same directory + cadence share one scan.""" - # Normalise so trivial path variants (``/tmp/flags`` vs ``/tmp/flags/``, - # or ``./flags`` vs an absolute equivalent in the same cwd) key to the - # same group instead of silently running N independent scans. - return ("directory-scan", os.path.normpath(self.directory), self.poke_interval) + # Use realpath so trivial path variants all resolve to the same canonical + # path: trailing slashes (``/tmp/flags`` vs ``/tmp/flags/``), relative vs + # absolute paths (``./flags`` vs ``/tmp/flags``), and symlinks vs their + # targets all key to the same group instead of running N independent scans. + return ("directory-scan", os.path.realpath(self.directory), self.poke_interval) @classmethod async def open_shared_stream(cls, kwargs: dict[str, Any]) -> AsyncIterator[Any]: diff --git a/providers/standard/tests/unit/standard/triggers/test_file.py b/providers/standard/tests/unit/standard/triggers/test_file.py index 17b97c8a5d710..2ee97fa9c5cf8 100644 --- a/providers/standard/tests/unit/standard/triggers/test_file.py +++ b/providers/standard/tests/unit/standard/triggers/test_file.py @@ -152,6 +152,35 @@ def test_shared_stream_key_normalises_trivial_path_variants(self, first, second) b = DirectoryFileDeleteTrigger(directory=second, filename="us.flag", poke_interval=1.0) assert a.shared_stream_key() == b.shared_stream_key() + def test_shared_stream_key_realpath_trailing_slash(self, tmp_path): + """Trailing slash variant keys to the same group as the plain path.""" + real_dir = str(tmp_path / "flags") + a = DirectoryFileDeleteTrigger(directory=real_dir, filename="f", poke_interval=1.0) + b = DirectoryFileDeleteTrigger(directory=real_dir + "/", filename="f", poke_interval=1.0) + assert a.shared_stream_key() == b.shared_stream_key() + + def test_shared_stream_key_realpath_relative_vs_absolute(self, tmp_path, monkeypatch): + """A relative path resolves to the same key as its absolute equivalent.""" + monkeypatch.chdir(tmp_path) + a = DirectoryFileDeleteTrigger(directory=".", filename="f", poke_interval=1.0) + b = DirectoryFileDeleteTrigger(directory=str(tmp_path), filename="f", poke_interval=1.0) + assert a.shared_stream_key() == b.shared_stream_key() + + @pytest.mark.skipif( + not hasattr(__import__("os"), "symlink"), + reason="symlinks not supported on this platform", + ) + def test_shared_stream_key_realpath_symlink_vs_target(self, tmp_path): + """A symlink and its target resolve to the same key.""" + + real_dir = tmp_path / "real" + real_dir.mkdir() + link_dir = tmp_path / "link" + link_dir.symlink_to(real_dir) + a = DirectoryFileDeleteTrigger(directory=str(real_dir), filename="f", poke_interval=1.0) + b = DirectoryFileDeleteTrigger(directory=str(link_dir), filename="f", poke_interval=1.0) + assert a.shared_stream_key() == b.shared_stream_key() + @pytest.mark.asyncio async def test_filter_shared_stream_fires_only_for_own_filename(self, tmp_path): directory = tmp_path / "flags" From 162c9055737875499651062a8e44a12ac9849d2e Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 22 May 2026 17:43:01 +0800 Subject: [PATCH 5/9] fixup! feat(triggerer): share one poll across sibling event triggers - DirectoryFileDeleteTrigger.open_shared_stream: raise on PermissionError / NotADirectoryError / IsADirectoryError so config bugs surface in the UI instead of silently spinning a warning loop; keep swallow + retry for the rest of OSError (transient I/O) --- .../providers/standard/triggers/file.py | 17 +++-- .../tests/unit/standard/triggers/test_file.py | 67 +++++++++++-------- 2 files changed, 52 insertions(+), 32 deletions(-) diff --git a/providers/standard/src/airflow/providers/standard/triggers/file.py b/providers/standard/src/airflow/providers/standard/triggers/file.py index b4b2285db3ed3..8c7b8894c26d5 100644 --- a/providers/standard/src/airflow/providers/standard/triggers/file.py +++ b/providers/standard/src/airflow/providers/standard/triggers/file.py @@ -184,11 +184,16 @@ async def open_shared_stream(cls, kwargs: dict[str, Any]) -> AsyncIterator[Any]: Drive one directory-listing loop and broadcast each snapshot. Missing directories yield an empty snapshot so subscribers keep - polling for the file to appear. Other ``OSError`` cases (permission - denied, transient I/O errors) are logged and the snapshot is skipped - for this cadence — failing the shared poll outright would - cascade-fail every sibling watcher on the same directory for what - may be a transient blip. + polling for the file to appear. Configuration-class failures + (``PermissionError``, ``NotADirectoryError``, ``IsADirectoryError``) + propagate — these are almost always permanent (wrong mount, wrong + mode, path points at a file), so silently retrying just hides the + misconfiguration from the operator; surfacing them as a + ``_PollFailure`` makes the trigger visibly fail in the UI, where it + can be diagnosed and restarted after the operator corrects the + config. Other ``OSError`` subclasses (transient I/O blips, NFS + hiccups, etc.) are logged at warning and the snapshot is skipped for + this cadence, since those may self-heal. """ directory = anyio.Path(kwargs["directory"]) poke_interval: float = kwargs["poke_interval"] @@ -197,6 +202,8 @@ async def open_shared_stream(cls, kwargs: dict[str, Any]) -> AsyncIterator[Any]: names = {p.name async for p in directory.iterdir()} except FileNotFoundError: names = set() + except (PermissionError, NotADirectoryError, IsADirectoryError): + raise except OSError: log.warning( "Failed to list %s; retrying after %ss", diff --git a/providers/standard/tests/unit/standard/triggers/test_file.py b/providers/standard/tests/unit/standard/triggers/test_file.py index 2ee97fa9c5cf8..309e922bbfec9 100644 --- a/providers/standard/tests/unit/standard/triggers/test_file.py +++ b/providers/standard/tests/unit/standard/triggers/test_file.py @@ -247,43 +247,56 @@ async def consume(): assert all(s["names"] == set() for s in snapshots) + @pytest.mark.parametrize( + "exc_cls", + [PermissionError, NotADirectoryError, IsADirectoryError], + ) @pytest.mark.asyncio - async def test_open_shared_stream_logs_and_retries_on_permission_error(self, tmp_path, mocker): - """A transient ``PermissionError`` from ``iterdir`` must not cascade-fail every sibling - watcher. The shared poll logs at warning level, sleeps for one poke, and tries again on - the next cadence so a brief perms blip is recoverable. - """ - # Two failures, then succeed -- proves the poll keeps retrying instead - # of propagating to subscribers. - states: list[set[str]] = [set(), {"us.flag"}] + async def test_open_shared_stream_raises_on_config_bug_oserror(self, mocker, tmp_path, exc_cls): + """PermissionError, NotADirectoryError, and IsADirectoryError must propagate rather than spin.""" async def _iterdir(self): - if not states: - if False: - yield # pragma: no cover - sentinel for async generator typing - return - state = states.pop(0) - if state == set(): - raise PermissionError("denied") - for name in state: - yield anyio.Path("/tmp") / name + raise exc_cls("config bug") + if False: + yield # pragma: no cover - sentinel for async generator typing mocker.patch.object(anyio.Path, "iterdir", _iterdir) - warning = mocker.patch("airflow.providers.standard.triggers.file.log.warning") directory = tmp_path / "flags" - snapshots = [] + gen = DirectoryFileDeleteTrigger.open_shared_stream( + {"directory": str(directory), "poke_interval": 0.01} + ) + with pytest.raises(exc_cls): + await gen.__anext__() - async def consume(): - it = DirectoryFileDeleteTrigger.open_shared_stream( - {"directory": str(directory), "poke_interval": 0.01} - ).__aiter__() - snapshots.append(await it.__anext__()) + @pytest.mark.asyncio + async def test_open_shared_stream_swallows_transient_oserror(self, tmp_path, mocker): + """A generic OSError is logged and retried; the snapshot from the next call is yielded.""" + call_count = 0 - await asyncio.wait_for(consume(), timeout=2.0) + async def _iterdir(self): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise OSError("transient blip") + if False: + yield # pragma: no cover - sentinel for async generator typing + + mocker.patch.object(anyio.Path, "iterdir", _iterdir) + + async def _noop_sleep(_duration): + pass + + mocker.patch("asyncio.sleep", side_effect=_noop_sleep) + + directory = tmp_path / "flags" + gen = DirectoryFileDeleteTrigger.open_shared_stream( + {"directory": str(directory), "poke_interval": 0.01} + ) + snapshot = await gen.__anext__() - assert snapshots == [{"directory": str(directory), "names": {"us.flag"}}] - assert warning.called, "PermissionError must produce a warning, not be silently swallowed" + assert snapshot == {"directory": str(directory), "names": set()} + assert call_count == 2 @pytest.mark.asyncio async def test_run_standalone_fallback_polls_until_filename_appears(self, tmp_path): From 7957067152c3aa142bb72b6cf19391de3356b315 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Sat, 23 May 2026 13:09:08 +0800 Subject: [PATCH 6/9] fixup! fixup! feat(triggerer): share one poll across sibling event triggers --- airflow-core/src/airflow/jobs/triggerer_job_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 66576da230456..c73dfd857bd92 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -1453,7 +1453,7 @@ async def run_trigger( # (BaseEventTrigger.shared_stream_key returns non-None) consume a # broadcast stream produced by SharedStreamManager and convert it # via filter_shared_stream(). Everything else stays on the original - # standalone-run() path. The key is computed after + # standalone-run() path. The key is computed after # render_template_fields so any templated attributes are already # resolved when the key is constructed. event_trigger: BaseEventTrigger | None = None From a432e05636cc8a398654d8b1d97794ffb3a009b5 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Sat, 23 May 2026 13:49:36 +0800 Subject: [PATCH 7/9] fixup! feat(triggerer): share one poll across sibling event triggers Round-3 doc cleanup for jason's review (C1/C2/C6): - Drop Kafka/SQS recommendations from event-scheduling.rst and BaseEventTrigger class docstring; the producer-side ack channel is out of scope this iteration. - Document the deterministic-key requirement on shared_stream_key and add a Slow-subscriber overflow mitigations section (raise subscriber queue size, redesign the key to narrow groups). --- .../event-scheduling.rst | 36 ++++++++++++++----- airflow-core/src/airflow/triggers/base.py | 26 +++++++------- 2 files changed, 41 insertions(+), 21 deletions(-) diff --git a/airflow-core/docs/authoring-and-scheduling/event-scheduling.rst b/airflow-core/docs/authoring-and-scheduling/event-scheduling.rst index b5d947bd6a027..ed3411e5410bd 100644 --- a/airflow-core/docs/authoring-and-scheduling/event-scheduling.rst +++ b/airflow-core/docs/authoring-and-scheduling/event-scheduling.rst @@ -70,7 +70,7 @@ Sharing one poll across sibling triggers .. versionadded:: 3.3 When several ``AssetWatcher`` instances on different assets back triggers that read from the **same upstream resource** -— a directory of flag files, a polling REST endpoint, a Kafka topic with auto-commit, and similar idempotent or +— a directory of flag files, a polling REST endpoint, and similar idempotent or subscriber-side-effect sources — the triggerer would otherwise spin up one independent poll loop per trigger. For a shared source with twenty subscribers that means twenty poll loops, twenty connections, twenty sets of API calls per cadence. See "Suitable upstreams" below for the precise scope. @@ -84,6 +84,8 @@ subclass overrides three hooks: (the default) opts out — the trigger runs its own independent ``run()`` loop, exactly as before. The return value is read **once** when the triggerer starts this trigger; changing it mid-lifetime has no effect on group membership, so siblings that should share a poll must return the same key from the outset. + The key must be deterministic — derive it from configuration fields, never from per-call values such as + ``time.time()`` or ``uuid.uuid4()``, because the comparison must be stable across the lifetime of the group. * :py:meth:`~airflow.triggers.base.BaseEventTrigger.open_shared_stream` — a ``@classmethod`` coroutine the triggerer drives **once per shared-stream group** to yield raw events from the upstream. Because the triggerer reuses one @@ -156,18 +158,16 @@ whose consumption does **not** depend on a side effect on a handle that only the producer holds. Good fits: * Idempotent / read-only reads — directory scans, polling REST APIs. -* Auto-commit Kafka consumers (``enable.auto.commit=true``). * Subscriber-side-effect cleanup, where the trigger's per-event action (``unlink``, local marking, …) goes through APIs the subscriber owns independently of the shared producer handle. -Currently **not** in scope: Kafka consumers with manual commit, SQS with -delete-on-process or visibility extension, and any source where progress -on the producer's handle is tied to the subscriber's accept / reject -decision. A producer-side ack channel to cover those cases is a planned -follow-up; it should be designed against a concrete Kafka or SQS consumer -rather than against an abstract API, so it is intentionally left out of -the first iteration. +Currently **not** in scope: Kafka consumers (regardless of commit mode), +SQS with delete-on-process or visibility extension, and any source where +progress on the producer's handle is tied to the subscriber's accept / +reject decision. These sources need a way for the subscriber to signal +acceptance back to the producer, which the current shared-stream API does +not provide. Verifying that sharing is active ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -186,6 +186,24 @@ If sharing is active you should see exactly one ``Shared stream group started`` how many subscribers join it. If you see one log line per subscriber instead, the keys probably do not compare equal — verify that ``shared_stream_key`` returns identical values across the siblings. +Slow-subscriber overflow +^^^^^^^^^^^^^^^^^^^^^^^^ + +Each subscriber in a shared-stream group has a bounded in-memory queue. If the poll loop +produces events faster than a subscriber's ``filter_shared_stream`` can consume them, the +queue fills and that trigger is failed with ``_SubscriberOverflow`` — a deliberate fail-fast +rather than unbounded memory growth. + +If subscribers repeatedly overflow, there are two mitigations: + +* Raise ``[triggerer] shared_stream_subscriber_queue_size`` to give the + filter more slack before the overflow threshold is reached. +* Redesign :py:meth:`~airflow.triggers.base.BaseEventTrigger.shared_stream_key` so fewer + sibling triggers share a single group — a narrower group reduces the rate at which any + one subscriber needs to consume events. + +Both reduce the mismatch between producer throughput and per-subscriber consume rate. + Avoid infinite scheduling ~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/airflow-core/src/airflow/triggers/base.py b/airflow-core/src/airflow/triggers/base.py index ae9832d3b2328..7cda8cc4272cd 100644 --- a/airflow-core/src/airflow/triggers/base.py +++ b/airflow-core/src/airflow/triggers/base.py @@ -255,10 +255,10 @@ class BaseEventTrigger(BaseTrigger): **Sharing an underlying I/O stream between triggers** A subclass that polls an upstream resource which can be safely consumed - by multiple sibling triggers (e.g. a directory scan, a polling REST API, - a Kafka topic read with ``enable.auto.commit=true``) may opt in to having - the triggerer run a single underlying poll loop and fan its raw events - out to every trigger in the group. To do so, override: + by multiple sibling triggers (e.g. a directory scan, a polling REST API) + may opt in to having the triggerer run a single underlying poll loop + and fan its raw events out to every trigger in the group. To do so, + override: * :meth:`shared_stream_key` — return a key identifying the shared stream (a ``tuple`` of strings is a common choice). Triggers @@ -283,18 +283,16 @@ class BaseEventTrigger(BaseTrigger): on a side effect on a handle that only the producer holds: * Idempotent / read-only reads (filesystem listings, polling REST APIs). - * Auto-commit consumers, e.g. Kafka with ``enable.auto.commit=true``. * Subscriber-side-effect cleanup, where the trigger's per-event action (``unlink``, local marking, …) operates through APIs the subscriber already owns, independent of the shared producer handle. - Upstreams that do **not** fit this scope today include Kafka consumers - with manual commit, SQS with delete-on-process or visibility extension, - and any source where producer-side commit / advance is tied to the - subscriber's accept / reject decision. Adding a producer-side ack - channel to support those cases is tracked as a follow-up — to be - designed against a concrete Kafka or SQS consumer rather than against - an abstract API. + Upstreams **not** in scope include Kafka consumers (regardless of + commit mode), SQS with delete-on-process or visibility extension, + and any source where progress on the producer's handle is tied to + the subscriber's accept / reject decision. These sources need a way + for the subscriber to signal acceptance back to the producer, which + the current shared-stream API does not provide. """ supports_triggerer_queue: bool = False @@ -331,6 +329,10 @@ def shared_stream_key(self) -> Hashable | None: of sibling triggers, ensure every trigger in the set returns the same key from the outset. + The key must be deterministic — derive it from configuration fields, + never from per-call values such as ``time.time()`` or ``uuid.uuid4()``, + because the comparison must be stable across the lifetime of the group. + .. note:: This method is called **after** :meth:`render_template_fields`, From 3ddd18ad1943a6e70c28f9756de6ccc8f04e9050 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Sat, 23 May 2026 13:52:07 +0800 Subject: [PATCH 8/9] fixup! feat(triggerer): share one poll across sibling event triggers Address jason's C4 suggestion: collapse the get-and-None-check on SharedStreamManager.subscribe into a single walrus expression. --- airflow-core/src/airflow/triggers/shared_stream.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/airflow-core/src/airflow/triggers/shared_stream.py b/airflow-core/src/airflow/triggers/shared_stream.py index f1bd1ecdf6f4f..f347a265524a1 100644 --- a/airflow-core/src/airflow/triggers/shared_stream.py +++ b/airflow-core/src/airflow/triggers/shared_stream.py @@ -332,8 +332,7 @@ def subscribe( """ if key is None: raise ValueError("shared stream key must not be None") - group = self._groups.get(key) - if group is None: + if (group := self._groups.get(key)) is None: _, kwargs = trigger.serialize() group = _SharedStreamGroup( key=key, From 0e0abfd15099b706fa6661b04385cdd4b94c70fb Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Sat, 23 May 2026 15:38:41 +0800 Subject: [PATCH 9/9] fixup! fixup! feat(triggerer): share one poll across sibling event triggers --- airflow-core/docs/authoring-and-scheduling/event-scheduling.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow-core/docs/authoring-and-scheduling/event-scheduling.rst b/airflow-core/docs/authoring-and-scheduling/event-scheduling.rst index ed3411e5410bd..b0118bf0d8be5 100644 --- a/airflow-core/docs/authoring-and-scheduling/event-scheduling.rst +++ b/airflow-core/docs/authoring-and-scheduling/event-scheduling.rst @@ -194,7 +194,7 @@ produces events faster than a subscriber's ``filter_shared_stream`` can consume queue fills and that trigger is failed with ``_SubscriberOverflow`` — a deliberate fail-fast rather than unbounded memory growth. -If subscribers repeatedly overflow, there are two mitigations: +If subscribers repeatedly overflow, there are two ways to address this: * Raise ``[triggerer] shared_stream_subscriber_queue_size`` to give the filter more slack before the overflow threshold is reached.