Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions arctic_training/data/hf_input_output_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any
from typing import Dict

from pydantic import Field

from arctic_training.data.hf_source import HFDataSource
from arctic_training.data.hf_source import HFDataSourceConfig
from arctic_training.data.utils import DatasetType


class HFDataSourceConfigInputOutput(HFDataSourceConfig):
input_key: str = Field(default="input")
"""The field name for the input/instruction part of the data."""

output_key: str = Field(default="output")
"""The field name for the output/response part of the data."""


class HFDataSourceInputOutput(HFDataSource):
"""
DataSource class for datasets with INPUT/OUTPUT format.

This source handles datasets where:
- Data contains separate INPUT and OUTPUT fields
- The full sequence is INPUT + OUTPUT
- Only the OUTPUT portion is used for training (INPUT is masked)
"""

name = "huggingface_input_output"
config: HFDataSourceConfigInputOutput

def post_load_callback(self, dataset: DatasetType) -> DatasetType:
def process_example(example: Dict[str, Any]) -> Dict[str, Any]:
input_text = example.get(self.config.input_key, "")
output_text = example.get(self.config.output_key, "")

if not output_text:
raise ValueError(
f"Output field '{self.config.output_key}' is empty or missing. "
f"Available keys: {list(example.keys())}"
)

return {
"input": str(input_text),
"output": str(output_text),
}

return dataset.map(
process_example,
num_proc=self.data_factory.config.num_proc,
desc=f"Loading {self.config.name_or_path}",
)
131 changes: 106 additions & 25 deletions arctic_training/data/sft_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from arctic_training.config.data import DataConfig
from arctic_training.config.utils import HumanInt
from arctic_training.data.factory import DataFactory
from arctic_training.data.hf_input_output_source import HFDataSourceInputOutput
from arctic_training.data.hf_instruct_source import HFDataSourceInstruct
from arctic_training.data.utils import DatasetType

Expand Down Expand Up @@ -265,6 +266,9 @@ class SFTDataConfig(DataConfig):
ignore_empty_think: bool = False
""" Whether to mask the empty think tokens preventing the loss of thinking ability."""

data_format: Literal["messages", "input_output"] = "messages"
""" The format of the input data: 'messages' for chat format, 'input_output' for INPUT/OUTPUT format. """

@model_validator(mode="after")
def validate_padding(self) -> Self:
if self.pad_to == "max_length" and "div_length" in self.model_fields_set:
Expand Down Expand Up @@ -357,37 +361,65 @@ def pack_dataset(self, dataset: DatasetType) -> DatasetType:
class SFTDataFactory(DataFactory):
name = "sft"
config: SFTDataConfig
default_source_cls = HFDataSourceInstruct
callbacks = [
("post-load", filter_dataset_length),
("post-load", pack_dataset),
]

@property
def default_source_cls(self):
"""Return the appropriate data source class based on data_format config."""
if self.config.data_format == "input_output":
return HFDataSourceInputOutput
else:
return HFDataSourceInstruct

def process(self, dataset: DatasetType) -> DatasetType:
if "messages" not in dataset.column_names:
raise ValueError("Dataset must have 'messages' column to tokenize for SFTDataFactory.")
dataset = dataset.select_columns(["messages"])
# sft based tokenization,
# we assume the messages are in the format of:
# {'role': '...', 'content': '...'}
# datasets = datasets.select(range(100, 1100))
dataset = dataset.select(range(len(dataset)))
# datasets.disable_caching()
# tmp = tokenize_messages(datasets[0]["messages"][:2], tokenizer, mask_inputs=mask_inputs)
# import pdb; pdb.set_trace()
return dataset.map(
lambda ex: {
**self.tokenize_messages(
ex["messages"],
self.tokenizer,
mask_inputs=self.config.mask_inputs,
ignore_empty_think=self.config.ignore_empty_think,
)
},
remove_columns=dataset.column_names,
num_proc=self.config.num_proc,
desc="Tokenizing messages",
)
if self.config.data_format == "messages":
if "messages" not in dataset.column_names:
raise ValueError("Dataset must have 'messages' column to tokenize for SFTDataFactory with 'messages' format.")
dataset = dataset.select_columns(["messages"])
# sft based tokenization,
# we assume the messages are in the format of:
# {'role': '...', 'content': '...'}
# datasets = datasets.select(range(100, 1100))
dataset = dataset.select(range(len(dataset)))
# datasets.disable_caching()
# tmp = tokenize_messages(datasets[0]["messages"][:2], tokenizer, mask_inputs=mask_inputs)
# import pdb; pdb.set_trace()
return dataset.map(
lambda ex: {
**self.tokenize_messages(
ex["messages"],
self.tokenizer,
mask_inputs=self.config.mask_inputs,
ignore_empty_think=self.config.ignore_empty_think,
)
},
remove_columns=dataset.column_names,
num_proc=self.config.num_proc,
desc="Tokenizing messages",
)
elif self.config.data_format == "input_output":
if "input" not in dataset.column_names or "output" not in dataset.column_names:
raise ValueError("Dataset must have 'input' and 'output' columns to tokenize for SFTDataFactory with 'input_output' format.")
dataset = dataset.select_columns(["input", "output"])
dataset = dataset.select(range(len(dataset)))
return dataset.map(
lambda ex: {
**self.tokenize_input_output(
ex["input"],
ex["output"],
self.tokenizer,
mask_inputs=self.config.mask_inputs,
)
},
remove_columns=dataset.column_names,
num_proc=self.config.num_proc,
desc="Tokenizing input/output",
)
else:
raise ValueError(f"Unknown data_format: {self.config.data_format}. Valid values are 'messages' and 'input_output'.")

@classmethod
def tokenize_messages(
Expand Down Expand Up @@ -416,6 +448,55 @@ def tokenize_messages(

return conversation_ids

@classmethod
def tokenize_input_output(
cls,
input_text: str,
output_text: str,
tokenizer: PreTrainedTokenizerBase,
mask_inputs: bool = True,
) -> BatchEncoding:
"""
Tokenize INPUT and OUTPUT format where only OUTPUT is trained.

Args:
input_text: The input/instruction text (will be masked in labels)
output_text: The output/response text (will be trained)
tokenizer: The tokenizer to use
mask_inputs: Whether to mask the input portion in labels

Returns:
BatchEncoding with input_ids, attention_mask, and labels
"""
# Tokenize INPUT and OUTPUT separately to avoid boundary issues
input_encoding = tokenizer(
input_text,
add_special_tokens=False,
)
output_encoding = tokenizer(
output_text,
add_special_tokens=False,
)

# Concatenate input_ids from both encodings
input_ids = input_encoding["input_ids"] + output_encoding["input_ids"]

# Concatenate attention_mask from both encodings
attention_mask = input_encoding["attention_mask"] + output_encoding["attention_mask"]

if mask_inputs:
# Create labels: mask INPUT portion, keep OUTPUT portion
labels = [IGNORE_INDEX] * len(input_encoding["input_ids"]) + output_encoding["input_ids"]
else:
# Train on the full sequence
labels = input_ids

return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}

@staticmethod
# this code is adpoted from https://github.com/huggingface/trl/issues/632 (user: Peter-Devine )
def get_assistant_start_end_indices(
Expand Down
5 changes: 4 additions & 1 deletion arctic_training/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def get_value(self, key: str) -> Union[int, float]:
"""Returns the value stored in the metrics dictionary for the given key."""
return self.values[key]

def print_summary(self, prefix: str = "train") -> None:
def print_summary(self, prefix: str = "train", print_output: bool = True) -> None:
"""Prints a summary of the metrics. If a value is not recorded by the Trainer, it is not included in the summary."""
if not self.enabled:
return
Expand Down Expand Up @@ -170,6 +170,9 @@ def print_summary(self, prefix: str = "train") -> None:

self.values.clear()

if not print_output:
return

summary_str = (
f"{prefix.title():>{len('train')}} iter:"
f" {self.summary_dict['iter']:>{self.max_iter_pad}}/{self.max_iter}"
Expand Down
24 changes: 23 additions & 1 deletion arctic_training/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Union

import torch
import torch.nn as nn
import torch.distributed.nn.functional
from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledFusedLogitsLoss

Expand All @@ -41,16 +42,37 @@ class SFTTrainer(Trainer):
optimizer_factory: Union[FusedAdamOptimizerFactory, CPUAdamOptimizerFactory]
scheduler_factory: Union[HFSchedulerFactory]
tokenizer_factory: Union[HFTokenizerFactory]
loss_fct = nn.CrossEntropyLoss(reduction="none")

def loss(self, batch) -> torch.Tensor:
batch = to_device(batch, self.device)
# print(self.tokenizer.decode(batch['input_ids'][0]))
# print((batch['labels'] != -100).sum())
# import pdb; pdb.set_trace()

if self.config.sequence_parallel_size == 1:
# if model.type=liger is configured - this will use a much more efficient fused
# logits+loss liger kernel - using significantly less gpu memory and a bit faster
# compute (liger fused logits+loss kernel does not repeat forward during backward)
labels = batch.pop("labels").view(-1)
loss_mask = (labels != -100).view(-1)[1:].contiguous()
# pred = labels[loss_mask]
# print(self.tokenizer.decode(pred))
shift_labels = batch['input_ids'][:, 1:].view(-1).contiguous()
input_ids = batch['input_ids']
# self.tokenizer.decode(shift_labels[loss_mask])
# import pdb; pdb.set_trace()

outputs = self.model(**batch, use_cache=False)
loss = outputs.loss
logits = outputs.logits
shift_logits = logits[..., :-1, :].contiguous()
shift_logits = shift_logits.view(-1, logits.shape[-1])
shift_labels = shift_labels.view(-1)
loss = self.loss_fct(shift_logits, shift_labels)
loss = (loss * loss_mask.contiguous()).sum()
# import pdb; pdb.set_trace()
# loss = outputs.loss
# import pdb; pdb.set_trace()
return loss

# Ulysses SP expectations:
Expand Down
5 changes: 2 additions & 3 deletions arctic_training/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,7 @@ def training_horizon(self) -> int:
if self.config.train_iters:
return self.config.train_iters

# XXX: this was incorrect for GAS
return self.config.epochs * len(self.train_dataloader) # // self.config.gradient_accumulation_steps
return self.config.epochs * len(self.train_dataloader) // self.config.gradient_accumulation_steps

@callback_wrapper("loss")
@abstractmethod
Expand Down Expand Up @@ -485,7 +484,7 @@ def epoch(self) -> None:
self.metrics.restart_timer("iter")

if self.config.train_log_iter_interval != 0:
self.metrics.print_summary()
self.metrics.print_summary(print_output=self.gas_boundary)

if self.gas_boundary:
if (
Expand Down
Loading
Loading