diff --git a/xtuner/v1/ray/dataflow/replay_buffer.py b/xtuner/v1/ray/dataflow/replay_buffer.py index cdf10fd68..15dae046f 100644 --- a/xtuner/v1/ray/dataflow/replay_buffer.py +++ b/xtuner/v1/ray/dataflow/replay_buffer.py @@ -63,6 +63,48 @@ class ReplayMeta: extra_info: Dict[str, Any] = field(default_factory=dict) +def summarize_group_payload(grouped_dataitem: List[RLDataFlowItem]) -> Dict[str, Any]: + summary: Dict[str, Any] = { + "payload_mode": "full", + "observation_count": len(grouped_dataitem), + "response_tokens": 0, + "response_chars": 0, + "versioned_segments": 0, + "versioned_tokens": 0, + "routed_expert_payloads": 0, + "judged_observations": 0, + "has_multimodal_prompt": False, + } + if not grouped_dataitem: + return summary + + first_data = grouped_dataitem[0].data + multimodal_train_info = getattr(first_data, "multimodal_train_info", None) + summary["has_multimodal_prompt"] = bool(multimodal_train_info) + + for item in grouped_dataitem: + rollout = item.env.rollout + judger = item.env.judger + response_ids = rollout.response_ids or [] + response_text = rollout.response or "" + versioned_response_ids = rollout.versioned_response_ids or [] + versioned_num_return_tokens = rollout.versioned_num_return_tokens or [] + + summary["response_tokens"] += len(response_ids) + summary["response_chars"] += len(response_text) + summary["versioned_segments"] += len(versioned_response_ids) + if versioned_num_return_tokens: + summary["versioned_tokens"] += sum(versioned_num_return_tokens) + else: + summary["versioned_tokens"] += sum(len(ids) for ids in versioned_response_ids) + if rollout.extra_info.get("routed_experts", None) is not None: + summary["routed_expert_payloads"] += 1 + if judger.uid is not None or judger.reward.get("score", 0.0) != 0.0 or len(judger.extra_info) > 0: + summary["judged_observations"] += 1 + + return summary + + def determine_group_state(group_data_items: List[RLDataFlowItem]) -> RolloutState: """Determines the processing strategy for a group of rollout samples based on their state.""" @@ -113,7 +155,7 @@ def mapping_dataitem_to_replaymeta(grouped_dataitem: List[RLDataFlowItem]) -> Re observation_refs=observation_refs, state=group_state, version=group_version, - extra_info={}, + extra_info=summarize_group_payload(grouped_dataitem), ) return replay_meta @@ -308,6 +350,7 @@ def __init__(self, replay_buffer_cfg): self.enable_partial_rollout: bool = False self.tail_batch_candidate_steps: int = 0 self.tail_batch_trigger_size: int = 0 + self._empty_env_ref: Optional[ObjectRef] = None self._completed_actions: Dict[int, List[int]] = defaultdict(list) self._aborted_actions: Dict[int, List[int]] = defaultdict(list) @@ -323,6 +366,97 @@ def __init__(self, replay_buffer_cfg): self.sample_from_aborted_count = 0 self.sample_from_expired_count = 0 + def _get_empty_env_ref(self) -> ObjectRef: + if self._empty_env_ref is None: + self._empty_env_ref = ray.put(RLEnvDataItem()) + return self._empty_env_ref + + def _filter_releasable_refs(self, refs: List[ObjectRef]) -> List[ObjectRef]: + return [ref for ref in refs if ref is not None and ref != self._empty_env_ref] + + def _free_replay_meta_refs(self, replay_meta: ReplayMeta, include_action_ref: bool = True): + refs = [] + if include_action_ref and replay_meta.action_ref is not None: + refs.append(replay_meta.action_ref) + refs.extend(self._filter_releasable_refs(replay_meta.observation_refs)) + if refs: + ray.internal.free(refs, local_only=False) + + def _update_replay_meta_state(self, replay_meta: ReplayMeta, new_state: RolloutState): + for observation_id in replay_meta.observation_ids: + old_state = self._observations2states.get(observation_id) + if old_state and observation_id in self._states.get(old_state, []): + self._states[old_state].remove(observation_id) + self._observations2states[observation_id] = new_state + if observation_id not in self._states[new_state]: + self._states[new_state].append(observation_id) + replay_meta.state = new_state + + def _strip_rollout_payload_for_rerun(self, replay_meta: ReplayMeta, new_state: RolloutState): + """Keep prompt refs only and drop rollout outputs that will not be + reused.""" + old_obs_refs = self._filter_releasable_refs(replay_meta.observation_refs) + if old_obs_refs: + ray.internal.free(old_obs_refs, local_only=False) + empty_env_ref = self._get_empty_env_ref() + replay_meta.observation_refs = [empty_env_ref for _ in replay_meta.observation_ids] + replay_meta.extra_info.update( + { + "payload_mode": "prompt_only", + "response_tokens": 0, + "response_chars": 0, + "versioned_segments": 0, + "versioned_tokens": 0, + "routed_expert_payloads": 0, + "judged_observations": 0, + } + ) + self._update_replay_meta_state(replay_meta, new_state) + + def get_storage_stats(self) -> Dict[str, float]: + stats: Dict[str, float] = { + "tracked_actions_count": float(len(self._actions)), + "tracked_roots_count": float(len(self._root2actions)), + "tracked_observations_count": float(len(self._observations)), + "completed_actions_count": float(sum(len(bucket) for bucket in self._completed_actions.values())), + "aborted_actions_count": float(sum(len(bucket) for bucket in self._aborted_actions.values())), + "expired_actions_count": float(len(self._expired_actions)), + "completed_versions_count": float(len(self._completed_actions)), + "aborted_versions_count": float(len(self._aborted_actions)), + "payload_full_actions_count": 0.0, + "payload_prompt_only_actions_count": 0.0, + "payload_full_observations_count": 0.0, + "payload_prompt_only_observations_count": 0.0, + "stored_response_tokens": 0.0, + "stored_response_chars": 0.0, + "stored_versioned_segments": 0.0, + "stored_versioned_tokens": 0.0, + "stored_routed_expert_payloads": 0.0, + "stored_judged_observations": 0.0, + "multimodal_actions_count": 0.0, + } + + for replay_meta in self._actions.values(): + summary = replay_meta.extra_info + observation_count = float(summary.get("observation_count", len(replay_meta.observation_ids))) + if summary.get("payload_mode", "full") == "prompt_only": + stats["payload_prompt_only_actions_count"] += 1.0 + stats["payload_prompt_only_observations_count"] += observation_count + else: + stats["payload_full_actions_count"] += 1.0 + stats["payload_full_observations_count"] += observation_count + + stats["stored_response_tokens"] += float(summary.get("response_tokens", 0)) + stats["stored_response_chars"] += float(summary.get("response_chars", 0)) + stats["stored_versioned_segments"] += float(summary.get("versioned_segments", 0)) + stats["stored_versioned_tokens"] += float(summary.get("versioned_tokens", 0)) + stats["stored_routed_expert_payloads"] += float(summary.get("routed_expert_payloads", 0)) + stats["stored_judged_observations"] += float(summary.get("judged_observations", 0)) + if summary.get("has_multimodal_prompt", False): + stats["multimodal_actions_count"] += 1.0 + + return stats + def add(self, grouped_dataitem: List[RLDataFlowItem]): """Adds a group of data items to the storage. @@ -426,6 +560,8 @@ def sample(self, sample_from_expired_states) -> List[RLDataFlowItem]: return [] def clear(self): + for replay_meta in self._actions.values(): + self._free_replay_meta_refs(replay_meta) attrs_to_clear = [ "_aborted_actions", "_completed_actions", @@ -699,6 +835,10 @@ def _check_completed_samples_expired(self): for version in expired_versions: bucket = self._completed_actions.pop(version) + for action_id in bucket: + replay_meta = self._actions.get(action_id) + if replay_meta is not None: + self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.EXPIRED) self._expired_actions.extend(bucket) self.logger.info( f"Moved {len(bucket)} completed samples with version {version} to expired samples due to exceeding tail_batch_candidate_steps." @@ -709,6 +849,10 @@ def _check_completed_samples_aborted(self): return for version, bucket in self._completed_actions.items(): + for action_id in bucket: + replay_meta = self._actions.get(action_id) + if replay_meta is not None: + self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.ABORTED) self._aborted_actions[0].extend(bucket) self.logger.info( f"Moved {len(bucket)} completed samples with version {version} to aborted samples due to partial rollout disabled." @@ -729,7 +873,9 @@ def _clear_meta_for_actions(self, replay_meta: ReplayMeta): if state and observation_id in self._states.get(state, []): self._states[state].remove(observation_id) + self._actions.pop(action_id, None) self._action2observations.pop(action_id, None) + self._free_replay_meta_refs(replay_meta) del replay_meta def _clear_meta_for_root(self, replay_meta: ReplayMeta): @@ -747,13 +893,16 @@ def _clear_meta_for_root(self, replay_meta: ReplayMeta): and clear all related actions. """ root_id = replay_meta.root_id + current_action_id = replay_meta.action_id + self._clear_meta_for_actions(replay_meta) if root_id in self._root2actions: for action_id in self._root2actions[root_id]: + if action_id == current_action_id: + continue new_replay_meta = self._actions.pop(action_id, None) if new_replay_meta: self._clear_meta_for_actions(new_replay_meta) del self._root2actions[root_id] - del replay_meta def _check_rollout_state_and_insert(self, replay_meta: ReplayMeta): """Checks the rollout state of a ReplayMeta object and inserts its @@ -775,11 +924,14 @@ def _check_rollout_state_and_insert(self, replay_meta: ReplayMeta): if state == RolloutState.ABORTED: if self.tail_batch_candidate_steps > 0 and replay_meta.version >= self.tail_batch_candidate_steps: # 过期的数据需要重置状态 + self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.EXPIRED) self._expired_actions.append(action_id) self.logger.debug( f"Add expired sample with action_id: {action_id} to _expired_actions because version: {replay_meta.version} >= tail_batch_candidate_steps: {self.tail_batch_candidate_steps}." ) else: + if not self.enable_partial_rollout: + self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.ABORTED) self._aborted_actions[replay_meta.version].append(action_id) self.logger.debug( f"Add aborted sample with action_id: {action_id} version: {replay_meta.version} to _aborted_actions." @@ -903,7 +1055,7 @@ def add(self, grouped_dataitem: List[RLDataFlowItem]): self.storage.add(grouped_dataitem) def status(self): - return { + status = { "remain_completed_samples_count": self.storage.completed_samples_count, "remain_aborted_samples_count": self.storage.aborted_samples_count, "remain_expired_samples_count": self.storage.expired_samples_count, @@ -911,6 +1063,8 @@ def status(self): "sample_from_aborted_count": self.storage.sample_from_aborted_count, "sample_from_expired_count": self.storage.sample_from_expired_count, } + status.update(self.storage.get_storage_stats()) + return status def save(self, file_path: Path | str): """Saves the replay buffer's storage to a file. diff --git a/xtuner/v1/rl/base/controller.py b/xtuner/v1/rl/base/controller.py index b500b53e4..4045b6eb2 100644 --- a/xtuner/v1/rl/base/controller.py +++ b/xtuner/v1/rl/base/controller.py @@ -4,6 +4,7 @@ import ray import torch +from ray import ObjectRef from ray.actor import ActorProxy from xtuner.v1.data_proto.sequence_context import SequenceContext @@ -28,6 +29,27 @@ class RawTrainingController: def __init__(self, workers: list[TrainingWorker]) -> None: self.workers = workers + def _collect_object_refs(self, obj, refs: list[ObjectRef]): + if isinstance(obj, ObjectRef): + refs.append(obj) + return + if isinstance(obj, (list, tuple)): + for item in obj: + self._collect_object_refs(item, refs) + return + if isinstance(obj, dict): + for value in obj.values(): + self._collect_object_refs(value, refs) + + def _free_batch_object_refs(self, data_batches): + refs: list[ObjectRef] = [] + for data in data_batches: + seq_ctx = data["seq_ctx"] + self._collect_object_refs(seq_ctx.pixel_values, refs) + self._collect_object_refs(seq_ctx.rollout_routed_experts, refs) + if refs: + ray.internal.free(refs, local_only=False) + # TODO(hha): 这个逻辑不够通用,应该复用 sft 函数,从而支持 expand soft pack def _get_pack_infos(self, dataset, num_tokens, target, random=None): inds = list(range(len(dataset))) @@ -260,7 +282,10 @@ def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: rollout_idx=rollout_idx, ) ) - log_infos = ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT) + try: + log_infos = ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT) + finally: + self._free_batch_object_refs(packed_data_batches) return log_infos @ray_method