diff --git a/src/accelerate/parallelism_config.py b/src/accelerate/parallelism_config.py index c4135f2f791..5ff4ebabfaf 100644 --- a/src/accelerate/parallelism_config.py +++ b/src/accelerate/parallelism_config.py @@ -236,8 +236,11 @@ def build_device_mesh(self, device_type: str): ) if self.dp_dim_names: device_mesh[self.dp_dim_names]._flatten("dp") - if self.dp_shard_cp_dim_names: - device_mesh[self.dp_shard_cp_dim_names]._flatten("dp_shard_cp") + # Always create dp_shard_cp submesh from available dims in the mesh + # (dp_shard is always present after _get_mesh, even when size == 1) + shard_cp_dims = [d for d in ["dp_shard", "cp"] if d in mesh_dim_names] + if shard_cp_dims: + device_mesh[shard_cp_dims]._flatten("dp_shard_cp") if self.dp_cp_dim_names: device_mesh[self.dp_cp_dim_names]._flatten("dp_cp") @@ -263,6 +266,11 @@ def _get_mesh(self) -> tuple[tuple[int, ...], tuple[str, ...]]: # Build mesh dimensions dictionary mesh_dims = {parallelism: self._sizes[parallelism] for parallelism in self.active_mesh_dims} + # Always include dp_shard for composable parallelism so that FSDP2 has + # a valid submesh even when dp_shard_size == 1 (no-op shard). + if mesh_dims and "dp_shard" not in mesh_dims: + mesh_dims["dp_shard"] = self.dp_shard_size + # Apply canonical ordering mesh_order = ["dp_replicate", "dp_shard", "cp", "sp", "tp"] sorted_items = sorted(