diff --git a/arctic_training/data/sft_factory.py b/arctic_training/data/sft_factory.py index 22283d14..378c19b6 100644 --- a/arctic_training/data/sft_factory.py +++ b/arctic_training/data/sft_factory.py @@ -424,15 +424,30 @@ def get_assistant_start_end_indices( ignore_empty_think: bool = False, ) -> List[Tuple[int, int]]: return_indices = [] + # Track search position to avoid matching assistant content that appears + # earlier in the conversation (e.g., in user context). Process ALL messages + # in order to track position, so assistant content is found AFTER the + # preceding user message, not at the first occurrence. + search_start = 0 for message in messages: - if message["role"] == "assistant": - message_text = message["content"] - if ignore_empty_think: - message_text = re.sub(r"^\s*\s*", "", message_text) + message_text = message["content"] + if message["role"] == "assistant" and ignore_empty_think: + message_text = re.sub(r"^\s*\s*", "", message_text) + # Find this message starting from current position + match_index = conversation_text.find(message_text, search_start) + if match_index == -1: + # Fallback: try searching from the beginning (original behavior) match_index = conversation_text.find(message_text) - # start_indices.append(match_index) - end_indices = match_index + len(message_text) - return_indices.append((match_index, end_indices)) + end_index = match_index + len(message_text) if match_index != -1 else -1 + + # Only record assistant message ranges + if message["role"] == "assistant": + return_indices.append((match_index, end_index)) + # Update search position for next message (track all messages in order) + # Use max() to ensure we never go backwards - this handles the case where + # fallback search finds content at an earlier position (likely a false match) + if match_index != -1: + search_start = max(search_start, match_index + len(message_text)) return return_indices @staticmethod @@ -446,7 +461,11 @@ def get_masked_labels(conversation_ids: BatchEncoding, assistant_ranges: List[Tu conversation_ids["offset_mapping"], ) ): - if any(id_s >= s and id_e <= e for s, e in assistant_ranges): + # Check if token OVERLAPS with any assistant range (not fully contained). + # This handles short assistant content where tokens span wider than the content. + # Overlap condition: token_start < range_end AND token_end > range_start + # Also handle edge case where range is invalid (s == -1 means not found) + if any(s != -1 and id_s < e and id_e > s for s, e in assistant_ranges): pre_output = id_ output.append(id_) else: diff --git a/arctic_training/trainer/sft_trainer.py b/arctic_training/trainer/sft_trainer.py index b9c2fed2..87119276 100644 --- a/arctic_training/trainer/sft_trainer.py +++ b/arctic_training/trainer/sft_trainer.py @@ -46,7 +46,19 @@ def loss(self, batch) -> torch.Tensor: batch = to_device(batch, self.device) if self.config.sequence_parallel_size == 1: - # if model.type=liger is configured - this will use a much more efficient fused + labels = batch.get("labels") + + # Check if all labels are masked (would cause NaN loss from CE) + if labels is not None and (labels == -100).all(): + # Run forward without labels to get logits, return differentiable zero + batch.pop("labels") + outputs = self.model(**batch, use_cache=False) + # Fake loss calculation - CE would return NaN, but we return differentiable zero + # A normal loss_fn upcasts logits to float so match it + loss = (outputs.logits.sum() * 0.0).float() + return loss + + # Normal path: if model.type=liger is configured - this will use a much more efficient fused # logits+loss liger kernel - using significantly less gpu memory and a bit faster # compute (liger fused logits+loss kernel does not repeat forward during backward) outputs = self.model(**batch, use_cache=False)