Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/accelerate/parallelism_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,11 @@ def build_device_mesh(self, device_type: str):
)
if self.dp_dim_names:
device_mesh[self.dp_dim_names]._flatten("dp")
if self.dp_shard_cp_dim_names:
device_mesh[self.dp_shard_cp_dim_names]._flatten("dp_shard_cp")
# Always create dp_shard_cp submesh from available dims in the mesh
# (dp_shard is always present after _get_mesh, even when size == 1)
shard_cp_dims = [d for d in ["dp_shard", "cp"] if d in mesh_dim_names]
if shard_cp_dims:
device_mesh[shard_cp_dims]._flatten("dp_shard_cp")
if self.dp_cp_dim_names:
device_mesh[self.dp_cp_dim_names]._flatten("dp_cp")

Expand All @@ -263,6 +266,11 @@ def _get_mesh(self) -> tuple[tuple[int, ...], tuple[str, ...]]:
# Build mesh dimensions dictionary
mesh_dims = {parallelism: self._sizes[parallelism] for parallelism in self.active_mesh_dims}

# Always include dp_shard for composable parallelism so that FSDP2 has
# a valid submesh even when dp_shard_size == 1 (no-op shard).
if mesh_dims and "dp_shard" not in mesh_dims:
mesh_dims["dp_shard"] = self.dp_shard_size

# Apply canonical ordering
mesh_order = ["dp_replicate", "dp_shard", "cp", "sp", "tp"]
sorted_items = sorted(
Expand Down