Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
148 commits
Select commit Hold shift + click to select a range
a3c9953
deepspeed moe integration
sfc-gh-sbekman Aug 29, 2025
49df058
add gpt-oss copy
sfc-gh-sbekman Sep 8, 2025
f59b929
Add the moe module layer + ParallelGroup
sfc-gh-reyazda Sep 10, 2025
a7ae5a2
Merge branch 'main' into stas/ds-moe
sfc-gh-sbekman Sep 10, 2025
aec47ec
fix format
sfc-gh-reyazda Sep 10, 2025
ffca4e1
Merge branch 'stas/ds-moe' of https://github.com/snowflakedb/ArcticTr…
sfc-gh-reyazda Sep 10, 2025
7e207a1
add cpu_adam_moe optimizer
sfc-gh-sbekman Sep 11, 2025
2476d73
moe utils
sfc-gh-sbekman Sep 11, 2025
1327c12
Merge remote-tracking branch 'origin/main' into stas/ds-moe
sfc-gh-sbekman Sep 16, 2025
8da7b5b
Merge remote-tracking branch 'origin/main' into stas/ds-moe
sfc-gh-sbekman Sep 16, 2025
e4b2657
add moe_gemm
sfc-gh-reyazda Sep 17, 2025
e0b8cc1
Merge branch 'stas/ds-moe' of https://github.com/snowflakedb/ArcticTr…
sfc-gh-reyazda Sep 17, 2025
4eb0825
fix typo
sfc-gh-reyazda Sep 18, 2025
27b1602
fix imports
sfc-gh-reyazda Sep 18, 2025
b5ad32b
fix imports
sfc-gh-reyazda Sep 18, 2025
03309c1
update
sfc-gh-sbekman Sep 19, 2025
a7aff40
hijack mlp/moe + optimizer groups setup
sfc-gh-sbekman Sep 23, 2025
25b881b
top_k
sfc-gh-sbekman Sep 23, 2025
1fd7e10
fix
sfc-gh-sbekman Sep 24, 2025
1de0ec6
add ep group creation + refactoring
sfc-gh-sbekman Sep 25, 2025
0c41479
gate+up->gate_up
sfc-gh-sbekman Sep 26, 2025
92361dd
fix indexing
sfc-gh-reyazda Sep 26, 2025
abb0d9d
Merge branch 'stas/ds-moe' of https://github.com/snowflakedb/ArcticTr…
sfc-gh-reyazda Sep 26, 2025
39fe701
rename some variables to better show their purpose
sfc-gh-reyazda Sep 26, 2025
513a2e4
add the local_expert_size and fix the counts for group-gemm
sfc-gh-reyazda Sep 26, 2025
d5cdd07
use only local experts
sfc-gh-sbekman Sep 26, 2025
c0e2d0e
fix the params and computation for the gated mlp
sfc-gh-reyazda Oct 1, 2025
06ec55d
use the new names
sfc-gh-reyazda Oct 1, 2025
4340857
change the gate and up weight order before concatenation
sfc-gh-reyazda Oct 1, 2025
1915c95
fix
sfc-gh-reyazda Oct 1, 2025
d73af51
fix gpt-oss
sfc-gh-reyazda Oct 1, 2025
0399592
Merge remote-tracking branch 'origin/main' into stas/ds-moe
sfc-gh-sbekman Oct 2, 2025
d587bf4
Merge remote-tracking branch 'origin/main' into stas/ds-moe
sfc-gh-sbekman Oct 2, 2025
da5ad46
switch to ds groups + refactor
sfc-gh-sbekman Oct 3, 2025
5629d8f
guard the triton import for the group-gemm
sfc-gh-reyazda Oct 6, 2025
2355da9
Merge branch 'stas/ds-moe' of https://github.com/snowflakedb/ArcticTr…
sfc-gh-reyazda Oct 6, 2025
4446fb1
Merge branch 'main' into stas/ds-moe
sfc-gh-sbekman Oct 6, 2025
0c409eb
format
sfc-gh-sbekman Oct 6, 2025
c1a3219
resync qwen3 and gpt-oss modeling code with the latest transformers
sfc-gh-sbekman Oct 6, 2025
faaeb98
wip
sfc-gh-sbekman Oct 7, 2025
6eced50
rework remapping
sfc-gh-sbekman Oct 8, 2025
2fcfd33
fix accuracy bugs
sfc-gh-reyazda Oct 9, 2025
1118589
Update debug.py
sfc-gh-sbekman Oct 9, 2025
ea8a538
add underflow_overflow debug module + cleanup/restore theo original test
sfc-gh-sbekman Oct 9, 2025
def0fa4
update
sfc-gh-sbekman Oct 9, 2025
2382f9e
Merge remote-tracking branch 'origin/main' into stas/ds-moe
sfc-gh-sbekman Oct 9, 2025
49b10d5
fix missing router gate copy for gpt-oss
sfc-gh-sbekman Oct 10, 2025
01ed80c
cleanup
sfc-gh-sbekman Oct 10, 2025
0328f75
make router_gate a parameter
sfc-gh-sbekman Oct 10, 2025
8345772
add the normalization plus token permutation for the local-experts
sfc-gh-reyazda Oct 10, 2025
958348b
Merge branch 'stas/ds-moe' of https://github.com/snowflakedb/ArcticTr…
sfc-gh-reyazda Oct 10, 2025
0b29428
add router aux loss coefficient
sfc-gh-reyazda Oct 10, 2025
6267388
wandb improvement
sfc-gh-sbekman Oct 10, 2025
a585d16
use the same zero stage
sfc-gh-sbekman Oct 10, 2025
32682ae
improve Makefile
sfc-gh-sbekman Oct 10, 2025
6afcf49
debug-off
sfc-gh-sbekman Oct 10, 2025
389837d
fix gpt-oss use-case
sfc-gh-sbekman Oct 10, 2025
46192bf
cleaner debug
sfc-gh-sbekman Oct 10, 2025
f9a3365
link back modeling copies
sfc-gh-sbekman Oct 10, 2025
d345ce4
add small gpt-oss models
sfc-gh-sbekman Oct 14, 2025
24b0004
sync local gpt-oss copy
sfc-gh-sbekman Oct 14, 2025
bd8c8ef
cleanup
sfc-gh-sbekman Oct 14, 2025
75c920c
Merge remote-tracking branch 'origin/main' into stas/ds-moe
sfc-gh-sbekman Oct 14, 2025
0f4df65
sort config
sfc-gh-sbekman Oct 14, 2025
063da8b
sort config
sfc-gh-sbekman Oct 14, 2025
9a6360c
use .shape consistently
sfc-gh-sbekman Oct 14, 2025
d522c86
consistent torch.empty dim args
sfc-gh-sbekman Oct 14, 2025
724e3df
fix gpt-oss activation function
sfc-gh-sbekman Oct 14, 2025
b0afc66
exclude router_gate, use cpu_adam_moe
sfc-gh-sbekman Oct 14, 2025
f1d4eea
flatten moe optim group into 1
sfc-gh-sbekman Oct 14, 2025
38fbd58
renames
sfc-gh-sbekman Oct 15, 2025
95ef4b1
project's readme file
sfc-gh-sbekman Oct 15, 2025
06864b4
add the alltoallv function to explicitly define backward
sfc-gh-reyazda Oct 16, 2025
da23d81
Merge branch 'stas/ds-moe' of https://github.com/snowflakedb/ArcticTr…
sfc-gh-reyazda Oct 16, 2025
013618d
fix var name
sfc-gh-sbekman Oct 16, 2025
3e08b60
complete ds integration with EP>1
sfc-gh-sbekman Oct 16, 2025
8a18bc4
extend the test to test 3 iterations w/ bf16 tolerance
sfc-gh-sbekman Oct 16, 2025
166eaa7
add shared-expert; fix modeling qwen
sfc-gh-reyazda Oct 29, 2025
ce2734d
remove the transpose when ep=1
sfc-gh-reyazda Nov 4, 2025
ff1356b
Merge remote-tracking branch 'origin/main' into stas/ds-moe
sfc-gh-sbekman Nov 4, 2025
c4ab1fd
Merge branch 'stas/ds-moe' of https://github.com/snowflakedb/ArcticTr…
sfc-gh-sbekman Nov 4, 2025
790a675
add repr
sfc-gh-sbekman Nov 4, 2025
89b3867
qwen3-next weight import + test
sfc-gh-sbekman Nov 4, 2025
b038cdb
revert
sfc-gh-sbekman Nov 4, 2025
6978265
move the shared expert-computation after moe-combine
sfc-gh-reyazda Nov 4, 2025
1cb2254
fix repr
sfc-gh-sbekman Nov 5, 2025
b6ce5c7
some models have a different intermediate size for experts than norma…
sfc-gh-sbekman Nov 5, 2025
d6f4677
fix
sfc-gh-sbekman Nov 5, 2025
a5f597e
Merge remote-tracking branch 'origin/main' into stas/ds-moe
sfc-gh-sbekman Nov 8, 2025
944ca1b
missing git
sfc-gh-sbekman Nov 8, 2025
cc9b967
Merge remote-tracking branch 'origin/main' into stas/ds-moe
sfc-gh-sbekman Nov 10, 2025
450dcc8
fix
sfc-gh-sbekman Nov 10, 2025
370d1f6
fix
sfc-gh-sbekman Nov 10, 2025
339875d
Merge remote-tracking branch 'origin/main' into stas/ds-moe
sfc-gh-sbekman Nov 12, 2025
511638d
lots of updates
sfc-gh-sbekman Nov 14, 2025
b4e8e2f
Merge remote-tracking branch 'origin/main' into stas/ds-moe
sfc-gh-sbekman Nov 14, 2025
f3f23aa
fix transformers to a working version
sfc-gh-sbekman Nov 14, 2025
6080fc0
cleanup
sfc-gh-sbekman Nov 14, 2025
5d292c1
simplify making dense variants; start qwen3-next flops
sfc-gh-sbekman Nov 17, 2025
b9cd573
add support for linear gated delta net attention
sfc-gh-sbekman Nov 17, 2025
956e4cb
sync
sfc-gh-sbekman Nov 18, 2025
ebc7caa
fix
sfc-gh-sbekman Nov 18, 2025
80df63d
faster start with partial model
sfc-gh-sbekman Nov 19, 2025
f490f2d
fix flop counter
sfc-gh-sbekman Nov 19, 2025
9111699
add per component time and token profiler
sfc-gh-sbekman Nov 19, 2025
1676ca2
Merge remote-tracking branch 'origin/main' into stas/ds-moe
sfc-gh-sbekman Nov 19, 2025
2acb5db
fix
sfc-gh-sbekman Nov 19, 2025
e9e37f6
cleanup
sfc-gh-sbekman Nov 19, 2025
2e0301d
cleanup
sfc-gh-sbekman Nov 19, 2025
e896001
cleanup
sfc-gh-sbekman Nov 19, 2025
4b705d5
fix imports
sfc-gh-sbekman Nov 19, 2025
728165a
Merge remote-tracking branch 'origin/main' into stas/ds-moe
sfc-gh-sbekman Nov 20, 2025
87da435
adapt to changes
sfc-gh-sbekman Nov 20, 2025
487b435
add missing marker
sfc-gh-sbekman Nov 25, 2025
45bffd8
Merge remote-tracking branch 'origin/main' into stas/ds-moe
sfc-gh-sbekman Nov 25, 2025
65e5d3b
remap params back to original during ckpt save + test
sfc-gh-sbekman Nov 27, 2025
0aa1b12
remap params back to original during ckpt save + test
sfc-gh-sbekman Nov 27, 2025
0b16f94
add qwen3 next weights export support + test
sfc-gh-sbekman Dec 9, 2025
840c52a
Merge remote-tracking branch 'origin/main' into stas/ds-moe
sfc-gh-sbekman Dec 10, 2025
1cde944
test tolerance
sfc-gh-sbekman Dec 10, 2025
9eede36
Merge remote-tracking branch 'origin/main' into stas/ds-moe
sfc-gh-sbekman Dec 10, 2025
e3a2be9
checkpoint save+resume+export
sfc-gh-sbekman Dec 13, 2025
4a1712b
save-resume works
sfc-gh-sbekman Dec 16, 2025
3b505a6
continous wandb runs
sfc-gh-sbekman Dec 16, 2025
094d1f4
fix
sfc-gh-sbekman Dec 16, 2025
586a2d9
Merge remote-tracking branch 'origin/main' into stas/ds-moe
sfc-gh-sbekman Dec 19, 2025
f6fd48a
Add the custom ops to optimize moe training performance
sfc-gh-reyazda Mar 27, 2026
242d3a8
Add the custom ops to optimize moe training performance
sfc-gh-reyazda Mar 27, 2026
531eb8d
Add the routing replay logic
sfc-gh-reyazda Mar 27, 2026
0cedc24
Merge remote-tracking branch 'origin/main' into stas/ds-moe
sfc-gh-sbekman Mar 30, 2026
8315805
cleanup
sfc-gh-sbekman Mar 31, 2026
5729a1c
cleanup
sfc-gh-sbekman Mar 31, 2026
bb4b942
fix docs breaking
sfc-gh-mwyatt Apr 1, 2026
5f8dd17
Add moe-training kernels (#367)
sfc-gh-reyazda Apr 3, 2026
cf3430d
Merge branch 'stas/ds-moe' into moe-custom-ops
sfc-gh-reyazda Apr 3, 2026
3e30bf5
Merge branch 'main' into moe-custom-ops
sfc-gh-reyazda Apr 3, 2026
c6fef9f
Add the custom Comm for creating new custom communication collectives
sfc-gh-reyazda Apr 3, 2026
8d2103b
fixes
sfc-gh-reyazda Apr 7, 2026
ee20cec
fix compile issue
sfc-gh-reyazda Apr 8, 2026
bb7d20f
add test_comm
sfc-gh-reyazda Apr 8, 2026
51aefab
add the counts for all2all op
sfc-gh-reyazda Apr 8, 2026
0b878d7
add count info/tensor for the alltoall-v functionality
sfc-gh-reyazda Apr 10, 2026
f74d305
Merge branch 'main' into comm-ops
sfc-gh-reyazda Apr 13, 2026
d0f1613
fix alltall rcv count
sfc-gh-reyazda Apr 14, 2026
41fabf5
Merge branch 'comm-ops' of https://github.com/snowflakedb/ArcticTrain…
sfc-gh-reyazda Apr 14, 2026
6138225
fixes
Apr 24, 2026
c255f70
some tweaks to prevent OOM, need to work on fixing the gemm kernel
Apr 27, 2026
64789fa
using torch._grouped_mm
Apr 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions arctic_training/config/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,8 @@ def build_deepspeed_config(self) -> Self:
from transformers import AutoConfig

model_config = AutoConfig.from_pretrained(self.model.name_or_path)
model_config = model_config.text_config if hasattr(model_config, "text_config") else model_config

if hasattr(model_config, "hidden_size"):
hidden_size = model_config.hidden_size
elif hasattr(model_config, "hidden_sizes"):
Expand Down Expand Up @@ -543,6 +545,7 @@ def get_config(config_file_or_dict: Union[Path, Dict]) -> BaseConfig:

trainer_cls = get_registered_trainer(trainer_type)
config_cls = _get_class_attr_type_hints(trainer_cls, "config")[0]

config = config_cls(**config_dict)

return config
15 changes: 15 additions & 0 deletions arctic_training/kernels/comm/comm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include "comm.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("init_comm_group", &ds_create_comm, "create comm group");
m.def("barrier", &ds_barrier, "barrier");
m.def("broadcast", &ds_broadcast, "broadcast");
m.def("wait_comm", &wait_comm, "wait on communication event");
m.def("allReduce", &ds_allreduce, "AllReduce");
m.def("alltoall", &ds_alltoall, "AllToAll");
m.def("allGather", &ds_allgather, "AllGather");
m.def("get_nccl_uid", &ds_get_nccl_uid, "Get NCCL UID");
m.def("init_nccl_comm", &ds_create_nccl_comm, "Create NCCL Comm");
}
//////
124 changes: 124 additions & 0 deletions arctic_training/kernels/comm/comm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#include "comm.h"
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>

#include <vector>
#include "context.h"

void ds_create_comm(std::vector<int>& comm_ranks, int rank)
{
CoMMContext::Instance().create_comm_group(comm_ranks, rank);
}

void ds_create_nccl_comm(std::vector<int>& comm_ranks, int rank, torch::Tensor& nccl_uid)
{
CoMMContext::Instance().create_nccl_comm(comm_ranks, rank, nccl_uid.data_ptr());
}

void ds_allreduce(torch::Tensor& send_buf, torch::Tensor& rcv_buf, int size, bool async_op)
{
if (async_op) CoMMContext::Instance().SynchComp();
ncclAllReduce(send_buf.data_ptr(),
rcv_buf.data_ptr(),
size,
(send_buf.scalar_type() == at::kFloat ? ncclFloat : (send_buf.scalar_type() == at::kHalf ? ncclHalf : ncclBfloat16)),
ncclSum,
CoMMContext::Instance().GetNCCLComm(),
CoMMContext::Instance().GetCommStream());
}

void ds_allgather(torch::Tensor& send_buf, torch::Tensor& rcv_buf, int size, bool async_op)
{
if (async_op) CoMMContext::Instance().SynchComp();
ncclAllGather(send_buf.data_ptr(),
rcv_buf.data_ptr(),
size,
(send_buf.scalar_type() == at::kFloat ? ncclFloat : (send_buf.scalar_type() == at::kHalf ? ncclHalf : ncclBfloat16)),
CoMMContext::Instance().GetNCCLComm(),
CoMMContext::Instance().GetCommStream());
}

void wait_comm() { CoMMContext::Instance().SynchComm(); }

void ds_broadcast(torch::Tensor& send_buf, torch::Tensor& rcv_buf, int size, bool async_op)
{
ncclBroadcast(send_buf.data_ptr(),
rcv_buf.data_ptr(),
size,
(send_buf.scalar_type() == at::kFloat ? ncclFloat : (send_buf.scalar_type() == at::kHalf ? ncclHalf : ncclBfloat16)),
0,
CoMMContext::Instance().GetNCCLComm(),
CoMMContext::Instance().GetCommStream());
}

void ds_barrier() { CoMMContext::Instance().barrier(); }

inline size_t wordSize(ncclDataType_t type) {
switch(type) {
case ncclChar:
case ncclUint8:
return 1;
case ncclHalf:
case ncclBfloat16:
return 2;
case ncclInt:
case ncclFloat:
case ncclUint32:
return 4;
case ncclInt64:
case ncclUint64:
case ncclDouble:
return 8;
default: return 0;
}
}

void ncclAlltoAll(void* sendbuff,
void* recvbuff,
int32_t *send_counts,
int32_t *recv_counts,
size_t max_count,
ncclDataType_t type,
const unsigned nRanks,
ncclComm_t comm,
cudaStream_t stream) {

size_t rankOffset = max_count * wordSize(type);

ncclGroupStart();
for (int r=0; r<nRanks; r++) {
ncclSend(((char*)sendbuff)+r*rankOffset, send_counts[r], type, r, comm, stream);
ncclRecv(((char*)recvbuff)+r*rankOffset, recv_counts[r], type, r, comm, stream);
}
ncclGroupEnd();
}

void ds_alltoall(torch::Tensor& send_buf, torch::Tensor& rcv_buf, torch::Tensor& send_counts, torch::Tensor& recv_counts, size_t max_count, bool async_op)
{
ncclAlltoAll(send_buf.data_ptr(),
rcv_buf.data_ptr(),
(int32_t*)send_counts.data_ptr(),
(int32_t*)recv_counts.data_ptr(),
max_count,
(send_buf.scalar_type() == at::kFloat ?
ncclFloat :
(send_buf.scalar_type() == at::kHalf ?
ncclHalf :
(send_buf.scalar_type() == torch::kInt8 ? ncclUint8 : ncclBfloat16))),
CoMMContext::Instance().GetNumRanks(),
CoMMContext::Instance().GetNCCLComm(),
CoMMContext::Instance().GetCommStream());
}

torch::Tensor ds_get_nccl_uid()
{

auto options = at::TensorOptions()
.dtype(torch::kUInt8)
.layout(torch::kStrided)
.device(torch::kCPU)
.requires_grad(false);
auto nccl_uid = CoMMContext::Instance().get_nccl_uid();
auto uid_tensor = torch::from_blob((void*)&nccl_uid, {sizeof(ncclUniqueId)}, options);
return uid_tensor;
}
15 changes: 15 additions & 0 deletions arctic_training/kernels/comm/comm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

#include <torch/extension.h>
#include <stdint.h>

#include "stdio.h"

void ds_barrier();
void wait_comm();
void ds_create_comm(std::vector<int>& comm_ranks, int rank);
void ds_allreduce(torch::Tensor& send_buf, torch::Tensor& rcv_buf, int size, bool async_op);
void ds_allgather(torch::Tensor& send_buf, torch::Tensor& rcv_buf, int size, bool async_op);
void ds_broadcast(torch::Tensor& send_buf, torch::Tensor& rcv_buf, int size, bool async_op);
void ds_alltoall(torch::Tensor& send_buf, torch::Tensor& rcv_buf, torch::Tensor& send_counts, torch::Tensor& recv_counts, size_t max_count, bool async_op);
torch::Tensor ds_get_nccl_uid();
void ds_create_nccl_comm(std::vector<int>& comm_ranks, int rank, torch::Tensor& nccl_uid);
138 changes: 138 additions & 0 deletions arctic_training/kernels/comm/comm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch

from arctic_training.op_builder import CommBuilder

from .layout import Layout

ds_comm = None


class Comm:
current_comm = None

def __init__(self, layout: Layout, local_rank: int):
global ds_comm
if ds_comm is None:
ds_comm = CommBuilder().load()

self.ds_comm = ds_comm
self._layout = layout
self.my_rank = local_rank
self.global_rank = torch.distributed.get_rank()
self.group_size = layout._group_size

self.global_ranks = layout.sibling_ranks(self.global_rank)
self.local_ranks = list(range(self.group_size))
self.rank_map = dict(zip(self.global_ranks, self.local_ranks))
self.counts_pinned_data = torch.empty(1024, dtype=torch.int32, device="cpu", pin_memory=True)
self.recv_counts_pinned_data = torch.empty(1024, dtype=torch.int32, device="cpu", pin_memory=True)

print("Initializing comm ...")

def all_reduce(self, val, inplace=True, async_op=False):
val_sum = val if inplace else torch.empty_like(val)
op = communicate_op(val, val_sum, async_op, op_type="all_reduce")
return val_sum, op

def all_gather(self, val, inplace=True, async_op=False):
val_gather = torch.empty((self.group_size * val.size(0), *val.shape[1:]), device=val.device, dtype=val.dtype)
op = communicate_op(val, val_gather, async_op, op_type="all_gather")
return val_gather, op

def all_to_all(
self, val, counts=None, receive_counts=None, max_count=None, result=None, inplace=True, async_op=False
):
if counts is not None:
if receive_counts is None:
receive_counts = torch.empty_like(counts)
torch.distributed.all_to_all_single(receive_counts, counts)

self.counts_pinned_data[: receive_counts.numel()].copy_(counts)

if max_count is None:
max_count = counts.max()
torch.distributed.all_reduce(max_count, op=torch.distributed.ReduceOp.MAX)

max_count = max_count.item()

receive_counts = self.recv_counts_pinned_data[: receive_counts.numel()].copy_(receive_counts)
counts = self.counts_pinned_data[: counts.numel()]
else:
max_count = val.size(0) // self.group_size
counts = torch.full((self.group_size,), max_count, device="cpu", dtype=torch.int32)
receive_counts = counts

result = result if result is not None else torch.empty_like(val)
op = communicate_op(
val,
result,
async_op,
world_size=self.group_size,
op_type="all_to_all",
send_counts=counts,
recv_counts=receive_counts,
max_count=max_count,
)
return result, op

def broadcast(self, val, inplace=True, async_op=False):
val_bcst = torch.empty_like(val)
op = communicate_op(val, val_bcst, async_op, op_type="broadcast")
return val_bcst, op

def barrier(self):
ds_comm.wait_comm()
ds_comm.barrier()

@classmethod
def get_current_comm(cls):
if cls.current_comm is None:
from arctic_training.kernels.comm.nccl import NcclComm

cls.current_comm = NcclComm()
return cls.current_comm


class communicate_op:
def __init__(
self,
val,
result,
async_op,
world_size=None,
op_type="all_reduce",
send_counts=None,
recv_counts=None,
max_count=None,
):
if op_type == "all_reduce":
ds_comm.allReduce(val, result, val.numel(), async_op)
elif op_type == "all_gather":
ds_comm.allGather(val, result, val.numel(), async_op)
elif op_type == "all_to_all":
ds_comm.alltoall(val, result, send_counts, recv_counts, max_count, async_op)
elif op_type == "broadcast":
ds_comm.broadcast(val, result, val.numel(), async_op)

def wait(self):
ds_comm.wait_comm()


def get_default_comm():
return Comm.get_current_comm()
Loading
Loading