From e04db0ce156f3afda906b94bc662a32a4578f918 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Tue, 31 Mar 2026 15:31:27 +0800 Subject: [PATCH 1/5] [Enhance] Add memory monitoring enhancements for RL training scripts - Introduced new environment variables for RL memory monitoring: XTUNER_RL_MEM_INTERVAL, XTUNER_RL_OBJECT_LIMIT, and XTUNER_RL_OBJECT_TOP_K. - Updated run_rl.sh and run_rl_submit.sh to utilize these new variables for configuring memory monitoring. - Enhanced rl_monitor_actor_memory function to accept additional parameters for object limit and top K objects to monitor. - Added a new summarize_group_payload function in replay_buffer.py to provide detailed statistics on grouped data items. - Implemented memory reference management improvements in controller.py and replay_buffer.py to optimize memory usage during training. These changes aim to improve the flexibility and efficiency of memory monitoring in RL training workflows. --- examples/v1/scripts/run_rl.sh | 5 +- examples/v1/scripts/run_rl_submit.sh | 5 +- xtuner/v1/ray/dataflow/replay_buffer.py | 150 +++++++++++++- xtuner/v1/rl/base/controller.py | 23 ++- xtuner/v1/train/cli/rl.py | 33 +-- xtuner/v1/utils/track_rl_mem.py | 256 ++++++++++++++++++++++-- 6 files changed, 432 insertions(+), 40 deletions(-) diff --git a/examples/v1/scripts/run_rl.sh b/examples/v1/scripts/run_rl.sh index 3fe52ec44..f81f09bb2 100644 --- a/examples/v1/scripts/run_rl.sh +++ b/examples/v1/scripts/run_rl.sh @@ -82,6 +82,9 @@ export LMDEPLOY_LOG_FILE="${WORK_DIR}/lmdeploy_log_${current_time}.txt" if [ "$ACCELERATOR" = "GPU" ]; then # TODO: support NPU RL Memory Monitor export XTUNER_RL_MEM_DIR="${WORK_DIR}/mem_${current_time}" + export XTUNER_RL_MEM_INTERVAL="${XTUNER_RL_MEM_INTERVAL:-60}" + export XTUNER_RL_OBJECT_LIMIT="${XTUNER_RL_OBJECT_LIMIT:-5000}" + export XTUNER_RL_OBJECT_TOP_K="${XTUNER_RL_OBJECT_TOP_K:-10}" fi # 2. Launch Ray cluster @@ -139,4 +142,4 @@ LOG_FILE="${WORK_DIR}/training_log_${current_time}.txt" python xtuner/v1/train/cli/rl.py \ --config $CONFIG_PATH \ - 2>&1 | tee -a "${WORK_DIR}/training_log_${current_time}.txt" \ No newline at end of file + 2>&1 | tee -a "${WORK_DIR}/training_log_${current_time}.txt" diff --git a/examples/v1/scripts/run_rl_submit.sh b/examples/v1/scripts/run_rl_submit.sh index 4d268527d..d870b67a8 100644 --- a/examples/v1/scripts/run_rl_submit.sh +++ b/examples/v1/scripts/run_rl_submit.sh @@ -73,6 +73,9 @@ export LMDEPLOY_LOG_FILE="${WORK_DIR}/lmdeploy_log_${current_time}.txt" if [ "$ACCELERATOR" = "GPU" ]; then # TODO: support NPU RL Memory Monitor export XTUNER_RL_MEM_DIR="${WORK_DIR}/mem_${current_time}" + export XTUNER_RL_MEM_INTERVAL="${XTUNER_RL_MEM_INTERVAL:-60}" + export XTUNER_RL_OBJECT_LIMIT="${XTUNER_RL_OBJECT_LIMIT:-5000}" + export XTUNER_RL_OBJECT_TOP_K="${XTUNER_RL_OBJECT_TOP_K:-10}" fi # 2. Launch Ray cluster @@ -157,4 +160,4 @@ if [ "$RAY_RANK" -eq 0 ]; then 2>&1 | tee -a "$LOG_FILE" echo "训练任务提交完成。日志文件: $LOG_FILE" -fi \ No newline at end of file +fi diff --git a/xtuner/v1/ray/dataflow/replay_buffer.py b/xtuner/v1/ray/dataflow/replay_buffer.py index cdf10fd68..add6bdd2e 100644 --- a/xtuner/v1/ray/dataflow/replay_buffer.py +++ b/xtuner/v1/ray/dataflow/replay_buffer.py @@ -63,6 +63,49 @@ class ReplayMeta: extra_info: Dict[str, Any] = field(default_factory=dict) +def summarize_group_payload(grouped_dataitem: List[RLDataFlowItem]) -> Dict[str, Any]: + summary: Dict[str, Any] = { + "payload_mode": "full", + "observation_count": len(grouped_dataitem), + "response_tokens": 0, + "response_chars": 0, + "versioned_segments": 0, + "versioned_tokens": 0, + "routed_expert_payloads": 0, + "judged_observations": 0, + "has_multimodal_prompt": False, + } + if not grouped_dataitem: + return summary + + first_data = grouped_dataitem[0].data + summary["has_multimodal_prompt"] = bool( + getattr(first_data, "multimodal_train_info", None) and len(first_data.multimodal_train_info) > 0 + ) + + for item in grouped_dataitem: + rollout = item.env.rollout + judger = item.env.judger + response_ids = rollout.response_ids or [] + response_text = rollout.response or "" + versioned_response_ids = rollout.versioned_response_ids or [] + versioned_num_return_tokens = rollout.versioned_num_return_tokens or [] + + summary["response_tokens"] += len(response_ids) + summary["response_chars"] += len(response_text) + summary["versioned_segments"] += len(versioned_response_ids) + if versioned_num_return_tokens: + summary["versioned_tokens"] += sum(versioned_num_return_tokens) + else: + summary["versioned_tokens"] += sum(len(ids) for ids in versioned_response_ids) + if rollout.extra_info.get("routed_experts", None) is not None: + summary["routed_expert_payloads"] += 1 + if judger.uid is not None or judger.reward.get("score", 0.0) != 0.0 or len(judger.extra_info) > 0: + summary["judged_observations"] += 1 + + return summary + + def determine_group_state(group_data_items: List[RLDataFlowItem]) -> RolloutState: """Determines the processing strategy for a group of rollout samples based on their state.""" @@ -113,7 +156,7 @@ def mapping_dataitem_to_replaymeta(grouped_dataitem: List[RLDataFlowItem]) -> Re observation_refs=observation_refs, state=group_state, version=group_version, - extra_info={}, + extra_info=summarize_group_payload(grouped_dataitem), ) return replay_meta @@ -323,6 +366,87 @@ def __init__(self, replay_buffer_cfg): self.sample_from_aborted_count = 0 self.sample_from_expired_count = 0 + def _free_replay_meta_refs(self, replay_meta: ReplayMeta, include_action_ref: bool = True): + refs = [] + if include_action_ref and replay_meta.action_ref is not None: + refs.append(replay_meta.action_ref) + refs.extend([ref for ref in replay_meta.observation_refs if ref is not None]) + if refs: + ray.internal.free(refs, local_only=False) + + def _update_replay_meta_state(self, replay_meta: ReplayMeta, new_state: RolloutState): + for observation_id in replay_meta.observation_ids: + old_state = self._observations2states.get(observation_id) + if old_state and observation_id in self._states.get(old_state, []): + self._states[old_state].remove(observation_id) + self._observations2states[observation_id] = new_state + if observation_id not in self._states[new_state]: + self._states[new_state].append(observation_id) + replay_meta.state = new_state + + def _strip_rollout_payload_for_rerun(self, replay_meta: ReplayMeta, new_state: RolloutState): + """Keep prompt refs only and drop rollout outputs that will not be reused.""" + old_obs_refs = [ref for ref in replay_meta.observation_refs if ref is not None] + if old_obs_refs: + ray.internal.free(old_obs_refs, local_only=False) + replay_meta.observation_refs = [ray.put(RLEnvDataItem()) for _ in replay_meta.observation_ids] + replay_meta.extra_info.update( + { + "payload_mode": "prompt_only", + "response_tokens": 0, + "response_chars": 0, + "versioned_segments": 0, + "versioned_tokens": 0, + "routed_expert_payloads": 0, + "judged_observations": 0, + } + ) + self._update_replay_meta_state(replay_meta, new_state) + + def get_storage_stats(self) -> Dict[str, float]: + stats: Dict[str, float] = { + "tracked_actions_count": float(len(self._actions)), + "tracked_roots_count": float(len(self._root2actions)), + "tracked_observations_count": float(len(self._observations)), + "completed_actions_count": float(sum(len(bucket) for bucket in self._completed_actions.values())), + "aborted_actions_count": float(sum(len(bucket) for bucket in self._aborted_actions.values())), + "expired_actions_count": float(len(self._expired_actions)), + "completed_versions_count": float(len(self._completed_actions)), + "aborted_versions_count": float(len(self._aborted_actions)), + "payload_full_actions_count": 0.0, + "payload_prompt_only_actions_count": 0.0, + "payload_full_observations_count": 0.0, + "payload_prompt_only_observations_count": 0.0, + "stored_response_tokens": 0.0, + "stored_response_chars": 0.0, + "stored_versioned_segments": 0.0, + "stored_versioned_tokens": 0.0, + "stored_routed_expert_payloads": 0.0, + "stored_judged_observations": 0.0, + "multimodal_actions_count": 0.0, + } + + for replay_meta in self._actions.values(): + summary = replay_meta.extra_info + observation_count = float(summary.get("observation_count", len(replay_meta.observation_ids))) + if summary.get("payload_mode", "full") == "prompt_only": + stats["payload_prompt_only_actions_count"] += 1.0 + stats["payload_prompt_only_observations_count"] += observation_count + else: + stats["payload_full_actions_count"] += 1.0 + stats["payload_full_observations_count"] += observation_count + + stats["stored_response_tokens"] += float(summary.get("response_tokens", 0)) + stats["stored_response_chars"] += float(summary.get("response_chars", 0)) + stats["stored_versioned_segments"] += float(summary.get("versioned_segments", 0)) + stats["stored_versioned_tokens"] += float(summary.get("versioned_tokens", 0)) + stats["stored_routed_expert_payloads"] += float(summary.get("routed_expert_payloads", 0)) + stats["stored_judged_observations"] += float(summary.get("judged_observations", 0)) + if summary.get("has_multimodal_prompt", False): + stats["multimodal_actions_count"] += 1.0 + + return stats + def add(self, grouped_dataitem: List[RLDataFlowItem]): """Adds a group of data items to the storage. @@ -426,6 +550,8 @@ def sample(self, sample_from_expired_states) -> List[RLDataFlowItem]: return [] def clear(self): + for replay_meta in self._actions.values(): + self._free_replay_meta_refs(replay_meta) attrs_to_clear = [ "_aborted_actions", "_completed_actions", @@ -699,6 +825,10 @@ def _check_completed_samples_expired(self): for version in expired_versions: bucket = self._completed_actions.pop(version) + for action_id in bucket: + replay_meta = self._actions.get(action_id) + if replay_meta is not None: + self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.EXPIRED) self._expired_actions.extend(bucket) self.logger.info( f"Moved {len(bucket)} completed samples with version {version} to expired samples due to exceeding tail_batch_candidate_steps." @@ -709,6 +839,10 @@ def _check_completed_samples_aborted(self): return for version, bucket in self._completed_actions.items(): + for action_id in bucket: + replay_meta = self._actions.get(action_id) + if replay_meta is not None: + self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.ABORTED) self._aborted_actions[0].extend(bucket) self.logger.info( f"Moved {len(bucket)} completed samples with version {version} to aborted samples due to partial rollout disabled." @@ -729,7 +863,9 @@ def _clear_meta_for_actions(self, replay_meta: ReplayMeta): if state and observation_id in self._states.get(state, []): self._states[state].remove(observation_id) + self._actions.pop(action_id, None) self._action2observations.pop(action_id, None) + self._free_replay_meta_refs(replay_meta) del replay_meta def _clear_meta_for_root(self, replay_meta: ReplayMeta): @@ -747,13 +883,16 @@ def _clear_meta_for_root(self, replay_meta: ReplayMeta): and clear all related actions. """ root_id = replay_meta.root_id + current_action_id = replay_meta.action_id + self._clear_meta_for_actions(replay_meta) if root_id in self._root2actions: for action_id in self._root2actions[root_id]: + if action_id == current_action_id: + continue new_replay_meta = self._actions.pop(action_id, None) if new_replay_meta: self._clear_meta_for_actions(new_replay_meta) del self._root2actions[root_id] - del replay_meta def _check_rollout_state_and_insert(self, replay_meta: ReplayMeta): """Checks the rollout state of a ReplayMeta object and inserts its @@ -775,11 +914,14 @@ def _check_rollout_state_and_insert(self, replay_meta: ReplayMeta): if state == RolloutState.ABORTED: if self.tail_batch_candidate_steps > 0 and replay_meta.version >= self.tail_batch_candidate_steps: # 过期的数据需要重置状态 + self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.EXPIRED) self._expired_actions.append(action_id) self.logger.debug( f"Add expired sample with action_id: {action_id} to _expired_actions because version: {replay_meta.version} >= tail_batch_candidate_steps: {self.tail_batch_candidate_steps}." ) else: + if not self.enable_partial_rollout: + self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.ABORTED) self._aborted_actions[replay_meta.version].append(action_id) self.logger.debug( f"Add aborted sample with action_id: {action_id} version: {replay_meta.version} to _aborted_actions." @@ -903,7 +1045,7 @@ def add(self, grouped_dataitem: List[RLDataFlowItem]): self.storage.add(grouped_dataitem) def status(self): - return { + status = { "remain_completed_samples_count": self.storage.completed_samples_count, "remain_aborted_samples_count": self.storage.aborted_samples_count, "remain_expired_samples_count": self.storage.expired_samples_count, @@ -911,6 +1053,8 @@ def status(self): "sample_from_aborted_count": self.storage.sample_from_aborted_count, "sample_from_expired_count": self.storage.sample_from_expired_count, } + status.update(self.storage.get_storage_stats()) + return status def save(self, file_path: Path | str): """Saves the replay buffer's storage to a file. diff --git a/xtuner/v1/rl/base/controller.py b/xtuner/v1/rl/base/controller.py index b500b53e4..0f3864044 100644 --- a/xtuner/v1/rl/base/controller.py +++ b/xtuner/v1/rl/base/controller.py @@ -4,6 +4,7 @@ import ray import torch +from ray import ObjectRef from ray.actor import ActorProxy from xtuner.v1.data_proto.sequence_context import SequenceContext @@ -28,6 +29,23 @@ class RawTrainingController: def __init__(self, workers: list[TrainingWorker]) -> None: self.workers = workers + def _collect_object_refs(self, obj, refs: list[ObjectRef]): + if isinstance(obj, ObjectRef): + refs.append(obj) + return + if isinstance(obj, (list, tuple)): + for item in obj: + self._collect_object_refs(item, refs) + + def _free_batch_object_refs(self, data_batches): + refs: list[ObjectRef] = [] + for data in data_batches: + seq_ctx = data["seq_ctx"] + self._collect_object_refs(seq_ctx.pixel_values, refs) + self._collect_object_refs(seq_ctx.rollout_routed_experts, refs) + if refs: + ray.internal.free(refs, local_only=False) + # TODO(hha): 这个逻辑不够通用,应该复用 sft 函数,从而支持 expand soft pack def _get_pack_infos(self, dataset, num_tokens, target, random=None): inds = list(range(len(dataset))) @@ -260,7 +278,10 @@ def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: rollout_idx=rollout_idx, ) ) - log_infos = ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT) + try: + log_infos = ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT) + finally: + self._free_batch_object_refs(packed_data_batches) return log_infos @ray_method diff --git a/xtuner/v1/train/cli/rl.py b/xtuner/v1/train/cli/rl.py index 0a91ee1ed..0788c7c3e 100644 --- a/xtuner/v1/train/cli/rl.py +++ b/xtuner/v1/train/cli/rl.py @@ -20,19 +20,20 @@ ) -def rl_monitor_actor_memory(work_dir, interval: int = 60): - while True: - try: - ray.init(address="auto") - time.sleep(interval) - break - except KeyboardInterrupt: - print("\n监控已停止") - break - except Exception: - print("连接 Ray 集群失败, 等等") +def rl_monitor_actor_memory(work_dir, interval: int = 60, object_limit: int = 5000, top_k: int = 10): + if not ray.is_initialized(): + while True: + try: + ray.init(address="auto") + time.sleep(interval) + break + except KeyboardInterrupt: + print("\n监控已停止") + return + except Exception: + print("连接 Ray 集群失败, 等等") - monitor_actor_memory(work_dir=work_dir, interval=interval) + monitor_actor_memory(work_dir=work_dir, interval=interval, object_limit=object_limit, top_k=top_k) @app.default() @@ -51,7 +52,13 @@ def main( if os.getenv("XTUNER_RL_MEM_DIR"): print("Start to monitor actor memory") - track_thread = threading.Thread(target=rl_monitor_actor_memory, args=(os.getenv("XTUNER_RL_MEM_DIR"),)) + monitor_interval = int(os.getenv("XTUNER_RL_MEM_INTERVAL", "60")) + object_limit = int(os.getenv("XTUNER_RL_OBJECT_LIMIT", "5000")) + top_k = int(os.getenv("XTUNER_RL_OBJECT_TOP_K", "10")) + track_thread = threading.Thread( + target=rl_monitor_actor_memory, + args=(os.getenv("XTUNER_RL_MEM_DIR"), monitor_interval, object_limit, top_k), + ) track_thread.daemon = True track_thread.start() diff --git a/xtuner/v1/utils/track_rl_mem.py b/xtuner/v1/utils/track_rl_mem.py index f42d7cce4..c7ee1efb1 100644 --- a/xtuner/v1/utils/track_rl_mem.py +++ b/xtuner/v1/utils/track_rl_mem.py @@ -1,7 +1,10 @@ import argparse +import dataclasses import json import os import time +from collections import defaultdict +from typing import Any import psutil import ray @@ -15,14 +18,163 @@ pynvml = None -def monitor_actor_memory(work_dir: str, interval: int = 60): +def _maybe_init_nvml(): if pynvml is None: - raise ImportError("pynvml 未安装,无法监控 GPU 内存") + return False + try: + pynvml.nvmlInit() + return True + except Exception: + return False + + +def _maybe_shutdown_nvml(initialized: bool): + if initialized: + try: + pynvml.nvmlShutdown() + except Exception: + pass + + +def _state_obj_to_dict(obj: Any) -> dict[str, Any]: + if isinstance(obj, dict): + return obj + if dataclasses.is_dataclass(obj): + return dataclasses.asdict(obj) + if hasattr(obj, "model_dump"): + return obj.model_dump() + if hasattr(obj, "dict"): + return obj.dict() + if hasattr(obj, "__dict__"): + return dict(obj.__dict__) + return {} + + +def _sanitize_tag_component(name: str) -> str: + return name.replace("/", "_").replace(" ", "_").replace(":", "_").replace(".", "_") + + +def _get_object_store_stats(object_limit: int = 5000, top_k: int = 10): + stats: dict[str, Any] = { + "available": False, + "total_objects": 0, + "total_size_mb": 0.0, + "callsite_enabled": 0, + "summary_by": "", + "ref_type_counts": {}, + "task_state_counts": {}, + "top_callsites": [], + "detail_truncated": 0, + "detail_object_count": 0, + "detail_total_size_mb": 0.0, + "detail_ref_type_counts": {}, + "detail_ref_type_size_mb": {}, + "top_call_sites_from_objects": [], + "top_pids": [], + "top_ips": [], + } + + try: + from ray.util import state as ray_state + + summary_raw = ray_state.summarize_objects(timeout=30, raise_on_missing_output=False) + summary_data = _state_obj_to_dict(summary_raw) + stats["available"] = True + stats["total_objects"] = summary_data.get("total_objects", 0) + stats["total_size_mb"] = float(summary_data.get("total_size_mb", 0.0) or 0.0) + stats["callsite_enabled"] = int(bool(summary_data.get("callsite_enabled", False))) + stats["summary_by"] = summary_data.get("summary_by", "") + + ref_type_counts = defaultdict(int) + task_state_counts = defaultdict(int) + callsite_items = [] + for callsite, item in (summary_data.get("summary", {}) or {}).items(): + item_dict = _state_obj_to_dict(item) + callsite_items.append( + { + "callsite": callsite, + "total_size_mb": float(item_dict.get("total_size_mb", 0.0) or 0.0), + "total_objects": int(item_dict.get("total_objects", 0) or 0), + "total_num_workers": int(item_dict.get("total_num_workers", 0) or 0), + "total_num_nodes": int(item_dict.get("total_num_nodes", 0) or 0), + "ref_type_counts": item_dict.get("ref_type_counts", {}) or {}, + "task_state_counts": item_dict.get("task_state_counts", {}) or {}, + } + ) + for ref_type, count in (item_dict.get("ref_type_counts", {}) or {}).items(): + ref_type_counts[str(ref_type)] += int(count) + for task_state, count in (item_dict.get("task_state_counts", {}) or {}).items(): + task_state_counts[str(task_state)] += int(count) + + callsite_items.sort(key=lambda x: (x["total_size_mb"], x["total_objects"]), reverse=True) + stats["top_callsites"] = callsite_items[:top_k] + stats["ref_type_counts"] = dict(ref_type_counts) + stats["task_state_counts"] = dict(task_state_counts) + + try: + object_states = ray_state.list_objects( + limit=object_limit, timeout=30, detail=False, raise_on_missing_output=False + ) + pid_size_mb = defaultdict(float) + pid_count = defaultdict(int) + ip_size_mb = defaultdict(float) + ip_count = defaultdict(int) + ref_type_size_mb = defaultdict(float) + ref_type_count = defaultdict(int) + callsite_size_mb = defaultdict(float) + callsite_count = defaultdict(int) + + object_state_dicts = [_state_obj_to_dict(obj) for obj in object_states] + stats["detail_object_count"] = len(object_state_dicts) + stats["detail_truncated"] = int(len(object_state_dicts) >= object_limit) + + for obj in object_state_dicts: + size_mb = float(obj.get("object_size", 0) or 0) / (1024**2) + pid = str(obj.get("pid", "unknown")) + ip = str(obj.get("ip", "unknown")) + ref_type = str(obj.get("reference_type", "UNKNOWN")) + call_site = str(obj.get("call_site", "unknown")) + + stats["detail_total_size_mb"] += size_mb + pid_size_mb[pid] += size_mb + pid_count[pid] += 1 + ip_size_mb[ip] += size_mb + ip_count[ip] += 1 + ref_type_size_mb[ref_type] += size_mb + ref_type_count[ref_type] += 1 + callsite_size_mb[call_site] += size_mb + callsite_count[call_site] += 1 + + stats["detail_ref_type_counts"] = dict(ref_type_count) + stats["detail_ref_type_size_mb"] = dict(ref_type_size_mb) + stats["top_call_sites_from_objects"] = [ + {"callsite": k, "size_mb": v, "count": callsite_count[k]} + for k, v in sorted(callsite_size_mb.items(), key=lambda item: item[1], reverse=True)[:top_k] + ] + stats["top_pids"] = [ + {"pid": k, "size_mb": v, "count": pid_count[k]} + for k, v in sorted(pid_size_mb.items(), key=lambda item: item[1], reverse=True)[:top_k] + ] + stats["top_ips"] = [ + {"ip": k, "size_mb": v, "count": ip_count[k]} + for k, v in sorted(ip_size_mb.items(), key=lambda item: item[1], reverse=True)[:top_k] + ] + except Exception as e: + stats["detail_error"] = str(e) + + except Exception as e: + stats["error"] = str(e) + + return stats + + +def monitor_actor_memory(work_dir: str, interval: int = 60, object_limit: int = 5000, top_k: int = 10): print(f"开始监控 Actor 内存使用情况,间隔 {interval} 秒...") print("=" * 80) os.makedirs(f"{work_dir}/tb", exist_ok=True) - f = open(f"{work_dir}/actor_memory.json", "w") + actor_f = open(f"{work_dir}/actor_memory.jsonl", "w", encoding="utf-8") + object_f = open(f"{work_dir}/object_store.jsonl", "w", encoding="utf-8") cluster_resources = ray.cluster_resources() total_gpus = int(cluster_resources.get("GPU", 0)) @@ -35,6 +187,7 @@ def monitor_actor_memory(work_dir: str, interval: int = 60): while True: count += 1 memory_info = {} + object_store_info = {} # 获取所有 Actor actors = ray.state.actors() @@ -54,17 +207,18 @@ def monitor_actor_memory(work_dir: str, interval: int = 60): try: process = psutil.Process(pid) memory_gb = process.memory_info().rss / 1024 / 1024 / 1024 - pynvml.nvmlInit() - device_count = pynvml.nvmlDeviceGetCount() - for i in range(device_count): - handle = pynvml.nvmlDeviceGetHandleByIndex(i) - # 检查该GPU是否被当前进程使用 - compute_procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle) - if any(proc.pid == pid for proc in compute_procs): - mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) - gpu_memory_gb = mem_info.used / 1024 / 1024 / 1024 - break - pynvml.nvmlShutdown() + nvml_initialized = _maybe_init_nvml() + if nvml_initialized: + device_count = pynvml.nvmlDeviceGetCount() + for i in range(device_count): + handle = pynvml.nvmlDeviceGetHandleByIndex(i) + # 检查该GPU是否被当前进程使用 + compute_procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle) + if any(proc.pid == pid for proc in compute_procs): + mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) + gpu_memory_gb = mem_info.used / 1024 / 1024 / 1024 + break + _maybe_shutdown_nvml(nvml_initialized) except (psutil.NoSuchProcess, psutil.AccessDenied): pass @@ -80,10 +234,18 @@ def monitor_actor_memory(work_dir: str, interval: int = 60): "gpu_mem_gb": [gpu_memory_gb], } + object_store_info = _get_object_store_stats(object_limit=object_limit, top_k=top_k) + object_store_info["time"] = current_time + object_store_info["object_limit"] = object_limit + object_store_info["top_k"] = top_k + # 写入文件 - json.dump(memory_info, f, ensure_ascii=False) - f.write("\n") - f.flush() + json.dump(memory_info, actor_f, ensure_ascii=False) + actor_f.write("\n") + actor_f.flush() + json.dump(object_store_info, object_f, ensure_ascii=False) + object_f.write("\n") + object_f.flush() for actor_name, memory_mb_info in memory_info.items(): if actor_name == "time": @@ -128,13 +290,60 @@ def monitor_actor_memory(work_dir: str, interval: int = 60): global_step=count, ) + tb_writer_list[0].add_scalar( + tag="ray_object_store/total_size_mb", + scalar_value=float(object_store_info.get("total_size_mb", 0.0) or 0.0), + global_step=count, + ) + tb_writer_list[0].add_scalar( + tag="ray_object_store/total_objects", + scalar_value=float(object_store_info.get("total_objects", 0) or 0), + global_step=count, + ) + tb_writer_list[0].add_scalar( + tag="ray_object_store/detail_total_size_mb", + scalar_value=float(object_store_info.get("detail_total_size_mb", 0.0) or 0.0), + global_step=count, + ) + tb_writer_list[0].add_scalar( + tag="ray_object_store/detail_object_count", + scalar_value=float(object_store_info.get("detail_object_count", 0) or 0), + global_step=count, + ) + tb_writer_list[0].add_scalar( + tag="ray_object_store/detail_truncated", + scalar_value=float(object_store_info.get("detail_truncated", 0) or 0), + global_step=count, + ) + + for ref_type, value in (object_store_info.get("ref_type_counts", {}) or {}).items(): + tb_writer_list[0].add_scalar( + tag=f"ray_object_store/ref_type_count/{_sanitize_tag_component(str(ref_type))}", + scalar_value=float(value), + global_step=count, + ) + for ref_type, value in (object_store_info.get("detail_ref_type_size_mb", {}) or {}).items(): + tb_writer_list[0].add_scalar( + tag=f"ray_object_store/ref_type_size_mb/{_sanitize_tag_component(str(ref_type))}", + scalar_value=float(value), + global_step=count, + ) + for task_state, value in (object_store_info.get("task_state_counts", {}) or {}).items(): + tb_writer_list[0].add_scalar( + tag=f"ray_object_store/task_state_count/{_sanitize_tag_component(str(task_state))}", + scalar_value=float(value), + global_step=count, + ) + time.sleep(interval) print(memory_info) + print(object_store_info) except KeyboardInterrupt: print("\n监控已停止") finally: - f.close() + actor_f.close() + object_f.close() for tb_writer in tb_writer_list: tb_writer.close() @@ -143,14 +352,19 @@ def monitor_actor_memory(work_dir: str, interval: int = 60): parser = argparse.ArgumentParser(description="RL MEMORY MONITOR") parser.add_argument("--work_dir", type=str, default="dense_8b") parser.add_argument("--interval", type=int, default=60) + parser.add_argument("--object_limit", type=int, default=5000) + parser.add_argument("--top_k", type=int, default=10) args = parser.parse_args() work_dir = args.work_dir interval = args.interval + object_limit = args.object_limit + top_k = args.top_k while True: try: - ray.init(address="auto") - time.sleep(interval) + if not ray.is_initialized(): + ray.init(address="auto") + time.sleep(interval) break except KeyboardInterrupt: print("\n监控已停止") @@ -158,4 +372,4 @@ def monitor_actor_memory(work_dir: str, interval: int = 60): except Exception: print("连接 Ray 集群失败, 等等") - monitor_actor_memory(work_dir=work_dir, interval=interval) + monitor_actor_memory(work_dir=work_dir, interval=interval, object_limit=object_limit, top_k=top_k) From 0a7cf3a287d676324b75782063bfbdf7893f5cbc Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Tue, 7 Apr 2026 17:52:30 +0800 Subject: [PATCH 2/5] Restore RL memory scripts to github/main --- examples/v1/scripts/run_rl.sh | 5 +- examples/v1/scripts/run_rl_submit.sh | 5 +- xtuner/v1/utils/track_rl_mem.py | 256 +++------------------------ 3 files changed, 23 insertions(+), 243 deletions(-) diff --git a/examples/v1/scripts/run_rl.sh b/examples/v1/scripts/run_rl.sh index f81f09bb2..3fe52ec44 100644 --- a/examples/v1/scripts/run_rl.sh +++ b/examples/v1/scripts/run_rl.sh @@ -82,9 +82,6 @@ export LMDEPLOY_LOG_FILE="${WORK_DIR}/lmdeploy_log_${current_time}.txt" if [ "$ACCELERATOR" = "GPU" ]; then # TODO: support NPU RL Memory Monitor export XTUNER_RL_MEM_DIR="${WORK_DIR}/mem_${current_time}" - export XTUNER_RL_MEM_INTERVAL="${XTUNER_RL_MEM_INTERVAL:-60}" - export XTUNER_RL_OBJECT_LIMIT="${XTUNER_RL_OBJECT_LIMIT:-5000}" - export XTUNER_RL_OBJECT_TOP_K="${XTUNER_RL_OBJECT_TOP_K:-10}" fi # 2. Launch Ray cluster @@ -142,4 +139,4 @@ LOG_FILE="${WORK_DIR}/training_log_${current_time}.txt" python xtuner/v1/train/cli/rl.py \ --config $CONFIG_PATH \ - 2>&1 | tee -a "${WORK_DIR}/training_log_${current_time}.txt" + 2>&1 | tee -a "${WORK_DIR}/training_log_${current_time}.txt" \ No newline at end of file diff --git a/examples/v1/scripts/run_rl_submit.sh b/examples/v1/scripts/run_rl_submit.sh index d870b67a8..4d268527d 100644 --- a/examples/v1/scripts/run_rl_submit.sh +++ b/examples/v1/scripts/run_rl_submit.sh @@ -73,9 +73,6 @@ export LMDEPLOY_LOG_FILE="${WORK_DIR}/lmdeploy_log_${current_time}.txt" if [ "$ACCELERATOR" = "GPU" ]; then # TODO: support NPU RL Memory Monitor export XTUNER_RL_MEM_DIR="${WORK_DIR}/mem_${current_time}" - export XTUNER_RL_MEM_INTERVAL="${XTUNER_RL_MEM_INTERVAL:-60}" - export XTUNER_RL_OBJECT_LIMIT="${XTUNER_RL_OBJECT_LIMIT:-5000}" - export XTUNER_RL_OBJECT_TOP_K="${XTUNER_RL_OBJECT_TOP_K:-10}" fi # 2. Launch Ray cluster @@ -160,4 +157,4 @@ if [ "$RAY_RANK" -eq 0 ]; then 2>&1 | tee -a "$LOG_FILE" echo "训练任务提交完成。日志文件: $LOG_FILE" -fi +fi \ No newline at end of file diff --git a/xtuner/v1/utils/track_rl_mem.py b/xtuner/v1/utils/track_rl_mem.py index c7ee1efb1..f42d7cce4 100644 --- a/xtuner/v1/utils/track_rl_mem.py +++ b/xtuner/v1/utils/track_rl_mem.py @@ -1,10 +1,7 @@ import argparse -import dataclasses import json import os import time -from collections import defaultdict -from typing import Any import psutil import ray @@ -18,163 +15,14 @@ pynvml = None -def _maybe_init_nvml(): +def monitor_actor_memory(work_dir: str, interval: int = 60): if pynvml is None: - return False - try: - pynvml.nvmlInit() - return True - except Exception: - return False - - -def _maybe_shutdown_nvml(initialized: bool): - if initialized: - try: - pynvml.nvmlShutdown() - except Exception: - pass - - -def _state_obj_to_dict(obj: Any) -> dict[str, Any]: - if isinstance(obj, dict): - return obj - if dataclasses.is_dataclass(obj): - return dataclasses.asdict(obj) - if hasattr(obj, "model_dump"): - return obj.model_dump() - if hasattr(obj, "dict"): - return obj.dict() - if hasattr(obj, "__dict__"): - return dict(obj.__dict__) - return {} - - -def _sanitize_tag_component(name: str) -> str: - return name.replace("/", "_").replace(" ", "_").replace(":", "_").replace(".", "_") - - -def _get_object_store_stats(object_limit: int = 5000, top_k: int = 10): - stats: dict[str, Any] = { - "available": False, - "total_objects": 0, - "total_size_mb": 0.0, - "callsite_enabled": 0, - "summary_by": "", - "ref_type_counts": {}, - "task_state_counts": {}, - "top_callsites": [], - "detail_truncated": 0, - "detail_object_count": 0, - "detail_total_size_mb": 0.0, - "detail_ref_type_counts": {}, - "detail_ref_type_size_mb": {}, - "top_call_sites_from_objects": [], - "top_pids": [], - "top_ips": [], - } - - try: - from ray.util import state as ray_state - - summary_raw = ray_state.summarize_objects(timeout=30, raise_on_missing_output=False) - summary_data = _state_obj_to_dict(summary_raw) - stats["available"] = True - stats["total_objects"] = summary_data.get("total_objects", 0) - stats["total_size_mb"] = float(summary_data.get("total_size_mb", 0.0) or 0.0) - stats["callsite_enabled"] = int(bool(summary_data.get("callsite_enabled", False))) - stats["summary_by"] = summary_data.get("summary_by", "") - - ref_type_counts = defaultdict(int) - task_state_counts = defaultdict(int) - callsite_items = [] - for callsite, item in (summary_data.get("summary", {}) or {}).items(): - item_dict = _state_obj_to_dict(item) - callsite_items.append( - { - "callsite": callsite, - "total_size_mb": float(item_dict.get("total_size_mb", 0.0) or 0.0), - "total_objects": int(item_dict.get("total_objects", 0) or 0), - "total_num_workers": int(item_dict.get("total_num_workers", 0) or 0), - "total_num_nodes": int(item_dict.get("total_num_nodes", 0) or 0), - "ref_type_counts": item_dict.get("ref_type_counts", {}) or {}, - "task_state_counts": item_dict.get("task_state_counts", {}) or {}, - } - ) - for ref_type, count in (item_dict.get("ref_type_counts", {}) or {}).items(): - ref_type_counts[str(ref_type)] += int(count) - for task_state, count in (item_dict.get("task_state_counts", {}) or {}).items(): - task_state_counts[str(task_state)] += int(count) - - callsite_items.sort(key=lambda x: (x["total_size_mb"], x["total_objects"]), reverse=True) - stats["top_callsites"] = callsite_items[:top_k] - stats["ref_type_counts"] = dict(ref_type_counts) - stats["task_state_counts"] = dict(task_state_counts) - - try: - object_states = ray_state.list_objects( - limit=object_limit, timeout=30, detail=False, raise_on_missing_output=False - ) - pid_size_mb = defaultdict(float) - pid_count = defaultdict(int) - ip_size_mb = defaultdict(float) - ip_count = defaultdict(int) - ref_type_size_mb = defaultdict(float) - ref_type_count = defaultdict(int) - callsite_size_mb = defaultdict(float) - callsite_count = defaultdict(int) - - object_state_dicts = [_state_obj_to_dict(obj) for obj in object_states] - stats["detail_object_count"] = len(object_state_dicts) - stats["detail_truncated"] = int(len(object_state_dicts) >= object_limit) - - for obj in object_state_dicts: - size_mb = float(obj.get("object_size", 0) or 0) / (1024**2) - pid = str(obj.get("pid", "unknown")) - ip = str(obj.get("ip", "unknown")) - ref_type = str(obj.get("reference_type", "UNKNOWN")) - call_site = str(obj.get("call_site", "unknown")) - - stats["detail_total_size_mb"] += size_mb - pid_size_mb[pid] += size_mb - pid_count[pid] += 1 - ip_size_mb[ip] += size_mb - ip_count[ip] += 1 - ref_type_size_mb[ref_type] += size_mb - ref_type_count[ref_type] += 1 - callsite_size_mb[call_site] += size_mb - callsite_count[call_site] += 1 - - stats["detail_ref_type_counts"] = dict(ref_type_count) - stats["detail_ref_type_size_mb"] = dict(ref_type_size_mb) - stats["top_call_sites_from_objects"] = [ - {"callsite": k, "size_mb": v, "count": callsite_count[k]} - for k, v in sorted(callsite_size_mb.items(), key=lambda item: item[1], reverse=True)[:top_k] - ] - stats["top_pids"] = [ - {"pid": k, "size_mb": v, "count": pid_count[k]} - for k, v in sorted(pid_size_mb.items(), key=lambda item: item[1], reverse=True)[:top_k] - ] - stats["top_ips"] = [ - {"ip": k, "size_mb": v, "count": ip_count[k]} - for k, v in sorted(ip_size_mb.items(), key=lambda item: item[1], reverse=True)[:top_k] - ] - except Exception as e: - stats["detail_error"] = str(e) - - except Exception as e: - stats["error"] = str(e) - - return stats - - -def monitor_actor_memory(work_dir: str, interval: int = 60, object_limit: int = 5000, top_k: int = 10): + raise ImportError("pynvml 未安装,无法监控 GPU 内存") print(f"开始监控 Actor 内存使用情况,间隔 {interval} 秒...") print("=" * 80) os.makedirs(f"{work_dir}/tb", exist_ok=True) - actor_f = open(f"{work_dir}/actor_memory.jsonl", "w", encoding="utf-8") - object_f = open(f"{work_dir}/object_store.jsonl", "w", encoding="utf-8") + f = open(f"{work_dir}/actor_memory.json", "w") cluster_resources = ray.cluster_resources() total_gpus = int(cluster_resources.get("GPU", 0)) @@ -187,7 +35,6 @@ def monitor_actor_memory(work_dir: str, interval: int = 60, object_limit: int = while True: count += 1 memory_info = {} - object_store_info = {} # 获取所有 Actor actors = ray.state.actors() @@ -207,18 +54,17 @@ def monitor_actor_memory(work_dir: str, interval: int = 60, object_limit: int = try: process = psutil.Process(pid) memory_gb = process.memory_info().rss / 1024 / 1024 / 1024 - nvml_initialized = _maybe_init_nvml() - if nvml_initialized: - device_count = pynvml.nvmlDeviceGetCount() - for i in range(device_count): - handle = pynvml.nvmlDeviceGetHandleByIndex(i) - # 检查该GPU是否被当前进程使用 - compute_procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle) - if any(proc.pid == pid for proc in compute_procs): - mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) - gpu_memory_gb = mem_info.used / 1024 / 1024 / 1024 - break - _maybe_shutdown_nvml(nvml_initialized) + pynvml.nvmlInit() + device_count = pynvml.nvmlDeviceGetCount() + for i in range(device_count): + handle = pynvml.nvmlDeviceGetHandleByIndex(i) + # 检查该GPU是否被当前进程使用 + compute_procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle) + if any(proc.pid == pid for proc in compute_procs): + mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) + gpu_memory_gb = mem_info.used / 1024 / 1024 / 1024 + break + pynvml.nvmlShutdown() except (psutil.NoSuchProcess, psutil.AccessDenied): pass @@ -234,18 +80,10 @@ def monitor_actor_memory(work_dir: str, interval: int = 60, object_limit: int = "gpu_mem_gb": [gpu_memory_gb], } - object_store_info = _get_object_store_stats(object_limit=object_limit, top_k=top_k) - object_store_info["time"] = current_time - object_store_info["object_limit"] = object_limit - object_store_info["top_k"] = top_k - # 写入文件 - json.dump(memory_info, actor_f, ensure_ascii=False) - actor_f.write("\n") - actor_f.flush() - json.dump(object_store_info, object_f, ensure_ascii=False) - object_f.write("\n") - object_f.flush() + json.dump(memory_info, f, ensure_ascii=False) + f.write("\n") + f.flush() for actor_name, memory_mb_info in memory_info.items(): if actor_name == "time": @@ -290,60 +128,13 @@ def monitor_actor_memory(work_dir: str, interval: int = 60, object_limit: int = global_step=count, ) - tb_writer_list[0].add_scalar( - tag="ray_object_store/total_size_mb", - scalar_value=float(object_store_info.get("total_size_mb", 0.0) or 0.0), - global_step=count, - ) - tb_writer_list[0].add_scalar( - tag="ray_object_store/total_objects", - scalar_value=float(object_store_info.get("total_objects", 0) or 0), - global_step=count, - ) - tb_writer_list[0].add_scalar( - tag="ray_object_store/detail_total_size_mb", - scalar_value=float(object_store_info.get("detail_total_size_mb", 0.0) or 0.0), - global_step=count, - ) - tb_writer_list[0].add_scalar( - tag="ray_object_store/detail_object_count", - scalar_value=float(object_store_info.get("detail_object_count", 0) or 0), - global_step=count, - ) - tb_writer_list[0].add_scalar( - tag="ray_object_store/detail_truncated", - scalar_value=float(object_store_info.get("detail_truncated", 0) or 0), - global_step=count, - ) - - for ref_type, value in (object_store_info.get("ref_type_counts", {}) or {}).items(): - tb_writer_list[0].add_scalar( - tag=f"ray_object_store/ref_type_count/{_sanitize_tag_component(str(ref_type))}", - scalar_value=float(value), - global_step=count, - ) - for ref_type, value in (object_store_info.get("detail_ref_type_size_mb", {}) or {}).items(): - tb_writer_list[0].add_scalar( - tag=f"ray_object_store/ref_type_size_mb/{_sanitize_tag_component(str(ref_type))}", - scalar_value=float(value), - global_step=count, - ) - for task_state, value in (object_store_info.get("task_state_counts", {}) or {}).items(): - tb_writer_list[0].add_scalar( - tag=f"ray_object_store/task_state_count/{_sanitize_tag_component(str(task_state))}", - scalar_value=float(value), - global_step=count, - ) - time.sleep(interval) print(memory_info) - print(object_store_info) except KeyboardInterrupt: print("\n监控已停止") finally: - actor_f.close() - object_f.close() + f.close() for tb_writer in tb_writer_list: tb_writer.close() @@ -352,19 +143,14 @@ def monitor_actor_memory(work_dir: str, interval: int = 60, object_limit: int = parser = argparse.ArgumentParser(description="RL MEMORY MONITOR") parser.add_argument("--work_dir", type=str, default="dense_8b") parser.add_argument("--interval", type=int, default=60) - parser.add_argument("--object_limit", type=int, default=5000) - parser.add_argument("--top_k", type=int, default=10) args = parser.parse_args() work_dir = args.work_dir interval = args.interval - object_limit = args.object_limit - top_k = args.top_k while True: try: - if not ray.is_initialized(): - ray.init(address="auto") - time.sleep(interval) + ray.init(address="auto") + time.sleep(interval) break except KeyboardInterrupt: print("\n监控已停止") @@ -372,4 +158,4 @@ def monitor_actor_memory(work_dir: str, interval: int = 60, object_limit: int = except Exception: print("连接 Ray 集群失败, 等等") - monitor_actor_memory(work_dir=work_dir, interval=interval, object_limit=object_limit, top_k=top_k) + monitor_actor_memory(work_dir=work_dir, interval=interval) From 7a5f1c3490a9b3ebd6f8976d3e76f9141275e0a5 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Tue, 7 Apr 2026 20:17:20 +0800 Subject: [PATCH 3/5] Restore RL CLI memory monitor to github/main --- xtuner/v1/train/cli/rl.py | 33 +++++++++++++-------------------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/xtuner/v1/train/cli/rl.py b/xtuner/v1/train/cli/rl.py index 0788c7c3e..0a91ee1ed 100644 --- a/xtuner/v1/train/cli/rl.py +++ b/xtuner/v1/train/cli/rl.py @@ -20,20 +20,19 @@ ) -def rl_monitor_actor_memory(work_dir, interval: int = 60, object_limit: int = 5000, top_k: int = 10): - if not ray.is_initialized(): - while True: - try: - ray.init(address="auto") - time.sleep(interval) - break - except KeyboardInterrupt: - print("\n监控已停止") - return - except Exception: - print("连接 Ray 集群失败, 等等") +def rl_monitor_actor_memory(work_dir, interval: int = 60): + while True: + try: + ray.init(address="auto") + time.sleep(interval) + break + except KeyboardInterrupt: + print("\n监控已停止") + break + except Exception: + print("连接 Ray 集群失败, 等等") - monitor_actor_memory(work_dir=work_dir, interval=interval, object_limit=object_limit, top_k=top_k) + monitor_actor_memory(work_dir=work_dir, interval=interval) @app.default() @@ -52,13 +51,7 @@ def main( if os.getenv("XTUNER_RL_MEM_DIR"): print("Start to monitor actor memory") - monitor_interval = int(os.getenv("XTUNER_RL_MEM_INTERVAL", "60")) - object_limit = int(os.getenv("XTUNER_RL_OBJECT_LIMIT", "5000")) - top_k = int(os.getenv("XTUNER_RL_OBJECT_TOP_K", "10")) - track_thread = threading.Thread( - target=rl_monitor_actor_memory, - args=(os.getenv("XTUNER_RL_MEM_DIR"), monitor_interval, object_limit, top_k), - ) + track_thread = threading.Thread(target=rl_monitor_actor_memory, args=(os.getenv("XTUNER_RL_MEM_DIR"),)) track_thread.daemon = True track_thread.start() From f4a5828339fbe5475b940596b0f921e793ee5eff Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Tue, 7 Apr 2026 20:52:59 +0800 Subject: [PATCH 4/5] resolve comments --- xtuner/v1/ray/dataflow/replay_buffer.py | 16 +++++++++++++--- xtuner/v1/rl/base/controller.py | 4 ++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/xtuner/v1/ray/dataflow/replay_buffer.py b/xtuner/v1/ray/dataflow/replay_buffer.py index add6bdd2e..9a0ea536e 100644 --- a/xtuner/v1/ray/dataflow/replay_buffer.py +++ b/xtuner/v1/ray/dataflow/replay_buffer.py @@ -351,6 +351,7 @@ def __init__(self, replay_buffer_cfg): self.enable_partial_rollout: bool = False self.tail_batch_candidate_steps: int = 0 self.tail_batch_trigger_size: int = 0 + self._empty_env_ref: Optional[ObjectRef] = None self._completed_actions: Dict[int, List[int]] = defaultdict(list) self._aborted_actions: Dict[int, List[int]] = defaultdict(list) @@ -366,11 +367,19 @@ def __init__(self, replay_buffer_cfg): self.sample_from_aborted_count = 0 self.sample_from_expired_count = 0 + def _get_empty_env_ref(self) -> ObjectRef: + if self._empty_env_ref is None: + self._empty_env_ref = ray.put(RLEnvDataItem()) + return self._empty_env_ref + + def _filter_releasable_refs(self, refs: List[ObjectRef]) -> List[ObjectRef]: + return [ref for ref in refs if ref is not None and ref != self._empty_env_ref] + def _free_replay_meta_refs(self, replay_meta: ReplayMeta, include_action_ref: bool = True): refs = [] if include_action_ref and replay_meta.action_ref is not None: refs.append(replay_meta.action_ref) - refs.extend([ref for ref in replay_meta.observation_refs if ref is not None]) + refs.extend(self._filter_releasable_refs(replay_meta.observation_refs)) if refs: ray.internal.free(refs, local_only=False) @@ -386,10 +395,11 @@ def _update_replay_meta_state(self, replay_meta: ReplayMeta, new_state: RolloutS def _strip_rollout_payload_for_rerun(self, replay_meta: ReplayMeta, new_state: RolloutState): """Keep prompt refs only and drop rollout outputs that will not be reused.""" - old_obs_refs = [ref for ref in replay_meta.observation_refs if ref is not None] + old_obs_refs = self._filter_releasable_refs(replay_meta.observation_refs) if old_obs_refs: ray.internal.free(old_obs_refs, local_only=False) - replay_meta.observation_refs = [ray.put(RLEnvDataItem()) for _ in replay_meta.observation_ids] + empty_env_ref = self._get_empty_env_ref() + replay_meta.observation_refs = [empty_env_ref for _ in replay_meta.observation_ids] replay_meta.extra_info.update( { "payload_mode": "prompt_only", diff --git a/xtuner/v1/rl/base/controller.py b/xtuner/v1/rl/base/controller.py index 0f3864044..4045b6eb2 100644 --- a/xtuner/v1/rl/base/controller.py +++ b/xtuner/v1/rl/base/controller.py @@ -36,6 +36,10 @@ def _collect_object_refs(self, obj, refs: list[ObjectRef]): if isinstance(obj, (list, tuple)): for item in obj: self._collect_object_refs(item, refs) + return + if isinstance(obj, dict): + for value in obj.values(): + self._collect_object_refs(value, refs) def _free_batch_object_refs(self, data_batches): refs: list[ObjectRef] = [] From 1272c09f15752fe95fec69b5736c7cd27f258bea Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Tue, 7 Apr 2026 20:57:23 +0800 Subject: [PATCH 5/5] Refactor summarize_group_payload to simplify multimodal prompt check and improve docstring formatting in replay_buffer.py --- xtuner/v1/ray/dataflow/replay_buffer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xtuner/v1/ray/dataflow/replay_buffer.py b/xtuner/v1/ray/dataflow/replay_buffer.py index 9a0ea536e..15dae046f 100644 --- a/xtuner/v1/ray/dataflow/replay_buffer.py +++ b/xtuner/v1/ray/dataflow/replay_buffer.py @@ -79,9 +79,8 @@ def summarize_group_payload(grouped_dataitem: List[RLDataFlowItem]) -> Dict[str, return summary first_data = grouped_dataitem[0].data - summary["has_multimodal_prompt"] = bool( - getattr(first_data, "multimodal_train_info", None) and len(first_data.multimodal_train_info) > 0 - ) + multimodal_train_info = getattr(first_data, "multimodal_train_info", None) + summary["has_multimodal_prompt"] = bool(multimodal_train_info) for item in grouped_dataitem: rollout = item.env.rollout @@ -394,7 +393,8 @@ def _update_replay_meta_state(self, replay_meta: ReplayMeta, new_state: RolloutS replay_meta.state = new_state def _strip_rollout_payload_for_rerun(self, replay_meta: ReplayMeta, new_state: RolloutState): - """Keep prompt refs only and drop rollout outputs that will not be reused.""" + """Keep prompt refs only and drop rollout outputs that will not be + reused.""" old_obs_refs = self._filter_releasable_refs(replay_meta.observation_refs) if old_obs_refs: ray.internal.free(old_obs_refs, local_only=False)