From a46bf262d8ff811c9cda257ca43df6813f1f174a Mon Sep 17 00:00:00 2001 From: aminediro Date: Tue, 31 Mar 2026 09:46:14 +0000 Subject: [PATCH 1/5] Add delta weight synchronization support to AsyncGRPO - Add `huggingface-hub` as dependency - Introduce sparse weight patching via `DeltaWeightTransferEngine` - Add `ULPChangeDetector` for optimizer-level change tracking - Add config parameters for delta sync control (repo, anchor interval, checksum verification) - Support both anchor checkpoints and delta patches via HF Hub (Xet storage) --- pyproject.toml | 1 + .../async_grpo/async_grpo_config.py | 42 +++ .../async_grpo/async_grpo_trainer.py | 41 +++ .../async_grpo/async_rollout_worker.py | 41 ++- trl/experimental/async_grpo/delta_engine.py | 342 ++++++++++++++++++ trl/experimental/async_grpo/weight_diff.py | 298 +++++++++++++++ 6 files changed, 764 insertions(+), 1 deletion(-) create mode 100644 trl/experimental/async_grpo/delta_engine.py create mode 100644 trl/experimental/async_grpo/weight_diff.py diff --git a/pyproject.toml b/pyproject.toml index ac9f4cc991f..90e0ee0bf49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ requires-python = ">=3.10" dependencies = [ "accelerate>=1.4.0", "datasets>=3.0.0", + "huggingface-hub>=0.36.2", "packaging>20.0", "transformers>=4.56.2", ] diff --git a/trl/experimental/async_grpo/async_grpo_config.py b/trl/experimental/async_grpo/async_grpo_config.py index 2afd760e7fc..a2fca996e33 100644 --- a/trl/experimental/async_grpo/async_grpo_config.py +++ b/trl/experimental/async_grpo/async_grpo_config.py @@ -185,6 +185,45 @@ class AsyncGRPOConfig(_BaseConfig): metadata={"help": "Number of training steps between weight synchronizations to the vLLM server."}, ) + # Delta weight sync + delta_sync_enabled: bool = field( + default=False, + metadata={ + "help": "Enable delta-compressed weight synchronization. Instead of transferring all " + "weights over NCCL, encode only changed bf16 weights as sparse safetensors patches, " + "upload them to HuggingFace Hub (Xet storage), and signal vLLM to fetch and apply them." + }, + ) + delta_sync_repo_id: str | None = field( + default=None, + metadata={ + "help": "HuggingFace Hub repository for storing delta weight patches and anchor " + "checkpoints (e.g. 'user/training-run-xyz'). Required when delta_sync_enabled=True. " + "The repo is created automatically if it does not exist." + }, + ) + delta_sync_anchor_interval: int = field( + default=10, + metadata={ + "help": "Save a full bf16 anchor checkpoint every N weight sync steps. Between anchors " + "only sparse delta patches are saved. Fireworks blog uses N=25, PULSE paper uses N=50." + }, + ) + delta_sync_verify_checksum: bool = field( + default=True, + metadata={ + "help": "Verify SHA256 checksum after applying each delta patch on the vLLM server. " + "Adds overhead per sync but guarantees bit-exact reconstruction." + }, + ) + delta_sync_log_ulp_accuracy: bool = field( + default=True, + metadata={ + "help": "Log precision/recall/F1 of ULP-based change predictions vs. actual bf16 " + "changes. Useful for validating the optimizer hook approach." + }, + ) + # Parameters that control the logging log_completions: bool = field( default=False, @@ -201,6 +240,9 @@ class AsyncGRPOConfig(_BaseConfig): def __post_init__(self): super().__post_init__() + if self.delta_sync_enabled and self.delta_sync_repo_id is None: + raise ValueError("delta_sync_repo_id is required when delta_sync_enabled=True") + # Accelerator config: required for the async IterableDataset-backed dataloader to work correctly. # split_batches=True and dispatch_batches=True ensure that the main process drives the dataloader # and batches are broadcast to other processes rather than each process pulling independently. diff --git a/trl/experimental/async_grpo/async_grpo_trainer.py b/trl/experimental/async_grpo/async_grpo_trainer.py index ba148e2fa71..9c9d71e542c 100644 --- a/trl/experimental/async_grpo/async_grpo_trainer.py +++ b/trl/experimental/async_grpo/async_grpo_trainer.py @@ -35,6 +35,7 @@ from .async_grpo_config import AsyncGRPOConfig from .async_rollout_worker import AsyncRolloutWorker +from .weight_diff import ULPChangeDetector logger = get_logger(__name__) @@ -380,6 +381,11 @@ def __init__( weight_names=weight_names, weight_dtype_names=weight_dtype_names, weight_shapes=weight_shapes, + delta_sync_enabled=self.args.delta_sync_enabled, + delta_sync_repo_id=self.args.delta_sync_repo_id, + delta_sync_anchor_interval=self.args.delta_sync_anchor_interval, + delta_sync_verify_checksum=self.args.delta_sync_verify_checksum, + delta_sync_base_model_id=model_name, ) self.rollout_queue = self.rollout_worker.rollout_buffer else: @@ -389,6 +395,9 @@ def __init__( # Add callbacks self.add_callback(StepIntervalCallback(self._sync_weight, self.args.weight_sync_steps)) + # ULP change detector for diagnostic logging (delta sync only) + self._ulp_detector: ULPChangeDetector | None = None + def get_train_dataloader(self) -> DataLoader: if self.accelerator.is_main_process: dataset = RolloutQueueDataset( @@ -571,6 +580,35 @@ def _streaming_iter(self): yield name, full def _sync_weight(self): + # Lazy-init ULP detector for diagnostic logging (delta sync only). + # Optimizer only exists after Trainer creates it inside super()._inner_training_loop(). + if ( + self.args.delta_sync_enabled + and self._ulp_detector is None + and hasattr(self, "optimizer") + and self.optimizer is not None + ): + # Unwrap AcceleratedOptimizer to get the native PyTorch optimizer + # (register_step_pre_hook requires torch.optim.Optimizer internals) + raw_optimizer = getattr(self.optimizer, "optimizer", self.optimizer) + self._ulp_detector = ULPChangeDetector(self.model, raw_optimizer) + + # Log ULP prediction accuracy (diagnostic, doesn't affect sync) + if ( + self.args.delta_sync_enabled + and self.args.delta_sync_log_ulp_accuracy + and self._ulp_detector is not None + and self.accelerator.is_main_process + ): + accuracy = self._ulp_detector.get_prediction_accuracy() + for k, v in accuracy.items(): + self._metrics["train"][f"delta/{k}"].append(v) + logger.info( + f"ULP accuracy: precision={accuracy['precision']:.3f} " + f"recall={accuracy['recall']:.3f} " + f"sparsity={accuracy['sparsity']:.4%}" + ) + t0 = time.time() logger.info("Weight sync: pausing vLLM...") if self.accelerator.is_main_process and self.rollout_worker: @@ -614,3 +652,6 @@ def _inner_training_loop(self, *args, **kwargs): finally: if self.accelerator.is_main_process and self.rollout_worker: self.rollout_worker.stop() + if self._ulp_detector is not None: + self._ulp_detector.close() + self._ulp_detector = None diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index eb720a2dda5..6d274e285dd 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -102,11 +102,17 @@ def __init__( weight_names: list[str] | None = None, weight_dtype_names: list[str] | None = None, weight_shapes: list[list[int]] | None = None, + delta_sync_enabled: bool = False, + delta_sync_repo_id: str | None = None, + delta_sync_anchor_interval: int = 10, + delta_sync_verify_checksum: bool = True, + delta_sync_base_model_id: str = "", ): if not is_vllm_available(min_version="0.17.1"): raise ImportError( "vLLM >= 0.17.1 is required to use AsyncRolloutWorker. Install it with: pip install 'vllm>=0.17.1'" ) + self.delta_sync_enabled = delta_sync_enabled self.model_name = model_name self.max_tool_calling_iterations = max_tool_calling_iterations self.dataset = dataset @@ -171,7 +177,12 @@ def __init__( self.model_version = 0 self.session = None - # Wait for the vLLM server and initialize NCCL weight transfer. + self._delta_sync_repo_id = delta_sync_repo_id + self._delta_sync_anchor_interval = delta_sync_anchor_interval + self._delta_sync_verify_checksum = delta_sync_verify_checksum + self._delta_sync_base_model_id = delta_sync_base_model_id + + # Wait for the vLLM server and initialize weight transfer. self._wait_for_server_ready_sync(timeout_s=self.server_timeout) self._init_weight_transfer() @@ -199,6 +210,24 @@ def _wait_for_server_ready_sync(self, timeout_s: float = 240.0, poll_interval_s: time.sleep(poll_interval_s) def _init_weight_transfer(self) -> None: + if self.delta_sync_enabled: + from .delta_engine import DeltaWeightTransferEngine + + self._delta_trainer_args = DeltaWeightTransferEngine.trainer_init( + repo_id=self._delta_sync_repo_id, + url=self.vllm_server_url, + anchor_interval=self._delta_sync_anchor_interval, + verify_checksum=self._delta_sync_verify_checksum, + base_model_id=self._delta_sync_base_model_id, + ) + requests.post( + f"{self.vllm_server_url}/init_weight_transfer_engine", + json={"init_info": {}}, + timeout=120, + ) + logger.info("Init delta weight transfer with HF Hub repo %s", self._delta_sync_repo_id) + return + response = requests.get(f"{self.vllm_server_url}/get_world_size") inference_world_size = response.json()["world_size"] world_size = inference_world_size + 1 @@ -287,6 +316,16 @@ def resume(self) -> None: logger.debug(f"[weight_sync] resume HTTP took {time.time() - t0:.1f}s") def send_weights(self, iterator) -> None: + if self.delta_sync_enabled: + from .delta_engine import DeltaWeightTransferEngine + + t0 = time.time() + DeltaWeightTransferEngine.trainer_send_weights( + iterator=iterator, + trainer_args=self._delta_trainer_args, + ) + logger.info(f"[delta_sync] send_weights took {time.time() - t0:.1f}s") + return if self.model_update_group is None: return t0 = time.time() diff --git a/trl/experimental/async_grpo/delta_engine.py b/trl/experimental/async_grpo/delta_engine.py new file mode 100644 index 00000000000..66edc6858fa --- /dev/null +++ b/trl/experimental/async_grpo/delta_engine.py @@ -0,0 +1,342 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Delta weight transfer engine for vLLM. + +Uses HuggingFace Hub (Xet storage) as the data plane for sparse weight patches. +The trainer uploads patches to HF Hub, then sends a lightweight metadata signal +to vLLM via ``/update_weights``. The vLLM worker downloads and applies patches. + +Registration happens at module import time so that vLLM's ``WeightTransferEngineFactory`` +can find the ``"delta"`` backend. Use ``--worker-extension-cls`` to trigger the import:: + + vllm serve model_name \\ + --worker-extension-cls trl.experimental.async_grpo.delta_engine.DeltaWorkerExtension \\ + --weight-transfer-backend delta +""" + +from __future__ import annotations + +import json +import logging +from collections.abc import Callable, Iterator +from dataclasses import dataclass, field +from typing import Any + +import requests +import torch +from huggingface_hub import HfApi, hf_hub_download +from safetensors import safe_open +from safetensors.torch import load_file, save +from vllm.config.parallel import ParallelConfig +from vllm.config.weight_transfer import WeightTransferConfig +from vllm.distributed.weight_transfer.base import ( + WeightTransferEngine, + WeightTransferInitInfo, + WeightTransferUpdateInfo, +) +from vllm.distributed.weight_transfer.factory import WeightTransferEngineFactory + +from .weight_diff import compute_bf16_checksum, encode_sparse_patch + + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Dataclasses +# --------------------------------------------------------------------------- + + +@dataclass +class DeltaWeightTransferInitInfo(WeightTransferInitInfo): + """No initialization needed for file-based Hub transport.""" + + pass + + +@dataclass +class DeltaWeightTransferUpdateInfo(WeightTransferUpdateInfo): + """Metadata sent via ``/update_weights`` — no weight data, just Hub coordinates.""" + + repo_id: str = "" + filename: str = "" + revision: str = "main" + patch_type: str = "anchor" # "anchor" | "delta" + expected_checksum: str = "" + # is_checkpoint_format: True for anchor (layerwise reload), False for delta (param.copy_) + + +@dataclass +class DeltaTrainerSendWeightsArgs: + """Trainer-side state passed to ``trainer_send_weights``. + + This is a mutable object — ``prev_bf16_snapshot`` and ``model_version`` + are updated after each call so the next call can compute a diff. + """ + + repo_id: str + url: str # vLLM server URL (for the /update_weights signal only) + hf_api: HfApi = field(default_factory=HfApi) + anchor_interval: int = 10 + verify_checksum: bool = True + revision: str = "main" + base_model_id: str = "" + # Mutable state — updated after each call: + prev_bf16_snapshot: dict[str, torch.Tensor] | None = None + model_version: int = 0 + _last_anchor_step: int = 0 + + +class DeltaWeightTransferEngine(WeightTransferEngine[DeltaWeightTransferInitInfo, DeltaWeightTransferUpdateInfo]): + """Weight transfer engine that uses HF Hub (Xet) for sparse delta patches. + + Worker side: downloads patches from Hub and applies them via a CPU bf16 snapshot. + Trainer side: encodes sparse patches and uploads them to Hub. + """ + + init_info_cls = DeltaWeightTransferInitInfo + update_info_cls = DeltaWeightTransferUpdateInfo + + def __init__(self, config: WeightTransferConfig, parallel_config: ParallelConfig) -> None: + super().__init__(config, parallel_config) + self._bf16_snapshot: dict[str, torch.Tensor] | None = None + + def init_transfer_engine(self, init_info: DeltaWeightTransferInitInfo) -> None: + pass # No process group setup needed + + def receive_weights( + self, + update_info: DeltaWeightTransferUpdateInfo, + load_weights: Callable[[list[tuple[str, torch.Tensor]]], None], + ) -> None: + local_path = hf_hub_download( + repo_id=update_info.repo_id, + filename=update_info.filename, + revision=update_info.revision, + force_download=True, + ) + + if update_info.patch_type == "anchor": + self._receive_anchor(local_path, load_weights) + else: + self._receive_delta(local_path, update_info.expected_checksum, load_weights) + + def _receive_anchor( + self, + local_path: str, + load_weights: Callable[[list[tuple[str, torch.Tensor]]], None], + ) -> None: + """Load a full anchor checkpoint and rebuild the snapshot.""" + state = load_file(local_path, device="cpu") + self._bf16_snapshot = {} + for name, tensor in state.items(): + self._bf16_snapshot[name] = tensor.to(torch.bfloat16).clone() + load_weights([(name, tensor)]) + logger.info("Loaded anchor checkpoint with %d parameters", len(self._bf16_snapshot)) + + def _receive_delta( + self, + local_path: str, + expected_checksum: str, + load_weights: Callable[[list[tuple[str, torch.Tensor]]], None], + ) -> None: + """Apply a sparse delta patch to the snapshot, then feed changed params to load_weights.""" + if self._bf16_snapshot is None: + raise RuntimeError( + "Cannot apply delta patch without a prior anchor. " + "Ensure the first weight sync is an anchor (is_checkpoint_format=True)." + ) + + with safe_open(local_path, framework="pt", device="cpu") as f: + meta = f.metadata() + changed_names = json.loads(meta.get("changed_params", "[]")) + + for name in changed_names: + indices = f.get_tensor(f"{name}.indices").long() + values = f.get_tensor(f"{name}.values") + # Apply to CPU snapshot (bit-exact, no FP arithmetic) + snap_flat = self._bf16_snapshot[name].flatten() + snap_flat[indices] = values + self._bf16_snapshot[name] = snap_flat.reshape(self._bf16_snapshot[name].shape) + # Pass reconstructed full tensor to load_weights + load_weights([(name, self._bf16_snapshot[name].to("cuda"))]) + + if expected_checksum: + actual = compute_bf16_checksum(self._bf16_snapshot) + if actual != expected_checksum: + raise ValueError(f"Checksum mismatch: expected {expected_checksum[:16]}..., got {actual[:16]}...") + + logger.info( + "Applied delta patch: %d params changed, sparsity=%s", + len(changed_names), + meta.get("sparsity", "?"), + ) + + def shutdown(self) -> None: + self._bf16_snapshot = None + + @staticmethod + def trainer_send_weights( + iterator: Iterator[tuple[str, torch.Tensor]], + trainer_args: dict[str, Any] | DeltaTrainerSendWeightsArgs, + ) -> None: + """Encode a sparse patch, upload to HF Hub, and signal vLLM. + + Args: + iterator: (name, tensor) pairs from the model (e.g. FSDP streaming iterator). + trainer_args: :class:`DeltaTrainerSendWeightsArgs` (mutable, updated in-place). + """ + if isinstance(trainer_args, dict): + args = DeltaTrainerSendWeightsArgs(**trainer_args) + else: + args = trainer_args + + # 1. Collect bf16 snapshot from the streaming iterator + curr_bf16: dict[str, torch.Tensor] = {} + for name, tensor in iterator: + curr_bf16[name] = tensor.to(torch.bfloat16).cpu().clone() + + args.model_version += 1 + is_anchor = args.prev_bf16_snapshot is None or args.model_version % args.anchor_interval == 0 + + # 2. Encode to safetensors bytes (no local disk write) + if is_anchor: + checksum = compute_bf16_checksum(curr_bf16) + metadata = { + "format": "anchor_checkpoint", + "version": "1", + "model_version": str(args.model_version), + "base_model_id": args.base_model_id, + "checksum_sha256": checksum, + } + buf = save(curr_bf16, metadata=metadata) + filename = f"anchors/step_{args.model_version:06d}.safetensors" + args._last_anchor_step = args.model_version + else: + tensors, meta_dict = encode_sparse_patch( + prev_bf16=args.prev_bf16_snapshot, + curr_bf16=curr_bf16, + model_version=args.model_version, + prev_model_version=args.model_version - 1, + anchor_step=args._last_anchor_step, + base_model_id=args.base_model_id, + ) + checksum = meta_dict["checksum_sha256"] + # safetensors requires at least one tensor + if not tensors: + tensors["__empty_delta__"] = torch.zeros(1, dtype=torch.int32) + buf = save(tensors, metadata=meta_dict) + filename = f"deltas/step_{args.model_version:06d}.safetensors" + + # 3. Upload to HF Hub (Xet handles chunking/dedup) + args.hf_api.upload_file( + path_or_fileobj=buf, + path_in_repo=filename, + repo_id=args.repo_id, + revision=args.revision, + commit_message=f"step {args.model_version} ({'anchor' if is_anchor else 'delta'})", + ) + + logger.info( + "[delta_engine] uploaded %s to %s/%s (%.1f MB)", + "anchor" if is_anchor else "delta", + args.repo_id, + filename, + len(buf) / 1e6, + ) + + # 4. Signal vLLM (metadata only — no weight data) + update_info = { + "repo_id": args.repo_id, + "filename": filename, + "revision": args.revision, + "patch_type": "anchor" if is_anchor else "delta", + "expected_checksum": checksum if args.verify_checksum else "", + "is_checkpoint_format": True, # Always True: vLLM fuses params (e.g. gate_up_proj), needs model.load_weights for name mapping + } + resp = requests.post( + f"{args.url}/update_weights", + json={"update_info": update_info}, + timeout=300, + ) + resp.raise_for_status() + + # 5. Update mutable state for next call + args.prev_bf16_snapshot = curr_bf16 + + @staticmethod + def trainer_init( + repo_id: str, + url: str, + anchor_interval: int = 10, + verify_checksum: bool = True, + revision: str = "main", + base_model_id: str = "", + token: str | None = None, + ) -> DeltaTrainerSendWeightsArgs: + """Initialize trainer-side state: create/ensure HF repo and return args. + + Call once at startup. Pass the returned object to ``trainer_send_weights`` + on every weight sync. + """ + api = HfApi(token=token) + api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True) + logger.info("[delta_engine] trainer_init: repo=%s, anchor_interval=%d", repo_id, anchor_interval) + return DeltaTrainerSendWeightsArgs( + repo_id=repo_id, + url=url, + hf_api=api, + anchor_interval=anchor_interval, + verify_checksum=verify_checksum, + revision=revision, + base_model_id=base_model_id, + ) + + +class DeltaWorkerExtension: + """vLLM worker extension for the delta weight transfer backend. + + This class is intentionally minimal. Its primary role is to trigger the + import of this module (via ``--worker-extension-cls``) which registers the + engine with ``WeightTransferEngineFactory`` at module level below. + + Usage with standard ``vllm serve``:: + + VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-0.6B \\ + --worker-extension-cls trl.experimental.async_grpo.delta_engine.DeltaWorkerExtension \\ + --weight-transfer-config '{"backend":"nccl"}' \\ + --max-model-len 8192 --enforce-eager --logprobs-mode processed_logprobs + + Note: ``backend`` must be ``"nccl"`` in the CLI (pydantic ``Literal`` validation). + This module overrides the ``"nccl"`` factory entry so that the actual engine + created is ``DeltaWeightTransferEngine``. + """ + + pass + + +# --------------------------------------------------------------------------- +# Module-level registration — runs when this module is first imported +# --------------------------------------------------------------------------- + +if "delta" not in WeightTransferEngineFactory._registry: + WeightTransferEngineFactory.register_engine("delta", DeltaWeightTransferEngine) + +# Override the "nccl" factory entry so that --weight-transfer-config '{"backend":"nccl"}' +# (which passes pydantic Literal["nccl","ipc"] validation) actually creates a +# DeltaWeightTransferEngine. This is safe: the trainer side never reads the factory, +# and the worker side is explicitly opting in via --worker-extension-cls. +WeightTransferEngineFactory._registry["nccl"] = lambda: DeltaWeightTransferEngine diff --git a/trl/experimental/async_grpo/weight_diff.py b/trl/experimental/async_grpo/weight_diff.py new file mode 100644 index 00000000000..c18edec1563 --- /dev/null +++ b/trl/experimental/async_grpo/weight_diff.py @@ -0,0 +1,298 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Delta-compressed weight synchronization engine. + +Implements sparse weight patching for AsyncGRPOTrainer: detects which bf16 weights changed +between consecutive optimizer steps, encodes only the changed elements as sparse safetensors +patches, and provides a checkpoint chain (anchor + deltas) for reconstructing any step. + +References: +- PULSE paper: arXiv:2602.03839 (Feb 2026) +""" + +from __future__ import annotations + +import hashlib +import json +import logging +from dataclasses import asdict, dataclass +from pathlib import Path + +import torch +from safetensors import safe_open + + +logger = logging.getLogger(__name__) + + +def bf16_absorption_threshold(w: torch.Tensor) -> torch.Tensor: + """BF16 absorption threshold per element: |delta_w| must exceed this to survive rounding. + + BF16 has 7 mantissa bits. An fp32 update is absorbed when |delta_w| < |w| / 256. + Reference: PULSE paper Definition A.3, Equation (4). + """ + return w.abs() * (2.0**-8) + + +class ULPChangeDetector: + """Detects which bf16 weights change across an optimizer step. + + Hooks into the Adam optimizer via ``register_step_pre_hook`` / ``register_step_post_hook`` + (PyTorch >= 2.1). Runs two passes per optimizer step: + + Pre-step (ULP prediction): uses existing Adam state (m, v) to predict which weights will + change after casting back to bf16. + + Post-step (ground truth): compares post-step bf16 cast against pre-step snapshot. + """ + + def __init__(self, model: torch.nn.Module, optimizer: torch.optim.Optimizer): + self.model = model + self.optimizer = optimizer + + self._predicted_masks: dict[str, torch.Tensor] = {} + self._validated_masks: dict[str, torch.Tensor] = {} + self._pre_step_bf16: dict[str, torch.Tensor] = {} + + # Build param_id -> name mapping + self._param_id_to_name: dict[int, str] = {} + for name, param in model.named_parameters(): + name = name.removeprefix("module.") + self._param_id_to_name[id(param)] = name + + self._pre_hook_handle = optimizer.register_step_pre_hook(self._pre_step_hook) + self._post_hook_handle = optimizer.register_step_post_hook(self._post_step_hook) + + def _pre_step_hook(self, optimizer, args, kwargs) -> None: + self._predicted_masks.clear() + self._pre_step_bf16.clear() + + for group in optimizer.param_groups: + lr = group["lr"] + beta1, beta2 = group["betas"] + eps = group["eps"] + weight_decay = group.get("weight_decay", 0.0) + + for p in group["params"]: + if p.grad is None: + continue + pid = id(p) + name = self._param_id_to_name.get(pid) + if name is None: + continue + + state = optimizer.state.get(p, {}) + if "exp_avg" not in state or "exp_avg_sq" not in state: + self._pre_step_bf16[name] = p.detach().to(torch.bfloat16).cpu().clone() + continue + + step_count = state.get("step", torch.tensor(1)).item() if "step" in state else 1 + m = state["exp_avg"] + v = state["exp_avg_sq"] + + with torch.no_grad(): + m_hat = m / (1 - beta1**step_count) + v_hat = v / (1 - beta2**step_count) + predicted_delta = lr * m_hat / (v_hat.sqrt() + eps) + if weight_decay > 0: + predicted_delta = predicted_delta + lr * weight_decay * p.data + threshold = bf16_absorption_threshold(p.data) + self._predicted_masks[name] = (predicted_delta.abs() > threshold).cpu() + + self._pre_step_bf16[name] = p.detach().to(torch.bfloat16).cpu().clone() + + def _post_step_hook(self, optimizer, args, kwargs) -> None: + self._validated_masks.clear() + + for group in optimizer.param_groups: + for p in group["params"]: + if p.grad is None: + continue + pid = id(p) + name = self._param_id_to_name.get(pid) + if name is None or name not in self._pre_step_bf16: + continue + + post_bf16 = p.detach().to(torch.bfloat16).cpu() + self._validated_masks[name] = post_bf16 != self._pre_step_bf16[name] + + def get_changed_params(self, use_validated: bool = True) -> dict[str, torch.Tensor]: + masks = self._validated_masks if use_validated else self._predicted_masks + return {name: mask for name, mask in masks.items() if mask.any()} + + def get_prediction_accuracy(self) -> dict[str, float]: + total_tp, total_fp, total_fn = 0, 0, 0 + total_changed, total_elements = 0, 0 + + for name, validated in self._validated_masks.items(): + predicted = self._predicted_masks.get(name) + n_validated = validated.sum().item() + total_changed += n_validated + total_elements += validated.numel() + + if predicted is None: + total_fn += n_validated + continue + + tp = (predicted & validated).sum().item() + fp = (predicted & ~validated).sum().item() + fn = (~predicted & validated).sum().item() + total_tp += tp + total_fp += fp + total_fn += fn + + precision = total_tp / max(total_tp + total_fp, 1) + recall = total_tp / max(total_tp + total_fn, 1) + f1 = 2 * precision * recall / max(precision + recall, 1e-12) + sparsity = 1.0 - total_changed / max(total_elements, 1) + + return { + "precision": precision, + "recall": recall, + "f1": f1, + "sparsity": sparsity, + "total_changed": total_changed, + "total_elements": total_elements, + } + + def close(self): + self._pre_hook_handle.remove() + self._post_hook_handle.remove() + + +@dataclass +class PatchMetadata: + format: str = "sparse_weight_patch" + version: str = "1" + model_version: int = 0 + prev_model_version: int = -1 + anchor_step: int = 0 + base_model_id: str = "" + num_changed_params: int = 0 + total_changed_elements: int = 0 + total_elements: int = 0 + sparsity: float = 0.0 + checksum_sha256: str = "" + changed_params: str = "[]" + + def to_metadata_dict(self) -> dict[str, str]: + return {k: str(v) for k, v in asdict(self).items()} + + @classmethod + def from_metadata_dict(cls, d: dict[str, str]) -> PatchMetadata: + field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()} + kwargs = {} + for k, v in d.items(): + if k not in field_types: + continue + ft = field_types[k] + if ft == "int": + kwargs[k] = int(v) + elif ft == "float": + kwargs[k] = float(v) + else: + kwargs[k] = v + return cls(**kwargs) + + +def compute_bf16_checksum(bf16_state_dict: dict[str, torch.Tensor]) -> str: + h = hashlib.sha256() + for name in sorted(bf16_state_dict.keys()): + h.update(bf16_state_dict[name].detach().cpu().contiguous().view(torch.uint8).numpy().tobytes()) + return h.hexdigest() + + +def encode_sparse_patch( + prev_bf16: dict[str, torch.Tensor], + curr_bf16: dict[str, torch.Tensor], + model_version: int, + prev_model_version: int, + anchor_step: int, + base_model_id: str = "", +) -> tuple[dict[str, torch.Tensor], dict[str, str]]: + """Encode a sparse weight patch between two bf16 snapshots. + + Returns ``(tensors_dict, metadata_dict)`` for ``safetensors.save_file()``. + Values are actual values, not additive deltas, so reconstruction is bit-exact (PULSE Prop A.6). + """ + tensors: dict[str, torch.Tensor] = {} + changed_param_names: list[str] = [] + total_changed = 0 + total_elements = 0 + + for name in sorted(curr_bf16.keys()): + total_elements += curr_bf16[name].numel() + changed_mask = curr_bf16[name] != prev_bf16[name] + num_changed = changed_mask.sum().item() + if num_changed == 0: + continue + changed_param_names.append(name) + total_changed += num_changed + flat_indices = changed_mask.flatten().nonzero(as_tuple=False).squeeze(1).to(torch.int32) + flat_values = curr_bf16[name].flatten()[flat_indices.long()] + tensors[f"{name}.indices"] = flat_indices.cpu() + tensors[f"{name}.values"] = flat_values.cpu() + + checksum = compute_bf16_checksum(curr_bf16) + sparsity = 1.0 - total_changed / max(total_elements, 1) + + meta = PatchMetadata( + model_version=model_version, + prev_model_version=prev_model_version, + anchor_step=anchor_step, + base_model_id=base_model_id, + num_changed_params=len(changed_param_names), + total_changed_elements=total_changed, + total_elements=total_elements, + sparsity=sparsity, + checksum_sha256=checksum, + changed_params=json.dumps(changed_param_names), + ) + + return tensors, meta.to_metadata_dict() + + +def apply_sparse_patch( + base_bf16: dict[str, torch.Tensor], + patch_path: str | Path, + verify_checksum: bool = True, +) -> dict[str, torch.Tensor]: + """Apply a sparse patch to ``base_bf16`` in-place. + + For each changed parameter: ``flat[indices] = values`` (direct assignment, no FP arithmetic). + Optionally verifies SHA256 checksum. Raises ``ValueError`` on mismatch. + """ + with safe_open(str(patch_path), framework="pt", device="cpu") as f: + meta = f.metadata() + changed_names = json.loads(meta["changed_params"]) + for name in changed_names: + indices = f.get_tensor(f"{name}.indices").long() + values = f.get_tensor(f"{name}.values") + flat = base_bf16[name].flatten() + flat[indices] = values + base_bf16[name] = flat.reshape(base_bf16[name].shape) + + if verify_checksum: + expected = meta.get("checksum_sha256", "") + if expected: + actual = compute_bf16_checksum(base_bf16) + if actual != expected: + raise ValueError( + f"Checksum mismatch after applying patch {patch_path}: " + f"expected {expected[:16]}..., got {actual[:16]}..." + ) + + return base_bf16 From a3a382585f9ba8b30bafbc61b058ce9dd0d640f9 Mon Sep 17 00:00:00 2001 From: aminediro Date: Tue, 31 Mar 2026 12:20:43 +0000 Subject: [PATCH 2/5] Add delta weight synchronization support to AsyncGRPO - Add `huggingface-hub` as dependency - Introduce sparse weight patching via `DeltaWeightTransferEngine` - Add `ULPChangeDetector` for optimizer-level change tracking - Add config parameters for delta sync control (repo, anchor interval, checksum verification) - Support both anchor checkpoints and delta patches via HF Hub (Xet storage) Add delta weight synchronization support to AsyncGRPO Implements two-phase delta sync workflow: non-blocking upload to HF Hub while inference continues, followed by a signal to vLLM to fetch and apply. Adds ULP change detection to selectively sync only modified parameters with element-level masks. Simplifies delta engine API by removing anchor/checksum logic; now uses HF Hub directly without intermediate configuration objects. --- .../async_grpo/async_grpo_trainer.py | 84 +++-- .../async_grpo/async_rollout_worker.py | 74 ++-- trl/experimental/async_grpo/delta_engine.py | 333 ++++++------------ 3 files changed, 220 insertions(+), 271 deletions(-) diff --git a/trl/experimental/async_grpo/async_grpo_trainer.py b/trl/experimental/async_grpo/async_grpo_trainer.py index 9c9d71e542c..ed0677c554a 100644 --- a/trl/experimental/async_grpo/async_grpo_trainer.py +++ b/trl/experimental/async_grpo/async_grpo_trainer.py @@ -383,9 +383,6 @@ def __init__( weight_shapes=weight_shapes, delta_sync_enabled=self.args.delta_sync_enabled, delta_sync_repo_id=self.args.delta_sync_repo_id, - delta_sync_anchor_interval=self.args.delta_sync_anchor_interval, - delta_sync_verify_checksum=self.args.delta_sync_verify_checksum, - delta_sync_base_model_id=model_name, ) self.rollout_queue = self.rollout_worker.rollout_buffer else: @@ -572,15 +569,39 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: self._metrics[mode].clear() def _streaming_iter(self): - # Iterate parameters one at a time. For FSDP2 (DTensor), full_tensor() all-gathers just this parameter across - # FSDP ranks, then frees it once the generator advances — avoiding materializing the full model in memory. + """Yield ``(name, tensor, mask)`` tuples. + + - ``_ulp_detector is None`` (NCCL path): all params, ``mask=None``. + - ULP active, no masks yet (first sync): all params, ``mask=None``. + - ULP active, masks available: only changed params with element-level masks. + """ + if self._ulp_detector is None: + for name, param in self.model.named_parameters(): + name = name.removeprefix("module.") + yield name, (param.full_tensor() if isinstance(param, DTensor) else param.detach()), None + return + + masks = self._ulp_detector._validated_masks + if not masks: + # First sync after ULP init — no optimizer step yet, send everything + for name, param in self.model.named_parameters(): + name = name.removeprefix("module.") + yield name, (param.full_tensor() if isinstance(param, DTensor) else param.detach()), None + return + + total, yielded = 0, 0 for name, param in self.model.named_parameters(): - name = name.removeprefix("module.") # DDP/FSDP1 wrapping - full = param.full_tensor() if isinstance(param, DTensor) else param.detach() - yield name, full + total += 1 + name = name.removeprefix("module.") + mask = masks.get(name) + if mask is None or not mask.any(): + continue + yield name, (param.full_tensor() if isinstance(param, DTensor) else param.detach()), mask + yielded += 1 + logger.info(f"ULP: {yielded}/{total} params changed") def _sync_weight(self): - # Lazy-init ULP detector for diagnostic logging (delta sync only). + # Lazy-init ULP detector for diagnostic logging (delta sync only) bc # Optimizer only exists after Trainer creates it inside super()._inner_training_loop(). if ( self.args.delta_sync_enabled @@ -588,12 +609,12 @@ def _sync_weight(self): and hasattr(self, "optimizer") and self.optimizer is not None ): + # TODO(@aminediro): check this works with FSDP2 # Unwrap AcceleratedOptimizer to get the native PyTorch optimizer # (register_step_pre_hook requires torch.optim.Optimizer internals) raw_optimizer = getattr(self.optimizer, "optimizer", self.optimizer) self._ulp_detector = ULPChangeDetector(self.model, raw_optimizer) - # Log ULP prediction accuracy (diagnostic, doesn't affect sync) if ( self.args.delta_sync_enabled and self.args.delta_sync_log_ulp_accuracy @@ -610,34 +631,57 @@ def _sync_weight(self): ) t0 = time.time() - logger.info("Weight sync: pausing vLLM...") + is_delta = self.args.delta_sync_enabled + + if is_delta: + # Phase 1: Upload to HF Hub while inference continues + logger.info("Weight sync: uploading to HF Hub (inference still running)...") + if self.accelerator.is_main_process and self.rollout_worker: + self.rollout_worker.send_weights(self._streaming_iter()) + else: + for _ in self._streaming_iter(): + pass + self.accelerator.wait_for_everyone() + t_upload = time.time() + logger.info(f"Weight sync: upload took {t_upload - t0:.1f}s, now pausing vLLM...") + + # Phase 2: Pause inference if self.accelerator.is_main_process and self.rollout_worker: self.rollout_worker.pause() t_pause = time.time() - logger.info(f"Weight sync: pause took {t_pause - t0:.1f}s, waiting for all ranks...") self.accelerator.wait_for_everyone() t_barrier = time.time() - logger.info(f"Weight sync: transferring weights... (barrier took {t_barrier - t_pause:.1f}s)") - if self.accelerator.is_main_process and self.rollout_worker: - self.rollout_worker.send_weights(self._streaming_iter()) + if is_delta: + # Phase 3: Signal vLLM to fetch the already-uploaded weights + logger.info(f"Weight sync: signaling vLLM to apply... (pause took {t_pause - t_upload:.1f}s)") + if self.accelerator.is_main_process and self.rollout_worker: + self.rollout_worker.send_weights(iter([])) else: - # Non-rank-0 processes must still participate in full_tensor() collectives for FSDP2. - for _ in self._streaming_iter(): - pass + # NCCL: transfer all weights directly + logger.info(f"Weight sync: transferring weights... (barrier took {t_barrier - t_pause:.1f}s)") + if self.accelerator.is_main_process and self.rollout_worker: + self.rollout_worker.send_weights(self._streaming_iter()) + else: + for _ in self._streaming_iter(): + pass + t_transfer = time.time() self.accelerator.wait_for_everyone() - logger.info(f"Weight sync: resuming vLLM... (transfer took {t_transfer - t_barrier:.1f}s)") + # Phase 4: Resume + logger.info(f"Weight sync: resuming vLLM... (apply took {t_transfer - t_barrier:.1f}s)") if self.accelerator.is_main_process and self.rollout_worker: self.rollout_worker.resume() self.model_version += 1 self.rollout_worker.update_model_version(self.model_version) weight_sync_time_s = time.time() - t0 self._metrics["train"]["weight_sync_time_s"].append(weight_sync_time_s) - logger.info(f"Weight sync: done. Total {weight_sync_time_s:.1f}s") + logger.info( + f"Weight sync: done. Total {weight_sync_time_s:.1f}s (inference paused {t_transfer - t_pause:.1f}s)" + ) def _inner_training_loop(self, *args, **kwargs): # Start the rollout worker here (not in __init__) so that checkpoint loading in Trainer.train() diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index 6d274e285dd..b62b3b341ef 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -104,9 +104,6 @@ def __init__( weight_shapes: list[list[int]] | None = None, delta_sync_enabled: bool = False, delta_sync_repo_id: str | None = None, - delta_sync_anchor_interval: int = 10, - delta_sync_verify_checksum: bool = True, - delta_sync_base_model_id: str = "", ): if not is_vllm_available(min_version="0.17.1"): raise ImportError( @@ -178,9 +175,6 @@ def __init__( self.session = None self._delta_sync_repo_id = delta_sync_repo_id - self._delta_sync_anchor_interval = delta_sync_anchor_interval - self._delta_sync_verify_checksum = delta_sync_verify_checksum - self._delta_sync_base_model_id = delta_sync_base_model_id # Wait for the vLLM server and initialize weight transfer. self._wait_for_server_ready_sync(timeout_s=self.server_timeout) @@ -211,15 +205,12 @@ def _wait_for_server_ready_sync(self, timeout_s: float = 240.0, poll_interval_s: def _init_weight_transfer(self) -> None: if self.delta_sync_enabled: - from .delta_engine import DeltaWeightTransferEngine + from huggingface_hub import HfApi - self._delta_trainer_args = DeltaWeightTransferEngine.trainer_init( - repo_id=self._delta_sync_repo_id, - url=self.vllm_server_url, - anchor_interval=self._delta_sync_anchor_interval, - verify_checksum=self._delta_sync_verify_checksum, - base_model_id=self._delta_sync_base_model_id, - ) + self._delta_hf_api = HfApi() + self._delta_hf_api.create_repo(repo_id=self._delta_sync_repo_id, repo_type="model", exist_ok=True) + self._delta_model_version = 0 + self._delta_pending_update_info: dict | None = None requests.post( f"{self.vllm_server_url}/init_weight_transfer_engine", json={"init_info": {}}, @@ -317,15 +308,51 @@ def resume(self) -> None: def send_weights(self, iterator) -> None: if self.delta_sync_enabled: - from .delta_engine import DeltaWeightTransferEngine - - t0 = time.time() - DeltaWeightTransferEngine.trainer_send_weights( - iterator=iterator, - trainer_args=self._delta_trainer_args, + self._send_weights_delta(iterator) + else: + self._send_weights_nccl(iterator) + + def _send_weights_delta(self, iterator) -> None: + """Delta sync via HF Hub. + + - Non-empty iterator (upload phase): encode + upload + stash update_info. + - Empty iterator (signal phase): POST /update_weights with stashed info. + """ + from .delta_engine import DeltaWeightTransferEngine + + # Peek first item to distinguish upload vs signal + first = next(iterator, None) + if first is not None: + # Re-chain the first item back + import itertools + + full_iter = itertools.chain([first], iterator) + self._delta_model_version += 1 + filename = f"steps/step_{self._delta_model_version:06d}.safetensors" + DeltaWeightTransferEngine.upload( + iterator=full_iter, + repo_id=self._delta_sync_repo_id, + filename=filename, + hf_api=self._delta_hf_api, ) - logger.info(f"[delta_sync] send_weights took {time.time() - t0:.1f}s") - return + self._delta_pending_update_info = { + "repo_id": self._delta_sync_repo_id, + "filename": filename, + "revision": "main", + "is_checkpoint_format": True, + } + else: + # Empty iterator — signal vLLM to fetch and apply + resp = requests.post( + f"{self.vllm_server_url}/update_weights", + json={"update_info": self._delta_pending_update_info}, + timeout=300, + ) + resp.raise_for_status() + self._delta_pending_update_info = None + + def _send_weights_nccl(self, iterator) -> None: + """NCCL sync: broadcast all params via NCCL + signal /update_weights.""" if self.model_update_group is None: return t0 = time.time() @@ -337,8 +364,9 @@ def send_weights(self, iterator) -> None: t_update.start() logger.debug(f"[weight_sync] /update_weights POST sent ({time.time() - t0:.1f}s)") t_nccl = time.time() + # Strip the mask — NCCL expects (name, tensor) pairs NCCLWeightTransferEngine.trainer_send_weights( - iterator=iterator, + iterator=((name, tensor) for name, tensor, _mask in iterator), trainer_args=NCCLTrainerSendWeightsArgs(group=self.model_update_group, packed=True), ) logger.debug(f"[weight_sync] NCCL transfer took {time.time() - t_nccl:.1f}s") diff --git a/trl/experimental/async_grpo/delta_engine.py b/trl/experimental/async_grpo/delta_engine.py index 66edc6858fa..9c106ac66c8 100644 --- a/trl/experimental/async_grpo/delta_engine.py +++ b/trl/experimental/async_grpo/delta_engine.py @@ -12,34 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Delta weight transfer engine for vLLM. - -Uses HuggingFace Hub (Xet storage) as the data plane for sparse weight patches. -The trainer uploads patches to HF Hub, then sends a lightweight metadata signal -to vLLM via ``/update_weights``. The vLLM worker downloads and applies patches. - -Registration happens at module import time so that vLLM's ``WeightTransferEngineFactory`` -can find the ``"delta"`` backend. Use ``--worker-extension-cls`` to trigger the import:: - - vllm serve model_name \\ - --worker-extension-cls trl.experimental.async_grpo.delta_engine.DeltaWorkerExtension \\ - --weight-transfer-backend delta -""" - from __future__ import annotations -import json import logging from collections.abc import Callable, Iterator -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any -import requests import torch -from huggingface_hub import HfApi, hf_hub_download +from huggingface_hub import hf_hub_download from safetensors import safe_open -from safetensors.torch import load_file, save +from safetensors.torch import save from vllm.config.parallel import ParallelConfig from vllm.config.weight_transfer import WeightTransferConfig from vllm.distributed.weight_transfer.base import ( @@ -49,62 +32,29 @@ ) from vllm.distributed.weight_transfer.factory import WeightTransferEngineFactory -from .weight_diff import compute_bf16_checksum, encode_sparse_patch - logger = logging.getLogger(__name__) -# --------------------------------------------------------------------------- -# Dataclasses -# --------------------------------------------------------------------------- - - @dataclass class DeltaWeightTransferInitInfo(WeightTransferInitInfo): - """No initialization needed for file-based Hub transport.""" - pass @dataclass class DeltaWeightTransferUpdateInfo(WeightTransferUpdateInfo): - """Metadata sent via ``/update_weights`` — no weight data, just Hub coordinates.""" + """Metadata sent via ``/update_weights`` — just Hub coordinates.""" repo_id: str = "" filename: str = "" revision: str = "main" - patch_type: str = "anchor" # "anchor" | "delta" - expected_checksum: str = "" - # is_checkpoint_format: True for anchor (layerwise reload), False for delta (param.copy_) - - -@dataclass -class DeltaTrainerSendWeightsArgs: - """Trainer-side state passed to ``trainer_send_weights``. - - This is a mutable object — ``prev_bf16_snapshot`` and ``model_version`` - are updated after each call so the next call can compute a diff. - """ - - repo_id: str - url: str # vLLM server URL (for the /update_weights signal only) - hf_api: HfApi = field(default_factory=HfApi) - anchor_interval: int = 10 - verify_checksum: bool = True - revision: str = "main" - base_model_id: str = "" - # Mutable state — updated after each call: - prev_bf16_snapshot: dict[str, torch.Tensor] | None = None - model_version: int = 0 - _last_anchor_step: int = 0 class DeltaWeightTransferEngine(WeightTransferEngine[DeltaWeightTransferInitInfo, DeltaWeightTransferUpdateInfo]): - """Weight transfer engine that uses HF Hub (Xet) for sparse delta patches. + """Weight transfer engine that uses HF Hub (Xet) as the data plane. - Worker side: downloads patches from Hub and applies them via a CPU bf16 snapshot. - Trainer side: encodes sparse patches and uploads them to Hub. + Worker side: downloads safetensors from Hub, feeds to ``load_weights``. + Trainer side: uploads changed params as safetensors to Hub. """ init_info_cls = DeltaWeightTransferInitInfo @@ -112,16 +62,28 @@ class DeltaWeightTransferEngine(WeightTransferEngine[DeltaWeightTransferInitInfo def __init__(self, config: WeightTransferConfig, parallel_config: ParallelConfig) -> None: super().__init__(config, parallel_config) + # CPU-side bf16 snapshot — needed because vLLM's load_weights expects full + # tensors, so we must reconstruct them from sparse (indices, values) patches. + # Kept on CPU to avoid GPU memory overhead (~2 bytes/param, e.g. ~1.2 GB for 0.6B model). self._bf16_snapshot: dict[str, torch.Tensor] | None = None def init_transfer_engine(self, init_info: DeltaWeightTransferInitInfo) -> None: - pass # No process group setup needed + pass def receive_weights( self, update_info: DeltaWeightTransferUpdateInfo, load_weights: Callable[[list[tuple[str, torch.Tensor]]], None], ) -> None: + """Download safetensors from Hub and feed to load_weights. + + Handles two formats based on the ``sparse`` metadata flag: + + - **Full** (first sync): keys are param names → feed directly to load_weights, + build snapshot for future sparse applies. + - **Sparse** (subsequent): keys are ``{name}.indices`` + ``{name}.values`` → + apply to snapshot, feed reconstructed full tensors to load_weights. + """ local_path = hf_hub_download( repo_id=update_info.repo_id, filename=update_info.filename, @@ -129,61 +91,33 @@ def receive_weights( force_download=True, ) - if update_info.patch_type == "anchor": - self._receive_anchor(local_path, load_weights) - else: - self._receive_delta(local_path, update_info.expected_checksum, load_weights) - - def _receive_anchor( - self, - local_path: str, - load_weights: Callable[[list[tuple[str, torch.Tensor]]], None], - ) -> None: - """Load a full anchor checkpoint and rebuild the snapshot.""" - state = load_file(local_path, device="cpu") - self._bf16_snapshot = {} - for name, tensor in state.items(): - self._bf16_snapshot[name] = tensor.to(torch.bfloat16).clone() - load_weights([(name, tensor)]) - logger.info("Loaded anchor checkpoint with %d parameters", len(self._bf16_snapshot)) - - def _receive_delta( - self, - local_path: str, - expected_checksum: str, - load_weights: Callable[[list[tuple[str, torch.Tensor]]], None], - ) -> None: - """Apply a sparse delta patch to the snapshot, then feed changed params to load_weights.""" - if self._bf16_snapshot is None: - raise RuntimeError( - "Cannot apply delta patch without a prior anchor. " - "Ensure the first weight sync is an anchor (is_checkpoint_format=True)." - ) - with safe_open(local_path, framework="pt", device="cpu") as f: - meta = f.metadata() - changed_names = json.loads(meta.get("changed_params", "[]")) - - for name in changed_names: - indices = f.get_tensor(f"{name}.indices").long() - values = f.get_tensor(f"{name}.values") - # Apply to CPU snapshot (bit-exact, no FP arithmetic) - snap_flat = self._bf16_snapshot[name].flatten() - snap_flat[indices] = values - self._bf16_snapshot[name] = snap_flat.reshape(self._bf16_snapshot[name].shape) - # Pass reconstructed full tensor to load_weights - load_weights([(name, self._bf16_snapshot[name].to("cuda"))]) - - if expected_checksum: - actual = compute_bf16_checksum(self._bf16_snapshot) - if actual != expected_checksum: - raise ValueError(f"Checksum mismatch: expected {expected_checksum[:16]}..., got {actual[:16]}...") - - logger.info( - "Applied delta patch: %d params changed, sparsity=%s", - len(changed_names), - meta.get("sparsity", "?"), - ) + is_sparse = f.metadata().get("sparse", "False") == "True" + + if not is_sparse: + # Full tensors — feed to load_weights and build snapshot + self._bf16_snapshot = {} + for name in f.keys(): + tensor = f.get_tensor(name) + self._bf16_snapshot[name] = tensor.to(torch.bfloat16).clone() + load_weights([(name, tensor)]) + logger.info("Applied full weights (%d params)", len(self._bf16_snapshot)) + else: + # Sparse — apply indices/values to snapshot, feed full tensors + changed = set() + for key in f.keys(): + if key.endswith(".indices"): + changed.add(key.removesuffix(".indices")) + + for name in changed: + indices = f.get_tensor(f"{name}.indices").long() + values = f.get_tensor(f"{name}.values") + snap = self._bf16_snapshot[name].flatten() + snap[indices] = values + self._bf16_snapshot[name] = snap.reshape(self._bf16_snapshot[name].shape) + load_weights([(name, self._bf16_snapshot[name])]) + + logger.info("Applied sparse weights (%d params changed)", len(changed)) def shutdown(self) -> None: self._bf16_snapshot = None @@ -191,136 +125,82 @@ def shutdown(self) -> None: @staticmethod def trainer_send_weights( iterator: Iterator[tuple[str, torch.Tensor]], - trainer_args: dict[str, Any] | DeltaTrainerSendWeightsArgs, + trainer_args: dict[str, Any] | Any, ) -> None: - """Encode a sparse patch, upload to HF Hub, and signal vLLM. + """Not used directly — the rollout worker manages upload + signaling.""" + raise NotImplementedError("Use AsyncRolloutWorker._send_weights_delta instead") - Args: - iterator: (name, tensor) pairs from the model (e.g. FSDP streaming iterator). - trainer_args: :class:`DeltaTrainerSendWeightsArgs` (mutable, updated in-place). + @staticmethod + def upload( + iterator: Iterator[tuple[str, torch.Tensor, torch.Tensor | None]], + repo_id: str, + filename: str, + hf_api: Any, + revision: str = "main", + ) -> int: + """Encode params as safetensors and upload to HF Hub. + + Each item is ``(name, tensor, mask)``: + + - ``mask is None``: full tensor stored as ``name`` (anchor). + - ``mask`` provided: sparse encoding — only changed elements stored + as ``{name}.indices`` (int32) + ``{name}.values`` (bf16). + + Returns the number of params encoded. """ - if isinstance(trainer_args, dict): - args = DeltaTrainerSendWeightsArgs(**trainer_args) - else: - args = trainer_args - - # 1. Collect bf16 snapshot from the streaming iterator - curr_bf16: dict[str, torch.Tensor] = {} - for name, tensor in iterator: - curr_bf16[name] = tensor.to(torch.bfloat16).cpu().clone() - - args.model_version += 1 - is_anchor = args.prev_bf16_snapshot is None or args.model_version % args.anchor_interval == 0 - - # 2. Encode to safetensors bytes (no local disk write) - if is_anchor: - checksum = compute_bf16_checksum(curr_bf16) - metadata = { - "format": "anchor_checkpoint", - "version": "1", - "model_version": str(args.model_version), - "base_model_id": args.base_model_id, - "checksum_sha256": checksum, - } - buf = save(curr_bf16, metadata=metadata) - filename = f"anchors/step_{args.model_version:06d}.safetensors" - args._last_anchor_step = args.model_version - else: - tensors, meta_dict = encode_sparse_patch( - prev_bf16=args.prev_bf16_snapshot, - curr_bf16=curr_bf16, - model_version=args.model_version, - prev_model_version=args.model_version - 1, - anchor_step=args._last_anchor_step, - base_model_id=args.base_model_id, - ) - checksum = meta_dict["checksum_sha256"] - # safetensors requires at least one tensor - if not tensors: - tensors["__empty_delta__"] = torch.zeros(1, dtype=torch.int32) - buf = save(tensors, metadata=meta_dict) - filename = f"deltas/step_{args.model_version:06d}.safetensors" - - # 3. Upload to HF Hub (Xet handles chunking/dedup) - args.hf_api.upload_file( + tensors: dict[str, torch.Tensor] = {} + num_params = 0 + sparse = False + + for name, tensor, mask in iterator: + num_params += 1 + bf16 = tensor.to(torch.bfloat16).cpu() + if mask is None: + tensors[name] = bf16.clone() + else: + sparse = True + indices = mask.flatten().nonzero(as_tuple=False).squeeze(1).to(torch.int32) + values = bf16.flatten()[indices.long()] + tensors[f"{name}.indices"] = indices + tensors[f"{name}.values"] = values + + if not tensors: + return 0 + + metadata = {"num_params": str(num_params), "sparse": str(sparse)} + buf = save(tensors, metadata=metadata) + + hf_api.upload_file( path_or_fileobj=buf, path_in_repo=filename, - repo_id=args.repo_id, - revision=args.revision, - commit_message=f"step {args.model_version} ({'anchor' if is_anchor else 'delta'})", + repo_id=repo_id, + revision=revision, + commit_message=f"weight update ({num_params} params, {len(buf) / 1e6:.1f} MB, sparse={sparse})", ) logger.info( - "[delta_engine] uploaded %s to %s/%s (%.1f MB)", - "anchor" if is_anchor else "delta", - args.repo_id, + "[delta_engine] uploaded %s/%s (%.1f MB, %d params, sparse=%s)", + repo_id, filename, len(buf) / 1e6, + num_params, + sparse, ) + return num_params - # 4. Signal vLLM (metadata only — no weight data) - update_info = { - "repo_id": args.repo_id, - "filename": filename, - "revision": args.revision, - "patch_type": "anchor" if is_anchor else "delta", - "expected_checksum": checksum if args.verify_checksum else "", - "is_checkpoint_format": True, # Always True: vLLM fuses params (e.g. gate_up_proj), needs model.load_weights for name mapping - } - resp = requests.post( - f"{args.url}/update_weights", - json={"update_info": update_info}, - timeout=300, - ) - resp.raise_for_status() - - # 5. Update mutable state for next call - args.prev_bf16_snapshot = curr_bf16 - @staticmethod - def trainer_init( - repo_id: str, - url: str, - anchor_interval: int = 10, - verify_checksum: bool = True, - revision: str = "main", - base_model_id: str = "", - token: str | None = None, - ) -> DeltaTrainerSendWeightsArgs: - """Initialize trainer-side state: create/ensure HF repo and return args. - - Call once at startup. Pass the returned object to ``trainer_send_weights`` - on every weight sync. - """ - api = HfApi(token=token) - api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True) - logger.info("[delta_engine] trainer_init: repo=%s, anchor_interval=%d", repo_id, anchor_interval) - return DeltaTrainerSendWeightsArgs( - repo_id=repo_id, - url=url, - hf_api=api, - anchor_interval=anchor_interval, - verify_checksum=verify_checksum, - revision=revision, - base_model_id=base_model_id, - ) +# --------------------------------------------------------------------------- +# Worker extension — its import triggers engine registration +# --------------------------------------------------------------------------- class DeltaWorkerExtension: """vLLM worker extension for the delta weight transfer backend. - This class is intentionally minimal. Its primary role is to trigger the - import of this module (via ``--worker-extension-cls``) which registers the - engine with ``WeightTransferEngineFactory`` at module level below. - - Usage with standard ``vllm serve``:: - - VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-0.6B \\ - --worker-extension-cls trl.experimental.async_grpo.delta_engine.DeltaWorkerExtension \\ - --weight-transfer-config '{"backend":"nccl"}' \\ - --max-model-len 8192 --enforce-eager --logprobs-mode processed_logprobs + This class is intentionally minimal. Its import (via ``--worker-extension-cls``) + registers the engine and overrides the ``"nccl"`` factory entry. - Note: ``backend`` must be ``"nccl"`` in the CLI (pydantic ``Literal`` validation). + ``backend`` must be ``"nccl"`` in the CLI (pydantic ``Literal`` validation). This module overrides the ``"nccl"`` factory entry so that the actual engine created is ``DeltaWeightTransferEngine``. """ @@ -329,14 +209,11 @@ class DeltaWorkerExtension: # --------------------------------------------------------------------------- -# Module-level registration — runs when this module is first imported +# Module-level registration # --------------------------------------------------------------------------- if "delta" not in WeightTransferEngineFactory._registry: WeightTransferEngineFactory.register_engine("delta", DeltaWeightTransferEngine) -# Override the "nccl" factory entry so that --weight-transfer-config '{"backend":"nccl"}' -# (which passes pydantic Literal["nccl","ipc"] validation) actually creates a -# DeltaWeightTransferEngine. This is safe: the trainer side never reads the factory, -# and the worker side is explicitly opting in via --worker-extension-cls. +# Override "nccl" so --weight-transfer-config '{"backend":"nccl"}' creates our engine. WeightTransferEngineFactory._registry["nccl"] = lambda: DeltaWeightTransferEngine From 592734f2aebcdfd471db42e1e2b5a8e15878ac36 Mon Sep 17 00:00:00 2001 From: aminediro Date: Tue, 31 Mar 2026 14:07:29 +0000 Subject: [PATCH 3/5] Simplify weight diff detector and remove unused features Remove ULP prediction logic, diagnostic logging config, and checkpoint chain reconstruction. Keep only ground-truth bf16 change detection via optimizer hooks and sparse patch metadata. --- .../async_grpo/async_grpo_config.py | 7 - .../async_grpo/async_grpo_trainer.py | 57 +++-- .../async_grpo/async_rollout_worker.py | 46 ++-- trl/experimental/async_grpo/delta_engine.py | 66 +++-- trl/experimental/async_grpo/weight_diff.py | 233 +++--------------- 5 files changed, 129 insertions(+), 280 deletions(-) diff --git a/trl/experimental/async_grpo/async_grpo_config.py b/trl/experimental/async_grpo/async_grpo_config.py index a2fca996e33..775c6811824 100644 --- a/trl/experimental/async_grpo/async_grpo_config.py +++ b/trl/experimental/async_grpo/async_grpo_config.py @@ -216,13 +216,6 @@ class AsyncGRPOConfig(_BaseConfig): "Adds overhead per sync but guarantees bit-exact reconstruction." }, ) - delta_sync_log_ulp_accuracy: bool = field( - default=True, - metadata={ - "help": "Log precision/recall/F1 of ULP-based change predictions vs. actual bf16 " - "changes. Useful for validating the optimizer hook approach." - }, - ) # Parameters that control the logging log_completions: bool = field( diff --git a/trl/experimental/async_grpo/async_grpo_trainer.py b/trl/experimental/async_grpo/async_grpo_trainer.py index ed0677c554a..63ed86124bf 100644 --- a/trl/experimental/async_grpo/async_grpo_trainer.py +++ b/trl/experimental/async_grpo/async_grpo_trainer.py @@ -35,7 +35,7 @@ from .async_grpo_config import AsyncGRPOConfig from .async_rollout_worker import AsyncRolloutWorker -from .weight_diff import ULPChangeDetector +from .weight_diff import BF16ChangeDetector logger = get_logger(__name__) @@ -383,6 +383,7 @@ def __init__( weight_shapes=weight_shapes, delta_sync_enabled=self.args.delta_sync_enabled, delta_sync_repo_id=self.args.delta_sync_repo_id, + delta_sync_anchor_interval=self.args.delta_sync_anchor_interval, ) self.rollout_queue = self.rollout_worker.rollout_buffer else: @@ -393,7 +394,7 @@ def __init__( self.add_callback(StepIntervalCallback(self._sync_weight, self.args.weight_sync_steps)) # ULP change detector for diagnostic logging (delta sync only) - self._ulp_detector: ULPChangeDetector | None = None + self._change_detector: BF16ChangeDetector | None = None def get_train_dataloader(self) -> DataLoader: if self.accelerator.is_main_process: @@ -571,19 +572,25 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: def _streaming_iter(self): """Yield ``(name, tensor, mask)`` tuples. - - ``_ulp_detector is None`` (NCCL path): all params, ``mask=None``. - - ULP active, no masks yet (first sync): all params, ``mask=None``. - - ULP active, masks available: only changed params with element-level masks. + - No weight diff (NCCL path): all params, ``mask=None``. + - Anchor step (every ``anchor_interval``): all params, ``mask=None``. + - Delta step: only changed params with element-level masks. + - Nothing changed: yields nothing (no-op). """ - if self._ulp_detector is None: + # NCCL path + if self._change_detector is None: for name, param in self.model.named_parameters(): name = name.removeprefix("module.") yield name, (param.full_tensor() if isinstance(param, DTensor) else param.detach()), None return - masks = self._ulp_detector._validated_masks - if not masks: - # First sync after ULP init — no optimizer step yet, send everything + # Force full upload on anchor interval or when no masks yet (first sync) + masks = self._change_detector._validated_masks + next_version = self.model_version + 1 + force_anchor = not masks or next_version % self.args.delta_sync_anchor_interval == 0 + + if force_anchor: + logger.info("Anchor step %d: sending all params", next_version) for name, param in self.model.named_parameters(): name = name.removeprefix("module.") yield name, (param.full_tensor() if isinstance(param, DTensor) else param.detach()), None @@ -605,7 +612,7 @@ def _sync_weight(self): # Optimizer only exists after Trainer creates it inside super()._inner_training_loop(). if ( self.args.delta_sync_enabled - and self._ulp_detector is None + and self._change_detector is None and hasattr(self, "optimizer") and self.optimizer is not None ): @@ -613,22 +620,24 @@ def _sync_weight(self): # Unwrap AcceleratedOptimizer to get the native PyTorch optimizer # (register_step_pre_hook requires torch.optim.Optimizer internals) raw_optimizer = getattr(self.optimizer, "optimizer", self.optimizer) - self._ulp_detector = ULPChangeDetector(self.model, raw_optimizer) + self._change_detector = BF16ChangeDetector(self.model, raw_optimizer) if ( self.args.delta_sync_enabled - and self.args.delta_sync_log_ulp_accuracy - and self._ulp_detector is not None + and self._change_detector is not None + and self._change_detector._validated_masks and self.accelerator.is_main_process ): - accuracy = self._ulp_detector.get_prediction_accuracy() - for k, v in accuracy.items(): - self._metrics["train"][f"delta/{k}"].append(v) - logger.info( - f"ULP accuracy: precision={accuracy['precision']:.3f} " - f"recall={accuracy['recall']:.3f} " - f"sparsity={accuracy['sparsity']:.4%}" - ) + total_changed = 0 + total_elements = 0 + for mask in self._change_detector._validated_masks.values(): + total_changed += mask.sum().item() + total_elements += mask.numel() + sparsity = 1.0 - total_changed / max(total_elements, 1) + self._metrics["train"]["delta/sparsity"].append(sparsity) + self._metrics["train"]["delta/total_changed"].append(total_changed) + self._metrics["train"]["delta/total_elements"].append(total_elements) + logger.info(f"Delta: {total_changed}/{total_elements} elements changed (sparsity={sparsity:.4%})") t0 = time.time() is_delta = self.args.delta_sync_enabled @@ -696,6 +705,6 @@ def _inner_training_loop(self, *args, **kwargs): finally: if self.accelerator.is_main_process and self.rollout_worker: self.rollout_worker.stop() - if self._ulp_detector is not None: - self._ulp_detector.close() - self._ulp_detector = None + if self._change_detector is not None: + self._change_detector.close() + self._change_detector = None diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index b62b3b341ef..fb1fe5f61f5 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -14,6 +14,7 @@ import asyncio import inspect +import itertools import queue import threading import time @@ -32,6 +33,8 @@ from trl.import_utils import is_vllm_available from trl.trainer.utils import print_prompt_completions_sample +from .delta_engine import DeltaWeightTransferEngine + if is_vllm_available(min_version="0.17.1"): from vllm.distributed.weight_transfer.nccl_engine import NCCLTrainerSendWeightsArgs, NCCLWeightTransferEngine @@ -104,6 +107,7 @@ def __init__( weight_shapes: list[list[int]] | None = None, delta_sync_enabled: bool = False, delta_sync_repo_id: str | None = None, + delta_sync_anchor_interval: int = 10, ): if not is_vllm_available(min_version="0.17.1"): raise ImportError( @@ -175,6 +179,7 @@ def __init__( self.session = None self._delta_sync_repo_id = delta_sync_repo_id + self._delta_sync_anchor_interval = delta_sync_anchor_interval # Wait for the vLLM server and initialize weight transfer. self._wait_for_server_ready_sync(timeout_s=self.server_timeout) @@ -313,36 +318,29 @@ def send_weights(self, iterator) -> None: self._send_weights_nccl(iterator) def _send_weights_delta(self, iterator) -> None: - """Delta sync via HF Hub. - - - Non-empty iterator (upload phase): encode + upload + stash update_info. - - Empty iterator (signal phase): POST /update_weights with stashed info. - """ - from .delta_engine import DeltaWeightTransferEngine - - # Peek first item to distinguish upload vs signal first = next(iterator, None) if first is not None: - # Re-chain the first item back - import itertools - full_iter = itertools.chain([first], iterator) self._delta_model_version += 1 - filename = f"steps/step_{self._delta_model_version:06d}.safetensors" - DeltaWeightTransferEngine.upload( + is_anchor = first[2] is None + subdir = "anchors" if is_anchor else "deltas" + filename = f"{subdir}/step_{self._delta_model_version:06d}.safetensors" + meta = DeltaWeightTransferEngine.upload( iterator=full_iter, repo_id=self._delta_sync_repo_id, filename=filename, hf_api=self._delta_hf_api, + model_version=self._delta_model_version, ) - self._delta_pending_update_info = { - "repo_id": self._delta_sync_repo_id, - "filename": filename, - "revision": "main", - "is_checkpoint_format": True, - } - else: - # Empty iterator — signal vLLM to fetch and apply + if meta is not None: + self._delta_pending_update_info = { + "repo_id": self._delta_sync_repo_id, + "filename": filename, + "revision": "main", + "is_checkpoint_format": True, + } + elif self._delta_pending_update_info is not None: + # Signal vLLM to fetch and apply resp = requests.post( f"{self.vllm_server_url}/update_weights", json={"update_info": self._delta_pending_update_info}, @@ -350,6 +348,9 @@ def _send_weights_delta(self, iterator) -> None: ) resp.raise_for_status() self._delta_pending_update_info = None + else: + # Nothing changed and nothing pending — no-op + logger.debug("[delta_sync] nothing changed, skipping") def _send_weights_nccl(self, iterator) -> None: """NCCL sync: broadcast all params via NCCL + signal /update_weights.""" @@ -364,7 +365,6 @@ def _send_weights_nccl(self, iterator) -> None: t_update.start() logger.debug(f"[weight_sync] /update_weights POST sent ({time.time() - t0:.1f}s)") t_nccl = time.time() - # Strip the mask — NCCL expects (name, tensor) pairs NCCLWeightTransferEngine.trainer_send_weights( iterator=((name, tensor) for name, tensor, _mask in iterator), trainer_args=NCCLTrainerSendWeightsArgs(group=self.model_update_group, packed=True), @@ -423,7 +423,7 @@ async def _generate_loop(self, stop_event: asyncio.Event) -> None: # Current assumption: reset side effects matter, return value is ignored. self.environments[slot].reset(**row) - logger.info(f"[slot] assigned slot={slot} group={group_id} free_after={len(free_slots)}") + logger.debug(f"[slot] assigned slot={slot} group={group_id} free_after={len(free_slots)}") task = asyncio.create_task( self._generate_one(pending_groups[group_id].prompt, tool_dict=self._sync_tool_dicts[slot]) ) diff --git a/trl/experimental/async_grpo/delta_engine.py b/trl/experimental/async_grpo/delta_engine.py index 9c106ac66c8..182ff38c018 100644 --- a/trl/experimental/async_grpo/delta_engine.py +++ b/trl/experimental/async_grpo/delta_engine.py @@ -14,6 +14,7 @@ from __future__ import annotations +import json import logging from collections.abc import Callable, Iterator from dataclasses import dataclass @@ -32,6 +33,8 @@ ) from vllm.distributed.weight_transfer.factory import WeightTransferEngineFactory +from .weight_diff import PatchMetadata + logger = logging.getLogger(__name__) @@ -62,6 +65,7 @@ class DeltaWeightTransferEngine(WeightTransferEngine[DeltaWeightTransferInitInfo def __init__(self, config: WeightTransferConfig, parallel_config: ParallelConfig) -> None: super().__init__(config, parallel_config) + # TODO: might be able to eliminate completely # CPU-side bf16 snapshot — needed because vLLM's load_weights expects full # tensors, so we must reconstruct them from sparse (indices, values) patches. # Kept on CPU to avoid GPU memory overhead (~2 bytes/param, e.g. ~1.2 GB for 0.6B model). @@ -92,32 +96,32 @@ def receive_weights( ) with safe_open(local_path, framework="pt", device="cpu") as f: - is_sparse = f.metadata().get("sparse", "False") == "True" + meta = PatchMetadata.from_metadata_dict(f.metadata()) - if not is_sparse: + if not meta.sparse: # Full tensors — feed to load_weights and build snapshot self._bf16_snapshot = {} for name in f.keys(): tensor = f.get_tensor(name) self._bf16_snapshot[name] = tensor.to(torch.bfloat16).clone() load_weights([(name, tensor)]) - logger.info("Applied full weights (%d params)", len(self._bf16_snapshot)) + logger.info("Applied anchor (step %d, %d params)", meta.model_version, meta.num_changed_params) else: # Sparse — apply indices/values to snapshot, feed full tensors - changed = set() - for key in f.keys(): - if key.endswith(".indices"): - changed.add(key.removesuffix(".indices")) - - for name in changed: + changed_names = json.loads(meta.changed_params) + for name in changed_names: indices = f.get_tensor(f"{name}.indices").long() values = f.get_tensor(f"{name}.values") snap = self._bf16_snapshot[name].flatten() snap[indices] = values self._bf16_snapshot[name] = snap.reshape(self._bf16_snapshot[name].shape) load_weights([(name, self._bf16_snapshot[name])]) - - logger.info("Applied sparse weights (%d params changed)", len(changed)) + logger.info( + "Applied delta (step %d, %d params, sparsity=%.4f)", + meta.model_version, + meta.num_changed_params, + meta.sparsity, + ) def shutdown(self) -> None: self._bf16_snapshot = None @@ -136,8 +140,9 @@ def upload( repo_id: str, filename: str, hf_api: Any, + model_version: int = 0, revision: str = "main", - ) -> int: + ) -> PatchMetadata | None: """Encode params as safetensors and upload to HF Hub. Each item is ``(name, tensor, mask)``: @@ -146,47 +151,62 @@ def upload( - ``mask`` provided: sparse encoding — only changed elements stored as ``{name}.indices`` (int32) + ``{name}.values`` (bf16). - Returns the number of params encoded. + Returns :class:`PatchMetadata` or ``None`` if the iterator was empty. """ tensors: dict[str, torch.Tensor] = {} - num_params = 0 + changed_names: list[str] = [] + total_changed = 0 + total_elements = 0 sparse = False for name, tensor, mask in iterator: - num_params += 1 bf16 = tensor.to(torch.bfloat16).cpu() + total_elements += bf16.numel() if mask is None: tensors[name] = bf16.clone() + changed_names.append(name) + total_changed += bf16.numel() else: sparse = True indices = mask.flatten().nonzero(as_tuple=False).squeeze(1).to(torch.int32) values = bf16.flatten()[indices.long()] tensors[f"{name}.indices"] = indices tensors[f"{name}.values"] = values + changed_names.append(name) + total_changed += len(indices) if not tensors: - return 0 - - metadata = {"num_params": str(num_params), "sparse": str(sparse)} - buf = save(tensors, metadata=metadata) + return None + + meta = PatchMetadata( + sparse=sparse, + model_version=model_version, + num_changed_params=len(changed_names), + total_changed_elements=total_changed, + total_elements=total_elements, + sparsity=1.0 - total_changed / max(total_elements, 1), + changed_params=json.dumps(changed_names), + ) + buf = save(tensors, metadata=meta.to_metadata_dict()) hf_api.upload_file( path_or_fileobj=buf, path_in_repo=filename, repo_id=repo_id, revision=revision, - commit_message=f"weight update ({num_params} params, {len(buf) / 1e6:.1f} MB, sparse={sparse})", + commit_message=f"step {model_version} ({len(changed_names)} params, {len(buf) / 1e6:.1f} MB, sparse={sparse})", ) logger.info( - "[delta_engine] uploaded %s/%s (%.1f MB, %d params, sparse=%s)", + "[delta_engine] uploaded %s/%s (%.1f MB, %d params, sparse=%s, sparsity=%.4f)", repo_id, filename, len(buf) / 1e6, - num_params, + len(changed_names), sparse, + meta.sparsity, ) - return num_params + return meta # --------------------------------------------------------------------------- diff --git a/trl/experimental/async_grpo/weight_diff.py b/trl/experimental/async_grpo/weight_diff.py index c18edec1563..c53f578aea2 100644 --- a/trl/experimental/async_grpo/weight_diff.py +++ b/trl/experimental/async_grpo/weight_diff.py @@ -13,160 +13,78 @@ # limitations under the License. """ -Delta-compressed weight synchronization engine. +Delta-compressed weight synchronization utilities. -Implements sparse weight patching for AsyncGRPOTrainer: detects which bf16 weights changed -between consecutive optimizer steps, encodes only the changed elements as sparse safetensors -patches, and provides a checkpoint chain (anchor + deltas) for reconstructing any step. - -References: -- PULSE paper: arXiv:2602.03839 (Feb 2026) +- ``BF16ChangeDetector``: hooks into the optimizer to detect which bf16 elements + actually changed after each step. +- ``PatchMetadata``: structured metadata stored in safetensors headers for both + anchor (full) and delta (sparse) weight files. """ from __future__ import annotations -import hashlib -import json import logging from dataclasses import asdict, dataclass -from pathlib import Path import torch -from safetensors import safe_open logger = logging.getLogger(__name__) -def bf16_absorption_threshold(w: torch.Tensor) -> torch.Tensor: - """BF16 absorption threshold per element: |delta_w| must exceed this to survive rounding. - - BF16 has 7 mantissa bits. An fp32 update is absorbed when |delta_w| < |w| / 256. - Reference: PULSE paper Definition A.3, Equation (4). - """ - return w.abs() * (2.0**-8) - - -class ULPChangeDetector: - """Detects which bf16 weights change across an optimizer step. - - Hooks into the Adam optimizer via ``register_step_pre_hook`` / ``register_step_post_hook`` - (PyTorch >= 2.1). Runs two passes per optimizer step: +class BF16ChangeDetector: + """Detects which bf16 weights actually changed across an optimizer step. - Pre-step (ULP prediction): uses existing Adam state (m, v) to predict which weights will - change after casting back to bf16. + Hooks into the optimizer via ``register_step_pre_hook`` / ``register_step_post_hook`` + (PyTorch >= 2.1). Snapshots bf16 values before the step, compares after. - Post-step (ground truth): compares post-step bf16 cast against pre-step snapshot. + ``_validated_masks[name]`` is a boolean tensor with True for each element that changed. """ def __init__(self, model: torch.nn.Module, optimizer: torch.optim.Optimizer): - self.model = model - self.optimizer = optimizer - - self._predicted_masks: dict[str, torch.Tensor] = {} self._validated_masks: dict[str, torch.Tensor] = {} self._pre_step_bf16: dict[str, torch.Tensor] = {} - # Build param_id -> name mapping + # Match model param names to optimizer param objects via data_ptr() + # (id() doesn't work because Accelerate wraps params as different objects) + model_params = {p.data_ptr(): name.removeprefix("module.") for name, p in model.named_parameters()} self._param_id_to_name: dict[int, str] = {} - for name, param in model.named_parameters(): - name = name.removeprefix("module.") - self._param_id_to_name[id(param)] = name + for group in optimizer.param_groups: + for p in group["params"]: + name = model_params.get(p.data_ptr()) + if name is not None: + self._param_id_to_name[id(p)] = name + + logger.info( + "BF16ChangeDetector: matched %d/%d optimizer params", + len(self._param_id_to_name), + sum(1 for _ in model.named_parameters()), + ) self._pre_hook_handle = optimizer.register_step_pre_hook(self._pre_step_hook) self._post_hook_handle = optimizer.register_step_post_hook(self._post_step_hook) def _pre_step_hook(self, optimizer, args, kwargs) -> None: - self._predicted_masks.clear() self._pre_step_bf16.clear() - for group in optimizer.param_groups: - lr = group["lr"] - beta1, beta2 = group["betas"] - eps = group["eps"] - weight_decay = group.get("weight_decay", 0.0) - for p in group["params"]: if p.grad is None: continue - pid = id(p) - name = self._param_id_to_name.get(pid) + name = self._param_id_to_name.get(id(p)) if name is None: continue - - state = optimizer.state.get(p, {}) - if "exp_avg" not in state or "exp_avg_sq" not in state: - self._pre_step_bf16[name] = p.detach().to(torch.bfloat16).cpu().clone() - continue - - step_count = state.get("step", torch.tensor(1)).item() if "step" in state else 1 - m = state["exp_avg"] - v = state["exp_avg_sq"] - - with torch.no_grad(): - m_hat = m / (1 - beta1**step_count) - v_hat = v / (1 - beta2**step_count) - predicted_delta = lr * m_hat / (v_hat.sqrt() + eps) - if weight_decay > 0: - predicted_delta = predicted_delta + lr * weight_decay * p.data - threshold = bf16_absorption_threshold(p.data) - self._predicted_masks[name] = (predicted_delta.abs() > threshold).cpu() - self._pre_step_bf16[name] = p.detach().to(torch.bfloat16).cpu().clone() def _post_step_hook(self, optimizer, args, kwargs) -> None: self._validated_masks.clear() - for group in optimizer.param_groups: for p in group["params"]: if p.grad is None: continue - pid = id(p) - name = self._param_id_to_name.get(pid) + name = self._param_id_to_name.get(id(p)) if name is None or name not in self._pre_step_bf16: continue - - post_bf16 = p.detach().to(torch.bfloat16).cpu() - self._validated_masks[name] = post_bf16 != self._pre_step_bf16[name] - - def get_changed_params(self, use_validated: bool = True) -> dict[str, torch.Tensor]: - masks = self._validated_masks if use_validated else self._predicted_masks - return {name: mask for name, mask in masks.items() if mask.any()} - - def get_prediction_accuracy(self) -> dict[str, float]: - total_tp, total_fp, total_fn = 0, 0, 0 - total_changed, total_elements = 0, 0 - - for name, validated in self._validated_masks.items(): - predicted = self._predicted_masks.get(name) - n_validated = validated.sum().item() - total_changed += n_validated - total_elements += validated.numel() - - if predicted is None: - total_fn += n_validated - continue - - tp = (predicted & validated).sum().item() - fp = (predicted & ~validated).sum().item() - fn = (~predicted & validated).sum().item() - total_tp += tp - total_fp += fp - total_fn += fn - - precision = total_tp / max(total_tp + total_fp, 1) - recall = total_tp / max(total_tp + total_fn, 1) - f1 = 2 * precision * recall / max(precision + recall, 1e-12) - sparsity = 1.0 - total_changed / max(total_elements, 1) - - return { - "precision": precision, - "recall": recall, - "f1": f1, - "sparsity": sparsity, - "total_changed": total_changed, - "total_elements": total_elements, - } + self._validated_masks[name] = p.detach().to(torch.bfloat16).cpu() != self._pre_step_bf16[name] def close(self): self._pre_hook_handle.remove() @@ -177,15 +95,12 @@ def close(self): class PatchMetadata: format: str = "sparse_weight_patch" version: str = "1" + sparse: bool = False model_version: int = 0 - prev_model_version: int = -1 - anchor_step: int = 0 - base_model_id: str = "" num_changed_params: int = 0 total_changed_elements: int = 0 total_elements: int = 0 sparsity: float = 0.0 - checksum_sha256: str = "" changed_params: str = "[]" def to_metadata_dict(self) -> dict[str, str]: @@ -203,96 +118,8 @@ def from_metadata_dict(cls, d: dict[str, str]) -> PatchMetadata: kwargs[k] = int(v) elif ft == "float": kwargs[k] = float(v) + elif ft == "bool": + kwargs[k] = v == "True" else: kwargs[k] = v return cls(**kwargs) - - -def compute_bf16_checksum(bf16_state_dict: dict[str, torch.Tensor]) -> str: - h = hashlib.sha256() - for name in sorted(bf16_state_dict.keys()): - h.update(bf16_state_dict[name].detach().cpu().contiguous().view(torch.uint8).numpy().tobytes()) - return h.hexdigest() - - -def encode_sparse_patch( - prev_bf16: dict[str, torch.Tensor], - curr_bf16: dict[str, torch.Tensor], - model_version: int, - prev_model_version: int, - anchor_step: int, - base_model_id: str = "", -) -> tuple[dict[str, torch.Tensor], dict[str, str]]: - """Encode a sparse weight patch between two bf16 snapshots. - - Returns ``(tensors_dict, metadata_dict)`` for ``safetensors.save_file()``. - Values are actual values, not additive deltas, so reconstruction is bit-exact (PULSE Prop A.6). - """ - tensors: dict[str, torch.Tensor] = {} - changed_param_names: list[str] = [] - total_changed = 0 - total_elements = 0 - - for name in sorted(curr_bf16.keys()): - total_elements += curr_bf16[name].numel() - changed_mask = curr_bf16[name] != prev_bf16[name] - num_changed = changed_mask.sum().item() - if num_changed == 0: - continue - changed_param_names.append(name) - total_changed += num_changed - flat_indices = changed_mask.flatten().nonzero(as_tuple=False).squeeze(1).to(torch.int32) - flat_values = curr_bf16[name].flatten()[flat_indices.long()] - tensors[f"{name}.indices"] = flat_indices.cpu() - tensors[f"{name}.values"] = flat_values.cpu() - - checksum = compute_bf16_checksum(curr_bf16) - sparsity = 1.0 - total_changed / max(total_elements, 1) - - meta = PatchMetadata( - model_version=model_version, - prev_model_version=prev_model_version, - anchor_step=anchor_step, - base_model_id=base_model_id, - num_changed_params=len(changed_param_names), - total_changed_elements=total_changed, - total_elements=total_elements, - sparsity=sparsity, - checksum_sha256=checksum, - changed_params=json.dumps(changed_param_names), - ) - - return tensors, meta.to_metadata_dict() - - -def apply_sparse_patch( - base_bf16: dict[str, torch.Tensor], - patch_path: str | Path, - verify_checksum: bool = True, -) -> dict[str, torch.Tensor]: - """Apply a sparse patch to ``base_bf16`` in-place. - - For each changed parameter: ``flat[indices] = values`` (direct assignment, no FP arithmetic). - Optionally verifies SHA256 checksum. Raises ``ValueError`` on mismatch. - """ - with safe_open(str(patch_path), framework="pt", device="cpu") as f: - meta = f.metadata() - changed_names = json.loads(meta["changed_params"]) - for name in changed_names: - indices = f.get_tensor(f"{name}.indices").long() - values = f.get_tensor(f"{name}.values") - flat = base_bf16[name].flatten() - flat[indices] = values - base_bf16[name] = flat.reshape(base_bf16[name].shape) - - if verify_checksum: - expected = meta.get("checksum_sha256", "") - if expected: - actual = compute_bf16_checksum(base_bf16) - if actual != expected: - raise ValueError( - f"Checksum mismatch after applying patch {patch_path}: " - f"expected {expected[:16]}..., got {actual[:16]}..." - ) - - return base_bf16 From 47114d65a105d0283d8cc6218f8e14ab61d40772 Mon Sep 17 00:00:00 2001 From: aminediro Date: Tue, 31 Mar 2026 14:47:44 +0000 Subject: [PATCH 4/5] Simplify delta sync to use HF Buckets instead of Hub - Move anchor/delta decision from trainer to rollout worker - Remove change detector from streaming iter; only check for validated masks - Migrate from HfApi to bucket_id and HF Bucket APIs - Simplify upload/download paths and remove revision parameter - Refactor _send_weights_delta with clearer empty/non-empty logic --- .../async_grpo/async_grpo_trainer.py | 24 ++--- .../async_grpo/async_rollout_worker.py | 80 +++++++++-------- trl/experimental/async_grpo/delta_engine.py | 88 ++++++++----------- 3 files changed, 89 insertions(+), 103 deletions(-) diff --git a/trl/experimental/async_grpo/async_grpo_trainer.py b/trl/experimental/async_grpo/async_grpo_trainer.py index 63ed86124bf..18e06967db1 100644 --- a/trl/experimental/async_grpo/async_grpo_trainer.py +++ b/trl/experimental/async_grpo/async_grpo_trainer.py @@ -572,30 +572,18 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: def _streaming_iter(self): """Yield ``(name, tensor, mask)`` tuples. - - No weight diff (NCCL path): all params, ``mask=None``. - - Anchor step (every ``anchor_interval``): all params, ``mask=None``. - - Delta step: only changed params with element-level masks. - - Nothing changed: yields nothing (no-op). + - No change detector (NCCL path or first sync): all params, ``mask=None``. + - Change detector active: only changed params with element-level masks. + + The anchor/delta decision is NOT made here — the rollout worker handles that. """ - # NCCL path - if self._change_detector is None: + if self._change_detector is None or not self._change_detector._validated_masks: for name, param in self.model.named_parameters(): name = name.removeprefix("module.") yield name, (param.full_tensor() if isinstance(param, DTensor) else param.detach()), None return - # Force full upload on anchor interval or when no masks yet (first sync) masks = self._change_detector._validated_masks - next_version = self.model_version + 1 - force_anchor = not masks or next_version % self.args.delta_sync_anchor_interval == 0 - - if force_anchor: - logger.info("Anchor step %d: sending all params", next_version) - for name, param in self.model.named_parameters(): - name = name.removeprefix("module.") - yield name, (param.full_tensor() if isinstance(param, DTensor) else param.detach()), None - return - total, yielded = 0, 0 for name, param in self.model.named_parameters(): total += 1 @@ -605,7 +593,7 @@ def _streaming_iter(self): continue yield name, (param.full_tensor() if isinstance(param, DTensor) else param.detach()), mask yielded += 1 - logger.info(f"ULP: {yielded}/{total} params changed") + logger.info(f"Delta: {yielded}/{total} params changed") def _sync_weight(self): # Lazy-init ULP detector for diagnostic logging (delta sync only) bc diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index fb1fe5f61f5..89c5d65b840 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -27,6 +27,7 @@ import requests from accelerate.logging import get_logger from datasets import Dataset +from huggingface_hub import create_bucket from transformers import AutoTokenizer from trl.chat_template_utils import add_response_schema, get_training_chat_template, parse_response @@ -210,10 +211,7 @@ def _wait_for_server_ready_sync(self, timeout_s: float = 240.0, poll_interval_s: def _init_weight_transfer(self) -> None: if self.delta_sync_enabled: - from huggingface_hub import HfApi - - self._delta_hf_api = HfApi() - self._delta_hf_api.create_repo(repo_id=self._delta_sync_repo_id, repo_type="model", exist_ok=True) + create_bucket(self._delta_sync_repo_id, exist_ok=True) self._delta_model_version = 0 self._delta_pending_update_info: dict | None = None requests.post( @@ -318,39 +316,49 @@ def send_weights(self, iterator) -> None: self._send_weights_nccl(iterator) def _send_weights_delta(self, iterator) -> None: + """Delta sync via HF Bucket. + + - Non-empty iterator: upload (anchor or delta based on step count). + - Empty iterator + pending info: signal vLLM to apply. + - Empty iterator + nothing pending: no-op. + """ first = next(iterator, None) - if first is not None: - full_iter = itertools.chain([first], iterator) - self._delta_model_version += 1 - is_anchor = first[2] is None - subdir = "anchors" if is_anchor else "deltas" - filename = f"{subdir}/step_{self._delta_model_version:06d}.safetensors" - meta = DeltaWeightTransferEngine.upload( - iterator=full_iter, - repo_id=self._delta_sync_repo_id, - filename=filename, - hf_api=self._delta_hf_api, - model_version=self._delta_model_version, - ) - if meta is not None: - self._delta_pending_update_info = { - "repo_id": self._delta_sync_repo_id, - "filename": filename, - "revision": "main", - "is_checkpoint_format": True, - } - elif self._delta_pending_update_info is not None: - # Signal vLLM to fetch and apply - resp = requests.post( - f"{self.vllm_server_url}/update_weights", - json={"update_info": self._delta_pending_update_info}, - timeout=300, - ) - resp.raise_for_status() - self._delta_pending_update_info = None - else: - # Nothing changed and nothing pending — no-op - logger.debug("[delta_sync] nothing changed, skipping") + + # (empty iterator) + if first is None: + if self._delta_pending_update_info is not None: + resp = requests.post( + f"{self.vllm_server_url}/update_weights", + json={"update_info": self._delta_pending_update_info}, + timeout=300, + ) + resp.raise_for_status() + self._delta_pending_update_info = None + return + + # Upload phase + self._delta_model_version += 1 + is_anchor = self._delta_model_version == 1 or self._delta_model_version % self._delta_sync_anchor_interval == 0 + + full_iter = itertools.chain([first], iterator) + if is_anchor: + # Force full tensors — strip masks + full_iter = ((name, tensor, None) for name, tensor, _mask in full_iter) + + subdir = "anchors" if is_anchor else "deltas" + filename = f"{subdir}/step_{self._delta_model_version:06d}.safetensors" + meta = DeltaWeightTransferEngine.upload( + iterator=full_iter, + bucket_id=self._delta_sync_repo_id, + filename=filename, + model_version=self._delta_model_version, + ) + if meta is not None: + self._delta_pending_update_info = { + "repo_id": self._delta_sync_repo_id, + "filename": filename, + "is_checkpoint_format": True, + } def _send_weights_nccl(self, iterator) -> None: """NCCL sync: broadcast all params via NCCL + signal /update_weights.""" diff --git a/trl/experimental/async_grpo/delta_engine.py b/trl/experimental/async_grpo/delta_engine.py index 182ff38c018..4ae63670c09 100644 --- a/trl/experimental/async_grpo/delta_engine.py +++ b/trl/experimental/async_grpo/delta_engine.py @@ -16,12 +16,13 @@ import json import logging +import tempfile from collections.abc import Callable, Iterator from dataclasses import dataclass from typing import Any import torch -from huggingface_hub import hf_hub_download +from huggingface_hub import batch_bucket_files, download_bucket_files from safetensors import safe_open from safetensors.torch import save from vllm.config.parallel import ParallelConfig @@ -46,11 +47,10 @@ class DeltaWeightTransferInitInfo(WeightTransferInitInfo): @dataclass class DeltaWeightTransferUpdateInfo(WeightTransferUpdateInfo): - """Metadata sent via ``/update_weights`` — just Hub coordinates.""" + """Metadata sent via ``/update_weights`` — just bucket coordinates.""" - repo_id: str = "" + repo_id: str = "" # bucket_id filename: str = "" - revision: str = "main" class DeltaWeightTransferEngine(WeightTransferEngine[DeltaWeightTransferInitInfo, DeltaWeightTransferUpdateInfo]): @@ -88,40 +88,38 @@ def receive_weights( - **Sparse** (subsequent): keys are ``{name}.indices`` + ``{name}.values`` → apply to snapshot, feed reconstructed full tensors to load_weights. """ - local_path = hf_hub_download( - repo_id=update_info.repo_id, - filename=update_info.filename, - revision=update_info.revision, - force_download=True, - ) - - with safe_open(local_path, framework="pt", device="cpu") as f: - meta = PatchMetadata.from_metadata_dict(f.metadata()) - - if not meta.sparse: - # Full tensors — feed to load_weights and build snapshot - self._bf16_snapshot = {} - for name in f.keys(): - tensor = f.get_tensor(name) - self._bf16_snapshot[name] = tensor.to(torch.bfloat16).clone() - load_weights([(name, tensor)]) - logger.info("Applied anchor (step %d, %d params)", meta.model_version, meta.num_changed_params) - else: - # Sparse — apply indices/values to snapshot, feed full tensors - changed_names = json.loads(meta.changed_params) - for name in changed_names: - indices = f.get_tensor(f"{name}.indices").long() - values = f.get_tensor(f"{name}.values") - snap = self._bf16_snapshot[name].flatten() - snap[indices] = values - self._bf16_snapshot[name] = snap.reshape(self._bf16_snapshot[name].shape) - load_weights([(name, self._bf16_snapshot[name])]) - logger.info( - "Applied delta (step %d, %d params, sparsity=%.4f)", - meta.model_version, - meta.num_changed_params, - meta.sparsity, - ) + with tempfile.TemporaryDirectory() as tmpdir: + local_path = f"{tmpdir}/weights.safetensors" + download_bucket_files( + update_info.repo_id, + files=[(update_info.filename, local_path)], + ) + + with safe_open(local_path, framework="pt", device="cpu") as f: + meta = PatchMetadata.from_metadata_dict(f.metadata()) + + if not meta.sparse: + self._bf16_snapshot = {} + for name in f.keys(): + tensor = f.get_tensor(name) + self._bf16_snapshot[name] = tensor.to(torch.bfloat16).clone() + load_weights([(name, tensor)]) + logger.info("Applied anchor (step %d, %d params)", meta.model_version, meta.num_changed_params) + else: + changed_names = json.loads(meta.changed_params) + for name in changed_names: + indices = f.get_tensor(f"{name}.indices").long() + values = f.get_tensor(f"{name}.values") + snap = self._bf16_snapshot[name].flatten() + snap[indices] = values + self._bf16_snapshot[name] = snap.reshape(self._bf16_snapshot[name].shape) + load_weights([(name, self._bf16_snapshot[name])]) + logger.info( + "Applied delta (step %d, %d params, sparsity=%.4f)", + meta.model_version, + meta.num_changed_params, + meta.sparsity, + ) def shutdown(self) -> None: self._bf16_snapshot = None @@ -137,11 +135,9 @@ def trainer_send_weights( @staticmethod def upload( iterator: Iterator[tuple[str, torch.Tensor, torch.Tensor | None]], - repo_id: str, + bucket_id: str, filename: str, - hf_api: Any, model_version: int = 0, - revision: str = "main", ) -> PatchMetadata | None: """Encode params as safetensors and upload to HF Hub. @@ -189,17 +185,11 @@ def upload( ) buf = save(tensors, metadata=meta.to_metadata_dict()) - hf_api.upload_file( - path_or_fileobj=buf, - path_in_repo=filename, - repo_id=repo_id, - revision=revision, - commit_message=f"step {model_version} ({len(changed_names)} params, {len(buf) / 1e6:.1f} MB, sparse={sparse})", - ) + batch_bucket_files(bucket_id, add=[(buf, filename)]) logger.info( "[delta_engine] uploaded %s/%s (%.1f MB, %d params, sparse=%s, sparsity=%.4f)", - repo_id, + bucket_id, filename, len(buf) / 1e6, len(changed_names), From 9f2e55fb4da8328290d16c8995217de85d8cf9f3 Mon Sep 17 00:00:00 2001 From: aminediro Date: Tue, 31 Mar 2026 16:28:07 +0000 Subject: [PATCH 5/5] Remove unnecessary section comments from delta_engine.py --- trl/experimental/async_grpo/delta_engine.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/trl/experimental/async_grpo/delta_engine.py b/trl/experimental/async_grpo/delta_engine.py index 4ae63670c09..d0c33f93bd0 100644 --- a/trl/experimental/async_grpo/delta_engine.py +++ b/trl/experimental/async_grpo/delta_engine.py @@ -199,11 +199,6 @@ def upload( return meta -# --------------------------------------------------------------------------- -# Worker extension — its import triggers engine registration -# --------------------------------------------------------------------------- - - class DeltaWorkerExtension: """vLLM worker extension for the delta weight transfer backend. @@ -218,10 +213,6 @@ class DeltaWorkerExtension: pass -# --------------------------------------------------------------------------- -# Module-level registration -# --------------------------------------------------------------------------- - if "delta" not in WeightTransferEngineFactory._registry: WeightTransferEngineFactory.register_engine("delta", DeltaWeightTransferEngine)