Skip to content

[release/2.10] Cherry-pick: [FSDP2] support dataclass args/kwargs output without memory leakage (#174692)#3146

Draft
chinmaydk99 wants to merge 1 commit intoROCm:release/2.10from
chinmaydk99:cherry-pick-fsdp2-dataclass-fix
Draft

[release/2.10] Cherry-pick: [FSDP2] support dataclass args/kwargs output without memory leakage (#174692)#3146
chinmaydk99 wants to merge 1 commit intoROCm:release/2.10from
chinmaydk99:cherry-pick-fsdp2-dataclass-fix

Conversation

@chinmaydk99
Copy link
Copy Markdown

Motivation

Addressing this JIRA ticket: https://amd-hub.atlassian.net/browse/AIPYTORCH-396

FSDP2 with reshard_after_forward=True crashes during the backward pass with RuntimeError: setStorage: sizes [...] are inconsistent with scalar type Float and target storage of size 0 when the wrapped module's forward() returns a dataclass containing tensors (rather than a plain tensor or tuple).

This is a cherry-pick of upstream fix pytorch#174692 (reland of pytorch#173415).

Technical Details

FSDP2's hook registration uses tree_flatten to find grad-requiring tensors in forward outputs/inputs, but tree_flatten does not traverse dataclass instances. It treats them as opaque leaves. This means no pre-backward hooks get registered, so parameters are never unsharded before gradient computation, causing the crash.

The fix adds two dataclass-aware traversal utilities (collect_grad_tensors, replace_grad_tensors) in _common_utils.py and uses them in place of tree_flatten/tree_unflatten in:

  • _fsdp_state.py — _register_pre_backward_hook (output traversal)
  • _fsdp_param_group.py — _register_post_backward_hook (input traversal)

Upstream commit: ab1d15e

…ytorch#174692)

support dataclass output. reland pytorch#173415

`pytest -s test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py -k test_dataclass_input_output`

Pull Request resolved: pytorch#174692
Approved by: https://github.com/mori360
@rocm-repo-management-api
Copy link
Copy Markdown

rocm-repo-management-api bot commented Apr 10, 2026

Jenkins build for ef9bb3215728bf956a13bc30e0896e330356373d commit finished as FAILURE
Links: Pipeline Overview / Build artifacts / Test Results

Detected error during Pytorch building:

/var/lib/jenkins/pytorch/torch/headeronly/macros/Export.h:130:9: note: previous definition is here
  130 | #define TORCH_HIP_API C10_IMPORT
      |         ^
2 warnings generated when compiling for host.
[7459/8132] Building HIPCC object caffe2/CMakeFiles/torch_hip.dir/__/aten/src/ATen/native/sparse/hip/torch_hip_generated_SparseHIPTensorMath.hip.o
FAILED: caffe2/CMakeFiles/torch_hip.dir/__/aten/src/ATen/native/sparse/hip/torch_hip_generated_SparseHIPTensorMath.hip.o /var/lib/jenkins/pytorch/build/caffe2/CMakeFiles/torch_hip.dir/__/aten/src/ATen/native/sparse/hip/torch_hip_generated_SparseHIPTensorMath.hip.o 
cd /var/lib/jenkins/pytorch/build/caffe2/CMakeFiles/torch_hip.dir/__/aten/src/ATen/native/sparse/hip && /opt/conda/envs/py_3.12/lib/python3.12/site-packages/cmake/data/bin/cmake -E make_directory /var/lib/jenkins/pytorch/build/caffe2/CMakeFiles/torch_hip.dir/__/aten/src/ATen/native/sparse/hip/. && /opt/conda/envs/py_3.12/lib/python3.12/site-packages/cmake/data/bin/cmake -D verbose:BOOL=OFF -D build_configuration:STRING=RELEASE -D generated_file:STRING=/var/lib/jenkins/pytorch/build/caffe2/CMakeFiles/torch_hip.dir/__/aten/src/ATen/native/sparse/hip/./torch_hip_generated_SparseHIPTensorMath.hip.o -P /var/lib/jenkins/pytorch/build/caffe2/CMakeFiles/torch_hip.dir/__/aten/src/ATen/native/sparse/hip/torch_hip_generated_SparseHIPTensorMath.hip.o.cmake
sccache: encountered fatal error
sccache: error: Failed to parse included file path
sccache: caused by: Failed to parse included file path
failed to execute:/opt/rocm/llvm/bin/clang++  --offload-arch=gfx90a --offload-arch=gfx908 --offload-arch=gfx942 -O3  -c -x hip /var/lib/jenkins/pytorch/aten/src/ATen/native/sparse/hip/SparseHIPTensorMath.hip -o "/var/lib/jenkins/pytorch/build/caffe2/CMakeFiles/torch_hip.dir/__/aten/src/ATen/native/sparse/hip/./torch_hip_generated_SparseHIPTensorMath.hip.o" --offload-compress -fclang-abi-compat=17 -DUSE_NCCL -DUSE_ROCM -D__HIP_PLATFORM_AMD__ -DUSE_FLASH_ATTENTION -DFLASHATTENTION_DISABLE_ALIBI -DFLASHATTENTION_DISABLE_SOFTCAP -DFLASH_NAMESPACE=pytorch_flash -DUNFUSE_FMA -DUSE_MEM_EFF_ATTENTION -DUSE_C10D_NCCL -DTORCH_HIP_BUILD_MAIN_LIB -DROCM_VERSION=70201 -DTORCH_HIP_VERSION=702 -DUSE_LAYERNORM_FAST_RECIPROCAL -DONNX_ML=1 -DONNXIFI_ENABLE_EXT=1 -DONNX_NAMESPACE=onnx_torch -DIDEEP_USE_MKL -DHAVE_MMAP=1 -D_FILE_OFFSET_BITS=64 -DHAVE_SHM_OPEN=1 -DHAVE_SHM_UNLINK=1 -DHAVE_MALLOC_USABLE_SIZE=1 -DHAVE_POSIX_FALLOCATE=1 -DUSE_EXTERNAL_MZCRC -DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS -D__HIP_PLATFORM_AMD__=1 -DUSE_PROF_API=1 -DAT_PER_OPERATOR_HEADERS -DUSE_DISTRIBUTED -DUSE_C10D_GLOO -DUSE_RPC -DUSE_TENSORPIPE -D__HIP_PLATFORM_AMD__ -DHIPBLASLT_USE_ROCROLLER -DFMT_HEADER_ONLY=1 -fPIC -D__HIP_PLATFORM_AMD__=1 -DCUDA_HAS_FP16=1 -DUSE_ROCM -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -DTORCH_HIP_VERSION=702 -Wno-shift-count-negative -Wno-shift-count-overflow -DCAFFE2_USE_MIOPEN -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP -std=c++17 -DHIPBLAS_V2 -DHIP_ENABLE_WARP_SYNC_BUILTINS -DHIPBLASLT_OUTER_VEC -DUSE_ROCM_CK_GEMM -fno-gpu-rdc -I/var/lib/jenkins/pytorch/build/aten/src -I/var/lib/jenkins/pytorch/aten/src -I/var/lib/jenkins/pytorch/build -I/var/lib/jenkins/pytorch -I/opt/rocm-7.2.1/include -I/var/lib/jenkins/pytorch/build/third_party/gloo -I/var/lib/jenkins/pytorch/cmake/../third_party/gloo -I/var/lib/jenkins/pytorch/cmake/../third_party/tensorpipe/third_party/libuv/include -I/var/lib/jenkins/pytorch/cmake/../third_party/googletest/googlemock/include -I/var/lib/jenkins/pytorch/cmake/../third_party/googletest/googletest/include -I/var/lib/jenkins/pytorch/third_party/protobuf/src -I/opt/conda/envs/py_3.12/include -I/var/lib/jenkins/pytorch/third_party/XNNPACK/include -I/var/lib/jenkins/pytorch/third_party/ittapi/include -I/var/lib/jenkins/pytorch/cmake/../third_party/eigen -I/opt/rocm/include -I/opt/rocm-7.2.1/include -I/var/lib/jenkins/pytorch/third_party/ideep/mkl-dnn/include/oneapi/dnnl -I/var/lib/jenkins/pytorch/third_party/ideep/include -I/var/lib/jenkins/pytorch/third_party/ideep/mkl-dnn/include/oneapi/dnnl -I/opt/conda/envs/py_3.12/include -I/var/lib/jenkins/pytorch/nlohmann -I/var/lib/jenkins/pytorch/INTERFACE -I/var/lib/jenkins/pytorch/third_party/nlohmann/include -I/var/lib/jenkins/pytorch/moodycamel -I/var/lib/jenkins/pytorch/INTERFACE -I/var/lib/jenkins/pytorch/third_party/concurrentqueue -I/var/lib/jenkins/pytorch/aten/src/THH -I/var/lib/jenkins/pytorch/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include -I/var/lib/jenkins/pytorch/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include -I/var/lib/jenkins/pytorch/aten/src/ATen/hip -I/var/lib/jenkins/pytorch/aten/src/ATen/../../../third_party/composable_kernel/include -I/var/lib/jenkins/pytorch/aten/src/ATen/../../../third_party/composable_kernel/library/include -I/var/lib/jenkins/pytorch/aten/src/ATen/../../../third_party/composable_kernel/example/ck_tile/01_fmha -I/var/lib/jenkins/pytorch/build/caffe2/aten/src/ATen/composable_kernel -I/var/lib/jenkins/pytorch/aten/src/ATen/../../../third_party/aiter/csrc/include -I/var/lib/jenkins/pytorch/third_party/fmt/include -I/var/lib/jenkins/pytorch/aten/src -I/var/lib/jenkins/pytorch/build/caffe2/aten/src -I/var/lib/jenkins/pytorch/build/aten/src -I/var/lib/jenkins/pytorch/aten/src -I/var/lib/jenkins/pytorch/aten/src/ATen/.. -I/var/lib/jenkins/pytorch/torch/include -I/opt/rocm-7.2.1/include -I/opt/rocm/include -I/var/lib/jenkins/pytorch/c10/hip/../.. -I/var/lib/jenkins/pytorch/build -I/var/lib/jenkins/pytorch/c10/../ -I/var/lib/jenkins/pytorch/build -I/var/lib/jenkins/pytorch/torch/csrc/api -I/var/lib/jenkins/pytorch/torch/csrc/api/include -I/var/lib/jenkins/pytorch/third_party/protobuf/src -I/opt/conda/envs/py_3.12/include -I/opt/rocm-7.2.1/include -I/opt/rocm/include -I/opt/rocm-7.2.1/include -I/opt/rocm-7.2.1/include -I/opt/rocm-7.2.1/include -I/opt/rocm-7.2.1/include -I/opt/rocm-7.2.1/include -I/opt/rocm-7.2.1/include -I/opt/rocm-7.2.1/include -I/opt/rocm-7.2.1/include -I/opt/rocm-7.2.1/include/hiprand -I/opt/rocm-7.2.1/include -I/opt/rocm-7.2.1/include/rocrand -I/opt/rocm-7.2.1/include -I/opt/rocm-7.2.1/include -I/opt/rocm-7.2.1/include -I/opt/rocm-7.2.1/include -I/opt/rocm-7.2.1/include -I/opt/rocm/include -I/opt/rocm/include -I/opt/rocm-7.2.1/include -I/opt/rocm-7.2.1/include -I/opt/rocm-7.2.1/include -I/opt/rocm-7.2.1/include -I/opt/rocm/include -I/usr/include/libdrm -I/opt/rocm/include -I/var/lib/jenkins/pytorch/build/third_party/gloo/hip -I/var/lib/jenkins/pytorch/build/aten/src -I/var/lib/jenkins/pytorch/aten/src -I/var/lib/jenkins/pytorch/build -I/var/lib/jenkins/pytorch -I/opt/rocm-7.2.1/include -I/var/lib/jenkins/pytorch/build/third_party/gloo -I/var/lib/jenkins/pytorch/cmake/../third_party/gloo -I/var/lib/jenkins/pytorch/cmake/../third_party/tensorpipe/third_party/libuv/include -I/var/lib/jenkins/pytorch/cmake/../third_party/googletest/googlemock/include -I/var/lib/jenkins/pytorch/cmake/../third_party/googletest/googletest/include -I/var/lib/jenkins/pytorch/third_party/protobuf/src -I/opt/conda/envs/py_3.12/include -I/var/lib/jenkins/pytorch/third_party/XNNPACK/include -I/var/lib/jenkins/pytorch/third_party/ittapi/include -I/var/lib/jenkins/pytorch/cmake/../third_party/eigen -I/opt/rocm/include -I/var/lib/jenkins/pytorch/third_party/ideep/mkl-dnn/include/oneapi/dnnl -I/var/lib/jenkins/pytorch/third_party/ideep/include -I/var/lib/jenkins/pytorch/nlohmann -I/var/lib/jenkins/pytorch/INTERFACE -I/var/lib/jenkins/pytorch/third_party/nlohmann/include -I/var/lib/jenkins/pytorch/moodycamel -I/var/lib/jenkins/pytorch/third_party/concurrentqueue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants