Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,10 @@ def use_stateful_dataloader(self):
return self.dataloader_config.use_stateful_dataloader
return False

@property
def custom_dataloader_classes(self):
return tuple(getattr(self.dataloader_config, "custom_classes", ()))

@property
def project_dir(self):
return self.project_configuration.project_dir
Expand Down Expand Up @@ -1397,7 +1401,7 @@ def print(self, *args, **kwargs):
def _prepare_one(self, obj, first_pass=False, device_placement=None):
# First pass of preparation: DataLoader, model, optimizer
if first_pass:
if isinstance(obj, torch.utils.data.DataLoader):
if isinstance(obj, torch.utils.data.DataLoader) or isinstance(obj, self.custom_dataloader_classes):
return self.prepare_data_loader(obj, device_placement=device_placement)
elif isinstance(obj, torch.nn.Module):
return self.prepare_model(obj, device_placement=device_placement)
Expand Down Expand Up @@ -2660,9 +2664,7 @@ def _prepare_msamp(self, *args, device_placement):
device_placement[optimizer_index] = False
return tuple(result), device_placement

def prepare_data_loader(
self, data_loader: torch.utils.data.DataLoader, device_placement=None, slice_fn_for_dispatch=None
):
def prepare_data_loader(self, data_loader, device_placement=None, slice_fn_for_dispatch=None):
"""
Prepares a PyTorch DataLoader for training in any distributed setup. It is recommended to use
[`Accelerator.prepare`] instead.
Expand Down Expand Up @@ -2715,6 +2717,7 @@ def prepare_data_loader(
non_blocking=self.non_blocking,
use_stateful_dataloader=self.use_stateful_dataloader,
torch_device_mesh=device_mesh,
custom_classes=self.custom_dataloader_classes,
)
self._dataloaders.append(prepared_data_loader)
return prepared_data_loader
Expand Down
79 changes: 76 additions & 3 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import importlib
import math
from contextlib import suppress
from typing import Callable, Optional, Union
from typing import Any, Callable, Optional, Union

import torch
from packaging import version
Expand Down Expand Up @@ -658,6 +658,68 @@ def set_sampler(self, sampler):
self.batch_sampler.batch_sampler.sampler = sampler


class CustomIterableDataLoader(DataLoaderStateMixin):
"""
Lightweight wrapper around custom iterable dataloader-like objects.

This wrapper only handles optional device placement and keeps dataloader state tracking
for gradient synchronization.
"""

def __init__(self, dataloader, device=None, _non_blocking: bool = False):
self.base_dataloader = dataloader
self.device = device
self._non_blocking = _non_blocking
self.gradient_state = GradientState()
self._drop_last = False

def __getattr__(self, name):
if name == "base_dataloader":
raise AttributeError()
return getattr(self.base_dataloader, name)

def __iter__(self):
self.begin()
dataloader_iter = iter(self.base_dataloader)
# We iterate one batch ahead to identify the last yielded batch.
try:
current_batch = next(dataloader_iter)
except StopIteration:
self.end()
return

while True:
try:
if self.device is not None:
current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking)
next_batch = next(dataloader_iter)
yield current_batch
current_batch = next_batch
except StopIteration:
self.end_of_dataloader = True
if self.device is not None:
current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking)
yield current_batch
break

self.end()

def __len__(self):
return len(self.base_dataloader)

def set_epoch(self, epoch: int):
if hasattr(self.base_dataloader, "set_epoch"):
self.base_dataloader.set_epoch(epoch)

@property
def dataset(self):
return getattr(self.base_dataloader, "dataset", self.base_dataloader)

@property
def total_batch_size(self):
return getattr(self.base_dataloader, "batch_size", None) or 1


if is_torch_xla_available():
import torch_xla.distributed.parallel_loader as xpl

Expand Down Expand Up @@ -1004,7 +1066,7 @@ def get_sampler(dataloader):


def prepare_data_loader(
dataloader: DataLoader,
dataloader: Any,
device: Optional[torch.device] = None,
num_processes: Optional[int] = None,
process_index: Optional[int] = None,
Expand All @@ -1019,7 +1081,8 @@ def prepare_data_loader(
non_blocking: bool = False,
use_stateful_dataloader: bool = False,
torch_device_mesh=None,
) -> DataLoader:
custom_classes: Optional[tuple[type[Any], ...]] = None,
) -> Any:
"""
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.

Expand Down Expand Up @@ -1088,6 +1151,9 @@ def prepare_data_loader(
This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
torch_device_mesh (`torch.distributed.DeviceMesh`, *optional*, defaults to `None`):
PyTorch device mesh.
custom_classes (`tuple[type, ...]`, *optional*, defaults to `None`):
A tuple of custom iterable dataloader-like classes to match with `isinstance`. Matching objects are
wrapped in a lightweight accelerator wrapper that only handles optional device placement.


Returns:
Expand All @@ -1100,6 +1166,13 @@ def prepare_data_loader(

</Tip>
"""
if custom_classes and isinstance(dataloader, custom_classes):
return CustomIterableDataLoader(
dataloader,
device=device if put_on_device else None,
_non_blocking=non_blocking,
)

if dispatch_batches is None:
if not put_on_device:
dispatch_batches = False
Expand Down
20 changes: 20 additions & 0 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,9 @@ class DataLoaderConfiguration:
If set to `True`, the dataloader prepared by the Accelerator will be backed by
[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed.
custom_classes (`tuple[type, ...]`, defaults to `()`):
A tuple of custom iterable dataloader-like classes. Matching objects will be prepared via
`Accelerator.prepare` and wrapped for optional device placement.
"""

split_batches: bool = field(
Expand Down Expand Up @@ -906,6 +909,23 @@ class DataLoaderConfiguration:
"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
},
)
custom_classes: tuple[type[Any], ...] = field(
default_factory=tuple,
metadata={
"help": "A tuple of custom iterable dataloader-like classes to treat as dataloaders in `Accelerator.prepare()`."
},
)

def __post_init__(self):
if self.custom_classes is None:
self.custom_classes = ()
elif isinstance(self.custom_classes, type):
self.custom_classes = (self.custom_classes,)
else:
self.custom_classes = tuple(self.custom_classes)

if not all(isinstance(cls, type) for cls in self.custom_classes):
raise TypeError("`custom_classes` must contain class objects.")


@dataclass
Expand Down
24 changes: 24 additions & 0 deletions tests/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,19 @@ def get_dataset(n_batches):
return (train_dataloader, valid_dataloader)


class CustomIterableDataLoader:
def __init__(self, values):
self.values = values
self.batch_size = 1

def __iter__(self):
for value in self.values:
yield torch.tensor([value])

def __len__(self):
return len(self.values)


def get_signature(model):
return sum(param.abs().sum().item() for param in model.parameters())

Expand Down Expand Up @@ -702,6 +715,17 @@ def test_can_pickle_dataloader(self, dispatch_batches):
assert len(loaded_skip_dl) == len(original_dl) - 2
assert [i for i in loaded_skip_dl] == [i for i in original_dl][2:]

def test_prepare_with_custom_iterable_dataloader(self):
dataloader_config = DataLoaderConfiguration(custom_classes=(CustomIterableDataLoader,))
accelerator = Accelerator(cpu=True, dataloader_config=dataloader_config)

dataloader = CustomIterableDataLoader([1, 2, 3])
prepared_dataloader = accelerator.prepare(dataloader)

assert prepared_dataloader in accelerator._dataloaders
assert prepared_dataloader._is_accelerate_prepared
assert [batch.tolist() for batch in prepared_dataloader] == [[1], [2], [3]]

# Ideally would be a parameterized test which works with either stateful or non-stateful dataloaders, but dependencies are a bit awkward.
@require_torchdata_stateful_dataloader
def test_prepared_objects_are_referenced_with_stateful_dataloader(self):
Expand Down
21 changes: 21 additions & 0 deletions tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from accelerate import Accelerator, PartialState
from accelerate.data_loader import (
BatchSamplerShard,
CustomIterableDataLoader,
DataLoaderDispatcher,
DataLoaderShard,
DataLoaderStateMixin,
Expand Down Expand Up @@ -95,6 +96,19 @@ def set_epoch(self, epoch):
self.epoch = epoch


class CustomIterableLoader:
def __init__(self, values):
self.values = values
self.batch_size = 1

def __iter__(self):
for value in self.values:
yield torch.tensor([value])

def __len__(self):
return len(self.values)


class DataLoaderTester(AccelerateTestCase):
def check_batch_sampler_shards(self, batch_sampler, expected, split_batches=False, even_batches=True):
batch_sampler_shards = [
Expand Down Expand Up @@ -438,6 +452,13 @@ def collate_fn(features):
assert isinstance(d["tensor"], torch.Tensor)
assert d["non_tensor"] == "non_tensor_value"

def test_prepare_data_loader_with_custom_iterable_loader(self):
dataloader = CustomIterableLoader([1, 2, 3])
prepared = prepare_data_loader(dataloader, custom_classes=(CustomIterableLoader,))

assert isinstance(prepared, CustomIterableDataLoader)
assert [batch.tolist() for batch in prepared] == [[1], [2], [3]]

@parameterized.expand([1, 2], name_func=parameterized_custom_name_func)
def test_reproducibility(self, num_processes):
set_seed(21)
Expand Down