Skip to content
Open
35 changes: 27 additions & 8 deletions arctic_training/data/sft_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,15 +424,30 @@ def get_assistant_start_end_indices(
ignore_empty_think: bool = False,
) -> List[Tuple[int, int]]:
return_indices = []
# Track search position to avoid matching assistant content that appears
# earlier in the conversation (e.g., in user context). Process ALL messages
# in order to track position, so assistant content is found AFTER the
# preceding user message, not at the first occurrence.
search_start = 0
for message in messages:
if message["role"] == "assistant":
message_text = message["content"]
if ignore_empty_think:
message_text = re.sub(r"^<think>\s*</think>\s*", "", message_text)
message_text = message["content"]
if message["role"] == "assistant" and ignore_empty_think:
message_text = re.sub(r"^<think>\s*</think>\s*", "", message_text)
# Find this message starting from current position
match_index = conversation_text.find(message_text, search_start)
if match_index == -1:
# Fallback: try searching from the beginning (original behavior)
match_index = conversation_text.find(message_text)
# start_indices.append(match_index)
end_indices = match_index + len(message_text)
return_indices.append((match_index, end_indices))
end_index = match_index + len(message_text) if match_index != -1 else -1

# Only record assistant message ranges
if message["role"] == "assistant":
return_indices.append((match_index, end_index))
# Update search position for next message (track all messages in order)
# Use max() to ensure we never go backwards - this handles the case where
# fallback search finds content at an earlier position (likely a false match)
if match_index != -1:
search_start = max(search_start, match_index + len(message_text))
return return_indices

@staticmethod
Expand All @@ -446,7 +461,11 @@ def get_masked_labels(conversation_ids: BatchEncoding, assistant_ranges: List[Tu
conversation_ids["offset_mapping"],
)
):
if any(id_s >= s and id_e <= e for s, e in assistant_ranges):
# Check if token OVERLAPS with any assistant range (not fully contained).
# This handles short assistant content where tokens span wider than the content.
# Overlap condition: token_start < range_end AND token_end > range_start
# Also handle edge case where range is invalid (s == -1 means not found)
if any(s != -1 and id_s < e and id_e > s for s, e in assistant_ranges):
pre_output = id_
output.append(id_)
else:
Expand Down
14 changes: 13 additions & 1 deletion arctic_training/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,19 @@ def loss(self, batch) -> torch.Tensor:
batch = to_device(batch, self.device)

if self.config.sequence_parallel_size == 1:
# if model.type=liger is configured - this will use a much more efficient fused
labels = batch.get("labels")

# Check if all labels are masked (would cause NaN loss from CE)
if labels is not None and (labels == -100).all():
# Run forward without labels to get logits, return differentiable zero
batch.pop("labels")
outputs = self.model(**batch, use_cache=False)
# Fake loss calculation - CE would return NaN, but we return differentiable zero
# A normal loss_fn upcasts logits to float so match it
loss = (outputs.logits.sum() * 0.0).float()
return loss

# Normal path: if model.type=liger is configured - this will use a much more efficient fused
# logits+loss liger kernel - using significantly less gpu memory and a bit faster
# compute (liger fused logits+loss kernel does not repeat forward during backward)
outputs = self.model(**batch, use_cache=False)
Expand Down