Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions arctic_training/data/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def _validate_subclass(cls) -> None:
_validate_class_method(cls, "process", ["self", "dataset"])
_validate_class_method(cls, "split_data", ["self", "training_data"])
_validate_class_method(cls, "create_dataloader", ["self", "dataset"])
_validate_class_method(cls, "create_dataloader_no_shuffle", ["self", "dataset"])

def __init__(self, trainer: "Trainer", config: Optional[DataConfig] = None) -> None:
if config is None:
Expand Down Expand Up @@ -226,14 +227,25 @@ def split_data(self, training_data: DatasetType) -> Tuple[DatasetType, Optional[

return training_data, evaluation_data

@callback_wrapper("create_dataloader")
def create_dataloader(self, dataset: DatasetType) -> DataLoader:
def _create_dataloader(self, dataset: DatasetType, sampler_shuffle: bool = True) -> DataLoader:
"""Create a torch DataLoader from the dataset."""
return DataLoader(
dataset,
batch_size=self.micro_batch_size,
sampler=DistributedSampler(dataset, num_replicas=self.world_size, rank=self.global_rank),
sampler=DistributedSampler(
dataset, num_replicas=self.world_size, rank=self.global_rank, shuffle=sampler_shuffle
),
num_workers=self.config.dl_num_workers,
persistent_workers=True,
drop_last=True,
)

@callback_wrapper("create_dataloader")
def create_dataloader(self, dataset: DatasetType) -> DataLoader:
"""Create a torch DataLoader from the dataset."""
return self._create_dataloader(dataset)

@callback_wrapper("create_dataloader_no_shuffle")
def create_dataloader_no_shuffle(self, dataset: DatasetType) -> DataLoader:
"""Create a torch DataLoader from the dataset."""
return self._create_dataloader(dataset, sampler_shuffle=False)
Comment on lines +248 to +251
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to require a new method for this? Why not extend the existing create_dataloader? This would avoid the need to add this new method for each data factory (_validate_class_method(cls, "create_dataloader_no_shuffle", ["self", "dataset"]))

def create_dataloader(self, dataset: DatasetType, shuffle: bool = True):
    return self._create_dataloader(sampler_shuffle=shuffle)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I previously extended the create_dataloader with optional shuffle flag. However, this caused UT failure on

_validate_class_method(cls, "create_dataloader", ["self", "dataset"])

It seems the validation does not support optional args, or at least I don't know how to achieve that.

I also didn't want shuffle to be mandatory for create_dataloader. But if this is preferred, I can make that change.

What do you prefer?

148 changes: 141 additions & 7 deletions arctic_training/data/sft_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Literal
from typing import Optional
from typing import Tuple
from typing import Union

import numpy as np
import torch
Expand All @@ -38,6 +39,15 @@
from arctic_training.data.utils import DatasetType

IGNORE_INDEX = -100
PACKING_KEYS = (
"input_ids",
"labels",
"position_ids",
"packed_sample_seqlens",
"attention_mask",
)

Packed_Data_Type = Dict[str, List[Union[List[int], int]]]


# this function is modified from TRL trl.trainer.utils.py
Expand Down Expand Up @@ -142,9 +152,11 @@ def __call__(self, instances: List[Dict]) -> Dict[str, torch.Tensor]:
if "position_ids" in instances[0]:
position_ids = [torch.tensor(example["position_ids"]) for example in instances]
packed_sample_seqlens = [example["packed_sample_seqlens"] for example in instances]
packed_seqlens_square_sum = [example["packed_seqlens_square_sum"] for example in instances]
else:
position_ids = [torch.tensor(list(range(len(example["input_ids"])))) for example in instances]
packed_sample_seqlens = [[len(example["input_ids"])] for example in instances]
packed_seqlens_square_sum = [-1 for example in instances]

fake_unpacked_long_seq = False
# fake_unpacked_long_seq = True
Expand Down Expand Up @@ -178,21 +190,106 @@ def __call__(self, instances: List[Dict]) -> Dict[str, torch.Tensor]:
"labels": labels,
"position_ids": position_ids,
"packed_sample_seqlens": packed_sample_seqlens,
"packed_seqlens_square_sum": packed_seqlens_square_sum,
}


def pack_sft_batch(
def pack_sft_batch_balance_length(
batch: Dict[str, List[List[int]]],
max_length: int,
always_max_length: bool,
drop_last: bool,
fuse_positions_prob: float,
seed: int,
) -> Packed_Data_Type:
keys = PACKING_KEYS
packed_batch: Packed_Data_Type = {k: [] for k in keys}
packed_batch["packed_seqlens_square_sum"] = []

rng = random.Random(seed)

# Best-fit-decreasing bin packing to maximize utilization of `max_length`.
# This reorders the samples within the provided batch for denser packing.
Comment thread
sfc-gh-truwase marked this conversation as resolved.
Outdated
samples = list(zip(batch["input_ids"], batch["labels"], batch["attention_mask"]))
# Sort by length descending; tie-breaker is deterministic to keep runs reproducible.
sorted_indices = sorted(range(len(samples)), key=lambda i: len(samples[i][0]), reverse=True)

bins: List[Dict[str, List[int]]] = []
bin_lengths: List[int] = []

def start_new_bin() -> int:
bins.append({k: [] for k in keys})
bin_lengths.append(0)
return len(bins) - 1

for idx in sorted_indices:
input_ids, labels, attention_mask = samples[idx]
sample_len = len(input_ids)

# Find the bin that leaves the least remaining space after insertion.
best_bin = None
best_remaining = None
for bin_idx, current_len in enumerate(bin_lengths):
remaining = max_length - current_len
if remaining <= 0:
continue
if not always_max_length and sample_len > remaining:
continue
take_len = min(sample_len, remaining)
remaining_after = remaining - take_len
if best_remaining is None or remaining_after < best_remaining:
best_remaining = remaining_after
best_bin = bin_idx

if best_bin is None:
best_bin = start_new_bin()

target_bin = bins[best_bin]
remaining = max_length - bin_lengths[best_bin]
if remaining <= 0:
continue # should not happen, but guard against negative remaining
take_len = min(sample_len, remaining) if always_max_length else sample_len
take_len = min(take_len, remaining)

target_bin["input_ids"].extend(input_ids[:take_len])
target_bin["labels"].extend(labels[:take_len])
target_bin["attention_mask"].extend(attention_mask[:take_len])
target_bin["position_ids"].extend(range(take_len))
target_bin["packed_sample_seqlens"].append(take_len)
bin_lengths[best_bin] += take_len

for bin_idx, packed in enumerate(bins):
total_len = bin_lengths[bin_idx]
if drop_last and total_len < max_length:
continue
if fuse_positions_prob and rng.random() <= fuse_positions_prob:
packed["position_ids"] = list(range(len(packed["input_ids"])))

# Add sum(seqlen^2) field
packed_batch["packed_seqlens_square_sum"].append(
sum([seqlen**2 for seqlen in packed["packed_sample_seqlens"]])
)

for k in keys:
packed_batch[k].append(packed[k])

return packed_batch


def pack_sft_batch_naive(
batch: Dict[str, List[List[int]]],
max_length: int,
always_max_length: bool,
drop_last: bool,
fuse_positions_prob: float,
seed: int,
) -> Dict[str, List[List[int]]]:
keys = ("input_ids", "labels", "position_ids", "packed_sample_seqlens", "attention_mask")
packed_batch: Dict[str, List[List[int]]] = {k: [] for k in keys}
) -> Packed_Data_Type:
keys = PACKING_KEYS
packed_batch: Packed_Data_Type = {k: [] for k in keys}
current_sample: Dict[str, List[int]] = {k: [] for k in keys}

packed_batch["packed_seqlens_square_sum"] = []

rng = random.Random(seed)

def should_flush() -> bool:
Expand All @@ -203,6 +300,16 @@ def flush() -> None:
if len(current_sample["input_ids"]) > 0:
if fuse_positions_prob and rng.random() <= fuse_positions_prob:
current_sample["position_ids"] = list(range(len(current_sample["input_ids"])))

# current_sample["packed_seqlens_square_sum"] = sum(
# [len * len for len in current_sample["packed_sample_seqlens"]]
# )
Comment thread
sfc-gh-truwase marked this conversation as resolved.
Outdated

# Add sum(seqlen^2) field
packed_batch["packed_seqlens_square_sum"].append(
sum([seqlen**2 for seqlen in current_sample["packed_sample_seqlens"]])
)

for k in keys:
packed_batch[k].append(current_sample[k])
current_sample[k] = []
Expand Down Expand Up @@ -248,6 +355,17 @@ class SFTDataConfig(DataConfig):
pack_samples: bool = False
""" Whether to pack multiple samples into samples up to size `max_length`. """

pack_samples_mode: Literal["naive", "balance_length"] = "naive"

shuffle_samples: bool = True
Copy link
Copy Markdown
Collaborator

@sfc-gh-sbekman sfc-gh-sbekman Jan 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this name is ambiguous as it doesn't tell when samples get shuffled - and we do want them to get shuffled always! if we don't we should fix that.

I think this should say dl_shuffle_samples or something such to indicate it's the DL that is being controlled.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will fix naming.

and we do want them to get shuffled always! if we don't we should fix that.

Can you clarify why we want shuffling always, since it is a major cause of imbalance across the ranks.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to say that the dataset must always be shuffled for proper training.

So if it doesn't get shuffled at the DL level then we need to shuffle it at the dataset level before we sort or pack, so if we are going for this work of this PR, we have 2 choices:

  1. first shuffle the dataset, then pack, then sort
  2. pack first, then shuffle the packed samples, then sort

I think (1) makes the most sense, since we want to shuffle at the source level, packing first is likely to result in less randomization.

And of course as flagged by your experiments we have an issue with multiple datasets not being blended but concatenated which leads to loss spikes when domain changes abruptly. But of course this is out of the scope of this PR.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Thanks for the clarification. I think we are talking about shuffling from different angles. The new dl_shuffle_samples is meant to give user a way to workaround the default shuffling behavior of torch distributed sampler. In other words, regardless of which of your algorithm proposals that we adopt in AT, it seems that when data is loaded for training it will be shuffled here.

Copy link
Copy Markdown
Collaborator Author

@sfc-gh-truwase sfc-gh-truwase Jan 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so if we are going for this work of this PR, we have 2 choices:

  • first shuffle the dataset, then pack, then sort
  • pack first, then shuffle the packed samples, then sort

My understanding is that with either option, the preprocessed dataset will first be saved to disk before loading for the actual training. Is that correct?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are talking about shuffling from different angles.

Yes and no. If you're taking away the shuffle at DL level, we need to make sure shuffling happens elsewhere.

the preprocessed dataset will first be saved to disk before loading for the actual training. Is that correct?

That's correct, we cache the result. Only. if hparams change it'll get rebuilt.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes and no. If you're taking away the shuffle at DL level, we need to make sure shuffling happens elsewhere.

Got it. That means the perf balancing recommendation of this PR to avoid shuffling is not a practical solution.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the dataset happens to be pre-shuffled enough, then the training quality shouldn't be impacted, but if it isn't all sorts of skewed learning may occur. So I'd check with the modeling guys first.

But it should be easy to overcome, if dl_shuffle is false, do ds.shuffle first before doing the packing, no?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern is how packing/sorting affects the random order of the pre-shuffled dataset. I realize your proposal to sort within chunks, as opposed to globally, should help to retain some randomness. But it is unclear to me how effective that would be. I will get more data and follow up with you. Thanks!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remember that when a shuffled data is packed even a few small samples together, the new long samples' content is already shuffled, so sorting these longer samples will not undo the shuffling effect. Does it make sense?

The randomness will only be lost if most samples are already of max_len, which is highly unlikely.

Regardless, it massively beats the global sorting.

So shuffle first then pack should be a solid strategy to retain randomness.

""" Whether to use a sampler that shuffles indices. """

sort_packed_samples: bool = False
""" Whether to sort packed samples. """

sort_packed_samples_order: Literal["ascend", "descend"] = "descend"
""" Sorting order for packed samples. """

drop_last: bool = False
""" Whether to drop the last packed sample, which might be shorter than `max_length`. """

Expand Down Expand Up @@ -333,10 +451,16 @@ def pack_dataset(self, dataset: DatasetType) -> DatasetType:

batch_size = len(dataset) // self.config.num_proc + 1
# for huge datasets keep the bs to a sane size to avoid cpu-oom
batch_size = int(min(batch_size, 1e3))
# batch_size = int(min(batch_size, 1e3))
Comment thread
sfc-gh-truwase marked this conversation as resolved.
Outdated
batch_size = int(min(batch_size, 1e4))
dataset = dataset.shuffle(seed=self.config.seed)
if self.config.pack_samples_mode == "balance_length":
packing_fn = pack_sft_batch_balance_length
else:
packing_fn = pack_sft_batch_naive

dataset = dataset.map(
lambda x: pack_sft_batch(
lambda x: packing_fn(
x,
max_length=self.config.max_length,
always_max_length=self.config.always_max_length,
Expand All @@ -349,6 +473,12 @@ def pack_dataset(self, dataset: DatasetType) -> DatasetType:
num_proc=self.config.num_proc,
desc="Packing dataset",
)

if self.config.sort_packed_samples:
dataset = dataset.sort(
"packed_seqlens_square_sum", reverse=(self.config.sort_packed_samples_order == "descend")
)

if len(dataset) < 1:
raise ValueError(f"No data left after packing dataset samples in {self.__class__.__name__}")
return dataset
Expand Down Expand Up @@ -466,6 +596,10 @@ def get_masked_labels(conversation_ids: BatchEncoding, assistant_ranges: List[Tu
return output

def create_dataloader(self, dataset: DatasetType) -> DataLoader:
dataloader = super().create_dataloader(dataset)
dataloader = (
super().create_dataloader(dataset)
if self.config.shuffle_samples
else super().create_dataloader_no_shuffle(dataset)
)
dataloader.collate_fn = DataCollatorForCausalLM(tokenizer=self.tokenizer, config=self.config)
return dataloader
10 changes: 10 additions & 0 deletions arctic_training/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def print_summary(self, prefix: str = "train") -> None:
)
self.values["seqlen_total"] = seqlen_subtotal

# self.seqlens is list[list[int]]
self.values["seqlen_square_sum"] = sum([sum([len * len for len in seqlens]) for seqlens in self.seqlens])

if "loss" in self.values:
loss = sum(gather_object(self.values["loss"], self.trainer.world_size)) / self.trainer.world_size
self.summary_dict["loss"] = loss
Expand Down Expand Up @@ -162,6 +165,9 @@ def print_summary(self, prefix: str = "train") -> None:
seq_len_total = sum(gather_object(self.values["seqlen_total"], self.trainer.world_size))
self.summary_dict["seqlen"] = seq_len_total / self.trainer.world_size

seqlen_square_sum_total = sum(gather_object(self.values["seqlen_square_sum"], self.trainer.world_size))
self.summary_dict["seqlen_square_sum"] = seqlen_square_sum_total / self.trainer.world_size

if "step_time" in self.values:
step_time_total = sum(gather_object(self.values["step_time"], self.trainer.world_size))
self.summary_dict["step_time"] = step_time_total / self.trainer.world_size
Expand All @@ -186,6 +192,10 @@ def print_summary(self, prefix: str = "train") -> None:
summary_str += f" | lr: {self.summary_dict['lr']:.3E}"
if "seqlen" in self.summary_dict:
summary_str += f" | seqlen: {human_format_base10_number(self.summary_dict['seqlen'])}"
if "seqlen_square_sum" in self.summary_dict:
summary_str += (
f" | seqlen_square_sum: {human_format_base10_number(self.summary_dict['seqlen_square_sum'])}"
)
if "step_time" in self.summary_dict:
summary_str += f" | step time: {human_format_secs(self.summary_dict['step_time'])}"
if "step_tflops" in self.summary_dict:
Expand Down
1 change: 1 addition & 0 deletions arctic_training/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ def epoch(self) -> None:
if "packed_sample_seqlens" in batch and "flash_attention" in self.config.model.attn_implementation:
# deal correctly with packed samples under FA2/FA3, by calculating each seqlen tflos separately
sample_seqlens = batch.pop("packed_sample_seqlens")
batch.pop("packed_seqlens_square_sum")
else:
sample_seqlens = [
[len(batch["input_ids"][idx]) * self.config.sequence_parallel_size]
Expand Down