diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 0bc7650f682..06bf2909cba 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -717,6 +717,10 @@ def use_stateful_dataloader(self): return self.dataloader_config.use_stateful_dataloader return False + @property + def custom_dataloader_classes(self): + return tuple(getattr(self.dataloader_config, "custom_classes", ())) + @property def project_dir(self): return self.project_configuration.project_dir @@ -1397,7 +1401,7 @@ def print(self, *args, **kwargs): def _prepare_one(self, obj, first_pass=False, device_placement=None): # First pass of preparation: DataLoader, model, optimizer if first_pass: - if isinstance(obj, torch.utils.data.DataLoader): + if isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, self.custom_dataloader_classes): return self.prepare_data_loader(obj, device_placement=device_placement) elif isinstance(obj, torch.nn.Module): return self.prepare_model(obj, device_placement=device_placement) @@ -2660,9 +2664,7 @@ def _prepare_msamp(self, *args, device_placement): device_placement[optimizer_index] = False return tuple(result), device_placement - def prepare_data_loader( - self, data_loader: torch.utils.data.DataLoader, device_placement=None, slice_fn_for_dispatch=None - ): + def prepare_data_loader(self, data_loader, device_placement=None, slice_fn_for_dispatch=None): """ Prepares a PyTorch DataLoader for training in any distributed setup. It is recommended to use [`Accelerator.prepare`] instead. @@ -2715,6 +2717,7 @@ def prepare_data_loader( non_blocking=self.non_blocking, use_stateful_dataloader=self.use_stateful_dataloader, torch_device_mesh=device_mesh, + custom_classes=self.custom_dataloader_classes, ) self._dataloaders.append(prepared_data_loader) return prepared_data_loader diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index a8d7eaa01a0..8adcd7e8ac5 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -15,7 +15,7 @@ import importlib import math from contextlib import suppress -from typing import Callable, Optional, Union +from typing import Any, Callable, Optional, Union import torch from packaging import version @@ -658,6 +658,68 @@ def set_sampler(self, sampler): self.batch_sampler.batch_sampler.sampler = sampler +class CustomIterableDataLoader(DataLoaderStateMixin): + """ + Lightweight wrapper around custom iterable dataloader-like objects. + + This wrapper only handles optional device placement and keeps dataloader state tracking + for gradient synchronization. + """ + + def __init__(self, dataloader, device=None, _non_blocking: bool = False): + self.base_dataloader = dataloader + self.device = device + self._non_blocking = _non_blocking + self.gradient_state = GradientState() + self._drop_last = False + + def __getattr__(self, name): + if name == "base_dataloader": + raise AttributeError() + return getattr(self.base_dataloader, name) + + def __iter__(self): + self.begin() + dataloader_iter = iter(self.base_dataloader) + # We iterate one batch ahead to identify the last yielded batch. + try: + current_batch = next(dataloader_iter) + except StopIteration: + self.end() + return + + while True: + try: + if self.device is not None: + current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking) + next_batch = next(dataloader_iter) + yield current_batch + current_batch = next_batch + except StopIteration: + self.end_of_dataloader = True + if self.device is not None: + current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking) + yield current_batch + break + + self.end() + + def __len__(self): + return len(self.base_dataloader) + + def set_epoch(self, epoch: int): + if hasattr(self.base_dataloader, "set_epoch"): + self.base_dataloader.set_epoch(epoch) + + @property + def dataset(self): + return getattr(self.base_dataloader, "dataset", self.base_dataloader) + + @property + def total_batch_size(self): + return getattr(self.base_dataloader, "batch_size", None) or 1 + + if is_torch_xla_available(): import torch_xla.distributed.parallel_loader as xpl @@ -1004,7 +1066,7 @@ def get_sampler(dataloader): def prepare_data_loader( - dataloader: DataLoader, + dataloader: Any, device: Optional[torch.device] = None, num_processes: Optional[int] = None, process_index: Optional[int] = None, @@ -1019,7 +1081,8 @@ def prepare_data_loader( non_blocking: bool = False, use_stateful_dataloader: bool = False, torch_device_mesh=None, -) -> DataLoader: + custom_classes: Optional[tuple[type[Any], ...]] = None, +) -> Any: """ Wraps a PyTorch `DataLoader` to generate batches for one of the processes only. @@ -1088,6 +1151,9 @@ def prepare_data_loader( This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed." torch_device_mesh (`torch.distributed.DeviceMesh`, *optional*, defaults to `None`): PyTorch device mesh. + custom_classes (`tuple[type, ...]`, *optional*, defaults to `None`): + A tuple of custom iterable dataloader-like classes to match with `isinstance`. Matching objects are + wrapped in a lightweight accelerator wrapper that only handles optional device placement. Returns: @@ -1100,6 +1166,13 @@ def prepare_data_loader( """ + if custom_classes and isinstance(dataloader, custom_classes): + return CustomIterableDataLoader( + dataloader, + device=device if put_on_device else None, + _non_blocking=non_blocking, + ) + if dispatch_batches is None: if not put_on_device: dispatch_batches = False diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index dd3fca53c6f..acad7fd9260 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -848,6 +848,9 @@ class DataLoaderConfiguration: If set to `True`, the dataloader prepared by the Accelerator will be backed by [torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed. + custom_classes (`tuple[type, ...]`, defaults to `()`): + A tuple of custom iterable dataloader-like classes. Matching objects will be prepared via + `Accelerator.prepare` and wrapped for optional device placement. """ split_batches: bool = field( @@ -906,6 +909,23 @@ class DataLoaderConfiguration: "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed." }, ) + custom_classes: tuple[type[Any], ...] = field( + default_factory=tuple, + metadata={ + "help": "A tuple of custom iterable dataloader-like classes to treat as dataloaders in `Accelerator.prepare()`." + }, + ) + + def __post_init__(self): + if self.custom_classes is None: + self.custom_classes = () + elif isinstance(self.custom_classes, type): + self.custom_classes = (self.custom_classes,) + else: + self.custom_classes = tuple(self.custom_classes) + + if not all(isinstance(cls, type) for cls in self.custom_classes): + raise TypeError("`custom_classes` must contain class objects.") @dataclass diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index c56b00ec3f5..45a816f2e26 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -106,6 +106,19 @@ def get_dataset(n_batches): return (train_dataloader, valid_dataloader) +class CustomIterableDataLoader: + def __init__(self, values): + self.values = values + self.batch_size = 1 + + def __iter__(self): + for value in self.values: + yield torch.tensor([value]) + + def __len__(self): + return len(self.values) + + def get_signature(model): return sum(param.abs().sum().item() for param in model.parameters()) @@ -702,6 +715,17 @@ def test_can_pickle_dataloader(self, dispatch_batches): assert len(loaded_skip_dl) == len(original_dl) - 2 assert [i for i in loaded_skip_dl] == [i for i in original_dl][2:] + def test_prepare_with_custom_iterable_dataloader(self): + dataloader_config = DataLoaderConfiguration(custom_classes=(CustomIterableDataLoader,)) + accelerator = Accelerator(cpu=True, dataloader_config=dataloader_config) + + dataloader = CustomIterableDataLoader([1, 2, 3]) + prepared_dataloader = accelerator.prepare(dataloader) + + assert prepared_dataloader in accelerator._dataloaders + assert prepared_dataloader._is_accelerate_prepared + assert [batch.tolist() for batch in prepared_dataloader] == [[1], [2], [3]] + # Ideally would be a parameterized test which works with either stateful or non-stateful dataloaders, but dependencies are a bit awkward. @require_torchdata_stateful_dataloader def test_prepared_objects_are_referenced_with_stateful_dataloader(self): diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index 2057990a967..741cd74beba 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -23,6 +23,7 @@ from accelerate import Accelerator, PartialState from accelerate.data_loader import ( BatchSamplerShard, + CustomIterableDataLoader, DataLoaderDispatcher, DataLoaderShard, DataLoaderStateMixin, @@ -95,6 +96,19 @@ def set_epoch(self, epoch): self.epoch = epoch +class CustomIterableLoader: + def __init__(self, values): + self.values = values + self.batch_size = 1 + + def __iter__(self): + for value in self.values: + yield torch.tensor([value]) + + def __len__(self): + return len(self.values) + + class DataLoaderTester(AccelerateTestCase): def check_batch_sampler_shards(self, batch_sampler, expected, split_batches=False, even_batches=True): batch_sampler_shards = [ @@ -438,6 +452,13 @@ def collate_fn(features): assert isinstance(d["tensor"], torch.Tensor) assert d["non_tensor"] == "non_tensor_value" + def test_prepare_data_loader_with_custom_iterable_loader(self): + dataloader = CustomIterableLoader([1, 2, 3]) + prepared = prepare_data_loader(dataloader, custom_classes=(CustomIterableLoader,)) + + assert isinstance(prepared, CustomIterableDataLoader) + assert [batch.tolist() for batch in prepared] == [[1], [2], [3]] + @parameterized.expand([1, 2], name_func=parameterized_custom_name_func) def test_reproducibility(self, num_processes): set_seed(21)