Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ requires-python = ">=3.10"
dependencies = [
"accelerate>=1.4.0",
"datasets>=4.7.0", # Support Json type and on_mixed_types="use_json"
"huggingface-hub>=0.36.2",
"packaging>20.0",
"transformers>=4.56.2",
]
Expand Down
35 changes: 35 additions & 0 deletions trl/experimental/async_grpo/async_grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,38 @@ class AsyncGRPOConfig(_BaseConfig):
metadata={"help": "Number of training steps between weight synchronizations to the vLLM server."},
)

# Delta weight sync
delta_sync_enabled: bool = field(
default=False,
metadata={
"help": "Enable delta-compressed weight synchronization. Instead of transferring all "
"weights over NCCL, encode only changed bf16 weights as sparse safetensors patches, "
"upload them to HuggingFace Hub (Xet storage), and signal vLLM to fetch and apply them."
},
)
delta_sync_repo_id: str | None = field(
default=None,
metadata={
"help": "HuggingFace Hub repository for storing delta weight patches and anchor "
"checkpoints (e.g. 'user/training-run-xyz'). Required when delta_sync_enabled=True. "
"The repo is created automatically if it does not exist."
},
)
delta_sync_anchor_interval: int = field(
default=10,
metadata={
"help": "Save a full bf16 anchor checkpoint every N weight sync steps. Between anchors "
"only sparse delta patches are saved. Fireworks blog uses N=25, PULSE paper uses N=50."
},
)
delta_sync_verify_checksum: bool = field(
default=True,
metadata={
"help": "Verify SHA256 checksum after applying each delta patch on the vLLM server. "
"Adds overhead per sync but guarantees bit-exact reconstruction."
},
)

# Parameters that control the logging
log_completions: bool = field(
default=False,
Expand All @@ -201,6 +233,9 @@ class AsyncGRPOConfig(_BaseConfig):
def __post_init__(self):
super().__post_init__()

if self.delta_sync_enabled and self.delta_sync_repo_id is None:
raise ValueError("delta_sync_repo_id is required when delta_sync_enabled=True")

# Accelerator config: required for the async IterableDataset-backed dataloader to work correctly.
# split_batches=True and dispatch_batches=True ensure that the main process drives the dataloader
# and batches are broadcast to other processes rather than each process pulling independently.
Expand Down
112 changes: 97 additions & 15 deletions trl/experimental/async_grpo/async_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from .async_grpo_config import AsyncGRPOConfig
from .async_rollout_worker import AsyncRolloutWorker
from .weight_diff import BF16ChangeDetector


logger = get_logger(__name__)
Expand Down Expand Up @@ -379,6 +380,9 @@ def __init__(
weight_names=weight_names,
weight_dtype_names=weight_dtype_names,
weight_shapes=weight_shapes,
delta_sync_enabled=self.args.delta_sync_enabled,
delta_sync_repo_id=self.args.delta_sync_repo_id,
delta_sync_anchor_interval=self.args.delta_sync_anchor_interval,
)
self.rollout_queue = self.rollout_worker.rollout_buffer
else:
Expand All @@ -388,6 +392,9 @@ def __init__(
# Add callbacks
self.add_callback(StepIntervalCallback(self._sync_weight, self.args.weight_sync_steps))

# ULP change detector for diagnostic logging (delta sync only)
self._change_detector: BF16ChangeDetector | None = None

def get_train_dataloader(self) -> DataLoader:
if self.accelerator.is_main_process:
dataset = RolloutQueueDataset(
Expand Down Expand Up @@ -562,43 +569,115 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
self._metrics[mode].clear()

def _streaming_iter(self):
# Iterate parameters one at a time. For FSDP2 (DTensor), full_tensor() all-gathers just this parameter across
# FSDP ranks, then frees it once the generator advances — avoiding materializing the full model in memory.
"""Yield ``(name, tensor, mask)`` tuples.

- No change detector (NCCL path or first sync): all params, ``mask=None``.
- Change detector active: only changed params with element-level masks.

The anchor/delta decision is NOT made here — the rollout worker handles that.
"""
if self._change_detector is None or not self._change_detector._validated_masks:
for name, param in self.model.named_parameters():
name = name.removeprefix("module.")
yield name, (param.full_tensor() if isinstance(param, DTensor) else param.detach()), None
return

masks = self._change_detector._validated_masks
total, yielded = 0, 0
for name, param in self.model.named_parameters():
name = name.removeprefix("module.") # DDP/FSDP1 wrapping
full = param.full_tensor() if isinstance(param, DTensor) else param.detach()
yield name, full
total += 1
name = name.removeprefix("module.")
mask = masks.get(name)
if mask is None or not mask.any():
continue
yield name, (param.full_tensor() if isinstance(param, DTensor) else param.detach()), mask
yielded += 1
logger.info(f"Delta: {yielded}/{total} params changed")

def _sync_weight(self):
# Lazy-init ULP detector for diagnostic logging (delta sync only) bc
# Optimizer only exists after Trainer creates it inside super()._inner_training_loop().
if (
self.args.delta_sync_enabled
and self._change_detector is None
and hasattr(self, "optimizer")
and self.optimizer is not None
):
# TODO(@aminediro): check this works with FSDP2
# Unwrap AcceleratedOptimizer to get the native PyTorch optimizer
# (register_step_pre_hook requires torch.optim.Optimizer internals)
raw_optimizer = getattr(self.optimizer, "optimizer", self.optimizer)
self._change_detector = BF16ChangeDetector(self.model, raw_optimizer)

if (
self.args.delta_sync_enabled
and self._change_detector is not None
and self._change_detector._validated_masks
and self.accelerator.is_main_process
):
total_changed = 0
total_elements = 0
for mask in self._change_detector._validated_masks.values():
total_changed += mask.sum().item()
total_elements += mask.numel()
sparsity = 1.0 - total_changed / max(total_elements, 1)
self._metrics["train"]["delta/sparsity"].append(sparsity)
self._metrics["train"]["delta/total_changed"].append(total_changed)
self._metrics["train"]["delta/total_elements"].append(total_elements)
logger.info(f"Delta: {total_changed}/{total_elements} elements changed (sparsity={sparsity:.4%})")

t0 = time.time()
logger.info("Weight sync: pausing vLLM...")
is_delta = self.args.delta_sync_enabled

if is_delta:
# Phase 1: Upload to HF Hub while inference continues
logger.info("Weight sync: uploading to HF Hub (inference still running)...")
if self.accelerator.is_main_process and self.rollout_worker:
self.rollout_worker.send_weights(self._streaming_iter())
else:
for _ in self._streaming_iter():
pass
self.accelerator.wait_for_everyone()
t_upload = time.time()
logger.info(f"Weight sync: upload took {t_upload - t0:.1f}s, now pausing vLLM...")

# Phase 2: Pause inference
if self.accelerator.is_main_process and self.rollout_worker:
self.rollout_worker.pause()
t_pause = time.time()
logger.info(f"Weight sync: pause took {t_pause - t0:.1f}s, waiting for all ranks...")

self.accelerator.wait_for_everyone()
t_barrier = time.time()

logger.info(f"Weight sync: transferring weights... (barrier took {t_barrier - t_pause:.1f}s)")
if self.accelerator.is_main_process and self.rollout_worker:
self.rollout_worker.send_weights(self._streaming_iter())
if is_delta:
# Phase 3: Signal vLLM to fetch the already-uploaded weights
logger.info(f"Weight sync: signaling vLLM to apply... (pause took {t_pause - t_upload:.1f}s)")
if self.accelerator.is_main_process and self.rollout_worker:
self.rollout_worker.send_weights(iter([]))
else:
# Non-rank-0 processes must still participate in full_tensor() collectives for FSDP2.
for _ in self._streaming_iter():
pass
# NCCL: transfer all weights directly
logger.info(f"Weight sync: transferring weights... (barrier took {t_barrier - t_pause:.1f}s)")
if self.accelerator.is_main_process and self.rollout_worker:
self.rollout_worker.send_weights(self._streaming_iter())
else:
for _ in self._streaming_iter():
pass

t_transfer = time.time()

self.accelerator.wait_for_everyone()

logger.info(f"Weight sync: resuming vLLM... (transfer took {t_transfer - t_barrier:.1f}s)")
# Phase 4: Resume
logger.info(f"Weight sync: resuming vLLM... (apply took {t_transfer - t_barrier:.1f}s)")
if self.accelerator.is_main_process and self.rollout_worker:
self.rollout_worker.resume()
self.model_version += 1
self.rollout_worker.update_model_version(self.model_version)
weight_sync_time_s = time.time() - t0
self._metrics["train"]["weight_sync_time_s"].append(weight_sync_time_s)
logger.info(f"Weight sync: done. Total {weight_sync_time_s:.1f}s")
logger.info(
f"Weight sync: done. Total {weight_sync_time_s:.1f}s (inference paused {t_transfer - t_pause:.1f}s)"
)

def _inner_training_loop(self, *args, **kwargs):
# Start the rollout worker here (not in __init__) so that checkpoint loading in Trainer.train()
Expand All @@ -613,3 +692,6 @@ def _inner_training_loop(self, *args, **kwargs):
finally:
if self.accelerator.is_main_process and self.rollout_worker:
self.rollout_worker.stop()
if self._change_detector is not None:
self._change_detector.close()
self._change_detector = None
81 changes: 78 additions & 3 deletions trl/experimental/async_grpo/async_rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import asyncio
import inspect
import itertools
import queue
import threading
import time
Expand All @@ -26,12 +27,15 @@
import requests
from accelerate.logging import get_logger
from datasets import Dataset
from huggingface_hub import create_bucket
from transformers import AutoTokenizer

from trl.chat_template_utils import add_response_schema, get_training_chat_template, parse_response
from trl.import_utils import is_vllm_available
from trl.trainer.utils import print_prompt_completions_sample

from .delta_engine import DeltaWeightTransferEngine


if is_vllm_available(min_version="0.17.1"):
from vllm.distributed.weight_transfer.nccl_engine import NCCLTrainerSendWeightsArgs, NCCLWeightTransferEngine
Expand Down Expand Up @@ -102,11 +106,15 @@ def __init__(
weight_names: list[str] | None = None,
weight_dtype_names: list[str] | None = None,
weight_shapes: list[list[int]] | None = None,
delta_sync_enabled: bool = False,
delta_sync_repo_id: str | None = None,
delta_sync_anchor_interval: int = 10,
):
if not is_vllm_available(min_version="0.17.1"):
raise ImportError(
"vLLM >= 0.17.1 is required to use AsyncRolloutWorker. Install it with: pip install 'vllm>=0.17.1'"
)
self.delta_sync_enabled = delta_sync_enabled
self.model_name = model_name
self.max_tool_calling_iterations = max_tool_calling_iterations
self.dataset = dataset
Expand Down Expand Up @@ -171,7 +179,10 @@ def __init__(
self.model_version = 0
self.session = None

# Wait for the vLLM server and initialize NCCL weight transfer.
self._delta_sync_repo_id = delta_sync_repo_id
self._delta_sync_anchor_interval = delta_sync_anchor_interval

# Wait for the vLLM server and initialize weight transfer.
self._wait_for_server_ready_sync(timeout_s=self.server_timeout)
self._init_weight_transfer()

Expand Down Expand Up @@ -199,6 +210,18 @@ def _wait_for_server_ready_sync(self, timeout_s: float = 240.0, poll_interval_s:
time.sleep(poll_interval_s)

def _init_weight_transfer(self) -> None:
if self.delta_sync_enabled:
create_bucket(self._delta_sync_repo_id, exist_ok=True)
self._delta_model_version = 0
self._delta_pending_update_info: dict | None = None
requests.post(
f"{self.vllm_server_url}/init_weight_transfer_engine",
json={"init_info": {}},
timeout=120,
)
logger.info("Init delta weight transfer with HF Hub repo %s", self._delta_sync_repo_id)
return

response = requests.get(f"{self.vllm_server_url}/get_world_size")
inference_world_size = response.json()["world_size"]
world_size = inference_world_size + 1
Expand Down Expand Up @@ -287,6 +310,58 @@ def resume(self) -> None:
logger.debug(f"[weight_sync] resume HTTP took {time.time() - t0:.1f}s")

def send_weights(self, iterator) -> None:
if self.delta_sync_enabled:
self._send_weights_delta(iterator)
else:
self._send_weights_nccl(iterator)

def _send_weights_delta(self, iterator) -> None:
"""Delta sync via HF Bucket.

- Non-empty iterator: upload (anchor or delta based on step count).
- Empty iterator + pending info: signal vLLM to apply.
- Empty iterator + nothing pending: no-op.
"""
first = next(iterator, None)

# (empty iterator)
if first is None:
if self._delta_pending_update_info is not None:
resp = requests.post(
f"{self.vllm_server_url}/update_weights",
json={"update_info": self._delta_pending_update_info},
timeout=300,
)
resp.raise_for_status()
self._delta_pending_update_info = None
return

# Upload phase
self._delta_model_version += 1
is_anchor = self._delta_model_version == 1 or self._delta_model_version % self._delta_sync_anchor_interval == 0

full_iter = itertools.chain([first], iterator)
if is_anchor:
# Force full tensors — strip masks
full_iter = ((name, tensor, None) for name, tensor, _mask in full_iter)

subdir = "anchors" if is_anchor else "deltas"
filename = f"{subdir}/step_{self._delta_model_version:06d}.safetensors"
meta = DeltaWeightTransferEngine.upload(
iterator=full_iter,
bucket_id=self._delta_sync_repo_id,
filename=filename,
model_version=self._delta_model_version,
)
if meta is not None:
self._delta_pending_update_info = {
"repo_id": self._delta_sync_repo_id,
"filename": filename,
"is_checkpoint_format": True,
}

def _send_weights_nccl(self, iterator) -> None:
"""NCCL sync: broadcast all params via NCCL + signal /update_weights."""
if self.model_update_group is None:
return
t0 = time.time()
Expand All @@ -299,7 +374,7 @@ def send_weights(self, iterator) -> None:
logger.debug(f"[weight_sync] /update_weights POST sent ({time.time() - t0:.1f}s)")
t_nccl = time.time()
NCCLWeightTransferEngine.trainer_send_weights(
iterator=iterator,
iterator=((name, tensor) for name, tensor, _mask in iterator),
trainer_args=NCCLTrainerSendWeightsArgs(group=self.model_update_group, packed=True),
)
logger.debug(f"[weight_sync] NCCL transfer took {time.time() - t_nccl:.1f}s")
Expand Down Expand Up @@ -356,7 +431,7 @@ async def _generate_loop(self, stop_event: asyncio.Event) -> None:
# Current assumption: reset side effects matter, return value is ignored.
self.environments[slot].reset(**row)

logger.info(f"[slot] assigned slot={slot} group={group_id} free_after={len(free_slots)}")
logger.debug(f"[slot] assigned slot={slot} group={group_id} free_after={len(free_slots)}")
task = asyncio.create_task(
self._generate_one(pending_groups[group_id].prompt, tool_dict=self._sync_tool_dicts[slot])
)
Expand Down
Loading
Loading