diff --git a/arctic_training/data/sft_factory.py b/arctic_training/data/sft_factory.py
index 22283d14..378c19b6 100644
--- a/arctic_training/data/sft_factory.py
+++ b/arctic_training/data/sft_factory.py
@@ -424,15 +424,30 @@ def get_assistant_start_end_indices(
ignore_empty_think: bool = False,
) -> List[Tuple[int, int]]:
return_indices = []
+ # Track search position to avoid matching assistant content that appears
+ # earlier in the conversation (e.g., in user context). Process ALL messages
+ # in order to track position, so assistant content is found AFTER the
+ # preceding user message, not at the first occurrence.
+ search_start = 0
for message in messages:
- if message["role"] == "assistant":
- message_text = message["content"]
- if ignore_empty_think:
- message_text = re.sub(r"^\s*\s*", "", message_text)
+ message_text = message["content"]
+ if message["role"] == "assistant" and ignore_empty_think:
+ message_text = re.sub(r"^\s*\s*", "", message_text)
+ # Find this message starting from current position
+ match_index = conversation_text.find(message_text, search_start)
+ if match_index == -1:
+ # Fallback: try searching from the beginning (original behavior)
match_index = conversation_text.find(message_text)
- # start_indices.append(match_index)
- end_indices = match_index + len(message_text)
- return_indices.append((match_index, end_indices))
+ end_index = match_index + len(message_text) if match_index != -1 else -1
+
+ # Only record assistant message ranges
+ if message["role"] == "assistant":
+ return_indices.append((match_index, end_index))
+ # Update search position for next message (track all messages in order)
+ # Use max() to ensure we never go backwards - this handles the case where
+ # fallback search finds content at an earlier position (likely a false match)
+ if match_index != -1:
+ search_start = max(search_start, match_index + len(message_text))
return return_indices
@staticmethod
@@ -446,7 +461,11 @@ def get_masked_labels(conversation_ids: BatchEncoding, assistant_ranges: List[Tu
conversation_ids["offset_mapping"],
)
):
- if any(id_s >= s and id_e <= e for s, e in assistant_ranges):
+ # Check if token OVERLAPS with any assistant range (not fully contained).
+ # This handles short assistant content where tokens span wider than the content.
+ # Overlap condition: token_start < range_end AND token_end > range_start
+ # Also handle edge case where range is invalid (s == -1 means not found)
+ if any(s != -1 and id_s < e and id_e > s for s, e in assistant_ranges):
pre_output = id_
output.append(id_)
else:
diff --git a/arctic_training/trainer/sft_trainer.py b/arctic_training/trainer/sft_trainer.py
index b9c2fed2..87119276 100644
--- a/arctic_training/trainer/sft_trainer.py
+++ b/arctic_training/trainer/sft_trainer.py
@@ -46,7 +46,19 @@ def loss(self, batch) -> torch.Tensor:
batch = to_device(batch, self.device)
if self.config.sequence_parallel_size == 1:
- # if model.type=liger is configured - this will use a much more efficient fused
+ labels = batch.get("labels")
+
+ # Check if all labels are masked (would cause NaN loss from CE)
+ if labels is not None and (labels == -100).all():
+ # Run forward without labels to get logits, return differentiable zero
+ batch.pop("labels")
+ outputs = self.model(**batch, use_cache=False)
+ # Fake loss calculation - CE would return NaN, but we return differentiable zero
+ # A normal loss_fn upcasts logits to float so match it
+ loss = (outputs.logits.sum() * 0.0).float()
+ return loss
+
+ # Normal path: if model.type=liger is configured - this will use a much more efficient fused
# logits+loss liger kernel - using significantly less gpu memory and a bit faster
# compute (liger fused logits+loss kernel does not repeat forward during backward)
outputs = self.model(**batch, use_cache=False)