From d87e0953bdec4ce3cbf8986d72dea0c12bc610c2 Mon Sep 17 00:00:00 2001 From: Kabir08 Date: Thu, 26 Feb 2026 00:18:51 +0530 Subject: [PATCH] Fix deadlock in BatchSamplerShard with drop_last=True and custom batch samplers without __len__ This fix addresses issue #3814 where deadlocks could occur when using custom batch samplers (generators) with drop_last=True in multi-process scenarios. Changes: - Added logic to determine last_batch_idx from batch sampler, handling cases where batch sampler doesn't have __len__ (set last_batch_idx = -1) - Added immediate yield logic when drop_last=True and either: 1. Batch sampler doesn't have __len__ (generator-based), OR 2. At last batch and it's incomplete Also added test_batch_sampler_with_drop_last_no_len() to verify the fix works correctly with custom batch samplers that don't have __len__. --- src/accelerate/data_loader.py | 26 ++++++++++++++++++ tests/test_data_loader.py | 50 +++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index e06f9fddf14..569d3ace665 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -218,6 +218,13 @@ def _iter_with_split(self): def _iter_with_no_split(self): initial_data = [] batch_to_yield = [] + # Get the total number of batches, if available + try: + last_batch_idx = len(self.batch_sampler) - 1 if len(self.batch_sampler) > 0 else -1 + except TypeError: + # If batch_sampler doesn't have __len__, we can't determine the last batch + # This can happen with custom batch samplers that are generators + last_batch_idx = -1 for idx, batch in enumerate(self.batch_sampler): # We gather the initial indices in case we need to circle back at the end. if not self.drop_last and idx < self.num_processes: @@ -226,6 +233,25 @@ def _iter_with_no_split(self): # yielding it. if idx % self.num_processes == self.process_index: batch_to_yield = batch + # Check if we should yield without waiting for synchronization + # This prevents deadlocks when: + # 1. We can't determine the batch count (custom batch sampler without __len__), OR + # 2. We're at the last batch and it's incomplete (would be dropped with drop_last=True) + is_last_batch_incomplete = self.batch_size is not None and len(batch) < self.batch_size + at_last_batch = last_batch_idx >= 0 and idx == last_batch_idx + should_yield_immediately = (last_batch_idx < 0) or (at_last_batch and is_last_batch_incomplete) + + if should_yield_immediately and self.drop_last: + if idx % self.num_processes == self.process_index: + yield batch_to_yield + # If we know the batch count, break at the last batch + # Otherwise continue to see if there are more batches + if last_batch_idx >= 0: + break + # For generators (last_batch_idx < 0), continue to next iteration + continue + + # Normal yield logic: wait for synchronization point if idx % self.num_processes == self.num_processes - 1 and ( self.batch_size is None or len(batch) == self.batch_size ): diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index cbb68538874..bda395d960e 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -358,6 +358,56 @@ def test_batch_sampler_with_varying_batch_size(self): assert list(batch_sampler_shards[0]) == [[0, 1, 2], [5, 6, 7, 8], [12, 13]] assert list(batch_sampler_shards[1]) == [[3, 4], [9, 10, 11]] + def test_batch_sampler_with_drop_last_no_len(self): + """ + Test that BatchSamplerShard works correctly with custom batch samplers that don't have + __len__ (i.e., are generators) and have drop_last=True. + + This tests the fix for issue #3814 where deadlocks could occur when using custom + batch samplers with drop_last=True in multi-process scenarios. + + The fix ensures that when we can't determine the batch count (last_batch_idx < 0) + and drop_last=True, each process yields its last batch immediately without waiting + for synchronization with other processes. + """ + + class CustomBatchSamplerNoLen: + """Custom batch sampler that doesn't have __len__ (generator-based).""" + + def __init__(self, batches, batch_size, drop_last): + self.batches = batches + self.batch_size = batch_size + self.drop_last = drop_last + + def __iter__(self): + for batch in self.batches: + if self.drop_last and len(batch) < self.batch_size: + continue + yield batch + + # Simulate two processes with different numbers of batches + # Process 0 has 3 batches: [[0,1,2], [3,4,5], [6,7,8]] + # Process 1 has 2 batches: [[0,1,2], [3,4,5]] + # This simulates a scenario where different processes have different batch counts + batches_process_0 = [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + batches_process_1 = [[0, 1, 2], [3, 4, 5]] + + batch_sampler_0 = CustomBatchSamplerNoLen(batches_process_0, batch_size=3, drop_last=True) + batch_sampler_1 = CustomBatchSamplerNoLen(batches_process_1, batch_size=3, drop_last=True) + + # Create shards for two processes + shard_0 = BatchSamplerShard(batch_sampler_0, num_processes=2, process_index=0, even_batches=False) + shard_1 = BatchSamplerShard(batch_sampler_1, num_processes=2, process_index=1, even_batches=False) + + # Both should iterate without deadlock + result_0 = list(shard_0) + result_1 = list(shard_1) + + # Process 0 should get batches at indices 0 and 2 (every 2nd batch starting from 0) + assert result_0 == [[0, 1, 2], [6, 7, 8]] + # Process 1 should get batch at index 1 (every 2nd batch starting from 1) + assert result_1 == [[3, 4, 5]] + def check_iterable_dataset_shards( self, dataset, seed, batch_size, drop_last=False, num_processes=2, split_batches=False ):