diff --git a/pyproject.toml b/pyproject.toml index f4831765478..15b0add31dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ requires-python = ">=3.10" dependencies = [ "accelerate>=1.4.0", "datasets>=4.7.0", # Support Json type and on_mixed_types="use_json" + "huggingface-hub>=0.36.2", "packaging>20.0", "transformers>=4.56.2", ] diff --git a/trl/experimental/async_grpo/async_grpo_config.py b/trl/experimental/async_grpo/async_grpo_config.py index 2afd760e7fc..775c6811824 100644 --- a/trl/experimental/async_grpo/async_grpo_config.py +++ b/trl/experimental/async_grpo/async_grpo_config.py @@ -185,6 +185,38 @@ class AsyncGRPOConfig(_BaseConfig): metadata={"help": "Number of training steps between weight synchronizations to the vLLM server."}, ) + # Delta weight sync + delta_sync_enabled: bool = field( + default=False, + metadata={ + "help": "Enable delta-compressed weight synchronization. Instead of transferring all " + "weights over NCCL, encode only changed bf16 weights as sparse safetensors patches, " + "upload them to HuggingFace Hub (Xet storage), and signal vLLM to fetch and apply them." + }, + ) + delta_sync_repo_id: str | None = field( + default=None, + metadata={ + "help": "HuggingFace Hub repository for storing delta weight patches and anchor " + "checkpoints (e.g. 'user/training-run-xyz'). Required when delta_sync_enabled=True. " + "The repo is created automatically if it does not exist." + }, + ) + delta_sync_anchor_interval: int = field( + default=10, + metadata={ + "help": "Save a full bf16 anchor checkpoint every N weight sync steps. Between anchors " + "only sparse delta patches are saved. Fireworks blog uses N=25, PULSE paper uses N=50." + }, + ) + delta_sync_verify_checksum: bool = field( + default=True, + metadata={ + "help": "Verify SHA256 checksum after applying each delta patch on the vLLM server. " + "Adds overhead per sync but guarantees bit-exact reconstruction." + }, + ) + # Parameters that control the logging log_completions: bool = field( default=False, @@ -201,6 +233,9 @@ class AsyncGRPOConfig(_BaseConfig): def __post_init__(self): super().__post_init__() + if self.delta_sync_enabled and self.delta_sync_repo_id is None: + raise ValueError("delta_sync_repo_id is required when delta_sync_enabled=True") + # Accelerator config: required for the async IterableDataset-backed dataloader to work correctly. # split_batches=True and dispatch_batches=True ensure that the main process drives the dataloader # and batches are broadcast to other processes rather than each process pulling independently. diff --git a/trl/experimental/async_grpo/async_grpo_trainer.py b/trl/experimental/async_grpo/async_grpo_trainer.py index aca72c73596..6fc16829db2 100644 --- a/trl/experimental/async_grpo/async_grpo_trainer.py +++ b/trl/experimental/async_grpo/async_grpo_trainer.py @@ -35,6 +35,7 @@ from .async_grpo_config import AsyncGRPOConfig from .async_rollout_worker import AsyncRolloutWorker +from .weight_diff import BF16ChangeDetector logger = get_logger(__name__) @@ -379,6 +380,9 @@ def __init__( weight_names=weight_names, weight_dtype_names=weight_dtype_names, weight_shapes=weight_shapes, + delta_sync_enabled=self.args.delta_sync_enabled, + delta_sync_repo_id=self.args.delta_sync_repo_id, + delta_sync_anchor_interval=self.args.delta_sync_anchor_interval, ) self.rollout_queue = self.rollout_worker.rollout_buffer else: @@ -388,6 +392,9 @@ def __init__( # Add callbacks self.add_callback(StepIntervalCallback(self._sync_weight, self.args.weight_sync_steps)) + # ULP change detector for diagnostic logging (delta sync only) + self._change_detector: BF16ChangeDetector | None = None + def get_train_dataloader(self) -> DataLoader: if self.accelerator.is_main_process: dataset = RolloutQueueDataset( @@ -562,43 +569,115 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: self._metrics[mode].clear() def _streaming_iter(self): - # Iterate parameters one at a time. For FSDP2 (DTensor), full_tensor() all-gathers just this parameter across - # FSDP ranks, then frees it once the generator advances — avoiding materializing the full model in memory. + """Yield ``(name, tensor, mask)`` tuples. + + - No change detector (NCCL path or first sync): all params, ``mask=None``. + - Change detector active: only changed params with element-level masks. + + The anchor/delta decision is NOT made here — the rollout worker handles that. + """ + if self._change_detector is None or not self._change_detector._validated_masks: + for name, param in self.model.named_parameters(): + name = name.removeprefix("module.") + yield name, (param.full_tensor() if isinstance(param, DTensor) else param.detach()), None + return + + masks = self._change_detector._validated_masks + total, yielded = 0, 0 for name, param in self.model.named_parameters(): - name = name.removeprefix("module.") # DDP/FSDP1 wrapping - full = param.full_tensor() if isinstance(param, DTensor) else param.detach() - yield name, full + total += 1 + name = name.removeprefix("module.") + mask = masks.get(name) + if mask is None or not mask.any(): + continue + yield name, (param.full_tensor() if isinstance(param, DTensor) else param.detach()), mask + yielded += 1 + logger.info(f"Delta: {yielded}/{total} params changed") def _sync_weight(self): + # Lazy-init ULP detector for diagnostic logging (delta sync only) bc + # Optimizer only exists after Trainer creates it inside super()._inner_training_loop(). + if ( + self.args.delta_sync_enabled + and self._change_detector is None + and hasattr(self, "optimizer") + and self.optimizer is not None + ): + # TODO(@aminediro): check this works with FSDP2 + # Unwrap AcceleratedOptimizer to get the native PyTorch optimizer + # (register_step_pre_hook requires torch.optim.Optimizer internals) + raw_optimizer = getattr(self.optimizer, "optimizer", self.optimizer) + self._change_detector = BF16ChangeDetector(self.model, raw_optimizer) + + if ( + self.args.delta_sync_enabled + and self._change_detector is not None + and self._change_detector._validated_masks + and self.accelerator.is_main_process + ): + total_changed = 0 + total_elements = 0 + for mask in self._change_detector._validated_masks.values(): + total_changed += mask.sum().item() + total_elements += mask.numel() + sparsity = 1.0 - total_changed / max(total_elements, 1) + self._metrics["train"]["delta/sparsity"].append(sparsity) + self._metrics["train"]["delta/total_changed"].append(total_changed) + self._metrics["train"]["delta/total_elements"].append(total_elements) + logger.info(f"Delta: {total_changed}/{total_elements} elements changed (sparsity={sparsity:.4%})") + t0 = time.time() - logger.info("Weight sync: pausing vLLM...") + is_delta = self.args.delta_sync_enabled + + if is_delta: + # Phase 1: Upload to HF Hub while inference continues + logger.info("Weight sync: uploading to HF Hub (inference still running)...") + if self.accelerator.is_main_process and self.rollout_worker: + self.rollout_worker.send_weights(self._streaming_iter()) + else: + for _ in self._streaming_iter(): + pass + self.accelerator.wait_for_everyone() + t_upload = time.time() + logger.info(f"Weight sync: upload took {t_upload - t0:.1f}s, now pausing vLLM...") + + # Phase 2: Pause inference if self.accelerator.is_main_process and self.rollout_worker: self.rollout_worker.pause() t_pause = time.time() - logger.info(f"Weight sync: pause took {t_pause - t0:.1f}s, waiting for all ranks...") self.accelerator.wait_for_everyone() t_barrier = time.time() - logger.info(f"Weight sync: transferring weights... (barrier took {t_barrier - t_pause:.1f}s)") - if self.accelerator.is_main_process and self.rollout_worker: - self.rollout_worker.send_weights(self._streaming_iter()) + if is_delta: + # Phase 3: Signal vLLM to fetch the already-uploaded weights + logger.info(f"Weight sync: signaling vLLM to apply... (pause took {t_pause - t_upload:.1f}s)") + if self.accelerator.is_main_process and self.rollout_worker: + self.rollout_worker.send_weights(iter([])) else: - # Non-rank-0 processes must still participate in full_tensor() collectives for FSDP2. - for _ in self._streaming_iter(): - pass + # NCCL: transfer all weights directly + logger.info(f"Weight sync: transferring weights... (barrier took {t_barrier - t_pause:.1f}s)") + if self.accelerator.is_main_process and self.rollout_worker: + self.rollout_worker.send_weights(self._streaming_iter()) + else: + for _ in self._streaming_iter(): + pass + t_transfer = time.time() self.accelerator.wait_for_everyone() - logger.info(f"Weight sync: resuming vLLM... (transfer took {t_transfer - t_barrier:.1f}s)") + # Phase 4: Resume + logger.info(f"Weight sync: resuming vLLM... (apply took {t_transfer - t_barrier:.1f}s)") if self.accelerator.is_main_process and self.rollout_worker: self.rollout_worker.resume() self.model_version += 1 self.rollout_worker.update_model_version(self.model_version) weight_sync_time_s = time.time() - t0 self._metrics["train"]["weight_sync_time_s"].append(weight_sync_time_s) - logger.info(f"Weight sync: done. Total {weight_sync_time_s:.1f}s") + logger.info( + f"Weight sync: done. Total {weight_sync_time_s:.1f}s (inference paused {t_transfer - t_pause:.1f}s)" + ) def _inner_training_loop(self, *args, **kwargs): # Start the rollout worker here (not in __init__) so that checkpoint loading in Trainer.train() @@ -613,3 +692,6 @@ def _inner_training_loop(self, *args, **kwargs): finally: if self.accelerator.is_main_process and self.rollout_worker: self.rollout_worker.stop() + if self._change_detector is not None: + self._change_detector.close() + self._change_detector = None diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index 3d7350d5a71..2f09d826d42 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -14,6 +14,7 @@ import asyncio import inspect +import itertools import queue import threading import time @@ -26,12 +27,15 @@ import requests from accelerate.logging import get_logger from datasets import Dataset +from huggingface_hub import create_bucket from transformers import AutoTokenizer from trl.chat_template_utils import add_response_schema, get_training_chat_template, parse_response from trl.import_utils import is_vllm_available from trl.trainer.utils import print_prompt_completions_sample +from .delta_engine import DeltaWeightTransferEngine + if is_vllm_available(min_version="0.17.1"): from vllm.distributed.weight_transfer.nccl_engine import NCCLTrainerSendWeightsArgs, NCCLWeightTransferEngine @@ -102,11 +106,15 @@ def __init__( weight_names: list[str] | None = None, weight_dtype_names: list[str] | None = None, weight_shapes: list[list[int]] | None = None, + delta_sync_enabled: bool = False, + delta_sync_repo_id: str | None = None, + delta_sync_anchor_interval: int = 10, ): if not is_vllm_available(min_version="0.17.1"): raise ImportError( "vLLM >= 0.17.1 is required to use AsyncRolloutWorker. Install it with: pip install 'vllm>=0.17.1'" ) + self.delta_sync_enabled = delta_sync_enabled self.model_name = model_name self.max_tool_calling_iterations = max_tool_calling_iterations self.dataset = dataset @@ -171,7 +179,10 @@ def __init__( self.model_version = 0 self.session = None - # Wait for the vLLM server and initialize NCCL weight transfer. + self._delta_sync_repo_id = delta_sync_repo_id + self._delta_sync_anchor_interval = delta_sync_anchor_interval + + # Wait for the vLLM server and initialize weight transfer. self._wait_for_server_ready_sync(timeout_s=self.server_timeout) self._init_weight_transfer() @@ -199,6 +210,18 @@ def _wait_for_server_ready_sync(self, timeout_s: float = 240.0, poll_interval_s: time.sleep(poll_interval_s) def _init_weight_transfer(self) -> None: + if self.delta_sync_enabled: + create_bucket(self._delta_sync_repo_id, exist_ok=True) + self._delta_model_version = 0 + self._delta_pending_update_info: dict | None = None + requests.post( + f"{self.vllm_server_url}/init_weight_transfer_engine", + json={"init_info": {}}, + timeout=120, + ) + logger.info("Init delta weight transfer with HF Hub repo %s", self._delta_sync_repo_id) + return + response = requests.get(f"{self.vllm_server_url}/get_world_size") inference_world_size = response.json()["world_size"] world_size = inference_world_size + 1 @@ -287,6 +310,58 @@ def resume(self) -> None: logger.debug(f"[weight_sync] resume HTTP took {time.time() - t0:.1f}s") def send_weights(self, iterator) -> None: + if self.delta_sync_enabled: + self._send_weights_delta(iterator) + else: + self._send_weights_nccl(iterator) + + def _send_weights_delta(self, iterator) -> None: + """Delta sync via HF Bucket. + + - Non-empty iterator: upload (anchor or delta based on step count). + - Empty iterator + pending info: signal vLLM to apply. + - Empty iterator + nothing pending: no-op. + """ + first = next(iterator, None) + + # (empty iterator) + if first is None: + if self._delta_pending_update_info is not None: + resp = requests.post( + f"{self.vllm_server_url}/update_weights", + json={"update_info": self._delta_pending_update_info}, + timeout=300, + ) + resp.raise_for_status() + self._delta_pending_update_info = None + return + + # Upload phase + self._delta_model_version += 1 + is_anchor = self._delta_model_version == 1 or self._delta_model_version % self._delta_sync_anchor_interval == 0 + + full_iter = itertools.chain([first], iterator) + if is_anchor: + # Force full tensors — strip masks + full_iter = ((name, tensor, None) for name, tensor, _mask in full_iter) + + subdir = "anchors" if is_anchor else "deltas" + filename = f"{subdir}/step_{self._delta_model_version:06d}.safetensors" + meta = DeltaWeightTransferEngine.upload( + iterator=full_iter, + bucket_id=self._delta_sync_repo_id, + filename=filename, + model_version=self._delta_model_version, + ) + if meta is not None: + self._delta_pending_update_info = { + "repo_id": self._delta_sync_repo_id, + "filename": filename, + "is_checkpoint_format": True, + } + + def _send_weights_nccl(self, iterator) -> None: + """NCCL sync: broadcast all params via NCCL + signal /update_weights.""" if self.model_update_group is None: return t0 = time.time() @@ -299,7 +374,7 @@ def send_weights(self, iterator) -> None: logger.debug(f"[weight_sync] /update_weights POST sent ({time.time() - t0:.1f}s)") t_nccl = time.time() NCCLWeightTransferEngine.trainer_send_weights( - iterator=iterator, + iterator=((name, tensor) for name, tensor, _mask in iterator), trainer_args=NCCLTrainerSendWeightsArgs(group=self.model_update_group, packed=True), ) logger.debug(f"[weight_sync] NCCL transfer took {time.time() - t_nccl:.1f}s") @@ -356,7 +431,7 @@ async def _generate_loop(self, stop_event: asyncio.Event) -> None: # Current assumption: reset side effects matter, return value is ignored. self.environments[slot].reset(**row) - logger.info(f"[slot] assigned slot={slot} group={group_id} free_after={len(free_slots)}") + logger.debug(f"[slot] assigned slot={slot} group={group_id} free_after={len(free_slots)}") task = asyncio.create_task( self._generate_one(pending_groups[group_id].prompt, tool_dict=self._sync_tool_dicts[slot]) ) diff --git a/trl/experimental/async_grpo/delta_engine.py b/trl/experimental/async_grpo/delta_engine.py new file mode 100644 index 00000000000..d0c33f93bd0 --- /dev/null +++ b/trl/experimental/async_grpo/delta_engine.py @@ -0,0 +1,220 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import logging +import tempfile +from collections.abc import Callable, Iterator +from dataclasses import dataclass +from typing import Any + +import torch +from huggingface_hub import batch_bucket_files, download_bucket_files +from safetensors import safe_open +from safetensors.torch import save +from vllm.config.parallel import ParallelConfig +from vllm.config.weight_transfer import WeightTransferConfig +from vllm.distributed.weight_transfer.base import ( + WeightTransferEngine, + WeightTransferInitInfo, + WeightTransferUpdateInfo, +) +from vllm.distributed.weight_transfer.factory import WeightTransferEngineFactory + +from .weight_diff import PatchMetadata + + +logger = logging.getLogger(__name__) + + +@dataclass +class DeltaWeightTransferInitInfo(WeightTransferInitInfo): + pass + + +@dataclass +class DeltaWeightTransferUpdateInfo(WeightTransferUpdateInfo): + """Metadata sent via ``/update_weights`` — just bucket coordinates.""" + + repo_id: str = "" # bucket_id + filename: str = "" + + +class DeltaWeightTransferEngine(WeightTransferEngine[DeltaWeightTransferInitInfo, DeltaWeightTransferUpdateInfo]): + """Weight transfer engine that uses HF Hub (Xet) as the data plane. + + Worker side: downloads safetensors from Hub, feeds to ``load_weights``. + Trainer side: uploads changed params as safetensors to Hub. + """ + + init_info_cls = DeltaWeightTransferInitInfo + update_info_cls = DeltaWeightTransferUpdateInfo + + def __init__(self, config: WeightTransferConfig, parallel_config: ParallelConfig) -> None: + super().__init__(config, parallel_config) + # TODO: might be able to eliminate completely + # CPU-side bf16 snapshot — needed because vLLM's load_weights expects full + # tensors, so we must reconstruct them from sparse (indices, values) patches. + # Kept on CPU to avoid GPU memory overhead (~2 bytes/param, e.g. ~1.2 GB for 0.6B model). + self._bf16_snapshot: dict[str, torch.Tensor] | None = None + + def init_transfer_engine(self, init_info: DeltaWeightTransferInitInfo) -> None: + pass + + def receive_weights( + self, + update_info: DeltaWeightTransferUpdateInfo, + load_weights: Callable[[list[tuple[str, torch.Tensor]]], None], + ) -> None: + """Download safetensors from Hub and feed to load_weights. + + Handles two formats based on the ``sparse`` metadata flag: + + - **Full** (first sync): keys are param names → feed directly to load_weights, + build snapshot for future sparse applies. + - **Sparse** (subsequent): keys are ``{name}.indices`` + ``{name}.values`` → + apply to snapshot, feed reconstructed full tensors to load_weights. + """ + with tempfile.TemporaryDirectory() as tmpdir: + local_path = f"{tmpdir}/weights.safetensors" + download_bucket_files( + update_info.repo_id, + files=[(update_info.filename, local_path)], + ) + + with safe_open(local_path, framework="pt", device="cpu") as f: + meta = PatchMetadata.from_metadata_dict(f.metadata()) + + if not meta.sparse: + self._bf16_snapshot = {} + for name in f.keys(): + tensor = f.get_tensor(name) + self._bf16_snapshot[name] = tensor.to(torch.bfloat16).clone() + load_weights([(name, tensor)]) + logger.info("Applied anchor (step %d, %d params)", meta.model_version, meta.num_changed_params) + else: + changed_names = json.loads(meta.changed_params) + for name in changed_names: + indices = f.get_tensor(f"{name}.indices").long() + values = f.get_tensor(f"{name}.values") + snap = self._bf16_snapshot[name].flatten() + snap[indices] = values + self._bf16_snapshot[name] = snap.reshape(self._bf16_snapshot[name].shape) + load_weights([(name, self._bf16_snapshot[name])]) + logger.info( + "Applied delta (step %d, %d params, sparsity=%.4f)", + meta.model_version, + meta.num_changed_params, + meta.sparsity, + ) + + def shutdown(self) -> None: + self._bf16_snapshot = None + + @staticmethod + def trainer_send_weights( + iterator: Iterator[tuple[str, torch.Tensor]], + trainer_args: dict[str, Any] | Any, + ) -> None: + """Not used directly — the rollout worker manages upload + signaling.""" + raise NotImplementedError("Use AsyncRolloutWorker._send_weights_delta instead") + + @staticmethod + def upload( + iterator: Iterator[tuple[str, torch.Tensor, torch.Tensor | None]], + bucket_id: str, + filename: str, + model_version: int = 0, + ) -> PatchMetadata | None: + """Encode params as safetensors and upload to HF Hub. + + Each item is ``(name, tensor, mask)``: + + - ``mask is None``: full tensor stored as ``name`` (anchor). + - ``mask`` provided: sparse encoding — only changed elements stored + as ``{name}.indices`` (int32) + ``{name}.values`` (bf16). + + Returns :class:`PatchMetadata` or ``None`` if the iterator was empty. + """ + tensors: dict[str, torch.Tensor] = {} + changed_names: list[str] = [] + total_changed = 0 + total_elements = 0 + sparse = False + + for name, tensor, mask in iterator: + bf16 = tensor.to(torch.bfloat16).cpu() + total_elements += bf16.numel() + if mask is None: + tensors[name] = bf16.clone() + changed_names.append(name) + total_changed += bf16.numel() + else: + sparse = True + indices = mask.flatten().nonzero(as_tuple=False).squeeze(1).to(torch.int32) + values = bf16.flatten()[indices.long()] + tensors[f"{name}.indices"] = indices + tensors[f"{name}.values"] = values + changed_names.append(name) + total_changed += len(indices) + + if not tensors: + return None + + meta = PatchMetadata( + sparse=sparse, + model_version=model_version, + num_changed_params=len(changed_names), + total_changed_elements=total_changed, + total_elements=total_elements, + sparsity=1.0 - total_changed / max(total_elements, 1), + changed_params=json.dumps(changed_names), + ) + buf = save(tensors, metadata=meta.to_metadata_dict()) + + batch_bucket_files(bucket_id, add=[(buf, filename)]) + + logger.info( + "[delta_engine] uploaded %s/%s (%.1f MB, %d params, sparse=%s, sparsity=%.4f)", + bucket_id, + filename, + len(buf) / 1e6, + len(changed_names), + sparse, + meta.sparsity, + ) + return meta + + +class DeltaWorkerExtension: + """vLLM worker extension for the delta weight transfer backend. + + This class is intentionally minimal. Its import (via ``--worker-extension-cls``) + registers the engine and overrides the ``"nccl"`` factory entry. + + ``backend`` must be ``"nccl"`` in the CLI (pydantic ``Literal`` validation). + This module overrides the ``"nccl"`` factory entry so that the actual engine + created is ``DeltaWeightTransferEngine``. + """ + + pass + + +if "delta" not in WeightTransferEngineFactory._registry: + WeightTransferEngineFactory.register_engine("delta", DeltaWeightTransferEngine) + +# Override "nccl" so --weight-transfer-config '{"backend":"nccl"}' creates our engine. +WeightTransferEngineFactory._registry["nccl"] = lambda: DeltaWeightTransferEngine diff --git a/trl/experimental/async_grpo/weight_diff.py b/trl/experimental/async_grpo/weight_diff.py new file mode 100644 index 00000000000..c53f578aea2 --- /dev/null +++ b/trl/experimental/async_grpo/weight_diff.py @@ -0,0 +1,125 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Delta-compressed weight synchronization utilities. + +- ``BF16ChangeDetector``: hooks into the optimizer to detect which bf16 elements + actually changed after each step. +- ``PatchMetadata``: structured metadata stored in safetensors headers for both + anchor (full) and delta (sparse) weight files. +""" + +from __future__ import annotations + +import logging +from dataclasses import asdict, dataclass + +import torch + + +logger = logging.getLogger(__name__) + + +class BF16ChangeDetector: + """Detects which bf16 weights actually changed across an optimizer step. + + Hooks into the optimizer via ``register_step_pre_hook`` / ``register_step_post_hook`` + (PyTorch >= 2.1). Snapshots bf16 values before the step, compares after. + + ``_validated_masks[name]`` is a boolean tensor with True for each element that changed. + """ + + def __init__(self, model: torch.nn.Module, optimizer: torch.optim.Optimizer): + self._validated_masks: dict[str, torch.Tensor] = {} + self._pre_step_bf16: dict[str, torch.Tensor] = {} + + # Match model param names to optimizer param objects via data_ptr() + # (id() doesn't work because Accelerate wraps params as different objects) + model_params = {p.data_ptr(): name.removeprefix("module.") for name, p in model.named_parameters()} + self._param_id_to_name: dict[int, str] = {} + for group in optimizer.param_groups: + for p in group["params"]: + name = model_params.get(p.data_ptr()) + if name is not None: + self._param_id_to_name[id(p)] = name + + logger.info( + "BF16ChangeDetector: matched %d/%d optimizer params", + len(self._param_id_to_name), + sum(1 for _ in model.named_parameters()), + ) + + self._pre_hook_handle = optimizer.register_step_pre_hook(self._pre_step_hook) + self._post_hook_handle = optimizer.register_step_post_hook(self._post_step_hook) + + def _pre_step_hook(self, optimizer, args, kwargs) -> None: + self._pre_step_bf16.clear() + for group in optimizer.param_groups: + for p in group["params"]: + if p.grad is None: + continue + name = self._param_id_to_name.get(id(p)) + if name is None: + continue + self._pre_step_bf16[name] = p.detach().to(torch.bfloat16).cpu().clone() + + def _post_step_hook(self, optimizer, args, kwargs) -> None: + self._validated_masks.clear() + for group in optimizer.param_groups: + for p in group["params"]: + if p.grad is None: + continue + name = self._param_id_to_name.get(id(p)) + if name is None or name not in self._pre_step_bf16: + continue + self._validated_masks[name] = p.detach().to(torch.bfloat16).cpu() != self._pre_step_bf16[name] + + def close(self): + self._pre_hook_handle.remove() + self._post_hook_handle.remove() + + +@dataclass +class PatchMetadata: + format: str = "sparse_weight_patch" + version: str = "1" + sparse: bool = False + model_version: int = 0 + num_changed_params: int = 0 + total_changed_elements: int = 0 + total_elements: int = 0 + sparsity: float = 0.0 + changed_params: str = "[]" + + def to_metadata_dict(self) -> dict[str, str]: + return {k: str(v) for k, v in asdict(self).items()} + + @classmethod + def from_metadata_dict(cls, d: dict[str, str]) -> PatchMetadata: + field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()} + kwargs = {} + for k, v in d.items(): + if k not in field_types: + continue + ft = field_types[k] + if ft == "int": + kwargs[k] = int(v) + elif ft == "float": + kwargs[k] = float(v) + elif ft == "bool": + kwargs[k] = v == "True" + else: + kwargs[k] = v + return cls(**kwargs)