diff --git a/xtuner/v1/module/decoder_layer/moe_decoder_layer.py b/xtuner/v1/module/decoder_layer/moe_decoder_layer.py index 4e72751da..700794cfe 100644 --- a/xtuner/v1/module/decoder_layer/moe_decoder_layer.py +++ b/xtuner/v1/module/decoder_layer/moe_decoder_layer.py @@ -37,6 +37,7 @@ from xtuner.v1.module.rope import RopeScalingConfig from xtuner.v1.ops.act_fn import get_act_fn from xtuner.v1.utils import ForwardState +from xtuner.v1.utils.misc import run_gc_once from ..linear import build_linear @@ -452,6 +453,7 @@ def _forward( residual=residual, shared_experts_out=shared_experts_out, ) + run_gc_once() return hidden_states, router_results["logits"], router_results["router_weights"] def _micro_batch_forward( @@ -586,6 +588,7 @@ def _micro_batch_forward( router_logits = [router_results["logits"] for router_results in router_results_list] router_weights = [router_results["router_weights"] for router_results in router_results_list] + run_gc_once() return tuple(hidden_states_out_list + router_logits + router_weights) def _pre_moe_forward( diff --git a/xtuner/v1/utils/misc.py b/xtuner/v1/utils/misc.py index a588e2f1a..87246f968 100644 --- a/xtuner/v1/utils/misc.py +++ b/xtuner/v1/utils/misc.py @@ -1,7 +1,8 @@ +import gc import os import sys import threading -from functools import reduce +from functools import cache, reduce from math import lcm from multiprocessing import resource_tracker as _mprt from multiprocessing import shared_memory as _mpshm @@ -214,3 +215,13 @@ def clean_param_name(name: str) -> str: if "_orig_mod." in name: name = name.replace("_orig_mod.", "") return name + + +@cache +def run_gc_once() -> None: + """Workaround for sporadic memory leaks observed after the first forward + pass of the MoE decoder layer on certain devices. + + Forces a garbage collection to prevent extra memory occupation. + """ + gc.collect()