From f8e57d99c8de531fa10fa5cf9f5a3e87e63330bd Mon Sep 17 00:00:00 2001 From: zhewei yao Date: Mon, 2 Feb 2026 22:59:37 +0000 Subject: [PATCH 1/3] adding input output format --- .../data/hf_input_output_source.py | 67 +++++++++ arctic_training/data/sft_factory.py | 131 ++++++++++++++---- tests/data/test_sft_factory.py | 112 +++++++++++++++ 3 files changed, 285 insertions(+), 25 deletions(-) create mode 100644 arctic_training/data/hf_input_output_source.py diff --git a/arctic_training/data/hf_input_output_source.py b/arctic_training/data/hf_input_output_source.py new file mode 100644 index 00000000..d4265d62 --- /dev/null +++ b/arctic_training/data/hf_input_output_source.py @@ -0,0 +1,67 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any +from typing import Dict + +from pydantic import Field + +from arctic_training.data.hf_source import HFDataSource +from arctic_training.data.hf_source import HFDataSourceConfig +from arctic_training.data.utils import DatasetType + + +class HFDataSourceConfigInputOutput(HFDataSourceConfig): + input_key: str = Field(default="input") + """The field name for the input/instruction part of the data.""" + + output_key: str = Field(default="output") + """The field name for the output/response part of the data.""" + + +class HFDataSourceInputOutput(HFDataSource): + """ + DataSource class for datasets with INPUT/OUTPUT format. + + This source handles datasets where: + - Data contains separate INPUT and OUTPUT fields + - The full sequence is INPUT + OUTPUT + - Only the OUTPUT portion is used for training (INPUT is masked) + """ + + name = "huggingface_input_output" + config: HFDataSourceConfigInputOutput + + def post_load_callback(self, dataset: DatasetType) -> DatasetType: + def process_example(example: Dict[str, Any]) -> Dict[str, Any]: + input_text = example.get(self.config.input_key, "") + output_text = example.get(self.config.output_key, "") + + if not output_text: + raise ValueError( + f"Output field '{self.config.output_key}' is empty or missing. " + f"Available keys: {list(example.keys())}" + ) + + return { + "input": str(input_text), + "output": str(output_text), + } + + return dataset.map( + process_example, + num_proc=self.data_factory.config.num_proc, + desc=f"Loading {self.config.name_or_path}", + ) diff --git a/arctic_training/data/sft_factory.py b/arctic_training/data/sft_factory.py index 22283d14..7dd27d73 100644 --- a/arctic_training/data/sft_factory.py +++ b/arctic_training/data/sft_factory.py @@ -34,6 +34,7 @@ from arctic_training.config.data import DataConfig from arctic_training.config.utils import HumanInt from arctic_training.data.factory import DataFactory +from arctic_training.data.hf_input_output_source import HFDataSourceInputOutput from arctic_training.data.hf_instruct_source import HFDataSourceInstruct from arctic_training.data.utils import DatasetType @@ -265,6 +266,9 @@ class SFTDataConfig(DataConfig): ignore_empty_think: bool = False """ Whether to mask the empty think tokens preventing the loss of thinking ability.""" + data_format: Literal["messages", "input_output"] = "messages" + """ The format of the input data: 'messages' for chat format, 'input_output' for INPUT/OUTPUT format. """ + @model_validator(mode="after") def validate_padding(self) -> Self: if self.pad_to == "max_length" and "div_length" in self.model_fields_set: @@ -357,37 +361,65 @@ def pack_dataset(self, dataset: DatasetType) -> DatasetType: class SFTDataFactory(DataFactory): name = "sft" config: SFTDataConfig - default_source_cls = HFDataSourceInstruct callbacks = [ ("post-load", filter_dataset_length), ("post-load", pack_dataset), ] + @property + def default_source_cls(self): + """Return the appropriate data source class based on data_format config.""" + if self.config.data_format == "input_output": + return HFDataSourceInputOutput + else: + return HFDataSourceInstruct + def process(self, dataset: DatasetType) -> DatasetType: - if "messages" not in dataset.column_names: - raise ValueError("Dataset must have 'messages' column to tokenize for SFTDataFactory.") - dataset = dataset.select_columns(["messages"]) - # sft based tokenization, - # we assume the messages are in the format of: - # {'role': '...', 'content': '...'} - # datasets = datasets.select(range(100, 1100)) - dataset = dataset.select(range(len(dataset))) - # datasets.disable_caching() - # tmp = tokenize_messages(datasets[0]["messages"][:2], tokenizer, mask_inputs=mask_inputs) - # import pdb; pdb.set_trace() - return dataset.map( - lambda ex: { - **self.tokenize_messages( - ex["messages"], - self.tokenizer, - mask_inputs=self.config.mask_inputs, - ignore_empty_think=self.config.ignore_empty_think, - ) - }, - remove_columns=dataset.column_names, - num_proc=self.config.num_proc, - desc="Tokenizing messages", - ) + if self.config.data_format == "messages": + if "messages" not in dataset.column_names: + raise ValueError("Dataset must have 'messages' column to tokenize for SFTDataFactory with 'messages' format.") + dataset = dataset.select_columns(["messages"]) + # sft based tokenization, + # we assume the messages are in the format of: + # {'role': '...', 'content': '...'} + # datasets = datasets.select(range(100, 1100)) + dataset = dataset.select(range(len(dataset))) + # datasets.disable_caching() + # tmp = tokenize_messages(datasets[0]["messages"][:2], tokenizer, mask_inputs=mask_inputs) + # import pdb; pdb.set_trace() + return dataset.map( + lambda ex: { + **self.tokenize_messages( + ex["messages"], + self.tokenizer, + mask_inputs=self.config.mask_inputs, + ignore_empty_think=self.config.ignore_empty_think, + ) + }, + remove_columns=dataset.column_names, + num_proc=self.config.num_proc, + desc="Tokenizing messages", + ) + elif self.config.data_format == "input_output": + if "input" not in dataset.column_names or "output" not in dataset.column_names: + raise ValueError("Dataset must have 'input' and 'output' columns to tokenize for SFTDataFactory with 'input_output' format.") + dataset = dataset.select_columns(["input", "output"]) + dataset = dataset.select(range(len(dataset))) + return dataset.map( + lambda ex: { + **self.tokenize_input_output( + ex["input"], + ex["output"], + self.tokenizer, + mask_inputs=self.config.mask_inputs, + ) + }, + remove_columns=dataset.column_names, + num_proc=self.config.num_proc, + desc="Tokenizing input/output", + ) + else: + raise ValueError(f"Unknown data_format: {self.config.data_format}. Valid values are 'messages' and 'input_output'.") @classmethod def tokenize_messages( @@ -416,6 +448,55 @@ def tokenize_messages( return conversation_ids + @classmethod + def tokenize_input_output( + cls, + input_text: str, + output_text: str, + tokenizer: PreTrainedTokenizerBase, + mask_inputs: bool = True, + ) -> BatchEncoding: + """ + Tokenize INPUT and OUTPUT format where only OUTPUT is trained. + + Args: + input_text: The input/instruction text (will be masked in labels) + output_text: The output/response text (will be trained) + tokenizer: The tokenizer to use + mask_inputs: Whether to mask the input portion in labels + + Returns: + BatchEncoding with input_ids, attention_mask, and labels + """ + # Tokenize INPUT and OUTPUT separately to avoid boundary issues + input_encoding = tokenizer( + input_text, + add_special_tokens=False, + ) + output_encoding = tokenizer( + output_text, + add_special_tokens=False, + ) + + # Concatenate input_ids from both encodings + input_ids = input_encoding["input_ids"] + output_encoding["input_ids"] + + # Concatenate attention_mask from both encodings + attention_mask = input_encoding["attention_mask"] + output_encoding["attention_mask"] + + if mask_inputs: + # Create labels: mask INPUT portion, keep OUTPUT portion + labels = [IGNORE_INDEX] * len(input_encoding["input_ids"]) + output_encoding["input_ids"] + else: + # Train on the full sequence + labels = input_ids + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + } + @staticmethod # this code is adpoted from https://github.com/huggingface/trl/issues/632 (user: Peter-Devine ) def get_assistant_start_end_indices( diff --git a/tests/data/test_sft_factory.py b/tests/data/test_sft_factory.py index 8b267207..f5c6ce88 100644 --- a/tests/data/test_sft_factory.py +++ b/tests/data/test_sft_factory.py @@ -17,6 +17,9 @@ from typing import List import pytest +from datasets import Dataset + +from arctic_training.data.sft_factory import IGNORE_INDEX, SFTDataFactory from .utils import create_data_factory @@ -108,3 +111,112 @@ def test_pad_to(model_name: str, tmp_path: Path): dataloader, _ = data_factory() for batch in dataloader: assert batch["input_ids"].shape[1] % 256 == 0, "Incorrect padded sequence length" + + +def test_tokenize_input_output(model_name: str): + """Test the tokenize_input_output method for INPUT/OUTPUT format.""" + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = tokenizer.eos_token + + input_text = "Translate to French: Hello, how are you?" + output_text = " Bonjour, comment allez-vous?" + + # Test with masking enabled (default) + result_masked = SFTDataFactory.tokenize_input_output( + input_text=input_text, output_text=output_text, tokenizer=tokenizer, mask_inputs=True + ) + + # Verify that we have the expected keys + assert "input_ids" in result_masked + assert "labels" in result_masked + assert "attention_mask" in result_masked + + # Verify that input_ids contains the concatenated sequence (tokenized separately) + input_encoding = tokenizer(input_text, add_special_tokens=False) + output_encoding = tokenizer(output_text, add_special_tokens=False) + expected_input_ids = input_encoding["input_ids"] + output_encoding["input_ids"] + assert result_masked["input_ids"] == expected_input_ids + + # Verify that INPUT portion is masked in labels + input_length = len(input_encoding["input_ids"]) + + # Check that first input_length tokens are masked + assert all(label == IGNORE_INDEX for label in result_masked["labels"][:input_length]) + + # Check that remaining tokens match OUTPUT encoding (not full input_ids) + assert result_masked["labels"][input_length:] == output_encoding["input_ids"] + + # Test without masking + result_unmasked = SFTDataFactory.tokenize_input_output( + input_text=input_text, output_text=output_text, tokenizer=tokenizer, mask_inputs=False + ) + + # Verify that labels equal input_ids when masking is disabled + assert result_unmasked["labels"] == result_unmasked["input_ids"] + + # Verify that input_ids are the same in both cases + assert result_masked["input_ids"] == result_unmasked["input_ids"] + + +def test_input_output_data_format(model_name: str, tmp_path: Path): + """Test SFT data factory with input_output format using a synthetic dataset.""" + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = tokenizer.eos_token + + # Create a synthetic dataset with input/output format + synthetic_data = Dataset.from_dict( + { + "input": [ + "Translate to French: Hello", + "What is 2+2?", + "Write a haiku about coding", + ], + "output": [" Bonjour", " The answer is 4.", " Code flows like water,\nBugs rise and fall away,\nPeace in the debug."], + } + ) + + # Save dataset to disk + dataset_path = tmp_path / "test_input_output_dataset" + synthetic_data.save_to_disk(str(dataset_path)) + + # Create data factory with input_output format + data_factory = create_data_factory( + model_name=model_name, + data_config_kwargs=dict( + type="sft", + data_format="input_output", + sources=[ + { + "type": "huggingface_input_output", + "name_or_path": str(dataset_path), + "input_key": "input", + "output_key": "output", + } + ], + cache_dir=tmp_path, + max_length=512, + batch_size=2, + ), + ) + + # Get dataloader + dataloader, _ = data_factory() + + # Verify that data is loaded correctly + batch = next(iter(dataloader)) + + assert "input_ids" in batch + assert "labels" in batch + assert "attention_mask" in batch or "position_ids" in batch + + # Verify that some labels are masked (IGNORE_INDEX) + has_masked_labels = any((batch["labels"] == IGNORE_INDEX).any() for _ in range(len(batch["labels"]))) + assert has_masked_labels, "Expected some labels to be masked (IGNORE_INDEX)" + + # Verify that not all labels are masked + has_unmasked_labels = any((batch["labels"] != IGNORE_INDEX).any() for _ in range(len(batch["labels"]))) + assert has_unmasked_labels, "Expected some labels to be unmasked (trainable)" From 96fd65da846e10a7f65ed0f39ef170b195789eee Mon Sep 17 00:00:00 2001 From: zhewei yao Date: Tue, 3 Feb 2026 01:20:19 +0000 Subject: [PATCH 2/3] fix print and number of training steps --- arctic_training/trainer/trainer.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/arctic_training/trainer/trainer.py b/arctic_training/trainer/trainer.py index 7035a4b3..5be667b2 100644 --- a/arctic_training/trainer/trainer.py +++ b/arctic_training/trainer/trainer.py @@ -364,8 +364,7 @@ def training_horizon(self) -> int: if self.config.train_iters: return self.config.train_iters - # XXX: this was incorrect for GAS - return self.config.epochs * len(self.train_dataloader) # // self.config.gradient_accumulation_steps + return self.config.epochs * len(self.train_dataloader) // self.config.gradient_accumulation_steps @callback_wrapper("loss") @abstractmethod @@ -484,10 +483,10 @@ def epoch(self) -> None: self.metrics.restart_timer("iter") - if self.config.train_log_iter_interval != 0: - self.metrics.print_summary() - if self.gas_boundary: + if self.config.train_log_iter_interval != 0: + self.metrics.print_summary() + if ( self.global_rank == 0 and self.config.train_log_iter_interval != 0 From 95d33103cfe6cab9db2cf063f401d35c0ad07dde Mon Sep 17 00:00:00 2001 From: zhewei yao Date: Tue, 3 Feb 2026 19:01:45 +0000 Subject: [PATCH 3/3] fixing loss cal --- arctic_training/metrics.py | 5 ++++- arctic_training/trainer/sft_trainer.py | 24 +++++++++++++++++++++++- arctic_training/trainer/trainer.py | 6 +++--- 3 files changed, 30 insertions(+), 5 deletions(-) diff --git a/arctic_training/metrics.py b/arctic_training/metrics.py index 4e2885e8..621f6f58 100644 --- a/arctic_training/metrics.py +++ b/arctic_training/metrics.py @@ -108,7 +108,7 @@ def get_value(self, key: str) -> Union[int, float]: """Returns the value stored in the metrics dictionary for the given key.""" return self.values[key] - def print_summary(self, prefix: str = "train") -> None: + def print_summary(self, prefix: str = "train", print_output: bool = True) -> None: """Prints a summary of the metrics. If a value is not recorded by the Trainer, it is not included in the summary.""" if not self.enabled: return @@ -170,6 +170,9 @@ def print_summary(self, prefix: str = "train") -> None: self.values.clear() + if not print_output: + return + summary_str = ( f"{prefix.title():>{len('train')}} iter:" f" {self.summary_dict['iter']:>{self.max_iter_pad}}/{self.max_iter}" diff --git a/arctic_training/trainer/sft_trainer.py b/arctic_training/trainer/sft_trainer.py index b9c2fed2..f5dc71e8 100644 --- a/arctic_training/trainer/sft_trainer.py +++ b/arctic_training/trainer/sft_trainer.py @@ -17,6 +17,7 @@ from typing import Union import torch +import torch.nn as nn import torch.distributed.nn.functional from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledFusedLogitsLoss @@ -41,16 +42,37 @@ class SFTTrainer(Trainer): optimizer_factory: Union[FusedAdamOptimizerFactory, CPUAdamOptimizerFactory] scheduler_factory: Union[HFSchedulerFactory] tokenizer_factory: Union[HFTokenizerFactory] + loss_fct = nn.CrossEntropyLoss(reduction="none") def loss(self, batch) -> torch.Tensor: batch = to_device(batch, self.device) + # print(self.tokenizer.decode(batch['input_ids'][0])) + # print((batch['labels'] != -100).sum()) + # import pdb; pdb.set_trace() if self.config.sequence_parallel_size == 1: # if model.type=liger is configured - this will use a much more efficient fused # logits+loss liger kernel - using significantly less gpu memory and a bit faster # compute (liger fused logits+loss kernel does not repeat forward during backward) + labels = batch.pop("labels").view(-1) + loss_mask = (labels != -100).view(-1)[1:].contiguous() + # pred = labels[loss_mask] + # print(self.tokenizer.decode(pred)) + shift_labels = batch['input_ids'][:, 1:].view(-1).contiguous() + input_ids = batch['input_ids'] + # self.tokenizer.decode(shift_labels[loss_mask]) + # import pdb; pdb.set_trace() + outputs = self.model(**batch, use_cache=False) - loss = outputs.loss + logits = outputs.logits + shift_logits = logits[..., :-1, :].contiguous() + shift_logits = shift_logits.view(-1, logits.shape[-1]) + shift_labels = shift_labels.view(-1) + loss = self.loss_fct(shift_logits, shift_labels) + loss = (loss * loss_mask.contiguous()).sum() + # import pdb; pdb.set_trace() + # loss = outputs.loss + # import pdb; pdb.set_trace() return loss # Ulysses SP expectations: diff --git a/arctic_training/trainer/trainer.py b/arctic_training/trainer/trainer.py index 5be667b2..a167e22e 100644 --- a/arctic_training/trainer/trainer.py +++ b/arctic_training/trainer/trainer.py @@ -483,10 +483,10 @@ def epoch(self) -> None: self.metrics.restart_timer("iter") - if self.gas_boundary: - if self.config.train_log_iter_interval != 0: - self.metrics.print_summary() + if self.config.train_log_iter_interval != 0: + self.metrics.print_summary(print_output=self.gas_boundary) + if self.gas_boundary: if ( self.global_rank == 0 and self.config.train_log_iter_interval != 0