diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index dc90d1c..9377d90 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,10 +17,10 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6.0.2 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6.2.0 with: python-version: "3.14" cache: "pip" @@ -37,7 +37,7 @@ jobs: run: pytest tests/unit/coordinator/ -v --cov=coordinator --cov-report=xml:coverage-coordinator.xml - name: Upload coordinator coverage to Codecov - uses: codecov/codecov-action@v5 + uses: codecov/codecov-action@v5.5.3 with: files: coverage-coordinator.xml flags: coordinator @@ -50,10 +50,10 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6.0.2 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6.2.0 with: python-version: "3.14" cache: "pip" @@ -63,7 +63,7 @@ jobs: python -m pip install --upgrade pip setuptools Cython cd styx-package && python setup.py build_ext --inplace && cd .. python -m pip install styx-package/ - python -m pip install msgspec pytest==9.0.2 pytest-cov pytest-asyncio psutil + python -m pip install msgspec pytest==9.0.2 pytest-cov pytest-asyncio psutil uvloop - name: Build Cython extensions run: python worker/setup.py build_ext --inplace @@ -72,7 +72,7 @@ jobs: run: pytest tests/unit/worker/ -v --cov=worker --cov-report=xml:coverage-worker.xml - name: Upload worker coverage to Codecov - uses: codecov/codecov-action@v5 + uses: codecov/codecov-action@v5.5.3 with: files: coverage-worker.xml flags: worker @@ -85,10 +85,10 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6.0.2 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6.2.0 with: python-version: "3.14" cache: "pip" @@ -104,7 +104,7 @@ jobs: run: pytest tests/unit/styx_package/ -v --cov=styx-package/styx --cov-report=xml:coverage-styx-package.xml - name: Upload styx-package coverage to Codecov - uses: codecov/codecov-action@v5 + uses: codecov/codecov-action@v5.5.3 with: files: coverage-styx-package.xml flags: styx-package @@ -118,10 +118,10 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6.0.2 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6.2.0 with: python-version: "3.14" cache: "pip" @@ -138,7 +138,7 @@ jobs: run: pytest tests/integration/ -v --cov=styx-package/styx --cov=coordinator --cov=worker --cov-report=xml:coverage-integration.xml - name: Upload integration coverage to Codecov - uses: codecov/codecov-action@v5 + uses: codecov/codecov-action@v5.5.3 with: files: coverage-integration.xml flags: integration diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index 61e843b..54d4859 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -10,13 +10,15 @@ concurrency: cancel-in-progress: true jobs: - e2e-tests: + e2e-base: + name: "E2E — base (non-migration)" + # Provides the legacy "e2e-tests" status check via the summary job below. runs-on: ubuntu-latest - timeout-minutes: 45 + timeout-minutes: 60 steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6.0.2 - name: Docker info run: | @@ -24,7 +26,7 @@ jobs: docker compose version - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6.2.0 with: python-version: "3.14" cache: "pip" @@ -34,37 +36,100 @@ jobs: python -m pip install --upgrade pip setuptools Cython cd styx-package && python setup.py build_ext --inplace && cd .. python -m pip install -r requirements.txt - # Install Styx package from repo root python -m pip install -e styx-package/ - name: Make scripts executable run: | chmod +x scripts/start_styx_cluster.sh scripts/stop_styx_cluster.sh || true - # Needed so Kafka advertises a reachable external listener - name: Set DOCKER_HOST_IP run: echo "DOCKER_HOST_IP=127.0.0.1" >> $GITHUB_ENV - - name: Run e2e tests + - name: Run non-migration e2e tests env: PYTHONUNBUFFERED: "1" run: | - pytest tests/e2e/ -m e2e -q + pytest tests/e2e/ -m "e2e" -q --log-level=DEBUG - - name: Collect worker and coordinator logs (always) + - name: Collect logs if: always() run: | mkdir -p artifacts - if [ -d logs ]; then - cp -r logs artifacts/ - else - echo "No logs folder found to copy." - fi + if [ -d logs ]; then cp -r logs artifacts/; fi + + - name: Upload artifacts + if: always() + uses: actions/upload-artifact@v7.0.0 + with: + name: e2e-base-artifacts + path: artifacts + if-no-files-found: ignore + + e2e-migration: + name: "E2E — migration" + runs-on: ubuntu-latest + timeout-minutes: 90 + + steps: + - name: Checkout + uses: actions/checkout@v6.0.2 + + - name: Docker info + run: | + docker version + docker compose version + + - name: Set up Python + uses: actions/setup-python@v6.2.0 + with: + python-version: "3.14" + cache: "pip" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip setuptools Cython + cd styx-package && python setup.py build_ext --inplace && cd .. + python -m pip install -r requirements.txt + python -m pip install -e styx-package/ - - name: Upload artifacts (always) + - name: Make scripts executable + run: | + chmod +x scripts/start_styx_cluster.sh scripts/stop_styx_cluster.sh || true + + - name: Set DOCKER_HOST_IP + run: echo "DOCKER_HOST_IP=127.0.0.1" >> $GITHUB_ENV + + - name: Run migration e2e tests + env: + PYTHONUNBUFFERED: "1" + run: | + pytest tests/e2e/ -m "e2e_migration" -q --log-level=DEBUG + + - name: Collect logs + if: always() + run: | + mkdir -p artifacts + if [ -d logs ]; then cp -r logs artifacts/; fi + + - name: Upload artifacts if: always() - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v7.0.0 with: - name: e2e-artifacts + name: e2e-migration-artifacts path: artifacts if-no-files-found: ignore + + # Legacy summary job — provides the "e2e-tests" status check that branch + # protection rules may reference from before the split into two jobs. + e2e-tests: + name: "e2e-tests" + if: always() + needs: [e2e-base, e2e-migration] + runs-on: ubuntu-latest + steps: + - name: Check results + run: | + if [[ "${{ needs.e2e-base.result }}" == "failure" || "${{ needs.e2e-migration.result }}" == "failure" ]]; then + echo "One or more E2E jobs failed" + exit 1 + fi diff --git a/.github/workflows/publish-helm-chart.yml b/.github/workflows/publish-helm-chart.yml index 7380443..4b23d47 100644 --- a/.github/workflows/publish-helm-chart.yml +++ b/.github/workflows/publish-helm-chart.yml @@ -15,10 +15,10 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6.0.2 - name: Set up Helm - uses: azure/setup-helm@v4 + uses: azure/setup-helm@v5.0.0 - name: Log in to GHCR run: echo "${{ secrets.GITHUB_TOKEN }}" | helm registry login ghcr.io --username ${{ github.repository_owner }} --password-stdin diff --git a/.github/workflows/publish-images.yml b/.github/workflows/publish-images.yml index c5d94d7..82cef8b 100644 --- a/.github/workflows/publish-images.yml +++ b/.github/workflows/publish-images.yml @@ -25,13 +25,13 @@ jobs: image_suffix: worker steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6.0.2 - name: Set up Build - uses: docker/setup-buildx-action@v3 + uses: docker/setup-buildx-action@v4.0.0 - name: Log in to GHCR - uses: docker/login-action@v3 + uses: docker/login-action@v4.0.0 with: registry: ghcr.io username: ${{ github.repository_owner }} @@ -39,7 +39,7 @@ jobs: - name: Extract metadata id: meta - uses: docker/metadata-action@v5 + uses: docker/metadata-action@v6.0.0 with: images: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }}-${{ matrix.image_suffix }} tags: | @@ -49,7 +49,7 @@ jobs: type=raw,value=latest,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }} - name: Build and push - uses: docker/build-push-action@v6 + uses: docker/build-push-action@v7.0.0 with: context: . file: ${{ matrix.dockerfile }} diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 19fdc12..1ae7a2e 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -16,10 +16,10 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6.0.2 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6.2.0 with: python-version: "3.14" cache: "pip" @@ -27,7 +27,7 @@ jobs: - name: Install Ruff run: | python -m pip install --upgrade pip - python -m pip install "ruff==0.15.6" + python -m pip install "ruff==0.15.7" - name: Ruff lint (must pass) run: ruff check . diff --git a/coordinator/coordinator_metadata.py b/coordinator/coordinator_metadata.py index 57f765d..58d6022 100644 --- a/coordinator/coordinator_metadata.py +++ b/coordinator/coordinator_metadata.py @@ -54,6 +54,19 @@ def __init__(self, networking: NetworkingManager, s3_client: S3Client) -> None: self.kafka_metadata_producer: AIOKafkaProducer | None = None self.max_operator_parallelism: int | None = None + # Migration fault tolerance: track pre-migration snapshot + self.pre_migration_snapshot_id: int = -1 + self.pre_migration_snapshot_pending: bool = False + # Deferred graph finalization: wait for post-migration snapshot before committing + self.post_migration_snapshot_pending: bool = False + # Deferred graph update: held until migration completes + self._pending_graph: StateflowGraph | None = None + + # Checkpoint-and-resume migration FT + self._migration_checkpoint_blob: bytes | None = None + self._migration_checkpoint_baseline_snap_id: int = -1 + self._migration_checkpoint_snapshot_complete: asyncio.Event = asyncio.Event() + async def start_kafka_metadata_producer(self) -> None: self.kafka_metadata_producer = AIOKafkaProducer( bootstrap_servers=[KAFKA_URL], @@ -109,6 +122,9 @@ async def send_recovery_to_participating_workers(self) -> None: worker_assignments = self.worker_pool.get_worker_assignments() participating_workers: list[Worker] = self.worker_pool.get_participating_workers() self.worker_is_healthy = {worker.worker_id: asyncio.Event() for worker in participating_workers} + logging.warning( + f"[RECOVERY] Sending InitRecovery to {len(participating_workers)} workers with snap_id={snap_id}", + ) async with asyncio.TaskGroup() as tg: for worker in participating_workers: tg.create_task( @@ -190,6 +206,35 @@ def register_snapshot( self.completed_epoch_counter = epoch_counter self.completed_t_counter = t_counter current_completed_snapshot: int = self.get_current_completed_snapshot_id() + logging.warning( + f"register_snapshot | worker={worker_id} snap={snapshot_id} " + f"cluster_min={current_completed_snapshot} prev={self.prev_completed_snapshot_id} " + f"baseline={self._migration_checkpoint_baseline_snap_id} " + f"blob_set={self._migration_checkpoint_blob is not None}", + ) + # Track pre-migration snapshot for recovery + if self.pre_migration_snapshot_pending and current_completed_snapshot != self.prev_completed_snapshot_id: + self.pre_migration_snapshot_id = current_completed_snapshot + self.pre_migration_snapshot_pending = False + logging.warning(f"Pre-migration snapshot recorded: {self.pre_migration_snapshot_id}") + # Finalize graph update after first post-migration snapshot completes + if self.post_migration_snapshot_pending and current_completed_snapshot != self.prev_completed_snapshot_id: + self.post_migration_snapshot_pending = False + self.finalize_graph_update() + logging.warning( + f"Post-migration snapshot {current_completed_snapshot} completed — graph finalized", + ) + # Detect migration checkpoint snapshot completion via baseline ID comparison + if ( + self._migration_checkpoint_baseline_snap_id >= 0 + and current_completed_snapshot > self._migration_checkpoint_baseline_snap_id + ): + self._migration_checkpoint_snapshot_complete.set() + logging.warning( + f"Migration checkpoint snapshot {current_completed_snapshot} completed " + f"(baseline was {self._migration_checkpoint_baseline_snap_id})", + ) + self._migration_checkpoint_baseline_snap_id = -1 if current_completed_snapshot != self.prev_completed_snapshot_id: logging.warning(f"Cluster completed snapshot: {current_completed_snapshot}") # if we reached a complete snapshot, we could compact its deltas with the previous one @@ -199,6 +244,7 @@ def register_snapshot( self.completed_out_offsets, self.completed_epoch_counter, self.completed_t_counter, + self._migration_checkpoint_blob, ), ) self.s3_client.put_object( @@ -225,6 +271,10 @@ async def update_stateflow_graph(self, new_stateflow_graph: StateflowGraph) -> N raise NotAStateflowGraphError # TODO the cluster was balanced by the previous deployment, if the graph is complex it might be # unbalanced after the update + logging.warning( + f"update_stateflow_graph | updating worker pool operators " + f"(new partitions: { {n: op.n_partitions for n, op in new_stateflow_graph.nodes.items()} })", + ) for _, operator in iter(new_stateflow_graph): for partition in range(self.max_operator_parallelism): operator_copy = deepcopy(operator) @@ -235,6 +285,15 @@ async def update_stateflow_graph(self, new_stateflow_graph: StateflowGraph) -> N (operator_copy.name, partition), operator_copy, ) + # Set _pending_graph BEFORE sending InitMigration to workers. + # This ensures recovery can detect an in-flight migration even if the + # gather is stuck retrying a dead worker connection for tens of seconds. + # Recovery clears _pending_graph, so if the gather completes after + # recovery, it harmlessly re-sets the field (which _perform_recovery + # already handles via saved_migration_in_progress). + self._pending_graph = new_stateflow_graph + logging.warning("update_stateflow_graph | _pending_graph set (before send)") + worker_assignments = self.worker_pool.get_worker_assignments() tasks = [ self.networking.send_message( @@ -252,15 +311,68 @@ async def update_stateflow_graph(self, new_stateflow_graph: StateflowGraph) -> N for worker in self.worker_pool.get_participating_workers() ] await asyncio.gather(*tasks) - self.submitted_graph = new_stateflow_graph + logging.warning("update_stateflow_graph | InitMigration sent to all workers") + + def revert_worker_pool_to_submitted_graph(self) -> None: + """Revert worker pool operators to match the current submitted_graph. + + Called before recovery when migration was in progress — undoes the + operator promotions (shadow → active) that update_stateflow_graph() + applied so that recovery sends the correct (OLD) layout to workers. + Also updates orphaned_operator_assignments so that dead-worker + partitions are rescheduled with the OLD layout operator definitions. + """ + if self.submitted_graph is None or self.max_operator_parallelism is None: + return + for _, operator in iter(self.submitted_graph): + for partition in range(self.max_operator_parallelism): + operator_copy = deepcopy(operator) + if partition >= operator.n_partitions: + operator_copy.make_shadow() + op_part = (operator_copy.name, partition) + self.worker_pool.update_operator(op_part, operator_copy) + # Also revert orphaned (dead-worker) partitions + if op_part in self.worker_pool.orphaned_operator_assignments: + self.worker_pool.orphaned_operator_assignments[op_part] = operator_copy + logging.warning("Reverted worker pool operators to match submitted_graph (pre-migration layout)") + + def finalize_graph_update(self) -> None: + """Commit the deferred graph update after migration completes successfully.""" + if self._pending_graph is None: + logging.warning("finalize_graph_update | no-op (_pending_graph is None)") + return + logging.warning("finalize_graph_update | committing pending graph to submitted_graph") + self.submitted_graph = self._pending_graph + self._pending_graph = None + # Write to Kafka via a background task; the producer is already running + # and the next snapshot will capture the new layout. metadata_key = msgpack_serialization(self.submitted_graph.name) - serialized_graph = cloudpickle_serialization(new_stateflow_graph) - await self.kafka_metadata_producer.send_and_wait( - "styx-metadata", - key=metadata_key, - value=serialized_graph, + serialized_graph = cloudpickle_serialization(self.submitted_graph) + self._pending_kafka_write = asyncio.ensure_future( + self.kafka_metadata_producer.send_and_wait( + "styx-metadata", + key=metadata_key, + value=serialized_graph, + ), + ) + + def set_migration_checkpoint( + self, + new_graph: StateflowGraph, + operator_partition_locations: dict, + ) -> None: + """Serialize migration metadata for inclusion in the next sequencer file.""" + self._migration_checkpoint_blob = cloudpickle_serialization( + { + "new_graph": new_graph, + "operator_partition_locations": operator_partition_locations, + }, ) + def clear_migration_checkpoint(self) -> None: + """Clear migration metadata so subsequent snapshots are normal.""" + self._migration_checkpoint_blob = None + async def submit_stateflow_graph( self, stateflow_graph: StateflowGraph, diff --git a/coordinator/coordinator_service.py b/coordinator/coordinator_service.py index 0a72d44..297f83b 100755 --- a/coordinator/coordinator_service.py +++ b/coordinator/coordinator_service.py @@ -21,7 +21,11 @@ from styx.common.logging import logging from styx.common.message_types import MessageType from styx.common.protocols import Protocols -from styx.common.serialization import Serializer +from styx.common.serialization import ( + Serializer, + cloudpickle_deserialization, + zstd_msgpack_deserialization, +) from styx.common.tcp_networking import MessagingMode, NetworkingManager from styx.common.util.aio_task_scheduler import AIOTaskScheduler import uvloop @@ -191,6 +195,10 @@ def __init__(self) -> None: } self.snapshotting_task: asyncio.Task | None = None + self._migration_checkpoint_task: asyncio.Task | None = None + # Deferred migration re-trigger: set during recovery, picked up after + # the first post-recovery snapshot completes (to avoid dedup races). + self._deferred_migration_graph: StateflowGraph | None = None self._protocol_controller_handlers_map: dict[MessageType, Callable[[bytes], Awaitable[None]]] = { MessageType.AriaProcessingDone: self._handle_aria_processing_done, @@ -281,10 +289,34 @@ async def _handle_update_execution_graph( async def _start_migration(self, graph: StateflowGraph) -> None: # Phase A: do NOT stop the protocol yet — workers will rehash in the background self.migration_in_progress = True - await self.stop_snapshotting() + + # Force a pre-migration snapshot at the next epoch boundary + if self.aria_metadata is not None: + self.aria_metadata.take_snapshot_at_next_epoch() + self.coordinator.pre_migration_snapshot_pending = True logging.warning(f"MIGRATION | START {graph}") - await self.coordinator.update_stateflow_graph(graph) + try: + await self.coordinator.update_stateflow_graph(graph) + except Exception: + # update_stateflow_graph failed (e.g. dead worker connection). + # Revert the pool and clear migration state. + logging.exception("MIGRATION | update_stateflow_graph FAILED — reverting") + self.coordinator.revert_worker_pool_to_submitted_graph() + self.coordinator._pending_graph = None # noqa: SLF001 + self.coordinator.pre_migration_snapshot_pending = False + self.migration_in_progress = False + return + + # Guard: if recovery ran while the gather was blocked (zombie coroutine), + # migration_in_progress is now False and _pending_graph was cleared. + # Don't proceed — the recovery handler already dealt with it. + if not self.migration_in_progress: + logging.warning( + "MIGRATION | update_stateflow_graph completed but migration_in_progress=False " + "(recovery intervened while gather was blocked) — aborting", + ) + return n_workers = len(self.coordinator.worker_pool.get_participating_workers()) self.migration_metadata = MigrationMetadata(n_workers) @@ -304,10 +336,21 @@ async def _handle_migration_repartitioning_done( __: bytes, ___: concurrent.futures.ProcessPoolExecutor, ) -> None: + if not self.migration_in_progress: + logging.warning("Dropping stale MigrationRepartitioningDone (no migration in progress)") + return mt = MessageType.MigrationRepartitioningDone logging.warning("Migration repartitioning done received!") async with self.networking_locks[mt]: + logging.warning("MigrationRepartitioningDone | acquired lock") + # Re-check inside the lock to close the TOCTOU window: + # recovery may have cleared migration_in_progress between the + # outer guard and lock acquisition. + if not self.migration_in_progress: + logging.warning("Dropping stale MigrationRepartitioningDone (recovery intervened)") + return + sync_complete: bool = await self.migration_metadata.repartitioning_done() logging.warning(f"Migration repartitioning is complete: {sync_complete}") @@ -315,8 +358,33 @@ async def _handle_migration_repartitioning_done( if not sync_complete: return + # Re-check after await — recovery could have fired during the sync. + if not self.migration_in_progress: + logging.warning("Dropping MigrationRepartitioningDone (recovery intervened after sync)") + return + + # Set migration checkpoint blob BEFORE telling workers to proceed. + # Workers will take a checkpoint snapshot after receiving + # MigrationRepartitioningDone; when the SnapID arrives the coordinator + # writes the sequencer file which will include this blob. + # Use max(..., 0) so the baseline is always >= 0 — after recovery + # worker_snapshot_ids are reset to -1, but we still need the + # checkpoint detection (baseline >= 0) to fire on the next snapshot. + baseline = max(self.coordinator.get_current_completed_snapshot_id(), 0) + self.coordinator._migration_checkpoint_baseline_snap_id = baseline # noqa: SLF001 + operator_partition_locations = self.coordinator.worker_pool.get_operator_partition_locations() + self.coordinator.set_migration_checkpoint( + self.coordinator._pending_graph, # noqa: SLF001 + operator_partition_locations, + ) + logging.warning( + f"MIGRATION | Checkpoint blob set, baseline snap_id={baseline}", + ) + + logging.warning("MigrationRepartitioningDone | calling finalize_migration_repartition") await self.finalize_migration_repartition() await self.migration_metadata.cleanup(mt) + logging.warning("MigrationRepartitioningDone | done") async def _handle_migration_init_done( self, @@ -324,8 +392,17 @@ async def _handle_migration_init_done( data: bytes, __: concurrent.futures.ProcessPoolExecutor, ) -> None: + if not self.migration_in_progress: + logging.warning("Dropping stale MigrationInitDone (no migration in progress)") + return mt = MessageType.MigrationInitDone async with self.networking_locks[mt]: + logging.warning("MigrationInitDone | acquired lock") + # Re-check inside the lock (TOCTOU protection). + if not self.migration_in_progress: + logging.warning("Dropping stale MigrationInitDone (recovery intervened)") + return + epoch_counter, t_counter, input_offsets, output_offsets = self.networking.decode_message(data) sync_complete: bool = await self.migration_metadata.init_done( @@ -339,13 +416,37 @@ async def _handle_migration_init_done( if not sync_complete: return + # Re-check after await — recovery could have fired during the sync. + if not self.migration_in_progress: + logging.warning("Dropping MigrationInitDone (recovery intervened after sync)") + return + logging.warning("MIGRATION | MigrationInitDone | sync_complete") n_workers = len(self.coordinator.worker_pool.get_participating_workers()) self.aria_metadata = AriaSyncMetadata(n_workers) + logging.warning("MigrationInitDone | closing protocol connections") await self.protocol_networking.close_all_connections() - await self.finalize_migration() - await self.migration_metadata.cleanup(mt) + logging.warning( + f"MigrationInitDone | protocol connections closed, " + f"current_snap={self.coordinator.get_current_completed_snapshot_id()} " + f"baseline={self.coordinator._migration_checkpoint_baseline_snap_id}", # noqa: SLF001 + ) + + # The checkpoint blob was set in _handle_migration_repartitioning_done + # (before workers took their snapshot). By now the checkpoint snapshot + # may already be complete. Handle the race, then launch the finalizer. + baseline = self.coordinator._migration_checkpoint_baseline_snap_id # noqa: SLF001 + if baseline >= 0 and self.coordinator.get_current_completed_snapshot_id() > baseline: + self.coordinator._migration_checkpoint_snapshot_complete.set() # noqa: SLF001 + self.coordinator._migration_checkpoint_baseline_snap_id = -1 # noqa: SLF001 + logging.warning("MIGRATION | Checkpoint snapshot already completed (race handled)") + + # Launch background task to wait for the checkpoint snapshot, then finalize + logging.warning("MigrationInitDone | launching checkpoint finalizer task") + self._migration_checkpoint_task = asyncio.ensure_future( + self._wait_for_migration_checkpoint_and_finalize(mt), + ) async def _handle_register_worker( self, @@ -427,6 +528,24 @@ async def _handle_snap_id( pool, ) + # Check for deferred migration re-trigger: fire once the first + # post-recovery snapshot completes (prev changes from -1 to a + # real value, meaning a full cluster snapshot has been written). + if ( + self._deferred_migration_graph is not None + and not self.migration_in_progress + and self.coordinator.prev_completed_snapshot_id >= 0 + ): + graph = self._deferred_migration_graph + self._deferred_migration_graph = None + logging.warning( + "[RECOVERY] First post-recovery snapshot complete — re-triggering deferred migration", + ) + # Use ensure_future so we don't block the SnapID handler. + self._deferred_migration_task = asyncio.ensure_future( + self._start_migration(graph), + ) + async def _handle_heartbeat( self, _: StreamWriter, @@ -478,6 +597,8 @@ async def protocol_controller(self, data: bytes) -> None: f"COORDINATOR PROTOCOL SERVER: Non supported message type: {mt}", ) return + if self.migration_in_progress: + logging.warning(f"PROTOCOL_MSG during migration: {mt.name}") await handler(data) # ------------------------ @@ -586,11 +707,18 @@ async def _handle_sync_cleanup(self, data: bytes) -> None: if not sync_complete: return + stop_flag = self.aria_metadata.stop_next_epoch + if stop_flag: + logging.warning( + "SyncCleanup | sync_complete, sending stop_gracefully=True to all workers", + ) await self.finalize_worker_sync( mt, - (self.aria_metadata.stop_next_epoch,), + (stop_flag,), Serializer.MSGPACK, ) + if stop_flag: + logging.warning("SyncCleanup | stop_gracefully sent successfully") self.aria_metadata.cleanup(epoch_end=True) def _record_epoch_metrics( @@ -653,8 +781,16 @@ async def _handle_deterministic_reordering(self, data: bytes) -> None: self.aria_metadata.cleanup() async def _handle_migration_done(self, _: bytes) -> None: + if not self.migration_in_progress: + logging.warning("Dropping stale MigrationDone (no migration in progress)") + return mt = MessageType.MigrationDone async with self.networking_locks[mt]: + logging.warning("MigrationDone | acquired lock") + # Re-check inside the lock (TOCTOU protection). + if not self.migration_in_progress: + logging.warning("Dropping stale MigrationDone (recovery intervened)") + return sync_complete: bool = await self.migration_metadata.set_empty_sync_done(mt) logging.warning( f"MIGRATION | MigrationDone | {self.migration_metadata.sync_sum}", @@ -668,8 +804,15 @@ async def _handle_migration_done(self, _: bytes) -> None: await self.migration_metadata.cleanup(mt) self.migration_in_progress = False - logging.warning("Restarting the snapshotting mechanism") - self.snapshotting_task = asyncio.create_task(self.send_snapshot_marker()) + # Clear migration checkpoint blob so subsequent snapshots are normal. + # The graph was already finalized in _wait_for_migration_checkpoint_and_finalize(). + self.coordinator.clear_migration_checkpoint() + + # Defer final graph write until a post-migration snapshot captures + # the state without migration metadata. This is a safety net — the + # graph was already updated via finalize_graph_update() in the + # checkpoint flow, but we want a clean snapshot without migration_blob. + self.coordinator.post_migration_snapshot_pending = True async def start_puller(self) -> None: async def request_handler(reader: StreamReader, writer: StreamWriter) -> None: @@ -681,11 +824,13 @@ async def request_handler(reader: StreamReader, writer: StreamWriter) -> None: self.protocol_controller(await reader.readexactly(size)), ) except asyncio.IncompleteReadError as e: - logging.info(f"Client disconnected unexpectedly: {e}") + logging.warning(f"Protocol client disconnected unexpectedly: {e}") except asyncio.CancelledError: pass + except Exception as e: + logging.warning(f"Protocol request_handler unexpected error: {e}") finally: - logging.info("Closing the connection") + logging.warning("Protocol connection closing") writer.close() await writer.wait_closed() @@ -772,6 +917,39 @@ async def finalize_migration(self) -> None: ), ) + async def _wait_for_migration_checkpoint_and_finalize(self, mt: MessageType) -> None: + """Wait for the migration checkpoint snapshot to complete, then finalize migration.""" + logging.warning("checkpoint_finalizer | waiting for checkpoint snapshot event") + await self.coordinator._migration_checkpoint_snapshot_complete.wait() # noqa: SLF001 + self.coordinator._migration_checkpoint_snapshot_complete.clear() # noqa: SLF001 + logging.warning( + f"checkpoint_finalizer | event fired, migration_in_progress={self.migration_in_progress}", + ) + + # Guard: recovery may have cleared migration state while we were waiting. + if not self.migration_in_progress: + logging.warning( + "MIGRATION | Checkpoint snapshot completed but migration_in_progress " + "is False (recovery intervened) — aborting finalization", + ) + return + + logging.warning( + f"MIGRATION | Checkpoint snapshot complete (snap_id=" + f"{self.coordinator.get_current_completed_snapshot_id()}) — " + f"publishing new graph and sending MigrationDone", + ) + + # Publish new graph to styx-metadata so ALL clients update partitioners + logging.warning("checkpoint_finalizer | calling finalize_graph_update") + self.coordinator.finalize_graph_update() + + # Send MigrationDone to workers + logging.warning("checkpoint_finalizer | sending MigrationDone to workers") + await self.finalize_migration() + await self.migration_metadata.cleanup(mt) + logging.warning("checkpoint_finalizer | done") + async def finalize_worker_sync( self, msg_type: MessageType, @@ -821,7 +999,10 @@ async def _reset_after_recovery(self) -> None: # 2) Reset migration metadata self.migration_metadata = MigrationMetadata(n_workers) - self.migration_in_progress = False + # migration_in_progress, pre_migration_snapshot_pending, _pending_graph, + # and migration checkpoint baseline are already cleared in _perform_recovery() + # step 1c (before recovery starts) + self.coordinator.pre_migration_snapshot_id = -1 # 3) Reset snapshot completion metadata self.coordinator.completed_input_offsets.clear() @@ -849,6 +1030,27 @@ async def _reset_after_recovery(self) -> None: logging.warning("Protocol metadata reset complete") + def _read_migration_blob_from_s3(self, snap_id: int) -> bytes | None: + """Read the migration checkpoint blob from the sequencer file in S3.""" + if snap_id <= 0: + return None + try: + resp = self.s3_client.get_object( + Bucket=SNAPSHOT_BUCKET_NAME, + Key=f"sequencer/{snap_id}.bin", + ) + loaded = zstd_msgpack_deserialization(resp["Body"].read()) + _sequencer_tuple_len_with_migration = 5 + if ( + isinstance(loaded, (tuple, list)) + and len(loaded) == _sequencer_tuple_len_with_migration + and loaded[4] is not None + ): + return loaded[4] + except Exception as e: + logging.warning(f"[RECOVERY] Failed to read migration blob from S3: {e}") + return None + async def _perform_recovery(self, workers_to_remove: set[Worker]) -> None: """ Full recovery state machine: @@ -864,6 +1066,36 @@ async def _perform_recovery(self, workers_to_remove: set[Worker]) -> None: logging.warning(f"Starting recovery process for workers: {workers_to_remove}") + # 0) Immediately cancel the migration checkpoint background task. + # This MUST happen before reading any migration state (_pending_graph, + # submitted_graph) to prevent finalize_graph_update() from committing + # the new graph layout while recovery is reverting to the old one. + has_checkpoint_task = ( + hasattr(self, "_migration_checkpoint_task") and self._migration_checkpoint_task is not None + ) + logging.warning( + f"[RECOVERY] step 0 | cancel checkpoint task={has_checkpoint_task} " + f"migration_in_progress={self.migration_in_progress}", + ) + if has_checkpoint_task: + self._migration_checkpoint_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._migration_checkpoint_task + self._migration_checkpoint_task = None + logging.warning("[RECOVERY] step 0 | checkpoint task cancelled") + + # 0b) Clear migration_in_progress immediately so any migration handler + # that passed its guard check before recovery will see it is False + # when it re-checks inside the lock (TOCTOU protection). + # Save the old value first — we need it for was_migrating detection + # even when _pending_graph was never set (e.g. update_stateflow_graph + # failed mid-send). + saved_migration_in_progress = self.migration_in_progress + self.migration_in_progress = False + logging.warning( + f"[RECOVERY] step 0b | migration_in_progress was {saved_migration_in_progress}, now False", + ) + # 1) Clean up dead worker channels and buffered tasks logging.warning(f"Closing connections to dead workers: {workers_to_remove}") for worker in workers_to_remove: @@ -878,9 +1110,61 @@ async def _perform_recovery(self, workers_to_remove: set[Worker]) -> None: await self.aio_task_scheduler.close() self.aio_task_scheduler = AIOTaskScheduler() + # 1b) Check if the latest snapshot contains a migration checkpoint. + # If yes (post-checkpoint recovery): use the NEW layout — workers will + # re-derive keys_to_send from the snapshot and resume async transfer. + # If no (pre-checkpoint recovery): revert to old layout and optionally + # re-trigger migration from scratch. + was_migrating = ( + saved_migration_in_progress + or self.coordinator._pending_graph is not None # noqa: SLF001 + or self.coordinator.post_migration_snapshot_pending + or self.coordinator._migration_checkpoint_baseline_snap_id >= 0 # noqa: SLF001 + ) + saved_pending_graph = None + snap_id = self.coordinator.get_current_completed_snapshot_id() + + if was_migrating: + logging.warning( + f"[RECOVERY] Migration was in progress. " + f"pre_migration_snapshot_id={self.coordinator.pre_migration_snapshot_id}, " + f"current_completed_snapshot_id={snap_id}, " + f"worker_snapshot_ids={self.coordinator.worker_snapshot_ids}", + ) + migration_blob = self._read_migration_blob_from_s3(snap_id) + if migration_blob is not None: + # POST-CHECKPOINT: use NEW layout — workers will resume transfer + logging.warning("[RECOVERY] Migration checkpoint found — using NEW layout for recovery") + migration_meta = cloudpickle_deserialization(migration_blob) + new_graph = migration_meta["new_graph"] + self.coordinator.submitted_graph = new_graph + self.coordinator._pending_graph = None # noqa: SLF001 + # The worker pool already has the new operator assignments from + # update_stateflow_graph() — no need to revert. + else: + # PRE-CHECKPOINT: revert to old layout and retry migration + logging.warning("[RECOVERY] No migration checkpoint — reverting to OLD layout") + saved_pending_graph = self.coordinator._pending_graph # noqa: SLF001 + self.coordinator.revert_worker_pool_to_submitted_graph() + + # Clear remaining migration state so stale messages are dropped. + logging.warning("[RECOVERY] step 1c | clearing remaining migration state") + self.coordinator.pre_migration_snapshot_pending = False + self.coordinator.post_migration_snapshot_pending = False + self.coordinator._migration_checkpoint_baseline_snap_id = -1 # noqa: SLF001 + self.coordinator._migration_checkpoint_snapshot_complete.clear() # noqa: SLF001 + self.coordinator._pending_graph = None # noqa: SLF001 + self.coordinator.clear_migration_checkpoint() + self._deferred_migration_graph = None # Clear any pending deferred re-trigger + # 2) Start recovery + graph_parts = ( + {n: op.n_partitions for n, op in self.coordinator.submitted_graph.nodes.items()} + if self.coordinator.submitted_graph + else None + ) logging.warning( - "Starting recovery process (reassign operators, send InitRecovery)", + f"[RECOVERY] Starting recovery (send InitRecovery). snap_id={snap_id}, graph_partitions={graph_parts}", ) await self.coordinator.start_recovery_process(workers_to_remove) @@ -901,6 +1185,21 @@ async def _perform_recovery(self, workers_to_remove: set[Worker]) -> None: logging.warning("Recovery process completed") + # 7) If migration was interrupted, DEFER re-trigger until the first + # post-recovery snapshot completes. Re-triggering immediately causes + # duplicate output: the protocol starts replaying from the snapshot, + # the dedup mechanism suppresses most duplicates, but the immediate + # migration Phase B stops the protocol before all replayed messages + # are fully dedup'd, and the migration restart's dedup scan starts + # from Phase B offsets (missing the earlier duplicates). + # By deferring, we let the recovery epoch fully complete and a clean + # snapshot capture the post-dedup state, eliminating duplicates. + if saved_pending_graph is not None: + logging.warning( + "[RECOVERY] Deferring migration re-trigger until first post-recovery snapshot", + ) + self._deferred_migration_graph = saved_pending_graph + async def heartbeat_monitor_coroutine(self) -> None: interval_time = HEARTBEAT_CHECK_INTERVAL / 1000 while True: @@ -917,6 +1216,11 @@ async def heartbeat_monitor_coroutine(self) -> None: # Add workers that re-registered (same IP/ports) to the failed set if (workers_to_remove or self.workers_that_re_registered) and self.recovery_state == RecoveryState.IDLE: + logging.warning( + f"heartbeat_monitor | entering recovery lock " + f"dead={workers_to_remove} re_registered={self.workers_that_re_registered} " + f"migration_in_progress={self.migration_in_progress}", + ) async with self.recovery_lock: if self.recovery_state != RecoveryState.IDLE: # Another recovery started while we were waiting for the lock @@ -933,6 +1237,7 @@ async def heartbeat_monitor_coroutine(self) -> None: logging.error(f"Error during recovery: {e}") finally: self.recovery_state = RecoveryState.IDLE + logging.warning("heartbeat_monitor | recovery_state back to IDLE") async def send_snapshot_marker(self) -> None: while True: diff --git a/coordinator/worker_pool.py b/coordinator/worker_pool.py index 5351f74..27ef3a2 100644 --- a/coordinator/worker_pool.py +++ b/coordinator/worker_pool.py @@ -153,7 +153,14 @@ def update_operator( operator_partition: OperatorPartition, operator: Operator | BaseOperator, ) -> None: - worker_id = self.operator_partition_to_worker[operator_partition] + worker_id = self.operator_partition_to_worker.get(operator_partition) + if worker_id is None or worker_id not in self._worker_queue_idx: + # Worker was already removed (e.g. dead before recovery revert). + # The partition will be rescheduled during initiate_recovery. + logging.debug( + f"update_operator | skipping {operator_partition} — worker {worker_id} removed", + ) + return worker = self.peek(worker_id) worker.assigned_operators[operator_partition] = operator diff --git a/demo/demo-migration-ft-ycsb/__init__.py b/demo/demo-migration-ft-ycsb/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/demo/demo-migration-ft-ycsb/calculate_metrics.py b/demo/demo-migration-ft-ycsb/calculate_metrics.py new file mode 120000 index 0000000..ac136e8 --- /dev/null +++ b/demo/demo-migration-ft-ycsb/calculate_metrics.py @@ -0,0 +1 @@ +../demo-migration-ycsb/calculate_metrics.py \ No newline at end of file diff --git a/demo/demo-migration-ft-ycsb/client.py b/demo/demo-migration-ft-ycsb/client.py new file mode 100644 index 0000000..3890e3d --- /dev/null +++ b/demo/demo-migration-ft-ycsb/client.py @@ -0,0 +1,236 @@ +"""YCSB migration + fault tolerance client. + +Usage: + python client.py + + + + +- migration_at_sec: second at which to trigger migration (-1 = disabled) +- kill_at_sec: second at which to `docker kill styx-worker-1` (-1 = disabled) +""" + +import multiprocessing +from multiprocessing import Pool, cpu_count +import os +import pickle +import random +import string +import subprocess +import sys +import time +from timeit import default_timer as timer + +import boto3 + +import calculate_metrics +import kafka_output_consumer +import pandas as pd +from styx.client.sync_client import SyncStyxClient +from styx.common.local_state_backends import LocalStateBackend +from styx.common.operator import Operator +from styx.common.serialization import ( + cloudpickle_deserialization, + cloudpickle_serialization, +) +from styx.common.stateflow_graph import StateflowGraph +from tqdm import tqdm +from ycsb import ycsb_operator + + +def ycsb_field(size=100, seed=0): + "Translated from the original Java code" + if seed is not None: + random.seed(seed) + charset = string.ascii_letters + string.digits + return "".join(random.choices(charset, k=size)) + + +N_ENTITIES = int(sys.argv[8]) +MIGRATION_AT_SEC = int(sys.argv[9]) +KILL_AT_SEC = int(sys.argv[10]) + +threads = int(sys.argv[1]) +START_N_PARTITIONS = int(sys.argv[2]) +END_N_PARTITIONS = int(sys.argv[3]) +messages_per_second = int(sys.argv[4]) +seconds = int(sys.argv[5]) +SAVE_DIR: str = sys.argv[6] +warmup_seconds: int = int(sys.argv[7]) + +BATCH_SIZE = 100_000 + +sleeps_per_second = 100 +sleep_time = 0.0085 +STYX_HOST: str = os.getenv("STYX_HOST", "localhost") +STYX_PORT: int = int(os.getenv("STYX_PORT", "8886")) +KAFKA_URL: str = os.getenv("KAFKA_URL", "localhost:9092") + +g = StateflowGraph( + "ycsb-benchmark", + operator_state_backend=LocalStateBackend.DICT, + max_operator_parallelism=max(START_N_PARTITIONS, END_N_PARTITIONS), +) +ycsb_operator.set_n_partitions(START_N_PARTITIONS) +g.add_operators(ycsb_operator) + + +def submit_graph(styx: SyncStyxClient): + print(f"Partitions: {list(g.nodes.values())[0].n_partitions}") + styx.submit_dataflow(g) + print("Graph submitted") + + +def generate_partition_batch_wrapper(args): + batch_start, batch_end, operator_state = args + current_active_graph, operator_name = operator_state + current_active_graph = cloudpickle_deserialization(current_active_graph) + local_partitions = {p: {} for p in range(START_N_PARTITIONS)} + for i in range(batch_start, batch_end): + partition = current_active_graph.get_operator_by_name(operator_name).which_partition(i) + local_partitions[partition][i] = tuple(ycsb_field() for _ in range(10)) + return local_partitions + + +def merge_partitions(global_partitions, local_partitions): + for part_id, entries in local_partitions.items(): + global_partitions[part_id].update(entries) + + +def ycsb_init(styx: SyncStyxClient, operator: Operator, num_workers: int = None): + styx.set_graph(g) + styx.init_metadata(g) + partitions = {p: {} for p in range(START_N_PARTITIONS)} + + ycsb_dataset_path = f"ycsb_dataset_{START_N_PARTITIONS}p.pkl" + + if os.path.exists(ycsb_dataset_path): + print("Loading YCSB dataset...") + with open(ycsb_dataset_path, "rb") as f: + partitions = pickle.load(f) + else: + print("Generating YCSB dataset...") + batch_ranges = [ + (i, min(i + BATCH_SIZE, N_ENTITIES)) + for i in range(0, N_ENTITIES, BATCH_SIZE) + ] + operator_state = (cloudpickle_serialization(styx._current_active_graph), operator.name) + tasks = [(start, end, operator_state) for start, end in batch_ranges] + + with Pool(processes=num_workers or cpu_count()) as pool: + results = [] + for result in tqdm(pool.imap_unordered(generate_partition_batch_wrapper, tasks), total=len(tasks)): + results.append(result) + + for local_partitions in results: + merge_partitions(partitions, local_partitions) + with open(ycsb_dataset_path, "wb") as f: + pickle.dump(partitions, f) + print("Data ready") + for partition, partition_data in partitions.items(): + styx.init_data(operator, partition, partition_data) + print("Data loaded") + time.sleep(5) + submit_graph(styx) + + +def transactional_ycsb_generator(operator: Operator): + while True: + key = random.randint(0, N_ENTITIES - 1) + op = "read" if random.random() < 0.85 else "update" + yield operator, key, op + + +def benchmark_runner(proc_num) -> dict[bytes, dict]: + print(f"Generator: {proc_num} starting") + styx = SyncStyxClient(STYX_HOST, STYX_PORT, kafka_url=KAFKA_URL) + styx.open(consume=False) + ycsb_generator = transactional_ycsb_generator(ycsb_operator) + timestamp_futures: dict[bytes, dict] = {} + time.sleep(5) + start = timer() + for cur_sec in range(seconds): + # Kill a worker at the specified second (only proc 0) + if proc_num == 0 and 0 <= KILL_AT_SEC == cur_sec: + subprocess.run(["docker", "kill", "styx-worker-1"], check=False) + print(f"KILL -> styx-worker-1 at second {cur_sec}") + + sec_start = timer() + for i in range(messages_per_second): + if i % (messages_per_second // sleeps_per_second) == 0: + time.sleep(sleep_time) + operator, key, func_name = next(ycsb_generator) + future = styx.send_event(operator=operator, key=key, function=func_name) + timestamp_futures[future.request_id] = {"op": f"{func_name} {key}"} + styx.flush() + sec_end = timer() + lps = sec_end - sec_start + if lps < 1: + time.sleep(1 - lps) + sec_end2 = timer() + print(f"{cur_sec} | Latency per second: {sec_end2 - sec_start}") + + # Trigger migration at the specified second (only proc 0) + if cur_sec == MIGRATION_AT_SEC and proc_num == 0: + new_g = StateflowGraph("ycsb-benchmark", operator_state_backend=LocalStateBackend.DICT) + ycsb_operator.set_n_partitions(END_N_PARTITIONS) + new_g.add_operators(ycsb_operator) + styx.update_dataflow(new_g) + print(f"Migration request submitted at second {cur_sec}") + + end = timer() + print(f"Average latency per second: {(end - start) / seconds}") + + styx.close() + + for key, metadata in styx.delivery_timestamps.items(): + timestamp_futures[key]["timestamp"] = metadata + return timestamp_futures + + +def main(): + print("Generate and push workload to Styx") + print(f"Migration at second: {MIGRATION_AT_SEC}, Kill at second: {KILL_AT_SEC}") + + if N_ENTITIES < 3: + print("Impossible to run this benchmark with one key") + return + + s3 = boto3.client( + "s3", + endpoint_url=os.getenv("S3_ENDPOINT") or "http://localhost:9000", + aws_access_key_id=os.getenv("S3_ACCESS_KEY") or "rustfsadmin", + aws_secret_access_key=os.getenv("S3_SECRET_KEY") or "rustfsadmin", + region_name=os.getenv("S3_REGION") or "us-east-1", + ) + styx_client = SyncStyxClient(STYX_HOST, STYX_PORT, kafka_url=KAFKA_URL, s3=s3) + ycsb_init(styx_client, ycsb_operator) + del styx_client + time.sleep(5) + + with Pool(threads) as p: + results = p.map(benchmark_runner, range(threads)) + + results = {k: v for d in results for k, v in d.items()} + assert len(results) == messages_per_second * seconds * threads + + pd.DataFrame( + { + "request_id": list(results.keys()), + "timestamp": [res["timestamp"] for res in results.values()], + "op": [res["op"] for res in results.values()], + } + ).sort_values("timestamp").to_csv(f"{SAVE_DIR}/client_requests.csv", index=False) + + print("Workload completed") + + +if __name__ == "__main__": + multiprocessing.set_start_method("fork") + main() + + print() + kafka_output_consumer.main(SAVE_DIR) + + print() + calculate_metrics.main(messages_per_second, threads, warmup_seconds, SAVE_DIR) diff --git a/demo/demo-migration-ft-ycsb/kafka_output_consumer.py b/demo/demo-migration-ft-ycsb/kafka_output_consumer.py new file mode 120000 index 0000000..69f6658 --- /dev/null +++ b/demo/demo-migration-ft-ycsb/kafka_output_consumer.py @@ -0,0 +1 @@ +../demo-migration-ycsb/kafka_output_consumer.py \ No newline at end of file diff --git a/demo/demo-migration-ft-ycsb/ycsb.py b/demo/demo-migration-ft-ycsb/ycsb.py new file mode 120000 index 0000000..6ee8661 --- /dev/null +++ b/demo/demo-migration-ft-ycsb/ycsb.py @@ -0,0 +1 @@ +../demo-migration-ycsb/ycsb.py \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 7a54e02..a5b5097 100755 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -15,6 +15,7 @@ services: - S3_ENDPOINT=http://rustfs:9000 - S3_ACCESS_KEY=rustfsadmin - S3_SECRET_KEY=rustfsadmin + - LOG_LEVEL=${LOG_LEVEL:-WARNING} depends_on: - prometheus - grafana @@ -54,6 +55,7 @@ services: - S3_ENDPOINT=http://rustfs:9000 - S3_ACCESS_KEY=rustfsadmin - S3_SECRET_KEY=rustfsadmin + - LOG_LEVEL=${LOG_LEVEL:-WARNING} depends_on: - coordinator diff --git a/pytest.ini b/pytest.ini index 3961d89..3e9757b 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,6 +1,7 @@ [pytest] markers = e2e: end-to-end tests that start containers + e2e_migration: e2e tests involving state migration integration: component-level tests using testcontainers log_cli = true log_cli_level = INFO diff --git a/requirements.txt b/requirements.txt index 20fb846..927e67f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,4 +34,4 @@ pytest==9.0.2 pytest-asyncio==1.3.0 testcontainers[kafka,minio]>=4.14.1 # make sure to have the same ruff verseion as in .github/workflows/ruff.yml line 30 -ruff==0.15.6 \ No newline at end of file +ruff==0.15.7 \ No newline at end of file diff --git a/styx-package/styx/common/logging.py b/styx-package/styx/common/logging.py index df59cb8..00ac149 100644 --- a/styx-package/styx/common/logging.py +++ b/styx-package/styx/common/logging.py @@ -1,12 +1,15 @@ import logging as _stdlib_logging from logging.handlers import QueueHandler, QueueListener +import os from queue import SimpleQueue _log_queue: SimpleQueue = SimpleQueue() # Root "styx" logger — uses QueueHandler so callers never block. +# Log level is configurable via LOG_LEVEL env var (default: WARNING). logging = _stdlib_logging.getLogger("styx") -logging.setLevel(_stdlib_logging.WARNING) +_log_level = getattr(_stdlib_logging, os.getenv("LOG_LEVEL", "WARNING").upper(), _stdlib_logging.WARNING) +logging.setLevel(_log_level) if not logging.handlers: logging.addHandler(QueueHandler(_log_queue)) diff --git a/styx-package/styx/common/message_types.py b/styx-package/styx/common/message_types.py index 45bfc29..937ed34 100644 --- a/styx-package/styx/common/message_types.py +++ b/styx-package/styx/common/message_types.py @@ -44,3 +44,4 @@ class MessageType(IntEnum): MigrationInitDone = 39 InitDataComplete = 40 UpdateExecutionGraph = 41 + SnapMigrationReassign = 42 diff --git a/tests/e2e/test_e2e_migration_ft_ycsb.py b/tests/e2e/test_e2e_migration_ft_ycsb.py new file mode 100644 index 0000000..4279f8b --- /dev/null +++ b/tests/e2e/test_e2e_migration_ft_ycsb.py @@ -0,0 +1,435 @@ +"""E2E tests for fault tolerance during state migration. + +These tests verify that Styx maintains exactly-once semantics when a worker +crashes at different points during a migration cycle: + +1. Normal migration (baseline) — no crash +2. Kill during Phase A — crash while protocol is still running, before + repartitioning starts. Recovery should roll back to pre-migration snapshot. +3. Kill during Phase C — crash after repartitioning, during async data + transfer. Recovery should roll back to pre-migration snapshot (no + post-migration snapshot has completed yet). +4. Kill after migration completes — crash after MigrationDone. Recovery + should use the post-migration snapshot with new partition layout. + +All tests validate: +- No duplicate requests +- Exactly-once output (no duplicate responses) +- Zero missed messages +""" + +from dataclasses import dataclass +import json +import logging +from pathlib import Path + +import pytest + +from tests.helpers import make_test_env, run_and_stream, run_and_stream_with_timed_action, wait_port + +log = logging.getLogger("e2e.migration_ft_ycsb") + + +def _assert_metrics( + results_dir: Path, + input_rate: int, + client_threads: int, +) -> None: + exp_name = f"ycsb_migration_{input_rate * client_threads}" + metrics_json = results_dir / f"{exp_name}.json" + log.info("Checking metrics json: %s", metrics_json) + assert metrics_json.exists(), f"Missing metrics json: {metrics_json}" + + with metrics_json.open("r", encoding="utf-8") as f: + metrics = json.load(f) + + log.info("==== METRICS JSON (%s) ====", metrics_json) + log.info("%s", json.dumps(metrics, indent=2, sort_keys=True)) + log.info("==== END METRICS JSON ====") + + assert metrics.get("duplicate_requests") is False, metrics + assert metrics.get("exactly_once_output") is True, metrics + assert int(metrics.get("missed messages")) == 0, metrics + + +@dataclass(frozen=True) +class _ClusterParams: + n_partitions: int = 4 + epoch_size: int = 1000 + threads_per_worker: int = 1 + enable_compression: str = "true" + use_composite_keys: str = "true" + use_fallback_cache: str = "true" + + +@dataclass(frozen=True) +class _ClientParams: + client_threads: int = 2 + n_entities: int = 10_000 + start_n_partitions: int = 4 + end_n_partitions: int = 8 + input_rate: int = 200 + total_time: int = 120 + warmup_seconds: int = 10 + migration_at_sec: int = 30 + kill_at_sec: int = -1 # <0 = disabled + + +@dataclass(frozen=True) +class _Paths: + repo_root: Path + demo_dir: Path + start_script: Path + stop_script: Path + + +def _resolve_paths() -> _Paths: + repo_root = Path(__file__).resolve().parents[2] + demo_dir = repo_root / "demo" / "demo-migration-ft-ycsb" + start_script = repo_root / "scripts" / "start_styx_cluster.sh" + stop_script = repo_root / "scripts" / "stop_styx_cluster.sh" + + assert demo_dir.exists(), f"Expected demo dir at {demo_dir}" + assert start_script.exists(), f"Missing: {start_script}" + assert stop_script.exists(), f"Missing: {stop_script}" + + return _Paths( + repo_root=repo_root, + demo_dir=demo_dir, + start_script=start_script, + stop_script=stop_script, + ) + + +def _make_results_dir(tmp_path: Path) -> Path: + results_dir = tmp_path / "results" + results_dir.mkdir(parents=True, exist_ok=True) + log.info("Results dir: %s", results_dir) + return results_dir + + +def _start_cmd(paths: _Paths, p: _ClusterParams) -> list[str]: + return [ + "bash", + str(paths.start_script), + str(p.n_partitions), + str(p.epoch_size), + str(p.threads_per_worker), + p.enable_compression, + p.use_composite_keys, + p.use_fallback_cache, + ] + + +def _stop_cmd(paths: _Paths, p: _ClusterParams) -> list[str]: + return ["bash", str(paths.stop_script), str(p.threads_per_worker)] + + +def _client_cmd(results_dir: Path, client: _ClientParams) -> list[str]: + # client.py expects: + # 1 threads + # 2 START_N_PARTITIONS + # 3 END_N_PARTITIONS + # 4 messages_per_second + # 5 seconds + # 6 SAVE_DIR + # 7 warmup_seconds + # 8 N_ENTITIES + # 9 migration_at_sec + # 10 kill_at_sec + return [ + "python", + "client.py", + str(client.client_threads), + str(client.start_n_partitions), + str(client.end_n_partitions), + str(client.input_rate), + str(client.total_time), + str(results_dir), + str(client.warmup_seconds), + str(client.n_entities), + str(client.migration_at_sec), + str(client.kill_at_sec), + ] + + +def _start_cluster_and_wait(paths: _Paths, env: dict, cluster: _ClusterParams) -> None: + rc, out = run_and_stream( + _start_cmd(paths, cluster), + cwd=str(paths.repo_root), + env=env, + timeout=20 * 60, + banner="START STYX CLUSTER", + log=log, + ) + if rc != 0: + raise AssertionError(f"start_styx_cluster.sh failed (rc={rc}). Output:\n{out}") + + log.info("Waiting for Kafka external listener (127.0.0.1:9092)...") + wait_port("127.0.0.1", 9092, timeout_s=180) + log.info("Kafka port is up.") + + log.info("Waiting for Styx coordinator published port (127.0.0.1:8886)...") + wait_port("127.0.0.1", 8886, timeout_s=240) + log.info("Coordinator port is up.") + + +def _run_client( + *, + paths: _Paths, + env: dict, + results_dir: Path, + client: _ClientParams, + timeout_s: int = 30 * 60, +) -> None: + rc, out = run_and_stream( + _client_cmd(results_dir, client), + cwd=str(paths.demo_dir), + env=env, + timeout=timeout_s, + banner="RUN MIGRATION FT YCSB CLIENT", + log=log, + ) + if rc != 0: + raise AssertionError(f"client.py failed (rc={rc}). Output:\n{out}") + + +def _assert_artifacts_and_metrics(results_dir: Path, client: _ClientParams) -> None: + client_csv = results_dir / "client_requests.csv" + output_csv = results_dir / "output.csv" + log.info("Checking artifacts: %s and %s", client_csv, output_csv) + + assert client_csv.exists(), f"Missing artifact: {client_csv}" + assert output_csv.exists(), f"Missing artifact: {output_csv}" + + _assert_metrics(results_dir, client.input_rate, client.client_threads) + + +def _stop_cluster(paths: _Paths, env: dict, cluster: _ClusterParams, *, timeout_s: int) -> None: + rc, out = run_and_stream( + _stop_cmd(paths, cluster), + cwd=str(paths.repo_root), + env=env, + timeout=timeout_s, + banner="STOP STYX CLUSTER", + log=log, + ) + if rc != 0: + log.warning("stop_styx_cluster.sh failed (rc=%s). Output:\n%s", rc, out) + + +# --------------------------------------------------------------------------- +# Test: normal migration (no crash) — baseline +# --------------------------------------------------------------------------- + + +@pytest.mark.e2e_migration +def test_migration_ft_baseline(tmp_path: Path): + """Normal migration completes without crash. + + Validates: no data loss, exactly-once output. + """ + paths = _resolve_paths() + results_dir = _make_results_dir(tmp_path) + + cluster = _ClusterParams() + client = _ClientParams( + total_time=120, + migration_at_sec=30, + kill_at_sec=-1, + ) + + env = make_test_env() + + try: + _start_cluster_and_wait(paths, env, cluster) + _run_client( + paths=paths, + env=env, + results_dir=results_dir, + client=client, + timeout_s=30 * 60, + ) + _assert_artifacts_and_metrics(results_dir, client) + finally: + _stop_cluster(paths, env, cluster, timeout_s=10 * 60) + + +# --------------------------------------------------------------------------- +# Test: kill during Phase A (protocol running, before repartitioning) +# --------------------------------------------------------------------------- + + +@pytest.mark.e2e_migration +def test_migration_ft_kill_during_phase_a(tmp_path: Path): + """Kill a worker during Phase A — crash while protocol is running. + + Migration starts at second 30. Kill at second 35: the protocol is still + running and rehash metadata may not have arrived yet. The coordinator + should detect the failure, recover to the pre-migration snapshot, and + retry the migration from scratch. + + Validates: recovery to pre-migration state, exactly-once output. + """ + paths = _resolve_paths() + results_dir = _make_results_dir(tmp_path) + + cluster = _ClusterParams() + client = _ClientParams( + total_time=120, + warmup_seconds=10, + migration_at_sec=30, + kill_at_sec=35, # Shortly after migration starts (Phase A) + ) + + env = make_test_env() + + try: + _start_cluster_and_wait(paths, env, cluster) + _run_client( + paths=paths, + env=env, + results_dir=results_dir, + client=client, + timeout_s=30 * 60, + ) + _assert_artifacts_and_metrics(results_dir, client) + finally: + _stop_cluster(paths, env, cluster, timeout_s=15 * 60) + + +# --------------------------------------------------------------------------- +# Test: kill during Phase C (async data transfer, before post-mig snapshot) +# --------------------------------------------------------------------------- + + +@pytest.mark.e2e_migration +def test_migration_ft_kill_during_phase_c(tmp_path: Path): + """Kill a worker during Phase C — async data transfer in progress. + + Migration starts at second 30. Kill at second 50: repartitioning has + completed but async data transfer is likely still in progress. No + post-migration snapshot should have completed yet. Recovery should + roll back to pre-migration snapshot with OLD partition layout. + + Validates: recovery to pre-migration state, exactly-once output. + """ + paths = _resolve_paths() + results_dir = _make_results_dir(tmp_path) + + cluster = _ClusterParams() + client = _ClientParams( + total_time=150, + warmup_seconds=10, + migration_at_sec=30, + kill_at_sec=50, # During async data transfer (Phase C) + ) + + env = make_test_env() + + try: + _start_cluster_and_wait(paths, env, cluster) + _run_client( + paths=paths, + env=env, + results_dir=results_dir, + client=client, + timeout_s=40 * 60, + ) + _assert_artifacts_and_metrics(results_dir, client) + finally: + _stop_cluster(paths, env, cluster, timeout_s=15 * 60) + + +# --------------------------------------------------------------------------- +# Test: kill after migration completes +# --------------------------------------------------------------------------- + + +@pytest.mark.e2e_migration +def test_migration_ft_kill_after_migration_done(tmp_path: Path): + """Kill a worker after migration completes. + + Migration starts at second 30. Kill at second 90: migration should have + completed and at least one post-migration snapshot should exist. Recovery + should use the post-migration snapshot with NEW partition layout. + + Validates: recovery to post-migration state, exactly-once output. + """ + paths = _resolve_paths() + results_dir = _make_results_dir(tmp_path) + + cluster = _ClusterParams() + client = _ClientParams( + total_time=150, + warmup_seconds=10, + migration_at_sec=30, + kill_at_sec=90, # Well after migration completes + ) + + env = make_test_env() + + try: + _start_cluster_and_wait(paths, env, cluster) + _run_client( + paths=paths, + env=env, + results_dir=results_dir, + client=client, + timeout_s=40 * 60, + ) + _assert_artifacts_and_metrics(results_dir, client) + finally: + _stop_cluster(paths, env, cluster, timeout_s=15 * 60) + + +# --------------------------------------------------------------------------- +# Test: kill during migration with external docker kill (timed action) +# --------------------------------------------------------------------------- + + +@pytest.mark.e2e_migration +def test_migration_ft_external_kill_during_migration(tmp_path: Path): + """Kill a worker externally via docker kill during migration. + + Uses run_and_stream_with_timed_action to kill a worker container + externally at a fixed wall-clock offset, independent of the client. + This simulates a more realistic failure scenario where the kill + timing is not synchronized with client message processing. + + Migration starts at second 30 (inside client). + External kill fires 45 seconds after client starts (~second 45). + """ + paths = _resolve_paths() + results_dir = _make_results_dir(tmp_path) + + cluster = _ClusterParams() + client = _ClientParams( + total_time=150, + warmup_seconds=10, + migration_at_sec=30, + kill_at_sec=-1, # Kill is handled externally + ) + + env = make_test_env() + + try: + _start_cluster_and_wait(paths, env, cluster) + + rc, out = run_and_stream_with_timed_action( + _client_cmd(results_dir, client), + cwd=str(paths.demo_dir), + env=env, + timeout=40 * 60, + banner="RUN MIGRATION FT YCSB CLIENT (external kill)", + action_at_s=45, # ~15 seconds after migration starts + action_cmd=["docker", "kill", "styx-worker-1"], + action_banner="KILL styx-worker-1", + log=log, + ) + if rc != 0: + raise AssertionError(f"client.py failed (rc={rc}). Output:\n{out}") + + _assert_artifacts_and_metrics(results_dir, client) + finally: + _stop_cluster(paths, env, cluster, timeout_s=15 * 60) diff --git a/tests/e2e/test_e2e_migration_tpcc.py b/tests/e2e/test_e2e_migration_tpcc.py index 1d51fcb..eaeca77 100644 --- a/tests/e2e/test_e2e_migration_tpcc.py +++ b/tests/e2e/test_e2e_migration_tpcc.py @@ -255,7 +255,7 @@ def _stop_cluster(paths: _Paths, env: dict, cluster: _ClusterParams, *, timeout_ log.warning("stop_styx_cluster.sh failed (rc=%s). Output:\n%s", rc, out) -@pytest.mark.e2e +@pytest.mark.e2e_migration def test_styx_e2e_migration_tpcc(tmp_path: Path): paths = _resolve_paths() results_dir = _make_results_dir(tmp_path) diff --git a/tests/e2e/test_e2e_migration_ycsb.py b/tests/e2e/test_e2e_migration_ycsb.py index d797293..ef68e6b 100644 --- a/tests/e2e/test_e2e_migration_ycsb.py +++ b/tests/e2e/test_e2e_migration_ycsb.py @@ -192,7 +192,7 @@ def _stop_cluster(paths: _Paths, env: dict, cluster: _ClusterParams, *, timeout_ log.warning("stop_styx_cluster.sh failed (rc=%s). Output:\n%s", rc, out) -@pytest.mark.e2e +@pytest.mark.e2e_migration def test_styx_e2e_migration_ycsb(tmp_path: Path): paths = _resolve_paths() results_dir = _make_results_dir(tmp_path) diff --git a/tests/integration/test_s3_snapshots.py b/tests/integration/test_s3_snapshots.py index 72de316..4891578 100644 --- a/tests/integration/test_s3_snapshots.py +++ b/tests/integration/test_s3_snapshots.py @@ -38,7 +38,7 @@ def test_retrieve_snapshot_negative_id_returns_empty(self, minio_bucket): from worker.fault_tolerance.async_snapshots import AsyncSnapshotsS3 snap = AsyncSnapshotsS3(worker_id=0, n_assigned_partitions=1) - data, off_in, off_out, ep, tc = snap.retrieve_snapshot(-1, [("any_op", 0)]) + data, off_in, off_out, ep, tc, _migration_blob = snap.retrieve_snapshot(-1, [("any_op", 0)]) assert data == {} assert off_in == {} assert off_out == {} @@ -68,7 +68,9 @@ def test_store_and_retrieve_round_trip(self, minio_bucket): # Retrieve snap = AsyncSnapshotsS3(worker_id=0, n_assigned_partitions=1) - data, _off_in, _off_out, ep, tc = snap.retrieve_snapshot(snapshot_id, [(operator_name, partition)]) + data, _off_in, _off_out, ep, tc, _migration_blob = snap.retrieve_snapshot( + snapshot_id, [(operator_name, partition)] + ) assert data[(operator_name, partition)] == state_data assert ep == epoch assert tc == t_counter @@ -101,7 +103,7 @@ def test_retrieve_merges_incremental_snapshots(self, minio_bucket): ) snap = AsyncSnapshotsS3(worker_id=0, n_assigned_partitions=1) - data, _off_in, _off_out, ep, tc = snap.retrieve_snapshot(2, [(op, partition)]) + data, _off_in, _off_out, ep, tc, _migration_blob = snap.retrieve_snapshot(2, [(op, partition)]) # b should be overwritten by snapshot 2 assert data[(op, partition)]["a"] == 1 diff --git a/tests/unit/coordinator/test_coordinator_metadata.py b/tests/unit/coordinator/test_coordinator_metadata.py index 61cf784..5861663 100644 --- a/tests/unit/coordinator/test_coordinator_metadata.py +++ b/tests/unit/coordinator/test_coordinator_metadata.py @@ -286,3 +286,193 @@ async def test_rejects_non_graph(self): c = _coordinator() with pytest.raises(NotAStateflowGraphError): await c.update_stateflow_graph("not a graph") + + @pytest.mark.asyncio + async def test_pending_graph_set_before_gather(self): + """_pending_graph must be set BEFORE sending InitMigration to workers.""" + from copy import deepcopy + + c = _coordinator() + c.max_operator_parallelism = 4 + c.register_worker("10.0.0.1", 5000, 6000) + + # Manually schedule operator partitions (avoids Kafka in submit_stateflow_graph) + old_g = _graph(n_partitions=1, max_parallelism=4) + for _, operator in iter(old_g): + for p in range(4): + op_copy = deepcopy(operator) + if p >= operator.n_partitions: + op_copy.make_shadow() + c.worker_pool.schedule_operator_partition((op_copy.name, p), op_copy) + c.submitted_graph = old_g + + new_g = _graph(n_partitions=2, max_parallelism=4) + + # Track when _pending_graph was set relative to send_message calls + set_before_send = None + original_send = c.networking.send_message + + async def tracking_send(*args, **kwargs): + nonlocal set_before_send + if set_before_send is None: + set_before_send = c._pending_graph is not None + return await original_send(*args, **kwargs) + + c.networking.send_message = tracking_send + + await c.update_stateflow_graph(new_g) + assert set_before_send is True, "_pending_graph must be set before send_message" + assert c._pending_graph is new_g + + +# --------------------------------------------------------------------------- +# finalize_graph_update +# --------------------------------------------------------------------------- + + +class TestFinalizeGraphUpdate: + def test_noop_when_no_pending(self): + c = _coordinator() + c._pending_graph = None + c.finalize_graph_update() # should not raise + assert c.submitted_graph is None + + @pytest.mark.asyncio + async def test_commits_pending_graph(self): + c = _coordinator() + c.kafka_metadata_producer = AsyncMock() + old_graph = _graph(n_partitions=1) + new_graph = _graph(n_partitions=2) + c.submitted_graph = old_graph + c._pending_graph = new_graph + c.finalize_graph_update() + assert c.submitted_graph is new_graph + assert c._pending_graph is None + + +# --------------------------------------------------------------------------- +# set_migration_checkpoint / clear_migration_checkpoint +# --------------------------------------------------------------------------- + + +class TestMigrationCheckpoint: + def test_set_and_clear(self): + c = _coordinator() + g = _graph() + locations = {("users", 0): 1} + c.set_migration_checkpoint(g, locations) + assert c._migration_checkpoint_blob is not None + c.clear_migration_checkpoint() + assert c._migration_checkpoint_blob is None + + +# --------------------------------------------------------------------------- +# register_snapshot — migration checkpoint detection +# --------------------------------------------------------------------------- + + +class TestRegisterSnapshotCheckpointDetection: + def test_checkpoint_event_fires_when_snapshot_exceeds_baseline(self): + c = _coordinator() + pool = MagicMock() + c.worker_snapshot_ids = {0: 0} + c.prev_completed_snapshot_id = 0 + # Simulate checkpoint blob set with baseline 0 + c._migration_checkpoint_baseline_snap_id = 0 + c._migration_checkpoint_blob = b"fake" + + c.register_snapshot(0, 1, {}, {}, 1, 1, pool) + + assert c._migration_checkpoint_snapshot_complete.is_set() + assert c._migration_checkpoint_baseline_snap_id == -1 + + def test_checkpoint_event_does_not_fire_when_below_baseline(self): + c = _coordinator() + pool = MagicMock() + c.worker_snapshot_ids = {0: 0, 1: 0} + c.prev_completed_snapshot_id = 0 + c._migration_checkpoint_baseline_snap_id = 1 # baseline is 1 + + c.register_snapshot(0, 1, {}, {}, 1, 1, pool) + # Worker 1 still at 0, so cluster min is 0, not > baseline 1 + assert not c._migration_checkpoint_snapshot_complete.is_set() + + def test_pre_migration_snapshot_recorded(self): + c = _coordinator() + pool = MagicMock() + c.worker_snapshot_ids = {0: 0} + c.prev_completed_snapshot_id = 0 + c.pre_migration_snapshot_pending = True + + c.register_snapshot(0, 1, {}, {}, 1, 1, pool) + + assert c.pre_migration_snapshot_id == 1 + assert c.pre_migration_snapshot_pending is False + + @pytest.mark.asyncio + async def test_post_migration_finalize_fires(self): + c = _coordinator() + c.kafka_metadata_producer = AsyncMock() + pool = MagicMock() + c.worker_snapshot_ids = {0: 0} + c.prev_completed_snapshot_id = 0 + c.post_migration_snapshot_pending = True + c._pending_graph = _graph() + c.submitted_graph = _graph(n_partitions=1) + + c.register_snapshot(0, 1, {}, {}, 1, 1, pool) + + assert c.post_migration_snapshot_pending is False + + +# --------------------------------------------------------------------------- +# revert_worker_pool_to_submitted_graph +# --------------------------------------------------------------------------- + + +class TestRevertWorkerPool: + def test_revert_restores_old_layout(self): + from copy import deepcopy + + c = _coordinator() + c.max_operator_parallelism = 4 + wid, _ = c.register_worker("10.0.0.1", 5000, 6000) + + # Schedule initial partitions (like submit_stateflow_graph would) + old_graph = _graph(n_partitions=2, max_parallelism=4) + for _, operator in iter(old_graph): + for p in range(2): + c.worker_pool.schedule_operator_partition( + (operator.name, p), + deepcopy(operator), + ) + for p in range(2, 4): + op_copy = deepcopy(operator) + op_copy.make_shadow() + c.worker_pool.schedule_operator_partition( + (operator.name, p), + op_copy, + ) + c.submitted_graph = old_graph + + # Simulate update_stateflow_graph promoting shadow -> active + new_graph = _graph(n_partitions=4, max_parallelism=4) + for _, operator in iter(new_graph): + for partition in range(4): + op_copy = deepcopy(operator) + if partition >= operator.n_partitions: + op_copy.make_shadow() + c.worker_pool.update_operator((op_copy.name, partition), op_copy) + + # Partition 2 and 3 should now be active (4 partitions) + w = c.worker_pool.peek(wid) + assert not w.assigned_operators[("users", 2)].is_shadow + assert not w.assigned_operators[("users", 3)].is_shadow + + # Revert + c.revert_worker_pool_to_submitted_graph() + + # Partitions 2 and 3 should be shadow again (old graph had 2 partitions) + assert w.assigned_operators[("users", 2)].is_shadow + assert w.assigned_operators[("users", 3)].is_shadow + assert not w.assigned_operators[("users", 0)].is_shadow diff --git a/tests/unit/coordinator/test_migration_fault_tolerance.py b/tests/unit/coordinator/test_migration_fault_tolerance.py new file mode 100644 index 0000000..79f8961 --- /dev/null +++ b/tests/unit/coordinator/test_migration_fault_tolerance.py @@ -0,0 +1,527 @@ +"""Unit tests for migration fault tolerance features in coordinator. + +Covers: +- Pre-migration snapshot tracking (register_snapshot) +- Deferred graph update (finalize_graph_update) +- Post-migration snapshot triggers graph finalization (register_snapshot) +- Worker pool revert on recovery during migration +- Recovery condition: detect mid-migration AND post-migration-pre-snapshot crashes +- Full migration lifecycle scenarios +""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from styx.common.local_state_backends import LocalStateBackend +from styx.common.operator import Operator +from styx.common.stateflow_graph import StateflowGraph + +from coordinator.coordinator_metadata import Coordinator + + +def _mock_coordinator(): + return Coordinator(MagicMock(), MagicMock()) + + +def _make_graph(name="test-app", n_partitions=4, max_parallelism=8): + g = StateflowGraph(name, operator_state_backend=LocalStateBackend.DICT, max_operator_parallelism=max_parallelism) + op = Operator("ycsb", n_partitions=n_partitions) + g.add_operators(op) + return g + + +# --------------------------------------------------------------------------- +# Step 8: Pre-migration snapshot tracking +# --------------------------------------------------------------------------- + + +class TestPreMigrationSnapshotTracking: + def test_pre_migration_snapshot_recorded(self): + c = _mock_coordinator() + c.worker_snapshot_ids = {0: 3, 1: 3} + c.prev_completed_snapshot_id = 2 + c.pre_migration_snapshot_pending = True + c.pre_migration_snapshot_id = -1 + + pool = MagicMock() + c.register_snapshot(0, 3, {}, {}, 10, 100, pool) + + assert c.pre_migration_snapshot_id == 3 + assert c.pre_migration_snapshot_pending is False + + def test_pre_migration_snapshot_not_recorded_when_not_pending(self): + c = _mock_coordinator() + c.worker_snapshot_ids = {0: 3, 1: 3} + c.prev_completed_snapshot_id = 2 + c.pre_migration_snapshot_pending = False + c.pre_migration_snapshot_id = -1 + + pool = MagicMock() + c.register_snapshot(0, 3, {}, {}, 10, 100, pool) + + assert c.pre_migration_snapshot_id == -1 + + def test_pre_migration_snapshot_not_recorded_when_no_advance(self): + c = _mock_coordinator() + c.worker_snapshot_ids = {0: 3, 1: 2} # min is 2, same as prev + c.prev_completed_snapshot_id = 2 + c.pre_migration_snapshot_pending = True + c.pre_migration_snapshot_id = -1 + + pool = MagicMock() + c.register_snapshot(0, 3, {}, {}, 10, 100, pool) + + # Snapshot didn't advance, so pre-migration not recorded yet + assert c.pre_migration_snapshot_pending is True + assert c.pre_migration_snapshot_id == -1 + + +# --------------------------------------------------------------------------- +# Step 9: Deferred graph update +# --------------------------------------------------------------------------- + + +class TestDeferredGraphUpdate: + @pytest.mark.asyncio + async def test_finalize_graph_update_commits_pending(self): + c = _mock_coordinator() + c.submitted_graph = MagicMock(name="old_graph") + new_graph = MagicMock(name="new_graph") + new_graph.name = "test-app" + c._pending_graph = new_graph + c.kafka_metadata_producer = AsyncMock() + + c.finalize_graph_update() + + assert c.submitted_graph is new_graph + assert c._pending_graph is None + + def test_finalize_graph_update_noop_when_no_pending(self): + c = _mock_coordinator() + old_graph = MagicMock(name="old_graph") + c.submitted_graph = old_graph + c._pending_graph = None + + c.finalize_graph_update() + + assert c.submitted_graph is old_graph + + def test_pending_graph_initialized_to_none(self): + c = _mock_coordinator() + assert c._pending_graph is None + assert c.pre_migration_snapshot_id == -1 + assert c.pre_migration_snapshot_pending is False + assert c.post_migration_snapshot_pending is False + + +# --------------------------------------------------------------------------- +# Post-migration snapshot triggers graph finalization +# --------------------------------------------------------------------------- + + +class TestPostMigrationSnapshotFinalization: + """Tests that finalize_graph_update() is only called when a post-migration + snapshot completes — NOT immediately when migration finishes. + + Bug caught: if finalize is called at MigrationDone time and a crash occurs + before the next snapshot, recovery uses the NEW graph but the OLD snapshot, + losing ~50% of keys. + """ + + def test_post_migration_snapshot_pending_initialized_false(self): + c = _mock_coordinator() + assert c.post_migration_snapshot_pending is False + + @pytest.mark.asyncio + async def test_register_snapshot_finalizes_graph_when_post_migration_pending(self): + """When post_migration_snapshot_pending is True and a new snapshot + completes, finalize_graph_update() should be called.""" + c = _mock_coordinator() + c.worker_snapshot_ids = {0: 3, 1: 3} + c.prev_completed_snapshot_id = 2 + c.post_migration_snapshot_pending = True + + new_graph = MagicMock(name="new_graph") + new_graph.name = "test-app" + c._pending_graph = new_graph + c.submitted_graph = MagicMock(name="old_graph") + c.kafka_metadata_producer = AsyncMock() + + pool = MagicMock() + c.register_snapshot(0, 3, {}, {}, 10, 100, pool) + + # Graph should be finalized + assert c.submitted_graph is new_graph + assert c._pending_graph is None + assert c.post_migration_snapshot_pending is False + + def test_register_snapshot_does_not_finalize_when_not_pending(self): + """When post_migration_snapshot_pending is False, finalize should + NOT be called even if snapshot advances.""" + c = _mock_coordinator() + c.worker_snapshot_ids = {0: 3, 1: 3} + c.prev_completed_snapshot_id = 2 + c.post_migration_snapshot_pending = False + + old_graph = MagicMock(name="old_graph") + c.submitted_graph = old_graph + c._pending_graph = MagicMock(name="new_graph") + + pool = MagicMock() + c.register_snapshot(0, 3, {}, {}, 10, 100, pool) + + # Graph should NOT be finalized + assert c.submitted_graph is old_graph + + def test_register_snapshot_does_not_finalize_when_snapshot_not_advanced(self): + """When snapshot hasn't advanced (not all workers reported), finalize + should NOT be called.""" + c = _mock_coordinator() + c.worker_snapshot_ids = {0: 3, 1: 2} # min is 2, same as prev + c.prev_completed_snapshot_id = 2 + c.post_migration_snapshot_pending = True + + old_graph = MagicMock(name="old_graph") + c.submitted_graph = old_graph + c._pending_graph = MagicMock(name="new_graph") + + pool = MagicMock() + c.register_snapshot(0, 3, {}, {}, 10, 100, pool) + + # Still pending — snapshot hasn't advanced + assert c.post_migration_snapshot_pending is True + assert c.submitted_graph is old_graph + + @pytest.mark.asyncio + async def test_finalize_happens_before_s3_write(self): + """Graph finalization must happen when the snapshot advances, in the + same register_snapshot call that writes to S3.""" + c = _mock_coordinator() + c.worker_snapshot_ids = {0: 3, 1: 3} + c.prev_completed_snapshot_id = 2 + c.post_migration_snapshot_pending = True + + new_graph = MagicMock(name="new_graph") + new_graph.name = "test-app" + c._pending_graph = new_graph + c.submitted_graph = MagicMock(name="old_graph") + c.kafka_metadata_producer = AsyncMock() + + pool = MagicMock() + c.register_snapshot(0, 3, {}, {}, 10, 100, pool) + + # Both finalization and S3 write happened + assert c.submitted_graph is new_graph + c.s3_client.put_object.assert_called_once() + + +# --------------------------------------------------------------------------- +# Recovery condition detection +# --------------------------------------------------------------------------- + + +class TestRecoveryConditionDetection: + """Tests that recovery correctly detects both mid-migration crashes AND + post-migration-pre-snapshot crashes. + + Bug caught: if migration completed but no post-migration snapshot was taken, + recovery didn't detect this as migration-related, used the NEW graph with + the OLD snapshot, and lost half the data. + """ + + def test_mid_migration_detected(self): + """migration_in_progress=True with _pending_graph set → needs revert.""" + c = _mock_coordinator() + c._pending_graph = MagicMock() + migration_in_progress = True + + was_migrating = (migration_in_progress and c._pending_graph is not None) or c.post_migration_snapshot_pending + assert was_migrating is True + + def test_post_migration_pre_snapshot_detected(self): + """migration_in_progress=False but post_migration_snapshot_pending=True + → needs revert. This is the bug scenario.""" + c = _mock_coordinator() + c._pending_graph = MagicMock() + c.post_migration_snapshot_pending = True + migration_in_progress = False + + was_migrating = (migration_in_progress and c._pending_graph is not None) or c.post_migration_snapshot_pending + assert was_migrating is True + + def test_no_migration_not_detected(self): + """Normal state — no migration, no pending snapshot.""" + c = _mock_coordinator() + c._pending_graph = None + c.post_migration_snapshot_pending = False + migration_in_progress = False + + was_migrating = (migration_in_progress and c._pending_graph is not None) or c.post_migration_snapshot_pending + assert was_migrating is False + + def test_completed_migration_with_snapshot_not_detected(self): + """Migration completed AND post-migration snapshot taken → no revert needed.""" + c = _mock_coordinator() + c._pending_graph = None # finalize_graph_update() cleared it + c.post_migration_snapshot_pending = False + migration_in_progress = False + + was_migrating = (migration_in_progress and c._pending_graph is not None) or c.post_migration_snapshot_pending + assert was_migrating is False + + def test_saved_pending_graph_preserved_for_retry(self): + """When was_migrating, the pending graph should be saved for auto-retry.""" + c = _mock_coordinator() + new_graph = MagicMock(name="new_graph") + c._pending_graph = new_graph + c.post_migration_snapshot_pending = True + migration_in_progress = False + + was_migrating = (migration_in_progress and c._pending_graph is not None) or c.post_migration_snapshot_pending + saved_pending_graph = c._pending_graph if was_migrating else None + + assert saved_pending_graph is new_graph + + +# --------------------------------------------------------------------------- +# Full migration lifecycle scenarios +# --------------------------------------------------------------------------- + + +class TestMigrationLifecycleScenarios: + """End-to-end scenarios testing the full migration lifecycle through + coordinator metadata, catching the exact bugs that caused 34000 missed + messages in E2E tests.""" + + def _setup_coordinator(self, n_partitions=4, max_parallelism=8): + c = _mock_coordinator() + c.max_operator_parallelism = max_parallelism + c.submitted_graph = _make_graph("test-app", n_partitions, max_parallelism) + c.worker_pool = MagicMock() + c.kafka_metadata_producer = AsyncMock() + c.worker_snapshot_ids = {0: 0, 1: 0, 2: 0} + c.prev_completed_snapshot_id = 0 + return c + + def test_scenario_crash_after_migration_before_snapshot(self): + """Scenario: migration completes, graph NOT finalized yet, crash + before post-migration snapshot. + + Expected: submitted_graph still has OLD layout, revert works, + migration can be retried. + """ + c = self._setup_coordinator(n_partitions=4) + old_graph = c.submitted_graph + new_graph = _make_graph("test-app", n_partitions=8, max_parallelism=8) + + # Simulate: migration sets _pending_graph (update_stateflow_graph) + c._pending_graph = new_graph + + # Simulate: pre-migration snapshot completes (snap 1) + c.pre_migration_snapshot_pending = True + c.worker_snapshot_ids = {0: 1, 1: 1, 2: 1} + pool = MagicMock() + c.register_snapshot(0, 1, {}, {}, 10, 100, pool) + assert c.pre_migration_snapshot_id == 1 + + # Simulate: migration finishes → set post_migration_snapshot_pending + # (NOT calling finalize_graph_update) + c.post_migration_snapshot_pending = True + + # CRASH happens here — before post-migration snapshot + + # Recovery check + migration_in_progress = False # already cleared by _handle_migration_done + was_migrating = (migration_in_progress and c._pending_graph is not None) or c.post_migration_snapshot_pending + assert was_migrating is True + + # submitted_graph is still the OLD graph + assert c.submitted_graph is old_graph + assert c.submitted_graph.nodes["ycsb"].n_partitions == 4 + + # Revert works correctly + c.revert_worker_pool_to_submitted_graph() + # Partitions 4-7 should be shadow + for call_args in c.worker_pool.update_operator.call_args_list[4:]: + _, operator = call_args[0] + assert operator.is_shadow + + # _pending_graph preserved for retry + assert c._pending_graph is new_graph + + @pytest.mark.asyncio + async def test_scenario_crash_after_migration_after_snapshot(self): + """Scenario: migration completes, post-migration snapshot taken, + graph finalized, then crash. + + Expected: submitted_graph has NEW layout, no revert needed. + """ + c = self._setup_coordinator(n_partitions=4) + new_graph = _make_graph("test-app", n_partitions=8, max_parallelism=8) + + # Migration sets _pending_graph + c._pending_graph = new_graph + c.pre_migration_snapshot_pending = True + + # Pre-migration snapshot (snap 1) + c.worker_snapshot_ids = {0: 1, 1: 1, 2: 1} + pool = MagicMock() + c.register_snapshot(0, 1, {}, {}, 10, 100, pool) + + # Migration finishes + c.post_migration_snapshot_pending = True + + # Post-migration snapshot (snap 2) → should finalize + c.worker_snapshot_ids = {0: 2, 1: 2, 2: 2} + c.register_snapshot(0, 2, {}, {}, 20, 200, pool) + + # Graph finalized + assert c.submitted_graph is new_graph + assert c._pending_graph is None + assert c.post_migration_snapshot_pending is False + + # CRASH happens here — after post-migration snapshot + + # Recovery check: no migration recovery needed + migration_in_progress = False + was_migrating = (migration_in_progress and c._pending_graph is not None) or c.post_migration_snapshot_pending + assert was_migrating is False + + def test_scenario_crash_during_migration_phase_a(self): + """Scenario: crash during Phase A (migration in progress, + _pending_graph set, pre-migration snapshot may or may not exist). + + Expected: revert to old layout, retry migration. + """ + c = self._setup_coordinator(n_partitions=4) + old_graph = c.submitted_graph + new_graph = _make_graph("test-app", n_partitions=8, max_parallelism=8) + + # Migration starts + c._pending_graph = new_graph + c.pre_migration_snapshot_pending = True + migration_in_progress = True + + # CRASH during Phase A — no pre-migration snapshot completed + + was_migrating = (migration_in_progress and c._pending_graph is not None) or c.post_migration_snapshot_pending + assert was_migrating is True + + # Old graph preserved + assert c.submitted_graph is old_graph + saved = c._pending_graph + assert saved is new_graph + + @pytest.mark.asyncio + async def test_scenario_normal_migration_no_crash(self): + """Scenario: migration completes normally, snapshot taken, no crash. + + Expected: graph finalized after post-migration snapshot. + """ + c = self._setup_coordinator(n_partitions=4) + new_graph = _make_graph("test-app", n_partitions=8, max_parallelism=8) + + # Step 1: Migration starts + c._pending_graph = new_graph + c.pre_migration_snapshot_pending = True + + # Step 2: Pre-migration snapshot (snap 1) + c.worker_snapshot_ids = {0: 1, 1: 1, 2: 1} + pool = MagicMock() + c.register_snapshot(0, 1, {}, {}, 10, 100, pool) + assert c.pre_migration_snapshot_id == 1 + + # Step 3: Migration finishes + c.post_migration_snapshot_pending = True + assert c.submitted_graph is not new_graph # NOT finalized yet + + # Step 4: Post-migration snapshot (snap 2) + c.worker_snapshot_ids = {0: 2, 1: 2, 2: 2} + c.register_snapshot(0, 2, {}, {}, 20, 200, pool) + + # Step 5: Graph finalized + assert c.submitted_graph is new_graph + assert c._pending_graph is None + assert c.post_migration_snapshot_pending is False + + # Step 6: Subsequent snapshots work normally + c.worker_snapshot_ids = {0: 3, 1: 3, 2: 3} + c.register_snapshot(0, 3, {}, {}, 30, 300, pool) + assert c.submitted_graph is new_graph # still new graph + + def test_scenario_state_cleanup_on_recovery(self): + """Verify all migration state is properly cleared during recovery.""" + c = self._setup_coordinator(n_partitions=4) + new_graph = _make_graph("test-app", n_partitions=8, max_parallelism=8) + + # Set up mid-migration state + c._pending_graph = new_graph + c.pre_migration_snapshot_pending = True + c.pre_migration_snapshot_id = 1 + c.post_migration_snapshot_pending = True + + # Simulate recovery clearing state (mirrors _perform_recovery step 1c) + saved_pending_graph = c._pending_graph + c.pre_migration_snapshot_pending = False + c.post_migration_snapshot_pending = False + c._pending_graph = None + + # All state cleared + assert c.pre_migration_snapshot_pending is False + assert c.post_migration_snapshot_pending is False + assert c._pending_graph is None + # But we saved the graph for retry + assert saved_pending_graph is new_graph + + +# --------------------------------------------------------------------------- +# Worker pool revert on recovery during migration +# --------------------------------------------------------------------------- + + +class TestRevertWorkerPoolToSubmittedGraph: + def _setup_coordinator_with_graph(self, n_partitions, max_parallelism=8): + """Create a coordinator with a submitted graph and worker pool.""" + c = _mock_coordinator() + c.max_operator_parallelism = max_parallelism + + g = StateflowGraph( + "test-app", operator_state_backend=LocalStateBackend.DICT, max_operator_parallelism=max_parallelism + ) + op = Operator("ycsb", n_partitions=n_partitions) + g.add_operators(op) + c.submitted_graph = g + + # Simulate worker pool with update_operator tracking + c.worker_pool = MagicMock() + return c + + def test_revert_restores_shadow_partitions(self): + c = self._setup_coordinator_with_graph(n_partitions=4, max_parallelism=8) + + c.revert_worker_pool_to_submitted_graph() + + # Should have called update_operator for all 8 partitions + assert c.worker_pool.update_operator.call_count == 8 + + # Partitions 0-3 should be active (not shadow) + for call_args in c.worker_pool.update_operator.call_args_list[:4]: + op_part, operator = call_args[0] + assert op_part[0] == "ycsb" + assert not operator.is_shadow + + # Partitions 4-7 should be shadow + for call_args in c.worker_pool.update_operator.call_args_list[4:]: + op_part, operator = call_args[0] + assert op_part[0] == "ycsb" + assert operator.is_shadow + + def test_revert_noop_when_no_graph(self): + c = _mock_coordinator() + c.submitted_graph = None + c.max_operator_parallelism = None + c.worker_pool = MagicMock() + + c.revert_worker_pool_to_submitted_graph() + + c.worker_pool.update_operator.assert_not_called() diff --git a/tests/unit/coordinator/test_worker_pool.py b/tests/unit/coordinator/test_worker_pool.py index fe5f676..aff29f7 100644 --- a/tests/unit/coordinator/test_worker_pool.py +++ b/tests/unit/coordinator/test_worker_pool.py @@ -212,6 +212,23 @@ def test_update_operator_replaces_operator_in_place(self): w = pool.peek(1) assert w.assigned_operators[("users", 0)] is op_new + def test_update_operator_skips_unknown_partition(self): + """update_operator should no-op when partition is not mapped to any worker.""" + pool = WorkerPool() + pool.register_worker("127.0.0.1", 5000, 6000) + # ("users", 99) was never scheduled — operator_partition_to_worker has no entry + pool.update_operator(("users", 99), _mock_operator()) # should not raise + + def test_update_operator_skips_removed_worker(self): + """update_operator should no-op when the owning worker was already removed.""" + pool = WorkerPool() + pool.register_worker("127.0.0.1", 5000, 6000) # id=1 + op = _mock_operator() + pool.schedule_operator_partition(("users", 0), op) + # Simulate worker death: remove worker but leave stale mapping + pool.remove_worker(1) + pool.update_operator(("users", 0), _mock_operator()) # should not raise + # --------------------------------------------------------------------------- # WorkerPool — query methods diff --git a/tests/unit/styx_package/test_logging.py b/tests/unit/styx_package/test_logging.py new file mode 100644 index 0000000..422c8e5 --- /dev/null +++ b/tests/unit/styx_package/test_logging.py @@ -0,0 +1,38 @@ +"""Unit tests for styx.common.logging — LOG_LEVEL env var support.""" + +import importlib +import logging as _stdlib_logging +import os + + +def test_default_log_level_is_warning(): + """Without LOG_LEVEL env var, logger should be WARNING.""" + os.environ.pop("LOG_LEVEL", None) + import styx.common.logging as mod + + importlib.reload(mod) + assert mod.logging.level == _stdlib_logging.WARNING + + +def test_log_level_from_env(monkeypatch): + """LOG_LEVEL=DEBUG should set the logger to DEBUG.""" + monkeypatch.setenv("LOG_LEVEL", "DEBUG") + import styx.common.logging as mod + + importlib.reload(mod) + assert mod.logging.level == _stdlib_logging.DEBUG + # Reset + monkeypatch.delenv("LOG_LEVEL", raising=False) + importlib.reload(mod) + + +def test_invalid_log_level_falls_back_to_warning(monkeypatch): + """An unrecognized LOG_LEVEL should fall back to WARNING.""" + monkeypatch.setenv("LOG_LEVEL", "INVALID_LEVEL") + import styx.common.logging as mod + + importlib.reload(mod) + assert mod.logging.level == _stdlib_logging.WARNING + # Reset + monkeypatch.delenv("LOG_LEVEL", raising=False) + importlib.reload(mod) diff --git a/tests/unit/styx_package/test_message_types.py b/tests/unit/styx_package/test_message_types.py index 7002296..82076c6 100644 --- a/tests/unit/styx_package/test_message_types.py +++ b/tests/unit/styx_package/test_message_types.py @@ -10,7 +10,7 @@ def test_is_int_enum(self): assert issubclass(MessageType, IntEnum) def test_total_count(self): - assert len(MessageType) == 42 + assert len(MessageType) == 43 def test_all_values_unique(self): values = [m.value for m in MessageType] @@ -18,7 +18,7 @@ def test_all_values_unique(self): def test_values_are_contiguous_from_zero(self): values = sorted(m.value for m in MessageType) - assert values == list(range(42)) + assert values == list(range(43)) def test_known_values(self): assert MessageType.RunFunRemote == 0 @@ -26,6 +26,7 @@ def test_known_values(self): assert MessageType.ClientMsg == 16 assert MessageType.InitDataComplete == 40 assert MessageType.UpdateExecutionGraph == 41 + assert MessageType.SnapMigrationReassign == 42 def test_int_comparison(self): assert MessageType.ClientMsg == 16 diff --git a/tests/unit/worker/test_async_snapshots.py b/tests/unit/worker/test_async_snapshots.py index afc77dd..5fd97ba 100644 --- a/tests/unit/worker/test_async_snapshots.py +++ b/tests/unit/worker/test_async_snapshots.py @@ -158,12 +158,13 @@ def test_puts_object(self, mock_mk): class TestRetrieveSnapshot: def test_snapshot_id_negative_one_returns_empty(self): s = _snap() - data, tp_off, tp_out_off, epoch, t_counter = s.retrieve_snapshot(-1, []) + data, tp_off, tp_out_off, epoch, t_counter, migration_blob = s.retrieve_snapshot(-1, []) assert data == {} assert tp_off == {} assert tp_out_off == {} assert epoch == 0 assert t_counter == 0 + assert migration_blob is None assert s.snapshot_id == 0 # -1 + 1 @patch("worker.fault_tolerance.async_snapshots._get_s3_client") @@ -176,7 +177,7 @@ def test_retrieve_with_no_files(self, mock_mk): mock_s3.get_paginator.return_value = paginator s = _snap() - data, _tp_off, _tp_out_off, _epoch, _t_counter = s.retrieve_snapshot( + data, _tp_off, _tp_out_off, _epoch, _t_counter, _migration_blob = s.retrieve_snapshot( 0, [("users", 0)], ) @@ -267,3 +268,161 @@ def test_empty_contents(self): s = _snap() result = s._list_bin_keys(mock_s3, "data/op/0/") assert result == [] + + +# --------------------------------------------------------------------------- +# _load_operator_state — tombstone handling +# --------------------------------------------------------------------------- + + +class TestLoadOperatorStateTombstones: + @patch("worker.fault_tolerance.async_snapshots._get_s3_client") + def test_tombstone_removes_key_from_previous_delta(self, mock_mk): + """A None value in a subsequent delta should delete the key.""" + mock_s3 = MagicMock() + mock_mk.return_value = mock_s3 + + s = _snap() + # Simulate two snapshot files for the same partition + paginator = MagicMock() + paginator.paginate.return_value = [ + { + "Contents": [ + {"Key": "data/users/0/1.bin"}, # base snapshot + {"Key": "data/users/0/2.bin"}, # delta with tombstone + ] + } + ] + mock_s3.get_paginator.return_value = paginator + + # First call returns base data, second returns delta with tombstone + call_count = [0] + + def fake_get_zstd_msgpack(_s3_client, _key): + call_count[0] += 1 + if call_count[0] == 1: + return {"key_a": "val_a", "key_b": "val_b"} + return {"key_b": None, "key_c": "val_c"} # tombstone key_b, add key_c + + s._get_zstd_msgpack = fake_get_zstd_msgpack + + data = s._load_operator_state(mock_s3, 2, [("users", 0)]) + assert "key_a" in data[("users", 0)] + assert "key_b" not in data[("users", 0)] # tombstoned + assert data[("users", 0)]["key_c"] == "val_c" + + @patch("worker.fault_tolerance.async_snapshots._get_s3_client") + def test_first_delta_filters_tombstones(self, mock_mk): + """First delta for a partition should filter out None values.""" + mock_s3 = MagicMock() + mock_mk.return_value = mock_s3 + + s = _snap() + paginator = MagicMock() + paginator.paginate.return_value = [{"Contents": [{"Key": "data/users/0/1.bin"}]}] + mock_s3.get_paginator.return_value = paginator + + s._get_zstd_msgpack = lambda _s3_client, _key: {"key_a": "val_a", "key_b": None} + + data = s._load_operator_state(mock_s3, 1, [("users", 0)]) + assert "key_a" in data[("users", 0)] + assert "key_b" not in data[("users", 0)] + + +# --------------------------------------------------------------------------- +# _load_sequencer_state — 5-element and 4-element tuple formats +# --------------------------------------------------------------------------- + + +class TestLoadSequencerStateTupleFormats: + @patch("worker.fault_tolerance.async_snapshots._get_s3_client") + def test_5_element_tuple_with_migration_blob(self, mock_mk): + """Sequencer file with migration checkpoint blob (5 elements).""" + mock_s3 = MagicMock() + mock_mk.return_value = mock_s3 + + s = _snap() + paginator = MagicMock() + paginator.paginate.return_value = [{"Contents": [{"Key": "sequencer/1.bin"}]}] + mock_s3.get_paginator.return_value = paginator + + blob = b"migration_data" + s._get_zstd_msgpack = lambda _s3_client, _key: ( + {("users", 0): 10}, + {("users", 0): 5}, + 3, + 42, + blob, + ) + + tp_off, tp_out, epoch, t_counter, migration_blob = s._load_sequencer_state(mock_s3, 1) + assert tp_off == {("users", 0): 10} + assert tp_out == {("users", 0): 5} + assert epoch == 3 + assert t_counter == 42 + assert migration_blob == blob + + @patch("worker.fault_tolerance.async_snapshots._get_s3_client") + def test_4_element_tuple_legacy_format(self, mock_mk): + """Legacy sequencer file without migration blob (4 elements).""" + mock_s3 = MagicMock() + mock_mk.return_value = mock_s3 + + s = _snap() + paginator = MagicMock() + paginator.paginate.return_value = [{"Contents": [{"Key": "sequencer/1.bin"}]}] + mock_s3.get_paginator.return_value = paginator + + s._get_zstd_msgpack = lambda _s3_client, _key: ( + {("users", 0): 10}, + {("users", 0): 5}, + 3, + 42, + ) + + tp_off, tp_out, epoch, t_counter, migration_blob = s._load_sequencer_state(mock_s3, 1) + assert tp_off == {("users", 0): 10} + assert tp_out == {("users", 0): 5} + assert epoch == 3 + assert t_counter == 42 + assert migration_blob is None + + @patch("worker.fault_tolerance.async_snapshots._get_s3_client") + def test_5_element_with_none_migration_blob(self, mock_mk): + """5-element tuple with None migration blob (normal snapshot after migration cleared).""" + mock_s3 = MagicMock() + mock_mk.return_value = mock_s3 + + s = _snap() + paginator = MagicMock() + paginator.paginate.return_value = [{"Contents": [{"Key": "sequencer/1.bin"}]}] + mock_s3.get_paginator.return_value = paginator + + s._get_zstd_msgpack = lambda _s3_client, _key: ( + {("users", 0): 10}, + {("users", 0): 5}, + 3, + 42, + None, + ) + + *_, migration_blob = s._load_sequencer_state(mock_s3, 1) + assert migration_blob is None + + @patch("worker.fault_tolerance.async_snapshots._get_s3_client") + def test_no_sequencer_files_returns_defaults(self, mock_mk): + """When no sequencer files exist, defaults are returned.""" + mock_s3 = MagicMock() + mock_mk.return_value = mock_s3 + + s = _snap() + paginator = MagicMock() + paginator.paginate.return_value = [{"Contents": []}] + mock_s3.get_paginator.return_value = paginator + + tp_off, tp_out, epoch, t_counter, migration_blob = s._load_sequencer_state(mock_s3, 1) + assert tp_off == {} + assert tp_out == {} + assert epoch == 0 + assert t_counter == 0 + assert migration_blob is None diff --git a/tests/unit/worker/test_kafka_batch_egress.py b/tests/unit/worker/test_kafka_batch_egress.py index 198a7cd..8a271e2 100644 --- a/tests/unit/worker/test_kafka_batch_egress.py +++ b/tests/unit/worker/test_kafka_batch_egress.py @@ -50,8 +50,7 @@ def test_restart_flag(self): class TestClearMessagesSentBeforeRecovery: def test_clears(self): e = _egress() - tp = TopicPartition("users--OUT", 0) - e.messages_sent_before_recovery[tp].add(b"key1") + e.messages_sent_before_recovery.add(b"key1") e.clear_messages_sent_before_recovery() assert len(e.messages_sent_before_recovery) == 0 @@ -88,13 +87,30 @@ async def test_send_skips_already_sent(self): mock_producer = MagicMock() e.kafka_egress_producer = mock_producer - tp = TopicPartition("users--OUT", 0) - e.messages_sent_before_recovery[tp].add(b"key1") + # Flat dedup set — partition-agnostic + e.messages_sent_before_recovery.add(b"key1") await e.send(b"key1", b"value1", "users", 0) assert len(e.batch) == 0 # skipped mock_producer.send.assert_not_called() - assert b"key1" not in e.messages_sent_before_recovery[tp] + assert b"key1" not in e.messages_sent_before_recovery + + @pytest.mark.asyncio + async def test_send_dedup_is_cross_partition(self): + """A key sent to partition 6 pre-crash should be deduped when + replayed to partition 2 after recovery.""" + e = _egress() + e.started.set() + mock_producer = MagicMock() + e.kafka_egress_producer = mock_producer + + # Key was originally sent on partition 6 (post-migration) + e.messages_sent_before_recovery.add(b"req42") + + # Replay sends to partition 2 (pre-migration layout) + await e.send(b"req42", b"val", "users", 2) + assert len(e.batch) == 0 # deduped cross-partition + mock_producer.send.assert_not_called() # --------------------------------------------------------------------------- @@ -152,6 +168,7 @@ def test_adds_keys(self): msg2.key = b"key2" e.process_messages_sent_before_recovery([msg1, msg2], current_offsets, all_done) - assert b"key1" in e.messages_sent_before_recovery[tp] - assert b"key2" in e.messages_sent_before_recovery[tp] + # Flat set now — keys from any partition are in the same set + assert b"key1" in e.messages_sent_before_recovery + assert b"key2" in e.messages_sent_before_recovery assert all_done[tp] is True # msg2 offset >= 10 - 1 diff --git a/tests/unit/worker/test_kafka_egress_coverage.py b/tests/unit/worker/test_kafka_egress_coverage.py index 873328a..97e76b5 100644 --- a/tests/unit/worker/test_kafka_egress_coverage.py +++ b/tests/unit/worker/test_kafka_egress_coverage.py @@ -1,13 +1,13 @@ """Additional coverage tests for worker/egress/styx_kafka_batch_egress.py Covers: send with unstarted event, send_immediate dedup path, -send_message_to_topic, stop. +send_message_to_topic, stop, dedup counters, run_dedup_scan_and_mark_started, +dedup_output_offsets constructor parameter. """ import asyncio from unittest.mock import AsyncMock, MagicMock -from aiokafka import TopicPartition import pytest from worker.egress.styx_kafka_batch_egress import StyxKafkaBatchEgress @@ -17,10 +17,10 @@ # --------------------------------------------------------------------------- -def _egress(offsets=None, restart=False): +def _egress(offsets=None, restart=False, dedup_offsets=None): if offsets is None: offsets = {("users", 0): -1, ("users", 1): -1} - return StyxKafkaBatchEgress(offsets, restart_after_failure=restart) + return StyxKafkaBatchEgress(offsets, restart_after_failure=restart, dedup_output_offsets=dedup_offsets) # --------------------------------------------------------------------------- @@ -75,12 +75,12 @@ async def test_send_immediate_dedup(self): mock_producer = MagicMock() e.kafka_egress_producer = mock_producer - tp = TopicPartition("users--OUT", 0) - e.messages_sent_before_recovery[tp].add(b"key1") + # Flat dedup set — partition-agnostic + e.messages_sent_before_recovery.add(b"key1") await e.send_immediate(b"key1", b"value1", "users", 0) mock_producer.send_and_wait.assert_not_called() - assert b"key1" not in e.messages_sent_before_recovery[tp] + assert b"key1" not in e.messages_sent_before_recovery # --------------------------------------------------------------------------- @@ -161,3 +161,185 @@ async def test_batch_multiple_partitions(self): assert e.topic_partition_output_offsets[("users", 0)] == 10 assert e.topic_partition_output_offsets[("users", 1)] == 20 assert e.batch == [] + + +# --------------------------------------------------------------------------- +# Constructor — dedup_output_offsets parameter +# --------------------------------------------------------------------------- + + +class TestDedupOutputOffsets: + def test_default_dedup_offsets_is_empty_dict(self): + e = _egress() + assert e.dedup_output_offsets == {} + + def test_custom_dedup_offsets_stored(self): + offsets = {("users", 0): 5, ("orders", 1): 10} + e = _egress(dedup_offsets=offsets) + assert e.dedup_output_offsets == offsets + + def test_dedup_counters_start_at_zero(self): + e = _egress() + assert e._dedup_suppressed == 0 + assert e._dedup_sent == 0 + + +# --------------------------------------------------------------------------- +# start — non-recovery sets started immediately +# --------------------------------------------------------------------------- + + +class TestStartNonRecovery: + @pytest.mark.asyncio + async def test_start_non_recovery_sets_started(self): + e = _egress(restart=False) + mock_producer = MagicMock() + mock_producer.start = AsyncMock() + # Patch AIOKafkaProducer to return our mock + e.kafka_egress_producer = mock_producer + # Simulate the start logic without real Kafka + # Just call the post-start code path + if not e.restart_after_failure: + e.started.set() + assert e.started.is_set() + + @pytest.mark.asyncio + async def test_start_recovery_does_not_set_started(self): + e = _egress(restart=True) + # Recovery path: started should NOT be set in start() + if not e.restart_after_failure: + e.started.set() + assert not e.started.is_set() + + +# --------------------------------------------------------------------------- +# run_dedup_scan_and_mark_started +# --------------------------------------------------------------------------- + + +class TestRunDedupScanAndMarkStarted: + @pytest.mark.asyncio + async def test_non_recovery_just_sets_started(self): + """When restart_after_failure=False, run_dedup_scan_and_mark_started just sets started.""" + e = _egress(restart=False) + await e.run_dedup_scan_and_mark_started() + assert e.started.is_set() + + @pytest.mark.asyncio + async def test_recovery_runs_dedup_scan(self): + """When restart_after_failure=True, dedup scan runs then started is set.""" + e = _egress(restart=True, dedup_offsets={}) + # Mock get_messages_sent_before_recovery (it requires Kafka) + e.get_messages_sent_before_recovery = AsyncMock() + await e.run_dedup_scan_and_mark_started() + e.get_messages_sent_before_recovery.assert_called_once() + assert e.started.is_set() + + +# --------------------------------------------------------------------------- +# stop — dedup stats logging +# --------------------------------------------------------------------------- + + +class TestStopDedupStats: + @pytest.mark.asyncio + async def test_stop_logs_dedup_stats_when_nonzero(self): + e = _egress() + e.worker_id = 0 + e._dedup_suppressed = 3 + e._dedup_sent = 10 + mock_producer = MagicMock() + mock_producer.stop = AsyncMock() + e.kafka_egress_producer = mock_producer + + await e.stop() + mock_producer.stop.assert_called_once() + + @pytest.mark.asyncio + async def test_stop_skips_log_when_zero_counters(self): + e = _egress() + e.worker_id = 0 + e._dedup_suppressed = 0 + e._dedup_sent = 0 + mock_producer = MagicMock() + mock_producer.stop = AsyncMock() + e.kafka_egress_producer = mock_producer + + await e.stop() + mock_producer.stop.assert_called_once() + + +# --------------------------------------------------------------------------- +# send — dedup counters +# --------------------------------------------------------------------------- + + +class TestSendDedupCounters: + @pytest.mark.asyncio + async def test_send_increments_dedup_sent(self): + e = _egress() + e.started.set() + mock_producer = MagicMock() + future = asyncio.Future() + future.set_result(MagicMock()) + mock_producer.send = AsyncMock(return_value=future) + e.kafka_egress_producer = mock_producer + + await e.send(b"key1", b"value1", "users", 0) + assert e._dedup_sent == 1 + + @pytest.mark.asyncio + async def test_send_increments_dedup_suppressed(self): + e = _egress() + e.started.set() + e.kafka_egress_producer = MagicMock() + e.messages_sent_before_recovery.add(b"key1") + + await e.send(b"key1", b"value1", "users", 0) + assert e._dedup_suppressed == 1 + assert e._dedup_sent == 0 + + +# --------------------------------------------------------------------------- +# send_immediate — dedup counters +# --------------------------------------------------------------------------- + + +class TestSendImmediateDedupCounters: + @pytest.mark.asyncio + async def test_send_immediate_increments_dedup_sent(self): + e = _egress() + e.started.set() + meta = MagicMock() + meta.offset = 10 + mock_producer = MagicMock() + mock_producer.send_and_wait = AsyncMock(return_value=meta) + e.kafka_egress_producer = mock_producer + + await e.send_immediate(b"key1", b"value1", "users", 0) + assert e._dedup_sent == 1 + + @pytest.mark.asyncio + async def test_send_immediate_increments_dedup_suppressed(self): + e = _egress() + e.started.set() + e.kafka_egress_producer = MagicMock() + e.messages_sent_before_recovery.add(b"key1") + + await e.send_immediate(b"key1", b"value1", "users", 0) + assert e._dedup_suppressed == 1 + assert e._dedup_sent == 0 + + +# --------------------------------------------------------------------------- +# get_messages_sent_before_recovery — empty partitions early return +# --------------------------------------------------------------------------- + + +class TestGetMessagesSentBeforeRecoveryEmpty: + @pytest.mark.asyncio + async def test_empty_dedup_offsets_returns_immediately(self): + e = _egress(dedup_offsets={}) + await e.get_messages_sent_before_recovery() + # Should return without creating a consumer + assert len(e.messages_sent_before_recovery) == 0 diff --git a/tests/unit/worker/test_kafka_ingress.py b/tests/unit/worker/test_kafka_ingress.py index 0e0b347..28075be 100644 --- a/tests/unit/worker/test_kafka_ingress.py +++ b/tests/unit/worker/test_kafka_ingress.py @@ -1,7 +1,10 @@ """Unit tests for worker/ingress/styx_kafka_ingress.py""" -from unittest.mock import MagicMock +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch +from aiokafka.errors import KafkaConnectionError, UnknownTopicOrPartitionError +import pytest from styx.common.message_types import MessageType from styx.common.run_func_payload import RunFuncPayload @@ -142,3 +145,261 @@ async def test_invalid_message_type(self): msg = _make_msg() ingress.handle_message_from_kafka(msg) sequencer.sequence.assert_not_called() + + +# --------------------------------------------------------------------------- +# Shadow partition redirect (post-recovery with client partitioner mismatch) +# --------------------------------------------------------------------------- + + +def _make_ingress_with_shadow_partitions(): + """Create an ingress with 4 active + 4 shadow partitions. + + Simulates post-recovery state where the system has 4 active partitions + but the client sends to 8 partitions. Shadow partitions (4-7) should + redirect messages to the correct active partition (0-3). + """ + networking = MagicMock() + networking.get_msg_type = MagicMock(return_value=MessageType.ClientMsg) + networking.in_the_same_network = MagicMock(return_value=True) + + sequencer = MagicMock() + state = MagicMock() + state.exists = MagicMock(return_value=False) # shadow partitions have no state + + dns = { + "ycsb": { + 0: ("host1", 5000, 6000), + 1: ("host2", 5001, 6001), + 2: ("host3", 5002, 6002), + 3: ("host3", 5003, 6003), + 4: ("host1", 5000, 6000), # shadow + 5: ("host2", 5001, 6001), # shadow + 6: ("host3", 5002, 6002), # shadow + 7: ("host3", 5003, 6003), # shadow + }, + } + + registered_operators = {} + for part in range(8): + op = MagicMock() + # 4-partition partitioner: always maps to 0-3 + op.which_partition = MagicMock(side_effect=lambda key, _p=part: hash(key) % 4) + op.dns = dns + registered_operators[("ycsb", part)] = op + + ingress = StyxKafkaIngress( + networking=networking, + sequencer=sequencer, + state=state, + registered_operators=registered_operators, + worker_id=0, + kafka_url="localhost:9092", + epoch_interval_ms=1, + sequence_max_size=1000, + ) + return ingress, networking, sequencer, state + + +class TestShadowPartitionRedirect: + """Tests that messages arriving on shadow partitions (4-7) are correctly + redirected to active partitions (0-3) after recovery. + + Bug scenario: client updates its partitioner to 8 partitions during + migration, worker crashes, recovery restores 4-partition layout. Client + continues sending to partitions 4-7 which are shadow. Without proper + redirect, ~50% of messages are lost. + """ + + def test_message_on_shadow_partition_redirected_locally(self): + """Message sent to shadow partition 4 should be redirected to the + correct active partition and sequenced locally.""" + ingress, networking, sequencer, _state = _make_ingress_with_shadow_partitions() + # Client sends to partition 4 (shadow) with key "test_key" + networking.decode_message.return_value = ("ycsb", "test_key", "read", (), 4) + # which_partition returns an active partition (0-3) + ingress.registered_operators[("ycsb", 4)].which_partition = MagicMock(return_value=0) + + msg = _make_msg(partition=4) + ingress.handle_message_from_kafka(msg) + + sequencer.sequence.assert_called_once() + payload = sequencer.sequence.call_args[0][0] + assert payload.partition == 0 # redirected to active partition + + def test_message_on_shadow_partition_redirected_remotely(self): + """Message on shadow partition redirected to remote worker via + WrongPartitionRequest.""" + ingress, networking, sequencer, _state = _make_ingress_with_shadow_partitions() + networking.decode_message.return_value = ("ycsb", "test_key", "read", (), 5) + ingress.registered_operators[("ycsb", 5)].which_partition = MagicMock(return_value=1) + networking.in_the_same_network.return_value = False + + msg = _make_msg(partition=5) + ingress.handle_message_from_kafka(msg) + + sequencer.sequence.assert_not_called() + assert len(ingress.send_message_tasks) == 1 + + def test_message_on_active_partition_not_redirected(self): + """Message on active partition 0 with existing key is processed normally.""" + ingress, networking, sequencer, state = _make_ingress_with_shadow_partitions() + state.exists.return_value = True + networking.decode_message.return_value = ("ycsb", "test_key", "read", (), 0) + + msg = _make_msg(partition=0) + ingress.handle_message_from_kafka(msg) + + sequencer.sequence.assert_called_once() + payload = sequencer.sequence.call_args[0][0] + assert payload.partition == 0 + + def test_redirect_preserves_kafka_offset_metadata(self): + """Redirected messages must preserve kafka_offset and + kafka_ingress_partition for offset tracking. + + Bug: if these aren't preserved, _advance_offsets can't track + progress and messages are re-consumed after migration. + """ + ingress, networking, sequencer, _state = _make_ingress_with_shadow_partitions() + networking.decode_message.return_value = ("ycsb", "test_key", "read", (), 6) + ingress.registered_operators[("ycsb", 6)].which_partition = MagicMock(return_value=2) + + msg = _make_msg(partition=6, offset=42) + ingress.handle_message_from_kafka(msg) + + payload = sequencer.sequence.call_args[0][0] + assert payload.kafka_offset == 42 + assert payload.kafka_ingress_partition == 6 + + def test_all_shadow_partitions_redirect(self): + """Verify all shadow partitions (4-7) redirect correctly.""" + ingress, networking, sequencer, _state = _make_ingress_with_shadow_partitions() + + for shadow_part in range(4, 8): + sequencer.reset_mock() + active_part = shadow_part % 4 + networking.decode_message.return_value = ( + "ycsb", + f"key_{shadow_part}", + "read", + (), + shadow_part, + ) + ingress.registered_operators[("ycsb", shadow_part)].which_partition = MagicMock(return_value=active_part) + + msg = _make_msg(partition=shadow_part) + ingress.handle_message_from_kafka(msg) + + sequencer.sequence.assert_called_once() + payload = sequencer.sequence.call_args[0][0] + assert payload.partition == active_part, f"Shadow partition {shadow_part} should redirect to {active_part}" + + +# --------------------------------------------------------------------------- +# Kafka consumer startup exception handling +# --------------------------------------------------------------------------- + + +class TestKafkaConsumerStartupExceptionHandling: + """Tests that the Kafka consumer startup loop correctly catches both + UnknownTopicOrPartitionError and KafkaConnectionError. + + Bug caught: Python 2 syntax `except A, B:` (which means `except A as B:` + in Python 3) only catches A and silently shadows the name B. This caused + KafkaConnectionError to crash the ingress during recovery, killing all + message consumption for that worker. + """ + + @pytest.mark.asyncio + async def test_kafka_connection_error_is_caught_and_retried(self): + """KafkaConnectionError during consumer.start() should be caught + and retried, not crash the ingress.""" + ingress, *_ = _make_ingress() + mock_consumer = AsyncMock() + # First call raises KafkaConnectionError, second succeeds + mock_consumer.start = AsyncMock( + side_effect=[KafkaConnectionError(), None], + ) + # getmany raises CancelledError to exit the consumer loop cleanly + mock_consumer.getmany = AsyncMock(side_effect=asyncio.CancelledError()) + mock_consumer.stop = AsyncMock() + + with ( + patch("worker.ingress.styx_kafka_ingress.AIOKafkaConsumer", return_value=mock_consumer), + patch("worker.ingress.styx_kafka_ingress.asyncio.sleep", new_callable=AsyncMock), + pytest.raises(asyncio.CancelledError), + ): + await ingress.start_kafka_consumer([], {}) + + # Consumer should have been started twice (retry after error) + assert mock_consumer.start.call_count == 2 + + @pytest.mark.asyncio + async def test_unknown_topic_error_is_caught_and_retried(self): + """UnknownTopicOrPartitionError during consumer.start() should be + caught and retried.""" + ingress, *_ = _make_ingress() + mock_consumer = AsyncMock() + mock_consumer.start = AsyncMock( + side_effect=[UnknownTopicOrPartitionError(), None], + ) + mock_consumer.getmany = AsyncMock(side_effect=asyncio.CancelledError()) + mock_consumer.stop = AsyncMock() + + with ( + patch("worker.ingress.styx_kafka_ingress.AIOKafkaConsumer", return_value=mock_consumer), + patch("worker.ingress.styx_kafka_ingress.asyncio.sleep", new_callable=AsyncMock), + pytest.raises(asyncio.CancelledError), + ): + await ingress.start_kafka_consumer([], {}) + + assert mock_consumer.start.call_count == 2 + + @pytest.mark.asyncio + async def test_other_exceptions_propagate(self): + """Non-Kafka exceptions should propagate, not be silently swallowed.""" + ingress, *_ = _make_ingress() + mock_consumer = AsyncMock() + mock_consumer.start = AsyncMock(side_effect=ValueError("unexpected")) + + with ( + patch("worker.ingress.styx_kafka_ingress.AIOKafkaConsumer", return_value=mock_consumer), + pytest.raises(ValueError, match="unexpected"), + ): + await ingress.start_kafka_consumer([], {}) + + def test_except_clause_catches_both_exception_types(self): + """Verify the except clause syntax is correct (not Python 2 style). + + Python 2: `except A, B:` catches A and binds to B + Python 3: `except A, B:` means `except A as B:` — only catches A! + Correct: `except (A, B):` catches both A and B + + This test inspects the source to prevent regressions from formatters + or linters that might not catch this subtle syntax difference. + """ + import ast + import inspect + import textwrap + + source = textwrap.dedent(inspect.getsource(StyxKafkaIngress.start_kafka_consumer)) + tree = ast.parse(source) + + # Find all except handlers in the function + for node in ast.walk(tree): + if ( + isinstance(node, ast.ExceptHandler) + and node.type is not None + and isinstance(node.type, ast.Name) + and node.type.id == "UnknownTopicOrPartitionError" + ): + # This means: except UnknownTopicOrPartitionError as + # i.e., KafkaConnectionError is NOT caught — this is the bug + pytest.fail( + "except clause uses Python 2 syntax " + "'except UnknownTopicOrPartitionError, KafkaConnectionError:' " + "which only catches UnknownTopicOrPartitionError. " + "Use 'except (UnknownTopicOrPartitionError, KafkaConnectionError):' " + "to catch both.", + ) diff --git a/tests/unit/worker/test_migration_fault_tolerance.py b/tests/unit/worker/test_migration_fault_tolerance.py new file mode 100644 index 0000000..cc2ce33 --- /dev/null +++ b/tests/unit/worker/test_migration_fault_tolerance.py @@ -0,0 +1,210 @@ +"""Unit tests for migration fault tolerance features. + +Covers: +- delta_map writes during migration (set_data_from_migration, set_batch_data_from_migration) +- tombstone writes on source-side key removal (get_key_to_migrate, get_async_migrate_batch) +- tombstone handling during snapshot recovery (_load_operator_state) +- migration_reassign_delta_maps preserves existing deltas +""" + +from unittest.mock import MagicMock, patch + +from worker.operator_state.aria.in_memory_state import InMemoryOperatorState + +OP = "op" +PART_0 = 0 +PART_1 = 1 +OP_PART_0 = (OP, PART_0) +OP_PART_1 = (OP, PART_1) + + +def _state(*partitions): + parts = set(partitions) if partitions else {OP_PART_0} + return InMemoryOperatorState(parts) + + +# --------------------------------------------------------------------------- +# Step 1: set_data_from_migration writes to delta_map +# --------------------------------------------------------------------------- + + +class TestSetDataFromMigrationDeltaMap: + def test_single_key_written_to_delta_map(self): + s = _state(OP_PART_0, OP_PART_1) + s.set_data_from_migration(OP_PART_1, "k1", "v1") + assert s.delta_map[OP_PART_1]["k1"] == "v1" + + def test_none_data_not_written_to_delta_map(self): + s = _state(OP_PART_0, OP_PART_1) + s.set_data_from_migration(OP_PART_1, "k1", None) + assert "k1" not in s.delta_map[OP_PART_1] + + def test_data_written_to_both_data_and_delta_map(self): + s = _state(OP_PART_0, OP_PART_1) + s.set_data_from_migration(OP_PART_1, "k1", {"field": 42}) + assert s.data[OP_PART_1]["k1"] == {"field": 42} + assert s.delta_map[OP_PART_1]["k1"] == {"field": 42} + + +class TestSetBatchDataFromMigrationDeltaMap: + def test_batch_written_to_delta_map(self): + s = _state(OP_PART_0, OP_PART_1) + batch = {"k1": "v1", "k2": "v2", "k3": "v3"} + s.set_batch_data_from_migration(OP_PART_1, batch) + assert s.delta_map[OP_PART_1] == batch + + def test_batch_merges_with_existing_delta_map(self): + s = _state(OP_PART_0, OP_PART_1) + s.delta_map[OP_PART_1]["existing"] = "val" + s.set_batch_data_from_migration(OP_PART_1, {"k1": "v1"}) + assert s.delta_map[OP_PART_1] == {"existing": "val", "k1": "v1"} + + +class TestMigrateWithinSameWorkerDeltaMap: + def test_migrate_within_same_worker_writes_delta_map(self): + s = _state(OP_PART_0, OP_PART_1) + s.data[OP_PART_0]["k1"] = "v1" + s.migrate_within_the_same_worker(OP, PART_1, "k1", PART_0) + # Destination gets the value in delta_map + assert s.delta_map[OP_PART_1]["k1"] == "v1" + # Source gets a tombstone in delta_map + assert s.delta_map[OP_PART_0]["k1"] is None + + +# --------------------------------------------------------------------------- +# Step 2: Source-side key removal writes tombstones +# --------------------------------------------------------------------------- + + +class TestGetKeyToMigrateTombstone: + def test_tombstone_written_to_delta_map(self): + s = _state(OP_PART_0, OP_PART_1) + s.data[OP_PART_0]["k1"] = "v1" + result = s.get_key_to_migrate(OP_PART_1, "k1", PART_0) + assert result == "v1" + assert s.delta_map[OP_PART_0]["k1"] is None + + def test_no_tombstone_when_key_already_gone(self): + s = _state(OP_PART_0, OP_PART_1) + # key not in data + result = s.get_key_to_migrate(OP_PART_1, "k1", PART_0) + assert result is None + assert "k1" not in s.delta_map[OP_PART_0] + + +class TestGetAsyncMigrateBatchTombstone: + def test_tombstones_written_for_migrated_keys(self): + s = _state(OP_PART_0, OP_PART_1) + s.data[OP_PART_0] = {"k1": "v1", "k2": "v2", "k3": "v3"} + s.keys_to_send = {OP_PART_0: {("k1", PART_1), ("k2", PART_1), ("k3", PART_1)}} + batch = s.get_async_migrate_batch(batch_size=10) + # All three keys should have tombstones in source delta_map + for key in ["k1", "k2", "k3"]: + assert s.delta_map[OP_PART_0][key] is None + # The batch should contain the values for the destination + assert len(batch[(OP, PART_1)]) == 3 + + def test_no_tombstone_for_already_deleted_key(self): + s = _state(OP_PART_0, OP_PART_1) + s.data[OP_PART_0] = {} # key not present + s.keys_to_send = {OP_PART_0: {("k1", PART_1)}} + s.get_async_migrate_batch(batch_size=10) + # No tombstone because key was already gone + assert "k1" not in s.delta_map[OP_PART_0] + + +# --------------------------------------------------------------------------- +# Step 3: Tombstone handling in snapshot recovery +# --------------------------------------------------------------------------- + + +class TestLoadOperatorStateTombstones: + def test_tombstone_removes_key_from_accumulated_state(self): + from worker.fault_tolerance.async_snapshots import AsyncSnapshotsS3 + + snap = AsyncSnapshotsS3.__new__(AsyncSnapshotsS3) + snap.worker_id = 0 + + # Simulate S3 with two deltas: first adds key, second tombstones it + delta_1 = {"k1": "v1", "k2": "v2"} + delta_2 = {"k1": None} # tombstone + + mock_s3 = MagicMock() + + with ( + patch.object(snap, "_iter_snapshot_files", return_value=[(1, "d/op/0/1.bin"), (2, "d/op/0/2.bin")]), + patch.object(snap, "_get_zstd_msgpack", side_effect=[delta_1, delta_2]), + ): + data = snap._load_operator_state(mock_s3, 2, [("op", 0)]) + + assert ("op", 0) in data + assert "k1" not in data[("op", 0)] # tombstoned + assert data[("op", 0)]["k2"] == "v2" # preserved + + def test_tombstone_in_first_delta_filtered_out(self): + from worker.fault_tolerance.async_snapshots import AsyncSnapshotsS3 + + snap = AsyncSnapshotsS3.__new__(AsyncSnapshotsS3) + snap.worker_id = 0 + + delta_1 = {"k1": None, "k2": "v2"} # tombstone in first delta + + mock_s3 = MagicMock() + + with ( + patch.object(snap, "_iter_snapshot_files", return_value=[(1, "d/op/0/1.bin")]), + patch.object(snap, "_get_zstd_msgpack", return_value=delta_1), + ): + data = snap._load_operator_state(mock_s3, 1, [("op", 0)]) + + assert "k1" not in data[("op", 0)] + assert data[("op", 0)]["k2"] == "v2" + + +# --------------------------------------------------------------------------- +# Step 5: migration_reassign_delta_maps preserves deltas +# --------------------------------------------------------------------------- + + +class TestMigrationReassignDeltaMaps: + def test_preserves_existing_deltas(self): + from worker.async_snapshotting import AsyncSnapshottingProcess + + proc = AsyncSnapshottingProcess.__new__(AsyncSnapshottingProcess) + proc.delta_maps = { + ("op", 0): {"k1": "v1"}, + ("op", 1): {"k2": "v2"}, + } + + # Reassign: keep partition 0, add partition 2, drop partition 1 + proc.migration_reassign_delta_maps([("op", 0), ("op", 2)]) + + assert proc.delta_maps[("op", 0)] == {"k1": "v1"} # preserved + assert proc.delta_maps[("op", 2)] == {} # new, empty + assert ("op", 1) not in proc.delta_maps # removed + + def test_no_change_when_same_partitions(self): + from worker.async_snapshotting import AsyncSnapshottingProcess + + proc = AsyncSnapshottingProcess.__new__(AsyncSnapshottingProcess) + proc.delta_maps = { + ("op", 0): {"k1": "v1"}, + ("op", 1): {"k2": "v2"}, + } + + proc.migration_reassign_delta_maps([("op", 0), ("op", 1)]) + + assert proc.delta_maps[("op", 0)] == {"k1": "v1"} + assert proc.delta_maps[("op", 1)] == {"k2": "v2"} + + def test_all_new_partitions(self): + from worker.async_snapshotting import AsyncSnapshottingProcess + + proc = AsyncSnapshottingProcess.__new__(AsyncSnapshottingProcess) + proc.delta_maps = {("op", 0): {"k1": "v1"}} + + proc.migration_reassign_delta_maps([("op", 2), ("op", 3)]) + + assert ("op", 0) not in proc.delta_maps + assert proc.delta_maps[("op", 2)] == {} + assert proc.delta_maps[("op", 3)] == {} diff --git a/worker/async_snapshotting.py b/worker/async_snapshotting.py index d26987d..e4e93bb 100644 --- a/worker/async_snapshotting.py +++ b/worker/async_snapshotting.py @@ -70,6 +70,17 @@ def clear_delta_maps(self) -> None: def init_delta_maps(self, assigned_partitions: list) -> None: self.delta_maps = {(op_part[0], op_part[1]): {} for op_part in assigned_partitions} + def migration_reassign_delta_maps(self, new_assigned_partitions: list) -> None: + new_keys = {(op_part[0], op_part[1]) for op_part in new_assigned_partitions} + # Add empty entries for newly assigned partitions + for key in new_keys: + if key not in self.delta_maps: + self.delta_maps[key] = {} + # Remove entries for partitions no longer assigned + for key in list(self.delta_maps.keys()): + if key not in new_keys: + del self.delta_maps[key] + def take_snapshot(self, metadata: tuple) -> None: loop = asyncio.get_running_loop() ( @@ -117,6 +128,13 @@ def snapshotting_controller(self, data: bytes) -> None: ) if snapshot_id != -1: self.async_snapshots.set_snapshot_id(snapshot_id) + case MessageType.SnapMigrationReassign: + (new_assigned_partitions,) = NetworkingManager.decode_message(data) + logging.warning("[SN_PROC] Migration reassign: preserving deltas across partition change") + self.migration_reassign_delta_maps(new_assigned_partitions) + self.async_snapshots.update_n_assigned_partitions( + len(new_assigned_partitions), + ) case _: logging.error( f"Worker Service: Non supported command message type: {message_type}", diff --git a/worker/egress/styx_kafka_batch_egress.py b/worker/egress/styx_kafka_batch_egress.py index 0490e6f..1c2360e 100644 --- a/worker/egress/styx_kafka_batch_egress.py +++ b/worker/egress/styx_kafka_batch_egress.py @@ -1,5 +1,4 @@ import asyncio -from collections import defaultdict import os from typing import TYPE_CHECKING import uuid @@ -25,26 +24,31 @@ def __init__( self, topic_partition_output_offsets: dict[OperatorPartition, int], restart_after_failure: bool = False, + dedup_output_offsets: dict[OperatorPartition, int] | None = None, ) -> None: self.kafka_egress_producer: AIOKafkaProducer | None = None - # operator: partition: output offset + # operator: partition: output offset — only for THIS worker's partitions self.topic_partition_output_offsets: dict[OperatorPartition, int] = topic_partition_output_offsets - # (operator, partition): replied request_ids - self.messages_sent_before_recovery: dict[TopicPartition, set] = defaultdict(set) + # Flat set of request_ids sent before recovery (partition-agnostic). + # Cross-partition dedup is needed because migration can change which + # output partition a response is written to. + self.messages_sent_before_recovery: set = set() self.restart_after_failure = restart_after_failure + # ALL output offsets from the global snapshot — used for dedup scanning + # so that every worker scans ALL output partitions (not just its own). + # This is essential because after recovery a request may be replayed on + # a different worker than the one that originally sent the output. + self.dedup_output_offsets: dict[OperatorPartition, int] = dedup_output_offsets or {} self.batch: list = [] self.started: asyncio.Event = asyncio.Event() + # --- debug counters --- + self._dedup_suppressed: int = 0 + self._dedup_sent: int = 0 def clear_messages_sent_before_recovery(self) -> None: - self.messages_sent_before_recovery: dict[TopicPartition, set] = defaultdict(set) + self.messages_sent_before_recovery: set = set() async def start(self, worker_id: int) -> None: - if self.restart_after_failure: - logging.warning( - f"Getting messages sent before recovery: {self.topic_partition_output_offsets}", - ) - await self.get_messages_sent_before_recovery() - logging.warning("Got messages sent before recovery") self.worker_id = worker_id self.kafka_egress_producer = AIOKafkaProducer( bootstrap_servers=[KAFKA_URL], @@ -59,10 +63,38 @@ async def start(self, worker_id: int) -> None: logging.info("Waiting for Kafka") continue break + if not self.restart_after_failure: + # Non-recovery path: mark started immediately. + # Recovery path: started is set later by run_dedup_scan_and_mark_started() + # after the ReadyAfterRecovery barrier ensures all workers flushed. + self.started.set() + + async def run_dedup_scan_and_mark_started(self) -> None: + """Run the dedup scan (if needed) and then mark the egress as started. + + This must be called AFTER all workers have stopped their old protocols + (i.e. after the ReadyAfterRecovery barrier) so that the dedup scan + captures all messages flushed by surviving workers. + """ + if self.restart_after_failure: + logging.warning( + f"[DEDUP] Getting messages sent before recovery. " + f"Scanning ALL output partitions: {self.dedup_output_offsets}", + ) + await self.get_messages_sent_before_recovery() + logging.warning( + f"[DEDUP] Done scanning. Dedup set size: {len(self.messages_sent_before_recovery)}", + ) self.started.set() async def stop(self) -> None: await self.send_batch() + if self._dedup_suppressed > 0 or self._dedup_sent > 0: + logging.warning( + f"[DEDUP] W{self.worker_id} final stats: " + f"suppressed={self._dedup_suppressed}, sent={self._dedup_sent}, " + f"remaining_dedup_set_size={len(self.messages_sent_before_recovery)}", + ) await self.kafka_egress_producer.stop() async def send( @@ -74,8 +106,7 @@ async def send( ) -> None: if not self.started.is_set(): await self.started.wait() - tp = TopicPartition(operator_name + "--OUT", partition) - if key not in self.messages_sent_before_recovery[tp]: + if key not in self.messages_sent_before_recovery: self.batch.append( await self.kafka_egress_producer.send( operator_name + "--OUT", @@ -84,8 +115,14 @@ async def send( partition=partition, ), ) + self._dedup_sent += 1 else: - self.messages_sent_before_recovery[tp].remove(key) + self.messages_sent_before_recovery.discard(key) + self._dedup_suppressed += 1 + logging.warning( + f"[DEDUP] W{getattr(self, 'worker_id', '?')} SUPPRESSED key={key!r} " + f"target_partition={operator_name}--OUT/{partition}", + ) async def send_message_to_topic( self, @@ -109,8 +146,7 @@ async def send_immediate( operator_name: str, partition: int, ) -> None: - tp = TopicPartition(operator_name + "--OUT", partition) - if key not in self.messages_sent_before_recovery[tp]: + if key not in self.messages_sent_before_recovery: res: RecordMetadata = await self.kafka_egress_producer.send_and_wait( operator_name + "--OUT", key=key, @@ -121,8 +157,14 @@ async def send_immediate( self.topic_partition_output_offsets[operator_name, partition], res.offset, ) + self._dedup_sent += 1 else: - self.messages_sent_before_recovery[tp].remove(key) + self.messages_sent_before_recovery.discard(key) + self._dedup_suppressed += 1 + logging.warning( + f"[DEDUP] W{getattr(self, 'worker_id', '?')} SUPPRESSED (immediate) key={key!r} " + f"target_partition={operator_name}--OUT/{partition}", + ) async def send_batch(self) -> None: if self.batch: @@ -137,25 +179,24 @@ async def send_batch(self) -> None: self.batch = [] async def get_messages_sent_before_recovery(self) -> None: + # Scan ALL output partitions from the global snapshot (not just this + # worker's). This ensures cross-worker dedup: a request originally + # processed on worker A may be replayed on worker B after recovery. + all_topic_partitions = [TopicPartition(op_name + "--OUT", part) for op_name, part in self.dedup_output_offsets] + if not all_topic_partitions: + return kafka_output_consumer = AIOKafkaConsumer( bootstrap_servers=[KAFKA_URL], enable_auto_commit=False, ) - output_topic_partitions = [ - TopicPartition(operator_name + "--OUT", partition) - for operator_name, partition in self.topic_partition_output_offsets - ] - kafka_output_consumer.assign(output_topic_partitions) + kafka_output_consumer.assign(all_topic_partitions) while True: - # start the kafka consumer try: await kafka_output_consumer.start() - for topic_partition in output_topic_partitions: - kafka_output_consumer.seek( - topic_partition, - self.topic_partition_output_offsets[(topic_partition.topic[:-5], topic_partition.partition)] - + 1, - ) + for tp in all_topic_partitions: + offset = self.dedup_output_offsets[(tp.topic[:-5], tp.partition)] + 1 + kafka_output_consumer.seek(tp, offset) + logging.warning(f"[DEDUP] Seeking {tp} to offset {offset}") except UnknownTopicOrPartitionError, KafkaConnectionError: await asyncio.sleep(1) logging.warning( @@ -164,38 +205,28 @@ async def get_messages_sent_before_recovery(self) -> None: continue break try: - # step 1 get current offset - current_offsets: dict[ - TopicPartition, - int, - ] = await kafka_output_consumer.end_offsets( - output_topic_partitions, - ) + current_offsets = await kafka_output_consumer.end_offsets(all_topic_partitions) logging.warning( - f"Reading from output from: {self.topic_partition_output_offsets} + 1 to {current_offsets}", - ) - all_partitions_done: dict[TopicPartition, bool] = dict.fromkeys( - output_topic_partitions, - False, + f"[DEDUP] End offsets: {current_offsets}", ) - for topic_partition, current_offset in current_offsets.items(): - if ( - current_offset - == self.topic_partition_output_offsets[(topic_partition.topic[:-5], topic_partition.partition)] + 1 - ): - all_partitions_done[topic_partition] = True - while not all(partition_is_done for partition_is_done in all_partitions_done.values()): - result = await kafka_output_consumer.getmany( - timeout_ms=EPOCH_INTERVAL_MS, - ) + all_partitions_done: dict[TopicPartition, bool] = { + tp: current_offsets[tp] == self.dedup_output_offsets[(tp.topic[:-5], tp.partition)] + 1 + for tp in all_topic_partitions + } + logging.warning(f"[DEDUP] Initial partitions_done: {all_partitions_done}") + per_partition_count: dict[TopicPartition, int] = dict.fromkeys(all_topic_partitions, 0) + while not all(all_partitions_done.values()): + result = await kafka_output_consumer.getmany(timeout_ms=EPOCH_INTERVAL_MS) for messages in result.values(): - self.process_messages_sent_before_recovery( - messages, - current_offsets, - all_partitions_done, - ) - if all(partition_is_done for partition_is_done in all_partitions_done.values()): + for msg in messages: + per_partition_count[TopicPartition(msg.topic, msg.partition)] += 1 + self.process_messages_sent_before_recovery(messages, current_offsets, all_partitions_done) + if all(all_partitions_done.values()): break + logging.warning( + f"[DEDUP] Scan complete. Messages read per partition: {per_partition_count}. " + f"Total dedup set size: {len(self.messages_sent_before_recovery)}", + ) finally: await kafka_output_consumer.stop() @@ -207,8 +238,6 @@ def process_messages_sent_before_recovery( ) -> None: for message in messages: tp = TopicPartition(message.topic, message.partition) + self.messages_sent_before_recovery.add(message.key) if message.offset >= current_offsets[tp] - 1: - self.messages_sent_before_recovery[tp].add(message.key) all_partitions_done[tp] = True - else: - self.messages_sent_before_recovery[tp].add(message.key) diff --git a/worker/fault_tolerance/async_snapshots.py b/worker/fault_tolerance/async_snapshots.py index 18b7b43..d99896f 100644 --- a/worker/fault_tolerance/async_snapshots.py +++ b/worker/fault_tolerance/async_snapshots.py @@ -28,6 +28,9 @@ SNAPSHOT_BUCKET_NAME: str = os.getenv("SNAPSHOT_BUCKET_NAME", "styx-snapshots") SEQUENCER_PREFIX = "sequencer/" +# Number of elements in sequencer tuple (with/without migration checkpoint blob) +_SEQUENCER_TUPLE_LEN_WITH_MIGRATION = 5 +_SEQUENCER_TUPLE_LEN_LEGACY = 4 # Per-process cached S3 client (avoids re-creating boto3 client on every call) _s3_client: S3Client | None = None @@ -141,15 +144,16 @@ def retrieve_snapshot( dict[OperatorPartition, int], int, int, + bytes | None, ]: self.snapshot_id = snapshot_id + 1 if snapshot_id == -1: - return {}, {}, {}, 0, 0 + return {}, {}, {}, 0, 0, None s3 = _get_s3_client() data = self._load_operator_state(s3, snapshot_id, registered_operators) - tp_offsets, tp_out_offsets, epoch, t_counter = self._load_sequencer_state(s3, snapshot_id) - return data, tp_offsets, tp_out_offsets, epoch, t_counter + tp_offsets, tp_out_offsets, epoch, t_counter, migration_blob = self._load_sequencer_state(s3, snapshot_id) + return data, tp_offsets, tp_out_offsets, epoch, t_counter, migration_blob def _load_operator_state( self, @@ -165,7 +169,14 @@ def _load_operator_state( for _, key in self._iter_snapshot_files(s3, prefix, snapshot_id): partition_data = self._get_zstd_msgpack(s3, key) if operator_partition in data and partition_data: - data[operator_partition].update(partition_data) + for k, v in partition_data.items(): + if v is None: + data[operator_partition].pop(k, None) + else: + data[operator_partition][k] = v + elif partition_data: + # First delta for this partition — filter out any tombstones + data[operator_partition] = {k: v for k, v in partition_data.items() if v is not None} else: data[operator_partition] = partition_data return data @@ -174,21 +185,33 @@ def _load_sequencer_state( self, s3: S3Client, snapshot_id: int, - ) -> tuple[dict[OperatorPartition, int], dict[OperatorPartition, int], int, int]: + ) -> tuple[dict[OperatorPartition, int], dict[OperatorPartition, int], int, int, bytes | None]: topic_partition_offsets: dict[OperatorPartition, int] = {} topic_partition_output_offsets: dict[OperatorPartition, int] = {} epoch = 0 t_counter = 0 + migration_blob: bytes | None = None for _, key in self._iter_snapshot_files(s3, SEQUENCER_PREFIX, snapshot_id): - ( - topic_partition_offsets, - topic_partition_output_offsets, - epoch, - t_counter, - ) = self._get_zstd_msgpack(s3, key) - - return topic_partition_offsets, topic_partition_output_offsets, epoch, t_counter + loaded = self._get_zstd_msgpack(s3, key) + if isinstance(loaded, (tuple, list)) and len(loaded) == _SEQUENCER_TUPLE_LEN_WITH_MIGRATION: + ( + topic_partition_offsets, + topic_partition_output_offsets, + epoch, + t_counter, + migration_blob, + ) = loaded + elif isinstance(loaded, (tuple, list)) and len(loaded) == _SEQUENCER_TUPLE_LEN_LEGACY: + ( + topic_partition_offsets, + topic_partition_output_offsets, + epoch, + t_counter, + ) = loaded + migration_blob = None + + return topic_partition_offsets, topic_partition_output_offsets, epoch, t_counter, migration_blob def _iter_snapshot_files( self, diff --git a/worker/operator_state/aria/in_memory_state.py b/worker/operator_state/aria/in_memory_state.py index 12b6905..984bdcf 100644 --- a/worker/operator_state/aria/in_memory_state.py +++ b/worker/operator_state/aria/in_memory_state.py @@ -52,6 +52,8 @@ def get_key_to_migrate( if data_to_send is None: # Key was already transferred via async migration batch return None + # Record source-side key removal as tombstone for snapshotting + self.delta_map[operator_partition][key] = None if operator_partition in self.keys_to_send: self.keys_to_send[operator_partition].discard((key, new_partition)) return data_to_send @@ -65,6 +67,7 @@ def set_data_from_migration( operator_partition = tuple(operator_partition) if data is not None: self.data[operator_partition][key] = data + self.delta_map[operator_partition][key] = data # Only remove from remote_keys when we actually received the data. # A None response means the async migration batch already transferred # this key — the batch will arrive and set it via set_batch_data_from_migration. @@ -114,6 +117,8 @@ def get_async_migrate_batch( if value is None: # Key was deleted between rehash and transfer — skip it continue + # Record source-side key removal as tombstone for snapshotting + self.delta_map[operator_partition][key] = None batch_to_send[(operator_name, new_partition)][key] = value c += 1 if not keys: @@ -132,6 +137,7 @@ def set_batch_data_from_migration( ) -> None: operator_partition = tuple(operator_partition) # new partitioning self.data[operator_partition].update(kv_pairs) + self.delta_map[operator_partition].update(kv_pairs) # Guard: remote_keys may not have entries if async batch arrived # before hash metadata, or entries were already removed. if operator_partition in self.remote_keys: diff --git a/worker/transactional_protocols/aria.py b/worker/transactional_protocols/aria.py index 414497f..82004fd 100644 --- a/worker/transactional_protocols/aria.py +++ b/worker/transactional_protocols/aria.py @@ -51,7 +51,7 @@ class AriaProtocol(BaseTransactionalProtocol): - def __init__( + def __init__( # noqa: PLR0913 self, worker_id: int, peers: dict[int, tuple[str, int, int]], @@ -68,6 +68,7 @@ def __init__( request_id_to_t_id_map: dict[bytes, int] | None = None, restart_after_recovery: bool = False, restart_after_migration: bool = False, + dedup_output_offsets: dict[OperatorPartition, int] | None = None, ) -> None: if topic_partition_offsets is None: topic_partition_offsets = {(tp.topic, tp.partition): -1 for tp in topic_partitions} @@ -127,6 +128,7 @@ def __init__( self.egress: StyxKafkaBatchEgress = StyxKafkaBatchEgress( output_offsets, restart_after_recovery or restart_after_migration, + dedup_output_offsets=dedup_output_offsets, ) # Primary task used for processing self.function_scheduler_task: asyncio.Task | None = None @@ -169,6 +171,7 @@ def __init__( self.wait_responses_to_be_sent = asyncio.Event() self.running: bool = True + self._stopping: bool = False self.stopped: asyncio.Event = asyncio.Event() self.snapshot_marker_received: bool = False self.snapshotting_port: int = snapshotting_port @@ -198,6 +201,11 @@ async def wait_stopped(self) -> None: await self.stopped.wait() async def stop(self) -> None: + if self._stopping: + # Another caller is already tearing down — just wait for completion. + await self.stopped.wait() + return + self._stopping = True await self.ingress.stop() await self.egress.stop() await self.aio_task_scheduler.close() @@ -383,6 +391,7 @@ async def _handle_sync_cleanup(self, data: bytes) -> None: async with self.networking_locks[mt]: (stop_gracefully,) = self.networking.decode_message(data) if stop_gracefully: + logging.warning(f"Worker {self.id} | SyncCleanup received stop_gracefully=True") self.running = False self.sync_workers_event[mt].set() @@ -551,25 +560,37 @@ async def function_scheduler(self) -> None: await self.started.wait() logging.warning("STARTED function scheduler") - while self.running: - # Wait until the ingress signals that messages are available, - # or a remote peer wants to proceed, instead of busy-spinning. - with contextlib.suppress(TimeoutError): - await asyncio.wait_for( - self.ingress.messages_available.wait(), - timeout=0.1, - ) - self.ingress.messages_available.clear() + try: + while self.running: + # Wait until the ingress signals that messages are available, + # or a remote peer wants to proceed, instead of busy-spinning. + with contextlib.suppress(TimeoutError): + await asyncio.wait_for( + self.ingress.messages_available.wait(), + timeout=0.1, + ) + self.ingress.messages_available.clear() - async with self.sequencer.lock: - sequence: list[SequencedItem] = self.sequencer.get_epoch() - if not sequence and not self.remote_wants_to_proceed: - continue + async with self.sequencer.lock: + sequence: list[SequencedItem] = self.sequencer.get_epoch() + if not sequence and not self.remote_wants_to_proceed: + continue - self.currently_processing = True - await self._process_epoch(sequence) + self.currently_processing = True + await self._process_epoch(sequence) - await self.stop() + logging.warning(f"Worker {self.id} | function_scheduler exiting (running=False)") + except asyncio.CancelledError: + logging.warning(f"Worker {self.id} | function_scheduler cancelled") + raise + except Exception: + logging.exception(f"Worker {self.id} | function_scheduler crashed") + finally: + # Only call stop() if no external caller (e.g. recovery) is + # already tearing down. Otherwise we deadlock: the external + # stop() awaits this task, while this task awaits stopped. + if not self._stopping: + await self.stop() async def _process_epoch(self, sequence: list[SequencedItem]) -> None: epoch_start = timer() @@ -1047,5 +1068,13 @@ async def sync_workers( msg_type=msg_type, serializer=serializer, ) - await self.sync_workers_event[msg_type].wait() + # Detect stuck barriers during migration debugging + try: + await asyncio.wait_for(self.sync_workers_event[msg_type].wait(), timeout=10.0) + except TimeoutError: + logging.warning( + f"Worker {self.id} | BARRIER STUCK for 10s on {msg_type.name} @epoch {self.sequencer.epoch_counter}", + ) + # Continue waiting (no timeout) + await self.sync_workers_event[msg_type].wait() self.sync_workers_event[msg_type].clear() diff --git a/worker/worker_service.py b/worker/worker_service.py index 89e95f2..208d7e8 100644 --- a/worker/worker_service.py +++ b/worker/worker_service.py @@ -34,7 +34,7 @@ from styx.common.logging import logging from styx.common.message_types import MessageType from styx.common.protocols import Protocols -from styx.common.serialization import Serializer, msgpack_deserialization +from styx.common.serialization import Serializer, cloudpickle_deserialization, msgpack_deserialization from styx.common.tcp_networking import MessagingMode, NetworkingManager from styx.common.util.aio_task_scheduler import AIOTaskScheduler import uvloop @@ -142,6 +142,7 @@ def __init__(self, thread_idx: int) -> None: self.worker_operators: dict[OperatorPartition, Operator] | None = None self.migration_completed: asyncio.Event = asyncio.Event() + self._migration_generation: int = 0 # bumped on recovery to invalidate zombie Phase B tasks self.m_input_offsets: dict[OperatorPartition, int] = {} self.m_output_offsets: dict[OperatorPartition, int] = {} @@ -289,6 +290,15 @@ async def _send_snap_assigned(self, snapshot_id: int) -> None: serializer=Serializer.MSGPACK, ) + async def _send_snap_migration_reassign(self) -> None: + await self.networking.send_message( + self.networking.host_name, + self.snapshotting_port, + msg=(list(self.registered_operators.keys()),), + msg_type=MessageType.SnapMigrationReassign, + serializer=Serializer.MSGPACK, + ) + def _attach_operator_networking(self) -> None: for operator in self.registered_operators.values(): operator.attach_state_networking( @@ -329,7 +339,7 @@ async def _handle_receive_execution_plan(self, data: bytes, _: MessageType) -> N await self._send_snap_assigned(snapshot_id=snapshot_id) logging.warning("Retrieving Snapshot from S3") - (snap_data, _in_off, _out_off, epoch, t_counter) = self.async_snapshots.retrieve_snapshot( + (snap_data, _in_off, _out_off, epoch, t_counter, _migration_blob) = self.async_snapshots.retrieve_snapshot( snapshot_id, self.registered_operators.keys(), ) @@ -381,7 +391,16 @@ async def _handle_init_migration(self, data: bytes, _: MessageType) -> None: logging.error(f"Uncaught exception during migration Phase A: {e}") async def _migration_stop_protocol(self) -> None: - await self.function_execution_protocol.wait_stopped() + try: + await asyncio.wait_for( + self.function_execution_protocol.wait_stopped(), + timeout=15.0, + ) + except TimeoutError: + logging.warning( + "MIGRATION | Protocol did not stop within 15s via stop_gracefully — force-stopping", + ) + await self.function_execution_protocol.stop() logging.warning("MIGRATION | ARIA STOPPED") async def _migration_decode_and_apply_plan(self, data: bytes) -> None: @@ -522,8 +541,8 @@ async def _migration_rebuild_runtime(self) -> None: n_assigned_partitions=len(self.registered_operators), ) - # Snapshot id -1 indicates "migration" - await self._send_snap_assigned(snapshot_id=-1) + # Reassign partitions in snapshotting subprocess, preserving accumulated deltas + await self._send_snap_migration_reassign() # Ensure local state has all newly assigned partitions self._ensure_local_state_partitions() @@ -550,6 +569,7 @@ async def _migration_rebuild_runtime(self) -> None: epoch_counter=self.m_epoch_counter, t_counter=self.m_t_counter, restart_after_migration=True, + dedup_output_offsets=dict(self.m_output_offsets), ) async def _handle_receive_migration_hashes( @@ -604,6 +624,7 @@ async def _handle_migration_repartitioning_done( _: MessageType, ) -> None: """Phase B: protocol stops, catch-up, send hashes, rebuild, resume.""" + my_generation = self._migration_generation try: logging.warning(f"MIGRATION | PHASE B START at {time.time_ns() // 1_000_000}") phase_b_start = timer() @@ -612,6 +633,14 @@ async def _handle_migration_repartitioning_done( t_stop_start = timer() await self._migration_stop_protocol() t_stop_end = timer() + + # Bail out if recovery invalidated this Phase B task + if self._migration_generation != my_generation: + logging.warning( + "MIGRATION | PHASE B ZOMBIE detected after protocol stop — bailing out", + ) + return + logging.warning( "MIGRATION | PHASE B Protocol Stopped |" f" @Epoch {self.function_execution_protocol.sequencer.epoch_counter} |" @@ -641,6 +670,24 @@ async def _handle_migration_repartitioning_done( f"MIGRATION | PHASE B Hash metadata sent | took: {t_hash_send_end - t_hash_send_start}", ) + # 5.5 Trigger a migration checkpoint snapshot. + # This captures the state at the exact moment before data transfer begins, + # enabling checkpoint-and-resume recovery instead of revert-and-retry. + await self.function_execution_protocol.send_delta_to_snapshotting_proc() + await self.function_execution_protocol.snapshotting_networking_manager.send_message( + self.function_execution_protocol.networking.host_name, + self.function_execution_protocol.snapshotting_port, + msg=( + self.function_execution_protocol.topic_partition_offsets, + self.function_execution_protocol.egress.topic_partition_output_offsets, + self.function_execution_protocol.sequencer.epoch_counter, + self.function_execution_protocol.sequencer.t_counter, + ), + msg_type=MessageType.SnapTakeSnapshot, + serializer=Serializer.MSGPACK, + ) + logging.warning("MIGRATION | PHASE B Migration checkpoint snapshot triggered") + # 6. Send MigrationInitDone with actual stop-time counters await self.networking.send_message( DISCOVERY_HOST, @@ -659,6 +706,13 @@ async def _handle_migration_repartitioning_done( logging.warning("MIGRATION | PHASE B WAITING FOR MigrationDone") await self.migration_completed.wait() + # Bail out if recovery invalidated this Phase B task + if self._migration_generation != my_generation: + logging.warning( + "MIGRATION | PHASE B ZOMBIE detected (generation mismatch) — bailing out", + ) + return + # 8. Rebuild runtime (now has correct merged counters from MigrationDone) t_deploy_start = timer() await self._migration_rebuild_runtime() @@ -669,6 +723,9 @@ async def _handle_migration_repartitioning_done( # 9. Start protocol with background migration self.function_execution_protocol.start() + # Run dedup scan after the MigrationDone barrier ensures all workers + # have flushed their old egress batches, then mark egress started. + await self.function_execution_protocol.egress.run_dedup_scan_and_mark_started() self.function_execution_protocol.started.set() # Reset sync events for next time @@ -694,6 +751,14 @@ async def _handle_migration_done(self, data: bytes, _: MessageType) -> None: async def _handle_init_recovery(self, data: bytes, _: MessageType) -> None: start_time = timer() + # Invalidate any zombie Phase B tasks from an interrupted migration. + # Bump the generation so that any Phase B task waiting on + # migration_completed will detect the generation mismatch and bail out. + self._migration_generation += 1 + self.migration_completed.set() # unblock the zombie so it can exit + await asyncio.sleep(0) # yield to let the zombie run its bail-out check + self.migration_completed.clear() + ( self.id, self.worker_operators, @@ -729,16 +794,72 @@ async def _handle_init_recovery(self, data: bytes, _: MessageType) -> None: ) await self._send_snap_assigned(snapshot_id=snapshot_id) - (snap_data, tp_offsets, tp_out_offsets, epoch, t_counter) = self.async_snapshots.retrieve_snapshot( - snapshot_id, - self.registered_operators.keys(), + (snap_data, tp_offsets, tp_out_offsets, epoch, t_counter, migration_blob) = ( + self.async_snapshots.retrieve_snapshot( + snapshot_id, + self.registered_operators.keys(), + ) + ) + + logging.warning( + f"[RECOVERY] W{self.id} snapshot_id={snapshot_id} epoch={epoch} t_counter={t_counter}" + f" migration_blob={'present' if migration_blob else 'none'}", + ) + logging.warning( + f"[RECOVERY] W{self.id} registered_operators={list(self.registered_operators.keys())}", + ) + logging.warning( + f"[RECOVERY] W{self.id} raw tp_offsets from snapshot: {tp_offsets}", + ) + logging.warning( + f"[RECOVERY] W{self.id} raw tp_out_offsets from snapshot: {tp_out_offsets}", ) + # Keep the FULL output offsets for dedup scanning (all partitions from + # the global snapshot). Each worker must scan ALL output partitions so + # that cross-worker dedup works: a request originally processed on + # worker A may be replayed on worker B after recovery. + all_dedup_out_offsets = dict(tp_out_offsets) tp_offsets = {k: v for k, v in tp_offsets.items() if k in self.registered_operators} tp_out_offsets = {k: v for k, v in tp_out_offsets.items() if k in self.registered_operators} + logging.warning( + f"[RECOVERY] W{self.id} filtered tp_offsets: {tp_offsets}", + ) + logging.warning( + f"[RECOVERY] W{self.id} filtered tp_out_offsets: {tp_out_offsets}", + ) + logging.warning( + f"[RECOVERY] W{self.id} dedup_output_offsets (all partitions): {all_dedup_out_offsets}", + ) + self.attach_state_to_operators_after_snapshot(snap_data) + # Detect migration checkpoint: if present, re-derive keys_to_send and resume transfer + restart_after_migration = False + if migration_blob is not None: + logging.warning(f"[RECOVERY] W{self.id} Migration checkpoint detected — resuming migration transfer") + migration_meta = cloudpickle_deserialization(migration_blob) + new_graph = migration_meta["new_graph"] + + # Re-derive keys_to_send from local state + new partitioner + keys_to_send: dict[tuple, set] = defaultdict(set) + for op_part in self.registered_operators: + op_name, partition = op_part + partitioner = new_graph.get_operator_by_name(op_name).get_partitioner() + op_data = self.local_state.get_operator_data_for_repartitioning(op_part) + for key in op_data: + new_part = partitioner.get_partition_no_cache(key) + if new_part != partition: + keys_to_send[op_part].add((key, new_part)) + + if keys_to_send: + self.local_state.add_keys_to_send(keys_to_send) + logging.warning( + f"[RECOVERY] W{self.id} Re-derived {sum(len(v) for v in keys_to_send.values())} keys to send", + ) + restart_after_migration = True + request_id_to_t_id_map = await self.get_sequencer_assignments_before_failure( epoch, ) @@ -757,7 +878,9 @@ async def _handle_init_recovery(self, data: bytes, _: MessageType) -> None: epoch_counter=epoch, t_counter=t_counter, request_id_to_t_id_map=request_id_to_t_id_map, - restart_after_recovery=True, + restart_after_recovery=not restart_after_migration, + restart_after_migration=restart_after_migration, + dedup_output_offsets=all_dedup_out_offsets, ) self.function_execution_protocol.start() @@ -776,6 +899,11 @@ async def _handle_init_recovery(self, data: bytes, _: MessageType) -> None: ) async def _handle_ready_after_recovery(self, _data: bytes, _: MessageType) -> None: + # All workers have stopped their old protocols by this point (the + # coordinator only sends ReadyAfterRecovery after all workers reported). + # Run the dedup scan NOW so it captures every message flushed by every + # surviving worker's old egress. + await self.function_execution_protocol.egress.run_dedup_scan_and_mark_started() self.function_execution_protocol.started.set() logging.warning( f"Worker: {self.id} recovered and ready at : {time.time() * 1000}",