diff --git a/arctic_training/data/factory.py b/arctic_training/data/factory.py index 82c7aec4..48dfec33 100644 --- a/arctic_training/data/factory.py +++ b/arctic_training/data/factory.py @@ -73,6 +73,7 @@ def _validate_subclass(cls) -> None: _validate_class_method(cls, "process", ["self", "dataset"]) _validate_class_method(cls, "split_data", ["self", "training_data"]) _validate_class_method(cls, "create_dataloader", ["self", "dataset"]) + _validate_class_method(cls, "create_dataloader_no_shuffle", ["self", "dataset"]) def __init__(self, trainer: "Trainer", config: Optional[DataConfig] = None) -> None: if config is None: @@ -226,14 +227,25 @@ def split_data(self, training_data: DatasetType) -> Tuple[DatasetType, Optional[ return training_data, evaluation_data - @callback_wrapper("create_dataloader") - def create_dataloader(self, dataset: DatasetType) -> DataLoader: + def _create_dataloader(self, dataset: DatasetType, sampler_shuffle: bool = True) -> DataLoader: """Create a torch DataLoader from the dataset.""" return DataLoader( dataset, batch_size=self.micro_batch_size, - sampler=DistributedSampler(dataset, num_replicas=self.world_size, rank=self.global_rank), + sampler=DistributedSampler( + dataset, num_replicas=self.world_size, rank=self.global_rank, shuffle=sampler_shuffle + ), num_workers=self.config.dl_num_workers, persistent_workers=True, drop_last=True, ) + + @callback_wrapper("create_dataloader") + def create_dataloader(self, dataset: DatasetType) -> DataLoader: + """Create a torch DataLoader from the dataset.""" + return self._create_dataloader(dataset) + + @callback_wrapper("create_dataloader_no_shuffle") + def create_dataloader_no_shuffle(self, dataset: DatasetType) -> DataLoader: + """Create a torch DataLoader from the dataset.""" + return self._create_dataloader(dataset, sampler_shuffle=False) diff --git a/arctic_training/data/sft_factory.py b/arctic_training/data/sft_factory.py index 22283d14..fd72434c 100644 --- a/arctic_training/data/sft_factory.py +++ b/arctic_training/data/sft_factory.py @@ -20,6 +20,7 @@ from typing import Literal from typing import Optional from typing import Tuple +from typing import Union import numpy as np import torch @@ -38,6 +39,15 @@ from arctic_training.data.utils import DatasetType IGNORE_INDEX = -100 +PACKING_KEYS = ( + "input_ids", + "labels", + "position_ids", + "packed_sample_seqlens", + "attention_mask", +) + +Packed_Data_Type = Dict[str, List[Union[List[int], int]]] # this function is modified from TRL trl.trainer.utils.py @@ -142,9 +152,11 @@ def __call__(self, instances: List[Dict]) -> Dict[str, torch.Tensor]: if "position_ids" in instances[0]: position_ids = [torch.tensor(example["position_ids"]) for example in instances] packed_sample_seqlens = [example["packed_sample_seqlens"] for example in instances] + packed_seqlens_square_sum = [example["packed_seqlens_square_sum"] for example in instances] else: position_ids = [torch.tensor(list(range(len(example["input_ids"])))) for example in instances] packed_sample_seqlens = [[len(example["input_ids"])] for example in instances] + packed_seqlens_square_sum = [-1 for example in instances] fake_unpacked_long_seq = False # fake_unpacked_long_seq = True @@ -178,21 +190,128 @@ def __call__(self, instances: List[Dict]) -> Dict[str, torch.Tensor]: "labels": labels, "position_ids": position_ids, "packed_sample_seqlens": packed_sample_seqlens, + "packed_seqlens_square_sum": packed_seqlens_square_sum, } -def pack_sft_batch( +def sort_packed_sft_batch(batch: Packed_Data_Type, reverse: bool) -> Packed_Data_Type: + packed_list = [] + packed_keys = list(batch.keys()) + + for idx in range(len(batch["input_ids"])): + packed_dict = {key: batch[key][idx] for key in packed_keys} + packed_list.append(packed_dict) + + def sum_square_compare(packed_sample): + return sum([seqlen**2 for seqlen in packed_sample["packed_sample_seqlens"]]) + + packed_list.sort(key=sum_square_compare, reverse=reverse) + + packed_batch: Packed_Data_Type = {k: [] for k in packed_keys} + + for packed_sample in packed_list: + for key in packed_keys: + packed_batch[key].append(packed_sample[key]) + + return packed_batch + + +def pack_sft_batch_balance_length( batch: Dict[str, List[List[int]]], max_length: int, always_max_length: bool, drop_last: bool, fuse_positions_prob: float, seed: int, -) -> Dict[str, List[List[int]]]: - keys = ("input_ids", "labels", "position_ids", "packed_sample_seqlens", "attention_mask") - packed_batch: Dict[str, List[List[int]]] = {k: [] for k in keys} +) -> Packed_Data_Type: + keys = PACKING_KEYS + packed_batch: Packed_Data_Type = {k: [] for k in keys} + packed_batch["packed_seqlens_square_sum"] = [] + + rng = random.Random(seed) + + # Best-fit-decreasing bin packing to maximize utilization of `max_length`. + # This packs multiple short samples within the provided batch into larger samples each trying to be as close as possible to max_length. + samples = list(zip(batch["input_ids"], batch["labels"], batch["attention_mask"])) + # Sort by length descending; tie-breaker is deterministic to keep runs reproducible. + sorted_indices = sorted(range(len(samples)), key=lambda i: len(samples[i][0]), reverse=True) + + bins: List[Dict[str, List[int]]] = [] + bin_lengths: List[int] = [] + + def start_new_bin() -> int: + bins.append({k: [] for k in keys}) + bin_lengths.append(0) + return len(bins) - 1 + + for idx in sorted_indices: + input_ids, labels, attention_mask = samples[idx] + sample_len = len(input_ids) + + # Find the bin that leaves the least remaining space after insertion. + best_bin = None + best_remaining = None + for bin_idx, current_len in enumerate(bin_lengths): + remaining = max_length - current_len + if remaining <= 0: + continue + if not always_max_length and sample_len > remaining: + continue + take_len = min(sample_len, remaining) + remaining_after = remaining - take_len + if best_remaining is None or remaining_after < best_remaining: + best_remaining = remaining_after + best_bin = bin_idx + + if best_bin is None: + best_bin = start_new_bin() + + target_bin = bins[best_bin] + remaining = max_length - bin_lengths[best_bin] + if remaining <= 0: + continue # should not happen, but guard against negative remaining + take_len = min(sample_len, remaining) if always_max_length else sample_len + take_len = min(take_len, remaining) + + target_bin["input_ids"].extend(input_ids[:take_len]) + target_bin["labels"].extend(labels[:take_len]) + target_bin["attention_mask"].extend(attention_mask[:take_len]) + target_bin["position_ids"].extend(range(take_len)) + target_bin["packed_sample_seqlens"].append(take_len) + bin_lengths[best_bin] += take_len + + for bin_idx, packed in enumerate(bins): + total_len = bin_lengths[bin_idx] + if drop_last and total_len < max_length: + continue + if fuse_positions_prob and rng.random() <= fuse_positions_prob: + packed["position_ids"] = list(range(len(packed["input_ids"]))) + + # Add sum(seqlen^2) field + packed_batch["packed_seqlens_square_sum"].append( + sum([seqlen**2 for seqlen in packed["packed_sample_seqlens"]]) + ) + + for k in keys: + packed_batch[k].append(packed[k]) + + return packed_batch + + +def pack_sft_batch_naive( + batch: Dict[str, List[List[int]]], + max_length: int, + always_max_length: bool, + drop_last: bool, + fuse_positions_prob: float, + seed: int, +) -> Packed_Data_Type: + keys = PACKING_KEYS + packed_batch: Packed_Data_Type = {k: [] for k in keys} current_sample: Dict[str, List[int]] = {k: [] for k in keys} + packed_batch["packed_seqlens_square_sum"] = [] + rng = random.Random(seed) def should_flush() -> bool: @@ -203,6 +322,12 @@ def flush() -> None: if len(current_sample["input_ids"]) > 0: if fuse_positions_prob and rng.random() <= fuse_positions_prob: current_sample["position_ids"] = list(range(len(current_sample["input_ids"]))) + + # Add sum(seqlen^2) field + packed_batch["packed_seqlens_square_sum"].append( + sum([seqlen**2 for seqlen in current_sample["packed_sample_seqlens"]]) + ) + for k in keys: packed_batch[k].append(current_sample[k]) current_sample[k] = [] @@ -248,6 +373,24 @@ class SFTDataConfig(DataConfig): pack_samples: bool = False """ Whether to pack multiple samples into samples up to size `max_length`. """ + pack_samples_mode: Literal["greedy", "balance_length"] = "greedy" + """ What packing algorithm to use. The default is a greedy packing algorithm""" + + max_pack_batch_size: int = 10**4 + """ Maximum batch/chunk size for packing samples. Helps to avoid CPU OOM""" + + dl_shuffle_samples: bool = True + """ Whether dataloader should shuffles samples. """ + + sort_packed_samples: bool = False + """ Whether to sort packed samples. """ + + sort_packed_samples_order: Literal["ascend", "descend"] = "descend" + """ Sorting order for packed samples. """ + + sort_packed_samples_scope: Literal["batch", "all"] = "batch" + """ Sorting order for packed samples. """ + drop_last: bool = False """ Whether to drop the last packed sample, which might be shorter than `max_length`. """ @@ -332,11 +475,18 @@ def pack_dataset(self, dataset: DatasetType) -> DatasetType: dataset = repeat_dataset(dataset=dataset, max_length=self.config.max_length, num_proc=self.config.num_proc) batch_size = len(dataset) // self.config.num_proc + 1 + # for huge datasets keep the bs to a sane size to avoid cpu-oom - batch_size = int(min(batch_size, 1e3)) + batch_size = int(min(batch_size, self.config.max_pack_batch_size)) + dataset = dataset.shuffle(seed=self.config.seed) + if self.config.pack_samples_mode == "balance_length": + packing_fn = pack_sft_batch_balance_length + else: + packing_fn = pack_sft_batch_naive + dataset = dataset.map( - lambda x: pack_sft_batch( + lambda x: packing_fn( x, max_length=self.config.max_length, always_max_length=self.config.always_max_length, @@ -349,6 +499,21 @@ def pack_dataset(self, dataset: DatasetType) -> DatasetType: num_proc=self.config.num_proc, desc="Packing dataset", ) + + if self.config.sort_packed_samples: + if self.config.sort_packed_samples_scope == "batch": + dataset = dataset.map( + lambda x: sort_packed_sft_batch(x, reverse=(self.config.sort_packed_samples_order == "descend")), + batched=True, + batch_size=batch_size, + num_proc=self.config.num_proc, + desc="Local sorting dataset", + ) + else: + dataset = dataset.sort( + "packed_seqlens_square_sum", reverse=(self.config.sort_packed_samples_order == "descend") + ) + if len(dataset) < 1: raise ValueError(f"No data left after packing dataset samples in {self.__class__.__name__}") return dataset @@ -466,6 +631,10 @@ def get_masked_labels(conversation_ids: BatchEncoding, assistant_ranges: List[Tu return output def create_dataloader(self, dataset: DatasetType) -> DataLoader: - dataloader = super().create_dataloader(dataset) + dataloader = ( + super().create_dataloader(dataset) + if self.config.dl_shuffle_samples + else super().create_dataloader_no_shuffle(dataset) + ) dataloader.collate_fn = DataCollatorForCausalLM(tokenizer=self.tokenizer, config=self.config) return dataloader diff --git a/arctic_training/metrics.py b/arctic_training/metrics.py index 4e2885e8..0deeee89 100644 --- a/arctic_training/metrics.py +++ b/arctic_training/metrics.py @@ -135,6 +135,9 @@ def print_summary(self, prefix: str = "train") -> None: ) self.values["seqlen_total"] = seqlen_subtotal + # self.seqlens is list[list[int]] + self.values["seqlen_square_sum"] = sum([sum([len * len for len in seqlens]) for seqlens in self.seqlens]) + if "loss" in self.values: loss = sum(gather_object(self.values["loss"], self.trainer.world_size)) / self.trainer.world_size self.summary_dict["loss"] = loss @@ -162,6 +165,9 @@ def print_summary(self, prefix: str = "train") -> None: seq_len_total = sum(gather_object(self.values["seqlen_total"], self.trainer.world_size)) self.summary_dict["seqlen"] = seq_len_total / self.trainer.world_size + seqlen_square_sum_total = sum(gather_object(self.values["seqlen_square_sum"], self.trainer.world_size)) + self.summary_dict["seqlen_square_sum"] = seqlen_square_sum_total / self.trainer.world_size + if "step_time" in self.values: step_time_total = sum(gather_object(self.values["step_time"], self.trainer.world_size)) self.summary_dict["step_time"] = step_time_total / self.trainer.world_size @@ -186,6 +192,10 @@ def print_summary(self, prefix: str = "train") -> None: summary_str += f" | lr: {self.summary_dict['lr']:.3E}" if "seqlen" in self.summary_dict: summary_str += f" | seqlen: {human_format_base10_number(self.summary_dict['seqlen'])}" + if "seqlen_square_sum" in self.summary_dict: + summary_str += ( + f" | seqlen_square_sum: {human_format_base10_number(self.summary_dict['seqlen_square_sum'])}" + ) if "step_time" in self.summary_dict: summary_str += f" | step time: {human_format_secs(self.summary_dict['step_time'])}" if "step_tflops" in self.summary_dict: diff --git a/arctic_training/trainer/trainer.py b/arctic_training/trainer/trainer.py index 29686c14..7d5e32cc 100644 --- a/arctic_training/trainer/trainer.py +++ b/arctic_training/trainer/trainer.py @@ -472,6 +472,7 @@ def epoch(self) -> None: if "packed_sample_seqlens" in batch and "flash_attention" in self.config.model.attn_implementation: # deal correctly with packed samples under FA2/FA3, by calculating each seqlen tflos separately sample_seqlens = batch.pop("packed_sample_seqlens") + batch.pop("packed_seqlens_square_sum") else: sample_seqlens = [ [len(batch["input_ids"][idx]) * self.config.sequence_parallel_size]