-
Notifications
You must be signed in to change notification settings - Fork 40
Improve sample packing #347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 11 commits
14b90bd
ff821ba
e47c58c
e20671e
61fb86e
5cc68ec
41c56e7
12a4fbf
4b8f8f6
9945e6a
b0a2c65
e5b8e4d
982daaa
f1c13e7
fa91599
cf1b04a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,7 @@ | |
| from typing import Literal | ||
| from typing import Optional | ||
| from typing import Tuple | ||
| from typing import Union | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
@@ -38,6 +39,15 @@ | |
| from arctic_training.data.utils import DatasetType | ||
|
|
||
| IGNORE_INDEX = -100 | ||
| PACKING_KEYS = ( | ||
| "input_ids", | ||
| "labels", | ||
| "position_ids", | ||
| "packed_sample_seqlens", | ||
| "attention_mask", | ||
| ) | ||
|
|
||
| Packed_Data_Type = Dict[str, List[Union[List[int], int]]] | ||
|
|
||
|
|
||
| # this function is modified from TRL trl.trainer.utils.py | ||
|
|
@@ -142,9 +152,11 @@ def __call__(self, instances: List[Dict]) -> Dict[str, torch.Tensor]: | |
| if "position_ids" in instances[0]: | ||
| position_ids = [torch.tensor(example["position_ids"]) for example in instances] | ||
| packed_sample_seqlens = [example["packed_sample_seqlens"] for example in instances] | ||
| packed_seqlens_square_sum = [example["packed_seqlens_square_sum"] for example in instances] | ||
| else: | ||
| position_ids = [torch.tensor(list(range(len(example["input_ids"])))) for example in instances] | ||
| packed_sample_seqlens = [[len(example["input_ids"])] for example in instances] | ||
| packed_seqlens_square_sum = [-1 for example in instances] | ||
|
|
||
| fake_unpacked_long_seq = False | ||
| # fake_unpacked_long_seq = True | ||
|
|
@@ -178,21 +190,106 @@ def __call__(self, instances: List[Dict]) -> Dict[str, torch.Tensor]: | |
| "labels": labels, | ||
| "position_ids": position_ids, | ||
| "packed_sample_seqlens": packed_sample_seqlens, | ||
| "packed_seqlens_square_sum": packed_seqlens_square_sum, | ||
| } | ||
|
|
||
|
|
||
| def pack_sft_batch( | ||
| def pack_sft_batch_balance_length( | ||
| batch: Dict[str, List[List[int]]], | ||
| max_length: int, | ||
| always_max_length: bool, | ||
| drop_last: bool, | ||
| fuse_positions_prob: float, | ||
| seed: int, | ||
| ) -> Packed_Data_Type: | ||
| keys = PACKING_KEYS | ||
| packed_batch: Packed_Data_Type = {k: [] for k in keys} | ||
| packed_batch["packed_seqlens_square_sum"] = [] | ||
|
|
||
| rng = random.Random(seed) | ||
|
|
||
| # Best-fit-decreasing bin packing to maximize utilization of `max_length`. | ||
| # This reorders the samples within the provided batch for denser packing. | ||
|
sfc-gh-truwase marked this conversation as resolved.
Outdated
|
||
| samples = list(zip(batch["input_ids"], batch["labels"], batch["attention_mask"])) | ||
| # Sort by length descending; tie-breaker is deterministic to keep runs reproducible. | ||
| sorted_indices = sorted(range(len(samples)), key=lambda i: len(samples[i][0]), reverse=True) | ||
|
|
||
| bins: List[Dict[str, List[int]]] = [] | ||
| bin_lengths: List[int] = [] | ||
|
|
||
| def start_new_bin() -> int: | ||
| bins.append({k: [] for k in keys}) | ||
| bin_lengths.append(0) | ||
| return len(bins) - 1 | ||
|
|
||
| for idx in sorted_indices: | ||
| input_ids, labels, attention_mask = samples[idx] | ||
| sample_len = len(input_ids) | ||
|
|
||
| # Find the bin that leaves the least remaining space after insertion. | ||
| best_bin = None | ||
| best_remaining = None | ||
| for bin_idx, current_len in enumerate(bin_lengths): | ||
| remaining = max_length - current_len | ||
| if remaining <= 0: | ||
| continue | ||
| if not always_max_length and sample_len > remaining: | ||
| continue | ||
| take_len = min(sample_len, remaining) | ||
| remaining_after = remaining - take_len | ||
| if best_remaining is None or remaining_after < best_remaining: | ||
| best_remaining = remaining_after | ||
| best_bin = bin_idx | ||
|
|
||
| if best_bin is None: | ||
| best_bin = start_new_bin() | ||
|
|
||
| target_bin = bins[best_bin] | ||
| remaining = max_length - bin_lengths[best_bin] | ||
| if remaining <= 0: | ||
| continue # should not happen, but guard against negative remaining | ||
| take_len = min(sample_len, remaining) if always_max_length else sample_len | ||
| take_len = min(take_len, remaining) | ||
|
|
||
| target_bin["input_ids"].extend(input_ids[:take_len]) | ||
| target_bin["labels"].extend(labels[:take_len]) | ||
| target_bin["attention_mask"].extend(attention_mask[:take_len]) | ||
| target_bin["position_ids"].extend(range(take_len)) | ||
| target_bin["packed_sample_seqlens"].append(take_len) | ||
| bin_lengths[best_bin] += take_len | ||
|
|
||
| for bin_idx, packed in enumerate(bins): | ||
| total_len = bin_lengths[bin_idx] | ||
| if drop_last and total_len < max_length: | ||
| continue | ||
| if fuse_positions_prob and rng.random() <= fuse_positions_prob: | ||
| packed["position_ids"] = list(range(len(packed["input_ids"]))) | ||
|
|
||
| # Add sum(seqlen^2) field | ||
| packed_batch["packed_seqlens_square_sum"].append( | ||
| sum([seqlen**2 for seqlen in packed["packed_sample_seqlens"]]) | ||
| ) | ||
|
|
||
| for k in keys: | ||
| packed_batch[k].append(packed[k]) | ||
|
|
||
| return packed_batch | ||
|
|
||
|
|
||
| def pack_sft_batch_naive( | ||
| batch: Dict[str, List[List[int]]], | ||
| max_length: int, | ||
| always_max_length: bool, | ||
| drop_last: bool, | ||
| fuse_positions_prob: float, | ||
| seed: int, | ||
| ) -> Dict[str, List[List[int]]]: | ||
| keys = ("input_ids", "labels", "position_ids", "packed_sample_seqlens", "attention_mask") | ||
| packed_batch: Dict[str, List[List[int]]] = {k: [] for k in keys} | ||
| ) -> Packed_Data_Type: | ||
| keys = PACKING_KEYS | ||
| packed_batch: Packed_Data_Type = {k: [] for k in keys} | ||
| current_sample: Dict[str, List[int]] = {k: [] for k in keys} | ||
|
|
||
| packed_batch["packed_seqlens_square_sum"] = [] | ||
|
|
||
| rng = random.Random(seed) | ||
|
|
||
| def should_flush() -> bool: | ||
|
|
@@ -203,6 +300,16 @@ def flush() -> None: | |
| if len(current_sample["input_ids"]) > 0: | ||
| if fuse_positions_prob and rng.random() <= fuse_positions_prob: | ||
| current_sample["position_ids"] = list(range(len(current_sample["input_ids"]))) | ||
|
|
||
| # current_sample["packed_seqlens_square_sum"] = sum( | ||
| # [len * len for len in current_sample["packed_sample_seqlens"]] | ||
| # ) | ||
|
sfc-gh-truwase marked this conversation as resolved.
Outdated
|
||
|
|
||
| # Add sum(seqlen^2) field | ||
| packed_batch["packed_seqlens_square_sum"].append( | ||
| sum([seqlen**2 for seqlen in current_sample["packed_sample_seqlens"]]) | ||
| ) | ||
|
|
||
| for k in keys: | ||
| packed_batch[k].append(current_sample[k]) | ||
| current_sample[k] = [] | ||
|
|
@@ -248,6 +355,17 @@ class SFTDataConfig(DataConfig): | |
| pack_samples: bool = False | ||
| """ Whether to pack multiple samples into samples up to size `max_length`. """ | ||
|
|
||
| pack_samples_mode: Literal["naive", "balance_length"] = "naive" | ||
|
|
||
| shuffle_samples: bool = True | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this name is ambiguous as it doesn't tell when samples get shuffled - and we do want them to get shuffled always! if we don't we should fix that. I think this should say
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will fix naming.
Can you clarify why we want shuffling always, since it is a major cause of imbalance across the ranks.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was trying to say that the dataset must always be shuffled for proper training. So if it doesn't get shuffled at the DL level then we need to shuffle it at the dataset level before we sort or pack, so if we are going for this work of this PR, we have 2 choices:
I think (1) makes the most sense, since we want to shuffle at the source level, packing first is likely to result in less randomization. And of course as flagged by your experiments we have an issue with multiple datasets not being blended but concatenated which leads to loss spikes when domain changes abruptly. But of course this is out of the scope of this PR.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it. Thanks for the clarification. I think we are talking about shuffling from different angles. The new
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
My understanding is that with either option, the preprocessed dataset will first be saved to disk before loading for the actual training. Is that correct?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes and no. If you're taking away the shuffle at DL level, we need to make sure shuffling happens elsewhere.
That's correct, we cache the result. Only. if hparams change it'll get rebuilt.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Got it. That means the perf balancing recommendation of this PR to avoid shuffling is not a practical solution.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the dataset happens to be pre-shuffled enough, then the training quality shouldn't be impacted, but if it isn't all sorts of skewed learning may occur. So I'd check with the modeling guys first. But it should be easy to overcome, if dl_shuffle is false, do
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My concern is how packing/sorting affects the random order of the pre-shuffled dataset. I realize your proposal to sort within chunks, as opposed to globally, should help to retain some randomness. But it is unclear to me how effective that would be. I will get more data and follow up with you. Thanks!
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remember that when a shuffled data is packed even a few small samples together, the new long samples' content is already shuffled, so sorting these longer samples will not undo the shuffling effect. Does it make sense? The randomness will only be lost if most samples are already of max_len, which is highly unlikely. Regardless, it massively beats the global sorting. So shuffle first then pack should be a solid strategy to retain randomness. |
||
| """ Whether to use a sampler that shuffles indices. """ | ||
|
|
||
| sort_packed_samples: bool = False | ||
| """ Whether to sort packed samples. """ | ||
|
|
||
| sort_packed_samples_order: Literal["ascend", "descend"] = "descend" | ||
| """ Sorting order for packed samples. """ | ||
|
|
||
| drop_last: bool = False | ||
| """ Whether to drop the last packed sample, which might be shorter than `max_length`. """ | ||
|
|
||
|
|
@@ -333,10 +451,16 @@ def pack_dataset(self, dataset: DatasetType) -> DatasetType: | |
|
|
||
| batch_size = len(dataset) // self.config.num_proc + 1 | ||
| # for huge datasets keep the bs to a sane size to avoid cpu-oom | ||
| batch_size = int(min(batch_size, 1e3)) | ||
| # batch_size = int(min(batch_size, 1e3)) | ||
|
sfc-gh-truwase marked this conversation as resolved.
Outdated
|
||
| batch_size = int(min(batch_size, 1e4)) | ||
| dataset = dataset.shuffle(seed=self.config.seed) | ||
| if self.config.pack_samples_mode == "balance_length": | ||
| packing_fn = pack_sft_batch_balance_length | ||
| else: | ||
| packing_fn = pack_sft_batch_naive | ||
|
|
||
| dataset = dataset.map( | ||
| lambda x: pack_sft_batch( | ||
| lambda x: packing_fn( | ||
| x, | ||
| max_length=self.config.max_length, | ||
| always_max_length=self.config.always_max_length, | ||
|
|
@@ -349,6 +473,12 @@ def pack_dataset(self, dataset: DatasetType) -> DatasetType: | |
| num_proc=self.config.num_proc, | ||
| desc="Packing dataset", | ||
| ) | ||
|
|
||
| if self.config.sort_packed_samples: | ||
| dataset = dataset.sort( | ||
| "packed_seqlens_square_sum", reverse=(self.config.sort_packed_samples_order == "descend") | ||
| ) | ||
|
|
||
| if len(dataset) < 1: | ||
| raise ValueError(f"No data left after packing dataset samples in {self.__class__.__name__}") | ||
| return dataset | ||
|
|
@@ -466,6 +596,10 @@ def get_masked_labels(conversation_ids: BatchEncoding, assistant_ranges: List[Tu | |
| return output | ||
|
|
||
| def create_dataloader(self, dataset: DatasetType) -> DataLoader: | ||
| dataloader = super().create_dataloader(dataset) | ||
| dataloader = ( | ||
| super().create_dataloader(dataset) | ||
| if self.config.shuffle_samples | ||
| else super().create_dataloader_no_shuffle(dataset) | ||
| ) | ||
| dataloader.collate_fn = DataCollatorForCausalLM(tokenizer=self.tokenizer, config=self.config) | ||
| return dataloader | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to require a new method for this? Why not extend the existing
create_dataloader? This would avoid the need to add this new method for each data factory (_validate_class_method(cls, "create_dataloader_no_shuffle", ["self", "dataset"]))There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I previously extended the
create_dataloaderwith optionalshuffleflag. However, this caused UT failure onArcticTraining/arctic_training/data/factory.py
Line 75 in f472557
It seems the validation does not support optional args, or at least I don't know how to achieve that.
I also didn't want
shuffleto be mandatory forcreate_dataloader. But if this is preferred, I can make that change.What do you prefer?