From 14b90bd6aac9848343fa434eda4c069a2f1c6ec7 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Thu, 4 Dec 2025 17:51:42 +0000 Subject: [PATCH 01/14] Plot DS wall clock timers in W&B Signed-off-by: Olatunji Ruwase --- arctic_training/trainer/trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/arctic_training/trainer/trainer.py b/arctic_training/trainer/trainer.py index 55a8946e..68cbcfa5 100644 --- a/arctic_training/trainer/trainer.py +++ b/arctic_training/trainer/trainer.py @@ -456,6 +456,8 @@ def epoch(self) -> None: and self.global_step % self.config.train_log_iter_interval == 0 ): metrics = {k: v for k, v in self.metrics.summary_dict.items()} + ds_timers = self.model.get_wall_clock_timers() + metrics.update(ds_timers) append_json_file(self.config.train_log_metrics_path, metrics) From ff821ba12a5ff0048f450b4377ee7afedab82655 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Thu, 4 Dec 2025 18:43:43 +0000 Subject: [PATCH 02/14] BC safety Signed-off-by: Olatunji Ruwase --- arctic_training/trainer/trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/arctic_training/trainer/trainer.py b/arctic_training/trainer/trainer.py index 68cbcfa5..66f250ac 100644 --- a/arctic_training/trainer/trainer.py +++ b/arctic_training/trainer/trainer.py @@ -208,6 +208,7 @@ def __init__(self, config: TrainerConfig, mode: str = "train") -> None: dschf = HfDeepSpeedConfig(self.config.deepspeed) # noqa: F841 model_factory = self.config.model.factory(self) self.model = model_factory() + self.ds_wall_clock_available = hasattr(self.model, 'get_wall_clock_timers') # prevent causal mask from being created in HF Transformers - it's a huge `[bs, seqlen, seqlen]` tensor # XXX: This should also benefit a single gpu use case when SDPA is used - so perhaps remove the SP>1 check? @@ -456,8 +457,9 @@ def epoch(self) -> None: and self.global_step % self.config.train_log_iter_interval == 0 ): metrics = {k: v for k, v in self.metrics.summary_dict.items()} - ds_timers = self.model.get_wall_clock_timers() - metrics.update(ds_timers) + if self.ds_wall_clock_available: + ds_timers = self.model.get_wall_clock_timers() + metrics.update(ds_timers) append_json_file(self.config.train_log_metrics_path, metrics) From e47c58c2223d1b7b86e66557af1c1e71a8646a1f Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Fri, 5 Dec 2025 14:37:15 +0000 Subject: [PATCH 03/14] Fix bug Signed-off-by: Olatunji Ruwase --- arctic_training/trainer/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/arctic_training/trainer/trainer.py b/arctic_training/trainer/trainer.py index 66f250ac..63728653 100644 --- a/arctic_training/trainer/trainer.py +++ b/arctic_training/trainer/trainer.py @@ -208,7 +208,6 @@ def __init__(self, config: TrainerConfig, mode: str = "train") -> None: dschf = HfDeepSpeedConfig(self.config.deepspeed) # noqa: F841 model_factory = self.config.model.factory(self) self.model = model_factory() - self.ds_wall_clock_available = hasattr(self.model, 'get_wall_clock_timers') # prevent causal mask from being created in HF Transformers - it's a huge `[bs, seqlen, seqlen]` tensor # XXX: This should also benefit a single gpu use case when SDPA is used - so perhaps remove the SP>1 check? @@ -235,6 +234,8 @@ def __init__(self, config: TrainerConfig, mode: str = "train") -> None: mpu=mpu, ) + self.ds_wall_clock_available = hasattr(self.model, "get_wall_clock_timers") + if self.config.sequence_parallel_size > 1: # deepspeed.initialize needs to run first from deepspeed.utils import groups From e20671e1f65fac5336d448bb8c07a5815e3f5e6b Mon Sep 17 00:00:00 2001 From: Tunji Ruwase Date: Mon, 26 Jan 2026 16:07:16 +0000 Subject: [PATCH 04/14] Debug sample lens Signed-off-by: Tunji Ruwase --- arctic_training/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arctic_training/trainer/trainer.py b/arctic_training/trainer/trainer.py index 63728653..df75248b 100644 --- a/arctic_training/trainer/trainer.py +++ b/arctic_training/trainer/trainer.py @@ -441,7 +441,7 @@ def epoch(self) -> None: for idx in range(len(batch["input_ids"])) ] self.metrics.seqlens = sample_seqlens - + print(f'AT: {self.global_rank=} {self.train_batch_idx=} {sample_seqlens=} {sum(sample_seqlens[0])=}') self.metrics.start_timer("step") self.step(batch) self.metrics.stop_timer("step") From 5cc68ecc049f3a104df40bcd9bf251db87273d6f Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Thu, 29 Jan 2026 11:28:58 +0000 Subject: [PATCH 05/14] Sample packing improvements Signed-off-by: Olatunji Ruwase --- arctic_training/data/factory.py | 4 +- arctic_training/data/sft_factory.py | 144 ++++++++++++++++++++++++++-- arctic_training/metrics.py | 8 ++ arctic_training/trainer/trainer.py | 5 +- 4 files changed, 152 insertions(+), 9 deletions(-) diff --git a/arctic_training/data/factory.py b/arctic_training/data/factory.py index 82c7aec4..74a4f9f3 100644 --- a/arctic_training/data/factory.py +++ b/arctic_training/data/factory.py @@ -227,12 +227,12 @@ def split_data(self, training_data: DatasetType) -> Tuple[DatasetType, Optional[ return training_data, evaluation_data @callback_wrapper("create_dataloader") - def create_dataloader(self, dataset: DatasetType) -> DataLoader: + def create_dataloader(self, dataset: DatasetType, sampler_shuffle: bool = True) -> DataLoader: """Create a torch DataLoader from the dataset.""" return DataLoader( dataset, batch_size=self.micro_batch_size, - sampler=DistributedSampler(dataset, num_replicas=self.world_size, rank=self.global_rank), + sampler=DistributedSampler(dataset, num_replicas=self.world_size, rank=self.global_rank, shuffle=sampler_shuffle), num_workers=self.config.dl_num_workers, persistent_workers=True, drop_last=True, diff --git a/arctic_training/data/sft_factory.py b/arctic_training/data/sft_factory.py index 22283d14..816f9530 100644 --- a/arctic_training/data/sft_factory.py +++ b/arctic_training/data/sft_factory.py @@ -38,7 +38,7 @@ from arctic_training.data.utils import DatasetType IGNORE_INDEX = -100 - +PACKING_KEYS = ("input_ids", "labels", "position_ids", "packed_sample_seqlens", "attention_mask", "packed_sort_index", "packed_seqlens_square_sum") # this function is modified from TRL trl.trainer.utils.py def pad( @@ -139,12 +139,17 @@ def __call__(self, instances: List[Dict]) -> Dict[str, torch.Tensor]: # torch.tensor(example["attention_mask"]) for example in instances # ] + if "position_ids" in instances[0]: position_ids = [torch.tensor(example["position_ids"]) for example in instances] packed_sample_seqlens = [example["packed_sample_seqlens"] for example in instances] + packed_sort_index = [example["packed_sort_index"] for example in instances] + packed_seqlens_square_sum = [example["packed_seqlens_square_sum"] for example in instances] else: position_ids = [torch.tensor(list(range(len(example["input_ids"])))) for example in instances] packed_sample_seqlens = [[len(example["input_ids"])] for example in instances] + packed_sort_index = [-1 for example in instances] + packed_seqlens_square_sum = [-1 for example in instances] fake_unpacked_long_seq = False # fake_unpacked_long_seq = True @@ -178,10 +183,111 @@ def __call__(self, instances: List[Dict]) -> Dict[str, torch.Tensor]: "labels": labels, "position_ids": position_ids, "packed_sample_seqlens": packed_sample_seqlens, + "packed_sort_index": packed_sort_index, + "packed_seqlens_square_sum": packed_seqlens_square_sum, } -def pack_sft_batch( +def pack_sft_batch_balance_length( + batch: Dict[str, List[List[int]]], + max_length: int, + always_max_length: bool, + drop_last: bool, + fuse_positions_prob: float, + seed: int, +) -> Dict[str, List[List[int]]]: + keys = PACKING_KEYS + packed_batch: Dict[str, List[List[int]]] = {k: [] for k in keys} + current_sample: Dict[str, List[int]] = {k: [] for k in keys} + + rng = random.Random(seed) + + # Best-fit-decreasing bin packing to maximize utilization of `max_length`. + # This reorders the samples within the provided batch for denser packing. + samples = list( + zip(batch["input_ids"], batch["labels"], batch["attention_mask"]) + ) + # Sort by length descending; tie-breaker is deterministic to keep runs reproducible. + sorted_indices = sorted(range(len(samples)), key=lambda i: len(samples[i][0]), reverse=True) + + print(f"pack_sft_batch: packing batch of {len(samples)=}") + + bins: List[Dict[str, List[int]]] = [] + bin_lengths: List[int] = [] + + def start_new_bin() -> int: + bins.append({k: [] for k in keys}) + bin_lengths.append(0) + return len(bins) - 1 + + for idx in sorted_indices: + input_ids, labels, attention_mask = samples[idx] + sample_len = len(input_ids) + + # Find the bin that leaves the least remaining space after insertion. + best_bin = None + best_remaining = None + for bin_idx, current_len in enumerate(bin_lengths): + remaining = max_length - current_len + if remaining <= 0: + continue + if not always_max_length and sample_len > remaining: + continue + take_len = min(sample_len, remaining) + remaining_after = remaining - take_len + if best_remaining is None or remaining_after < best_remaining: + best_remaining = remaining_after + best_bin = bin_idx + + if best_bin is None: + best_bin = start_new_bin() + + target_bin = bins[best_bin] + remaining = max_length - bin_lengths[best_bin] + if remaining <= 0: + continue # should not happen, but guard against negative remaining + take_len = min(sample_len, remaining) if always_max_length else sample_len + take_len = min(take_len, remaining) + + target_bin["input_ids"].extend(input_ids[:take_len]) + target_bin["labels"].extend(labels[:take_len]) + target_bin["attention_mask"].extend(attention_mask[:take_len]) + target_bin["position_ids"].extend(range(take_len)) + target_bin["packed_sample_seqlens"].append(take_len) + bin_lengths[best_bin] += take_len + + for bin_idx, packed in enumerate(bins): + total_len = bin_lengths[bin_idx] + if drop_last and total_len < max_length: + continue + if fuse_positions_prob and rng.random() <= fuse_positions_prob: + packed["position_ids"] = list(range(len(packed["input_ids"]))) + + packed["packed_sort_index"] = bin_idx + packed["packed_seqlens_square_sum"] = sum([len * len for len in packed["packed_sample_seqlens"]]) + + for k in keys: + packed_batch[k].append(packed[k]) + + return packed_batch + + +def pack_analysis(dataset: DatasetType): + + def dump_stats(input_stats, tag): + min_value = min(input_stats) + max_value = max(input_stats) + avg_value = sum(input_stats)/len(input_stats) + + print(f"pack_analysis {tag}: {min_value=} {max_value=} {avg_value=}") + + attention_compute = [sum([(len * len)/1e9 for len in batch['packed_sample_seqlens']]) for batch in dataset] + packed_len = [sum([len for len in batch['packed_sample_seqlens']]) for batch in dataset] + + dump_stats(attention_compute, "attention_compute_billions") + dump_stats(packed_len, "packed_lengths") + +def pack_sft_batch_naive( batch: Dict[str, List[List[int]]], max_length: int, always_max_length: bool, @@ -189,7 +295,7 @@ def pack_sft_batch( fuse_positions_prob: float, seed: int, ) -> Dict[str, List[List[int]]]: - keys = ("input_ids", "labels", "position_ids", "packed_sample_seqlens", "attention_mask") + keys = PACKING_KEYS packed_batch: Dict[str, List[List[int]]] = {k: [] for k in keys} current_sample: Dict[str, List[int]] = {k: [] for k in keys} @@ -203,6 +309,10 @@ def flush() -> None: if len(current_sample["input_ids"]) > 0: if fuse_positions_prob and rng.random() <= fuse_positions_prob: current_sample["position_ids"] = list(range(len(current_sample["input_ids"]))) + + current_sample["packed_sort_index"] = len(packed_batch["packed_sort_index"]) + current_sample["packed_seqlens_square_sum"] = sum([len * len for len in current_sample["packed_sample_seqlens"]]) + for k in keys: packed_batch[k].append(current_sample[k]) current_sample[k] = [] @@ -248,6 +358,17 @@ class SFTDataConfig(DataConfig): pack_samples: bool = False """ Whether to pack multiple samples into samples up to size `max_length`. """ + pack_samples_mode: Literal["naive", "balance_length"] = "naive" + + shuffle_samples: bool = True + """ Whether to use a sampler that shuffles indices. """ + + sort_packed_samples: bool = False + """ Whether to sort packed samples. """ + + sort_packed_samples_order: Literal["ascend", "descend"] = "descend" + """ Sorting order for packed samples. """ + drop_last: bool = False """ Whether to drop the last packed sample, which might be shorter than `max_length`. """ @@ -328,15 +449,22 @@ def pack_dataset(self, dataset: DatasetType) -> DatasetType: if not self.config.pack_samples: return dataset + if self.config.repeat_to_pack_max_length: dataset = repeat_dataset(dataset=dataset, max_length=self.config.max_length, num_proc=self.config.num_proc) batch_size = len(dataset) // self.config.num_proc + 1 # for huge datasets keep the bs to a sane size to avoid cpu-oom - batch_size = int(min(batch_size, 1e3)) + # batch_size = int(min(batch_size, 1e3)) + batch_size = int(min(batch_size, 1e4)) dataset = dataset.shuffle(seed=self.config.seed) + if self.config.pack_samples_mode == "balance_length": + packing_fn = pack_sft_batch_balance_length + else: + packing_fn = pack_sft_batch_naive + dataset = dataset.map( - lambda x: pack_sft_batch( + lambda x: packing_fn( x, max_length=self.config.max_length, always_max_length=self.config.always_max_length, @@ -349,6 +477,10 @@ def pack_dataset(self, dataset: DatasetType) -> DatasetType: num_proc=self.config.num_proc, desc="Packing dataset", ) + + if self.config.sort_packed_samples: + dataset = dataset.sort("packed_seqlens_square_sum", reverse=(self.config.sort_packed_samples_order=="descend")) + if len(dataset) < 1: raise ValueError(f"No data left after packing dataset samples in {self.__class__.__name__}") return dataset @@ -466,6 +598,6 @@ def get_masked_labels(conversation_ids: BatchEncoding, assistant_ranges: List[Tu return output def create_dataloader(self, dataset: DatasetType) -> DataLoader: - dataloader = super().create_dataloader(dataset) + dataloader = super().create_dataloader(dataset, sampler_shuffle=self.config.shuffle_samples) dataloader.collate_fn = DataCollatorForCausalLM(tokenizer=self.tokenizer, config=self.config) return dataloader diff --git a/arctic_training/metrics.py b/arctic_training/metrics.py index 4e2885e8..7b59dd1c 100644 --- a/arctic_training/metrics.py +++ b/arctic_training/metrics.py @@ -134,6 +134,9 @@ def print_summary(self, prefix: str = "train") -> None: / self.trainer.config.sequence_parallel_size ) self.values["seqlen_total"] = seqlen_subtotal + + # self.seqlens is list[list[int]] + self.values["seqlen_square_sum"] = sum([sum([len*len for len in seqlens]) for seqlens in self.seqlens]) if "loss" in self.values: loss = sum(gather_object(self.values["loss"], self.trainer.world_size)) / self.trainer.world_size @@ -162,6 +165,9 @@ def print_summary(self, prefix: str = "train") -> None: seq_len_total = sum(gather_object(self.values["seqlen_total"], self.trainer.world_size)) self.summary_dict["seqlen"] = seq_len_total / self.trainer.world_size + seqlen_square_sum_total = sum(gather_object(self.values["seqlen_square_sum"], self.trainer.world_size)) + self.summary_dict["seqlen_square_sum"] = seqlen_square_sum_total / self.trainer.world_size + if "step_time" in self.values: step_time_total = sum(gather_object(self.values["step_time"], self.trainer.world_size)) self.summary_dict["step_time"] = step_time_total / self.trainer.world_size @@ -186,6 +192,8 @@ def print_summary(self, prefix: str = "train") -> None: summary_str += f" | lr: {self.summary_dict['lr']:.3E}" if "seqlen" in self.summary_dict: summary_str += f" | seqlen: {human_format_base10_number(self.summary_dict['seqlen'])}" + if "seqlen_square_sum" in self.summary_dict: + summary_str += f" | seqlen_square_sum: {human_format_base10_number(self.summary_dict['seqlen_square_sum'])}" if "step_time" in self.summary_dict: summary_str += f" | step time: {human_format_secs(self.summary_dict['step_time'])}" if "step_tflops" in self.summary_dict: diff --git a/arctic_training/trainer/trainer.py b/arctic_training/trainer/trainer.py index 8c018ddb..e1d99b19 100644 --- a/arctic_training/trainer/trainer.py +++ b/arctic_training/trainer/trainer.py @@ -173,6 +173,7 @@ def __init__(self, config: TrainerConfig, mode: str = "train") -> None: data_factory = self.config.data.factory(self) self.train_dataloader, self.eval_dataloader = data_factory() + if mode == "process-data": return @@ -471,13 +472,15 @@ def epoch(self) -> None: if "packed_sample_seqlens" in batch and "flash_attention" in self.config.model.attn_implementation: # deal correctly with packed samples under FA2/FA3, by calculating each seqlen tflos separately sample_seqlens = batch.pop("packed_sample_seqlens") + packed_sort_index = batch.pop("packed_sort_index") + packed_seqlens_square_sum = float(batch.pop("packed_seqlens_square_sum")[0]) else: sample_seqlens = [ [len(batch["input_ids"][idx]) * self.config.sequence_parallel_size] for idx in range(len(batch["input_ids"])) ] self.metrics.seqlens = sample_seqlens - print(f'AT: {self.global_rank=} {self.train_batch_idx=} {sample_seqlens=} {sum(sample_seqlens[0])=}') + self.metrics.start_timer("step") self.step(batch) self.metrics.stop_timer("step") From 41c56e7a92e663c93ca48a4a6225368976cf3809 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Thu, 29 Jan 2026 11:37:27 +0000 Subject: [PATCH 06/14] Cleanup Signed-off-by: Olatunji Ruwase --- arctic_training/trainer/trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/arctic_training/trainer/trainer.py b/arctic_training/trainer/trainer.py index e1d99b19..77c92bce 100644 --- a/arctic_training/trainer/trainer.py +++ b/arctic_training/trainer/trainer.py @@ -173,7 +173,6 @@ def __init__(self, config: TrainerConfig, mode: str = "train") -> None: data_factory = self.config.data.factory(self) self.train_dataloader, self.eval_dataloader = data_factory() - if mode == "process-data": return From 12a4fbf8287fc37dac7cd7e90d07ba9ff11ec635 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Thu, 29 Jan 2026 16:55:01 +0000 Subject: [PATCH 07/14] Formatting and cleanup Signed-off-by: Olatunji Ruwase --- arctic_training/data/factory.py | 4 +- arctic_training/data/sft_factory.py | 72 ++++++++++++++--------------- arctic_training/metrics.py | 10 ++-- arctic_training/trainer/trainer.py | 3 +- 4 files changed, 45 insertions(+), 44 deletions(-) diff --git a/arctic_training/data/factory.py b/arctic_training/data/factory.py index 74a4f9f3..7a0fbb64 100644 --- a/arctic_training/data/factory.py +++ b/arctic_training/data/factory.py @@ -232,7 +232,9 @@ def create_dataloader(self, dataset: DatasetType, sampler_shuffle: bool = True) return DataLoader( dataset, batch_size=self.micro_batch_size, - sampler=DistributedSampler(dataset, num_replicas=self.world_size, rank=self.global_rank, shuffle=sampler_shuffle), + sampler=DistributedSampler( + dataset, num_replicas=self.world_size, rank=self.global_rank, shuffle=sampler_shuffle + ), num_workers=self.config.dl_num_workers, persistent_workers=True, drop_last=True, diff --git a/arctic_training/data/sft_factory.py b/arctic_training/data/sft_factory.py index 816f9530..38cf5f0c 100644 --- a/arctic_training/data/sft_factory.py +++ b/arctic_training/data/sft_factory.py @@ -20,6 +20,7 @@ from typing import Literal from typing import Optional from typing import Tuple +from typing import Union import numpy as np import torch @@ -38,7 +39,16 @@ from arctic_training.data.utils import DatasetType IGNORE_INDEX = -100 -PACKING_KEYS = ("input_ids", "labels", "position_ids", "packed_sample_seqlens", "attention_mask", "packed_sort_index", "packed_seqlens_square_sum") +PACKING_KEYS = ( + "input_ids", + "labels", + "position_ids", + "packed_sample_seqlens", + "attention_mask", +) + +Packed_Data_Type = Dict[str, List[Union[List[int], int]]] + # this function is modified from TRL trl.trainer.utils.py def pad( @@ -139,16 +149,13 @@ def __call__(self, instances: List[Dict]) -> Dict[str, torch.Tensor]: # torch.tensor(example["attention_mask"]) for example in instances # ] - if "position_ids" in instances[0]: position_ids = [torch.tensor(example["position_ids"]) for example in instances] packed_sample_seqlens = [example["packed_sample_seqlens"] for example in instances] - packed_sort_index = [example["packed_sort_index"] for example in instances] packed_seqlens_square_sum = [example["packed_seqlens_square_sum"] for example in instances] else: position_ids = [torch.tensor(list(range(len(example["input_ids"])))) for example in instances] packed_sample_seqlens = [[len(example["input_ids"])] for example in instances] - packed_sort_index = [-1 for example in instances] packed_seqlens_square_sum = [-1 for example in instances] fake_unpacked_long_seq = False @@ -183,7 +190,6 @@ def __call__(self, instances: List[Dict]) -> Dict[str, torch.Tensor]: "labels": labels, "position_ids": position_ids, "packed_sample_seqlens": packed_sample_seqlens, - "packed_sort_index": packed_sort_index, "packed_seqlens_square_sum": packed_seqlens_square_sum, } @@ -195,23 +201,19 @@ def pack_sft_batch_balance_length( drop_last: bool, fuse_positions_prob: float, seed: int, -) -> Dict[str, List[List[int]]]: +) -> Packed_Data_Type: keys = PACKING_KEYS - packed_batch: Dict[str, List[List[int]]] = {k: [] for k in keys} - current_sample: Dict[str, List[int]] = {k: [] for k in keys} + packed_batch: Packed_Data_Type = {k: [] for k in keys} + packed_batch["packed_seqlens_square_sum"] = [] rng = random.Random(seed) # Best-fit-decreasing bin packing to maximize utilization of `max_length`. # This reorders the samples within the provided batch for denser packing. - samples = list( - zip(batch["input_ids"], batch["labels"], batch["attention_mask"]) - ) + samples = list(zip(batch["input_ids"], batch["labels"], batch["attention_mask"])) # Sort by length descending; tie-breaker is deterministic to keep runs reproducible. sorted_indices = sorted(range(len(samples)), key=lambda i: len(samples[i][0]), reverse=True) - print(f"pack_sft_batch: packing batch of {len(samples)=}") - bins: List[Dict[str, List[int]]] = [] bin_lengths: List[int] = [] @@ -263,8 +265,10 @@ def start_new_bin() -> int: if fuse_positions_prob and rng.random() <= fuse_positions_prob: packed["position_ids"] = list(range(len(packed["input_ids"]))) - packed["packed_sort_index"] = bin_idx - packed["packed_seqlens_square_sum"] = sum([len * len for len in packed["packed_sample_seqlens"]]) + # Add sum(seqlen^2) field + packed_batch["packed_seqlens_square_sum"].append( + sum([seqlen**2 for seqlen in packed["packed_sample_seqlens"]]) + ) for k in keys: packed_batch[k].append(packed[k]) @@ -272,21 +276,6 @@ def start_new_bin() -> int: return packed_batch -def pack_analysis(dataset: DatasetType): - - def dump_stats(input_stats, tag): - min_value = min(input_stats) - max_value = max(input_stats) - avg_value = sum(input_stats)/len(input_stats) - - print(f"pack_analysis {tag}: {min_value=} {max_value=} {avg_value=}") - - attention_compute = [sum([(len * len)/1e9 for len in batch['packed_sample_seqlens']]) for batch in dataset] - packed_len = [sum([len for len in batch['packed_sample_seqlens']]) for batch in dataset] - - dump_stats(attention_compute, "attention_compute_billions") - dump_stats(packed_len, "packed_lengths") - def pack_sft_batch_naive( batch: Dict[str, List[List[int]]], max_length: int, @@ -294,11 +283,13 @@ def pack_sft_batch_naive( drop_last: bool, fuse_positions_prob: float, seed: int, -) -> Dict[str, List[List[int]]]: +) -> Packed_Data_Type: keys = PACKING_KEYS - packed_batch: Dict[str, List[List[int]]] = {k: [] for k in keys} + packed_batch: Packed_Data_Type = {k: [] for k in keys} current_sample: Dict[str, List[int]] = {k: [] for k in keys} + packed_batch["packed_seqlens_square_sum"] = [] + rng = random.Random(seed) def should_flush() -> bool: @@ -310,8 +301,14 @@ def flush() -> None: if fuse_positions_prob and rng.random() <= fuse_positions_prob: current_sample["position_ids"] = list(range(len(current_sample["input_ids"]))) - current_sample["packed_sort_index"] = len(packed_batch["packed_sort_index"]) - current_sample["packed_seqlens_square_sum"] = sum([len * len for len in current_sample["packed_sample_seqlens"]]) + # current_sample["packed_seqlens_square_sum"] = sum( + # [len * len for len in current_sample["packed_sample_seqlens"]] + # ) + + # Add sum(seqlen^2) field + packed_batch["packed_seqlens_square_sum"].append( + sum([seqlen**2 for seqlen in current_sample["packed_sample_seqlens"]]) + ) for k in keys: packed_batch[k].append(current_sample[k]) @@ -365,7 +362,7 @@ class SFTDataConfig(DataConfig): sort_packed_samples: bool = False """ Whether to sort packed samples. """ - + sort_packed_samples_order: Literal["ascend", "descend"] = "descend" """ Sorting order for packed samples. """ @@ -449,7 +446,6 @@ def pack_dataset(self, dataset: DatasetType) -> DatasetType: if not self.config.pack_samples: return dataset - if self.config.repeat_to_pack_max_length: dataset = repeat_dataset(dataset=dataset, max_length=self.config.max_length, num_proc=self.config.num_proc) @@ -479,7 +475,9 @@ def pack_dataset(self, dataset: DatasetType) -> DatasetType: ) if self.config.sort_packed_samples: - dataset = dataset.sort("packed_seqlens_square_sum", reverse=(self.config.sort_packed_samples_order=="descend")) + dataset = dataset.sort( + "packed_seqlens_square_sum", reverse=(self.config.sort_packed_samples_order == "descend") + ) if len(dataset) < 1: raise ValueError(f"No data left after packing dataset samples in {self.__class__.__name__}") diff --git a/arctic_training/metrics.py b/arctic_training/metrics.py index 7b59dd1c..0deeee89 100644 --- a/arctic_training/metrics.py +++ b/arctic_training/metrics.py @@ -134,9 +134,9 @@ def print_summary(self, prefix: str = "train") -> None: / self.trainer.config.sequence_parallel_size ) self.values["seqlen_total"] = seqlen_subtotal - + # self.seqlens is list[list[int]] - self.values["seqlen_square_sum"] = sum([sum([len*len for len in seqlens]) for seqlens in self.seqlens]) + self.values["seqlen_square_sum"] = sum([sum([len * len for len in seqlens]) for seqlens in self.seqlens]) if "loss" in self.values: loss = sum(gather_object(self.values["loss"], self.trainer.world_size)) / self.trainer.world_size @@ -166,7 +166,7 @@ def print_summary(self, prefix: str = "train") -> None: self.summary_dict["seqlen"] = seq_len_total / self.trainer.world_size seqlen_square_sum_total = sum(gather_object(self.values["seqlen_square_sum"], self.trainer.world_size)) - self.summary_dict["seqlen_square_sum"] = seqlen_square_sum_total / self.trainer.world_size + self.summary_dict["seqlen_square_sum"] = seqlen_square_sum_total / self.trainer.world_size if "step_time" in self.values: step_time_total = sum(gather_object(self.values["step_time"], self.trainer.world_size)) @@ -193,7 +193,9 @@ def print_summary(self, prefix: str = "train") -> None: if "seqlen" in self.summary_dict: summary_str += f" | seqlen: {human_format_base10_number(self.summary_dict['seqlen'])}" if "seqlen_square_sum" in self.summary_dict: - summary_str += f" | seqlen_square_sum: {human_format_base10_number(self.summary_dict['seqlen_square_sum'])}" + summary_str += ( + f" | seqlen_square_sum: {human_format_base10_number(self.summary_dict['seqlen_square_sum'])}" + ) if "step_time" in self.summary_dict: summary_str += f" | step time: {human_format_secs(self.summary_dict['step_time'])}" if "step_tflops" in self.summary_dict: diff --git a/arctic_training/trainer/trainer.py b/arctic_training/trainer/trainer.py index 77c92bce..80a87ed8 100644 --- a/arctic_training/trainer/trainer.py +++ b/arctic_training/trainer/trainer.py @@ -471,8 +471,7 @@ def epoch(self) -> None: if "packed_sample_seqlens" in batch and "flash_attention" in self.config.model.attn_implementation: # deal correctly with packed samples under FA2/FA3, by calculating each seqlen tflos separately sample_seqlens = batch.pop("packed_sample_seqlens") - packed_sort_index = batch.pop("packed_sort_index") - packed_seqlens_square_sum = float(batch.pop("packed_seqlens_square_sum")[0]) + batch.pop("packed_seqlens_square_sum") else: sample_seqlens = [ [len(batch["input_ids"][idx]) * self.config.sequence_parallel_size] From 4b8f8f6888cdd3fd4e999e8c43680090ccf60089 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Thu, 29 Jan 2026 16:55:34 +0000 Subject: [PATCH 08/14] Formatting and cleanup Signed-off-by: Olatunji Ruwase --- arctic_training/data/sft_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arctic_training/data/sft_factory.py b/arctic_training/data/sft_factory.py index 38cf5f0c..7681674b 100644 --- a/arctic_training/data/sft_factory.py +++ b/arctic_training/data/sft_factory.py @@ -363,7 +363,7 @@ class SFTDataConfig(DataConfig): sort_packed_samples: bool = False """ Whether to sort packed samples. """ - sort_packed_samples_order: Literal["ascend", "descend"] = "descend" + sort_packed_samples_order: Literal["ascend", "descend"] = "descend" """ Sorting order for packed samples. """ drop_last: bool = False From 9945e6aee79e84d92059681beb7c005d6c006d2c Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Thu, 29 Jan 2026 16:56:03 +0000 Subject: [PATCH 09/14] Formatting Signed-off-by: Olatunji Ruwase --- arctic_training/data/sft_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arctic_training/data/sft_factory.py b/arctic_training/data/sft_factory.py index 7681674b..38cf5f0c 100644 --- a/arctic_training/data/sft_factory.py +++ b/arctic_training/data/sft_factory.py @@ -363,7 +363,7 @@ class SFTDataConfig(DataConfig): sort_packed_samples: bool = False """ Whether to sort packed samples. """ - sort_packed_samples_order: Literal["ascend", "descend"] = "descend" + sort_packed_samples_order: Literal["ascend", "descend"] = "descend" """ Sorting order for packed samples. """ drop_last: bool = False From b0a2c65085fdc31f2745524bf755482dee4fb44a Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Thu, 29 Jan 2026 17:30:41 +0000 Subject: [PATCH 10/14] UT fix Signed-off-by: Olatunji Ruwase --- arctic_training/data/factory.py | 14 ++++++++++++-- arctic_training/data/sft_factory.py | 6 +++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/arctic_training/data/factory.py b/arctic_training/data/factory.py index 7a0fbb64..48dfec33 100644 --- a/arctic_training/data/factory.py +++ b/arctic_training/data/factory.py @@ -73,6 +73,7 @@ def _validate_subclass(cls) -> None: _validate_class_method(cls, "process", ["self", "dataset"]) _validate_class_method(cls, "split_data", ["self", "training_data"]) _validate_class_method(cls, "create_dataloader", ["self", "dataset"]) + _validate_class_method(cls, "create_dataloader_no_shuffle", ["self", "dataset"]) def __init__(self, trainer: "Trainer", config: Optional[DataConfig] = None) -> None: if config is None: @@ -226,8 +227,7 @@ def split_data(self, training_data: DatasetType) -> Tuple[DatasetType, Optional[ return training_data, evaluation_data - @callback_wrapper("create_dataloader") - def create_dataloader(self, dataset: DatasetType, sampler_shuffle: bool = True) -> DataLoader: + def _create_dataloader(self, dataset: DatasetType, sampler_shuffle: bool = True) -> DataLoader: """Create a torch DataLoader from the dataset.""" return DataLoader( dataset, @@ -239,3 +239,13 @@ def create_dataloader(self, dataset: DatasetType, sampler_shuffle: bool = True) persistent_workers=True, drop_last=True, ) + + @callback_wrapper("create_dataloader") + def create_dataloader(self, dataset: DatasetType) -> DataLoader: + """Create a torch DataLoader from the dataset.""" + return self._create_dataloader(dataset) + + @callback_wrapper("create_dataloader_no_shuffle") + def create_dataloader_no_shuffle(self, dataset: DatasetType) -> DataLoader: + """Create a torch DataLoader from the dataset.""" + return self._create_dataloader(dataset, sampler_shuffle=False) diff --git a/arctic_training/data/sft_factory.py b/arctic_training/data/sft_factory.py index 38cf5f0c..fb728c70 100644 --- a/arctic_training/data/sft_factory.py +++ b/arctic_training/data/sft_factory.py @@ -596,6 +596,10 @@ def get_masked_labels(conversation_ids: BatchEncoding, assistant_ranges: List[Tu return output def create_dataloader(self, dataset: DatasetType) -> DataLoader: - dataloader = super().create_dataloader(dataset, sampler_shuffle=self.config.shuffle_samples) + dataloader = ( + super().create_dataloader(dataset) + if self.config.shuffle_samples + else super().create_dataloader_no_shuffle(dataset) + ) dataloader.collate_fn = DataCollatorForCausalLM(tokenizer=self.tokenizer, config=self.config) return dataloader From e5b8e4da0b014b3cd48f9c380602f84aca1f2eb2 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Thu, 29 Jan 2026 17:34:16 -0500 Subject: [PATCH 11/14] Update arctic_training/data/sft_factory.py Co-authored-by: Stas Bekman --- arctic_training/data/sft_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arctic_training/data/sft_factory.py b/arctic_training/data/sft_factory.py index fb728c70..c044349b 100644 --- a/arctic_training/data/sft_factory.py +++ b/arctic_training/data/sft_factory.py @@ -209,7 +209,7 @@ def pack_sft_batch_balance_length( rng = random.Random(seed) # Best-fit-decreasing bin packing to maximize utilization of `max_length`. - # This reorders the samples within the provided batch for denser packing. + # This packs multiple short samples within the provided batch into larger samples each trying to be as close as possible to max_length. samples = list(zip(batch["input_ids"], batch["labels"], batch["attention_mask"])) # Sort by length descending; tie-breaker is deterministic to keep runs reproducible. sorted_indices = sorted(range(len(samples)), key=lambda i: len(samples[i][0]), reverse=True) From 982daaa0772d68205f43229a3895fe754402cbca Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Thu, 29 Jan 2026 22:45:19 +0000 Subject: [PATCH 12/14] Fix naming and cleanup Signed-off-by: Olatunji Ruwase --- arctic_training/data/sft_factory.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/arctic_training/data/sft_factory.py b/arctic_training/data/sft_factory.py index c044349b..557eb2fc 100644 --- a/arctic_training/data/sft_factory.py +++ b/arctic_training/data/sft_factory.py @@ -301,10 +301,6 @@ def flush() -> None: if fuse_positions_prob and rng.random() <= fuse_positions_prob: current_sample["position_ids"] = list(range(len(current_sample["input_ids"]))) - # current_sample["packed_seqlens_square_sum"] = sum( - # [len * len for len in current_sample["packed_sample_seqlens"]] - # ) - # Add sum(seqlen^2) field packed_batch["packed_seqlens_square_sum"].append( sum([seqlen**2 for seqlen in current_sample["packed_sample_seqlens"]]) @@ -357,8 +353,8 @@ class SFTDataConfig(DataConfig): pack_samples_mode: Literal["naive", "balance_length"] = "naive" - shuffle_samples: bool = True - """ Whether to use a sampler that shuffles indices. """ + dl_shuffle_samples: bool = True + """ Whether dataloader should shuffles samples. """ sort_packed_samples: bool = False """ Whether to sort packed samples. """ @@ -451,7 +447,6 @@ def pack_dataset(self, dataset: DatasetType) -> DatasetType: batch_size = len(dataset) // self.config.num_proc + 1 # for huge datasets keep the bs to a sane size to avoid cpu-oom - # batch_size = int(min(batch_size, 1e3)) batch_size = int(min(batch_size, 1e4)) dataset = dataset.shuffle(seed=self.config.seed) if self.config.pack_samples_mode == "balance_length": @@ -598,7 +593,7 @@ def get_masked_labels(conversation_ids: BatchEncoding, assistant_ranges: List[Tu def create_dataloader(self, dataset: DatasetType) -> DataLoader: dataloader = ( super().create_dataloader(dataset) - if self.config.shuffle_samples + if self.config.dl_shuffle_samples else super().create_dataloader_no_shuffle(dataset) ) dataloader.collate_fn = DataCollatorForCausalLM(tokenizer=self.tokenizer, config=self.config) From f1c13e759e4eb33c1fedacfaac229584b5c076f7 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Tue, 3 Feb 2026 12:47:07 +0000 Subject: [PATCH 13/14] Local sorting Signed-off-by: Olatunji Ruwase --- arctic_training/data/sft_factory.py | 48 ++++++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/arctic_training/data/sft_factory.py b/arctic_training/data/sft_factory.py index 557eb2fc..58eae66f 100644 --- a/arctic_training/data/sft_factory.py +++ b/arctic_training/data/sft_factory.py @@ -194,6 +194,28 @@ def __call__(self, instances: List[Dict]) -> Dict[str, torch.Tensor]: } +def sort_packed_sft_batch(batch: Packed_Data_Type, reverse: bool) -> Packed_Data_Type: + packed_list = [] + packed_keys = list(batch.keys()) + + for idx in range(len(batch["input_ids"])): + packed_dict = {key: batch[key][idx] for key in packed_keys} + packed_list.append(packed_dict) + + def sum_square_compare(packed_sample): + return sum([seqlen**2 for seqlen in packed_sample["packed_sample_seqlens"]]) + + packed_list.sort(key=sum_square_compare, reverse=reverse) + + packed_batch: Packed_Data_Type = {k: [] for k in packed_keys} + + for packed_sample in packed_list: + for key in packed_keys: + packed_batch[key].append(packed_sample[key]) + + return packed_batch + + def pack_sft_batch_balance_length( batch: Dict[str, List[List[int]]], max_length: int, @@ -352,6 +374,10 @@ class SFTDataConfig(DataConfig): """ Whether to pack multiple samples into samples up to size `max_length`. """ pack_samples_mode: Literal["naive", "balance_length"] = "naive" + """ What packing algorithm to use. The default is a greedy packing algorithm""" + + max_pack_batch_size: int = 10**4 + """ Maximum batch/chunk size for packing samples. Helps to avoid CPU OOM""" dl_shuffle_samples: bool = True """ Whether dataloader should shuffles samples. """ @@ -362,6 +388,9 @@ class SFTDataConfig(DataConfig): sort_packed_samples_order: Literal["ascend", "descend"] = "descend" """ Sorting order for packed samples. """ + sort_packed_samples_scope: Literal["local", "global"] = "local" + """ Sorting order for packed samples. """ + drop_last: bool = False """ Whether to drop the last packed sample, which might be shorter than `max_length`. """ @@ -446,8 +475,10 @@ def pack_dataset(self, dataset: DatasetType) -> DatasetType: dataset = repeat_dataset(dataset=dataset, max_length=self.config.max_length, num_proc=self.config.num_proc) batch_size = len(dataset) // self.config.num_proc + 1 + # for huge datasets keep the bs to a sane size to avoid cpu-oom - batch_size = int(min(batch_size, 1e4)) + batch_size = int(min(batch_size, self.config.max_pack_batch_size)) + dataset = dataset.shuffle(seed=self.config.seed) if self.config.pack_samples_mode == "balance_length": packing_fn = pack_sft_batch_balance_length @@ -470,9 +501,18 @@ def pack_dataset(self, dataset: DatasetType) -> DatasetType: ) if self.config.sort_packed_samples: - dataset = dataset.sort( - "packed_seqlens_square_sum", reverse=(self.config.sort_packed_samples_order == "descend") - ) + if self.config.sort_packed_samples_scope == "local": + dataset = dataset.map( + lambda x: sort_packed_sft_batch(x, reverse=(self.config.sort_packed_samples_order == "descend")), + batched=True, + batch_size=batch_size, + num_proc=self.config.num_proc, + desc="Local sorting dataset", + ) + else: + dataset = dataset.sort( + "packed_seqlens_square_sum", reverse=(self.config.sort_packed_samples_order == "descend") + ) if len(dataset) < 1: raise ValueError(f"No data left after packing dataset samples in {self.__class__.__name__}") From fa91599c716a1a0832fb3174073a3c186984f337 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Thu, 5 Feb 2026 19:47:40 +0000 Subject: [PATCH 14/14] PR feedback Signed-off-by: Olatunji Ruwase --- arctic_training/data/sft_factory.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/arctic_training/data/sft_factory.py b/arctic_training/data/sft_factory.py index 58eae66f..fd72434c 100644 --- a/arctic_training/data/sft_factory.py +++ b/arctic_training/data/sft_factory.py @@ -373,7 +373,7 @@ class SFTDataConfig(DataConfig): pack_samples: bool = False """ Whether to pack multiple samples into samples up to size `max_length`. """ - pack_samples_mode: Literal["naive", "balance_length"] = "naive" + pack_samples_mode: Literal["greedy", "balance_length"] = "greedy" """ What packing algorithm to use. The default is a greedy packing algorithm""" max_pack_batch_size: int = 10**4 @@ -388,7 +388,7 @@ class SFTDataConfig(DataConfig): sort_packed_samples_order: Literal["ascend", "descend"] = "descend" """ Sorting order for packed samples. """ - sort_packed_samples_scope: Literal["local", "global"] = "local" + sort_packed_samples_scope: Literal["batch", "all"] = "batch" """ Sorting order for packed samples. """ drop_last: bool = False @@ -501,7 +501,7 @@ def pack_dataset(self, dataset: DatasetType) -> DatasetType: ) if self.config.sort_packed_samples: - if self.config.sort_packed_samples_scope == "local": + if self.config.sort_packed_samples_scope == "batch": dataset = dataset.map( lambda x: sort_packed_sft_batch(x, reverse=(self.config.sort_packed_samples_order == "descend")), batched=True,