diff --git a/arctic_training/data/chat_markers.py b/arctic_training/data/chat_markers.py new file mode 100644 index 00000000..69947496 --- /dev/null +++ b/arctic_training/data/chat_markers.py @@ -0,0 +1,482 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Chat template marker detection and token-based label masking. + +This module provides utilities for detecting chat template markers (e.g., user/assistant +turn boundaries) and creating labels for SFT training using token-based matching rather +than character position matching. This approach: + +1. Correctly includes end-of-turn tokens in the training signal +2. Is more robust to tokenizer variations +3. Works with any chat template format + +References: +- Unsloth's train_on_responses_only: https://github.com/unslothai/unsloth-zoo +""" + +import re +from dataclasses import dataclass +from typing import Dict +from typing import List +from typing import Optional + +from transformers import PreTrainedTokenizerBase + +from arctic_training.logging import logger + +IGNORE_INDEX = -100 + + +@dataclass +class ChatMarkerConfig: + """Markers for identifying role boundaries in chat templates.""" + + user_start: str # Marks start of user turn + assistant_start: str # Marks start of assistant turn + turn_end: str # Marks end of any turn + system_start: Optional[str] = None # Optional system marker + + +# Pre-defined configurations for popular model families +KNOWN_CHAT_MARKERS: Dict[str, ChatMarkerConfig] = { + # ===== ChatML Format (Qwen, Yi, Arctic, etc.) ===== + "chatml": ChatMarkerConfig( + user_start="<|im_start|>user\n", + assistant_start="<|im_start|>assistant\n", + turn_end="<|im_end|>", + system_start="<|im_start|>system\n", + ), + # ===== Llama 3 / 3.1 / 3.2 ===== + "llama3": ChatMarkerConfig( + user_start="<|start_header_id|>user<|end_header_id|>\n\n", + assistant_start="<|start_header_id|>assistant<|end_header_id|>\n\n", + turn_end="<|eot_id|>", + system_start="<|start_header_id|>system<|end_header_id|>\n\n", + ), + # ===== Llama 2 / Mistral (Instruct format) ===== + "llama2": ChatMarkerConfig( + user_start="[INST] ", + assistant_start=" [/INST] ", + turn_end="", + system_start="<>\n", + ), + # ===== Mistral v0.3+ / Mixtral ===== + "mistral_v3": ChatMarkerConfig( + user_start="[INST]", + assistant_start="[/INST]", + turn_end="", + ), + # ===== Phi-3 / Phi-4 ===== + "phi3": ChatMarkerConfig( + user_start="<|user|>\n", + assistant_start="<|assistant|>\n", + turn_end="<|end|>", + system_start="<|system|>\n", + ), + # ===== Gemma / Gemma 2 ===== + "gemma": ChatMarkerConfig( + user_start="user\n", + assistant_start="model\n", + turn_end="", + ), + # ===== DeepSeek ===== + "deepseek": ChatMarkerConfig( + user_start="User: ", + assistant_start="Assistant: ", + turn_end="<|end▁of▁sentence|>", + ), + # ===== DeepSeek V2/V3 (ChatML-like) ===== + "deepseek_v2": ChatMarkerConfig( + user_start="<|User|>", + assistant_start="<|Assistant|>", + turn_end="<|end_of_sentence|>", + ), + # ===== Vicuna / Alpaca ===== + "vicuna": ChatMarkerConfig( + user_start="USER: ", + assistant_start="ASSISTANT: ", + turn_end="", + ), + # ===== Zephyr (based on Mistral) ===== + "zephyr": ChatMarkerConfig( + user_start="<|user|>\n", + assistant_start="<|assistant|>\n", + turn_end="", + system_start="<|system|>\n", + ), + # ===== Command-R (Cohere) ===== + "command_r": ChatMarkerConfig( + user_start="<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", + assistant_start="<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", + turn_end="<|END_OF_TURN_TOKEN|>", + system_start="<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>", + ), +} + + +def detect_model_family(tokenizer: PreTrainedTokenizerBase) -> Optional[str]: + """ + Detect which model family the tokenizer belongs to. + Returns the key for KNOWN_CHAT_MARKERS or None if unknown. + """ + # Check by model name patterns + model_name = getattr(tokenizer, "name_or_path", "").lower() + + patterns = [ + (r"qwen|arctic", "chatml"), + (r"llama-3|llama3|llama-4|llama4", "llama3"), + (r"llama-2|llama2", "llama2"), + (r"mistral.*v0\.[3-9]|mixtral", "mistral_v3"), + (r"phi-[34]|phi[34]", "phi3"), + (r"gemma", "gemma"), + (r"deepseek.*v[23]|deepseek-v[23]", "deepseek_v2"), + (r"deepseek", "deepseek"), + (r"vicuna|alpaca", "vicuna"), + (r"zephyr", "zephyr"), + (r"command-r|c4ai", "command_r"), + ] + + for pattern, family in patterns: + if re.search(pattern, model_name): + logger.info(f"Detected model family '{family}' from model name: {model_name}") + return family + + # Check by special tokens + special_tokens = set(getattr(tokenizer, "additional_special_tokens", [])) + + if "<|im_start|>" in special_tokens: + logger.info("Detected model family 'chatml' from special tokens") + return "chatml" + if "<|start_header_id|>" in special_tokens: + logger.info("Detected model family 'llama3' from special tokens") + return "llama3" + if "" in special_tokens: + logger.info("Detected model family 'gemma' from special tokens") + return "gemma" + if "<|user|>" in special_tokens and "<|assistant|>" in special_tokens: + logger.info("Detected model family 'phi3' from special tokens") + return "phi3" + + return None # Unknown - will use heuristic + + +def detect_markers_heuristic(tokenizer: PreTrainedTokenizerBase) -> ChatMarkerConfig: + """ + Auto-detect chat markers by applying the template to minimal examples and diffing. + This is the fallback when the model family is unknown. + """ + # Minimal conversations with unique single-char content + user_only = [{"role": "user", "content": "X"}] + user_asst = [{"role": "user", "content": "X"}, {"role": "assistant", "content": "Y"}] + multi_turn = [ + {"role": "user", "content": "X"}, + {"role": "assistant", "content": "Y"}, + {"role": "user", "content": "Z"}, + ] + + try: + t1 = tokenizer.apply_chat_template(user_only, tokenize=False, add_generation_prompt=False) + t2 = tokenizer.apply_chat_template(user_asst, tokenize=False, add_generation_prompt=False) + t3 = tokenizer.apply_chat_template(multi_turn, tokenize=False, add_generation_prompt=False) + except Exception as e: + logger.warning(f"Failed to apply chat template for heuristic detection: {e}") + # Return a safe default that won't match anything - will fall back to char-based + return ChatMarkerConfig( + user_start="<|UNKNOWN_USER|>", + assistant_start="<|UNKNOWN_ASSISTANT|>", + turn_end="<|UNKNOWN_END|>", + ) + + # Extract markers by diffing + # User turn = t1 (contains user start + "X" + turn end + maybe newline) + # Find where X appears + x_pos = t1.find("X") + user_start = t1[:x_pos] if x_pos > 0 else "" + + # Assistant turn = difference between t2 and t1 + # t2 should start with t1's content (possibly without trailing generation prompt) + asst_turn = t2[len(t1) :] if t2.startswith(t1) else t2[t2.find("X") + 1 :] + + # Find Y in assistant turn + y_pos_in_asst = asst_turn.find("Y") + assistant_start = asst_turn[:y_pos_in_asst] if y_pos_in_asst > 0 else "" + + # Turn end = what comes after Y and before the next user turn in t3 + # Find where Y ends and Z's user turn starts + y_pos_t3 = t3.find("Y") + z_pos_t3 = t3.find("Z") + if y_pos_t3 >= 0 and z_pos_t3 > y_pos_t3: + between = t3[y_pos_t3 + 1 : z_pos_t3] + # Turn end is the part before the next user_start + if user_start and user_start in between: + turn_end = between[: between.find(user_start)] + else: + # Just take up to some reasonable delimiter + turn_end = between.split("\n")[0] if "\n" in between else between + else: + # Fallback: look for common end tokens + turn_end = "" + for common_end in ["<|im_end|>", "<|eot_id|>", "", "<|end|>", ""]: + if common_end in t2: + turn_end = common_end + break + + # Clean up extracted markers + user_start = user_start.strip() if user_start else "<|user|>" + assistant_start = assistant_start.strip() if assistant_start else "<|assistant|>" + turn_end = turn_end.strip() if turn_end else "" + + if not turn_end: + logger.warning( + "Could not detect turn_end marker from chat template. " + "Assistant response boundaries may not be correctly identified. " + "Consider explicitly specifying chat_template_family in your config." + ) + + logger.info( + "Heuristically detected chat markers:\n" + f" user_start: {repr(user_start)}\n" + f" assistant_start: {repr(assistant_start)}\n" + f" turn_end: {repr(turn_end)}" + ) + + return ChatMarkerConfig( + user_start=user_start, + assistant_start=assistant_start, + turn_end=turn_end, + ) + + +def get_chat_markers( + tokenizer: PreTrainedTokenizerBase, + chat_template_family: Optional[str] = None, +) -> ChatMarkerConfig: + """ + Get chat markers for the given tokenizer. + Priority: explicit family > detected family > heuristic detection + """ + # 1. Explicit family specified + if chat_template_family is not None: + if chat_template_family in KNOWN_CHAT_MARKERS: + logger.info(f"Using explicitly specified chat template family: {chat_template_family}") + return KNOWN_CHAT_MARKERS[chat_template_family] + else: + raise ValueError( + f"Unknown chat_template_family: {chat_template_family}. " + f"Available options: {list(KNOWN_CHAT_MARKERS.keys())}" + ) + + # 2. Try to detect model family + family = detect_model_family(tokenizer) + if family is not None: + return KNOWN_CHAT_MARKERS[family] + + # 3. Heuristic detection (fallback) + logger.info("Model family not recognized, using heuristic marker detection") + return detect_markers_heuristic(tokenizer) + + +def _find_subsequence(sequence: List[int], subsequence: List[int], start: int = 0) -> int: + """ + Find the starting index of a subsequence within a sequence. + Returns -1 if not found or if subsequence is empty. + """ + if not subsequence: + return -1 # Empty subsequence is considered "not found" + subseq_len = len(subsequence) + for i in range(start, len(sequence) - subseq_len + 1): + if sequence[i : i + subseq_len] == subsequence: + return i + return -1 + + +def _tokenize_marker(marker: str, tokenizer: PreTrainedTokenizerBase) -> List[int]: + """ + Tokenize a marker string, handling tokenizer variations. + Similar to Unsloth's _find_common_token_ids but simplified. + """ + if not marker: + return [] + return tokenizer(marker, add_special_tokens=False).input_ids + + +def _tokenize_marker_without_trailing_whitespace(marker: str, tokenizer: PreTrainedTokenizerBase) -> List[int]: + """ + Tokenize a marker string, stripping trailing whitespace to avoid tokenizer + context sensitivity issues. + + Some tokenizers (e.g., Qwen) tokenize "\\n\\n" differently than "\\n" + "\\n". + When assistant content starts with "\\n", the marker "assistant\\n" + "\\n" becomes + "assistant\\n\\n" which tokenizes differently, causing pattern match failures. + + By stripping trailing whitespace from the marker, we avoid this issue. + """ + if not marker: + return [] + # Strip trailing whitespace (newlines, spaces) to avoid context sensitivity + stripped = marker.rstrip() + return tokenizer(stripped, add_special_tokens=False).input_ids + + +def get_token_based_labels( + input_ids: List[int], + tokenizer: PreTrainedTokenizerBase, + markers: ChatMarkerConfig, +) -> List[int]: + """ + Create labels using token-based pattern matching. + + This approach: + 1. Finds assistant_start marker tokens in input_ids + 2. Marks tokens as trainable from assistant_start until turn_end or next user_start + 3. Includes the turn_end token in the training signal + + Args: + input_ids: The tokenized input sequence + tokenizer: The tokenizer used + markers: Chat template markers + + Returns: + List of labels where trainable tokens have their token ID and masked tokens have -100 + """ + # Tokenize the markers - strip trailing whitespace to avoid context sensitivity + # This handles cases where content starts with newlines causing different tokenization + user_start_ids = _tokenize_marker_without_trailing_whitespace(markers.user_start, tokenizer) + assistant_start_ids = _tokenize_marker_without_trailing_whitespace(markers.assistant_start, tokenizer) + turn_end_ids = _tokenize_marker(markers.turn_end, tokenizer) + + # Initialize all labels as masked + labels = [IGNORE_INDEX] * len(input_ids) + + # Find all assistant turns and mark them as trainable + i = 0 + while i < len(input_ids): + # Look for assistant start marker + if assistant_start_ids: + asst_start_pos = _find_subsequence(input_ids, assistant_start_ids, i) + else: + asst_start_pos = -1 + + if asst_start_pos == -1: + break # No more assistant turns + + # Move past the assistant start marker (don't include marker in labels) + content_start = asst_start_pos + len(assistant_start_ids) + + # Find the end of this assistant turn + # Look for either turn_end or next user_start, whichever comes first + turn_end_pos = len(input_ids) # Default to end of sequence + + if turn_end_ids: + end_pos = _find_subsequence(input_ids, turn_end_ids, content_start) + if end_pos != -1: + # Include the turn_end token in labels + turn_end_pos = end_pos + len(turn_end_ids) + + if user_start_ids: + next_user_pos = _find_subsequence(input_ids, user_start_ids, content_start) + if next_user_pos != -1: + turn_end_pos = min(turn_end_pos, next_user_pos) + + # Mark tokens from content_start to turn_end_pos as trainable + for j in range(content_start, min(turn_end_pos, len(input_ids))): + labels[j] = input_ids[j] + + # Move to search for next assistant turn + i = turn_end_pos + + # Validate that we found trainable tokens + trainable_count = sum(1 for label in labels if label != IGNORE_INDEX) + if trainable_count == 0: + logger.warning( + "No trainable tokens found after label masking. " + "This may cause NaN loss during training. " + "Check that the chat template markers match your model's format." + ) + + return labels + + +def get_token_based_labels_with_ignore_empty_think( + input_ids: List[int], + tokenizer: PreTrainedTokenizerBase, + markers: ChatMarkerConfig, + ignore_empty_think: bool = False, +) -> List[int]: + """ + Create labels using token-based pattern matching, with support for ignore_empty_think. + + When ignore_empty_think is True, empty patterns at the start of + assistant responses are masked (set to IGNORE_INDEX) to prevent the model from + learning to produce empty thinking blocks. + """ + labels = get_token_based_labels(input_ids, tokenizer, markers) + + if not ignore_empty_think: + return labels + + # Handle empty think tags - find and mask them + # Tokenize the empty think pattern + empty_think_pattern = "" + empty_think_ids = _tokenize_marker(empty_think_pattern, tokenizer) + + if not empty_think_ids: + return labels + + # Find and mask empty think patterns that appear at start of assistant responses + # Use same tokenization method as get_token_based_labels() for consistency + assistant_start_ids = _tokenize_marker_without_trailing_whitespace(markers.assistant_start, tokenizer) + user_start_ids = _tokenize_marker_without_trailing_whitespace(markers.user_start, tokenizer) + turn_end_ids = _tokenize_marker(markers.turn_end, tokenizer) + + i = 0 + while i < len(input_ids): + if assistant_start_ids: + asst_start_pos = _find_subsequence(input_ids, assistant_start_ids, i) + else: + break + + if asst_start_pos == -1: + break + + content_start = asst_start_pos + len(assistant_start_ids) + + # Check if empty think pattern is right after assistant start + if input_ids[content_start : content_start + len(empty_think_ids)] == empty_think_ids: + # Mask the empty think tokens + for j in range(content_start, content_start + len(empty_think_ids)): + if j < len(labels): + labels[j] = IGNORE_INDEX + + # Find the end of this assistant turn to skip past it entirely + # This prevents false matches if the response contains marker patterns + turn_end_pos = len(input_ids) # Default to end of sequence + + if turn_end_ids: + end_pos = _find_subsequence(input_ids, turn_end_ids, content_start) + if end_pos != -1: + turn_end_pos = end_pos + len(turn_end_ids) + + if user_start_ids: + next_user_pos = _find_subsequence(input_ids, user_start_ids, content_start) + if next_user_pos != -1: + turn_end_pos = min(turn_end_pos, next_user_pos) + + # Move past the entire assistant turn + i = turn_end_pos + + return labels diff --git a/arctic_training/data/sft_factory.py b/arctic_training/data/sft_factory.py index 22283d14..4c876756 100644 --- a/arctic_training/data/sft_factory.py +++ b/arctic_training/data/sft_factory.py @@ -14,12 +14,10 @@ # limitations under the License. import random -import re from typing import Dict from typing import List from typing import Literal from typing import Optional -from typing import Tuple import numpy as np import torch @@ -33,6 +31,9 @@ from arctic_training.config.data import DataConfig from arctic_training.config.utils import HumanInt +from arctic_training.data.chat_markers import KNOWN_CHAT_MARKERS +from arctic_training.data.chat_markers import get_chat_markers +from arctic_training.data.chat_markers import get_token_based_labels_with_ignore_empty_think from arctic_training.data.factory import DataFactory from arctic_training.data.hf_instruct_source import HFDataSourceInstruct from arctic_training.data.utils import DatasetType @@ -265,6 +266,14 @@ class SFTDataConfig(DataConfig): ignore_empty_think: bool = False """ Whether to mask the empty think tokens preventing the loss of thinking ability.""" + chat_template_family: Optional[str] = None + """ + Explicitly specify the chat template family for label masking. + If None, auto-detection is used based on model name and special tokens. + Available options: chatml, llama3, llama2, mistral_v3, phi3, gemma, + deepseek, deepseek_v2, vicuna, zephyr, command_r + """ + @model_validator(mode="after") def validate_padding(self) -> Self: if self.pad_to == "max_length" and "div_length" in self.model_fields_set: @@ -280,6 +289,16 @@ def validate_padding(self) -> Self: ) return self + @model_validator(mode="after") + def validate_chat_template_family(self) -> Self: + if self.chat_template_family is not None: + if self.chat_template_family not in KNOWN_CHAT_MARKERS: + raise ValueError( + f"Unknown chat_template_family: {self.chat_template_family}. " + f"Available options: {list(KNOWN_CHAT_MARKERS.keys())}" + ) + return self + def filter_dataset_length(self, dataset: DatasetType) -> DatasetType: if not self.config.filter_samples: @@ -375,6 +394,12 @@ def process(self, dataset: DatasetType) -> DatasetType: # datasets.disable_caching() # tmp = tokenize_messages(datasets[0]["messages"][:2], tokenizer, mask_inputs=mask_inputs) # import pdb; pdb.set_trace() + + # Get chat markers once for the entire dataset processing + chat_markers = None + if self.config.mask_inputs: + chat_markers = get_chat_markers(self.tokenizer, self.config.chat_template_family) + return dataset.map( lambda ex: { **self.tokenize_messages( @@ -382,6 +407,7 @@ def process(self, dataset: DatasetType) -> DatasetType: self.tokenizer, mask_inputs=self.config.mask_inputs, ignore_empty_think=self.config.ignore_empty_think, + chat_markers=chat_markers, ) }, remove_columns=dataset.column_names, @@ -396,75 +422,50 @@ def tokenize_messages( tokenizer: PreTrainedTokenizerBase, mask_inputs: bool = True, ignore_empty_think: bool = False, + chat_markers=None, ) -> BatchEncoding: + """ + Tokenize messages and create labels for SFT training. + + Uses token-based pattern matching to identify assistant turns, which: + 1. Correctly includes end-of-turn tokens in the training signal + 2. Is more robust to tokenizer variations + 3. Works with any chat template format + + Args: + messages: List of message dicts with 'role' and 'content' keys + tokenizer: The tokenizer to use + mask_inputs: Whether to mask non-assistant tokens (set to -100) + ignore_empty_think: Whether to mask empty patterns + chat_markers: Pre-computed chat markers (optional, will be detected if None) + + Returns: + BatchEncoding with input_ids, attention_mask, and labels + """ conversation_text = tokenizer.apply_chat_template(conversation=messages, tokenize=False) conversation_ids = tokenizer( conversation_text, - return_offsets_mapping=mask_inputs, add_special_tokens=False, ) if mask_inputs: - assistant_ranges = cls.get_assistant_start_end_indices(messages, conversation_text, ignore_empty_think) - # _ = get_assistant_start_end_indices(messages, conversation_text) - labels = cls.get_masked_labels(conversation_ids, assistant_ranges) + # Get chat markers if not provided + if chat_markers is None: + chat_markers = get_chat_markers(tokenizer) + + # Use token-based label masking + labels = get_token_based_labels_with_ignore_empty_think( + input_ids=conversation_ids["input_ids"], + tokenizer=tokenizer, + markers=chat_markers, + ignore_empty_think=ignore_empty_think, + ) conversation_ids["labels"] = labels - # compare_messages_with_labels(split_list_by_specific_num(conversation_ids["labels"]), messages, tokenizer) - del conversation_ids["offset_mapping"] else: conversation_ids["labels"] = conversation_ids["input_ids"] return conversation_ids - @staticmethod - # this code is adpoted from https://github.com/huggingface/trl/issues/632 (user: Peter-Devine ) - def get_assistant_start_end_indices( - messages: List[Dict[str, str]], - conversation_text: str, - ignore_empty_think: bool = False, - ) -> List[Tuple[int, int]]: - return_indices = [] - for message in messages: - if message["role"] == "assistant": - message_text = message["content"] - if ignore_empty_think: - message_text = re.sub(r"^\s*\s*", "", message_text) - match_index = conversation_text.find(message_text) - # start_indices.append(match_index) - end_indices = match_index + len(message_text) - return_indices.append((match_index, end_indices)) - return return_indices - - @staticmethod - def get_masked_labels(conversation_ids: BatchEncoding, assistant_ranges: List[Tuple[int, int]]) -> List[int]: - pre_output = IGNORE_INDEX - output = [] - - for id_, (id_s, id_e) in list( - zip( - conversation_ids["input_ids"], - conversation_ids["offset_mapping"], - ) - ): - if any(id_s >= s and id_e <= e for s, e in assistant_ranges): - pre_output = id_ - output.append(id_) - else: - # the if-else here is to include the eos token in the loss. - # for instance, the asistent answer is - # <|assistant|> I am good <|user|> xxx - # -100 1 2 3 4 -100 -100 - # after the shift, input_ids = input_ids[:-1], labels = labels[1:] - # 1 2 3 4 -100 -100 - # now the prediction is correct, and the model will be able to predict token - if pre_output != IGNORE_INDEX: - pre_output = IGNORE_INDEX - output.append(id_) - else: - pre_output = IGNORE_INDEX - output.append(IGNORE_INDEX) - return output - def create_dataloader(self, dataset: DatasetType) -> DataLoader: dataloader = super().create_dataloader(dataset) dataloader.collate_fn = DataCollatorForCausalLM(tokenizer=self.tokenizer, config=self.config) diff --git a/arctic_training/trainer/sft_trainer.py b/arctic_training/trainer/sft_trainer.py index b9c2fed2..ca51e0a0 100644 --- a/arctic_training/trainer/sft_trainer.py +++ b/arctic_training/trainer/sft_trainer.py @@ -23,6 +23,7 @@ from arctic_training.checkpoint.ds_engine import DSCheckpointEngine from arctic_training.checkpoint.hf_engine import HFCheckpointEngine from arctic_training.data.sft_factory import SFTDataFactory +from arctic_training.logging import logger from arctic_training.model.hf_factory import HFModelFactory from arctic_training.model.liger_factory import LigerModelFactory from arctic_training.optimizer.adam_factory import CPUAdamOptimizerFactory @@ -81,19 +82,37 @@ def loss(self, batch) -> torch.Tensor: # memory-wise, but which has more compute overhead before backward re-runs forward. The # total memory usage is very similar, but cuda cache flushes earlier if pushing close to # OOM than liger. - if self.config.model.type == "liger": + # + # Note: When in eval mode with SP > 1, liger's fused cross-entropy returns None, so we + # fall back to tiled compute for evaluation even when model.type == "liger". + # Use model_unwrapped.training because self.model is a DeepSpeed engine wrapper. + use_liger_fused_loss = self.config.model.type == "liger" and self.model_unwrapped.training + + if use_liger_fused_loss: # letting liger do fused logits+loss calculation outputs = self.model(**batch, use_cache=False) loss = outputs.loss if loss is None: - # XXX: not sure why this happens with SP>1 and eval-enabled, I checked shift_labels contain valid non -100 tokens - disabling fused_linear_cross_entropy=False in AutoLigerKernelForCausalLM.from_pretrained doesn't help. all works when eval is off. raise ValueError( - "Liger-Kernel failed to compute loss (returned None) - it's known to fail with eval enabled along" - " train steps when SP>1." + "Liger-Kernel failed to compute loss (returned None). This is unexpected during training." ) + # Handle NaN loss from Liger (can happen with all-masked batches) + if torch.isnan(loss) or torch.isinf(loss): + # Check if this is due to all labels being masked + good_tokens = ((shift_labels != -100).view(-1)).sum() + if good_tokens == 0: + logger.warning( + "Batch has no trainable tokens on this SP rank (all labels are -100). " + "Returning zero loss. This may indicate data issues with too many " + "empty/masked outputs or an unfavorable packing distribution." + ) + # Create fresh zero tensor (NaN * 0 = NaN, so we can't use loss * 0) + loss = torch.tensor(0.0, device=loss.device, dtype=loss.dtype, requires_grad=True) + # If there are good tokens but still NaN, let it propagate (real numerical issue) + else: # Currently relying on an automatic num_shards derivation based on the goal that it'll # take approximately 1GB of fp32 logits in a shard, could make this configurable if @@ -157,6 +176,19 @@ def fused_logits_loss_fn(model_with_head=None, hidden_states=None, shift_labels= good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=self.sp_group) total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(self.sp_world_size)) total_good_tokens = sum(good_tokens_per_rank) - loss = total_loss / total_good_tokens + + # Protect against division by zero when all tokens are masked + # This can happen with packed samples that have mostly non-assistant content + if total_good_tokens == 0: + logger.warning( + "Batch has no trainable tokens across all SP ranks (all labels are -100). " + "Returning zero loss. This may indicate data issues with too many " + "empty/masked outputs or an unfavorable packing distribution." + ) + # Create fresh zero tensor (total_loss may contain NaN, and NaN * 0 = NaN) + # Use loss.device/dtype since loss is guaranteed to exist from line 174 + loss = torch.tensor(0.0, device=loss.device, dtype=loss.dtype, requires_grad=True) + else: + loss = total_loss / total_good_tokens return loss diff --git a/arctic_training/trainer/trainer.py b/arctic_training/trainer/trainer.py index 7035a4b3..3ff27be3 100644 --- a/arctic_training/trainer/trainer.py +++ b/arctic_training/trainer/trainer.py @@ -561,6 +561,8 @@ def evaluate(self) -> None: self.model.eval() with torch.no_grad(): losses = [self.loss(eval_batch).item() for eval_batch in self.eval_batches] + # clear the cache after evaluation + torch.cuda.empty_cache() self.metrics.record("loss/eval", losses) # type: ignore @callback_wrapper("checkpoint")