|
3 | 3 | from typing import Annotated, Any, Literal, TypeVar |
4 | 4 |
|
5 | 5 | import torch |
6 | | -import torch.distributed as dist |
7 | 6 | import torch.nn as nn |
8 | 7 | from cyclopts import Parameter |
9 | 8 | from pydantic import BaseModel, ConfigDict |
10 | 9 | from torch.distributed.device_mesh import DeviceMesh |
11 | | -from torch.distributed.nn.functional import all_reduce |
12 | | -from typing_extensions import Self |
13 | | - |
14 | | -from xtuner.v1.loss.utils import sp_split |
15 | | - |
16 | | -from .chunk_loss import ChunkLoss |
17 | 10 |
|
18 | 11 |
|
19 | 12 | # Do loss calibration among dp, sp and grad accumulation: |
|
46 | 39 |
|
47 | 40 |
|
48 | 41 | class BaseLossKwargs(BaseModel): |
49 | | - """Everything needed to compute the loss.""" |
50 | | - |
51 | | - model_config = ConfigDict(title="loss keyword arguments", extra="forbid", arbitrary_types_allowed=True) |
52 | | - shifted_labels: torch.Tensor |
| 42 | + """Everything needed to compute the loss. |
53 | 43 |
|
54 | | - def sp_split(self, sp_mesh: DeviceMesh) -> Self: |
55 | | - self.shifted_labels = sp_split(self.shifted_labels, sp_mesh=sp_mesh, split_dim=1, padding_value=-100) |
56 | | - return self |
| 44 | + Subclasses should implement sp_split() and to() methods if they contain tensors that need to be split across |
| 45 | + sequence parallel mesh or moved to device. |
| 46 | + """ |
57 | 47 |
|
58 | | - def to(self, device: torch.device | str) -> Self: |
59 | | - self.shifted_labels = self.shifted_labels.to(device) |
60 | | - return self |
| 48 | + model_config = ConfigDict(title="loss keyword arguments", extra="forbid", arbitrary_types_allowed=True) |
61 | 49 |
|
62 | 50 | def chunk(self, chunk_size) -> list["BaseLossKwargs"]: |
63 | 51 | tensor_fields: dict[str, tuple[torch.Tensor, ...]] = {} |
@@ -114,10 +102,13 @@ class BaseLossConfig(BaseModel): |
114 | 102 | chunk_size: Annotated[int | None, Parameter(help="chunk size when mode is chunk")] = 1024 |
115 | 103 |
|
116 | 104 | @property |
| 105 | + @abstractmethod |
117 | 106 | def loss_ctx_cls(self) -> type["BaseLossContext"]: |
118 | 107 | raise NotImplementedError |
119 | 108 |
|
| 109 | + # TODO: private property maybe not a good idea |
120 | 110 | @property |
| 111 | + @abstractmethod |
121 | 112 | def _loss_kwargs_cls(self) -> type["BaseLossKwargs"]: |
122 | 113 | raise NotImplementedError |
123 | 114 |
|
@@ -160,72 +151,10 @@ def __init__(self, loss_cfg: BaseLossConfig, loss_kwargs: BaseLossKwargs): |
160 | 151 | self._batch_size = 1 |
161 | 152 |
|
162 | 153 | @staticmethod |
163 | | - @abstractmethod |
164 | | - def build_batches(loss_ctx_list: list[_BaseLossContextT], *args, **kwargs) -> list[_BaseLossContextT]: ... |
165 | | - |
166 | | - @abstractmethod |
167 | | - def loss_fn( |
168 | | - self, |
169 | | - hidden_states: torch.Tensor, |
170 | | - head_weight: torch.Tensor, |
171 | | - head_bias: torch.Tensor | None, |
172 | | - loss_kwargs: BaseLossKwargs, |
173 | | - ) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]: |
174 | | - """Step 2.a and 2.b in the loss calculation.""" |
175 | | - ... |
176 | | - |
177 | | - def eager_mode( |
178 | | - self, |
179 | | - hidden_states: torch.Tensor, |
180 | | - head_weight: torch.Tensor, |
181 | | - head_bias: torch.Tensor | None, |
182 | | - loss_kwargs: BaseLossKwargs, |
183 | | - ): |
184 | | - return self.loss_fn(hidden_states, head_weight, head_bias, loss_kwargs) |
185 | | - |
186 | | - def chunk_mode( |
187 | | - self, |
188 | | - hidden_states: torch.Tensor, |
189 | | - head_weight: torch.Tensor, |
190 | | - head_bias: torch.Tensor | None, |
191 | | - loss_kwargs: BaseLossKwargs, |
192 | | - ): |
193 | | - assert self.loss_cfg.chunk_size is not None, "chunk_size must be set in chunk mode" |
194 | | - |
195 | | - chunks = loss_kwargs.chunk(self.loss_cfg.chunk_size) |
196 | | - loss, extra_info = ChunkLoss.apply( |
197 | | - hidden_states, head_weight, head_bias, self.loss_fn, chunks, self.loss_cfg.chunk_size |
198 | | - ) |
199 | | - return loss, (None, extra_info) |
200 | | - |
201 | | - def forward( |
202 | | - self, |
203 | | - hidden_states: torch.Tensor, |
204 | | - head_weight: torch.Tensor, |
205 | | - head_bias: torch.Tensor | None = None, |
206 | | - ) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]: |
207 | | - from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo |
208 | | - |
209 | | - assert self.loss_kwargs is not None, "loss_kwargs must be set before calling forward" |
210 | | - if head_bias is not None: |
211 | | - raise NotImplementedError("Loss does not support head_bias yet.") |
212 | | - |
213 | | - if self.loss_cfg.mode == "eager": |
214 | | - loss, (logits, extra_info) = self.eager_mode(hidden_states, head_weight, head_bias, self.loss_kwargs) |
215 | | - else: |
216 | | - loss, (logits, extra_info) = self.chunk_mode(hidden_states, head_weight, head_bias, self.loss_kwargs) |
217 | | - |
218 | | - # TODO: yanhuida, should be removed |
219 | | - if not isinstance(extra_info, ModelForwardExtraLogInfo): |
220 | | - extra_info = ModelForwardExtraLogInfo(extra_info) |
221 | | - |
222 | | - extra_info["local_base_loss"] = loss.detach().clone() |
223 | | - |
224 | | - # Step 2.c in the loss calculation: reduce the loss over all ranks using all_reduce with autograd support |
225 | | - if dist.is_initialized(): |
226 | | - loss = all_reduce(loss, op=dist.ReduceOp.SUM, group=dist.group.WORLD) |
227 | | - |
228 | | - return loss, (logits, extra_info) |
| 154 | + def build_batches(loss_ctx_list: list[_BaseLossContextT], *args, **kwargs) -> list[_BaseLossContextT]: |
| 155 | + for ctx in loss_ctx_list: |
| 156 | + ctx._batch_size = len(loss_ctx_list) |
| 157 | + return loss_ctx_list |
229 | 158 |
|
230 | 159 | @classmethod |
231 | 160 | def cat(cls: type[_BaseLossContextT], chunks: list[_BaseLossContextT]) -> _BaseLossContextT: |
|
0 commit comments