Fix DTensor attr handling in make_fx tracer#2998
Conversation
|
@tugsbayasgalan for CooR we do need to have device meshes as inputs into the graph. The way CooR works is we hoist device meshes as inputs to the graph and then use custom ops to extract process groups such that the graph is the same across all ranks but at runtime different ranks use different process groups. I wonder if a more principled approach would be to have a DCE pass? cc @aorenste |
|
@bobrenjc93 Specifically for this PR, the bug was we were incorrectly lifting DeviceMesh as input because we incorrectly thought they were tensors. But yes, i do think we should find a more principled way to handle this. Probably something like this (https://github.com/pytorch/pytorch/blob/665a8750269104209a9e0f1ce35e642db0c31b4f/torch/_functorch/_aot_autograd/subclass_utils.py#L256). Basically, we should reuse subclass wrapping/unwrapping from AOTAutograd as much as possible. It has custom logic to handle opaque type objects like DeviceMesh seperately. |
Fix graph_trainer make_fx tracing for DTensor-backed module state by avoiding non-tensor DTensor attrs in the
traced graph inputs.
Previously, our subclass unwrap logic flattened every attribute returned by tensor_flatten(). For DTensor
this included non-tensor attrs like device_mesh, which leaked into the graph signature as untyped placeholders.