Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions xtuner/v1/module/decoder_layer/moe_decoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from xtuner.v1.module.rope import RopeScalingConfig
from xtuner.v1.ops.act_fn import get_act_fn
from xtuner.v1.utils import ForwardState
from xtuner.v1.utils.misc import run_gc_once

from ..linear import build_linear

Expand Down Expand Up @@ -452,6 +453,7 @@ def _forward(
residual=residual,
shared_experts_out=shared_experts_out,
)
run_gc_once()
return hidden_states, router_results["logits"], router_results["router_weights"]

def _micro_batch_forward(
Expand Down Expand Up @@ -586,6 +588,7 @@ def _micro_batch_forward(

router_logits = [router_results["logits"] for router_results in router_results_list]
router_weights = [router_results["router_weights"] for router_results in router_results_list]
run_gc_once()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please test directly on the NPU with the main branch to see if the memory leak issue still exists. If it does, we may need to figure out why reference count-based reclamation is failing.

return tuple(hidden_states_out_list + router_logits + router_weights)

def _pre_moe_forward(
Expand Down
13 changes: 12 additions & 1 deletion xtuner/v1/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import gc
import os
import sys
import threading
from functools import reduce
from functools import cache, reduce
from math import lcm
from multiprocessing import resource_tracker as _mprt
from multiprocessing import shared_memory as _mpshm
Expand Down Expand Up @@ -214,3 +215,13 @@ def clean_param_name(name: str) -> str:
if "_orig_mod." in name:
name = name.replace("_orig_mod.", "")
return name


@cache
def run_gc_once() -> None:
"""Workaround for sporadic memory leaks observed after the first forward
pass of the MoE decoder layer on certain devices.

Forces a garbage collection to prevent extra memory occupation.
"""
gc.collect()
Loading