Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def add_parser_api_server():
hf_overrides = ArgumentHelper.hf_overrides(pt_group)
disable_metrics = ArgumentHelper.disable_metrics(pt_group)
dp = ArgumentHelper.dp(pt_group)
ArgumentHelper.ep(pt_group)
ep = ArgumentHelper.ep(pt_group)
ArgumentHelper.enable_microbatch(pt_group)
ArgumentHelper.enable_eplb(pt_group)
ArgumentHelper.role(pt_group)
Expand All @@ -148,6 +148,7 @@ def add_parser_api_server():
tb_group._group_actions.append(hf_overrides)
tb_group._group_actions.append(disable_metrics)
tb_group._group_actions.append(dp)
tb_group._group_actions.append(ep)
ArgumentHelper.cp(tb_group)
ArgumentHelper.rope_scaling_factor(tb_group)
ArgumentHelper.num_tokens_per_iter(tb_group)
Expand Down Expand Up @@ -255,6 +256,7 @@ def api_server(args):
tp=args.tp,
dp=args.dp,
cp=args.cp,
ep=args.ep,
nnodes=args.nnodes,
node_rank=args.node_rank,
dist_init_addr=args.dist_init_addr,
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ class TurbomindEngineConfig:
tp: int = 1
dp: int = 1
cp: int = 1
ep: int = 1
device_num: int = None
attn_tp_size: int = None
attn_cp_size: int = None
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/turbomind/deploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class ModelConfig:
attn_tp_size: int = 1
attn_cp_size: int = 1
mlp_tp_size: int = 1
ep_size: int = 1
model_format: str = 'hf'
expert_num: list[int] = field(default_factory=list)
expert_router_bias: bool = False
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/turbomind/deploy/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def get_tm_model(model_path,
tm_cfg.model_config.attn_cp_size = engine_config.attn_cp_size
if engine_config.mlp_tp_size is not None:
tm_cfg.model_config.mlp_tp_size = engine_config.mlp_tp_size
tm_cfg.model_config.ep_size = engine_config.ep

output_model = OUTPUT_MODELS.get(output_model_name)(input_model=input_model,
cfg=tm_cfg,
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/turbomind/deploy/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class Ffn(Module):

def __init__(self, model: BaseOutputModel):
self.model = model
self.tp = model.mlp_tp_size
self.tp = model.mlp_tp_size if model.model_config.ep_size == 1 else 1
# inter_sizes in config are padded and may be different from what's
# in the weights
self.inter_size = model.model_config.inter_size
Expand Down
21 changes: 20 additions & 1 deletion lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,26 @@ def complete_parallel_config(cfg: TurbomindEngineConfig):

def update_parallel_config(cfg: TurbomindEngineConfig):
cfg.device_num = len(cfg.devices) * cfg.nnodes if cfg.devices else cfg.device_num
if not complete_parallel_config(cfg):
if not complete_parallel_config(cfg) and cfg.ep > 1:
if cfg.communicator in ['cuda-ipc', 'native']:
assert cfg.nnodes == 1, 'TurboMind does not support multi-node with ep > 1'
total = cfg.dp * cfg.ep
if not cfg.device_num:
count = torch.cuda.device_count() * cfg.nnodes
if total < count:
count = total
cfg.device_num = count
assert total % cfg.device_num == 0
overlap = total // cfg.device_num
attn_dp_size = overlap
inner_tp_size = cfg.ep // overlap
cfg.outer_dp_size = cfg.dp // overlap
cfg.attn_dp_size = overlap // cfg.nnodes
cfg.attn_tp_size = inner_tp_size // cfg.cp
cfg.attn_cp_size = cfg.cp
cfg.mlp_dp_size = 1
cfg.mlp_tp_size = cfg.attn_dp_size * cfg.attn_tp_size * cfg.attn_cp_size
Comment on lines +89 to +108
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EP path can compute attn_dp_size = overlap // cfg.nnodes, which becomes 0 for common multi-node cases (e.g., overlap==1 and nnodes>1), violating later invariants and producing invalid parallel config. Since device_num already accounts for nnodes, avoid dividing overlap by nnodes here (or otherwise ensure attn_dp_size>=1 with a correct derivation).

Copilot uses AI. Check for mistakes.
elif not complete_parallel_config(cfg):
total = cfg.dp * cfg.tp
if not cfg.device_num:
count = torch.cuda.device_count() * cfg.nnodes
Expand Down
84 changes: 84 additions & 0 deletions src/turbomind/comm/device_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,58 @@
#include <cuda_runtime.h>

#include "src/turbomind/comm/host_comm.h"
#include "src/turbomind/core/buffer.h"
#include "src/turbomind/core/tensor.h"

namespace turbomind::comm {

struct EpConfig {
int num_nodes;
int num_experts;
int hidden;
int ll_max_tokens_per_rank;
};

enum EpMode
{
kNull,
kHighThroughput,
kLowLatency,
};

struct EpDispatchInput {
EpMode& mode;
core::Tensor& x;
core::Tensor_<float>& topk_weights;
core::Tensor_<int64_t>& topk_idx;
};

struct EpDispatchOutput {
core::Tensor out_x;
core::Tensor out_topk_weights;
core::Buffer_<int>& f2n;
core::Buffer_<int>& f2E;
core::Buffer_<int>& en2f;
core::Buffer_<int>& offsets;

std::vector<core::Tensor> handle;

int out_token_num;
int out_expert_token_num;
};

struct EpCombineInput {
EpMode& mode;
core::Tensor& x;
std::vector<core::Tensor>& handle;
std::optional<core::Tensor> topk_weights;
std::optional<core::Tensor> topk_idx;
};
Comment on lines +52 to +58
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EpCombineInput uses std::optional, but this header doesn't include <optional>, which will cause compilation errors depending on include order. Add #include <optional> (and keep headers self-contained).

Copilot uses AI. Check for mistakes.

struct EpCombineOutput {
core::Tensor out_x;
};

enum QueryAttr
{
kHasAllGather2D
Expand Down Expand Up @@ -117,6 +166,41 @@ class DeviceCommImpl {
{
throw std::runtime_error("not implemented");
}

virtual void ReduceScatterV(const void* sendbuff, //
void* recvbuff,
const size_t* counts,
DataType type,
int group,
cudaStream_t stream)
{
throw std::runtime_error("not implemented");
}

virtual void AllGatherV(const void* sendbuff, //
void* recvbuff,
const size_t* counts,
DataType type,
int group,
cudaStream_t stream)
{
throw std::runtime_error("not implemented");
}

virtual void InitializeEp(const EpConfig& config)
{
throw std::runtime_error("ep not implemented");
}

virtual void Dispatch(const EpDispatchInput& input, EpDispatchOutput& output, int group)
{
throw std::runtime_error("not implemented");
}

virtual void Combine(const EpCombineInput& input, EpCombineOutput& output, int group)
{
throw std::runtime_error("not implemented");
}
};

class DeviceComm {
Expand Down
21 changes: 19 additions & 2 deletions src/turbomind/comm/nccl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,25 @@

cmake_minimum_required(VERSION 3.11)

add_library(nccl_comm STATIC nccl.cu)
target_link_libraries(nccl_comm PRIVATE rms_norm core ${NCCL_LIBRARIES} logger)
set(DEEP_EP_SOURCE_FILES
deep_ep/deep_ep.cpp
deep_ep/gin_backend.cu
deep_ep/kernels/runtime.cu
deep_ep/kernels/layout.cu
deep_ep/kernels/intranode.cu
deep_ep/kernels/internode.cu
deep_ep/kernels/internode_ll.cu
)

add_library(deepep STATIC ${DEEP_EP_SOURCE_FILES})
target_link_libraries(deepep PRIVATE ${NCCL_LIBRARIES} CUDA::cudart)
set_property(TARGET deepep PROPERTY CUDA_ARCHITECTURES 90)
target_include_directories(deepep PRIVATE ${NCCL_INCLUDE_DIRS})
set_property(TARGET deepep PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET deepep PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)

add_library(nccl_comm STATIC nccl.cu nccl_ep.cu)
target_link_libraries(nccl_comm PRIVATE rms_norm core ${NCCL_LIBRARIES} logger deepep)
target_include_directories(nccl_comm PRIVATE ${NCCL_INCLUDE_DIRS})

set_property(TARGET nccl_comm PROPERTY POSITION_INDEPENDENT_CODE ON)
Expand Down
Loading
Loading