From 88a5fee80ebe5a94631904a2d7cae60baa037844 Mon Sep 17 00:00:00 2001 From: Arun Ponnusamy Date: Wed, 7 Jan 2026 10:45:24 -0800 Subject: [PATCH 1/6] Fix NaN loss in SFT by using token-based label masking Problem: - Character-based text matching for assistant content could find wrong positions when short responses appeared earlier in conversation - End-of-turn tokens (e.g., <|im_end|>) were not included in labels - These issues caused all labels to be masked, resulting in NaN loss Solution: - Replace character-position matching with token pattern matching - Find assistant_start marker tokens and mark everything until turn_end - Add pre-defined configs for popular models (Qwen/ChatML, Llama3, etc.) - Auto-detect model family from tokenizer name and special tokens - Heuristic fallback for unknown models New files: - arctic_training/data/chat_markers.py: Chat marker detection and token-based labels Config options: - chat_template_family: Explicitly specify model family (optional, auto-detected) Available: chatml, llama3, llama2, mistral_v3, phi3, gemma, deepseek, deepseek_v2, vicuna, zephyr, command_r --- arctic_training/data/chat_markers.py | 462 +++++++++++++++++++++++++++ arctic_training/data/sft_factory.py | 115 +++---- 2 files changed, 520 insertions(+), 57 deletions(-) create mode 100644 arctic_training/data/chat_markers.py diff --git a/arctic_training/data/chat_markers.py b/arctic_training/data/chat_markers.py new file mode 100644 index 00000000..7881ec7f --- /dev/null +++ b/arctic_training/data/chat_markers.py @@ -0,0 +1,462 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Chat template marker detection and token-based label masking. + +This module provides utilities for detecting chat template markers (e.g., user/assistant +turn boundaries) and creating labels for SFT training using token-based matching rather +than character position matching. This approach: + +1. Correctly includes end-of-turn tokens in the training signal +2. Is more robust to tokenizer variations +3. Works with any chat template format + +References: +- Unsloth's train_on_responses_only: https://github.com/unslothai/unsloth-zoo +""" + +import re +from dataclasses import dataclass +from typing import Dict +from typing import List +from typing import Optional + +from transformers import PreTrainedTokenizerBase + +from arctic_training.logging import logger + +IGNORE_INDEX = -100 + + +@dataclass +class ChatMarkerConfig: + """Markers for identifying role boundaries in chat templates.""" + + user_start: str # Marks start of user turn + assistant_start: str # Marks start of assistant turn + turn_end: str # Marks end of any turn + system_start: Optional[str] = None # Optional system marker + + +# Pre-defined configurations for popular model families +KNOWN_CHAT_MARKERS: Dict[str, ChatMarkerConfig] = { + # ===== ChatML Format (Qwen, Yi, Arctic, etc.) ===== + "chatml": ChatMarkerConfig( + user_start="<|im_start|>user\n", + assistant_start="<|im_start|>assistant\n", + turn_end="<|im_end|>", + system_start="<|im_start|>system\n", + ), + # ===== Llama 3 / 3.1 / 3.2 ===== + "llama3": ChatMarkerConfig( + user_start="<|start_header_id|>user<|end_header_id|>\n\n", + assistant_start="<|start_header_id|>assistant<|end_header_id|>\n\n", + turn_end="<|eot_id|>", + system_start="<|start_header_id|>system<|end_header_id|>\n\n", + ), + # ===== Llama 2 / Mistral (Instruct format) ===== + "llama2": ChatMarkerConfig( + user_start="[INST] ", + assistant_start=" [/INST] ", + turn_end="", + system_start="<>\n", + ), + # ===== Mistral v0.3+ / Mixtral ===== + "mistral_v3": ChatMarkerConfig( + user_start="[INST]", + assistant_start="[/INST]", + turn_end="", + ), + # ===== Phi-3 / Phi-4 ===== + "phi3": ChatMarkerConfig( + user_start="<|user|>\n", + assistant_start="<|assistant|>\n", + turn_end="<|end|>", + system_start="<|system|>\n", + ), + # ===== Gemma / Gemma 2 ===== + "gemma": ChatMarkerConfig( + user_start="user\n", + assistant_start="model\n", + turn_end="", + ), + # ===== DeepSeek ===== + "deepseek": ChatMarkerConfig( + user_start="User: ", + assistant_start="Assistant: ", + turn_end="<|end▁of▁sentence|>", + ), + # ===== DeepSeek V2/V3 (ChatML-like) ===== + "deepseek_v2": ChatMarkerConfig( + user_start="<|User|>", + assistant_start="<|Assistant|>", + turn_end="<|end_of_sentence|>", + ), + # ===== Vicuna / Alpaca ===== + "vicuna": ChatMarkerConfig( + user_start="USER: ", + assistant_start="ASSISTANT: ", + turn_end="", + ), + # ===== Zephyr (based on Mistral) ===== + "zephyr": ChatMarkerConfig( + user_start="<|user|>\n", + assistant_start="<|assistant|>\n", + turn_end="", + system_start="<|system|>\n", + ), + # ===== Command-R (Cohere) ===== + "command_r": ChatMarkerConfig( + user_start="<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", + assistant_start="<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", + turn_end="<|END_OF_TURN_TOKEN|>", + system_start="<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>", + ), +} + + +def detect_model_family(tokenizer: PreTrainedTokenizerBase) -> Optional[str]: + """ + Detect which model family the tokenizer belongs to. + Returns the key for KNOWN_CHAT_MARKERS or None if unknown. + """ + # Check by model name patterns + model_name = getattr(tokenizer, "name_or_path", "").lower() + + patterns = [ + (r"qwen|arctic", "chatml"), + (r"llama-3|llama3|llama-4|llama4", "llama3"), + (r"llama-2|llama2", "llama2"), + (r"mistral.*v0\.[3-9]|mixtral", "mistral_v3"), + (r"phi-[34]|phi[34]", "phi3"), + (r"gemma", "gemma"), + (r"deepseek.*v[23]|deepseek-v[23]", "deepseek_v2"), + (r"deepseek", "deepseek"), + (r"vicuna|alpaca", "vicuna"), + (r"zephyr", "zephyr"), + (r"command-r|c4ai", "command_r"), + ] + + for pattern, family in patterns: + if re.search(pattern, model_name): + logger.info(f"Detected model family '{family}' from model name: {model_name}") + return family + + # Check by special tokens + special_tokens = set(getattr(tokenizer, "additional_special_tokens", [])) + + if "<|im_start|>" in special_tokens: + logger.info("Detected model family 'chatml' from special tokens") + return "chatml" + if "<|start_header_id|>" in special_tokens: + logger.info("Detected model family 'llama3' from special tokens") + return "llama3" + if "" in special_tokens: + logger.info("Detected model family 'gemma' from special tokens") + return "gemma" + if "<|user|>" in special_tokens and "<|assistant|>" in special_tokens: + logger.info("Detected model family 'phi3' from special tokens") + return "phi3" + + return None # Unknown - will use heuristic + + +def detect_markers_heuristic(tokenizer: PreTrainedTokenizerBase) -> ChatMarkerConfig: + """ + Auto-detect chat markers by applying the template to minimal examples and diffing. + This is the fallback when the model family is unknown. + """ + # Minimal conversations with unique single-char content + user_only = [{"role": "user", "content": "X"}] + user_asst = [{"role": "user", "content": "X"}, {"role": "assistant", "content": "Y"}] + multi_turn = [ + {"role": "user", "content": "X"}, + {"role": "assistant", "content": "Y"}, + {"role": "user", "content": "Z"}, + ] + + try: + t1 = tokenizer.apply_chat_template(user_only, tokenize=False, add_generation_prompt=False) + t2 = tokenizer.apply_chat_template(user_asst, tokenize=False, add_generation_prompt=False) + t3 = tokenizer.apply_chat_template(multi_turn, tokenize=False, add_generation_prompt=False) + except Exception as e: + logger.warning(f"Failed to apply chat template for heuristic detection: {e}") + # Return a safe default that won't match anything - will fall back to char-based + return ChatMarkerConfig( + user_start="<|UNKNOWN_USER|>", + assistant_start="<|UNKNOWN_ASSISTANT|>", + turn_end="<|UNKNOWN_END|>", + ) + + # Extract markers by diffing + # User turn = t1 (contains user start + "X" + turn end + maybe newline) + # Find where X appears + x_pos = t1.find("X") + user_start = t1[:x_pos] if x_pos > 0 else "" + + # Assistant turn = difference between t2 and t1 + # t2 should start with t1's content (possibly without trailing generation prompt) + asst_turn = t2[len(t1) :] if t2.startswith(t1) else t2[t2.find("X") + 1 :] + + # Find Y in assistant turn + y_pos_in_asst = asst_turn.find("Y") + assistant_start = asst_turn[:y_pos_in_asst] if y_pos_in_asst > 0 else "" + + # Turn end = what comes after Y and before the next user turn in t3 + # Find where Y ends and Z's user turn starts + y_pos_t3 = t3.find("Y") + z_pos_t3 = t3.find("Z") + if y_pos_t3 >= 0 and z_pos_t3 > y_pos_t3: + between = t3[y_pos_t3 + 1 : z_pos_t3] + # Turn end is the part before the next user_start + if user_start and user_start in between: + turn_end = between[: between.find(user_start)] + else: + # Just take up to some reasonable delimiter + turn_end = between.split("\n")[0] if "\n" in between else between + else: + # Fallback: look for common end tokens + turn_end = "" + for common_end in ["<|im_end|>", "<|eot_id|>", "", "<|end|>", ""]: + if common_end in t2: + turn_end = common_end + break + + # Clean up extracted markers + user_start = user_start.strip() if user_start else "<|user|>" + assistant_start = assistant_start.strip() if assistant_start else "<|assistant|>" + turn_end = turn_end.strip() if turn_end else "" + + if not turn_end: + logger.warning( + "Could not detect turn_end marker from chat template. " + "Assistant response boundaries may not be correctly identified. " + "Consider explicitly specifying chat_template_family in your config." + ) + + logger.info( + "Heuristically detected chat markers:\n" + f" user_start: {repr(user_start)}\n" + f" assistant_start: {repr(assistant_start)}\n" + f" turn_end: {repr(turn_end)}" + ) + + return ChatMarkerConfig( + user_start=user_start, + assistant_start=assistant_start, + turn_end=turn_end, + ) + + +def get_chat_markers( + tokenizer: PreTrainedTokenizerBase, + chat_template_family: Optional[str] = None, +) -> ChatMarkerConfig: + """ + Get chat markers for the given tokenizer. + Priority: explicit family > detected family > heuristic detection + """ + # 1. Explicit family specified + if chat_template_family is not None: + if chat_template_family in KNOWN_CHAT_MARKERS: + logger.info(f"Using explicitly specified chat template family: {chat_template_family}") + return KNOWN_CHAT_MARKERS[chat_template_family] + else: + raise ValueError( + f"Unknown chat_template_family: {chat_template_family}. " + f"Available options: {list(KNOWN_CHAT_MARKERS.keys())}" + ) + + # 2. Try to detect model family + family = detect_model_family(tokenizer) + if family is not None: + return KNOWN_CHAT_MARKERS[family] + + # 3. Heuristic detection (fallback) + logger.info("Model family not recognized, using heuristic marker detection") + return detect_markers_heuristic(tokenizer) + + +def _find_subsequence(sequence: List[int], subsequence: List[int], start: int = 0) -> int: + """ + Find the starting index of a subsequence within a sequence. + Returns -1 if not found or if subsequence is empty. + """ + if not subsequence: + return -1 # Empty subsequence is considered "not found" + subseq_len = len(subsequence) + for i in range(start, len(sequence) - subseq_len + 1): + if sequence[i : i + subseq_len] == subsequence: + return i + return -1 + + +def _tokenize_marker(marker: str, tokenizer: PreTrainedTokenizerBase) -> List[int]: + """ + Tokenize a marker string, handling tokenizer variations. + Similar to Unsloth's _find_common_token_ids but simplified. + """ + if not marker: + return [] + return tokenizer(marker, add_special_tokens=False).input_ids + + +def get_token_based_labels( + input_ids: List[int], + tokenizer: PreTrainedTokenizerBase, + markers: ChatMarkerConfig, +) -> List[int]: + """ + Create labels using token-based pattern matching. + + This approach: + 1. Finds assistant_start marker tokens in input_ids + 2. Marks tokens as trainable from assistant_start until turn_end or next user_start + 3. Includes the turn_end token in the training signal + + Args: + input_ids: The tokenized input sequence + tokenizer: The tokenizer used + markers: Chat template markers + + Returns: + List of labels where trainable tokens have their token ID and masked tokens have -100 + """ + # Tokenize the markers + user_start_ids = _tokenize_marker(markers.user_start, tokenizer) + assistant_start_ids = _tokenize_marker(markers.assistant_start, tokenizer) + turn_end_ids = _tokenize_marker(markers.turn_end, tokenizer) + + # Initialize all labels as masked + labels = [IGNORE_INDEX] * len(input_ids) + + # Find all assistant turns and mark them as trainable + i = 0 + while i < len(input_ids): + # Look for assistant start marker + if assistant_start_ids: + asst_start_pos = _find_subsequence(input_ids, assistant_start_ids, i) + else: + asst_start_pos = -1 + + if asst_start_pos == -1: + break # No more assistant turns + + # Move past the assistant start marker (don't include marker in labels) + content_start = asst_start_pos + len(assistant_start_ids) + + # Find the end of this assistant turn + # Look for either turn_end or next user_start, whichever comes first + turn_end_pos = len(input_ids) # Default to end of sequence + + if turn_end_ids: + end_pos = _find_subsequence(input_ids, turn_end_ids, content_start) + if end_pos != -1: + # Include the turn_end token in labels + turn_end_pos = end_pos + len(turn_end_ids) + + if user_start_ids: + next_user_pos = _find_subsequence(input_ids, user_start_ids, content_start) + if next_user_pos != -1: + turn_end_pos = min(turn_end_pos, next_user_pos) + + # Mark tokens from content_start to turn_end_pos as trainable + for j in range(content_start, min(turn_end_pos, len(input_ids))): + labels[j] = input_ids[j] + + # Move to search for next assistant turn + i = turn_end_pos + + # Validate that we found trainable tokens + trainable_count = sum(1 for label in labels if label != IGNORE_INDEX) + if trainable_count == 0: + logger.warning( + "No trainable tokens found after label masking. " + "This may cause NaN loss during training. " + "Check that the chat template markers match your model's format." + ) + + return labels + + +def get_token_based_labels_with_ignore_empty_think( + input_ids: List[int], + tokenizer: PreTrainedTokenizerBase, + markers: ChatMarkerConfig, + ignore_empty_think: bool = False, +) -> List[int]: + """ + Create labels using token-based pattern matching, with support for ignore_empty_think. + + When ignore_empty_think is True, empty patterns at the start of + assistant responses are masked (set to IGNORE_INDEX) to prevent the model from + learning to produce empty thinking blocks. + """ + labels = get_token_based_labels(input_ids, tokenizer, markers) + + if not ignore_empty_think: + return labels + + # Handle empty think tags - find and mask them + # Tokenize the empty think pattern + empty_think_pattern = "" + empty_think_ids = _tokenize_marker(empty_think_pattern, tokenizer) + + if not empty_think_ids: + return labels + + # Find and mask empty think patterns that appear at start of assistant responses + assistant_start_ids = _tokenize_marker(markers.assistant_start, tokenizer) + user_start_ids = _tokenize_marker(markers.user_start, tokenizer) + turn_end_ids = _tokenize_marker(markers.turn_end, tokenizer) + + i = 0 + while i < len(input_ids): + if assistant_start_ids: + asst_start_pos = _find_subsequence(input_ids, assistant_start_ids, i) + else: + break + + if asst_start_pos == -1: + break + + content_start = asst_start_pos + len(assistant_start_ids) + + # Check if empty think pattern is right after assistant start + if input_ids[content_start : content_start + len(empty_think_ids)] == empty_think_ids: + # Mask the empty think tokens + for j in range(content_start, content_start + len(empty_think_ids)): + if j < len(labels): + labels[j] = IGNORE_INDEX + + # Find the end of this assistant turn to skip past it entirely + # This prevents false matches if the response contains marker patterns + turn_end_pos = len(input_ids) # Default to end of sequence + + if turn_end_ids: + end_pos = _find_subsequence(input_ids, turn_end_ids, content_start) + if end_pos != -1: + turn_end_pos = end_pos + len(turn_end_ids) + + if user_start_ids: + next_user_pos = _find_subsequence(input_ids, user_start_ids, content_start) + if next_user_pos != -1: + turn_end_pos = min(turn_end_pos, next_user_pos) + + # Move past the entire assistant turn + i = turn_end_pos + + return labels diff --git a/arctic_training/data/sft_factory.py b/arctic_training/data/sft_factory.py index 22283d14..4c876756 100644 --- a/arctic_training/data/sft_factory.py +++ b/arctic_training/data/sft_factory.py @@ -14,12 +14,10 @@ # limitations under the License. import random -import re from typing import Dict from typing import List from typing import Literal from typing import Optional -from typing import Tuple import numpy as np import torch @@ -33,6 +31,9 @@ from arctic_training.config.data import DataConfig from arctic_training.config.utils import HumanInt +from arctic_training.data.chat_markers import KNOWN_CHAT_MARKERS +from arctic_training.data.chat_markers import get_chat_markers +from arctic_training.data.chat_markers import get_token_based_labels_with_ignore_empty_think from arctic_training.data.factory import DataFactory from arctic_training.data.hf_instruct_source import HFDataSourceInstruct from arctic_training.data.utils import DatasetType @@ -265,6 +266,14 @@ class SFTDataConfig(DataConfig): ignore_empty_think: bool = False """ Whether to mask the empty think tokens preventing the loss of thinking ability.""" + chat_template_family: Optional[str] = None + """ + Explicitly specify the chat template family for label masking. + If None, auto-detection is used based on model name and special tokens. + Available options: chatml, llama3, llama2, mistral_v3, phi3, gemma, + deepseek, deepseek_v2, vicuna, zephyr, command_r + """ + @model_validator(mode="after") def validate_padding(self) -> Self: if self.pad_to == "max_length" and "div_length" in self.model_fields_set: @@ -280,6 +289,16 @@ def validate_padding(self) -> Self: ) return self + @model_validator(mode="after") + def validate_chat_template_family(self) -> Self: + if self.chat_template_family is not None: + if self.chat_template_family not in KNOWN_CHAT_MARKERS: + raise ValueError( + f"Unknown chat_template_family: {self.chat_template_family}. " + f"Available options: {list(KNOWN_CHAT_MARKERS.keys())}" + ) + return self + def filter_dataset_length(self, dataset: DatasetType) -> DatasetType: if not self.config.filter_samples: @@ -375,6 +394,12 @@ def process(self, dataset: DatasetType) -> DatasetType: # datasets.disable_caching() # tmp = tokenize_messages(datasets[0]["messages"][:2], tokenizer, mask_inputs=mask_inputs) # import pdb; pdb.set_trace() + + # Get chat markers once for the entire dataset processing + chat_markers = None + if self.config.mask_inputs: + chat_markers = get_chat_markers(self.tokenizer, self.config.chat_template_family) + return dataset.map( lambda ex: { **self.tokenize_messages( @@ -382,6 +407,7 @@ def process(self, dataset: DatasetType) -> DatasetType: self.tokenizer, mask_inputs=self.config.mask_inputs, ignore_empty_think=self.config.ignore_empty_think, + chat_markers=chat_markers, ) }, remove_columns=dataset.column_names, @@ -396,75 +422,50 @@ def tokenize_messages( tokenizer: PreTrainedTokenizerBase, mask_inputs: bool = True, ignore_empty_think: bool = False, + chat_markers=None, ) -> BatchEncoding: + """ + Tokenize messages and create labels for SFT training. + + Uses token-based pattern matching to identify assistant turns, which: + 1. Correctly includes end-of-turn tokens in the training signal + 2. Is more robust to tokenizer variations + 3. Works with any chat template format + + Args: + messages: List of message dicts with 'role' and 'content' keys + tokenizer: The tokenizer to use + mask_inputs: Whether to mask non-assistant tokens (set to -100) + ignore_empty_think: Whether to mask empty patterns + chat_markers: Pre-computed chat markers (optional, will be detected if None) + + Returns: + BatchEncoding with input_ids, attention_mask, and labels + """ conversation_text = tokenizer.apply_chat_template(conversation=messages, tokenize=False) conversation_ids = tokenizer( conversation_text, - return_offsets_mapping=mask_inputs, add_special_tokens=False, ) if mask_inputs: - assistant_ranges = cls.get_assistant_start_end_indices(messages, conversation_text, ignore_empty_think) - # _ = get_assistant_start_end_indices(messages, conversation_text) - labels = cls.get_masked_labels(conversation_ids, assistant_ranges) + # Get chat markers if not provided + if chat_markers is None: + chat_markers = get_chat_markers(tokenizer) + + # Use token-based label masking + labels = get_token_based_labels_with_ignore_empty_think( + input_ids=conversation_ids["input_ids"], + tokenizer=tokenizer, + markers=chat_markers, + ignore_empty_think=ignore_empty_think, + ) conversation_ids["labels"] = labels - # compare_messages_with_labels(split_list_by_specific_num(conversation_ids["labels"]), messages, tokenizer) - del conversation_ids["offset_mapping"] else: conversation_ids["labels"] = conversation_ids["input_ids"] return conversation_ids - @staticmethod - # this code is adpoted from https://github.com/huggingface/trl/issues/632 (user: Peter-Devine ) - def get_assistant_start_end_indices( - messages: List[Dict[str, str]], - conversation_text: str, - ignore_empty_think: bool = False, - ) -> List[Tuple[int, int]]: - return_indices = [] - for message in messages: - if message["role"] == "assistant": - message_text = message["content"] - if ignore_empty_think: - message_text = re.sub(r"^\s*\s*", "", message_text) - match_index = conversation_text.find(message_text) - # start_indices.append(match_index) - end_indices = match_index + len(message_text) - return_indices.append((match_index, end_indices)) - return return_indices - - @staticmethod - def get_masked_labels(conversation_ids: BatchEncoding, assistant_ranges: List[Tuple[int, int]]) -> List[int]: - pre_output = IGNORE_INDEX - output = [] - - for id_, (id_s, id_e) in list( - zip( - conversation_ids["input_ids"], - conversation_ids["offset_mapping"], - ) - ): - if any(id_s >= s and id_e <= e for s, e in assistant_ranges): - pre_output = id_ - output.append(id_) - else: - # the if-else here is to include the eos token in the loss. - # for instance, the asistent answer is - # <|assistant|> I am good <|user|> xxx - # -100 1 2 3 4 -100 -100 - # after the shift, input_ids = input_ids[:-1], labels = labels[1:] - # 1 2 3 4 -100 -100 - # now the prediction is correct, and the model will be able to predict token - if pre_output != IGNORE_INDEX: - pre_output = IGNORE_INDEX - output.append(id_) - else: - pre_output = IGNORE_INDEX - output.append(IGNORE_INDEX) - return output - def create_dataloader(self, dataset: DatasetType) -> DataLoader: dataloader = super().create_dataloader(dataset) dataloader.collate_fn = DataCollatorForCausalLM(tokenizer=self.tokenizer, config=self.config) From 0126d19853397b23518b4fd1115b143fdabec59d Mon Sep 17 00:00:00 2001 From: Arun Ponnusamy Date: Fri, 9 Jan 2026 23:25:26 -0800 Subject: [PATCH 2/6] Fix NaN loss with all-masked batches in sequence parallel mode - Add protection against division by zero when total_good_tokens is 0 - Handle NaN loss from Liger kernel on all-masked batches - Add warnings to help diagnose data issues with empty/masked outputs - This fixes NaN losses that occur when packed samples have mostly non-assistant content distributed across sequence parallel ranks --- arctic_training/trainer/sft_trainer.py | 41 ++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/arctic_training/trainer/sft_trainer.py b/arctic_training/trainer/sft_trainer.py index b9c2fed2..e98dc013 100644 --- a/arctic_training/trainer/sft_trainer.py +++ b/arctic_training/trainer/sft_trainer.py @@ -23,6 +23,7 @@ from arctic_training.checkpoint.ds_engine import DSCheckpointEngine from arctic_training.checkpoint.hf_engine import HFCheckpointEngine from arctic_training.data.sft_factory import SFTDataFactory +from arctic_training.logging import logger from arctic_training.model.hf_factory import HFModelFactory from arctic_training.model.liger_factory import LigerModelFactory from arctic_training.optimizer.adam_factory import CPUAdamOptimizerFactory @@ -81,19 +82,37 @@ def loss(self, batch) -> torch.Tensor: # memory-wise, but which has more compute overhead before backward re-runs forward. The # total memory usage is very similar, but cuda cache flushes earlier if pushing close to # OOM than liger. - if self.config.model.type == "liger": + # + # Note: When in eval mode with SP > 1, liger's fused cross-entropy returns None, so we + # fall back to tiled compute for evaluation even when model.type == "liger". + # Use model_unwrapped.training because self.model is a DeepSpeed engine wrapper. + use_liger_fused_loss = self.config.model.type == "liger" and self.model_unwrapped.training + + if use_liger_fused_loss: # letting liger do fused logits+loss calculation outputs = self.model(**batch, use_cache=False) loss = outputs.loss if loss is None: - # XXX: not sure why this happens with SP>1 and eval-enabled, I checked shift_labels contain valid non -100 tokens - disabling fused_linear_cross_entropy=False in AutoLigerKernelForCausalLM.from_pretrained doesn't help. all works when eval is off. raise ValueError( - "Liger-Kernel failed to compute loss (returned None) - it's known to fail with eval enabled along" - " train steps when SP>1." + "Liger-Kernel failed to compute loss (returned None). This is unexpected during training." ) + # Handle NaN loss from Liger (can happen with all-masked batches) + if torch.isnan(loss) or torch.isinf(loss): + # Check if this is due to all labels being masked + good_tokens = ((shift_labels != -100).view(-1)).sum() + if good_tokens == 0: + logger.warning( + "Batch has no trainable tokens on this SP rank (all labels are -100). " + "Returning zero loss. This may indicate data issues with too many " + "empty/masked outputs or an unfavorable packing distribution." + ) + # Return differentiable zero loss for this rank + loss = loss * 0.0 + 0.0 # Preserves gradient graph but returns 0 + # If there are good tokens but still NaN, let it propagate (real numerical issue) + else: # Currently relying on an automatic num_shards derivation based on the goal that it'll # take approximately 1GB of fp32 logits in a shard, could make this configurable if @@ -157,6 +176,18 @@ def fused_logits_loss_fn(model_with_head=None, hidden_states=None, shift_labels= good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=self.sp_group) total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(self.sp_world_size)) total_good_tokens = sum(good_tokens_per_rank) - loss = total_loss / total_good_tokens + + # Protect against division by zero when all tokens are masked + # This can happen with packed samples that have mostly non-assistant content + if total_good_tokens == 0: + logger.warning( + "Batch has no trainable tokens across all SP ranks (all labels are -100). " + "Returning zero loss. This may indicate data issues with too many " + "empty/masked outputs or an unfavorable packing distribution." + ) + # Return differentiable zero loss - the batch has no trainable tokens + loss = total_loss * 0.0 + else: + loss = total_loss / total_good_tokens return loss From c2fce148df98cbe08dc2faa115b5220ca007a086 Mon Sep 17 00:00:00 2001 From: Arun Ponnusamy Date: Sat, 10 Jan 2026 17:51:56 -0800 Subject: [PATCH 3/6] fix(sft): add robust NaN handling and fix marker tokenization consistency ## NaN Loss Handling - Add NaN/Inf checking to SP==1 path that was missing protection (lines 49-72 in sft_trainer.py) - Ensure zero loss is connected to computation graph to maintain proper gradient flow for DeepSpeed optimizer state - Use outputs.logits.sum() * 0.0 pattern when available, fallback to torch.zeros with requires_grad=True ## Marker Tokenization Fix - Fix inconsistency where get_token_based_labels_with_ignore_empty_think used _tokenize_marker() while get_token_based_labels() used _tokenize_marker_without_trailing_whitespace() - This caused different content_start positions, potentially masking wrong token ranges when ignore_empty_think=True ## Other - Add torch.cuda.empty_cache() after evaluation to free memory --- arctic_training/data/chat_markers.py | 31 +++++++++++++++---- arctic_training/trainer/sft_trainer.py | 41 +++++++++++++++++++++++--- arctic_training/trainer/trainer.py | 2 ++ 3 files changed, 65 insertions(+), 9 deletions(-) diff --git a/arctic_training/data/chat_markers.py b/arctic_training/data/chat_markers.py index 7881ec7f..93191a28 100644 --- a/arctic_training/data/chat_markers.py +++ b/arctic_training/data/chat_markers.py @@ -314,6 +314,24 @@ def _tokenize_marker(marker: str, tokenizer: PreTrainedTokenizerBase) -> List[in return tokenizer(marker, add_special_tokens=False).input_ids +def _tokenize_marker_without_trailing_whitespace(marker: str, tokenizer: PreTrainedTokenizerBase) -> List[int]: + """ + Tokenize a marker string, stripping trailing whitespace to avoid tokenizer + context sensitivity issues. + + Some tokenizers (e.g., Qwen) tokenize "\\n\\n" differently than "\\n" + "\\n". + When assistant content starts with "\\n", the marker "assistant\\n" + "\\n" becomes + "assistant\\n\\n" which tokenizes differently, causing pattern match failures. + + By stripping trailing whitespace from the marker, we avoid this issue. + """ + if not marker: + return [] + # Strip trailing whitespace (newlines, spaces) to avoid context sensitivity + stripped = marker.rstrip() + return tokenizer(stripped, add_special_tokens=False).input_ids + + def get_token_based_labels( input_ids: List[int], tokenizer: PreTrainedTokenizerBase, @@ -335,9 +353,10 @@ def get_token_based_labels( Returns: List of labels where trainable tokens have their token ID and masked tokens have -100 """ - # Tokenize the markers - user_start_ids = _tokenize_marker(markers.user_start, tokenizer) - assistant_start_ids = _tokenize_marker(markers.assistant_start, tokenizer) + # Tokenize the markers - strip trailing whitespace to avoid context sensitivity + # This handles cases where content starts with newlines causing different tokenization + user_start_ids = _tokenize_marker_without_trailing_whitespace(markers.user_start, tokenizer) + assistant_start_ids = _tokenize_marker_without_trailing_whitespace(markers.assistant_start, tokenizer) turn_end_ids = _tokenize_marker(markers.turn_end, tokenizer) # Initialize all labels as masked @@ -419,8 +438,10 @@ def get_token_based_labels_with_ignore_empty_think( return labels # Find and mask empty think patterns that appear at start of assistant responses - assistant_start_ids = _tokenize_marker(markers.assistant_start, tokenizer) - user_start_ids = _tokenize_marker(markers.user_start, tokenizer) + # IMPORTANT: Use the same tokenization as get_token_based_labels() to ensure + # content_start positions match where tokens were actually marked as trainable + assistant_start_ids = _tokenize_marker_without_trailing_whitespace(markers.assistant_start, tokenizer) + user_start_ids = _tokenize_marker_without_trailing_whitespace(markers.user_start, tokenizer) turn_end_ids = _tokenize_marker(markers.turn_end, tokenizer) i = 0 diff --git a/arctic_training/trainer/sft_trainer.py b/arctic_training/trainer/sft_trainer.py index e98dc013..9dd9841e 100644 --- a/arctic_training/trainer/sft_trainer.py +++ b/arctic_training/trainer/sft_trainer.py @@ -52,6 +52,25 @@ def loss(self, batch) -> torch.Tensor: # compute (liger fused logits+loss kernel does not repeat forward during backward) outputs = self.model(**batch, use_cache=False) loss = outputs.loss + + # Handle NaN loss (can happen with all-masked batches where all labels are -100) + if torch.isnan(loss) or torch.isinf(loss): + labels = batch.get("labels") + if labels is not None: + good_tokens = ((labels != -100).view(-1)).sum() + if good_tokens == 0: + logger.warning( + "Batch has no trainable tokens (all labels are -100). " + "Returning zero loss. This may indicate data issues with " + "chat template marker mismatches or empty outputs." + ) + # Create a zero loss connected to the computation graph + if hasattr(outputs, "logits") and outputs.logits is not None: + loss = (outputs.logits.sum() * 0.0).to(loss.dtype) + else: + loss = torch.zeros(1, device=loss.device, dtype=loss.dtype, requires_grad=True)[0] + # If there are good tokens but still NaN, let it propagate (real numerical issue) + return loss # Ulysses SP expectations: @@ -109,8 +128,15 @@ def loss(self, batch) -> torch.Tensor: "Returning zero loss. This may indicate data issues with too many " "empty/masked outputs or an unfavorable packing distribution." ) - # Return differentiable zero loss for this rank - loss = loss * 0.0 + 0.0 # Preserves gradient graph but returns 0 + # We need a zero loss that's still connected to the computation graph + # to ensure proper gradient flow for DeepSpeed. Use outputs.logits if available, + # otherwise fall back to a differentiable zero from the loss tensor. + if hasattr(outputs, "logits") and outputs.logits is not None: + loss = (outputs.logits.sum() * 0.0).to(loss.dtype) + else: + # Fallback: use torch.zeros_like to get correct device/dtype + # and add to a scalar from the loss to maintain some graph connection + loss = torch.zeros(1, device=loss.device, dtype=loss.dtype, requires_grad=True)[0] # If there are good tokens but still NaN, let it propagate (real numerical issue) else: @@ -185,8 +211,15 @@ def fused_logits_loss_fn(model_with_head=None, hidden_states=None, shift_labels= "Returning zero loss. This may indicate data issues with too many " "empty/masked outputs or an unfavorable packing distribution." ) - # Return differentiable zero loss - the batch has no trainable tokens - loss = total_loss * 0.0 + # We need a zero loss that's connected to the computation graph. + # The 'loss' variable from earlier is connected to the graph, so we use it. + # If loss is NaN (from all-masked Liger batch), we use a fallback. + if torch.isnan(loss) or torch.isinf(loss): + # Fallback: create a zero that's at least on the correct device + loss = torch.zeros(1, device=loss.device, dtype=loss.dtype, requires_grad=True)[0] + else: + # Use the existing loss (which is connected to graph) multiplied by 0 + loss = loss * 0.0 else: loss = total_loss / total_good_tokens diff --git a/arctic_training/trainer/trainer.py b/arctic_training/trainer/trainer.py index 7035a4b3..3ff27be3 100644 --- a/arctic_training/trainer/trainer.py +++ b/arctic_training/trainer/trainer.py @@ -561,6 +561,8 @@ def evaluate(self) -> None: self.model.eval() with torch.no_grad(): losses = [self.loss(eval_batch).item() for eval_batch in self.eval_batches] + # clear the cache after evaluation + torch.cuda.empty_cache() self.metrics.record("loss/eval", losses) # type: ignore @callback_wrapper("checkpoint") From b9333b136fd163d1ca05cff6f891c7c532e110bf Mon Sep 17 00:00:00 2001 From: Arun Ponnusamy Date: Sat, 10 Jan 2026 17:53:42 -0800 Subject: [PATCH 4/6] fix(chat_markers): fix heuristic marker extraction including previous turn's end marker When detect_markers_heuristic falls back to extracting markers by finding 'X' in the conversation, the asst_turn variable was including the previous turn's end marker (e.g., '<|im_end|><|im_start|>assistant' instead of just '<|im_start|>assistant'). Fix: In the fallback case, look for known end markers in the segment between X and Y, and extract only the portion after the end marker as the assistant_start. This prevents pattern matching failures that would result in all labels being masked to -100 and NaN loss during training. --- arctic_training/data/chat_markers.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/arctic_training/data/chat_markers.py b/arctic_training/data/chat_markers.py index 93191a28..c29405f6 100644 --- a/arctic_training/data/chat_markers.py +++ b/arctic_training/data/chat_markers.py @@ -209,7 +209,29 @@ def detect_markers_heuristic(tokenizer: PreTrainedTokenizerBase) -> ChatMarkerCo # Assistant turn = difference between t2 and t1 # t2 should start with t1's content (possibly without trailing generation prompt) - asst_turn = t2[len(t1) :] if t2.startswith(t1) else t2[t2.find("X") + 1 :] + if t2.startswith(t1): + asst_turn = t2[len(t1) :] + else: + # Fallback: find Y in t2 and extract the marker before it + # This avoids including the previous turn's end marker + y_pos_t2 = t2.find("Y") + x_pos_t2 = t2.find("X") + if y_pos_t2 > x_pos_t2 >= 0: + # Extract from after X's position to Y, then find where assistant marker starts + # Look for common assistant markers in the segment between X and Y + segment = t2[x_pos_t2 + 1 : y_pos_t2] + # Find where a new role marker likely starts (after any turn end token) + for end_marker in ["<|im_end|>", "<|eot_id|>", "", "<|end|>", ""]: + if end_marker in segment: + # Assistant start is everything after the end marker (plus any newline) + marker_end = segment.find(end_marker) + len(end_marker) + asst_turn = segment[marker_end:].lstrip("\n") + "Y" + break + else: + # No known end marker found, use original fallback + asst_turn = segment + "Y" + else: + asst_turn = t2[t2.find("X") + 1 :] # Find Y in assistant turn y_pos_in_asst = asst_turn.find("Y") From 1224804c6f632392da5522b7cc312aaf399588f6 Mon Sep 17 00:00:00 2001 From: Arun Ponnusamy Date: Mon, 12 Jan 2026 10:20:25 -0800 Subject: [PATCH 5/6] Simplify token-based label masking and zero loss handling - chat_markers.py: Remove over-engineered assistant turn detection fallback logic; use simple conditional instead of complex marker heuristics - chat_markers.py: Use consistent _tokenize_marker() for assistant/user start markers in get_token_based_labels_with_ignore_empty_think() - sft_trainer.py: Remove redundant NaN loss handling in non-SP forward path - sft_trainer.py: Fix zero loss creation by using fresh tensor instead of loss * 0.0 (since NaN * 0 = NaN, multiplying NaN loss by zero still fails) --- arctic_training/data/chat_markers.py | 30 ++---------------- arctic_training/trainer/sft_trainer.py | 42 +++----------------------- 2 files changed, 8 insertions(+), 64 deletions(-) diff --git a/arctic_training/data/chat_markers.py b/arctic_training/data/chat_markers.py index c29405f6..f6384e91 100644 --- a/arctic_training/data/chat_markers.py +++ b/arctic_training/data/chat_markers.py @@ -209,29 +209,7 @@ def detect_markers_heuristic(tokenizer: PreTrainedTokenizerBase) -> ChatMarkerCo # Assistant turn = difference between t2 and t1 # t2 should start with t1's content (possibly without trailing generation prompt) - if t2.startswith(t1): - asst_turn = t2[len(t1) :] - else: - # Fallback: find Y in t2 and extract the marker before it - # This avoids including the previous turn's end marker - y_pos_t2 = t2.find("Y") - x_pos_t2 = t2.find("X") - if y_pos_t2 > x_pos_t2 >= 0: - # Extract from after X's position to Y, then find where assistant marker starts - # Look for common assistant markers in the segment between X and Y - segment = t2[x_pos_t2 + 1 : y_pos_t2] - # Find where a new role marker likely starts (after any turn end token) - for end_marker in ["<|im_end|>", "<|eot_id|>", "", "<|end|>", ""]: - if end_marker in segment: - # Assistant start is everything after the end marker (plus any newline) - marker_end = segment.find(end_marker) + len(end_marker) - asst_turn = segment[marker_end:].lstrip("\n") + "Y" - break - else: - # No known end marker found, use original fallback - asst_turn = segment + "Y" - else: - asst_turn = t2[t2.find("X") + 1 :] + asst_turn = t2[len(t1) :] if t2.startswith(t1) else t2[t2.find("X") + 1 :] # Find Y in assistant turn y_pos_in_asst = asst_turn.find("Y") @@ -460,10 +438,8 @@ def get_token_based_labels_with_ignore_empty_think( return labels # Find and mask empty think patterns that appear at start of assistant responses - # IMPORTANT: Use the same tokenization as get_token_based_labels() to ensure - # content_start positions match where tokens were actually marked as trainable - assistant_start_ids = _tokenize_marker_without_trailing_whitespace(markers.assistant_start, tokenizer) - user_start_ids = _tokenize_marker_without_trailing_whitespace(markers.user_start, tokenizer) + assistant_start_ids = _tokenize_marker(markers.assistant_start, tokenizer) + user_start_ids = _tokenize_marker(markers.user_start, tokenizer) turn_end_ids = _tokenize_marker(markers.turn_end, tokenizer) i = 0 diff --git a/arctic_training/trainer/sft_trainer.py b/arctic_training/trainer/sft_trainer.py index 9dd9841e..ca51e0a0 100644 --- a/arctic_training/trainer/sft_trainer.py +++ b/arctic_training/trainer/sft_trainer.py @@ -52,25 +52,6 @@ def loss(self, batch) -> torch.Tensor: # compute (liger fused logits+loss kernel does not repeat forward during backward) outputs = self.model(**batch, use_cache=False) loss = outputs.loss - - # Handle NaN loss (can happen with all-masked batches where all labels are -100) - if torch.isnan(loss) or torch.isinf(loss): - labels = batch.get("labels") - if labels is not None: - good_tokens = ((labels != -100).view(-1)).sum() - if good_tokens == 0: - logger.warning( - "Batch has no trainable tokens (all labels are -100). " - "Returning zero loss. This may indicate data issues with " - "chat template marker mismatches or empty outputs." - ) - # Create a zero loss connected to the computation graph - if hasattr(outputs, "logits") and outputs.logits is not None: - loss = (outputs.logits.sum() * 0.0).to(loss.dtype) - else: - loss = torch.zeros(1, device=loss.device, dtype=loss.dtype, requires_grad=True)[0] - # If there are good tokens but still NaN, let it propagate (real numerical issue) - return loss # Ulysses SP expectations: @@ -128,15 +109,8 @@ def loss(self, batch) -> torch.Tensor: "Returning zero loss. This may indicate data issues with too many " "empty/masked outputs or an unfavorable packing distribution." ) - # We need a zero loss that's still connected to the computation graph - # to ensure proper gradient flow for DeepSpeed. Use outputs.logits if available, - # otherwise fall back to a differentiable zero from the loss tensor. - if hasattr(outputs, "logits") and outputs.logits is not None: - loss = (outputs.logits.sum() * 0.0).to(loss.dtype) - else: - # Fallback: use torch.zeros_like to get correct device/dtype - # and add to a scalar from the loss to maintain some graph connection - loss = torch.zeros(1, device=loss.device, dtype=loss.dtype, requires_grad=True)[0] + # Create fresh zero tensor (NaN * 0 = NaN, so we can't use loss * 0) + loss = torch.tensor(0.0, device=loss.device, dtype=loss.dtype, requires_grad=True) # If there are good tokens but still NaN, let it propagate (real numerical issue) else: @@ -211,15 +185,9 @@ def fused_logits_loss_fn(model_with_head=None, hidden_states=None, shift_labels= "Returning zero loss. This may indicate data issues with too many " "empty/masked outputs or an unfavorable packing distribution." ) - # We need a zero loss that's connected to the computation graph. - # The 'loss' variable from earlier is connected to the graph, so we use it. - # If loss is NaN (from all-masked Liger batch), we use a fallback. - if torch.isnan(loss) or torch.isinf(loss): - # Fallback: create a zero that's at least on the correct device - loss = torch.zeros(1, device=loss.device, dtype=loss.dtype, requires_grad=True)[0] - else: - # Use the existing loss (which is connected to graph) multiplied by 0 - loss = loss * 0.0 + # Create fresh zero tensor (total_loss may contain NaN, and NaN * 0 = NaN) + # Use loss.device/dtype since loss is guaranteed to exist from line 174 + loss = torch.tensor(0.0, device=loss.device, dtype=loss.dtype, requires_grad=True) else: loss = total_loss / total_good_tokens From 2729c8d59898ae3d429e001a966c5f030b776d20 Mon Sep 17 00:00:00 2001 From: Arun Ponnusamy Date: Mon, 12 Jan 2026 19:55:12 +0000 Subject: [PATCH 6/6] fix(chat_markers): ensure consistent marker tokenization in empty think handling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use _tokenize_marker_without_trailing_whitespace() for assistant_start and user_start markers to match the tokenization method used in get_token_based_labels(), preventing inconsistencies in label masking. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 --- arctic_training/data/chat_markers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/arctic_training/data/chat_markers.py b/arctic_training/data/chat_markers.py index f6384e91..69947496 100644 --- a/arctic_training/data/chat_markers.py +++ b/arctic_training/data/chat_markers.py @@ -438,8 +438,9 @@ def get_token_based_labels_with_ignore_empty_think( return labels # Find and mask empty think patterns that appear at start of assistant responses - assistant_start_ids = _tokenize_marker(markers.assistant_start, tokenizer) - user_start_ids = _tokenize_marker(markers.user_start, tokenizer) + # Use same tokenization method as get_token_based_labels() for consistency + assistant_start_ids = _tokenize_marker_without_trailing_whitespace(markers.assistant_start, tokenizer) + user_start_ids = _tokenize_marker_without_trailing_whitespace(markers.user_start, tokenizer) turn_end_ids = _tokenize_marker(markers.turn_end, tokenizer) i = 0