diff --git a/arctic_training/data/sft_factory.py b/arctic_training/data/sft_factory.py index 22283d14..1981d4b4 100644 --- a/arctic_training/data/sft_factory.py +++ b/arctic_training/data/sft_factory.py @@ -195,32 +195,66 @@ def pack_sft_batch( rng = random.Random(seed) - def should_flush() -> bool: - total_len = len(current_sample["input_ids"]) - return total_len > max_length or (not always_max_length and total_len + len(input_ids) > max_length) - - def flush() -> None: - if len(current_sample["input_ids"]) > 0: - if fuse_positions_prob and rng.random() <= fuse_positions_prob: - current_sample["position_ids"] = list(range(len(current_sample["input_ids"]))) - for k in keys: - packed_batch[k].append(current_sample[k]) - current_sample[k] = [] - - # Pack multiple samples into one sample - for input_ids, labels, attention_mask in zip(batch["input_ids"], batch["labels"], batch["attention_mask"]): - if should_flush(): - flush() - - current_sample["input_ids"].extend(input_ids) - current_sample["labels"].extend(labels) - current_sample["attention_mask"].extend(attention_mask) - current_sample["position_ids"].extend(range(len(input_ids))) - current_sample["packed_sample_seqlens"].extend([len(input_ids)]) - - # Add the last example - if not drop_last: - flush() + # Best-fit-decreasing bin packing to maximize utilization of `max_length`. + # This reorders the samples within the provided batch for denser packing. + samples = list( + zip(batch["input_ids"], batch["labels"], batch["attention_mask"]) + ) + # Sort by length descending; tie-breaker is deterministic to keep runs reproducible. + sorted_indices = sorted(range(len(samples)), key=lambda i: len(samples[i][0]), reverse=True) + + bins: List[Dict[str, List[int]]] = [] + bin_lengths: List[int] = [] + + def start_new_bin() -> int: + bins.append({k: [] for k in keys}) + bin_lengths.append(0) + return len(bins) - 1 + + for idx in sorted_indices: + input_ids, labels, attention_mask = samples[idx] + sample_len = len(input_ids) + + # Find the bin that leaves the least remaining space after insertion. + best_bin = None + best_remaining = None + for bin_idx, current_len in enumerate(bin_lengths): + remaining = max_length - current_len + if remaining <= 0: + continue + if not always_max_length and sample_len > remaining: + continue + take_len = min(sample_len, remaining) + remaining_after = remaining - take_len + if best_remaining is None or remaining_after < best_remaining: + best_remaining = remaining_after + best_bin = bin_idx + + if best_bin is None: + best_bin = start_new_bin() + + target_bin = bins[best_bin] + remaining = max_length - bin_lengths[best_bin] + if remaining <= 0: + continue # should not happen, but guard against negative remaining + take_len = min(sample_len, remaining) if always_max_length else sample_len + take_len = min(take_len, remaining) + + target_bin["input_ids"].extend(input_ids[:take_len]) + target_bin["labels"].extend(labels[:take_len]) + target_bin["attention_mask"].extend(attention_mask[:take_len]) + target_bin["position_ids"].extend(range(take_len)) + target_bin["packed_sample_seqlens"].append(take_len) + bin_lengths[best_bin] += take_len + + for bin_idx, packed in enumerate(bins): + total_len = bin_lengths[bin_idx] + if drop_last and total_len < max_length: + continue + if fuse_positions_prob and rng.random() <= fuse_positions_prob: + packed["position_ids"] = list(range(len(packed["input_ids"]))) + for k in keys: + packed_batch[k].append(packed[k]) return packed_batch @@ -333,7 +367,7 @@ def pack_dataset(self, dataset: DatasetType) -> DatasetType: batch_size = len(dataset) // self.config.num_proc + 1 # for huge datasets keep the bs to a sane size to avoid cpu-oom - batch_size = int(min(batch_size, 1e3)) + batch_size = int(min(batch_size, 1e4)) dataset = dataset.shuffle(seed=self.config.seed) dataset = dataset.map( lambda x: pack_sft_batch(