Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 157 additions & 3 deletions xtuner/v1/ray/dataflow/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,48 @@ class ReplayMeta:
extra_info: Dict[str, Any] = field(default_factory=dict)


def summarize_group_payload(grouped_dataitem: List[RLDataFlowItem]) -> Dict[str, Any]:
summary: Dict[str, Any] = {
"payload_mode": "full",
"observation_count": len(grouped_dataitem),
"response_tokens": 0,
"response_chars": 0,
"versioned_segments": 0,
"versioned_tokens": 0,
"routed_expert_payloads": 0,
"judged_observations": 0,
"has_multimodal_prompt": False,
}
if not grouped_dataitem:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit: summarize_group_payload is a public (module-level) function. Per project standards it should have a docstring with Google Style format, parameter types, and return type.

return summary

first_data = grouped_dataitem[0].data
multimodal_train_info = getattr(first_data, "multimodal_train_info", None)
summary["has_multimodal_prompt"] = bool(multimodal_train_info)

for item in grouped_dataitem:
rollout = item.env.rollout
judger = item.env.judger
response_ids = rollout.response_ids or []
response_text = rollout.response or ""
versioned_response_ids = rollout.versioned_response_ids or []
versioned_num_return_tokens = rollout.versioned_num_return_tokens or []

summary["response_tokens"] += len(response_ids)
summary["response_chars"] += len(response_text)
summary["versioned_segments"] += len(versioned_response_ids)
if versioned_num_return_tokens:
summary["versioned_tokens"] += sum(versioned_num_return_tokens)
else:
summary["versioned_tokens"] += sum(len(ids) for ids in versioned_response_ids)
if rollout.extra_info.get("routed_experts", None) is not None:
summary["routed_expert_payloads"] += 1
if judger.uid is not None or judger.reward.get("score", 0.0) != 0.0 or len(judger.extra_info) > 0:
summary["judged_observations"] += 1

return summary


def determine_group_state(group_data_items: List[RLDataFlowItem]) -> RolloutState:
"""Determines the processing strategy for a group of rollout samples based
on their state."""
Expand Down Expand Up @@ -113,7 +155,7 @@ def mapping_dataitem_to_replaymeta(grouped_dataitem: List[RLDataFlowItem]) -> Re
observation_refs=observation_refs,
state=group_state,
version=group_version,
extra_info={},
extra_info=summarize_group_payload(grouped_dataitem),
)
return replay_meta

Expand Down Expand Up @@ -308,6 +350,7 @@ def __init__(self, replay_buffer_cfg):
self.enable_partial_rollout: bool = False
self.tail_batch_candidate_steps: int = 0
self.tail_batch_trigger_size: int = 0
self._empty_env_ref: Optional[ObjectRef] = None

self._completed_actions: Dict[int, List[int]] = defaultdict(list)
self._aborted_actions: Dict[int, List[int]] = defaultdict(list)
Expand All @@ -323,6 +366,97 @@ def __init__(self, replay_buffer_cfg):
self.sample_from_aborted_count = 0
self.sample_from_expired_count = 0

def _get_empty_env_ref(self) -> ObjectRef:
if self._empty_env_ref is None:
self._empty_env_ref = ray.put(RLEnvDataItem())
return self._empty_env_ref

def _filter_releasable_refs(self, refs: List[ObjectRef]) -> List[ObjectRef]:
return [ref for ref in refs if ref is not None and ref != self._empty_env_ref]

def _free_replay_meta_refs(self, replay_meta: ReplayMeta, include_action_ref: bool = True):
refs = []
if include_action_ref and replay_meta.action_ref is not None:
refs.append(replay_meta.action_ref)
refs.extend(self._filter_releasable_refs(replay_meta.observation_refs))
if refs:
ray.internal.free(refs, local_only=False)

def _update_replay_meta_state(self, replay_meta: ReplayMeta, new_state: RolloutState):
for observation_id in replay_meta.observation_ids:
old_state = self._observations2states.get(observation_id)
if old_state and observation_id in self._states.get(old_state, []):
self._states[old_state].remove(observation_id)
self._observations2states[observation_id] = new_state
if observation_id not in self._states[new_state]:
self._states[new_state].append(observation_id)
replay_meta.state = new_state

def _strip_rollout_payload_for_rerun(self, replay_meta: ReplayMeta, new_state: RolloutState):
"""Keep prompt refs only and drop rollout outputs that will not be
reused."""
old_obs_refs = self._filter_releasable_refs(replay_meta.observation_refs)
if old_obs_refs:
ray.internal.free(old_obs_refs, local_only=False)
empty_env_ref = self._get_empty_env_ref()
replay_meta.observation_refs = [empty_env_ref for _ in replay_meta.observation_ids]
replay_meta.extra_info.update(
{
"payload_mode": "prompt_only",
"response_tokens": 0,
"response_chars": 0,
"versioned_segments": 0,
"versioned_tokens": 0,
"routed_expert_payloads": 0,
"judged_observations": 0,
}
)
self._update_replay_meta_state(replay_meta, new_state)

def get_storage_stats(self) -> Dict[str, float]:
stats: Dict[str, float] = {
"tracked_actions_count": float(len(self._actions)),
"tracked_roots_count": float(len(self._root2actions)),
"tracked_observations_count": float(len(self._observations)),
"completed_actions_count": float(sum(len(bucket) for bucket in self._completed_actions.values())),
"aborted_actions_count": float(sum(len(bucket) for bucket in self._aborted_actions.values())),
"expired_actions_count": float(len(self._expired_actions)),
"completed_versions_count": float(len(self._completed_actions)),
"aborted_versions_count": float(len(self._aborted_actions)),
"payload_full_actions_count": 0.0,
"payload_prompt_only_actions_count": 0.0,
"payload_full_observations_count": 0.0,
"payload_prompt_only_observations_count": 0.0,
"stored_response_tokens": 0.0,
"stored_response_chars": 0.0,
"stored_versioned_segments": 0.0,
"stored_versioned_tokens": 0.0,
"stored_routed_expert_payloads": 0.0,
"stored_judged_observations": 0.0,
"multimodal_actions_count": 0.0,
}

for replay_meta in self._actions.values():
summary = replay_meta.extra_info
observation_count = float(summary.get("observation_count", len(replay_meta.observation_ids)))
if summary.get("payload_mode", "full") == "prompt_only":
stats["payload_prompt_only_actions_count"] += 1.0
stats["payload_prompt_only_observations_count"] += observation_count
else:
stats["payload_full_actions_count"] += 1.0
stats["payload_full_observations_count"] += observation_count

stats["stored_response_tokens"] += float(summary.get("response_tokens", 0))
stats["stored_response_chars"] += float(summary.get("response_chars", 0))
stats["stored_versioned_segments"] += float(summary.get("versioned_segments", 0))
stats["stored_versioned_tokens"] += float(summary.get("versioned_tokens", 0))
stats["stored_routed_expert_payloads"] += float(summary.get("routed_expert_payloads", 0))
stats["stored_judged_observations"] += float(summary.get("judged_observations", 0))
if summary.get("has_multimodal_prompt", False):
stats["multimodal_actions_count"] += 1.0

return stats

def add(self, grouped_dataitem: List[RLDataFlowItem]):
"""Adds a group of data items to the storage.

Expand Down Expand Up @@ -426,6 +560,8 @@ def sample(self, sample_from_expired_states) -> List[RLDataFlowItem]:
return []

def clear(self):
for replay_meta in self._actions.values():
self._free_replay_meta_refs(replay_meta)
attrs_to_clear = [
"_aborted_actions",
"_completed_actions",
Expand Down Expand Up @@ -699,6 +835,10 @@ def _check_completed_samples_expired(self):

for version in expired_versions:
bucket = self._completed_actions.pop(version)
for action_id in bucket:
replay_meta = self._actions.get(action_id)
if replay_meta is not None:
self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.EXPIRED)
self._expired_actions.extend(bucket)
self.logger.info(
f"Moved {len(bucket)} completed samples with version {version} to expired samples due to exceeding tail_batch_candidate_steps."
Expand All @@ -709,6 +849,10 @@ def _check_completed_samples_aborted(self):
return

for version, bucket in self._completed_actions.items():
for action_id in bucket:
replay_meta = self._actions.get(action_id)
if replay_meta is not None:
self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.ABORTED)
self._aborted_actions[0].extend(bucket)
self.logger.info(
f"Moved {len(bucket)} completed samples with version {version} to aborted samples due to partial rollout disabled."
Expand All @@ -729,7 +873,9 @@ def _clear_meta_for_actions(self, replay_meta: ReplayMeta):
if state and observation_id in self._states.get(state, []):
self._states[state].remove(observation_id)

self._actions.pop(action_id, None)
self._action2observations.pop(action_id, None)
self._free_replay_meta_refs(replay_meta)
del replay_meta

def _clear_meta_for_root(self, replay_meta: ReplayMeta):
Expand All @@ -747,13 +893,16 @@ def _clear_meta_for_root(self, replay_meta: ReplayMeta):
and clear all related actions.
"""
root_id = replay_meta.root_id
current_action_id = replay_meta.action_id
self._clear_meta_for_actions(replay_meta)
if root_id in self._root2actions:
for action_id in self._root2actions[root_id]:
if action_id == current_action_id:
continue
new_replay_meta = self._actions.pop(action_id, None)
if new_replay_meta:
self._clear_meta_for_actions(new_replay_meta)
Comment on lines 899 to 904
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit: _clear_meta_for_actions now pops from self._actions internally (line 866 — good, makes it a proper "single source of truth" for cleanup). However, this loop still does self._actions.pop(action_id, None) before calling _clear_meta_for_actions(new_replay_meta), resulting in a redundant double-pop on self._actions for the same key.

Not a bug (the second pop returns None harmlessly), but it's confusing for readers since _clear_meta_for_actions is documented as "the single source of truth for deleting an action." Consider letting _clear_meta_for_actions own the pop:

for action_id in self._root2actions[root_id]:
    if action_id == current_action_id:
        continue
    new_replay_meta = self._actions.get(action_id)
    if new_replay_meta:
        self._clear_meta_for_actions(new_replay_meta)

del self._root2actions[root_id]
del replay_meta

def _check_rollout_state_and_insert(self, replay_meta: ReplayMeta):
"""Checks the rollout state of a ReplayMeta object and inserts its
Expand All @@ -775,11 +924,14 @@ def _check_rollout_state_and_insert(self, replay_meta: ReplayMeta):
if state == RolloutState.ABORTED:
if self.tail_batch_candidate_steps > 0 and replay_meta.version >= self.tail_batch_candidate_steps:
# 过期的数据需要重置状态
self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.EXPIRED)
self._expired_actions.append(action_id)
self.logger.debug(
f"Add expired sample with action_id: {action_id} to _expired_actions because version: {replay_meta.version} >= tail_batch_candidate_steps: {self.tail_batch_candidate_steps}."
)
else:
if not self.enable_partial_rollout:
self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.ABORTED)
self._aborted_actions[replay_meta.version].append(action_id)
self.logger.debug(
f"Add aborted sample with action_id: {action_id} version: {replay_meta.version} to _aborted_actions."
Expand Down Expand Up @@ -903,14 +1055,16 @@ def add(self, grouped_dataitem: List[RLDataFlowItem]):
self.storage.add(grouped_dataitem)

def status(self):
return {
status = {
"remain_completed_samples_count": self.storage.completed_samples_count,
"remain_aborted_samples_count": self.storage.aborted_samples_count,
"remain_expired_samples_count": self.storage.expired_samples_count,
"sample_from_dataset_count": self.sample_from_dataset_count,
"sample_from_aborted_count": self.storage.sample_from_aborted_count,
"sample_from_expired_count": self.storage.sample_from_expired_count,
}
status.update(self.storage.get_storage_stats())
return status

def save(self, file_path: Path | str):
"""Saves the replay buffer's storage to a file.
Expand Down
27 changes: 26 additions & 1 deletion xtuner/v1/rl/base/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import ray
import torch
from ray import ObjectRef
from ray.actor import ActorProxy

from xtuner.v1.data_proto.sequence_context import SequenceContext
Expand All @@ -28,6 +29,27 @@ class RawTrainingController:
def __init__(self, workers: list[TrainingWorker]) -> None:
self.workers = workers

def _collect_object_refs(self, obj, refs: list[ObjectRef]):
if isinstance(obj, ObjectRef):
refs.append(obj)
return
Comment on lines 29 to +35
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning: _collect_object_refs recurses into lists/tuples but doesn't handle dict values. If pixel_values or rollout_routed_experts ever hold refs inside a dict (e.g., via extra_info), those would be missed.

Not necessarily a problem today since pixel_values is torch.FloatTensor | None and rollout_routed_experts is torch.Tensor | None, but the recursive traversal pattern suggests it's intended to be generic. If so, consider also handling dicts. If not, the method name/docstring should clarify it only handles lists/tuples of ObjectRef.

if isinstance(obj, (list, tuple)):
for item in obj:
self._collect_object_refs(item, refs)
return
if isinstance(obj, dict):
for value in obj.values():
self._collect_object_refs(value, refs)

def _free_batch_object_refs(self, data_batches):
refs: list[ObjectRef] = []
for data in data_batches:
seq_ctx = data["seq_ctx"]
self._collect_object_refs(seq_ctx.pixel_values, refs)
self._collect_object_refs(seq_ctx.rollout_routed_experts, refs)
if refs:
ray.internal.free(refs, local_only=False)

# TODO(hha): 这个逻辑不够通用,应该复用 sft 函数,从而支持 expand soft pack
def _get_pack_infos(self, dataset, num_tokens, target, random=None):
inds = list(range(len(dataset)))
Expand Down Expand Up @@ -260,7 +282,10 @@ def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx:
rollout_idx=rollout_idx,
)
)
log_infos = ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT)
try:
log_infos = ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT)
finally:
self._free_batch_object_refs(packed_data_batches)
return log_infos

@ray_method
Expand Down
Loading