Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 184 additions & 21 deletions xtuner/v1/module/dispatcher/agrs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Literal, TypeAlias, cast

import torch
Expand All @@ -7,12 +8,14 @@
AsyncCollectiveTensor,
all_gather_tensor,
all_gather_tensor_autograd,
all_to_all_single_autograd,
reduce_scatter_tensor,
reduce_scatter_tensor_autograd,
)
from typing_extensions import override

from xtuner.v1.ops import permute, unpermute
from xtuner.v1.ops.comm import AllGatherManager, ReduceScatterManager, SymmBufferManager
from xtuner.v1.utils import copy_method_signature, get_device, get_logger

from . import XTUNER_DISPATCHER_DEBUG
Expand All @@ -31,6 +34,16 @@
logger = get_logger()


USE_CUSTOM_AG = int(os.getenv("XTUNER_USE_CUSTOM_AG_IN_DISPATCHER", 0)) == 1
USE_CUSTOM_RS = int(os.getenv("XTUNER_USE_CUSTOM_RS_IN_DISPATCHER", 0)) == 1

ag_symm = None
rs_symm = None
ag_manager = None
rs_manager = None
rs_event = None


MoEAGRSHandle = tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]


Expand Down Expand Up @@ -106,9 +119,39 @@ def forward(
with torch.cuda.stream(comm_stream):
comm_stream.wait_event(forward_previous_event)

dispatched_hidden_states = all_gather_tensor_autograd(hidden_states, gather_dim=0, group=process_group)
if isinstance(dispatched_hidden_states, AsyncCollectiveTensor):
dispatched_hidden_states = dispatched_hidden_states.wait()
if USE_CUSTOM_AG:
global ag_manager, ag_symm
if ag_symm is None:
ag_symm = SymmBufferManager(int(os.getenv("SYMM_BUF_SIZE", 0)), num_buffers=1)
if ag_manager is None:
# agrs dispatcher do not need to select comm sm
ag_manager = AllGatherManager(num_buffers=1, select_comm_sm=False)
send_bytes = hidden_states.element_size() * hidden_states.numel()
recv_bytes = send_bytes * process_group.size()
recv_numel = hidden_states.numel() * process_group.size()

ag_manager.prepare_allgather_objects(
send_bytes=send_bytes,
process_group=process_group,
all_gather_stream=comm_stream,
barrier_all=False, # intra-node comm, no need to barrier across nodes
)
device = hidden_states.device
dtype = hidden_states.dtype
combined_grad_out_symm = ag_symm.get_buffer(bytes=recv_bytes, device=device)
combined_grad_out_symm = combined_grad_out_symm.view(dtype)[:recv_numel]

ag_manager.execute_allgather(
send_bytes=send_bytes,
all_gather_output=combined_grad_out_symm,
all_gather_input=hidden_states,
process_group=process_group,
)
dispatched_hidden_states = combined_grad_out_symm.view(-1, *hidden_states.shape[1:])
else:
dispatched_hidden_states = all_gather_tensor_autograd(hidden_states, gather_dim=0, group=process_group)
if isinstance(dispatched_hidden_states, AsyncCollectiveTensor):
dispatched_hidden_states = dispatched_hidden_states.wait()

# topk_ids (seq, topk)
topk_ids = topk_ids.T.flatten()
Expand Down Expand Up @@ -141,24 +184,77 @@ def forward(

@staticmethod
def backward(
ctx, grad_output: torch.Tensor, *args
) -> tuple[torch.Tensor | None, None, None, None, None, None, None, None, None]:
ctx, grad_output: torch.Tensor, grad_topk_ids: torch.Tensor | None, grad_topk_weights: torch.Tensor
) -> tuple[torch.Tensor | None, None, torch.Tensor | None, None, None, None, None, None, None]:
world_size = dist.get_world_size(group=ctx.process_group)
if world_size == 1:
return grad_output, None, None, None, None, None, None, None, None

with torch.cuda.stream(ctx.comm_stream):
# ctx.comm_stream.wait_stream(compute_stream)
if ctx.backward_previous_event is not None:
ctx.comm_stream.wait_event(ctx.backward_previous_event)
combined_grad_output = reduce_scatter_tensor(
grad_output, reduceOp="sum", scatter_dim=0, group=ctx.process_group
)

if USE_CUSTOM_RS:
global rs_manager, rs_symm
if rs_symm is None:
rs_symm = SymmBufferManager(int(os.getenv("SYMM_BUF_SIZE", 0)), num_buffers=1)
if rs_manager is None:
rs_manager = ReduceScatterManager(num_buffers=1, select_comm_sm=False)

process_group = ctx.process_group
comm_stream = ctx.comm_stream

send_bytes = grad_output.element_size() * grad_output.numel()
recv_bytes = send_bytes // process_group.size()
send_numel = grad_output.numel()

rs_manager.prepare_reducescatter_objects(
recv_bytes=recv_bytes,
process_group=process_group,
reduce_scatter_stream=comm_stream,
barrier_all=False, # intra-node comm, no need to barrier across nodes
)
device = grad_output.device
dtype = grad_output.dtype
symm_input = rs_symm.get_buffer(bytes=send_bytes, device=device)
symm_input = symm_input.view(dtype)[:send_numel]
symm_input = symm_input.view(grad_output.shape)

reduce_scatter_output_numel = recv_bytes // symm_input.element_size()
reduce_output = symm_input.new_empty((reduce_scatter_output_numel,))
combined_grad_output = rs_manager.execute_reducescatter(
recv_bytes=recv_bytes,
reduce_scatter_output=reduce_output,
reduce_scatter_input=symm_input,
reduce_scatter_group=process_group,
reduce_scatter_reduce_op=dist.ReduceOp.SUM,
)
combined_grad_output = combined_grad_output.view(-1, *grad_output.shape[1:])
else:
combined_grad_output = reduce_scatter_tensor(
grad_output, reduceOp="sum", scatter_dim=0, group=ctx.process_group
)

grad_output.record_stream(ctx.comm_stream)
combined_grad_output.record_stream(ctx.comm_stream)

world_size = dist.get_world_size(group=ctx.process_group)
grad_topk_weights = grad_topk_weights.view(-1)
combined_grad_topk_weights = torch.empty_like(grad_topk_weights)
dist.all_to_all_single(
combined_grad_topk_weights,
grad_topk_weights,
group=ctx.process_group,
)
combined_grad_topk_weights = combined_grad_topk_weights.view(world_size, -1)
combined_grad_topk_weights = combined_grad_topk_weights.T.contiguous()

# grad_topk_weights and combined_grad_topk_weights must record_stream, this is very important
grad_topk_weights.record_stream(ctx.comm_stream)
combined_grad_topk_weights.record_stream(ctx.comm_stream)
if ctx.backward_finished_event is not None:
ctx.backward_finished_event.record(ctx.comm_stream)
return combined_grad_output, None, None, None, None, None, None, None, None
return combined_grad_output, None, combined_grad_topk_weights, None, None, None, None, None, None


_async_dispatch = copy_method_signature(_AsyncDispatch.forward)(_AsyncDispatch.apply)
Expand All @@ -179,11 +275,46 @@ def forward(
with torch.cuda.stream(comm_stream):
comm_stream.wait_event(forward_previous_event)

combined_hidden_states = reduce_scatter_tensor_autograd(
hidden_states, reduceOp="sum", scatter_dim=0, group=process_group
)
if isinstance(combined_hidden_states, AsyncCollectiveTensor):
combined_hidden_states = combined_hidden_states.wait()
if USE_CUSTOM_RS:
global rs_manager, rs_symm, rs_event
if rs_symm is None:
rs_symm = SymmBufferManager(int(os.getenv("SYMM_BUF_SIZE", 0)), num_buffers=1)
if rs_manager is None:
rs_manager = ReduceScatterManager(num_buffers=1, select_comm_sm=False)

send_bytes = hidden_states.element_size() * hidden_states.numel()
recv_bytes = send_bytes // process_group.size()
send_numel = hidden_states.numel()

rs_manager.prepare_reducescatter_objects(
recv_bytes=recv_bytes,
process_group=process_group,
reduce_scatter_stream=comm_stream,
barrier_all=False, # intra-node comm, no need to barrier across nodes
)
device = hidden_states.device
dtype = hidden_states.dtype
symm_input = rs_symm.get_buffer(bytes=send_bytes, device=device)
symm_input = symm_input.view(dtype)[:send_numel]
symm_input = symm_input.view(hidden_states.shape)

reduce_scatter_output_numel = recv_bytes // symm_input.element_size()
reduce_output = symm_input.new_empty((reduce_scatter_output_numel,))
rs_manager.execute_reducescatter(
recv_bytes=recv_bytes,
reduce_scatter_output=reduce_output,
reduce_scatter_input=symm_input,
reduce_scatter_group=process_group,
reduce_scatter_reduce_op=dist.ReduceOp.SUM,
)
combined_hidden_states = reduce_output.view(-1, *hidden_states.shape[1:])
rs_event = comm_stream.record_event()
else:
combined_hidden_states = reduce_scatter_tensor_autograd(
hidden_states, reduceOp="sum", scatter_dim=0, group=process_group
)
if isinstance(combined_hidden_states, AsyncCollectiveTensor):
Comment on lines +279 to +316
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning — Fragile shared-buffer coupling between _AsyncCombine.forward and UnpermuteMoE_topK.forward.

This reduce-scatter path obtains a symmetric buffer via rs_symm.get_buffer() but never copies hidden_states into it. The correctness relies on an implicit invariant: hidden_states was already written into this exact buffer by _unpermute_inplace in UnpermuteMoE_topK.forward, and rs_symm has num_buffers=1 so the same buffer is returned.

This is extremely fragile — any code change that:

  • changes num_buffers from 1
  • inserts another rs_symm.get_buffer() call between unpermute and combine
  • reorders the combine/dispatch calls

...will silently produce incorrect results (reduce-scatter on stale/garbage data) with no error.

At minimum, add an assertion to verify the buffer identity, e.g.:

assert symm_input.data_ptr() == hidden_states.data_ptr(), \
    "RS symm buffer mismatch: hidden_states must reside in the symmetric buffer"

Or better, refactor to explicitly pass the buffer through the call chain instead of relying on global singleton state.

combined_hidden_states = combined_hidden_states.wait()

forward_finished_event.record(comm_stream)

Expand All @@ -205,7 +336,40 @@ def backward(
with torch.cuda.stream(ctx.comm_stream):
if ctx.backward_previous_event is not None:
ctx.comm_stream.wait_event(ctx.backward_previous_event)
combined_grad_output = all_gather_tensor(grad_output, gather_dim=0, group=ctx.process_group)

if USE_CUSTOM_AG:
global ag_manager, ag_symm
if ag_symm is None:
ag_symm = SymmBufferManager(int(os.getenv("SYMM_BUF_SIZE", 0)), num_buffers=1)
if ag_manager is None:
# agrs dispatcher do not need to select comm sm
ag_manager = AllGatherManager(num_buffers=1, select_comm_sm=False)

send_bytes = grad_output.element_size() * grad_output.numel()
recv_bytes = send_bytes * ctx.process_group.size()
recv_numel = grad_output.numel() * ctx.process_group.size()

ag_manager.prepare_allgather_objects(
send_bytes=send_bytes,
process_group=ctx.process_group,
all_gather_stream=ctx.comm_stream,
barrier_all=False, # intra-node comm, no need to barrier across nodes
)
device = grad_output.device
dtype = grad_output.dtype
combined_grad_out_symm = ag_symm.get_buffer(bytes=recv_bytes, device=device)
combined_grad_out_symm = combined_grad_out_symm.view(dtype)[:recv_numel]

ag_manager.execute_allgather(
send_bytes=send_bytes,
all_gather_output=combined_grad_out_symm,
all_gather_input=grad_output,
process_group=ctx.process_group,
)
combined_grad_output = combined_grad_out_symm.view(-1, *grad_output.shape[1:])
else:
combined_grad_output = all_gather_tensor(grad_output, gather_dim=0, group=ctx.process_group)

grad_output.record_stream(ctx.comm_stream)
combined_grad_output.record_stream(ctx.comm_stream)

Expand Down Expand Up @@ -313,10 +477,10 @@ def dispatch(
)
dispatched_topk_ids = dispatched_topk_ids.view(-1, 1)
topk_weights = topk_weights.T.flatten()
dispatched_topk_weights = torch.empty_like(topk_weights)
dist.all_to_all_single(
dispatched_topk_weights,
dispatched_topk_weights = all_to_all_single_autograd(
topk_weights,
output_split_sizes=None,
input_split_sizes=None,
group=self._process_group,
)
dispatched_topk_weights = dispatched_topk_weights.view(-1, 1)
Expand Down Expand Up @@ -467,13 +631,12 @@ def combine(
else:
forward_finished_event = None
backward_previous_event = None
hidden_states = pre_combined["hidden_states"] # .float()
hidden_states = pre_combined["hidden_states"]
combined_hidden_states = reduce_scatter_tensor_autograd(
hidden_states, reduceOp="sum", scatter_dim=0, group=self._process_group
)
if isinstance(combined_hidden_states, AsyncCollectiveTensor):
combined_hidden_states = combined_hidden_states.wait()
# combined_hidden_states = combined_hidden_states.bfloat16()

return MoEAGRSCombineResult(
hidden_states=combined_hidden_states,
Expand Down
10 changes: 9 additions & 1 deletion xtuner/v1/ops/comm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from .all_to_all import all_to_all_single_autograd, ulysses_all_to_all
from .nvls_agrs import AllGatherManager, ReduceScatterManager, SymmBufferManager
from .sequence_parallel import split_for_sequence_parallel


__all__ = ["all_to_all_single_autograd", "ulysses_all_to_all", "split_for_sequence_parallel"]
__all__ = [
"all_to_all_single_autograd",
"ulysses_all_to_all",
"split_for_sequence_parallel",
"SymmBufferManager",
"AllGatherManager",
"ReduceScatterManager",
]
Loading