Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,13 @@ def _iter_with_split(self):
def _iter_with_no_split(self):
initial_data = []
batch_to_yield = []
# Get the total number of batches, if available
try:
last_batch_idx = len(self.batch_sampler) - 1 if len(self.batch_sampler) > 0 else -1
except TypeError:
# If batch_sampler doesn't have __len__, we can't determine the last batch
# This can happen with custom batch samplers that are generators
last_batch_idx = -1
for idx, batch in enumerate(self.batch_sampler):
# We gather the initial indices in case we need to circle back at the end.
if not self.drop_last and idx < self.num_processes:
Expand All @@ -226,6 +233,25 @@ def _iter_with_no_split(self):
# yielding it.
if idx % self.num_processes == self.process_index:
batch_to_yield = batch
# Check if we should yield without waiting for synchronization
# This prevents deadlocks when:
# 1. We can't determine the batch count (custom batch sampler without __len__), OR
# 2. We're at the last batch and it's incomplete (would be dropped with drop_last=True)
is_last_batch_incomplete = self.batch_size is not None and len(batch) < self.batch_size
at_last_batch = last_batch_idx >= 0 and idx == last_batch_idx
should_yield_immediately = (last_batch_idx < 0) or (at_last_batch and is_last_batch_incomplete)

if should_yield_immediately and self.drop_last:
if idx % self.num_processes == self.process_index:
yield batch_to_yield
# If we know the batch count, break at the last batch
# Otherwise continue to see if there are more batches
if last_batch_idx >= 0:
break
# For generators (last_batch_idx < 0), continue to next iteration
continue

# Normal yield logic: wait for synchronization point
if idx % self.num_processes == self.num_processes - 1 and (
self.batch_size is None or len(batch) == self.batch_size
):
Expand Down
50 changes: 50 additions & 0 deletions tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,56 @@ def test_batch_sampler_with_varying_batch_size(self):
assert list(batch_sampler_shards[0]) == [[0, 1, 2], [5, 6, 7, 8], [12, 13]]
assert list(batch_sampler_shards[1]) == [[3, 4], [9, 10, 11]]

def test_batch_sampler_with_drop_last_no_len(self):
"""
Test that BatchSamplerShard works correctly with custom batch samplers that don't have
__len__ (i.e., are generators) and have drop_last=True.

This tests the fix for issue #3814 where deadlocks could occur when using custom
batch samplers with drop_last=True in multi-process scenarios.

The fix ensures that when we can't determine the batch count (last_batch_idx < 0)
and drop_last=True, each process yields its last batch immediately without waiting
for synchronization with other processes.
"""

class CustomBatchSamplerNoLen:
"""Custom batch sampler that doesn't have __len__ (generator-based)."""

def __init__(self, batches, batch_size, drop_last):
self.batches = batches
self.batch_size = batch_size
self.drop_last = drop_last

def __iter__(self):
for batch in self.batches:
if self.drop_last and len(batch) < self.batch_size:
continue
yield batch

# Simulate two processes with different numbers of batches
# Process 0 has 3 batches: [[0,1,2], [3,4,5], [6,7,8]]
# Process 1 has 2 batches: [[0,1,2], [3,4,5]]
# This simulates a scenario where different processes have different batch counts
batches_process_0 = [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
batches_process_1 = [[0, 1, 2], [3, 4, 5]]

batch_sampler_0 = CustomBatchSamplerNoLen(batches_process_0, batch_size=3, drop_last=True)
batch_sampler_1 = CustomBatchSamplerNoLen(batches_process_1, batch_size=3, drop_last=True)

# Create shards for two processes
shard_0 = BatchSamplerShard(batch_sampler_0, num_processes=2, process_index=0, even_batches=False)
shard_1 = BatchSamplerShard(batch_sampler_1, num_processes=2, process_index=1, even_batches=False)

# Both should iterate without deadlock
result_0 = list(shard_0)
result_1 = list(shard_1)

# Process 0 should get batches at indices 0 and 2 (every 2nd batch starting from 0)
assert result_0 == [[0, 1, 2], [6, 7, 8]]
# Process 1 should get batch at index 1 (every 2nd batch starting from 1)
assert result_1 == [[3, 4, 5]]

def check_iterable_dataset_shards(
self, dataset, seed, batch_size, drop_last=False, num_processes=2, split_batches=False
):
Expand Down