diff --git a/xtuner/v1/ray/config/worker.py b/xtuner/v1/ray/config/worker.py index 8fbb4c88a..55d28fc1f 100644 --- a/xtuner/v1/ray/config/worker.py +++ b/xtuner/v1/ray/config/worker.py @@ -143,6 +143,13 @@ class RolloutConfig(BaseModel): help="Number of GPUs allocated for each inference engine in the rollout worker.", ), ] = 1 + data_parallel_size: Annotated[ + int, + Parameter( + group=infer_group, + help="Number of GPUs allocated for processing data batches in parallel (Data Parallelism).", + ), + ] = 1 expert_parallel_size: Annotated[ int, Parameter( diff --git a/xtuner/v1/ray/rollout/vllm.py b/xtuner/v1/ray/rollout/vllm.py index 400db51ae..8fd077114 100644 --- a/xtuner/v1/ray/rollout/vllm.py +++ b/xtuner/v1/ray/rollout/vllm.py @@ -1,18 +1,132 @@ +import asyncio +import os +import traceback from argparse import Namespace from typing import Any, Dict, List, Union -import uvloop +import ray +import requests +import torch from vllm.entrypoints.openai.api_server import run_server -from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args +from vllm.entrypoints.utils import cli_env_setup from vllm.utils import FlexibleArgumentParser +from xtuner.v1.data_proto.rl_data import RLRolloutResponseItem, RolloutState from xtuner.v1.ray.config import RolloutConfig +from xtuner.v1.ray.rollout.worker import RolloutWorker +from xtuner.v1.utils.device import get_device, get_torch_device_module -from .worker import RolloutWorker +DEVICE = get_device() +DEVICE_MODULE = get_torch_device_module() -def run_vllm_server_wrapper(server_args): - uvloop.run(run_server(server_args)) + +def stateless_init_process_group(master_address, master_port, rank, world_size, device): + """VLLM provides `StatelessProcessGroup` to create a process group without + considering the global process group in torch.distributed. + + It is recommended to create `StatelessProcessGroup`, and then initialize + the data-plane communication (NCCL) between external (train processes) + and vLLM workers. + """ + from vllm.distributed.utils import StatelessProcessGroup + + pg = StatelessProcessGroup.create(host=master_address, port=master_port, rank=rank, world_size=world_size) + + if DEVICE == "npu": + from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator + + pynccl = PyHcclCommunicator(pg, device=device) + else: + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + + pynccl = PyNcclCommunicator(pg, device=device) + return pynccl + + +class WorkerWrap: + def init_process_group( + self, master_address, master_port, rank_offset, world_size, group_name, backend="hccl", use_ray=False + ): + """Init torch process group for model weights update.""" + assert torch.distributed.is_initialized(), "default torch process group must be initialized" + assert group_name != "", "group name must not be empty" + + rank = torch.distributed.get_rank() + rank_offset + self._model_update_with_ray = use_ray + if use_ray: + import ray.util.collective as collective + + collective.init_collective_group(world_size=world_size, rank=rank, backend=backend, group_name=group_name) + self._model_update_group = group_name + else: + self._model_update_group = stateless_init_process_group( + master_address, + master_port, + rank, + world_size, + self.device, + ) + print( + f"init_process_group: master_address={master_address}, master_port={master_port}, ", + f"rank={rank}, world_size={world_size}, group_name={group_name}", + ) + + def update_weight_npu_ipc(self, data): + import base64 + import json + from multiprocessing.reduction import ForkingPickler + + if isinstance(data, str): + data = json.loads(data) + + def _construct(item): + func, args = item + args = list(args) + args[6] = DEVICE_MODULE.current_device() + return func(*args) + + serialized_data = data["serialized_named_tensors"] + if isinstance(serialized_data, list): + serialized_data = serialized_data[self.global_rank] + weights = ForkingPickler.loads(base64.b64decode(serialized_data)) + weights = [(k, _construct(v)) for k, v in weights] + DEVICE_MODULE.synchronize() + self.model_runner.model.load_weights(weights=weights) + del weights + DEVICE_MODULE.synchronize() + DEVICE_MODULE.empty_cache() + + def get_worker_pids(self): + current_pid = os.getpid() + return current_pid + + +@ray.remote +class VllmServerWrapper: + def __init__(self, server_namespace: Namespace): + cli_env_setup() + server_args = getattr(server_namespace, "args", Namespace()) + env = getattr(server_namespace, "env", {}) + for k, v in env.items(): + os.environ[k] = str(v) + try: + asyncio.run(run_server(server_args)) + except Exception as e: + error_msg = f"Failed to start server in VllmServerWrapper: {type(e).__name__}: {str(e)}" + stack_trace = traceback.format_exc() + print(error_msg) + print(stack_trace) + raise # Re-raise the exception to prevent silent failure + + def actor_health(self): + return "healthy" + + +# Add a dummy task. +def run_lmdeploy_server_wrapper(server_namespace: Namespace): + return ray.get(VllmServerWrapper.remote(server_namespace).actor_health.remote()) # type: ignore class vLLMWorker(RolloutWorker): @@ -26,15 +140,25 @@ def __init__( accelerator: str = "GPU", ): super().__init__(config, rank, master_addr, master_port, world_size, accelerator) - self.server_func = run_vllm_server_wrapper self.router_func = "" + self.server_func = run_lmdeploy_server_wrapper self.endpoints["health_generate"] = "health" + self.endpoints["v1/chat/completions"] = "v1/chat/completions" self.endpoints["generate"] = "v1/chat/completions" - self.endpoints["output_ids"] = "output_ids" - self.endpoints["response"] = "text" self.endpoints["sleep"] = "sleep" - self.endpoints["wake_up"] = "wakeup" + self.endpoints["wake_up"] = "wake_up" + self.endpoints["models"] = "models" + self.endpoints["update_weights"] = "update_weights" + # self.endpoints['abort_request'] = "abort_request" self.api_keys = self.config.api_key + self.model_name = self.config.model_name + self.enable_return_routed_experts = self.config.enable_return_routed_experts + self.dp_size = self.config.data_parallel_size + assert self.dp_size > 0, "data_parallel_size must be > 0" + assert self.config.tensor_parallel_size % self.dp_size == 0, ( + f"tensor_parallel_size ({self.config.tensor_parallel_size}) must be divisible by data_parallel_size ({self.dp_size})" + ) + self.tp_size = self.config.tensor_parallel_size // self.dp_size async def _create_request( self, @@ -47,45 +171,88 @@ async def _create_request( extra_params: dict, extra_info: dict, ): - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_keys}", # 如果需要鉴权 - } + stream = extra_params["stream"] + headers = {"Content-Type": "application/json"} + + if "image_data" in extra_info: + if not isinstance(prompt, list): + raise ValueError("image_data requires prompt to be a list of messages") + + image_index = 0 + for message in prompt: + if not isinstance(message, dict): + continue + if message.get("role") == "user": + new_content = [] + for content_part in message.get("content", []): + if not isinstance(content_part, dict): + new_content.append(content_part) + continue + if content_part.get("type") == "image_url": + content_part["image_url"]["url"] = f"file://{extra_info['image_data'][image_index]}" + content_part["image_url"].pop("image_wh", None) + image_index += 1 + new_content.append(content_part) + else: + new_content.append(content_part) + + message["content"] = new_content + + assert image_index == len(extra_info["image_data"]), ( + f"Expected {len(extra_info['image_data'])} images, but processed {image_index}." + ) + payload = { "model": self.config.model_path, "messages": prompt, - "stream": True, + "stream": stream, } - payload.update(sample_params) - payload.update(extra_params) + if "train_prompt_ids" in extra_info: + payload["input_ids"] = extra_info["train_prompt_ids"] + + vllm_sample_params = self._transform_sample_params(sample_params, extra_params) + payload.update(vllm_sample_params) return await self._safe_post_request(url, headers, payload) + def _transform_sample_params(self, sample_params: Dict, extra_params: Dict = {}): + import copy + + vllm_sample_params = copy.deepcopy(sample_params) + if extra_params: + vllm_sample_params.update(extra_params) + if "stops" in vllm_sample_params: + vllm_sample_params["stop"] = vllm_sample_params.pop("stops") + if "no_stop_trim" in vllm_sample_params: + vllm_sample_params["include_stop_str_in_output"] = vllm_sample_params.pop("no_stop_trim") + if "top_logprobs" in vllm_sample_params and "return_logprob" in vllm_sample_params: + vllm_sample_params["logprobs"] = vllm_sample_params.pop("return_logprob") + return vllm_sample_params + def get_logprobs(self, input_ids, sampling_params): pass def generate(self, input_ids, sampling_params): pass - def sleep(self, level=1, tags: List[str] | None = None): - import requests - + def sleep(self, level=1): url = f"{self.server_url}/{self.endpoints['sleep']}" - headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_keys}"} - data = {"tags": tags} - response = requests.post(url, headers=headers, json=data) + headers = {"Content-Type": "application/json"} + params = {} + params["level"] = level + response = requests.post(url, headers=headers, params=params) assert response.status_code == 200, response.status_code - return response.json() + return response.text def wake_up(self, tags: List[str] | None = None): - import requests - url = f"{self.server_url}/{self.endpoints['wake_up']}" - headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_keys}"} - data = {"tags": tags} - response = requests.post(url, headers=headers, json=data) + headers = {"Content-Type": "application/json"} + params = {} + if tags is not None: + params["tags"] = tags + response = requests.post(url, headers=headers, params=params) assert response.status_code == 200, response.status_code - return response.json() + return response.text def pause_generation(self): pass @@ -93,31 +260,138 @@ def pause_generation(self): def continue_generation(self): pass - def update_weights(self, ipc_handles): - # todo - pass + def onload_weights(self): + """Onloads the model weights by waking up the model.""" + return self.wake_up(tags=["weights"]) - def reset_prefix_cache(self): - # todo - pass + def onload_kvcache(self): + """Onloads the KV cache by waking up the model.""" + return self.wake_up(tags=["kv_cache"]) + + def offload(self): + """Offloads the model weights and KV cache.""" + return self.sleep(level=2) + + def reset_prefix_cache(self, tags: List[str] | None = None): + raise NotImplementedError("The 'reset_prefix_cache' API is not yet implemented in the vLLM server.") def _transform_rollout_config_to_server_configs(self) -> Namespace: # use vllm FlexibleArgumentParser to parse the config # and return the args as the default server config # vllm server_args: vllm/vllm/engine/arg_utils.py - parser = FlexibleArgumentParser() + parser = FlexibleArgumentParser(description="vLLM OpenAI-Compatible RESTful API server.") parser = make_arg_parser(parser) - args = parser.parse_args([]) - args.__dict__.update(vars(self.config)) - - args.host = self.host - args.port = self.server_port - args.model = self.config.model_path - args.disable_log_requests = True - args.disable_log_stats = True - args.tensor_parallel_size = self.config.tensor_parallel_size - if args.expert_parallel_size > 1: - args.tensor_parallel_size = self.config.expert_parallel_size - args.enable_expert_parallel = True - args.max_model_length = self.config.context_length - return args + args_ = parser.parse_args([]) + + args = {} + args["host"] = self.host + args["port"] = self.server_port + args["api_key"] = self.api_keys + args["api_keys"] = self.api_keys + args["model"] = self.config.model_path + args["log_level"] = "info" + args["data_parallel_size"] = self.dp_size + args["tensor_parallel_size"] = self.tp_size + args["enable_expert_parallel"] = False + + args["distributed_executor_backend"] = "ray" + args["max_model_len"] = self.config.context_length + args["enforce_eager"] = False + args["enable_sleep_mode"] = True + args["worker_extension_cls"] = "xtuner.v1.ray.rollout.vllm.WorkerWrap" + args["trust_remote_code"] = True + args["enable_prefix_caching"] = False + args["allowed_local_media_path"] = "/" + args["mm_processor_cache_gb"] = 0 + args["max_num_batched_tokens"] = 4096 + args["max_num_seqs"] = self.config.rollout_max_batch_size_per_instance // self.dp_size + args["block_size"] = 128 + args["gpu_memory_utilization"] = self.config.gpu_memory_utilization + args["compilation_config"] = { + "cudagraph_capture_sizes": [16, 12, 8, 4, 2, 1], + "cudagraph_mode": "FULL_DECODE_ONLY", + } + args["additional_config"] = {"enable_cpu_binding": True} + args["limit_mm_per_prompt"] = {"image": 10, "video": 0} + args["enable_log_requests"] = False + args["uvicorn_log_level"] = "error" + env = { + "VLLM_VERSION": "0.11.0", + "TASK_QUEUE_ENABLE": "0", + "CPU_AFFINITY_CONF": "2", + "VLLM_USE_V1": "1", + "VLLM_RAY_PER_WORKER_GPUS": "0.1", + "VLLM_RAY_BUNDLE_INDICES": ",".join(map(str, self.engine_bundle_idxs)), + "VLLM_MONITOR": "1", + "VLLM_ACCU_MONITOR": "0", + "CUSTOM_SCHEDULE_KV_LIMIT": "0.9", + "HCCL_BUFFSIZE": "512", + "VLLM_ASCEND_ENABLE_FLASHCOMM1": "0", + "SHM_BARRIER": "true", + "USE_TOKEN_IN": "1", + "ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15", + "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1", + "HCCL_CONNECT_TIMEOUT": "7200", + "HCCL_OP_EXPANSION_MODE": "AIV", + "INTERNS1_VIT_USE_TP": "1", + "VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION": "1", + "VLLM_SERVER_DEV_MODE": "1", + "VLLM_ASCEND_ENABLE_NZ": "0", + } + + # Apply extra_rollout_config overrides for vLLM parameters (prefix: "vllm_") + extra_cfg = getattr(self.config, "extra_rollout_config", None) or {} + for key, value in extra_cfg.items(): + if key.startswith("vllm_"): + real_key = key[5:] + args[real_key] = value + + args_.__dict__.update(args) + validate_parsed_serve_args(args_) + + return Namespace( + args=args_, + env=env, + api_key=self.api_keys, + api_keys=self.api_keys, + ray_runtime_env={"env_vars": env}, + ) + + async def _handle_stream_response(self, uid, sample_params, extra_params, response) -> RLRolloutResponseItem: + raise NotImplementedError + + async def _handle_non_stream_response( + self, root_id, action_id, sample_params, extra_params, response, input_extra_info + ) -> RLRolloutResponseItem: + uid = action_id + last_token_ids = [] + last_logprobs = [] + + response = response.json()["choices"][0] + if "logprobs" in response: + last_token_ids = response["token_ids"] + last_logprobs = [item["logprob"] for item in response["logprobs"]["content"]] + assert len(last_token_ids) == len(last_logprobs) + assert len(last_token_ids) <= sample_params["max_tokens"], ( + f"Generation length exceeds limit: generated {len(last_token_ids)}, limit {sample_params['max_tokens']}" + ) + last_trajectory = response["message"]["content"] + finish_reason = response["finish_reason"] + if finish_reason == "abort" and self.receive_abort_request.is_set() is False: + self.receive_abort_request.set() + self.logger.info(f"Setting receive_abort_request to True for rank {self.rank}") + + if finish_reason != "abort" and (len(last_token_ids) == 0 or len(last_logprobs) == 0): + self.logger.error(f"Invalid rollout response for request {uid}: {response}") + return RLRolloutResponseItem(state=RolloutState.SKIPPED) + + rollout_response = RLRolloutResponseItem( + response=last_trajectory, + response_ids=last_token_ids if len(last_token_ids) > 0 else None, + num_return_tokens=len(last_token_ids) if len(last_token_ids) > 0 else None, + finish_reason=finish_reason, + logprobs=last_logprobs, + state=RolloutState.ABORTED if finish_reason == "abort" else RolloutState.COMPLETED, + ) + + return rollout_response diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index 855bc589a..6d8a27270 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -842,6 +842,8 @@ def update_rollout_info( self.rollout_cfg_info["api_key"] = rollout_config.api_key if os.environ.get("XTUNER_USE_SGLANG", "0") == "1": self.rollout_cfg_info["backend"] = "sglang" + elif os.environ.get("XTUNER_USE_VLLM", "0") == "1": + self.rollout_cfg_info["backend"] = "vllm" else: self.rollout_cfg_info["backend"] = (rollout_config.extra_rollout_config or dict()).get( "lmdeploy_backend", "pytorch" @@ -926,7 +928,7 @@ def _update_weights_hf_generator(self, submodule=None, final_update=True): self.request_update_params(state_dict, finished=False) del state_dict, name_list, param_list - if self.rollout_cfg_info["backend"] == "pytorch" and final_update: + if self.rollout_cfg_info["backend"] in ("pytorch", "vllm") and final_update: self.request_update_params({}, finished=True) dist.barrier() @@ -1031,7 +1033,7 @@ def get_params(tensor_list, name_list, save_dtype): state_dict = dict(zip(name_list, fsdp_unshard_tensor_list)) self.request_update_params(state_dict) - if self.rollout_cfg_info["backend"] == "pytorch": + if self.rollout_cfg_info["backend"] in ("pytorch", "vllm"): self.request_update_params({}, finished=True) dist.barrier() @@ -1234,6 +1236,42 @@ def request_update_params(self, state_dict, finished=False): if self.rollout_url is None: self.logger.error(f"rank {self.rank} url in None, cannot update weights and skip") return + + if self.rollout_cfg_info["backend"] == "vllm": + + def serialize_state_dict(state_dict: dict) -> str: + import base64 + from io import BytesIO + from multiprocessing.reduction import ForkingPickler + + from torch.multiprocessing.reductions import reduce_tensor + + data = [(k, reduce_tensor(v)) for k, v in state_dict.items()] + buf = BytesIO() + ForkingPickler(buf).dump(data) + buf.seek(0) + return base64.b64encode(buf.read()).decode("utf-8") + + serialized_data = [None] * self.rollout_cfg_info["tp"] + dist.gather_object( + serialize_state_dict(state_dict), + serialized_data if dist.get_rank() == head_rank else None, + dst=head_rank, + group=cpu_group, + ) + if dist.get_rank() == head_rank: + headers = { + "Content-Type": "application/json", + } + data_ = json.dumps(dict(serialized_named_tensors=serialized_data, finished=finished)) + data = dict(method="update_weight_npu_ipc", args=[data_]) + response = requests.post(f"{self.rollout_url}/collective_rpc", headers=headers, json=data) + assert response.status_code == 200, f"response.status_code = {response.status_code}" + + if finished: + dist.barrier(group=cpu_group) + return + if self.rollout_cfg_info["backend"] == "pytorch": # TODO(chenchiyu): remove lmdeploy related code from lmdeploy.utils import serialize_state_dict