Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 61 additions & 27 deletions arctic_training/data/sft_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,32 +195,66 @@ def pack_sft_batch(

rng = random.Random(seed)

def should_flush() -> bool:
total_len = len(current_sample["input_ids"])
return total_len > max_length or (not always_max_length and total_len + len(input_ids) > max_length)

def flush() -> None:
if len(current_sample["input_ids"]) > 0:
if fuse_positions_prob and rng.random() <= fuse_positions_prob:
current_sample["position_ids"] = list(range(len(current_sample["input_ids"])))
for k in keys:
packed_batch[k].append(current_sample[k])
current_sample[k] = []

# Pack multiple samples into one sample
for input_ids, labels, attention_mask in zip(batch["input_ids"], batch["labels"], batch["attention_mask"]):
if should_flush():
flush()

current_sample["input_ids"].extend(input_ids)
current_sample["labels"].extend(labels)
current_sample["attention_mask"].extend(attention_mask)
current_sample["position_ids"].extend(range(len(input_ids)))
current_sample["packed_sample_seqlens"].extend([len(input_ids)])

# Add the last example
if not drop_last:
flush()
# Best-fit-decreasing bin packing to maximize utilization of `max_length`.
# This reorders the samples within the provided batch for denser packing.
samples = list(
zip(batch["input_ids"], batch["labels"], batch["attention_mask"])
)
# Sort by length descending; tie-breaker is deterministic to keep runs reproducible.
sorted_indices = sorted(range(len(samples)), key=lambda i: len(samples[i][0]), reverse=True)

bins: List[Dict[str, List[int]]] = []
bin_lengths: List[int] = []

def start_new_bin() -> int:
bins.append({k: [] for k in keys})
bin_lengths.append(0)
return len(bins) - 1

for idx in sorted_indices:
input_ids, labels, attention_mask = samples[idx]
sample_len = len(input_ids)

# Find the bin that leaves the least remaining space after insertion.
best_bin = None
best_remaining = None
for bin_idx, current_len in enumerate(bin_lengths):
remaining = max_length - current_len
if remaining <= 0:
continue
if not always_max_length and sample_len > remaining:
continue
take_len = min(sample_len, remaining)
remaining_after = remaining - take_len
if best_remaining is None or remaining_after < best_remaining:
best_remaining = remaining_after
best_bin = bin_idx

if best_bin is None:
best_bin = start_new_bin()

target_bin = bins[best_bin]
remaining = max_length - bin_lengths[best_bin]
if remaining <= 0:
continue # should not happen, but guard against negative remaining
take_len = min(sample_len, remaining) if always_max_length else sample_len
take_len = min(take_len, remaining)

target_bin["input_ids"].extend(input_ids[:take_len])
target_bin["labels"].extend(labels[:take_len])
target_bin["attention_mask"].extend(attention_mask[:take_len])
target_bin["position_ids"].extend(range(take_len))
target_bin["packed_sample_seqlens"].append(take_len)
bin_lengths[best_bin] += take_len

for bin_idx, packed in enumerate(bins):
total_len = bin_lengths[bin_idx]
if drop_last and total_len < max_length:
continue
if fuse_positions_prob and rng.random() <= fuse_positions_prob:
packed["position_ids"] = list(range(len(packed["input_ids"])))
for k in keys:
packed_batch[k].append(packed[k])

return packed_batch

Expand Down Expand Up @@ -333,7 +367,7 @@ def pack_dataset(self, dataset: DatasetType) -> DatasetType:

batch_size = len(dataset) // self.config.num_proc + 1
# for huge datasets keep the bs to a sane size to avoid cpu-oom
batch_size = int(min(batch_size, 1e3))
batch_size = int(min(batch_size, 1e4))
dataset = dataset.shuffle(seed=self.config.seed)
dataset = dataset.map(
lambda x: pack_sft_batch(
Expand Down
Loading