Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions torchtitan/experiments/graph_trainer/make_fx_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class SubclassMeta:
cls: type
attrs: list[str]
ctx: Any
inner_metas: dict[str, tuple[int, Any]]
inner_metas: dict[str, "TensorAttrMeta | StaticAttrMeta"]
outer_size: torch.Size
outer_stride: tuple[int, ...]

Expand All @@ -59,17 +59,31 @@ class SubclassLayout:
meta: SubclassMeta | None


@dataclass
class TensorAttrMeta:
num_tensors: int
meta: SubclassMeta | None


@dataclass
class StaticAttrMeta:
value: Any


def _unwrap_subclass(t: torch.Tensor) -> tuple[list[torch.Tensor], SubclassMeta | None]:
if not is_traceable_wrapper_subclass(t):
return [t], None
attrs, ctx = t.__tensor_flatten__()
all_inner = []
inner_metas = {}
for attr in attrs:
inner_t = getattr(t, attr)
tensors, meta = _unwrap_subclass(inner_t)
all_inner.extend(tensors)
inner_metas[attr] = (len(tensors), meta)
inner = getattr(t, attr)
if isinstance(inner, torch.Tensor):
tensors, meta = _unwrap_subclass(inner)
all_inner.extend(tensors)
inner_metas[attr] = TensorAttrMeta(len(tensors), meta)
else:
inner_metas[attr] = StaticAttrMeta(inner)
meta = SubclassMeta(
cls=type(t),
attrs=attrs,
Expand All @@ -87,13 +101,18 @@ def _wrap_to_subclass(
inner_dict = {}
idx = 0
for attr in meta.attrs:
num_inner, inner_meta = meta.inner_metas[attr]
inner_tensors = plain_tensors[idx : idx + num_inner]
idx += num_inner
if inner_meta is None:
attr_meta = meta.inner_metas[attr]
if isinstance(attr_meta, StaticAttrMeta):
inner_dict[attr] = attr_meta.value
continue

assert isinstance(attr_meta, TensorAttrMeta)
inner_tensors = plain_tensors[idx : idx + attr_meta.num_tensors]
idx += attr_meta.num_tensors
if attr_meta.meta is None:
inner_dict[attr] = inner_tensors[0]
else:
inner_dict[attr] = _wrap_to_subclass(list(inner_tensors), inner_meta)
inner_dict[attr] = _wrap_to_subclass(list(inner_tensors), attr_meta.meta)
return meta.cls.__tensor_unflatten__(
inner_dict, meta.ctx, meta.outer_size, meta.outer_stride
)
Expand Down
25 changes: 25 additions & 0 deletions torchtitan/experiments/graph_trainer/tests/test_trace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,11 +348,36 @@ def forward(model, tokens):
layout.meta is not None for layout in traced.input_subclass_layouts.values()
)
self.assertTrue(has_subclass)
placeholders = [n for n in traced.gm.graph.nodes if n.op == "placeholder"]
self.assertTrue(
all(isinstance(n.meta.get("val"), torch.Tensor) for n in placeholders)
)

out_eager = model(tokens_dt)
wrapped = run_traced_train_step(traced, model, tokens_dt)
self.assertTrue(torch.equal(out_eager.full_tensor(), wrapped.full_tensor()))

def test_dtensor_graph_has_no_untyped_placeholders(self):
from torch.distributed._tensor import DTensor, Replicate
from torch.distributed.device_mesh import init_device_mesh

mesh = init_device_mesh(self.DEVICE, (1,))

model = SimpleMLP().to(device=self.DEVICE, dtype=self.DTYPE)
self._distribute_params(model, mesh)

tokens = torch.randint(0, 256, (2, 32), device=self.DEVICE)
tokens_dt = DTensor.from_local(tokens, mesh, [Replicate()])

def forward(model, tokens):
return model(tokens)

traced = trace_train_step(forward)(model, tokens_dt)
placeholders = [n for n in traced.gm.graph.nodes if n.op == "placeholder"]
self.assertTrue(
all(isinstance(n.meta.get("val"), torch.Tensor) for n in placeholders)
)

def test_dtensor_train_step(self):
from torch.distributed._tensor import DTensor, Replicate
from torch.distributed.device_mesh import init_device_mesh
Expand Down
Loading