Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions xtuner/v1/model/compose/intern_s1/modeling_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from xtuner.v1.ops.act_fn import get_act_fn
from xtuner.v1.utils import get_logger
from xtuner.v1.module import AttnOutputs
import os
from xtuner.v1.utils.activation_offload import async_save_on_cpu

DEVICE = get_device()
DEVICE_MODULE = get_torch_device_module()
Expand Down Expand Up @@ -230,6 +232,7 @@ def __init__(self, config: InternS1VisionConfig) -> None:
dpr = np.linspace(0.0, float(config.drop_path_rate), int(config.num_hidden_layers))
self.layer = nn.ModuleList([
InternS1VisionLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
self.offload_stream = torch.cuda.Stream()

def forward(
self,
Expand All @@ -241,8 +244,17 @@ def forward(
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) # type: ignore

hidden_states = layer_module(hidden_states)
if int(os.getenv("XTUNER_ACTIVATION_OFFLOAD", "0")) == 1:
with async_save_on_cpu(
h2d_stream=self.offload_stream,
d2h_stream=self.offload_stream,
block_idx=int(i),
group="vision",
custom_check_fn=lambda x: x.data_ptr() == hidden_states.data_ptr(),
):
hidden_states = layer_module(hidden_states)
else:
hidden_states = layer_module(hidden_states)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) # type: ignore
Expand Down
4 changes: 2 additions & 2 deletions xtuner/v1/model/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def _micro_batch_forward(
h2d_stream=self.offload_stream,
d2h_stream=self.offload_stream,
block_idx=layer_idx - self.config.first_k_dense_replace,
depth=len(self.layers) - self.config.first_k_dense_replace,
group="text",
custom_check_fn=lambda x: x.data_ptr()
in [hidden_states.data_ptr() for hidden_states in hidden_states_list],
prefetch=True,
Expand Down Expand Up @@ -579,7 +579,7 @@ def _forward(
h2d_stream=self.offload_stream,
d2h_stream=self.offload_stream,
block_idx=int(idx),
depth=len(self.layers),
group="text",
custom_check_fn=lambda x: x.data_ptr() == hidden_states.data_ptr(),
):
layer_results = decoder_layer(
Expand Down
44 changes: 26 additions & 18 deletions xtuner/v1/utils/activation_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def __init__(self):
self._block_tensor_nums = {} # offload tensors per block

def get_cnt(self, block_idx):
prev_block_idx = None if self._block_idx == -1 else self._block_idx
after_block = False

if block_idx > self._block_idx:
self._block_tensor_nums[block_idx] = 1
if block_idx != 0:
Expand All @@ -43,7 +45,7 @@ def get_cnt(self, block_idx):
self._block_tensor_nums = {block_idx: 1}

offload_tensor_key = f"{self._block_idx}_{self._block_tensor_nums[self._block_idx] - 1}"
return offload_tensor_key, after_block
return offload_tensor_key, after_block, prev_block_idx

def get_prefetch_keys(self, block_idx, tensor_idx):
prefetch_block_idx = max((idx for idx in self._block_tensor_nums.keys() if idx < block_idx), default=None)
Expand Down Expand Up @@ -193,11 +195,13 @@ def __init__(self, check=False):
self.items = {}
self.check = check
self.device_item = []
self.getcnt = GetCnt()
self.getcnt = {}
self.may_npu_tensors = {}

def get_cnt(self, block_idx):
return self.getcnt.get_cnt(block_idx)
def get_cnt(self, block_idx, group="default"):
if group not in self.getcnt:
self.getcnt[group] = GetCnt()
return self.getcnt[group].get_cnt(block_idx)

def assert_exist(self, key):
if key not in self.items:
Expand Down Expand Up @@ -249,16 +253,17 @@ def get(self, key):
self.may_npu_tensors.update({key: self.items.pop(key)})
return act

def prefetch_get(self, block_idx, tensor_idx, h2d_stream, d2h_stream):
prefetch_keys = self.getcnt.get_prefetch_keys(block_idx, tensor_idx)
def prefetch_get(self, block_idx, tensor_idx, h2d_stream, d2h_stream, group="default"):
if group not in self.getcnt:
return
prefetch_keys = self.getcnt[group].get_prefetch_keys(block_idx, tensor_idx)
for prefetch_key in prefetch_keys:
if self.exist(prefetch_key):
prefetch_swap_tensor = self.get(prefetch_key)
full_key = f"{group}_{prefetch_key}"
if self.exist(full_key):
prefetch_swap_tensor = self.get(full_key)
h2d_stream.wait_stream(d2h_stream)
prefetch_swap_tensor.prefetch_launch_h2d(h2d_stream, True)
# prefetch_swap_tensor.tensor.record_stream(h2d_stream)
else:
torch.distributed.breakpoint()

def empty(self):
return len(self.items) == 0
Expand Down Expand Up @@ -291,7 +296,8 @@ def __init__(
h2d_stream: torch.cuda.Stream,
d2h_stream: torch.cuda.Stream,
block_idx: int,
depth: int,
depth: int | None = None,
group: str = "default",
custom_check_fn=None,
prefetch=True,
) -> None:
Expand All @@ -302,19 +308,21 @@ def _pack_to_cpu(tensor):
if (custom_check_fn is not None) and (not custom_check_fn(tensor)):
return tensor

key, after_block = OffloadManager().get_cnt(block_idx)
key, after_block, prev_block_idx = OffloadManager().get_cnt(block_idx, group=group)

if after_block:
OffloadManager().del_npu_tensor(f"{block_idx - 1}_", d2h_stream)
if after_block and (prev_block_idx is not None):
OffloadManager().del_npu_tensor(f"{group}_{prev_block_idx}_", d2h_stream)

swap_tensor = SwapTensor(tensor, key)
full_key = f"{group}_{key}"

if block_idx <= depth - 1:
should_offload = depth is None or block_idx <= depth - 1
if should_offload:
working_stream = torch.cuda.current_stream()
d2h_stream.wait_stream(working_stream)
swap_tensor.launch_d2h(d2h_stream)

OffloadManager().put(key, swap_tensor)
OffloadManager().put(full_key, swap_tensor)
return swap_tensor

def _unpack_from_cpu(swap_tensor) -> torch.Tensor:
Expand All @@ -328,14 +336,14 @@ def _unpack_from_cpu(swap_tensor) -> torch.Tensor:

block_idx, tensor_idx = swap_tensor.key.split("_")

OffloadManager().del_may_npu_tensor(f"{int(block_idx) + 1}_", h2d_stream)
OffloadManager().del_may_npu_tensor(f"{group}_{int(block_idx) + 1}_", h2d_stream)
swap_tensor.launch_h2d(h2d_stream, True, working_stream)
# if block_idx in ["0", "2", "3"]:
# if block_idx in ["0"]:
# torch.cuda.synchronize()

if prefetch and block_idx != 0:
OffloadManager().prefetch_get(int(block_idx), int(tensor_idx), h2d_stream, d2h_stream)
OffloadManager().prefetch_get(int(block_idx), int(tensor_idx), h2d_stream, d2h_stream, group=group)

# if block_idx in ["0"] and tensor_idx == "1":
# swap_tensor.load()
Expand Down
Loading