diff --git a/torchtitan/experiments/graph_trainer/make_fx_tracer.py b/torchtitan/experiments/graph_trainer/make_fx_tracer.py index 22b4a32ee5..0de6dfec82 100644 --- a/torchtitan/experiments/graph_trainer/make_fx_tracer.py +++ b/torchtitan/experiments/graph_trainer/make_fx_tracer.py @@ -48,7 +48,7 @@ class SubclassMeta: cls: type attrs: list[str] ctx: Any - inner_metas: dict[str, tuple[int, Any]] + inner_metas: dict[str, "TensorAttrMeta | StaticAttrMeta"] outer_size: torch.Size outer_stride: tuple[int, ...] @@ -59,6 +59,17 @@ class SubclassLayout: meta: SubclassMeta | None +@dataclass +class TensorAttrMeta: + num_tensors: int + meta: SubclassMeta | None + + +@dataclass +class StaticAttrMeta: + value: Any + + def _unwrap_subclass(t: torch.Tensor) -> tuple[list[torch.Tensor], SubclassMeta | None]: if not is_traceable_wrapper_subclass(t): return [t], None @@ -66,10 +77,13 @@ def _unwrap_subclass(t: torch.Tensor) -> tuple[list[torch.Tensor], SubclassMeta all_inner = [] inner_metas = {} for attr in attrs: - inner_t = getattr(t, attr) - tensors, meta = _unwrap_subclass(inner_t) - all_inner.extend(tensors) - inner_metas[attr] = (len(tensors), meta) + inner = getattr(t, attr) + if isinstance(inner, torch.Tensor): + tensors, meta = _unwrap_subclass(inner) + all_inner.extend(tensors) + inner_metas[attr] = TensorAttrMeta(len(tensors), meta) + else: + inner_metas[attr] = StaticAttrMeta(inner) meta = SubclassMeta( cls=type(t), attrs=attrs, @@ -87,13 +101,18 @@ def _wrap_to_subclass( inner_dict = {} idx = 0 for attr in meta.attrs: - num_inner, inner_meta = meta.inner_metas[attr] - inner_tensors = plain_tensors[idx : idx + num_inner] - idx += num_inner - if inner_meta is None: + attr_meta = meta.inner_metas[attr] + if isinstance(attr_meta, StaticAttrMeta): + inner_dict[attr] = attr_meta.value + continue + + assert isinstance(attr_meta, TensorAttrMeta) + inner_tensors = plain_tensors[idx : idx + attr_meta.num_tensors] + idx += attr_meta.num_tensors + if attr_meta.meta is None: inner_dict[attr] = inner_tensors[0] else: - inner_dict[attr] = _wrap_to_subclass(list(inner_tensors), inner_meta) + inner_dict[attr] = _wrap_to_subclass(list(inner_tensors), attr_meta.meta) return meta.cls.__tensor_unflatten__( inner_dict, meta.ctx, meta.outer_size, meta.outer_stride ) diff --git a/torchtitan/experiments/graph_trainer/tests/test_trace_module.py b/torchtitan/experiments/graph_trainer/tests/test_trace_module.py index 43346e1e9a..345b651862 100644 --- a/torchtitan/experiments/graph_trainer/tests/test_trace_module.py +++ b/torchtitan/experiments/graph_trainer/tests/test_trace_module.py @@ -348,11 +348,36 @@ def forward(model, tokens): layout.meta is not None for layout in traced.input_subclass_layouts.values() ) self.assertTrue(has_subclass) + placeholders = [n for n in traced.gm.graph.nodes if n.op == "placeholder"] + self.assertTrue( + all(isinstance(n.meta.get("val"), torch.Tensor) for n in placeholders) + ) out_eager = model(tokens_dt) wrapped = run_traced_train_step(traced, model, tokens_dt) self.assertTrue(torch.equal(out_eager.full_tensor(), wrapped.full_tensor())) + def test_dtensor_graph_has_no_untyped_placeholders(self): + from torch.distributed._tensor import DTensor, Replicate + from torch.distributed.device_mesh import init_device_mesh + + mesh = init_device_mesh(self.DEVICE, (1,)) + + model = SimpleMLP().to(device=self.DEVICE, dtype=self.DTYPE) + self._distribute_params(model, mesh) + + tokens = torch.randint(0, 256, (2, 32), device=self.DEVICE) + tokens_dt = DTensor.from_local(tokens, mesh, [Replicate()]) + + def forward(model, tokens): + return model(tokens) + + traced = trace_train_step(forward)(model, tokens_dt) + placeholders = [n for n in traced.gm.graph.nodes if n.op == "placeholder"] + self.assertTrue( + all(isinstance(n.meta.get("val"), torch.Tensor) for n in placeholders) + ) + def test_dtensor_train_step(self): from torch.distributed._tensor import DTensor, Replicate from torch.distributed.device_mesh import init_device_mesh