Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions backend/pg_queue/migrations/0004_pgbarrierstate_and_more.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Generated by Django 4.2.1 on 2026-06-15 06:55

import django.utils.timezone
from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("pg_queue", "0003_pgqueuemessage_pg_queue_message_priority_range"),
]

operations = [
migrations.CreateModel(
name="PgBarrierState",
fields=[
("execution_id", models.TextField(primary_key=True, serialize=False)),
("remaining", models.IntegerField()),
("results", models.JSONField(default=list)),
("created_at", models.DateTimeField(default=django.utils.timezone.now)),
("expires_at", models.DateTimeField()),
],
options={
"db_table": "pg_barrier_state",
"indexes": [
models.Index(fields=["expires_at"], name="pg_barrier_expires_idx")
],
},
),
migrations.AddConstraint(
model_name="pgbarrierstate",
constraint=models.CheckConstraint(
check=models.Q(("expires_at__gt", models.F("created_at"))),
name="pg_barrier_expires_after_created",
),
),
]
49 changes: 49 additions & 0 deletions backend/pg_queue/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,52 @@ class Meta:
name="pg_queue_message_dequeue_idx",
)
]


class PgBarrierState(models.Model):
"""Per-execution fan-in barrier state for ``PgBarrier`` (the Postgres
``WORKER_BARRIER_BACKEND``).

One row per in-flight barrier (keyed by ``execution_id``). The worker-side
``barrier_pg_decr_and_check`` link task atomically decrements ``remaining``
and appends to ``results`` in a single ``UPDATE … RETURNING``; the task that
drives ``remaining`` to 0 dispatches the aggregating callback and deletes the
row. A header-task failure aborts the barrier by deleting the row outright
(``DELETE … RETURNING`` — atomic claim+teardown), so the callback can never
fire with partial results. ``expires_at`` bounds an orphaned barrier (header
tasks that never complete); a periodic sweep job (not yet implemented) is the
intended reclaim backstop.

Managed=True / generated migration — no DB-side function, extension-free
(UN-3533), same posture as ``PgQueueMessage``.
"""

execution_id = models.TextField(primary_key=True)
# Header tasks still pending. The last task to decrement it to 0 fires the
# callback. A value < 0 (decrement after expiry/cleanup) means the barrier
# was already torn down — the task cleans up without firing.
remaining = models.IntegerField()
# Aggregated header-task results, appended in completion order (JSONB array).
results = models.JSONField(default=list)
created_at = models.DateTimeField(default=timezone.now)
# Orphan bound (Redis-TTL equivalent): a barrier whose header tasks never
# complete is reclaimable past this. Must exceed the longest execution
# wall-clock, same budgeting as WORKER_BARRIER_KEY_TTL_SECONDS.
expires_at = models.DateTimeField()

class Meta:
Comment thread
muhammad-ali-e marked this conversation as resolved.
db_table = "pg_barrier_state"
constraints = [
# The one writer-proof invariant (the worker SQL can't import this
# model). Deliberately NO `remaining >= 0` check — the teardown path
# relies on `remaining` going negative as a "barrier already gone"
# signal, so a non-negative constraint would break it.
models.CheckConstraint(
check=models.Q(expires_at__gt=models.F("created_at")),
name="pg_barrier_expires_after_created",
),
]
indexes = [
# Drives the (future) periodic expiry-sweep job.
models.Index(fields=["expires_at"], name="pg_barrier_expires_idx"),
]
11 changes: 11 additions & 0 deletions workers/queue_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@
from .decorator import worker_task
from .dispatch import dispatch
from .fairness import FairnessKey
from .pg_barrier import (
PgBarrier,
barrier_pg_abort,
barrier_pg_decr_and_check,
)
from .redis_barrier import (
RedisDecrBarrier,
barrier_abort,
Expand All @@ -54,10 +59,13 @@
"BarrierHandle",
"CeleryChordBarrier",
"FairnessKey",
"PgBarrier",
"QueueBackend",
"RedisDecrBarrier",
"barrier_abort",
"barrier_decr_and_check",
"barrier_pg_abort",
"barrier_pg_decr_and_check",
"dispatch",
"get_barrier",
"select_backend",
Expand All @@ -77,6 +85,7 @@ class BarrierBackend(StrEnum):

CHORD = "chord"
REDIS = "redis"
PG = "pg"


def get_barrier() -> Barrier:
Expand Down Expand Up @@ -110,6 +119,8 @@ def get_barrier() -> Barrier:
return CeleryChordBarrier()
if backend is BarrierBackend.REDIS:
return RedisDecrBarrier()
if backend is BarrierBackend.PG:
return PgBarrier()
# Unreachable — StrEnum constructor would have raised above for
# anything not in the enum. Defensive raise so the type checker
# sees an exhaustive match.
Expand Down
48 changes: 47 additions & 1 deletion workers/queue_backend/barrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, Protocol
import os
from typing import TYPE_CHECKING, Any, Protocol, TypedDict

from celery import chord

Expand All @@ -47,6 +48,51 @@

logger = logging.getLogger(__name__)

# Shared barrier-key TTL — both the Redis and PG backends bound an orphaned
# barrier (header tasks that never complete) by the same env var, since only one
# backend is active per deployment. One definition here prevents drift.
_DEFAULT_BARRIER_TTL_SECONDS = 6 * 60 * 60 # 6h


def barrier_ttl_seconds() -> int:
"""Barrier TTL from ``WORKER_BARRIER_KEY_TTL_SECONDS`` (default 6h).

Read at call time so tests can flip it. Invalid / non-positive values raise,
matching ``get_barrier()``'s loud-on-misconfig posture — a TTL shorter than
execution wall-clock would tear barriers down early (spurious behaviour).
"""
raw = os.getenv("WORKER_BARRIER_KEY_TTL_SECONDS")
if raw is None:
return _DEFAULT_BARRIER_TTL_SECONDS
try:
value = int(raw)
except ValueError as exc:
raise ValueError(
f"WORKER_BARRIER_KEY_TTL_SECONDS={raw!r} is not an integer. Unset it "
f"to default to {_DEFAULT_BARRIER_TTL_SECONDS}s (6h)."
) from exc
if value <= 0:
raise ValueError(
f"WORKER_BARRIER_KEY_TTL_SECONDS={value} must be a positive integer. "
f"Unset it to default to {_DEFAULT_BARRIER_TTL_SECONDS}s (6h)."
)
return value


class CallbackDescriptor(TypedDict):
"""Serialisable aggregating-callback spec baked into a barrier link signature.

Crosses a Celery serialisation boundary (producer → broker → worker), so the
four-key contract is typed to catch a typo/rename before it surfaces as a
remote ``KeyError`` mid-aggregation. Shared by both the Redis and PG
backends. ``fairness_headers`` is ``None`` when the producer passed no key.
"""

task_name: str
kwargs: dict[str, Any]
queue: str
fairness_headers: dict[str, Any] | None


class Barrier(Protocol):
"""Fan-out-then-callback primitive.
Expand Down
Loading