Skip to content
Open
88 changes: 80 additions & 8 deletions arctic_training/data/sft_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import random
import re
from typing import Dict
Expand All @@ -37,8 +39,14 @@
from arctic_training.data.hf_instruct_source import HFDataSourceInstruct
from arctic_training.data.utils import DatasetType

logger = logging.getLogger(__name__)
Comment thread
sfc-gh-aponnusamy marked this conversation as resolved.
Outdated

IGNORE_INDEX = -100

# Debug flag for detailed logging of label masking issues
# Set environment variable DEBUG_LABEL_MASKING=1 to enable
DEBUG_LABEL_MASKING = os.environ.get("DEBUG_LABEL_MASKING", "0") == "1"


# this function is modified from TRL trl.trainer.utils.py
def pad(
Expand Down Expand Up @@ -409,6 +417,25 @@ def tokenize_messages(
# _ = get_assistant_start_end_indices(messages, conversation_text)
labels = cls.get_masked_labels(conversation_ids, assistant_ranges)
conversation_ids["labels"] = labels

# Debug logging for label masking issues
if DEBUG_LABEL_MASKING:
non_masked = sum(1 for label in labels if label != IGNORE_INDEX)
total = len(labels)
if non_masked == 0:
logger.warning(
"ALL LABELS MASKED! This will cause NaN loss.\n conversation_text length:"
f" {len(conversation_text)}\n assistant_ranges: {assistant_ranges}\n num_messages:"
f" {len(messages)}\n assistant_contents:"
f" {[m['content'][:50] + '...' for m in messages if m['role'] == 'assistant']}"
)
elif non_masked < 5:
logger.warning(
f"Very few non-masked labels ({non_masked}/{total}).\n assistant_ranges: {assistant_ranges}\n"
" assistant_contents:"
f" {[m['content'][:50] + '...' for m in messages if m['role'] == 'assistant']}"
)

# compare_messages_with_labels(split_list_by_specific_num(conversation_ids["labels"]), messages, tokenizer)
del conversation_ids["offset_mapping"]
else:
Expand All @@ -424,15 +451,56 @@ def get_assistant_start_end_indices(
ignore_empty_think: bool = False,
) -> List[Tuple[int, int]]:
return_indices = []
# Track search position to avoid matching assistant content that appears
# earlier in the conversation (e.g., in user context). Process ALL messages
# in order to track position, so assistant content is found AFTER the
# preceding user message, not at the first occurrence.
search_start = 0
for message in messages:
if message["role"] == "assistant":
message_text = message["content"]
if ignore_empty_think:
message_text = re.sub(r"^<think>\s*</think>\s*", "", message_text)
message_text = message["content"]
original_text = message_text # Keep for debug logging
if message["role"] == "assistant" and ignore_empty_think:
message_text = re.sub(r"^<think>\s*</think>\s*", "", message_text)
# Find this message starting from current position
match_index = conversation_text.find(message_text, search_start)
if match_index == -1:
# Fallback: try searching from the beginning (original behavior)
match_index = conversation_text.find(message_text)
# start_indices.append(match_index)
end_indices = match_index + len(message_text)
return_indices.append((match_index, end_indices))
if DEBUG_LABEL_MASKING and message["role"] == "assistant":
if match_index == -1:
logger.warning(
"Assistant content NOT FOUND in conversation_text!\n"
f" role: {message['role']}\n"
f" original_content: {repr(original_text[:100])}...\n"
f" search_text: {repr(message_text[:100])}...\n"
f" search_start was: {search_start}\n"
f" conversation_text length: {len(conversation_text)}"
)
else:
logger.warning(
"Assistant content found via fallback (from position 0)!\n"
f" content: {repr(message_text[:50])}...\n"
f" found at: {match_index}, search_start was: {search_start}"
)
end_index = match_index + len(message_text) if match_index != -1 else -1

if DEBUG_LABEL_MASKING and message["role"] == "assistant":
conv_len = len(conversation_text)
position_pct = (match_index / conv_len * 100) if match_index != -1 else -1
logger.info(
Comment thread
sfc-gh-aponnusamy marked this conversation as resolved.
Outdated
f"Assistant range: ({match_index}, {end_index}) - "
f"{position_pct:.1f}% into conversation (len={conv_len}), "
f"content: {repr(message_text[:30])}..."
)

# Only record assistant message ranges
if message["role"] == "assistant":
return_indices.append((match_index, end_index))
# Update search position for next message (track all messages in order)
# Use max() to ensure we never go backwards - this handles the case where
# fallback search finds content at an earlier position (likely a false match)
if match_index != -1:
search_start = max(search_start, match_index + len(message_text))
return return_indices

@staticmethod
Expand All @@ -446,7 +514,11 @@ def get_masked_labels(conversation_ids: BatchEncoding, assistant_ranges: List[Tu
conversation_ids["offset_mapping"],
)
):
if any(id_s >= s and id_e <= e for s, e in assistant_ranges):
# Check if token OVERLAPS with any assistant range (not fully contained).
# This handles short assistant content where tokens span wider than the content.
# Overlap condition: token_start < range_end AND token_end > range_start
# Also handle edge case where range is invalid (s == -1 means not found)
if any(s != -1 and id_s < e and id_e > s for s, e in assistant_ranges):
pre_output = id_
output.append(id_)
else:
Expand Down
14 changes: 13 additions & 1 deletion arctic_training/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,19 @@ def loss(self, batch) -> torch.Tensor:
batch = to_device(batch, self.device)

if self.config.sequence_parallel_size == 1:
# if model.type=liger is configured - this will use a much more efficient fused
labels = batch.get("labels")

# Check if all labels are masked (would cause NaN loss from CE)
if labels is not None and (labels == -100).all():
# Run forward without labels to get logits, return differentiable zero
batch_without_labels = {k: v for k, v in batch.items() if k != "labels"}
Comment thread
sfc-gh-aponnusamy marked this conversation as resolved.
Outdated
outputs = self.model(**batch_without_labels, use_cache=False)
# Fake loss calculation - CE would return NaN, but we return differentiable zero
# A normal loss_fn upcasts logits to float so match it
loss = (outputs.logits.sum() * 0.0).float()
return loss

# Normal path: if model.type=liger is configured - this will use a much more efficient fused
# logits+loss liger kernel - using significantly less gpu memory and a bit faster
# compute (liger fused logits+loss kernel does not repeat forward during backward)
outputs = self.model(**batch, use_cache=False)
Expand Down
25 changes: 24 additions & 1 deletion arctic_training/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,8 +559,31 @@ def evaluate(self) -> None:
Evaluation loop. Measures the model's performance on the evaluation dataset.
"""
self.model.eval()
losses = []
with torch.no_grad():
losses = [self.loss(eval_batch).item() for eval_batch in self.eval_batches]
for batch_idx, eval_batch in enumerate(self.eval_batches):
loss = self.loss(eval_batch)
loss_value = loss.item()
losses.append(loss_value)

# Debug NaN losses
if torch.isnan(loss) or torch.isinf(loss):
logger.warning(
f"NaN/Inf eval loss detected at batch {batch_idx}!\n"
f" loss: {loss_value}\n"
f" batch keys: {eval_batch.keys()}\n"
f" input_ids shape: {eval_batch['input_ids'].shape if 'input_ids' in eval_batch else 'N/A'}\n"
f" labels shape: {eval_batch['labels'].shape if 'labels' in eval_batch else 'N/A'}"
)
if "labels" in eval_batch:
labels = eval_batch["labels"]
non_masked = (labels != -100).sum().item()
total = labels.numel()
pct = (100 * non_masked / total) if total > 0 else 0.0
logger.warning(
f" non-masked labels: {non_masked}/{total} ({pct:.1f}%)\n"
f" all labels are -100: {non_masked == 0}"
)
Comment thread
sfc-gh-aponnusamy marked this conversation as resolved.
Outdated
self.metrics.record("loss/eval", losses) # type: ignore

@callback_wrapper("checkpoint")
Expand Down