-
Notifications
You must be signed in to change notification settings - Fork 414
fix rl mem leak #1646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
fix rl mem leak #1646
Changes from all commits
e04db0c
0a7cf3a
7a5f1c3
f4a5828
1272c09
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -63,6 +63,48 @@ class ReplayMeta: | |
| extra_info: Dict[str, Any] = field(default_factory=dict) | ||
|
|
||
|
|
||
| def summarize_group_payload(grouped_dataitem: List[RLDataFlowItem]) -> Dict[str, Any]: | ||
| summary: Dict[str, Any] = { | ||
| "payload_mode": "full", | ||
| "observation_count": len(grouped_dataitem), | ||
| "response_tokens": 0, | ||
| "response_chars": 0, | ||
| "versioned_segments": 0, | ||
| "versioned_tokens": 0, | ||
| "routed_expert_payloads": 0, | ||
| "judged_observations": 0, | ||
| "has_multimodal_prompt": False, | ||
| } | ||
| if not grouped_dataitem: | ||
| return summary | ||
|
|
||
| first_data = grouped_dataitem[0].data | ||
| multimodal_train_info = getattr(first_data, "multimodal_train_info", None) | ||
| summary["has_multimodal_prompt"] = bool(multimodal_train_info) | ||
|
|
||
| for item in grouped_dataitem: | ||
| rollout = item.env.rollout | ||
| judger = item.env.judger | ||
| response_ids = rollout.response_ids or [] | ||
| response_text = rollout.response or "" | ||
| versioned_response_ids = rollout.versioned_response_ids or [] | ||
| versioned_num_return_tokens = rollout.versioned_num_return_tokens or [] | ||
|
|
||
| summary["response_tokens"] += len(response_ids) | ||
| summary["response_chars"] += len(response_text) | ||
| summary["versioned_segments"] += len(versioned_response_ids) | ||
| if versioned_num_return_tokens: | ||
| summary["versioned_tokens"] += sum(versioned_num_return_tokens) | ||
| else: | ||
| summary["versioned_tokens"] += sum(len(ids) for ids in versioned_response_ids) | ||
| if rollout.extra_info.get("routed_experts", None) is not None: | ||
| summary["routed_expert_payloads"] += 1 | ||
| if judger.uid is not None or judger.reward.get("score", 0.0) != 0.0 or len(judger.extra_info) > 0: | ||
| summary["judged_observations"] += 1 | ||
|
|
||
| return summary | ||
|
|
||
|
|
||
| def determine_group_state(group_data_items: List[RLDataFlowItem]) -> RolloutState: | ||
| """Determines the processing strategy for a group of rollout samples based | ||
| on their state.""" | ||
|
|
@@ -113,7 +155,7 @@ def mapping_dataitem_to_replaymeta(grouped_dataitem: List[RLDataFlowItem]) -> Re | |
| observation_refs=observation_refs, | ||
| state=group_state, | ||
| version=group_version, | ||
| extra_info={}, | ||
| extra_info=summarize_group_payload(grouped_dataitem), | ||
| ) | ||
| return replay_meta | ||
|
|
||
|
|
@@ -308,6 +350,7 @@ def __init__(self, replay_buffer_cfg): | |
| self.enable_partial_rollout: bool = False | ||
| self.tail_batch_candidate_steps: int = 0 | ||
| self.tail_batch_trigger_size: int = 0 | ||
| self._empty_env_ref: Optional[ObjectRef] = None | ||
|
|
||
| self._completed_actions: Dict[int, List[int]] = defaultdict(list) | ||
| self._aborted_actions: Dict[int, List[int]] = defaultdict(list) | ||
|
|
@@ -323,6 +366,97 @@ def __init__(self, replay_buffer_cfg): | |
| self.sample_from_aborted_count = 0 | ||
| self.sample_from_expired_count = 0 | ||
|
|
||
| def _get_empty_env_ref(self) -> ObjectRef: | ||
| if self._empty_env_ref is None: | ||
| self._empty_env_ref = ray.put(RLEnvDataItem()) | ||
| return self._empty_env_ref | ||
|
|
||
| def _filter_releasable_refs(self, refs: List[ObjectRef]) -> List[ObjectRef]: | ||
| return [ref for ref in refs if ref is not None and ref != self._empty_env_ref] | ||
|
|
||
| def _free_replay_meta_refs(self, replay_meta: ReplayMeta, include_action_ref: bool = True): | ||
| refs = [] | ||
| if include_action_ref and replay_meta.action_ref is not None: | ||
| refs.append(replay_meta.action_ref) | ||
| refs.extend(self._filter_releasable_refs(replay_meta.observation_refs)) | ||
| if refs: | ||
| ray.internal.free(refs, local_only=False) | ||
|
|
||
| def _update_replay_meta_state(self, replay_meta: ReplayMeta, new_state: RolloutState): | ||
| for observation_id in replay_meta.observation_ids: | ||
| old_state = self._observations2states.get(observation_id) | ||
| if old_state and observation_id in self._states.get(old_state, []): | ||
| self._states[old_state].remove(observation_id) | ||
| self._observations2states[observation_id] = new_state | ||
| if observation_id not in self._states[new_state]: | ||
| self._states[new_state].append(observation_id) | ||
| replay_meta.state = new_state | ||
|
|
||
| def _strip_rollout_payload_for_rerun(self, replay_meta: ReplayMeta, new_state: RolloutState): | ||
| """Keep prompt refs only and drop rollout outputs that will not be | ||
| reused.""" | ||
| old_obs_refs = self._filter_releasable_refs(replay_meta.observation_refs) | ||
| if old_obs_refs: | ||
| ray.internal.free(old_obs_refs, local_only=False) | ||
| empty_env_ref = self._get_empty_env_ref() | ||
| replay_meta.observation_refs = [empty_env_ref for _ in replay_meta.observation_ids] | ||
| replay_meta.extra_info.update( | ||
| { | ||
| "payload_mode": "prompt_only", | ||
| "response_tokens": 0, | ||
| "response_chars": 0, | ||
| "versioned_segments": 0, | ||
| "versioned_tokens": 0, | ||
| "routed_expert_payloads": 0, | ||
| "judged_observations": 0, | ||
| } | ||
| ) | ||
| self._update_replay_meta_state(replay_meta, new_state) | ||
|
|
||
| def get_storage_stats(self) -> Dict[str, float]: | ||
| stats: Dict[str, float] = { | ||
| "tracked_actions_count": float(len(self._actions)), | ||
| "tracked_roots_count": float(len(self._root2actions)), | ||
| "tracked_observations_count": float(len(self._observations)), | ||
| "completed_actions_count": float(sum(len(bucket) for bucket in self._completed_actions.values())), | ||
| "aborted_actions_count": float(sum(len(bucket) for bucket in self._aborted_actions.values())), | ||
| "expired_actions_count": float(len(self._expired_actions)), | ||
| "completed_versions_count": float(len(self._completed_actions)), | ||
| "aborted_versions_count": float(len(self._aborted_actions)), | ||
| "payload_full_actions_count": 0.0, | ||
| "payload_prompt_only_actions_count": 0.0, | ||
| "payload_full_observations_count": 0.0, | ||
| "payload_prompt_only_observations_count": 0.0, | ||
| "stored_response_tokens": 0.0, | ||
| "stored_response_chars": 0.0, | ||
| "stored_versioned_segments": 0.0, | ||
| "stored_versioned_tokens": 0.0, | ||
| "stored_routed_expert_payloads": 0.0, | ||
| "stored_judged_observations": 0.0, | ||
| "multimodal_actions_count": 0.0, | ||
| } | ||
|
|
||
| for replay_meta in self._actions.values(): | ||
| summary = replay_meta.extra_info | ||
| observation_count = float(summary.get("observation_count", len(replay_meta.observation_ids))) | ||
| if summary.get("payload_mode", "full") == "prompt_only": | ||
| stats["payload_prompt_only_actions_count"] += 1.0 | ||
| stats["payload_prompt_only_observations_count"] += observation_count | ||
| else: | ||
| stats["payload_full_actions_count"] += 1.0 | ||
| stats["payload_full_observations_count"] += observation_count | ||
|
|
||
| stats["stored_response_tokens"] += float(summary.get("response_tokens", 0)) | ||
| stats["stored_response_chars"] += float(summary.get("response_chars", 0)) | ||
| stats["stored_versioned_segments"] += float(summary.get("versioned_segments", 0)) | ||
| stats["stored_versioned_tokens"] += float(summary.get("versioned_tokens", 0)) | ||
| stats["stored_routed_expert_payloads"] += float(summary.get("routed_expert_payloads", 0)) | ||
| stats["stored_judged_observations"] += float(summary.get("judged_observations", 0)) | ||
| if summary.get("has_multimodal_prompt", False): | ||
| stats["multimodal_actions_count"] += 1.0 | ||
|
|
||
| return stats | ||
|
|
||
| def add(self, grouped_dataitem: List[RLDataFlowItem]): | ||
| """Adds a group of data items to the storage. | ||
|
|
||
|
|
@@ -426,6 +560,8 @@ def sample(self, sample_from_expired_states) -> List[RLDataFlowItem]: | |
| return [] | ||
|
|
||
| def clear(self): | ||
| for replay_meta in self._actions.values(): | ||
| self._free_replay_meta_refs(replay_meta) | ||
| attrs_to_clear = [ | ||
| "_aborted_actions", | ||
| "_completed_actions", | ||
|
|
@@ -699,6 +835,10 @@ def _check_completed_samples_expired(self): | |
|
|
||
| for version in expired_versions: | ||
| bucket = self._completed_actions.pop(version) | ||
| for action_id in bucket: | ||
| replay_meta = self._actions.get(action_id) | ||
| if replay_meta is not None: | ||
| self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.EXPIRED) | ||
| self._expired_actions.extend(bucket) | ||
| self.logger.info( | ||
| f"Moved {len(bucket)} completed samples with version {version} to expired samples due to exceeding tail_batch_candidate_steps." | ||
|
|
@@ -709,6 +849,10 @@ def _check_completed_samples_aborted(self): | |
| return | ||
|
|
||
| for version, bucket in self._completed_actions.items(): | ||
| for action_id in bucket: | ||
| replay_meta = self._actions.get(action_id) | ||
| if replay_meta is not None: | ||
| self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.ABORTED) | ||
| self._aborted_actions[0].extend(bucket) | ||
| self.logger.info( | ||
| f"Moved {len(bucket)} completed samples with version {version} to aborted samples due to partial rollout disabled." | ||
|
|
@@ -729,7 +873,9 @@ def _clear_meta_for_actions(self, replay_meta: ReplayMeta): | |
| if state and observation_id in self._states.get(state, []): | ||
| self._states[state].remove(observation_id) | ||
|
|
||
| self._actions.pop(action_id, None) | ||
| self._action2observations.pop(action_id, None) | ||
| self._free_replay_meta_refs(replay_meta) | ||
| del replay_meta | ||
|
|
||
| def _clear_meta_for_root(self, replay_meta: ReplayMeta): | ||
|
|
@@ -747,13 +893,16 @@ def _clear_meta_for_root(self, replay_meta: ReplayMeta): | |
| and clear all related actions. | ||
| """ | ||
| root_id = replay_meta.root_id | ||
| current_action_id = replay_meta.action_id | ||
| self._clear_meta_for_actions(replay_meta) | ||
| if root_id in self._root2actions: | ||
| for action_id in self._root2actions[root_id]: | ||
| if action_id == current_action_id: | ||
| continue | ||
| new_replay_meta = self._actions.pop(action_id, None) | ||
| if new_replay_meta: | ||
| self._clear_meta_for_actions(new_replay_meta) | ||
|
Comment on lines
899
to
904
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: Nit: Not a bug (the second pop returns for action_id in self._root2actions[root_id]:
if action_id == current_action_id:
continue
new_replay_meta = self._actions.get(action_id)
if new_replay_meta:
self._clear_meta_for_actions(new_replay_meta) |
||
| del self._root2actions[root_id] | ||
| del replay_meta | ||
|
|
||
| def _check_rollout_state_and_insert(self, replay_meta: ReplayMeta): | ||
| """Checks the rollout state of a ReplayMeta object and inserts its | ||
|
|
@@ -775,11 +924,14 @@ def _check_rollout_state_and_insert(self, replay_meta: ReplayMeta): | |
| if state == RolloutState.ABORTED: | ||
| if self.tail_batch_candidate_steps > 0 and replay_meta.version >= self.tail_batch_candidate_steps: | ||
| # 过期的数据需要重置状态 | ||
| self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.EXPIRED) | ||
| self._expired_actions.append(action_id) | ||
| self.logger.debug( | ||
| f"Add expired sample with action_id: {action_id} to _expired_actions because version: {replay_meta.version} >= tail_batch_candidate_steps: {self.tail_batch_candidate_steps}." | ||
| ) | ||
| else: | ||
| if not self.enable_partial_rollout: | ||
| self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.ABORTED) | ||
| self._aborted_actions[replay_meta.version].append(action_id) | ||
| self.logger.debug( | ||
| f"Add aborted sample with action_id: {action_id} version: {replay_meta.version} to _aborted_actions." | ||
|
|
@@ -903,14 +1055,16 @@ def add(self, grouped_dataitem: List[RLDataFlowItem]): | |
| self.storage.add(grouped_dataitem) | ||
|
|
||
| def status(self): | ||
| return { | ||
| status = { | ||
| "remain_completed_samples_count": self.storage.completed_samples_count, | ||
| "remain_aborted_samples_count": self.storage.aborted_samples_count, | ||
| "remain_expired_samples_count": self.storage.expired_samples_count, | ||
| "sample_from_dataset_count": self.sample_from_dataset_count, | ||
| "sample_from_aborted_count": self.storage.sample_from_aborted_count, | ||
| "sample_from_expired_count": self.storage.sample_from_expired_count, | ||
| } | ||
| status.update(self.storage.get_storage_stats()) | ||
| return status | ||
|
|
||
| def save(self, file_path: Path | str): | ||
| """Saves the replay buffer's storage to a file. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
|
|
||
| import ray | ||
| import torch | ||
| from ray import ObjectRef | ||
| from ray.actor import ActorProxy | ||
|
|
||
| from xtuner.v1.data_proto.sequence_context import SequenceContext | ||
|
|
@@ -28,6 +29,27 @@ class RawTrainingController: | |
| def __init__(self, workers: list[TrainingWorker]) -> None: | ||
| self.workers = workers | ||
|
|
||
| def _collect_object_refs(self, obj, refs: list[ObjectRef]): | ||
| if isinstance(obj, ObjectRef): | ||
| refs.append(obj) | ||
| return | ||
|
Comment on lines
29
to
+35
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude: Warning: Not necessarily a problem today since |
||
| if isinstance(obj, (list, tuple)): | ||
| for item in obj: | ||
| self._collect_object_refs(item, refs) | ||
| return | ||
| if isinstance(obj, dict): | ||
| for value in obj.values(): | ||
| self._collect_object_refs(value, refs) | ||
|
|
||
| def _free_batch_object_refs(self, data_batches): | ||
| refs: list[ObjectRef] = [] | ||
| for data in data_batches: | ||
| seq_ctx = data["seq_ctx"] | ||
| self._collect_object_refs(seq_ctx.pixel_values, refs) | ||
| self._collect_object_refs(seq_ctx.rollout_routed_experts, refs) | ||
| if refs: | ||
| ray.internal.free(refs, local_only=False) | ||
|
|
||
| # TODO(hha): 这个逻辑不够通用,应该复用 sft 函数,从而支持 expand soft pack | ||
| def _get_pack_infos(self, dataset, num_tokens, target, random=None): | ||
| inds = list(range(len(dataset))) | ||
|
|
@@ -260,7 +282,10 @@ def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: | |
| rollout_idx=rollout_idx, | ||
| ) | ||
| ) | ||
| log_infos = ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT) | ||
| try: | ||
| log_infos = ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT) | ||
| finally: | ||
| self._free_batch_object_refs(packed_data_batches) | ||
| return log_infos | ||
|
|
||
| @ray_method | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Claude: Nit:
summarize_group_payloadis a public (module-level) function. Per project standards it should have a docstring with Google Style format, parameter types, and return type.