diff --git a/arctic_training/config/trainer.py b/arctic_training/config/trainer.py index 7a9f94c6..85d27cfd 100644 --- a/arctic_training/config/trainer.py +++ b/arctic_training/config/trainer.py @@ -406,6 +406,8 @@ def build_deepspeed_config(self) -> Self: from transformers import AutoConfig model_config = AutoConfig.from_pretrained(self.model.name_or_path) + model_config = model_config.text_config if hasattr(model_config, "text_config") else model_config + if hasattr(model_config, "hidden_size"): hidden_size = model_config.hidden_size elif hasattr(model_config, "hidden_sizes"): @@ -543,6 +545,7 @@ def get_config(config_file_or_dict: Union[Path, Dict]) -> BaseConfig: trainer_cls = get_registered_trainer(trainer_type) config_cls = _get_class_attr_type_hints(trainer_cls, "config")[0] + config = config_cls(**config_dict) return config diff --git a/arctic_training/kernels/comm/comm.cpp b/arctic_training/kernels/comm/comm.cpp new file mode 100644 index 00000000..5e02f047 --- /dev/null +++ b/arctic_training/kernels/comm/comm.cpp @@ -0,0 +1,15 @@ +#include "comm.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("init_comm_group", &ds_create_comm, "create comm group"); + m.def("barrier", &ds_barrier, "barrier"); + m.def("broadcast", &ds_broadcast, "broadcast"); + m.def("wait_comm", &wait_comm, "wait on communication event"); + m.def("allReduce", &ds_allreduce, "AllReduce"); + m.def("alltoall", &ds_alltoall, "AllToAll"); + m.def("allGather", &ds_allgather, "AllGather"); + m.def("get_nccl_uid", &ds_get_nccl_uid, "Get NCCL UID"); + m.def("init_nccl_comm", &ds_create_nccl_comm, "Create NCCL Comm"); +} +////// diff --git a/arctic_training/kernels/comm/comm.cu b/arctic_training/kernels/comm/comm.cu new file mode 100644 index 00000000..a04334ca --- /dev/null +++ b/arctic_training/kernels/comm/comm.cu @@ -0,0 +1,124 @@ +#include "comm.h" +#include +#include + +#include +#include "context.h" + +void ds_create_comm(std::vector& comm_ranks, int rank) +{ + CoMMContext::Instance().create_comm_group(comm_ranks, rank); +} + +void ds_create_nccl_comm(std::vector& comm_ranks, int rank, torch::Tensor& nccl_uid) +{ + CoMMContext::Instance().create_nccl_comm(comm_ranks, rank, nccl_uid.data_ptr()); +} + +void ds_allreduce(torch::Tensor& send_buf, torch::Tensor& rcv_buf, int size, bool async_op) +{ + if (async_op) CoMMContext::Instance().SynchComp(); + ncclAllReduce(send_buf.data_ptr(), + rcv_buf.data_ptr(), + size, + (send_buf.scalar_type() == at::kFloat ? ncclFloat : (send_buf.scalar_type() == at::kHalf ? ncclHalf : ncclBfloat16)), + ncclSum, + CoMMContext::Instance().GetNCCLComm(), + CoMMContext::Instance().GetCommStream()); +} + +void ds_allgather(torch::Tensor& send_buf, torch::Tensor& rcv_buf, int size, bool async_op) +{ + if (async_op) CoMMContext::Instance().SynchComp(); + ncclAllGather(send_buf.data_ptr(), + rcv_buf.data_ptr(), + size, + (send_buf.scalar_type() == at::kFloat ? ncclFloat : (send_buf.scalar_type() == at::kHalf ? ncclHalf : ncclBfloat16)), + CoMMContext::Instance().GetNCCLComm(), + CoMMContext::Instance().GetCommStream()); +} + +void wait_comm() { CoMMContext::Instance().SynchComm(); } + +void ds_broadcast(torch::Tensor& send_buf, torch::Tensor& rcv_buf, int size, bool async_op) +{ + ncclBroadcast(send_buf.data_ptr(), + rcv_buf.data_ptr(), + size, + (send_buf.scalar_type() == at::kFloat ? ncclFloat : (send_buf.scalar_type() == at::kHalf ? ncclHalf : ncclBfloat16)), + 0, + CoMMContext::Instance().GetNCCLComm(), + CoMMContext::Instance().GetCommStream()); +} + +void ds_barrier() { CoMMContext::Instance().barrier(); } + +inline size_t wordSize(ncclDataType_t type) { + switch(type) { + case ncclChar: + case ncclUint8: + return 1; + case ncclHalf: + case ncclBfloat16: + return 2; + case ncclInt: + case ncclFloat: + case ncclUint32: + return 4; + case ncclInt64: + case ncclUint64: + case ncclDouble: + return 8; + default: return 0; + } +} + +void ncclAlltoAll(void* sendbuff, + void* recvbuff, + int32_t *send_counts, + int32_t *recv_counts, + size_t max_count, + ncclDataType_t type, + const unsigned nRanks, + ncclComm_t comm, + cudaStream_t stream) { + + size_t rankOffset = max_count * wordSize(type); + + ncclGroupStart(); + for (int r=0; r +#include + +#include "stdio.h" + +void ds_barrier(); +void wait_comm(); +void ds_create_comm(std::vector& comm_ranks, int rank); +void ds_allreduce(torch::Tensor& send_buf, torch::Tensor& rcv_buf, int size, bool async_op); +void ds_allgather(torch::Tensor& send_buf, torch::Tensor& rcv_buf, int size, bool async_op); +void ds_broadcast(torch::Tensor& send_buf, torch::Tensor& rcv_buf, int size, bool async_op); +void ds_alltoall(torch::Tensor& send_buf, torch::Tensor& rcv_buf, torch::Tensor& send_counts, torch::Tensor& recv_counts, size_t max_count, bool async_op); +torch::Tensor ds_get_nccl_uid(); +void ds_create_nccl_comm(std::vector& comm_ranks, int rank, torch::Tensor& nccl_uid); diff --git a/arctic_training/kernels/comm/comm.py b/arctic_training/kernels/comm/comm.py new file mode 100644 index 00000000..f8550a17 --- /dev/null +++ b/arctic_training/kernels/comm/comm.py @@ -0,0 +1,138 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from arctic_training.op_builder import CommBuilder + +from .layout import Layout + +ds_comm = None + + +class Comm: + current_comm = None + + def __init__(self, layout: Layout, local_rank: int): + global ds_comm + if ds_comm is None: + ds_comm = CommBuilder().load() + + self.ds_comm = ds_comm + self._layout = layout + self.my_rank = local_rank + self.global_rank = torch.distributed.get_rank() + self.group_size = layout._group_size + + self.global_ranks = layout.sibling_ranks(self.global_rank) + self.local_ranks = list(range(self.group_size)) + self.rank_map = dict(zip(self.global_ranks, self.local_ranks)) + self.counts_pinned_data = torch.empty(1024, dtype=torch.int32, device="cpu", pin_memory=True) + self.recv_counts_pinned_data = torch.empty(1024, dtype=torch.int32, device="cpu", pin_memory=True) + + print("Initializing comm ...") + + def all_reduce(self, val, inplace=True, async_op=False): + val_sum = val if inplace else torch.empty_like(val) + op = communicate_op(val, val_sum, async_op, op_type="all_reduce") + return val_sum, op + + def all_gather(self, val, inplace=True, async_op=False): + val_gather = torch.empty((self.group_size * val.size(0), *val.shape[1:]), device=val.device, dtype=val.dtype) + op = communicate_op(val, val_gather, async_op, op_type="all_gather") + return val_gather, op + + def all_to_all( + self, val, counts=None, receive_counts=None, max_count=None, result=None, inplace=True, async_op=False + ): + if counts is not None: + if receive_counts is None: + receive_counts = torch.empty_like(counts) + torch.distributed.all_to_all_single(receive_counts, counts) + + self.counts_pinned_data[: receive_counts.numel()].copy_(counts) + + if max_count is None: + max_count = counts.max() + torch.distributed.all_reduce(max_count, op=torch.distributed.ReduceOp.MAX) + + max_count = max_count.item() + + receive_counts = self.recv_counts_pinned_data[: receive_counts.numel()].copy_(receive_counts) + counts = self.counts_pinned_data[: counts.numel()] + else: + max_count = val.size(0) // self.group_size + counts = torch.full((self.group_size,), max_count, device="cpu", dtype=torch.int32) + receive_counts = counts + + result = result if result is not None else torch.empty_like(val) + op = communicate_op( + val, + result, + async_op, + world_size=self.group_size, + op_type="all_to_all", + send_counts=counts, + recv_counts=receive_counts, + max_count=max_count, + ) + return result, op + + def broadcast(self, val, inplace=True, async_op=False): + val_bcst = torch.empty_like(val) + op = communicate_op(val, val_bcst, async_op, op_type="broadcast") + return val_bcst, op + + def barrier(self): + ds_comm.wait_comm() + ds_comm.barrier() + + @classmethod + def get_current_comm(cls): + if cls.current_comm is None: + from arctic_training.kernels.comm.nccl import NcclComm + + cls.current_comm = NcclComm() + return cls.current_comm + + +class communicate_op: + def __init__( + self, + val, + result, + async_op, + world_size=None, + op_type="all_reduce", + send_counts=None, + recv_counts=None, + max_count=None, + ): + if op_type == "all_reduce": + ds_comm.allReduce(val, result, val.numel(), async_op) + elif op_type == "all_gather": + ds_comm.allGather(val, result, val.numel(), async_op) + elif op_type == "all_to_all": + ds_comm.alltoall(val, result, send_counts, recv_counts, max_count, async_op) + elif op_type == "broadcast": + ds_comm.broadcast(val, result, val.numel(), async_op) + + def wait(self): + ds_comm.wait_comm() + + +def get_default_comm(): + return Comm.get_current_comm() diff --git a/arctic_training/kernels/comm/context.h b/arctic_training/kernels/comm/context.h new file mode 100644 index 00000000..5611f71d --- /dev/null +++ b/arctic_training/kernels/comm/context.h @@ -0,0 +1,184 @@ +#pragma once + +#include +#include +#include +#include +#include +#include "cublas_v2.h" +#include "cuda.h" +#include "curand.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#define WARP_SIZE 32 + +#define CUDA_CHECK(callstr) \ + { \ + cudaError_t error_code = callstr; \ + if (error_code != cudaSuccess) { \ + std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \ + assert(0); \ + } \ + } + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) + +#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \ + for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y) + +#define DS_CUDA_NUM_THREADS 512 +#define DS_MAXIMUM_NUM_BLOCKS 262144 + +inline int DS_GET_BLOCKS(const int N) +{ + return std::max( + std::min((N + DS_CUDA_NUM_THREADS - 1) / DS_CUDA_NUM_THREADS, DS_MAXIMUM_NUM_BLOCKS), + // Use at least 1 block, since CUDA does not allow empty block + 1); +} + +class CoMMContext { +public: + CoMMContext() + : _workspace(nullptr), + _seed(42), + _curr_offset(0), + _comm_stream(0), + _comp_stream(0), + _comm_created(false) + { + curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT); + curandSetPseudoRandomGeneratorSeed(_gen, 123); + if (cublasCreate(&_cublasHandle) != CUBLAS_STATUS_SUCCESS) { + auto message = std::string("Fail to create cublas handle."); + std::cerr << message << std::endl; + throw std::runtime_error(message); + } + cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH); + cudaEventCreate(&_comp_event, (cudaEventDisableTiming | cudaEventBlockingSync)); + cudaEventCreate(&_comm_event, (cudaEventDisableTiming | cudaEventBlockingSync)); + } + + virtual ~CoMMContext() + { + cublasDestroy(_cublasHandle); + cudaFree(_workspace); + ncclCommDestroy(_nccl_comm); + cudaEventDestroy(_comp_event); + cudaEventDestroy(_comm_event); + } + + static CoMMContext& Instance() + { + static CoMMContext _ctx; + return _ctx; + } + + void create_comm_group(std::vector comm_ranks, int rank) + { + // + } + + inline ncclUniqueId get_nccl_uid() + { + ncclUniqueId _nccl_uid; + ncclGetUniqueId(&_nccl_uid); + return _nccl_uid; + } + + void create_nccl_comm(std::vector comm_ranks, int rank, void* nccl_uid_ptr){ + + unsigned num_ranks = comm_ranks.size(); + ncclUniqueId _nccl_uid = *((ncclUniqueId*)nccl_uid_ptr); + _nranks = num_ranks; + ncclCommInitRank(&_nccl_comm, num_ranks, _nccl_uid, rank); + printf("********** nccl comm: %p \n", _nccl_comm); + + } + inline ncclComm_t GetNCCLComm() { return _nccl_comm; } + + inline unsigned GetNumRanks() const { return _nranks; } + + inline void barrier() { + // + } + + inline void SynchComp() + { + cudaEventRecord(_comp_event, _comp_stream); + cudaStreamWaitEvent(_comm_stream, _comp_event, 0); + } + inline void SynchComm() + { + cudaEventRecord(_comm_event, _comm_stream); + cudaStreamWaitEvent(_comp_stream, _comm_event, 0); + } + void GenWorkSpace(size_t size) + { + if (!_workspace) { + assert(_workspace == nullptr); + cudaMalloc(&_workspace, size); + } else if (_workSpaceSize < size) { + cudaFree(_workspace); + cudaMalloc(&_workspace, size); + } + + _workSpaceSize = size; + } + + void* GetWorkSpace() { return _workspace; } + + curandGenerator_t& GetRandGenerator() { return _gen; } + + cudaStream_t GetCurrentStream() + { + // get current pytorch stream. + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + return stream; + } + + cudaStream_t GetCommStream(bool async_op = false) + { + // if (!_comm_stream) + // _comm_stream = async_op ? at::cuda::getStreamFromPool(true) + // : at::cuda::getCurrentCUDAStream(); + return at::cuda::getCurrentCUDAStream(); //_comm_stream; + } + + cublasHandle_t GetCublasHandle() { return _cublasHandle; } + + std::pair IncrementOffset(uint64_t offset_inc) + { + uint64_t offset = _curr_offset; + _curr_offset += offset_inc; + return std::pair(_seed, offset); + } + + void SetSeed(uint64_t new_seed) { _seed = new_seed; } + +private: + curandGenerator_t _gen; + cublasHandle_t _cublasHandle; + cudaEvent_t _comp_event; + cudaEvent_t _comm_event; + + void* _workspace; + uint64_t _seed; + uint64_t _curr_offset; + size_t _workSpaceSize; + cudaStream_t _comp_stream; + cudaStream_t _comm_stream; + ncclComm_t _nccl_comm; + unsigned _nranks; + bool _comm_created; +}; diff --git a/arctic_training/kernels/comm/layout.py b/arctic_training/kernels/comm/layout.py new file mode 100644 index 00000000..99042e18 --- /dev/null +++ b/arctic_training/kernels/comm/layout.py @@ -0,0 +1,36 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class Layout: + def __init__(self, g_size=None, stride=1, world_size=1): + + num_groups = world_size // g_size + self._group_size = g_size + self._sibling_ranks = [] + for gid in range(num_groups): + if stride == 1: + self._sibling_ranks.append([gid * g_size + i for i in range(g_size)]) + else: + self._sibling_ranks.append([gid % stride + i * stride for i in range(g_size)]) + + def sibling_ranks(self, rank): + for sranks in self._sibling_ranks: + if rank in sranks: + break + return sranks + + def parent_rank(self, rank): + return self.sibling_ranks(rank)[0] diff --git a/arctic_training/kernels/comm/nccl.py b/arctic_training/kernels/comm/nccl.py new file mode 100644 index 00000000..7f8daba3 --- /dev/null +++ b/arctic_training/kernels/comm/nccl.py @@ -0,0 +1,66 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch + +from .comm import Comm +from .layout import Layout + + +class NcclComm(Comm): + + def __init__(self, layout: Layout, set_device: bool = True): + + super().__init__(layout, int(os.getenv("LOCAL_RANK", "0"))) + + dist_world_size = torch.distributed.get_world_size() + + self.comm_group = None + self.parent_rank = 0 + if self.group_size < dist_world_size: + self.parent_rank = self._layout.parent_rank(self.global_rank) + for sranks in self._layout._sibling_ranks: + comm_group = torch.distributed.new_group(sranks) + if self.global_rank in sranks: + self.comm_group = comm_group + + if set_device: + torch.cuda.set_device(f"cuda:{self.my_rank}") + + nccl_uid = torch.tensor([torch.cuda.nccl.unique_id()], dtype=torch.uint8, device=torch.cuda.current_device()) + torch.distributed.broadcast(nccl_uid, self.parent_rank, group=self.comm_group) + self.ds_comm.init_nccl_comm(self.local_ranks, self.rank_map[self.global_rank], nccl_uid.to("cpu").squeeze(0)) + + +class MPIComm(Comm): + + def __init__(self, layout=None, set_device=True): + from mpi4py import MPI + + local_rank = int(MPI.COMM_WORLD.Get_rank()) + super().__init__(layout, local_rank) + if set_device: + torch.cuda.set_device(f"cuda:{self.my_rank}") + self.ds_comm.init_comm_group(self.global_ranks, self.rank_map[self.global_rank]) + + +def create_comm(layout=None, set_device=True, backend="nccl"): + return ( + NcclComm(layout=layout, set_device=set_device) + if backend == "nccl" + else MPIComm(layout=layout, set_device=set_device) + ) diff --git a/arctic_training/kernels/comm/test_comm.py b/arctic_training/kernels/comm/test_comm.py new file mode 100644 index 00000000..1397aed5 --- /dev/null +++ b/arctic_training/kernels/comm/test_comm.py @@ -0,0 +1,149 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import torch +from deepspeed.comm import init_distributed + +from arctic_training.kernels.comm.layout import Layout +from arctic_training.kernels.comm.nccl import create_comm + + +def execute(val, is_torch=True, comm=None, op="AllReduce", counts=None): + if is_torch: + if op == "AllReduce": + torch.distributed.all_reduce(val) + return val + elif op == "AllGather": + world_size = torch.distributed.get_world_size() + result = torch.empty((world_size * val.shape[0],) + val.shape[1:]) + torch.distributed.all_gather_into_tensor(result, val) + return result + elif op == "AlltoAll": + result = torch.empty_like(val) + torch.distributed.all_to_all_single(result, val) + return result + else: + if op == "AllReduce": + val, _ = comm.all_reduce(val) + return val + elif op == "AllGather": + val, _ = comm.all_gather(val) + return val + elif op == "AlltoAll": + if counts is not None: + val, _ = comm.all_to_all(val, counts) + else: + val, _ = comm.all_to_all(val) + return val + + +def init_comm(): + init_distributed(dist_backend="nccl") + global_rank = torch.distributed.get_rank() + group_stride = 1 + group_size = torch.distributed.get_world_size() // group_stride + # gid = global_rank // group_size + comm = create_comm(Layout(group_size, group_stride, world_size=torch.distributed.get_world_size())) + val = torch.arange(group_size, dtype=torch.bfloat16, device=torch.cuda.current_device()) + + # test to see all is working + val, _ = comm.all_to_all(val) + print(f"[{global_rank}]: alltoall -> {val}") + # exit() + return comm + + +def run_nccl_comm_test(comm, op="AllReduce"): + + global_rank = torch.distributed.get_rank() + + inp = torch.randn(8 * 8192 * 4096, dtype=torch.half, device=torch.cuda.current_device()) + # weight1 = torch.randn(4096, 4096, dtype=torch.half, device=torch.cuda.current_device()) + # weight2 = torch.randn(4096, 16384, dtype=torch.half, device=torch.cuda.current_device()) + + # Correctness check + inp_torch = inp.clone() + inp_custom = inp.clone() + torch_val = execute(inp_torch, is_torch=True, op=op) + custom_val = execute(inp_custom, is_torch=False, comm=comm, op=op) + if torch.allclose(torch_val, custom_val): + if global_rank == 0: + print(f"Correctness check passed for {op}") + else: + raise AssertionError(f"Results do not match for {op} on rank {global_rank}") + + if op == "AlltoAll": + counts = [8192 * 4096] * 8 + # counts = [ + # 768 * 4096, + # 1000 * 4096, + # 512 * 4096, + # 1024 * 4096, + # 768 * 4096, + # 1000 * 4096, + # 512 * 4096, + # 1024 * 4096, + # ] # [6144 * 4096] * 8 + counts = torch.tensor(counts, dtype=torch.int32, device=torch.cuda.current_device()) + + def func(*args, **kwargs): + if op == "AlltoAll": + kwargs["counts"] = counts + + return execute(*args, **kwargs) + + for _ in range(10): + # val = torch.matmul(inp.view(-1, 4096), weight1) + func(inp, is_torch=False, comm=comm, op=op) + # out = torch.matmul(val.view(-1, 4096), weight2) + torch.cuda.synchronize() + start = time.time() + for _ in range(1000): + # val = torch.matmul(inp.view(-1, 4096), weight1) + func(inp, is_torch=False, comm=comm, op=op) + # out = torch.matmul(val.view(-1, 4096), weight2) + torch.cuda.synchronize() + end = time.time() + ds_time = end - start + if global_rank == 0: + print(f"------------------- ds_comm execution time for {op}: {end - start} ms -------------------") + # print(f'[{global_rank}]: {end - start} ms') + + for _ in range(10): + # val = torch.matmul(inp.view(-1, 4096), weight1) + func(inp, is_torch=True, op=op) + # out = torch.matmul(val.view(-1, 4096), weight2) + + torch.cuda.synchronize() + start = time.time() + for _ in range(1000): + # val = torch.matmul(inp.view(-1, 4096), weight1) + func(inp, is_torch=True, op=op) + # out = torch.matmul(val.view(-1, 4096), weight2) + torch.cuda.synchronize() + end = time.time() + pt_time = end - start + if global_rank == 0: + print(f"------------------- torch execution time for {op}: {end - start} ms -------------------") + print(f"speedup: {pt_time / ds_time}x") + # print(f'[{global_rank}]: {end - start} ms') + + +comm = init_comm() +run_nccl_comm_test(comm=comm, op="AllReduce") +run_nccl_comm_test(comm=comm, op="AlltoAll") +torch.distributed.destroy_process_group() diff --git a/arctic_training/kernels/moe_ops/moe_scatter/moe_scatter.py b/arctic_training/kernels/moe_ops/moe_scatter/moe_scatter.py index c2cbccc6..cb486a1a 100644 --- a/arctic_training/kernels/moe_ops/moe_scatter/moe_scatter.py +++ b/arctic_training/kernels/moe_ops/moe_scatter/moe_scatter.py @@ -25,12 +25,12 @@ def forward(ctx, inf_module, expert_cumsum, mapped_slots, activations, expert_co ctx.n_experts = expert_counts.shape[0] ctx.save_for_backward(mapped_slots, assignments) n_tokens, n_top_k = assignments.shape - num_experts = expert_counts.shape[0] + # num_experts = expert_counts.shape[0] max_capacity_per_expert = expert_counts.max() - torch.distributed.all_reduce(max_capacity_per_expert, op=torch.distributed.ReduceOp.MAX) + # torch.distributed.all_reduce(max_capacity_per_expert, op=torch.distributed.ReduceOp.MAX) moe_input = torch.zeros( - max_capacity_per_expert * num_experts, - # n_tokens * n_top_k, + # max_capacity_per_expert * num_experts, + n_tokens * n_top_k, activations.shape[1], device=activations.device, dtype=activations.dtype, @@ -46,6 +46,8 @@ def forward(ctx, inf_module, expert_cumsum, mapped_slots, activations, expert_co offsets, max_capacity_per_expert, ) + # print(f'Max capacity per expert: {max_capacity_per_expert} expert_counts: {expert_counts}') + # exit(0) return moe_input, expert_cumsum, mapped_slots, max_capacity_per_expert @staticmethod diff --git a/arctic_training/kernels/moe_ops/moe_scatter/moe_scatter_cuda.cu b/arctic_training/kernels/moe_ops/moe_scatter/moe_scatter_cuda.cu index 02000cac..26f3b919 100644 --- a/arctic_training/kernels/moe_ops/moe_scatter/moe_scatter_cuda.cu +++ b/arctic_training/kernels/moe_ops/moe_scatter/moe_scatter_cuda.cu @@ -168,7 +168,7 @@ __global__ void moe_scatter_kernel(T* moe_input, assignments, \ expert_counts, \ offsets, \ - max_capacity_per_expert, \ + nullptr, \ n_channels, \ n_experts); \ break; diff --git a/arctic_training/model/moe/alltoall.py b/arctic_training/model/moe/alltoall.py index a5eb4fef..9cfe25a9 100644 --- a/arctic_training/model/moe/alltoall.py +++ b/arctic_training/model/moe/alltoall.py @@ -50,9 +50,11 @@ class AlltoAllFunction(torch.autograd.Function): @staticmethod def forward(ctx, group, x): x = x.contiguous() - y = torch.empty_like(x) - dist.all_to_all_single(y, x, group=group) - return y + # y = torch.empty_like(x) + ctx.group = group + # dist.all_to_all_single(y, x, group=group) + return x + # return y @staticmethod def backward(ctx, grad_output): @@ -61,3 +63,30 @@ def backward(ctx, grad_output): def AlltoAll(*args, **kwargs): return AlltoAllFunction.apply(*args, **kwargs) + + +class CustomAlltoAllFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, comm, x, counts, max_count): + ctx.comm = comm + + receive_counts = torch.empty_like(counts) + torch.distributed.all_to_all_single(receive_counts, counts) + + y = comm.all_to_all(x, counts=counts, receive_counts=receive_counts, max_count=max_count)[0] + + ctx.save_for_backward(receive_counts, max_count, counts) + return y, receive_counts + + @staticmethod + def backward(ctx, *grad_outputs): + receive_counts, max_count, counts = ctx.saved_tensors + grad_output = grad_outputs[0].contiguous() + grad_input = ctx.comm.all_to_all( + grad_output, counts=receive_counts, receive_counts=counts, max_count=max_count + )[0] + return (None, grad_input, None, None) + + +def CustomAlltoAll(*args, **kwargs): + return CustomAlltoAllFunction.apply(*args, **kwargs) diff --git a/arctic_training/model/moe/moe.py b/arctic_training/model/moe/moe.py index 9756491d..6b87e317 100644 --- a/arctic_training/model/moe/moe.py +++ b/arctic_training/model/moe/moe.py @@ -24,10 +24,15 @@ from arctic_training.debug.utils import pr0 # noqa, currently using it on/off a lot from arctic_training.debug.utils import see_memory_usage +from arctic_training.kernels.comm.layout import Layout +from arctic_training.kernels.comm.nccl import create_comm from arctic_training.model.moe.alltoall import AlltoAll from arctic_training.model.moe.alltoall import AlltoAllV +from arctic_training.model.moe.alltoall import CustomAlltoAll from arctic_training.model.moe.timers import SynchronizedWallClockTimerSimple +_custom_comm = None + # from arctic_training.debug.utils import pr0 @dataclass(kw_only=True) @@ -52,6 +57,7 @@ class MoEConfig: def torch_group_gemm_fn(A, B, rows_cumsum): + return torch._grouped_mm(A, B, offs=rows_cumsum.to(torch.int32)) C = torch.zeros((rows_cumsum[-1], B.shape[-1]), device=A.device, dtype=A.dtype) for i in range(len(rows_cumsum)): start = 0 if i == 0 else rows_cumsum[i - 1] @@ -66,6 +72,8 @@ class ArcticMoE(nn.Module): config: MoEConfig object """ + layer_id = 0 + def __init__(self, config: MoEConfig): super(ArcticMoE, self).__init__() self._config = config @@ -79,7 +87,8 @@ def __init__(self, config: MoEConfig): self.num_experts = config.num_experts self.top_k = config.top_k self.use_custom_kernel = config.use_custom_moe_kernel - + ArcticMoE.layer_id = ArcticMoE.layer_id + 1 + self.layerId = ArcticMoE.layer_id # profiler self.timers = SynchronizedWallClockTimerSimple() self.wall_clock_breakdown = False @@ -147,6 +156,10 @@ def __init__(self, config: MoEConfig): self.moe_scatter = RaggedMoEScatterModule() self.moe_gather = RaggedMoEGatherModule(normalize_scores=self._config.normalize_topk_scores) self.topk_kernel = RaggedTopKGatingModule() + global _custom_comm + if _custom_comm is None: + _custom_comm = create_comm(Layout(self.ep_size, stride=1, world_size=dist.get_world_size())) + self.custom_comm = _custom_comm self.enable_routing_replay = config.enable_routing_replay @@ -250,16 +263,28 @@ def forward(self, hidden_states, routing_replay_assignments: torch.Tensor = None self.timers.stop("router") see_memory_usage("after router", force=False) - if self.ep_size > 1 and not self.use_custom_kernel: - self.timers.start("a2a-v1") - moe_input = self.alltoall_V(moe_input, expert_token_count, expert_token_rcv_count) - self.timers.stop("a2a-v1") + # hack to disable custom kernel for the rest of the forward pass to get a sense of the breakdown without the custom kernel - + # will be removing this soon once we have more confidence in the custom kernel performance and correctness + prev_custom_kernel_flag = self.use_custom_kernel + self.use_custom_kernel = False - self.timers.start("a2a-v1 transpose") - moe_input, expert_token_count_cumsum, expert_token_count_transposed = self.local_ep_transpose( - moe_input, expert_token_rcv_count - ) - self.timers.stop("a2a-v1 transpose") + if self.ep_size > 1: + if self.use_custom_kernel: + if False: + moe_input, expert_token_rcv_count = CustomAlltoAll( + self.custom_comm, moe_input, expert_token_count, max_capacity_per_expert + ) + moe_input = AlltoAll(self.ep_group, moe_input) + else: + self.timers.start("a2a-v1") + moe_input = self.alltoall_V(moe_input, expert_token_count, expert_token_rcv_count) + self.timers.stop("a2a-v1") + + self.timers.start("a2a-v1 transpose") + moe_input, expert_token_count_cumsum, expert_token_count_transposed = self.local_ep_transpose( + moe_input, expert_token_rcv_count + ) + self.timers.stop("a2a-v1 transpose") else: expert_token_count_cumsum = expert_token_count.cumsum(0) @@ -275,6 +300,11 @@ def forward(self, hidden_states, routing_replay_assignments: torch.Tensor = None if self.ep_size > 1: self.timers.start("a2a-v2 permute") if self.use_custom_kernel: + # custom kernel already outputs in the correct order for the alltoall, so we can skip the local transpose and just do alltoall directly + if False: + moe_output = CustomAlltoAll( + self.custom_comm, moe_output, expert_token_rcv_count, max_capacity_per_expert + ) moe_output = AlltoAll(self.ep_group, moe_output) else: if self.wall_clock_breakdown: @@ -294,6 +324,8 @@ def forward(self, hidden_states, routing_replay_assignments: torch.Tensor = None self.timers.stop("a2a-v2") see_memory_usage("after alltoall_V", force=False) + self.use_custom_kernel = prev_custom_kernel_flag + self.timers.start("moe-combine") if self.use_custom_kernel: output = self.CustomMoEGather(moe_output, token_mapped_slots, scores) @@ -422,7 +454,12 @@ def CustomTopKRouter(self, logits, hidden_states, routing_replay_assignments): self.expert_counts, assignments, offsets, logits ) - expert_token_rcv_count = self.expert_counts + if self.ep_size == 1: + expert_token_rcv_count = self.expert_counts + else: + expert_token_rcv_count = torch.empty_like(self.expert_counts) + # with torch.cuda.stream(self.comm_stream): + dist.all_to_all_single(expert_token_rcv_count, self.expert_counts, group=self.ep_group) (moe_input, self.expert_cumsum, mapped_slots, max_capacity_per_expert) = self.moe_scatter( self.expert_cumsum, mapped_slots, hidden_states, self.expert_counts, assignments, offsets @@ -431,7 +468,7 @@ def CustomTopKRouter(self, logits, hidden_states, routing_replay_assignments): return scores, moe_input, mapped_slots, self.expert_counts, expert_token_rcv_count, max_capacity_per_expert def torch_bached_gemm(self, input, max_capacity_per_expert): - input = input.reshape(-1, max_capacity_per_expert, input.shape[-1]) + input = input.reshape(-1, self.ep_size * max_capacity_per_expert, input.shape[-1]) intermediate = torch.bmm( input, self.expert_gate_up, diff --git a/arctic_training/model/moe/utils.py b/arctic_training/model/moe/utils.py index f50a77b0..4a4d48cf 100644 --- a/arctic_training/model/moe/utils.py +++ b/arctic_training/model/moe/utils.py @@ -14,6 +14,7 @@ # limitations under the License. import copy +import gc import os import re import time @@ -258,6 +259,10 @@ def copy_weights(from_name, to_param, local_expert_indices, transpose=False): # putting the gate and up weigths in every-other order to match arctic-moe style gate_up = torch.stack((gate_stacked, up_stacked), dim=-1).view(*up_stacked.shape[:-1], -1).contiguous() + gate_stacked = None # free memory + up_stacked = None # free memory + gc.collect() + torch.cuda.empty_cache() # pr0(f"{gate_up.shape=}", force=True) # pr0(f"{arctic_moe.expert_gate_up.shape=}", force=True) arctic_moe.expert_gate_up.copy_(gate_up) diff --git a/arctic_training/op_builder/__init__.py b/arctic_training/op_builder/__init__.py index b2f25e75..d5db8399 100644 --- a/arctic_training/op_builder/__init__.py +++ b/arctic_training/op_builder/__init__.py @@ -13,4 +13,5 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .comm import CommBuilder from .moe_ops import RaggedOpsBuilder diff --git a/arctic_training/op_builder/comm.py b/arctic_training/op_builder/comm.py new file mode 100644 index 00000000..5d8da1b0 --- /dev/null +++ b/arctic_training/op_builder/comm.py @@ -0,0 +1,111 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from .builder import CUDAOpBuilder +from .builder import installed_cuda_version + + +class CommBuilder(CUDAOpBuilder): + BUILD_VAR = "AT_BUILD_COMM" + NAME = "COMM_OPS" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f"arctic_training.{self.NAME}" + + def is_compatible(self, verbose=False): + try: + import torch + except ImportError: + if verbose: + self.warning("Please install torch if trying to pre-compile arctic_training kernels") + return False + + cuda_okay = True + if torch.cuda.is_available(): # ignore-cuda + sys_cuda_major, _ = installed_cuda_version() + torch_cuda_major = int(torch.version.cuda.split(".")[0]) + cuda_capability = torch.cuda.get_device_properties(0).major # ignore-cuda + if cuda_capability < 6: + if verbose: + self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") + cuda_okay = False + if cuda_capability >= 8: + if torch_cuda_major < 11 or sys_cuda_major < 11: + if verbose: + self.warning("On Ampere and higher architectures please use CUDA 11+") + cuda_okay = False + return super().is_compatible(verbose) and cuda_okay + + def filter_ccs(self, ccs): + ccs_retained = [] + ccs_pruned = [] + for cc in [cc.split(".") for cc in ccs]: + if int(cc[0]) >= 8: + # Blocked flash has a dependency on Ampere + newer + ccs_retained.append(cc) + else: + ccs_pruned.append(cc) + if len(ccs_pruned) > 0: + self.warning(f"Filtered compute capabilities {ccs_pruned}") + return ccs_retained + + def get_prefix(self): + ai_path = self._src_path("arctic_training") + return "arctic_training" if os.path.isdir(ai_path) else ".." + + def sources(self): + sources = [ + "arctic_training/kernels/comm/comm.cu", + "arctic_training/kernels/comm/comm.cpp", + ] + + prefix = self.get_prefix() + sources = [os.path.join(prefix, src) for src in sources] + return sources + + def extra_ldflags(self): + return [] + + def include_paths(self): + include_dirs = ["arctic_training/kernels/includes", "arctic_training/kernels/comm"] + prefix = self.get_prefix() + includes = [os.path.join(prefix, include_dir) for include_dir in include_dirs] + + return includes diff --git a/arctic_training/trainer/causal_trainer.py b/arctic_training/trainer/causal_trainer.py index 0fd2c9b3..ee0cb0da 100644 --- a/arctic_training/trainer/causal_trainer.py +++ b/arctic_training/trainer/causal_trainer.py @@ -19,7 +19,9 @@ from arctic_training.data.causal_factory import CausalDataFactory from arctic_training.model.hf_factory import HFModelFactory from arctic_training.model.liger_factory import LigerModelFactory +from arctic_training.optimizer.adam_factory import CPUAdamMoEOptimizerFactory from arctic_training.optimizer.adam_factory import CPUAdamOptimizerFactory +from arctic_training.optimizer.adam_factory import FusedAdamMoEOptimizerFactory from arctic_training.optimizer.adam_factory import FusedAdamOptimizerFactory from arctic_training.scheduler.hf_factory import HFSchedulerFactory from arctic_training.tokenizer.hf_factory import HFTokenizerFactory @@ -33,7 +35,9 @@ class CausalTrainer(Trainer): data_factory: CausalDataFactory model_factory: Union[HFModelFactory, LigerModelFactory] checkpoint_engine: Union[DSCheckpointEngine, HFCheckpointEngine] - optimizer_factory: Union[FusedAdamOptimizerFactory, CPUAdamOptimizerFactory] + optimizer_factory: Union[ + FusedAdamOptimizerFactory, FusedAdamMoEOptimizerFactory, CPUAdamOptimizerFactory, CPUAdamMoEOptimizerFactory + ] scheduler_factory: Union[HFSchedulerFactory] tokenizer_factory: Union[HFTokenizerFactory] diff --git a/arctic_training/trainer/trainer.py b/arctic_training/trainer/trainer.py index 168d490b..3d9e72fe 100644 --- a/arctic_training/trainer/trainer.py +++ b/arctic_training/trainer/trainer.py @@ -38,6 +38,7 @@ from tqdm import tqdm from transformers import set_seed from transformers.integrations.deepspeed import HfDeepSpeedConfig +from wandb import util as wandb_util # type: ignore from wandb.sdk.wandb_run import Run as WandbRun from arctic_training.callback.logging import post_loss_log_cb