From 50ac82331ed81932238b9defcd0cabb4ba783cec Mon Sep 17 00:00:00 2001 From: Arun Ponnusamy Date: Fri, 19 Dec 2025 09:13:40 -0800 Subject: [PATCH 1/8] Debugging val/loss for sql autocompletion that results in NaN --- arctic_training/config/ac_sql_sft.yml | 70 +++++++++++++++++++++++ arctic_training/data/sft_factory.py | 80 ++++++++++++++++++++++++--- arctic_training/trainer/trainer.py | 24 +++++++- 3 files changed, 166 insertions(+), 8 deletions(-) create mode 100644 arctic_training/config/ac_sql_sft.yml diff --git a/arctic_training/config/ac_sql_sft.yml b/arctic_training/config/ac_sql_sft.yml new file mode 100644 index 00000000..19ceac35 --- /dev/null +++ b/arctic_training/config/ac_sql_sft.yml @@ -0,0 +1,70 @@ +type: sft +micro_batch_size: 1 +gradient_accumulation_steps: 6 +epochs: 2 +activation_checkpoint_cpu_offload: true + +model: + type: liger + name_or_path: /modeling-checkpoints/zheweiyao/persist_for_snowflake_text2sql/mid-training-sqlite-13A-32B-packing-H200-global_step_28740 + attn_implementation: flash_attention_3 + +data: + sources: + # TODO: Update path once dataset is generated using csv_to_huggingface_dataset.py + - type: huggingface_instruct + name_or_path: /home/yak/data/input + role_mapping: + user: messages.role.user + assistant: messages.role.assistant + + eval_sources: + # TODO: Update path once validation dataset is generated using csv_to_huggingface_dataset.py + - type: huggingface_instruct + name_or_path: /home/yak/data/validation + role_mapping: + user: messages.role.user + assistant: messages.role.assistant + + num_proc: 16 + max_length: 32768 + cache_dir: /checkpoint/aponnusamy/data-cache + pack_samples: true + ignore_empty_think: true + +# Evaluation interval - required when eval_sources is provided +eval_interval: 150 + +deepspeed: + zero_optimization: + stage: 3 + allgather_bucket_size: 500000000 + reduce_bucket_size: 250000000 + memory_efficient_linear: true + +optimizer: + weight_decay: 0.01 + learning_rate: 1e-5 + +scheduler: + type: huggingface + name: cosine_with_min_lr + warmup_ratio: 0.1 + scheduler_specific_kwargs: + min_lr_rate: 0.1 + +checkpoint: + - type: huggingface + save_every_n_steps: 150 + save_end_of_training: true + output_dir: /modeling-checkpoints/aponnusamy/checkpoints/qwen3-32b-ac-sft + +wandb: + enable: true + project: snowflakesql_sft_training + name: qwen3-32b-ac-sft + +logger: + level: WARNING + output_dir: logs + print_output_ranks: [0] diff --git a/arctic_training/data/sft_factory.py b/arctic_training/data/sft_factory.py index 22283d14..60ae2852 100644 --- a/arctic_training/data/sft_factory.py +++ b/arctic_training/data/sft_factory.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +import os import random import re from typing import Dict @@ -37,8 +39,14 @@ from arctic_training.data.hf_instruct_source import HFDataSourceInstruct from arctic_training.data.utils import DatasetType +logger = logging.getLogger(__name__) + IGNORE_INDEX = -100 +# Debug flag for detailed logging of label masking issues +# Set environment variable DEBUG_LABEL_MASKING=1 to enable +DEBUG_LABEL_MASKING = os.environ.get("DEBUG_LABEL_MASKING", "0") == "1" + # this function is modified from TRL trl.trainer.utils.py def pad( @@ -409,6 +417,25 @@ def tokenize_messages( # _ = get_assistant_start_end_indices(messages, conversation_text) labels = cls.get_masked_labels(conversation_ids, assistant_ranges) conversation_ids["labels"] = labels + + # Debug logging for label masking issues + if DEBUG_LABEL_MASKING: + non_masked = sum(1 for l in labels if l != IGNORE_INDEX) + total = len(labels) + if non_masked == 0: + logger.warning( + "ALL LABELS MASKED! This will cause NaN loss.\n conversation_text length:" + f" {len(conversation_text)}\n assistant_ranges: {assistant_ranges}\n num_messages:" + f" {len(messages)}\n assistant_contents:" + f" {[m['content'][:50] + '...' for m in messages if m['role'] == 'assistant']}" + ) + elif non_masked < 5: + logger.warning( + f"Very few non-masked labels ({non_masked}/{total}).\n assistant_ranges: {assistant_ranges}\n" + " assistant_contents:" + f" {[m['content'][:50] + '...' for m in messages if m['role'] == 'assistant']}" + ) + # compare_messages_with_labels(split_list_by_specific_num(conversation_ids["labels"]), messages, tokenizer) del conversation_ids["offset_mapping"] else: @@ -424,15 +451,54 @@ def get_assistant_start_end_indices( ignore_empty_think: bool = False, ) -> List[Tuple[int, int]]: return_indices = [] + # Track search position to avoid matching assistant content that appears + # earlier in the conversation (e.g., in user context). Process ALL messages + # in order to track position, so assistant content is found AFTER the + # preceding user message, not at the first occurrence. + search_start = 0 for message in messages: - if message["role"] == "assistant": - message_text = message["content"] - if ignore_empty_think: - message_text = re.sub(r"^\s*\s*", "", message_text) + message_text = message["content"] + original_text = message_text # Keep for debug logging + if message["role"] == "assistant" and ignore_empty_think: + message_text = re.sub(r"^\s*\s*", "", message_text) + # Find this message starting from current position + match_index = conversation_text.find(message_text, search_start) + if match_index == -1: + # Fallback: try searching from the beginning (original behavior) match_index = conversation_text.find(message_text) - # start_indices.append(match_index) - end_indices = match_index + len(message_text) - return_indices.append((match_index, end_indices)) + if DEBUG_LABEL_MASKING and message["role"] == "assistant": + if match_index == -1: + logger.warning( + "Assistant content NOT FOUND in conversation_text!\n" + f" role: {message['role']}\n" + f" original_content: {repr(original_text[:100])}...\n" + f" search_text: {repr(message_text[:100])}...\n" + f" search_start was: {search_start}\n" + f" conversation_text length: {len(conversation_text)}" + ) + else: + logger.warning( + "Assistant content found via fallback (from position 0)!\n" + f" content: {repr(message_text[:50])}...\n" + f" found at: {match_index}, search_start was: {search_start}" + ) + end_index = match_index + len(message_text) if match_index != -1 else -1 + + if DEBUG_LABEL_MASKING and message["role"] == "assistant": + conv_len = len(conversation_text) + position_pct = (match_index / conv_len * 100) if match_index != -1 else -1 + logger.info( + f"Assistant range: ({match_index}, {end_index}) - " + f"{position_pct:.1f}% into conversation (len={conv_len}), " + f"content: {repr(message_text[:30])}..." + ) + + # Only record assistant message ranges + if message["role"] == "assistant": + return_indices.append((match_index, end_index)) + # Update search position for next message (track all messages in order) + if match_index != -1: + search_start = match_index + len(message_text) return return_indices @staticmethod diff --git a/arctic_training/trainer/trainer.py b/arctic_training/trainer/trainer.py index 7035a4b3..c2094bb5 100644 --- a/arctic_training/trainer/trainer.py +++ b/arctic_training/trainer/trainer.py @@ -559,8 +559,30 @@ def evaluate(self) -> None: Evaluation loop. Measures the model's performance on the evaluation dataset. """ self.model.eval() + losses = [] with torch.no_grad(): - losses = [self.loss(eval_batch).item() for eval_batch in self.eval_batches] + for batch_idx, eval_batch in enumerate(self.eval_batches): + loss = self.loss(eval_batch) + loss_value = loss.item() + losses.append(loss_value) + + # Debug NaN losses + if torch.isnan(loss) or torch.isinf(loss): + logger.warning( + f"NaN/Inf eval loss detected at batch {batch_idx}!\n" + f" loss: {loss_value}\n" + f" batch keys: {eval_batch.keys()}\n" + f" input_ids shape: {eval_batch['input_ids'].shape if 'input_ids' in eval_batch else 'N/A'}\n" + f" labels shape: {eval_batch['labels'].shape if 'labels' in eval_batch else 'N/A'}" + ) + if "labels" in eval_batch: + labels = eval_batch["labels"] + non_masked = (labels != -100).sum().item() + total = labels.numel() + logger.warning( + f" non-masked labels: {non_masked}/{total} ({100*non_masked/total:.1f}%)\n" + f" all labels are -100: {non_masked == 0}" + ) self.metrics.record("loss/eval", losses) # type: ignore @callback_wrapper("checkpoint") From e9be1e4d1a33bccb6c5a46ed48a3505ca4897df0 Mon Sep 17 00:00:00 2001 From: Arun Ponnusamy Date: Fri, 19 Dec 2025 10:08:08 -0800 Subject: [PATCH 2/8] Debugging val/loss for sql autocompletion that results in NaN --- arctic_training/data/sft_factory.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/arctic_training/data/sft_factory.py b/arctic_training/data/sft_factory.py index 60ae2852..2dfa7ecc 100644 --- a/arctic_training/data/sft_factory.py +++ b/arctic_training/data/sft_factory.py @@ -512,7 +512,11 @@ def get_masked_labels(conversation_ids: BatchEncoding, assistant_ranges: List[Tu conversation_ids["offset_mapping"], ) ): - if any(id_s >= s and id_e <= e for s, e in assistant_ranges): + # Check if token OVERLAPS with any assistant range (not fully contained). + # This handles short assistant content where tokens span wider than the content. + # Overlap condition: token_start < range_end AND token_end > range_start + # Also handle edge case where range is invalid (s == -1 means not found) + if any(s != -1 and id_s < e and id_e > s for s, e in assistant_ranges): pre_output = id_ output.append(id_) else: From 59f84cc2a08316620ceb718f734c19568e19a2b7 Mon Sep 17 00:00:00 2001 From: Arun Ponnusamy Date: Fri, 19 Dec 2025 10:43:27 -0800 Subject: [PATCH 3/8] Debugging val/loss for sql autocompletion that results in NaN --- arctic_training/data/sft_factory.py | 4 +++- arctic_training/trainer/trainer.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/arctic_training/data/sft_factory.py b/arctic_training/data/sft_factory.py index 2dfa7ecc..62df60be 100644 --- a/arctic_training/data/sft_factory.py +++ b/arctic_training/data/sft_factory.py @@ -497,8 +497,10 @@ def get_assistant_start_end_indices( if message["role"] == "assistant": return_indices.append((match_index, end_index)) # Update search position for next message (track all messages in order) + # Use max() to ensure we never go backwards - this handles the case where + # fallback search finds content at an earlier position (likely a false match) if match_index != -1: - search_start = match_index + len(message_text) + search_start = max(search_start, match_index + len(message_text)) return return_indices @staticmethod diff --git a/arctic_training/trainer/trainer.py b/arctic_training/trainer/trainer.py index c2094bb5..dd372f32 100644 --- a/arctic_training/trainer/trainer.py +++ b/arctic_training/trainer/trainer.py @@ -579,8 +579,9 @@ def evaluate(self) -> None: labels = eval_batch["labels"] non_masked = (labels != -100).sum().item() total = labels.numel() + pct = (100 * non_masked / total) if total > 0 else 0.0 logger.warning( - f" non-masked labels: {non_masked}/{total} ({100*non_masked/total:.1f}%)\n" + f" non-masked labels: {non_masked}/{total} ({pct:.1f}%)\n" f" all labels are -100: {non_masked == 0}" ) self.metrics.record("loss/eval", losses) # type: ignore From c2d3fb24cc7e933a58f7c35ed827c3e7b7029112 Mon Sep 17 00:00:00 2001 From: Arun Ponnusamy Date: Mon, 5 Jan 2026 10:29:53 -0800 Subject: [PATCH 4/8] Debugging val/loss for sql autocompletion that results in NaN --- arctic_training/config/ac_sql_sft.yml | 70 --------------------------- 1 file changed, 70 deletions(-) delete mode 100644 arctic_training/config/ac_sql_sft.yml diff --git a/arctic_training/config/ac_sql_sft.yml b/arctic_training/config/ac_sql_sft.yml deleted file mode 100644 index 19ceac35..00000000 --- a/arctic_training/config/ac_sql_sft.yml +++ /dev/null @@ -1,70 +0,0 @@ -type: sft -micro_batch_size: 1 -gradient_accumulation_steps: 6 -epochs: 2 -activation_checkpoint_cpu_offload: true - -model: - type: liger - name_or_path: /modeling-checkpoints/zheweiyao/persist_for_snowflake_text2sql/mid-training-sqlite-13A-32B-packing-H200-global_step_28740 - attn_implementation: flash_attention_3 - -data: - sources: - # TODO: Update path once dataset is generated using csv_to_huggingface_dataset.py - - type: huggingface_instruct - name_or_path: /home/yak/data/input - role_mapping: - user: messages.role.user - assistant: messages.role.assistant - - eval_sources: - # TODO: Update path once validation dataset is generated using csv_to_huggingface_dataset.py - - type: huggingface_instruct - name_or_path: /home/yak/data/validation - role_mapping: - user: messages.role.user - assistant: messages.role.assistant - - num_proc: 16 - max_length: 32768 - cache_dir: /checkpoint/aponnusamy/data-cache - pack_samples: true - ignore_empty_think: true - -# Evaluation interval - required when eval_sources is provided -eval_interval: 150 - -deepspeed: - zero_optimization: - stage: 3 - allgather_bucket_size: 500000000 - reduce_bucket_size: 250000000 - memory_efficient_linear: true - -optimizer: - weight_decay: 0.01 - learning_rate: 1e-5 - -scheduler: - type: huggingface - name: cosine_with_min_lr - warmup_ratio: 0.1 - scheduler_specific_kwargs: - min_lr_rate: 0.1 - -checkpoint: - - type: huggingface - save_every_n_steps: 150 - save_end_of_training: true - output_dir: /modeling-checkpoints/aponnusamy/checkpoints/qwen3-32b-ac-sft - -wandb: - enable: true - project: snowflakesql_sft_training - name: qwen3-32b-ac-sft - -logger: - level: WARNING - output_dir: logs - print_output_ranks: [0] From 9e90bdb87f1e9c0936a414deeda1f3def61a4d67 Mon Sep 17 00:00:00 2001 From: Arun Ponnusamy Date: Mon, 5 Jan 2026 10:50:50 -0800 Subject: [PATCH 5/8] Debugging val/loss for sql autocompletion that results in NaN --- arctic_training/data/sft_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arctic_training/data/sft_factory.py b/arctic_training/data/sft_factory.py index 62df60be..68f7379a 100644 --- a/arctic_training/data/sft_factory.py +++ b/arctic_training/data/sft_factory.py @@ -420,7 +420,7 @@ def tokenize_messages( # Debug logging for label masking issues if DEBUG_LABEL_MASKING: - non_masked = sum(1 for l in labels if l != IGNORE_INDEX) + non_masked = sum(1 for label in labels if label != IGNORE_INDEX) total = len(labels) if non_masked == 0: logger.warning( From a7eb093114e813e6184253f484d5651eb01d56ab Mon Sep 17 00:00:00 2001 From: Arun Ponnusamy Date: Mon, 5 Jan 2026 12:29:55 -0800 Subject: [PATCH 6/8] Debugging val/loss for sql autocompletion that results in NaN --- arctic_training/trainer/sft_trainer.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/arctic_training/trainer/sft_trainer.py b/arctic_training/trainer/sft_trainer.py index b9c2fed2..21dff261 100644 --- a/arctic_training/trainer/sft_trainer.py +++ b/arctic_training/trainer/sft_trainer.py @@ -46,7 +46,19 @@ def loss(self, batch) -> torch.Tensor: batch = to_device(batch, self.device) if self.config.sequence_parallel_size == 1: - # if model.type=liger is configured - this will use a much more efficient fused + labels = batch.get("labels") + + # Check if all labels are masked (would cause NaN loss from CE) + if labels is not None and (labels == -100).all(): + # Run forward without labels to get logits, return differentiable zero + batch_without_labels = {k: v for k, v in batch.items() if k != "labels"} + outputs = self.model(**batch_without_labels, use_cache=False) + # Fake loss calculation - CE would return NaN, but we return differentiable zero + # A normal loss_fn upcasts logits to float so match it + loss = (outputs.logits.sum() * 0.0).float() + return loss + + # Normal path: if model.type=liger is configured - this will use a much more efficient fused # logits+loss liger kernel - using significantly less gpu memory and a bit faster # compute (liger fused logits+loss kernel does not repeat forward during backward) outputs = self.model(**batch, use_cache=False) From b289f7eb7d4416940e1f8b4093cde0a35ea84fe6 Mon Sep 17 00:00:00 2001 From: Arun Ponnusamy Date: Mon, 5 Jan 2026 13:12:23 -0800 Subject: [PATCH 7/8] Debugging val/loss for sql autocompletion that results in NaN - Code review comments --- arctic_training/data/sft_factory.py | 13 +++++-------- arctic_training/trainer/sft_trainer.py | 4 ++-- arctic_training/trainer/trainer.py | 25 +------------------------ 3 files changed, 8 insertions(+), 34 deletions(-) diff --git a/arctic_training/data/sft_factory.py b/arctic_training/data/sft_factory.py index 68f7379a..45c3f0a2 100644 --- a/arctic_training/data/sft_factory.py +++ b/arctic_training/data/sft_factory.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import os import random import re @@ -39,8 +38,6 @@ from arctic_training.data.hf_instruct_source import HFDataSourceInstruct from arctic_training.data.utils import DatasetType -logger = logging.getLogger(__name__) - IGNORE_INDEX = -100 # Debug flag for detailed logging of label masking issues @@ -423,14 +420,14 @@ def tokenize_messages( non_masked = sum(1 for label in labels if label != IGNORE_INDEX) total = len(labels) if non_masked == 0: - logger.warning( + print( "ALL LABELS MASKED! This will cause NaN loss.\n conversation_text length:" f" {len(conversation_text)}\n assistant_ranges: {assistant_ranges}\n num_messages:" f" {len(messages)}\n assistant_contents:" f" {[m['content'][:50] + '...' for m in messages if m['role'] == 'assistant']}" ) elif non_masked < 5: - logger.warning( + print( f"Very few non-masked labels ({non_masked}/{total}).\n assistant_ranges: {assistant_ranges}\n" " assistant_contents:" f" {[m['content'][:50] + '...' for m in messages if m['role'] == 'assistant']}" @@ -468,7 +465,7 @@ def get_assistant_start_end_indices( match_index = conversation_text.find(message_text) if DEBUG_LABEL_MASKING and message["role"] == "assistant": if match_index == -1: - logger.warning( + print( "Assistant content NOT FOUND in conversation_text!\n" f" role: {message['role']}\n" f" original_content: {repr(original_text[:100])}...\n" @@ -477,7 +474,7 @@ def get_assistant_start_end_indices( f" conversation_text length: {len(conversation_text)}" ) else: - logger.warning( + print( "Assistant content found via fallback (from position 0)!\n" f" content: {repr(message_text[:50])}...\n" f" found at: {match_index}, search_start was: {search_start}" @@ -487,7 +484,7 @@ def get_assistant_start_end_indices( if DEBUG_LABEL_MASKING and message["role"] == "assistant": conv_len = len(conversation_text) position_pct = (match_index / conv_len * 100) if match_index != -1 else -1 - logger.info( + print( f"Assistant range: ({match_index}, {end_index}) - " f"{position_pct:.1f}% into conversation (len={conv_len}), " f"content: {repr(message_text[:30])}..." diff --git a/arctic_training/trainer/sft_trainer.py b/arctic_training/trainer/sft_trainer.py index 21dff261..87119276 100644 --- a/arctic_training/trainer/sft_trainer.py +++ b/arctic_training/trainer/sft_trainer.py @@ -51,8 +51,8 @@ def loss(self, batch) -> torch.Tensor: # Check if all labels are masked (would cause NaN loss from CE) if labels is not None and (labels == -100).all(): # Run forward without labels to get logits, return differentiable zero - batch_without_labels = {k: v for k, v in batch.items() if k != "labels"} - outputs = self.model(**batch_without_labels, use_cache=False) + batch.pop("labels") + outputs = self.model(**batch, use_cache=False) # Fake loss calculation - CE would return NaN, but we return differentiable zero # A normal loss_fn upcasts logits to float so match it loss = (outputs.logits.sum() * 0.0).float() diff --git a/arctic_training/trainer/trainer.py b/arctic_training/trainer/trainer.py index dd372f32..7035a4b3 100644 --- a/arctic_training/trainer/trainer.py +++ b/arctic_training/trainer/trainer.py @@ -559,31 +559,8 @@ def evaluate(self) -> None: Evaluation loop. Measures the model's performance on the evaluation dataset. """ self.model.eval() - losses = [] with torch.no_grad(): - for batch_idx, eval_batch in enumerate(self.eval_batches): - loss = self.loss(eval_batch) - loss_value = loss.item() - losses.append(loss_value) - - # Debug NaN losses - if torch.isnan(loss) or torch.isinf(loss): - logger.warning( - f"NaN/Inf eval loss detected at batch {batch_idx}!\n" - f" loss: {loss_value}\n" - f" batch keys: {eval_batch.keys()}\n" - f" input_ids shape: {eval_batch['input_ids'].shape if 'input_ids' in eval_batch else 'N/A'}\n" - f" labels shape: {eval_batch['labels'].shape if 'labels' in eval_batch else 'N/A'}" - ) - if "labels" in eval_batch: - labels = eval_batch["labels"] - non_masked = (labels != -100).sum().item() - total = labels.numel() - pct = (100 * non_masked / total) if total > 0 else 0.0 - logger.warning( - f" non-masked labels: {non_masked}/{total} ({pct:.1f}%)\n" - f" all labels are -100: {non_masked == 0}" - ) + losses = [self.loss(eval_batch).item() for eval_batch in self.eval_batches] self.metrics.record("loss/eval", losses) # type: ignore @callback_wrapper("checkpoint") From d99936700118e1c0127fe2a3c0b3c742308381fc Mon Sep 17 00:00:00 2001 From: Arun Ponnusamy Date: Mon, 5 Jan 2026 13:16:58 -0800 Subject: [PATCH 8/8] Debugging val/loss for sql autocompletion that results in NaN - Code review comments --- arctic_training/data/sft_factory.py | 50 ----------------------------- 1 file changed, 50 deletions(-) diff --git a/arctic_training/data/sft_factory.py b/arctic_training/data/sft_factory.py index 45c3f0a2..378c19b6 100644 --- a/arctic_training/data/sft_factory.py +++ b/arctic_training/data/sft_factory.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import random import re from typing import Dict @@ -40,10 +39,6 @@ IGNORE_INDEX = -100 -# Debug flag for detailed logging of label masking issues -# Set environment variable DEBUG_LABEL_MASKING=1 to enable -DEBUG_LABEL_MASKING = os.environ.get("DEBUG_LABEL_MASKING", "0") == "1" - # this function is modified from TRL trl.trainer.utils.py def pad( @@ -414,25 +409,6 @@ def tokenize_messages( # _ = get_assistant_start_end_indices(messages, conversation_text) labels = cls.get_masked_labels(conversation_ids, assistant_ranges) conversation_ids["labels"] = labels - - # Debug logging for label masking issues - if DEBUG_LABEL_MASKING: - non_masked = sum(1 for label in labels if label != IGNORE_INDEX) - total = len(labels) - if non_masked == 0: - print( - "ALL LABELS MASKED! This will cause NaN loss.\n conversation_text length:" - f" {len(conversation_text)}\n assistant_ranges: {assistant_ranges}\n num_messages:" - f" {len(messages)}\n assistant_contents:" - f" {[m['content'][:50] + '...' for m in messages if m['role'] == 'assistant']}" - ) - elif non_masked < 5: - print( - f"Very few non-masked labels ({non_masked}/{total}).\n assistant_ranges: {assistant_ranges}\n" - " assistant_contents:" - f" {[m['content'][:50] + '...' for m in messages if m['role'] == 'assistant']}" - ) - # compare_messages_with_labels(split_list_by_specific_num(conversation_ids["labels"]), messages, tokenizer) del conversation_ids["offset_mapping"] else: @@ -455,7 +431,6 @@ def get_assistant_start_end_indices( search_start = 0 for message in messages: message_text = message["content"] - original_text = message_text # Keep for debug logging if message["role"] == "assistant" and ignore_empty_think: message_text = re.sub(r"^\s*\s*", "", message_text) # Find this message starting from current position @@ -463,33 +438,8 @@ def get_assistant_start_end_indices( if match_index == -1: # Fallback: try searching from the beginning (original behavior) match_index = conversation_text.find(message_text) - if DEBUG_LABEL_MASKING and message["role"] == "assistant": - if match_index == -1: - print( - "Assistant content NOT FOUND in conversation_text!\n" - f" role: {message['role']}\n" - f" original_content: {repr(original_text[:100])}...\n" - f" search_text: {repr(message_text[:100])}...\n" - f" search_start was: {search_start}\n" - f" conversation_text length: {len(conversation_text)}" - ) - else: - print( - "Assistant content found via fallback (from position 0)!\n" - f" content: {repr(message_text[:50])}...\n" - f" found at: {match_index}, search_start was: {search_start}" - ) end_index = match_index + len(message_text) if match_index != -1 else -1 - if DEBUG_LABEL_MASKING and message["role"] == "assistant": - conv_len = len(conversation_text) - position_pct = (match_index / conv_len * 100) if match_index != -1 else -1 - print( - f"Assistant range: ({match_index}, {end_index}) - " - f"{position_pct:.1f}% into conversation (len={conv_len}), " - f"content: {repr(message_text[:30])}..." - ) - # Only record assistant message ranges if message["role"] == "assistant": return_indices.append((match_index, end_index))