diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 6d1d5d61c03..7e4ef504d69 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -631,8 +631,7 @@ def get_sampler(self): return get_sampler(self) def set_sampler(self, sampler): - sampler_is_batch_sampler = isinstance(self.sampler, BatchSampler) - if sampler_is_batch_sampler: + if isinstance(self.sampler, BatchSampler): self.sampler.sampler = sampler else: self.batch_sampler.sampler = sampler @@ -958,8 +957,7 @@ def get_sampler(self): return get_sampler(self) def set_sampler(self, sampler): - sampler_is_batch_sampler = isinstance(self.sampler, BatchSampler) - if sampler_is_batch_sampler: + if isinstance(self.sampler, BatchSampler): self.sampler.sampler = sampler else: self.batch_sampler.sampler = sampler @@ -977,10 +975,8 @@ def get_sampler(dataloader): Returns: `torch.utils.data.Sampler`: The sampler associated to the dataloader """ - sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) - if sampler_is_batch_sampler: - sampler = getattr(dataloader.sampler, "sampler", None) - else: + sampler = getattr(dataloader.sampler, "sampler", None) + if not sampler: sampler = getattr(dataloader.batch_sampler, "sampler", None) return sampler @@ -1155,11 +1151,18 @@ def prepare_data_loader( new_dataset = dataloader.dataset # Iterable dataset doesn't like batch_sampler, but data_loader creates a default one for it + if isinstance(dataloader.sampler, BatchSampler): + logger.warning( + "BatchSampler was passed to sampler argument." + "If you have a custom Sampler that yields a list of batch indices at a time, please pass it as the batch_sampler argument instead." + "For more information, see https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader" + ) + new_batch_sampler = dataloader.batch_sampler if not isinstance(new_dataset, IterableDataset) else None - sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) + synchronized_generator = None + sampler = dataloader.sampler - sampler = get_sampler(dataloader) if isinstance(sampler, RandomSampler) and use_seedable_sampler: # When iterating through the dataloader during distributed processes # we want to ensure that on each process we are iterating through the same @@ -1208,9 +1211,8 @@ def prepare_data_loader( seed = int(torch.empty((), dtype=torch.int64).random_().item()) sampler.generator.manual_seed(seed) synchronized_generator = sampler.generator - batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler new_batch_sampler = BatchSamplerShard( - batch_sampler, + dataloader.batch_sampler, num_processes=num_processes, process_index=process_index, split_batches=split_batches, @@ -1254,19 +1256,6 @@ def prepare_data_loader( torch_device_mesh=torch_device_mesh, **kwargs, ) - elif sampler_is_batch_sampler: - dataloader = DataLoaderShard( - new_dataset, - device=device if put_on_device and state.distributed_type != DistributedType.XLA else None, - sampler=new_batch_sampler, - batch_size=dataloader.batch_size, - rng_types=rng_types, - _drop_last=dataloader.drop_last, - _non_blocking=non_blocking, - synchronized_generator=synchronized_generator, - use_stateful_dataloader=use_stateful_dataloader, - **kwargs, - ) else: dataloader = DataLoaderShard( new_dataset, @@ -1361,13 +1350,17 @@ def skip_first_batches(dataloader, num_batches=0): dataloader = dataloader.dataloader dataset = dataloader.dataset - sampler_is_batch_sampler = False + if isinstance(dataloader.sampler, BatchSampler): + logger.warning( + "BatchSampler was passed to sampler argument." + "If you have a custom Sampler that yields a list of batch indices at a time, please pass it as the batch_sampler argument instead." + "For more information, see https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader" + ) + if isinstance(dataset, IterableDataset): new_batch_sampler = None else: - sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) - batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler - new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches) + new_batch_sampler = SkipBatchSampler(dataloader.batch_sampler, skip_batches=num_batches) # We ignore all of those since they are all dealt with by our new_batch_sampler ignore_kwargs = [ @@ -1404,9 +1397,6 @@ def skip_first_batches(dataloader, num_batches=0): if new_batch_sampler is None: # Need to manually skip batches in the dataloader kwargs["skip_batches"] = num_batches - elif sampler_is_batch_sampler: - kwargs["sampler"] = new_batch_sampler - kwargs["batch_size"] = dataloader.batch_size else: kwargs["batch_sampler"] = new_batch_sampler dataloader = DataLoaderShard(