diff --git a/devtools/intermediate_output_tap/TARGETS b/devtools/intermediate_output_tap/TARGETS new file mode 100644 index 00000000000..5ae2ddc8380 --- /dev/null +++ b/devtools/intermediate_output_tap/TARGETS @@ -0,0 +1,85 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +runtime.python_library( + name = "spec", + srcs = ["_spec.py"], +) + +runtime.python_library( + name = "custom_ops_lib", + srcs = ["custom_ops_lib.py"], + deps = [ + "//caffe2:torch", + ], +) + +runtime.python_library( + name = "selectors", + srcs = ["_selectors.py"], + deps = [ + "//caffe2:torch", + ], +) + +runtime.python_library( + name = "reducers", + srcs = ["_reducers.py"], + deps = [ + "//caffe2:torch", + "//executorch/exir/dialects:lib", + ], +) + +runtime.python_library( + name = "tap_pass", + srcs = ["_tap_pass.py"], + deps = [ + "//caffe2:torch", + ":custom_ops_lib", + ":reducers", + ":selectors", + ":spec", + ], +) + +runtime.python_library( + name = "strip_pass", + srcs = ["_strip_pass.py"], + deps = [ + "//caffe2:torch", + ":reducers", + ":tap_pass", + ], +) + +runtime.python_library( + name = "convenience", + srcs = ["_convenience.py"], + deps = [ + "fbsource//third-party/pypi/pandas:pandas", + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/runtime:runtime", + ":reducers", + ":selectors", + ":spec", + ":strip_pass", + ":tap_pass", + ], +) + +runtime.python_library( + name = "lib", + srcs = ["__init__.py"], + deps = [ + ":convenience", + ":custom_ops_lib", + ":reducers", + ":selectors", + ":spec", + ":strip_pass", + ":tap_pass", + ], +) diff --git a/devtools/intermediate_output_tap/__init__.py b/devtools/intermediate_output_tap/__init__.py new file mode 100644 index 00000000000..07f99ad7148 --- /dev/null +++ b/devtools/intermediate_output_tap/__init__.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Public API for the ExecuTorch numerical debugger. + +Backend-agnostic intermediate-value tap: + +- Runtime side : USER_OUTPUT taps (this module — works through delegates without + any backend-side changes) + +Typical usage: + + from executorch.devtools.intermediate_output_tap import ( + compare_aot_runtime_dataframe, + tap_intermediate_outputs, strip_taps_, STATS, + ) + + ep = export(model, example_inputs) + ep_tapped, specs = tap_intermediate_outputs(ep, reducer=STATS) + aot_flat, _ = pytree.tree_flatten(ep_tapped.module()(*example_inputs)) + edge = to_edge_transform_and_lower(ep_tapped, partitioner=[XnnpackPartitioner()]) + strip_taps_(edge) + et_program = edge.to_executorch() + + rt_flat = runtime.forward(*example_inputs) + df = compare_aot_runtime_dataframe(specs, aot_flat, rt_flat) +""" + +# Importing this module registers torch.ops.executorch_devtools.tap.Tensor. +from executorch.devtools.intermediate_output_tap import custom_ops_lib # noqa: F401 +from executorch.devtools.intermediate_output_tap._convenience import ( + compare_aot_runtime_dataframe, + specs_to_dataframe, + tap_all_and_run, + tap_compare, +) +from executorch.devtools.intermediate_output_tap._reducers import ( + FULL_TENSOR, + get_reducer, + StatReducer, + STATS, +) +from executorch.devtools.intermediate_output_tap._selectors import ( + NodeSelector, + select_all, + select_all_call_function, + select_any, + select_by_meta_tag, + select_by_module_class, + select_by_module_path, + select_by_op_type, + select_not, +) +from executorch.devtools.intermediate_output_tap._spec import TapSpec +from executorch.devtools.intermediate_output_tap._strip_pass import strip_taps_ +from executorch.devtools.intermediate_output_tap._tap_pass import ( + find_tap_nodes, + is_tap_node, + tap_intermediate_outputs, + TapRule, +) + + +__all__ = [ + # Core API + "tap_intermediate_outputs", + "strip_taps_", + "TapSpec", + "TapRule", + # Convenience + "tap_compare", + "tap_all_and_run", + "specs_to_dataframe", + "compare_aot_runtime_dataframe", + # Reducers + "StatReducer", + "FULL_TENSOR", + "STATS", + "get_reducer", + # Selectors + "NodeSelector", + "select_all_call_function", + "select_by_op_type", + "select_by_module_path", + "select_by_module_class", + "select_by_meta_tag", + "select_any", + "select_all", + "select_not", + # Helpers + "find_tap_nodes", + "is_tap_node", +] diff --git a/devtools/intermediate_output_tap/_convenience.py b/devtools/intermediate_output_tap/_convenience.py new file mode 100644 index 00000000000..a12ec0922ab --- /dev/null +++ b/devtools/intermediate_output_tap/_convenience.py @@ -0,0 +1,266 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Convenience helpers built on top of `tap_intermediate_outputs` / `strip_taps_`: + +* `tap_compare`: one-shot helper that exports a model, taps it, lowers with + the user's partitioner, runs through the ExecuTorch runtime, and returns + the AOT-vs-runtime comparison DataFrame plus the tap specs. The simplest + way to use the intermediate-output tap. +* `tap_all_and_run`: smoke-test wrapper that exports a model, taps every + call_function, lowers with the user's partitioner, runs through the + ExecuTorch runtime, and returns a per-tap DataFrame of runtime values. +* `specs_to_dataframe`: build a per-tap DataFrame from a tap_specs list and + the runtime's flat output tuple. +* `compare_aot_runtime_dataframe`: side-by-side AOT-vs-runtime DataFrame from + the flat outputs of the *tapped* ExportedProgram (eager) and the post-strip + runtime program. +""" + +from __future__ import annotations + +import os +import tempfile +from collections.abc import Sequence +from typing import Any + +import pandas as pd +import torch +import torch.utils._pytree as pytree +from executorch.devtools.intermediate_output_tap._reducers import StatReducer +from executorch.devtools.intermediate_output_tap._selectors import ( + NodeSelector, + select_all_call_function, +) +from executorch.devtools.intermediate_output_tap._spec import TapSpec +from executorch.devtools.intermediate_output_tap._strip_pass import strip_taps_ +from executorch.devtools.intermediate_output_tap._tap_pass import ( + tap_intermediate_outputs, + TapRule, +) + + +def tap_compare( + model: torch.nn.Module, + example_inputs: tuple[Any, ...], + partitioner: list | None = None, + *, + selector: NodeSelector | None = None, + reducer: str | StatReducer = "STATS", + rules: Sequence[TapRule] | None = None, + error_on_empty: bool = True, +) -> tuple[pd.DataFrame, list[TapSpec]]: + """ + One-shot AOT-vs-runtime numerical-debugging helper. + + Runs the full pipeline: export -> tap -> capture AOT reference values + -> lower with `partitioner` -> strip -> to_executorch -> runtime + -> AOT-vs-runtime DataFrame. + + Args: + model: Eager nn.Module to debug. + example_inputs: Positional args to the model's forward. + partitioner: Optional list of partitioners passed to + `to_edge_transform_and_lower`. Defaults to `[]` (no delegation). + selector / reducer / rules: Same semantics as + `tap_intermediate_outputs`. Pass either `selector`/`reducer` + (single rule) or `rules=` (multiple rules with first-match + semantics). + error_on_empty: Same semantics as `tap_intermediate_outputs`. + + Returns: + A `(df, specs)` tuple where: + - `df`: side-by-side AOT-vs-runtime DataFrame from + `compare_aot_runtime_dataframe`. + - `specs`: list of `TapSpec`s in tap-creation order. + """ + from executorch.exir import to_edge_transform_and_lower + + ep = torch.export.export(model, example_inputs, strict=True) + ep_t, specs = tap_intermediate_outputs( + ep, + selector=selector, + reducer=reducer, + rules=rules, + error_on_empty=error_on_empty, + ) + + # AOT-side reference values: tap.Tensor's eager impl applies the reducer, + # so the flat outputs of the tapped EP already contain reduced values at + # the same positions the runtime will use. + aot_out = ep_t.module()(*example_inputs) + aot_flat, _ = pytree.tree_flatten(aot_out) + + edge = to_edge_transform_and_lower(ep_t, partitioner=partitioner or []) + strip_taps_(edge) + et_program = edge.to_executorch() + + flat_inputs, _ = pytree.tree_flatten(example_inputs) + rt_flat = list(_run_pte(et_program, flat_inputs)) + + df = compare_aot_runtime_dataframe(specs, aot_flat, rt_flat) + return df, specs + + +def tap_all_and_run( + model: torch.nn.Module, + example_inputs: tuple[Any, ...], + partitioner: list | None = None, + reducer: str | StatReducer = "STATS", + selector: NodeSelector | None = None, + skip_if_no_debug_handle: bool = True, +) -> pd.DataFrame: + """ + Export -> tap -> lower -> strip -> to_executorch -> run -> DataFrame. + + Returns a DataFrame indexed by tap with columns: + node_name, op_target, debug_handle, output_index, reducer_name, plus + one column per reducer field (or `value` for FULL_TENSOR). + """ + from executorch.exir import to_edge_transform_and_lower + + selector = selector or select_all_call_function() + ep = torch.export.export(model, example_inputs, strict=True) + ep_tapped, specs = tap_intermediate_outputs( + ep, + selector=selector, + reducer=reducer, + skip_if_no_debug_handle=skip_if_no_debug_handle, + ) + edge = to_edge_transform_and_lower(ep_tapped, partitioner=partitioner or []) + strip_taps_(edge) + et_program = edge.to_executorch() + + flat_outputs = _run_pte(et_program, example_inputs) + return specs_to_dataframe(specs, flat_outputs) + + +def _run_pte(et_program, example_inputs: tuple[Any, ...]) -> Sequence[Any]: + from executorch.runtime import Runtime, Verification + + with tempfile.TemporaryDirectory() as temp_dir: + pte_path = os.path.join(temp_dir, "model.pte") + et_program.save(pte_path) + rt = Runtime.get() + program = rt.load_program(pte_path, verification=Verification.Minimal) + method = program.load_method("forward") + return method.execute(example_inputs) + + +def specs_to_dataframe( + specs: Sequence[TapSpec], + flat_outputs: Sequence[Any], +) -> pd.DataFrame: + """Build a per-tap DataFrame from the tap_specs + flat output tuple.""" + rows = [] + for spec in specs: + runtime_value = flat_outputs[spec.output_index] + row: dict[str, Any] = { + "node_name": spec.node_name, + "op_target": spec.op_target, + "debug_handle": spec.debug_handle, + "output_index": spec.output_index, + "reducer_name": spec.reducer_name, + } + if spec.fields: + tensor_vals = ( + runtime_value.detach().cpu().tolist() + if isinstance(runtime_value, torch.Tensor) + else list(runtime_value) + ) + for i, field in enumerate(spec.fields): + row[field] = tensor_vals[i] if i < len(tensor_vals) else None + else: + row["value"] = runtime_value + rows.append(row) + return pd.DataFrame(rows) + + +def _flat_floats(v: Any) -> list[float]: + """Flatten a tap value (tensor / list / scalar) to a flat list of floats.""" + if isinstance(v, torch.Tensor): + return [ + float(x) for x in v.detach().to(torch.float32).cpu().reshape(-1).tolist() + ] + if isinstance(v, (list, tuple)): + out: list[float] = [] + for x in v: + out.extend(_flat_floats(x)) + return out + try: + return [float(v)] + except (TypeError, ValueError): + return [] + + +def _sqnr_db(aot_vals: list[float], rt_vals: list[float]) -> float: + """Signal-to-quantization-noise ratio in dB. Higher is better. + + Thin wrapper around `torch.ao.ns.fx.utils.compute_sqnr` (the canonical + implementation already used by `backends/test/harness/error_statistics.py`). + """ + from torch.ao.ns.fx.utils import compute_sqnr + + n = min(len(aot_vals), len(rt_vals)) + if n == 0: + return float("nan") + aot_t = torch.tensor(aot_vals[:n], dtype=torch.float32) + rt_t = torch.tensor(rt_vals[:n], dtype=torch.float32) + return float(compute_sqnr(rt_t, aot_t)) + + +def compare_aot_runtime_dataframe( + specs: Sequence[TapSpec], + aot_flat: Sequence[Any], + rt_flat: Sequence[Any], +) -> pd.DataFrame: + """ + Build a side-by-side AOT-vs-runtime DataFrame from the flat outputs of + the *tapped* ExportedProgram (eager) and the post-strip runtime program. + + Both `aot_flat[spec.output_index]` and `rt_flat[spec.output_index]` already + contain the *reduced* tap value, since `tap.Tensor`'s eager impl applies + the named reducer (see `custom_ops_lib.py`). + + Output columns per spec: + - For non-FULL_TENSOR reducers: one `aot_` and `rt_` per + reducer field (e.g. `aot_min`, `rt_min`, ...). + - For FULL_TENSOR: `sqnr_db` (signal-to-noise of aot vs rt over the + whole tensor, in dB) + """ + rows: list[dict[str, Any]] = [] + for spec in specs: + aot_vals = _flat_floats(aot_flat[spec.output_index]) + rt_vals = _flat_floats(rt_flat[spec.output_index]) + + row: dict[str, Any] = { + "node_name": spec.node_name, + "module_path": spec.module_path, + "module_class": spec.module_class, + "op_target": spec.op_target, + "reducer_name": spec.reducer_name, + "output_index": spec.output_index, + } + + if spec.reducer_name == "FULL_TENSOR": + row["sqnr_db"] = _sqnr_db(aot_vals, rt_vals) + row["aot_numel"] = len(aot_vals) + row["rt_numel"] = len(rt_vals) + else: + fields = ( + list(spec.fields) + if spec.fields + else [f"v{i}" for i in range(max(len(aot_vals), len(rt_vals)))] + ) + for i, f in enumerate(fields): + row[f"aot_{f}"] = aot_vals[i] if i < len(aot_vals) else float("nan") + row[f"rt_{f}"] = rt_vals[i] if i < len(rt_vals) else float("nan") + + rows.append(row) + return pd.DataFrame(rows) diff --git a/devtools/intermediate_output_tap/_reducers.py b/devtools/intermediate_output_tap/_reducers.py new file mode 100644 index 00000000000..481dfb2f55e --- /dev/null +++ b/devtools/intermediate_output_tap/_reducers.py @@ -0,0 +1,319 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Stat reducers used by `tap_intermediate_outputs`. + +A `StatReducer` is a small specification consumed by `strip_taps_` (after +`to_backend`) to materialise a portable reducer subgraph in place of the +`executorch_devtools::tap.Tensor` placeholder. + +`emit(graph, src_node) -> fx.Node` builds the reducer subgraph in `graph` +just before the placeholder, using the source tensor `src_node` as input, +and returns the final node whose output replaces the placeholder's output. + +`eager(tensor) -> tensor` is the pure-torch equivalent that callers can use +to reproduce, in eager mode, what the runtime will compute. `tap.Tensor`'s +own dispatch impl uses this to produce the *reduced* value at AOT time, so +that `ep.module()(*inputs)` returns the same flat outputs as the runtime. + +The emit functions cast to fp32 first for cross-backend numerical stability +and produce a fixed-shape output (0-D or 1-D) regardless of the source +tensor's rank, so callers don't need to track per-tap shapes. + +We ship two built-ins: + +* `FULL_TENSOR` — identity. The whole source tensor is surfaced. +* `STATS` — a comprehensive bundle of debugging-friendly scalars: + min, max, mean, abs_max, abs_mean, std, rms, l1_norm, l2_norm, + nan_count, inf_count, zero_count, p99_abs. +""" + +from __future__ import annotations + +import operator +from collections.abc import Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import torch +from executorch.exir.dialects._ops import ops as exir_ops + + +if TYPE_CHECKING: + import torch.fx as fx + + +# --- Reducer dataclass --------------------------------------------------- + +EmitFn = Callable[["fx.Graph", "fx.Node"], "fx.Node"] +EagerFn = Callable[[torch.Tensor], torch.Tensor] + + +@dataclass(frozen=True) +class StatReducer: + """ + A reducer specification. + + `emit` is invoked by `strip_taps_` to materialise the reducer subgraph + in the post-lowering FX graph. `eager` is the equivalent pure-torch + implementation, used by callers that want to reproduce what the runtime + will compute (e.g. AOT-vs-runtime comparisons without a debugger). + + `name` is what the user types and what's stored on each TapSpec. + `fields` enumerates the columns of the 1-D output tensor (empty for + FULL_TENSOR which preserves a tensor of values). + """ + + name: str + fields: tuple[str, ...] + emit: EmitFn + eager: EagerFn + + +# --- Helpers ------------------------------------------------------------- + + +def _cast_fp32(graph: "fx.Graph", x: "fx.Node") -> "fx.Node": + """Insert a fp32 cast (no-op semantically if already fp32).""" + return graph.call_function( + exir_ops.edge.aten._to_copy.default, + args=(x,), + kwargs={"dtype": torch.float32}, + ) + + +def _scalar_node(graph: "fx.Graph", op, x: "fx.Node") -> "fx.Node": + """Call a full-reduction op (amin/amax/mean/sum) producing a 0-d tensor.""" + return graph.call_function(op, args=(x,)) + + +def _stack(graph: "fx.Graph", scalars: list["fx.Node"]) -> "fx.Node": + """Stack a list of 0-d tensors into a 1-D tensor.""" + return graph.call_function( + exir_ops.edge.aten.stack.default, + args=(scalars,), + kwargs={"dim": 0}, + ) + + +def _abs(graph: "fx.Graph", x: "fx.Node") -> "fx.Node": + return graph.call_function(exir_ops.edge.aten.abs.default, args=(x,)) + + +def _square(graph: "fx.Graph", x: "fx.Node") -> "fx.Node": + return graph.call_function(exir_ops.edge.aten.pow.Tensor_Scalar, args=(x, 2.0)) + + +def _sqrt(graph: "fx.Graph", x: "fx.Node") -> "fx.Node": + return graph.call_function(exir_ops.edge.aten.sqrt.default, args=(x,)) + + +def _full_sum(graph: "fx.Graph", x: "fx.Node") -> "fx.Node": + """Full-tensor sum via aten.sum.dim_IntList(dim=[]) — portable + has out variant.""" + return graph.call_function(exir_ops.edge.aten.sum.dim_IntList, args=(x, [])) + + +def _bool_to_fp32_count(graph: "fx.Graph", mask: "fx.Node") -> "fx.Node": + """Sum of a bool mask cast to fp32 → a 0-d fp32 count.""" + casted = graph.call_function( + exir_ops.edge.aten._to_copy.default, + args=(mask,), + kwargs={"dtype": torch.float32}, + ) + return _full_sum(graph, casted) + + +# --- FULL_TENSOR --------------------------------------------------------- + + +def _emit_full_tensor(_graph: "fx.Graph", src: "fx.Node") -> "fx.Node": + return src + + +def _eager_full_tensor(t: torch.Tensor) -> torch.Tensor: + return t.detach() + + +FULL_TENSOR: StatReducer = StatReducer( + name="FULL_TENSOR", + fields=(), + emit=_emit_full_tensor, + eager=_eager_full_tensor, +) + + +# --- STATS --------------------------------------------------------------- + + +_STATS_FIELDS: tuple[str, ...] = ( + "min", + "max", + "mean", + "abs_max", + "abs_mean", + "std", + "rms", + "l1_norm", + "l2_norm", + "nan_count", + "inf_count", + "zero_count", + "p99_abs", +) + + +def _emit_stats(graph: "fx.Graph", src: "fx.Node") -> "fx.Node": + f = _cast_fp32(graph, src) + abs_f = _abs(graph, f) + sq_f = _square(graph, f) + + mn = _scalar_node(graph, exir_ops.edge.aten.amin.default, f) + mx = _scalar_node(graph, exir_ops.edge.aten.amax.default, f) + me = _scalar_node(graph, exir_ops.edge.aten.mean.default, f) + abs_max = _scalar_node(graph, exir_ops.edge.aten.amax.default, abs_f) + abs_mean = _scalar_node(graph, exir_ops.edge.aten.mean.default, abs_f) + + sum_sq = _full_sum(graph, sq_f) + mean_sq = _scalar_node(graph, exir_ops.edge.aten.mean.default, sq_f) + rms = _sqrt(graph, mean_sq) + + # std = sqrt( E[x^2] - E[x]^2 ); avoids aten.var which lacks an out variant. + me_sq_scalar = graph.call_function( + exir_ops.edge.aten.pow.Tensor_Scalar, args=(me, 2.0) + ) + var = graph.call_function( + exir_ops.edge.aten.sub.Tensor, args=(mean_sq, me_sq_scalar) + ) + # Variance can be slightly negative due to fp roundoff; clamp at 0 via abs. + var = graph.call_function(exir_ops.edge.aten.abs.default, args=(var,)) + std = _sqrt(graph, var) + + l1 = _full_sum(graph, abs_f) + l2 = _sqrt(graph, sum_sq) + + nan_mask = graph.call_function(exir_ops.edge.aten.isnan.default, args=(f,)) + nan_count = _bool_to_fp32_count(graph, nan_mask) + + inf_mask = graph.call_function(exir_ops.edge.aten.isinf.default, args=(f,)) + inf_count = _bool_to_fp32_count(graph, inf_mask) + + zero_mask = graph.call_function(exir_ops.edge.aten.eq.Scalar, args=(f, 0.0)) + zero_count = _bool_to_fp32_count(graph, zero_mask) + + # p99_abs: use topk on flattened |x| to get the k-th largest, where + # k = max(1, ceil(numel * 0.01)). Numel is read from the source's + # FakeTensor at graph-build time. + fake = src.meta.get("val") + numel = int(fake.numel()) if fake is not None else 1 + k = max(1, (numel + 99) // 100) # ceil(numel/100) + abs_flat = graph.call_function( + exir_ops.edge.aten.view_copy.default, args=(abs_f, [-1]) + ) + topk_out = graph.call_function( + exir_ops.edge.aten.topk.default, + args=(abs_flat, k), + kwargs={"dim": -1, "largest": True, "sorted": True}, + ) + topk_values = graph.call_function(operator.getitem, args=(topk_out, 0)) + p99_abs = graph.call_function( + exir_ops.edge.aten.select_copy.int, args=(topk_values, 0, k - 1) + ) + + return _stack( + graph, + [ + mn, + mx, + me, + abs_max, + abs_mean, + std, + rms, + l1, + l2, + nan_count, + inf_count, + zero_count, + p99_abs, + ], + ) + + +def _eager_stats(t: torch.Tensor) -> torch.Tensor: + f = t.detach().to(torch.float32) + abs_f = f.abs() + sq = f.pow(2.0) + + # std via E[x^2] - E[x]^2 (population variance) to match the emit subgraph. + if f.numel() > 0: + var = (sq.mean() - f.mean().pow(2)).abs() + std = var.sqrt() + else: + std = torch.tensor(0.0) + + sum_sq = sq.sum() + rms = sq.mean().sqrt() + l1 = abs_f.sum() + l2 = sum_sq.sqrt() + + nan_count = torch.isnan(f).to(torch.float32).sum() + inf_count = torch.isinf(f).to(torch.float32).sum() + zero_count = (f == 0).to(torch.float32).sum() + + numel = f.numel() + k = max(1, (numel + 99) // 100) + if numel > 0: + topk_vals = torch.topk(abs_f.reshape(-1), k=k, largest=True, sorted=True).values + p99_abs = topk_vals[k - 1] + else: + p99_abs = torch.tensor(float("nan")) + + return torch.stack( + [ + f.amin(), + f.amax(), + f.mean(), + abs_f.amax(), + abs_f.mean(), + std, + rms, + l1, + l2, + nan_count, + inf_count, + zero_count, + p99_abs, + ], + dim=0, + ) + + +STATS: StatReducer = StatReducer( + name="STATS", + fields=_STATS_FIELDS, + emit=_emit_stats, + eager=_eager_stats, +) + + +# --- Registry ------------------------------------------------------------- + +_BUILTIN_REDUCERS: dict[str, StatReducer] = {r.name: r for r in (FULL_TENSOR, STATS)} + + +def get_reducer(name_or_reducer: str | StatReducer) -> StatReducer: + """Look up a built-in by name, or return a user-supplied StatReducer as-is.""" + if isinstance(name_or_reducer, StatReducer): + return name_or_reducer + if name_or_reducer not in _BUILTIN_REDUCERS: + raise ValueError( + f"Unknown reducer {name_or_reducer!r}; " + f"available: {sorted(_BUILTIN_REDUCERS)}" + ) + return _BUILTIN_REDUCERS[name_or_reducer] diff --git a/devtools/intermediate_output_tap/_selectors.py b/devtools/intermediate_output_tap/_selectors.py new file mode 100644 index 00000000000..059d2069692 --- /dev/null +++ b/devtools/intermediate_output_tap/_selectors.py @@ -0,0 +1,216 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +""" +Predicates for selecting which FX nodes to tap. + +A `NodeSelector` is just `Callable[[fx.Node], bool]`. The provided builders +let you compose them by op type, by `nn_module_stack` path, by arbitrary meta +tag, and via boolean combinators. + +Examples: + selector = select_any( + select_by_op_type("aten.linear.default", "aten.matmul.default"), + select_by_module_path("layers.*.attention"), + ) + selector = select_all(selector, select_not(select_by_op_type("aten.view.default"))) +""" + +from __future__ import annotations + +import fnmatch +from collections.abc import Callable +from typing import Any + +import torch.fx as fx + + +NodeSelector = Callable[[fx.Node], bool] + + +def select_all_call_function( + exclude: tuple[str, ...] = ("getitem",), +) -> NodeSelector: + """Match every `call_function` node whose target name is not in `exclude`.""" + excluded = set(exclude) + + def predicate(n: fx.Node) -> bool: + if n.op != "call_function": + return False + target_name = getattr(n.target, "__name__", str(n.target)) + # `getitem` shows up as the builtin name; also normalise common aten suffixes. + return target_name not in excluded and str(n.target) not in excluded + + return predicate + + +def select_by_op_type(*op_targets: str) -> NodeSelector: + """ + Match nodes whose `str(node.target)` ends with any of `op_targets`. + + The "ends with" rule lets the user write either the short name + ("aten.linear.default") or a fully-qualified name and have it match. + """ + if not op_targets: + raise ValueError("select_by_op_type requires at least one op target") + suffixes = tuple(op_targets) + + def predicate(n: fx.Node) -> bool: + if n.op != "call_function": + return False + target_str = str(n.target) + return any(target_str.endswith(s) or target_str == s for s in suffixes) + + return predicate + + +def select_by_module_path(*patterns: str) -> NodeSelector: + """ + Match nodes whose `nn_module_stack` (the chain of nn.Module attribute names + walked to reach this op during tracing) contains a path matching ANY of + `patterns`. Each pattern is a shell-glob (fnmatch) — e.g. "layers.*", + "layers.0.attention", "*.attention.*". + + Example: + # Match anything inside layers 5 through 9. + select_by_module_path(*[f"*.layers.{i}.*" for i in range(5, 10)]) + """ + if not patterns: + raise ValueError("select_by_module_path requires at least one pattern") + pats = tuple(patterns) + + def predicate(n: fx.Node) -> bool: + stack = n.meta.get("nn_module_stack") + if not stack: + return False + # nn_module_stack is an OrderedDict: id -> (qualified_path, module_type) + for entry in stack.values(): + path = entry[0] if isinstance(entry, tuple) else entry + for p in pats: + if fnmatch.fnmatchcase(path, p): + return True + return False + + return predicate + + +def _bare_class_name(mod_type) -> str: + """Extract the bare class name from an `nn_module_stack` type entry.""" + cls_name = getattr(mod_type, "__name__", None) + if cls_name is None: + cls_name = str(mod_type).rsplit(".", 1)[-1].rstrip("'>") + return cls_name + + +def select_by_module_class( + *class_names: str, + output_only: bool = False, +) -> NodeSelector: + """ + Match nodes whose `nn_module_stack` contains a module of one of + `class_names`. + + `class_names` are matched against the bare class name (e.g. "Attention", + "RMSNorm"), not the fully-qualified type. This lets the selector survive + moves between modules, and avoids the fnmatch escaping needed for paths. + + The `nn_module_stack` entry's second element is typically either a class + object or a string like `"my.pkg.module.MyClass"`; both forms are handled. + + Args: + class_names: One or more bare class names to match. + output_only: If True, match only the *terminal* node of each matching + module instance — i.e., a node N is matched iff (a) N is inside an + instance of a target class, and (b) no user of N is inside the + same module instance. This taps only the value(s) flowing out of + the module, not every intermediate op. Defaults to False (match + every intermediate op). + + Example: + # Tap every intermediate op inside any RMSNorm. + select_by_module_class("RMSNorm") + + # Tap only the value flowing out of each RMSNorm. + select_by_module_class("RMSNorm", output_only=True) + """ + if not class_names: + raise ValueError("select_by_module_class requires at least one class name") + names = set(class_names) + + def _matching_instance_ids(n: fx.Node) -> list: + stack = n.meta.get("nn_module_stack") + if not stack: + return [] + ids = [] + for mod_id, entry in stack.items(): + if not isinstance(entry, tuple) or len(entry) < 2: + continue + if _bare_class_name(entry[1]) in names: + ids.append(mod_id) + return ids + + if not output_only: + def predicate(n: fx.Node) -> bool: + return bool(_matching_instance_ids(n)) + + return predicate + + def predicate_terminal(n: fx.Node) -> bool: + my_ids = _matching_instance_ids(n) + if not my_ids: + return False + # Terminal iff no user shares any of `my_ids` in its nn_module_stack. + my_id_set = set(my_ids) + for user in n.users: + user_stack = user.meta.get("nn_module_stack") or {} + if any(uid in my_id_set for uid in user_stack.keys()): + return False + return True + + return predicate_terminal + + +# Sentinel: matches when the meta key exists at all, regardless of value. +_ANY_VALUE: object = object() + + +def select_by_meta_tag(key: str, value: Any = _ANY_VALUE) -> NodeSelector: + """ + Match nodes that carry `node.meta[key]`. If `value` is provided, also + requires `node.meta[key] == value`. + """ + + def predicate(n: fx.Node) -> bool: + if key not in n.meta: + return False + if value is _ANY_VALUE: + return True + return n.meta[key] == value + + return predicate + + +def select_any(*selectors: NodeSelector) -> NodeSelector: + """Match if ANY of `selectors` matches.""" + if not selectors: + return lambda _n: False + sels = tuple(selectors) + return lambda n: any(s(n) for s in sels) + + +def select_all(*selectors: NodeSelector) -> NodeSelector: + """Match if ALL of `selectors` match.""" + if not selectors: + return lambda _n: True + sels = tuple(selectors) + return lambda n: all(s(n) for s in sels) + + +def select_not(selector: NodeSelector) -> NodeSelector: + """Match if `selector` does NOT match.""" + return lambda n: not selector(n) diff --git a/devtools/intermediate_output_tap/_spec.py b/devtools/intermediate_output_tap/_spec.py new file mode 100644 index 00000000000..1c24c753110 --- /dev/null +++ b/devtools/intermediate_output_tap/_spec.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +""" +TapSpec records one tap inserted by `tap_intermediate_outputs(...)`. + +A list of TapSpecs is returned to the user from the AOT pass; the user uses +the `output_index` on each spec to demux the runtime program's flat output +tuple back into per-op intermediate values (e.g. via +`compare_aot_runtime_dataframe`). +""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class TapSpec: + """ + Metadata about a single tap. + + Attributes: + node_name: The FX node name of the *source* node (the tapped op) at the + time the AOT pass ran. Useful for debugging / pretty-printing. + op_target: `str(node.target)` of the source node, e.g. + "aten.linear.default". + debug_handle: `node.meta["debug_handle"]` of the source node, or None + if the source had no debug handle. Set at AOT-pass time. + output_index: 0-based index into the runtime program's flat output + tuple where this tap's value lands. Computed at AOT time and stable + through `to_edge` / `to_backend` / `to_executorch` because we only + ever *append* to the output node and `OutputSpec`. + reducer_name: Name of the StatReducer used (e.g. "STATS"). + fields: Names of the per-element fields in the reducer's output tensor + (e.g. ("min", "max", "abs_max")). Empty tuple for FULL_TENSOR. + stack_trace: `node.meta["stack_trace"]` of the source node if present, + for human-readable error messages. + module_path: The `nn_module_stack` path of the source node, e.g. + "layers.1.attention.wvs.0", or None if not available. + module_class: Bare class name of the leaf nn.Module the source node + ran inside (e.g. "Linear", "_RMSNorm"), or None if not available. + """ + + node_name: str + op_target: str + debug_handle: int | None + output_index: int + reducer_name: str + fields: tuple[str, ...] + stack_trace: str | None = None + module_path: str | None = None + module_class: str | None = None diff --git a/devtools/intermediate_output_tap/_strip_pass.py b/devtools/intermediate_output_tap/_strip_pass.py new file mode 100644 index 00000000000..3ef4cada91a --- /dev/null +++ b/devtools/intermediate_output_tap/_strip_pass.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Post-`to_backend` pass: replace each `executorch_devtools::tap.Tensor` node +with either an identity edge (FULL_TENSOR) or a portable reducer subgraph +(STATS, or any user-supplied StatReducer). + +Pattern stolen from `remove_graph_break_` in +`executorch/examples/apple/coreml/llama/export_static_llm_coreml.py`. + +This pass MUST run *after* `to_edge_transform_and_lower(...)` and *before* +`to_executorch()`. Running it before partitioning would defeat the whole +mechanism (the reducer ops would be eligible for delegation). +""" + +from __future__ import annotations + +import torch.fx as fx +from executorch.devtools.intermediate_output_tap._reducers import get_reducer +from executorch.devtools.intermediate_output_tap._tap_pass import find_tap_nodes + + +def strip_taps_(edge_manager) -> None: + """ + Replace every `tap.Tensor(src, reducer_name, debug_handle)` node in every + method of `edge_manager` with the materialised reducer subgraph, in place. + + For FULL_TENSOR the placeholder is collapsed (the source node's value + flows directly to whatever consumed the placeholder). + + Args: + edge_manager: An EdgeProgramManager (post-`to_edge_transform_and_lower`). + """ + for method_name in edge_manager.methods: + ep = edge_manager.exported_program(method_name) + _strip_taps_in_graph_module(ep.graph_module) + + +def _strip_taps_in_graph_module(gm: fx.GraphModule) -> None: + """Strip taps in a single GraphModule, in place.""" + graph = gm.graph + tap_nodes = find_tap_nodes(gm) + if not tap_nodes: + return + + output_node = graph.output_node() + + for tap in tap_nodes: + # tap.args = (src_node, reducer_name, debug_handle) + src, reducer_name = tap.args[0], tap.args[1] + reducer = get_reducer(str(reducer_name)) + + if reducer.name == "FULL_TENSOR": + # Identity: re-route all consumers to the source. + tap.replace_all_uses_with(src) + continue + + # Build the reducer subgraph (reads from src). + with graph.inserting_before(tap): + replacement = reducer.emit(graph, src) + replacement.meta["is_tap"] = True + replacement.meta["source_node"] = src.name if isinstance(src, fx.Node) else None + + # `tap` may have ended up in the data path during to_edge's re-trace + # (because CompositeExplicitAutograd preserves the op as an identity + # node, and re-traced consumers point at it instead of `src`). So: + # - the OUTPUT-node use becomes the reducer (the value we want + # surfaced as a tap). + # - every OTHER use is rewritten back to `src` (identity passthrough), + # restoring the original data path. + for use_node in list(tap.users.keys()): + if use_node is output_node: + new_outs = tuple( + replacement if a is tap else a for a in output_node.args[0] + ) + output_node.args = (new_outs,) + else: + use_node.replace_input_with(tap, src) + + graph.eliminate_dead_code() + gm.recompile() diff --git a/devtools/intermediate_output_tap/_tap_pass.py b/devtools/intermediate_output_tap/_tap_pass.py new file mode 100644 index 00000000000..6089feb7ac9 --- /dev/null +++ b/devtools/intermediate_output_tap/_tap_pass.py @@ -0,0 +1,327 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +AOT pass: insert `tap.Tensor` placeholders after selected nodes and surface +them as additional USER_OUTPUTs of the ExportedProgram. + +Pattern stolen from `executorch/exir/passes/weights_to_outputs_pass.py`: +- find existing output node +- build new output args (existing + new tap nodes) +- create new output node, replace_all_uses_with, erase old +- append OutputSpec(USER_OUTPUT) entries to gs.output_specs +- eliminate_dead_code() + recompile() +""" + +from __future__ import annotations + +import copy +import warnings +from collections.abc import Callable, Sequence + +import torch +import torch.fx as fx +from executorch.devtools.intermediate_output_tap import ( # noqa: F401 registers tap.Tensor + custom_ops_lib, +) +from executorch.devtools.intermediate_output_tap._reducers import ( + get_reducer, + StatReducer, + STATS, +) +from executorch.devtools.intermediate_output_tap._selectors import ( + NodeSelector, + select_all_call_function, +) +from executorch.devtools.intermediate_output_tap._spec import TapSpec +from torch.export import ExportedProgram +from torch.export.exported_program import OutputKind, OutputSpec, TensorArgument + + +# Public alias: one (selector, reducer) pair, for use with the `rules=` arg. +TapRule = tuple[NodeSelector, "str | StatReducer"] + + +# Don't ever tap our own tap nodes if a user runs the pass twice. +# `tap.Tensor` is already an OpOverload (not a packet) since "Tensor" is the +# overload name — same convention as torch.ops.executorch_utils.graph_break.Tensor. +_TAP_TARGET = torch.ops.executorch_devtools.tap.Tensor + + +def _is_tap_node(n: fx.Node) -> bool: + return n.op == "call_function" and n.target is _TAP_TARGET + + +def tap_intermediate_outputs( # noqa: C901 + ep: ExportedProgram, + selector: NodeSelector | None = None, + reducer: str | StatReducer = STATS, + *, + rules: Sequence[TapRule] | None = None, + tap_name_prefix: str = "tap_", + skip_if_no_debug_handle: bool = False, + max_taps: int | None = None, + inplace: bool = False, + error_on_empty: bool = True, +) -> tuple[ExportedProgram, list[TapSpec]]: + """ + Rewrite `ep` so each node matching a selector has its output appended to + the program outputs (wrapped in a `tap.Tensor` placeholder that survives + partitioning). Returns the new ExportedProgram and a list of TapSpecs. + + The returned EP is safe to feed to + `to_edge_transform_and_lower(...).to_executorch()` *after* calling + `strip_taps_(edge_manager)` to replace the placeholders with their + reducer subgraphs (or identities, for FULL_TENSOR). + + Args: + ep: The ExportedProgram to tap. + selector: A predicate over fx.Node. Defaults to + `select_all_call_function()`. Tap nodes themselves are always + excluded so re-running the pass is idempotent. Ignored when + `rules=` is provided. + reducer: Either a built-in reducer name ("STATS", "FULL_TENSOR") + or a custom StatReducer instance. Ignored when `rules=` is + provided. + rules: Optional sequence of `(selector, reducer)` pairs. When + provided, each node is tested against the rules in order and + tapped with the FIRST matching rule's reducer. A single + `(selector, reducer)` tuple is also accepted as a shortcut for + `[(selector, reducer)]`. Mutually exclusive with + `selector`/`reducer`. + tap_name_prefix: Prefix for the tap nodes' names. Helps when + grepping the dumped graph. + skip_if_no_debug_handle: If True, only tap nodes that already + carry `node.meta["debug_handle"]`. + very large models. + inplace: If False (default), deep-copy `ep` before mutating. + error_on_empty: If True (default), raise `ValueError` when no nodes + match any selector. Set to False to only emit a `UserWarning` + and return `(ep, [])` — handy when iterating on selector + patterns. + """ + # Normalize to a list of (selector, reducer_obj) pairs. + if rules is not None: + if selector is not None: + raise ValueError( + "tap_intermediate_outputs: pass either `selector`/`reducer` " + "or `rules=`, not both." + ) + # Allow a single `(selector, reducer)` tuple as syntactic sugar for + # `[(selector, reducer)]`. We detect "single rule" as a 2-tuple whose + # first element is callable (i.e. a NodeSelector). + if ( + isinstance(rules, tuple) + and len(rules) == 2 + and callable(rules[0]) + ): + rules = [rules] + if not rules: + raise ValueError("tap_intermediate_outputs: `rules` must be non-empty.") + normalized_rules: list[tuple[NodeSelector, StatReducer]] = [ + (sel, get_reducer(red)) for sel, red in rules + ] + else: + sel = selector if selector is not None else select_all_call_function() + normalized_rules = [(sel, get_reducer(reducer))] + + if not inplace: + ep = copy.deepcopy(ep) + + gs = ep.graph_signature + gm = ep.graph_module + graph = gm.graph + output_node = graph.output_node() + existing_outputs = list(output_node.args[0]) + + # Snapshot before we start mutating the graph. + candidate_nodes = [n for n in graph.nodes if not _is_tap_node(n)] + + specs: list[TapSpec] = [] + new_tap_nodes: list[fx.Node] = [] + + for node in candidate_nodes: + if node.op != "call_function": + continue + # First-match wins. + matched: StatReducer | None = None + for sel, red in normalized_rules: + if sel(node): + matched = red + break + if matched is None: + continue + debug_handle = node.meta.get("debug_handle") + if skip_if_no_debug_handle and debug_handle is None: + continue + if max_taps is not None and len(specs) >= max_taps: + break + + # tap.Tensor's int arg cannot be None; sentinel 0 means "no handle". + dh_arg = int(debug_handle) if isinstance(debug_handle, int) else 0 + + with graph.inserting_after(node): + tap_node = graph.call_function( + _TAP_TARGET, + args=(node, matched.name, dh_arg), + ) + # Don't override the auto-assigned name — FX guarantees uniqueness. + # Stash the prefixed-source-name in meta for human-readable logs. + tap_node.meta["tap_label"] = f"{tap_name_prefix}{node.name}" + # Preserve provenance for users that pretty-print the graph. + if debug_handle is not None: + tap_node.meta["debug_handle"] = debug_handle + if "from_node" in node.meta: + tap_node.meta["from_node"] = node.meta["from_node"] + if "stack_trace" in node.meta: + tap_node.meta["stack_trace"] = node.meta["stack_trace"] + if "nn_module_stack" in node.meta: + tap_node.meta["nn_module_stack"] = node.meta["nn_module_stack"] + tap_node.meta["is_tap"] = True + tap_node.meta["source_node"] = node.name + + new_tap_nodes.append(tap_node) + # Leaf module FQN + bare class from nn_module_stack (e.g., + # "layers.0.attention.wqs.0" / "Linear"). + module_path: str | None = None + module_class: str | None = None + stack = node.meta.get("nn_module_stack") + if stack: + try: + last_entry = list(stack.values())[-1] + if isinstance(last_entry, tuple): + module_path = last_entry[0] + if len(last_entry) >= 2: + mod_type = last_entry[1] + cls_name = getattr(mod_type, "__name__", None) + if cls_name is None: + cls_name = ( + str(mod_type).rsplit(".", 1)[-1].rstrip("'>") + ) + module_class = cls_name + else: + module_path = str(last_entry) + except Exception: + module_path = None + module_class = None + specs.append( + TapSpec( + node_name=node.name, + op_target=str(node.target), + debug_handle=debug_handle if isinstance(debug_handle, int) else None, + output_index=len(existing_outputs) + len(specs), + reducer_name=matched.name, + fields=matched.fields, + stack_trace=node.meta.get("stack_trace"), + module_path=module_path, + module_class=module_class, + ) + ) + + if not new_tap_nodes: + msg = ( + "tap_intermediate_outputs: selector matched 0 nodes. " + "Double-check your selector predicates, " + "or pass `error_on_empty=False` to suppress this error." + ) + if error_on_empty: + raise ValueError(msg) + warnings.warn(msg, UserWarning, stacklevel=2) + return ep, [] + + # Splice new outputs into the graph (mirror weights_to_outputs_pass). + new_output_args = tuple(existing_outputs + new_tap_nodes) + with graph.inserting_before(output_node): + new_output = graph.output(new_output_args) + output_node.replace_all_uses_with(new_output) + graph.erase_node(output_node) + + # Append OutputSpec entries so the EP's signature matches the graph. + for tap_node in new_tap_nodes: + gs.output_specs.append( + OutputSpec( + kind=OutputKind.USER_OUTPUT, + arg=TensorArgument(name=tap_node.name), + target=None, + ) + ) + + # Update each ModuleCallSignature's out_spec so `to_edge`'s re-trace can + # unflatten the new flat output list. The "" (root) entry holds the + # user-facing forward output structure; we wrap it in a tuple alongside + # the new tap leaves and re-derive the spec. + _extend_module_call_graph_outputs(ep, new_tap_nodes) + + graph.eliminate_dead_code() + gm.recompile() + return ep, specs + + +def _extend_module_call_graph_outputs( + ep: ExportedProgram, + new_tap_nodes: list[fx.Node], +) -> None: + """ + Append `len(new_tap_nodes)` extra leaves to the root module-call entry's + `out_spec` so the pytree unflatten step in `run_decompositions` works. + Also extends the entry's `outputs: list[ArgumentSpec]`. + + NOTE: We append TensorArgument(name="") for each new tap output. Empty + names are *skipped* by `_verify_exported_program_module_call_graph` (its + check is `if arg.name and arg.name not in nodes`). We can't use the + pre-trace tap node names because `to_edge`'s re-trace renames nodes via + `from_node` chains, and our tap nodes' provenance wouldn't update them + correctly — leading to "Output X does not exist in the graph" errors. + The verifier's name check is metadata-only; the actual pytree unflatten + only needs `out_spec` to have the correct number of leaves. + """ + import torch.utils._pytree as pytree + from torch.export.exported_program import TensorArgument as _TensorArgument + + n_new = len(new_tap_nodes) + if n_new == 0: + return + + for entry in ep._module_call_graph: + if entry.fqn != "": + continue + sig = entry.signature + if sig is None: + continue + old_spec = sig.out_spec + # Build a dummy structure matching the old spec, then wrap with N new + # leaves and re-derive the spec. This handles arbitrary pytree shapes. + old_dummy = pytree.tree_unflatten([0] * old_spec.num_leaves, old_spec) + if isinstance(old_dummy, tuple): + new_dummy = (*old_dummy, *([0] * n_new)) + else: + new_dummy = (old_dummy, *([0] * n_new)) + sig.out_spec = pytree.tree_structure(new_dummy) + for _ in range(n_new): + sig.outputs.append(_TensorArgument(name="")) + break + + +def find_tap_nodes(gm: fx.GraphModule) -> list[fx.Node]: + """Helper: enumerate tap.Tensor nodes in a GraphModule (any dialect).""" + out: list[fx.Node] = [] + for n in gm.graph.nodes: + if n.op != "call_function": + continue + # Match across dialects: + # pre-edge: torch.ops.executorch_devtools.tap.Tensor — str ends with name + # post-edge: : schema = ... + # so substring-match the qualified name. + if "executorch_devtools.tap.Tensor" in str(n.target) or n.target is _TAP_TARGET: + out.append(n) + return out + + +# Re-export the predicate so callers can identify tap nodes without importing +# torch.ops directly. +is_tap_node: Callable[[fx.Node], bool] = _is_tap_node diff --git a/devtools/intermediate_output_tap/custom_ops_lib.py b/devtools/intermediate_output_tap/custom_ops_lib.py new file mode 100644 index 00000000000..79436915c27 --- /dev/null +++ b/devtools/intermediate_output_tap/custom_ops_lib.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Custom op registration for the intermediate-output tap mechanism. + +The op `executorch_devtools::tap.Tensor(Tensor x, str reducer_name, int debug_handle) -> Tensor` +is a placeholder whose sole job is to be an unknown-to-every-partitioner FX +node that "uses" a tapped tensor `x`. Because `x` now has a user outside any +partition, every ExecuTorch partitioner must surface `x` as a partition output +(this is the canonical contract enforced in +`executorch/exir/lowered_backend_module.py`). + +After `to_edge_transform_and_lower(...)` the tap.Tensor node still exists in +the parent graph; `strip_taps_` (see `_strip_pass.py`) replaces it with either +an identity edge (FULL_TENSOR) or a small reducer subgraph of portable aten +ops. + +The op's eager impl computes the named reducer's eager equivalent (e.g. +`min/max/mean/abs_max/...` for STATS). Two reasons for this: + +1. **Re-trace safety.** `to_edge_transform_and_lower` re-traces the graph. + If `tap.Tensor` simply returned `x` literally, the re-traced FX graph + would treat the tap output as identical to `x` and re-route downstream + consumers (which would otherwise be reading `x`) through the tap node, + pulling it into a delegate's input list. Returning a *different* tensor + (different shape for non-FULL_TENSOR; `x.detach()` for FULL_TENSOR) + keeps consumers wired to `x` directly so the tap stays a host-only stub + with the FX `output` node as its sole consumer. +2. **AOT/runtime parity.** Calling `ep_t.module()(*inputs)` then returns the + same reduced values the runtime emits post-strip, removing the need for + callers to reapply the reducer themselves. + +The dispatch key MUST be `CompositeExplicitAutograd` (not +`CompositeImplicitAutograd`) so the op survives tracing/decomposition; +otherwise it would inline at export time and disappear before partitioning. +This mirrors the pattern in +`executorch/examples/apple/coreml/llama/export_static_llm_coreml.py`. + +`reducer_name` and `debug_handle` are stored as op arguments (not just +node.meta) so they survive any meta-stripping pass between `to_edge` and +`strip_taps_`. +""" + +from __future__ import annotations + +from torch.library import impl, Library + +# Library namespace verified collision-free across fbsource as of Nov 2025. +lib: Library = Library("executorch_devtools", "DEF") + +lib.define("tap.Tensor(Tensor x, str reducer_name, int debug_handle) -> Tensor") + + +@impl(lib, "tap.Tensor", "CompositeExplicitAutograd") +def tap_tensor_impl(x, reducer_name, debug_handle): # noqa: ARG001 + # Defer the import to break a module-import cycle (`_reducers` → torch → + # custom_ops_lib registration). + from executorch.devtools.intermediate_output_tap._reducers import get_reducer + + return get_reducer(reducer_name).eager(x) diff --git a/devtools/intermediate_output_tap/tests/TARGETS b/devtools/intermediate_output_tap/tests/TARGETS new file mode 100644 index 00000000000..e0c8c82dd13 --- /dev/null +++ b/devtools/intermediate_output_tap/tests/TARGETS @@ -0,0 +1,58 @@ +load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") + +oncall("executorch") + +python_unittest( + name = "test_selectors", + srcs = ["test_selectors.py"], + deps = [ + "//caffe2:torch", + "//executorch/devtools/intermediate_output_tap:selectors", + ], +) + +python_unittest( + name = "test_reducers", + srcs = ["test_reducers.py"], + deps = [ + "//caffe2:torch", + "//executorch/devtools/intermediate_output_tap:reducers", + ], +) + +python_unittest( + name = "test_tap_pass", + srcs = ["test_tap_pass.py"], + deps = [ + "//caffe2:torch", + "//executorch/devtools/intermediate_output_tap:reducers", + "//executorch/devtools/intermediate_output_tap:selectors", + "//executorch/devtools/intermediate_output_tap:tap_pass", + ], +) + +python_unittest( + name = "test_strip_pass", + srcs = ["test_strip_pass.py"], + deps = [ + "//caffe2:torch", + "//executorch/devtools/intermediate_output_tap:reducers", + "//executorch/devtools/intermediate_output_tap:selectors", + "//executorch/devtools/intermediate_output_tap:strip_pass", + "//executorch/devtools/intermediate_output_tap:tap_pass", + "//executorch/exir:lib", + ], +) + +python_unittest( + name = "test_xnnpack_e2e", + srcs = ["test_xnnpack_e2e.py"], + deps = [ + "//caffe2:torch", + "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + "//executorch/devtools/backend_debug:delegation_info", + "//executorch/devtools/intermediate_output_tap:lib", + "//executorch/exir:lib", + "//executorch/runtime:runtime", + ], +) diff --git a/devtools/intermediate_output_tap/tests/test_coreml_e2e.py b/devtools/intermediate_output_tap/tests/test_coreml_e2e.py new file mode 100644 index 00000000000..7635291009d --- /dev/null +++ b/devtools/intermediate_output_tap/tests/test_coreml_e2e.py @@ -0,0 +1,254 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Numerical-debugging tutorial for the ExecuTorch CoreML backend. + +This script walks through end-to-end use of the intermediate-output tap +infrastructure on the static-attention Llama from +`examples/apple/coreml/llama/`. It is a smoke test (random weights, tiny +ModelArgs — no checkpoint download required) and produces two tables: + +1. A delegation summary showing how many subgraphs ExecuTorch handed off + to CoreML and which operators ran on which side. + +2. An AOT-vs-runtime comparison of the tapped intermediate values, so you + can see numerical drift between eager-PyTorch and the CoreML runtime + at hand-picked points in the model. + +Run with: + python swift_play/test_inspector_coreml.py +""" + +import os +import tempfile + +import coremltools as ct +import pandas as pd +import torch +import torch.utils._pytree as pytree +from executorch.backends.apple.coreml.compiler import CoreMLBackend +from executorch.backends.apple.coreml.partition import CoreMLPartitioner +from executorch.devtools.backend_debug import get_delegation_info +from executorch.devtools.intermediate_output_tap import ( + compare_aot_runtime_dataframe, + FULL_TENSOR, + select_all, + select_any, + select_by_module_path, + select_by_op_type, + STATS, + strip_taps_, + tap_intermediate_outputs, +) +import types + +from executorch.examples.apple.coreml.llama.export_static_llm_coreml import ( + _create_example_inputs, + _transform_eager_model, + remove_graph_break_, +) +from executorch.examples.models.llama.llama_transformer import construct_transformer +from executorch.examples.models.llama.model_args import ModelArgs +from executorch.exir import to_edge_transform_and_lower +from executorch.runtime import Runtime, Verification +from torch.export import export + + +def _build_model() -> tuple[torch.nn.Module, ModelArgs]: + """Build a tiny static-attention Llama with random weights.""" + args = ModelArgs( + dim=64, + n_layers=2, + n_heads=4, + n_kv_heads=2, + vocab_size=128, + hidden_dim=128, + max_seq_len=64, + max_context_len=64, + generate_full_logits=True, + ) + args.attention_type = "static_mha" + args.attention_kwargs = {"decompose_sdpa_in_mha": True} + + model = construct_transformer(args) + transform_args = types.SimpleNamespace( + target_split_size=None, + max_splits=8, + embedding_quantize="", + linear_quantize="c4w", + no_graph_breaks=False, + ) + model = _transform_eager_model(model, transform_args, torch.float16) + return model, args + + +def main() -> None: + # ------------------------------------------------------------------ + # Step 1: Build and quantize the model. + # ------------------------------------------------------------------ + model, model_args = _build_model() + + # ------------------------------------------------------------------ + # Step 2: Create example inputs and export. + # ------------------------------------------------------------------ + input_len = 8 + example_inputs, cache_len = _create_example_inputs( + model_args, + input_len=input_len, + max_context_len=model_args.max_context_len, + float_dtype=torch.float16, + ) + print(f"input_len={input_len} cache_len={cache_len}") + + with torch.no_grad(): + _ = model(*example_inputs) # eager sanity check + + print("Exporting...") + ep = export(model, example_inputs) + + # ------------------------------------------------------------------ + # Step 3: Pick which intermediate values to tap. + # + # We use two reducers in a single pass via `rules=`: + # + # * FULL_TENSOR for the embedding output and `layers.1.attention.wvs.0` + # — surfaces raw activation tensors; the comparison DataFrame + # computes SQNR over all elements. + # + # * STATS for everything else (`output`, all wqs/wks linears, layer 0's + # wvs, and layer 1's RMSNorm output mul) — gives a rich set of + # debugging scalars (min/max/mean/std/rms/l1/l2/abs_max/abs_mean/ + # nan_count/inf_count/zero_count/p99_abs). + # + # Patterns use `*` between `layers..` and the inner module so they match + # both the bare path (`layers..attention...`) and the wrapped path + # (`layers..block.attention...`) that BlockWithGraphBreak introduces + # at the partition boundaries. + # + # Rules are tried in order; the first match wins per node. + # ------------------------------------------------------------------ + selector_full_tensor = select_any( + # Token-embedding output (one big tensor, before any transformer block). + select_by_op_type("aten.embedding.default"), + # First wvs linear in layer 1 — captures full activation post-Q/K/V. + select_all( + select_by_op_type("aten.linear.default"), + select_by_module_path("layers.1.*attention.wvs.*"), + ), + ) + selector_stats = select_any( + select_all( + select_by_op_type("aten.linear.default"), + select_any( + select_by_module_path("output"), + select_by_module_path("*.attention.wqs.*"), + select_by_module_path("*.attention.wks.*"), + select_by_module_path("layers.0.*attention.wvs.*"), + ), + ), + select_all( + select_by_op_type("aten.mul.Tensor"), + select_any( + select_by_module_path("layers.1.*attention_norm"), + select_by_module_path("layers.1.*attention_norm.*"), + select_by_module_path("layers.1.*ffn_norm"), + select_by_module_path("layers.1.*ffn_norm.*"), + ), + ), + ) + + ep_t, specs = tap_intermediate_outputs( + ep, + rules=[ + (selector_full_tensor, FULL_TENSOR), + (selector_stats, STATS), + ], + ) + n_full = sum(1 for s in specs if s.reducer_name == "FULL_TENSOR") + n_stats = sum(1 for s in specs if s.reducer_name == "STATS") + print(f"Inserted {len(specs)} tap(s) ({n_full} FULL_TENSOR + {n_stats} STATS).") + + # ------------------------------------------------------------------ + # Step 4: Capture the AOT-side reference values. + # + # `tap.Tensor`'s eager impl applies the reducer, so the flat outputs of + # the tapped EP already contain reduced values at the same positions + # the runtime will use. We pytree-flatten because the static-llama + # forward returns nested (logits, (k_caches, v_caches)). + # ------------------------------------------------------------------ + aot_out = ep_t.module()(*example_inputs) + aot_flat, _ = pytree.tree_flatten(aot_out) + + # ------------------------------------------------------------------ + # Step 5: Lower to CoreML, strip the taps, and show what got delegated. + # ------------------------------------------------------------------ + coreml_partitioner = CoreMLPartitioner( + compile_specs=CoreMLBackend.generate_compile_specs( + minimum_deployment_target=ct.target.iOS18, + compute_precision=ct.precision.FLOAT16, + compute_unit=ct.ComputeUnit.CPU_AND_NE, + ), + ) + print("Lowering to CoreML...") + edge = to_edge_transform_and_lower(ep_t, partitioner=[coreml_partitioner]) + # Drop the `executorch_utils::graph_break.Tensor` placeholders that + # `_transform_eager_model` inserted to force partition boundaries — they + # have no out-variant kernel, so they must not survive into the runtime + # program. + remove_graph_break_(edge) + strip_taps_(edge) + + delegation_info = get_delegation_info(edge.exported_program().graph_module) + print( + "\n=== Delegation summary " + f"(num_delegated_subgraphs={delegation_info.num_delegated_subgraphs}) ===" + ) + print(delegation_info.get_summary()) + with pd.option_context( + "display.max_columns", None, + "display.width", 240, + "display.max_colwidth", 60, + ): + print( + delegation_info.get_operator_delegation_dataframe().to_string(index=False) + ) + + # ------------------------------------------------------------------ + # Step 6: Save the .pte and run it through the ExecuTorch runtime. + # ------------------------------------------------------------------ + et_program = edge.to_executorch() + with tempfile.TemporaryDirectory() as temp_dir: + pte_path = os.path.join(temp_dir, "model.pte") + et_program.save(pte_path) + print(f"\nSaved PTE: {pte_path} ({os.path.getsize(pte_path)} bytes)") + + rt = Runtime.get() + program = rt.load_program(pte_path, verification=Verification.Minimal) + method = program.load_method("forward") + # Runtime takes a flat tensor list — flatten the (tokens, options_dict) + # pytree the same way torch.export did. + flat_inputs, _ = pytree.tree_flatten(example_inputs) + rt_flat = list(method.execute(flat_inputs)) + + # ------------------------------------------------------------------ + # Step 7: Compare AOT vs runtime. + # ------------------------------------------------------------------ + df = compare_aot_runtime_dataframe(specs, aot_flat, rt_flat) + with pd.option_context( + "display.max_columns", None, + "display.width", 280, + "display.max_colwidth", 30, + "display.float_format", "{:.4g}".format, + ): + print(f"\n{len(specs)} tap(s) — AOT vs CoreML runtime:") + print(df.to_string(index=False)) + + +if __name__ == "__main__": + main() diff --git a/devtools/intermediate_output_tap/tests/test_reducers.py b/devtools/intermediate_output_tap/tests/test_reducers.py new file mode 100644 index 00000000000..9970da265d9 --- /dev/null +++ b/devtools/intermediate_output_tap/tests/test_reducers.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest + +import torch +from executorch.devtools.intermediate_output_tap._reducers import ( + FULL_TENSOR, + get_reducer, + StatReducer, + STATS, +) + + +class ReducersTest(unittest.TestCase): + def test_get_reducer_by_name(self): + self.assertIs(get_reducer("FULL_TENSOR"), FULL_TENSOR) + self.assertIs(get_reducer("STATS"), STATS) + + def test_get_reducer_passthrough(self): + custom = StatReducer( + name="X", + fields=("a",), + emit=lambda g, n: n, + eager=lambda t: t, + ) + self.assertIs(get_reducer(custom), custom) + + def test_get_reducer_unknown_raises(self): + with self.assertRaises(ValueError): + get_reducer("DOES_NOT_EXIST") + + def test_reducer_field_counts(self): + self.assertEqual(FULL_TENSOR.fields, ()) + self.assertEqual( + STATS.fields, + ( + "min", + "max", + "mean", + "abs_max", + "abs_mean", + "std", + "rms", + "l1_norm", + "l2_norm", + "nan_count", + "inf_count", + "zero_count", + "p99_abs", + ), + ) + + def test_reducer_names_unique(self): + names = {r.name for r in (FULL_TENSOR, STATS)} + self.assertEqual(len(names), 2) + + def test_full_tensor_eager_is_identity(self): + t = torch.randn(2, 3, 4) + out = FULL_TENSOR.eager(t) + self.assertEqual(out.shape, t.shape) + torch.testing.assert_close(out, t.detach()) + + def test_stats_eager_correctness(self): + torch.manual_seed(0) + t = torch.randn(64) + out = STATS.eager(t) + self.assertEqual(out.shape, (len(STATS.fields),)) + + f = t.to(torch.float32) + expected = { + "min": float(f.amin()), + "max": float(f.amax()), + "mean": float(f.mean()), + "abs_max": float(f.abs().amax()), + "abs_mean": float(f.abs().mean()), + "rms": float(f.pow(2).mean().sqrt()), + "l1_norm": float(f.abs().sum()), + "l2_norm": float(f.pow(2).sum().sqrt()), + "nan_count": 0.0, + "inf_count": 0.0, + "zero_count": float((f == 0).to(torch.float32).sum()), + } + for i, field in enumerate(STATS.fields): + if field in expected: + torch.testing.assert_close( + float(out[i]), expected[field], rtol=1e-4, atol=1e-5 + ) + # std uses E[x^2] - E[x]^2 (population variance); compare to that. + pop_var = float((f.pow(2).mean() - f.mean().pow(2)).abs()) + torch.testing.assert_close( + float(out[STATS.fields.index("std")]) ** 2, + pop_var, + rtol=1e-4, + atol=1e-5, + ) + + def test_stats_p99_abs_matches_topk(self): + torch.manual_seed(0) + t = torch.randn(1000) + out = STATS.eager(t) + numel = t.numel() + k = max(1, (numel + 99) // 100) + expected = float( + torch.topk(t.abs().reshape(-1), k=k, largest=True, sorted=True).values[ + k - 1 + ] + ) + torch.testing.assert_close( + float(out[STATS.fields.index("p99_abs")]), + expected, + rtol=1e-4, + atol=1e-5, + ) + + def test_stats_counts_nan_and_inf(self): + t = torch.tensor( + [1.0, float("nan"), 2.0, float("inf"), 0.0, -float("inf"), 0.0] + ) + out = STATS.eager(t) + i_nan = STATS.fields.index("nan_count") + i_inf = STATS.fields.index("inf_count") + i_zero = STATS.fields.index("zero_count") + self.assertEqual(float(out[i_nan]), 1.0) + self.assertEqual(float(out[i_inf]), 2.0) + self.assertEqual(float(out[i_zero]), 2.0) diff --git a/devtools/intermediate_output_tap/tests/test_selectors.py b/devtools/intermediate_output_tap/tests/test_selectors.py new file mode 100644 index 00000000000..80b5f1b73f4 --- /dev/null +++ b/devtools/intermediate_output_tap/tests/test_selectors.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest + +import torch +from executorch.devtools.intermediate_output_tap._selectors import ( + select_all, + select_all_call_function, + select_any, + select_by_meta_tag, + select_by_module_class, + select_by_module_path, + select_by_op_type, + select_not, +) +from torch.export import export + + +class _Inner(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x): + return self.linear(x).relu() + + +class _Outer(torch.nn.Module): + def __init__(self): + super().__init__() + self.inner = _Inner() + self.head = torch.nn.Linear(4, 2) + + def forward(self, x): + return self.head(self.inner(x)) + + +def _exported_graph(): + ep = export(_Outer(), (torch.randn(2, 4),), strict=True) + return ep.graph_module.graph + + +class SelectorsTest(unittest.TestCase): + def setUp(self): + self.graph = _exported_graph() + self.call_nodes = [n for n in self.graph.nodes if n.op == "call_function"] + + def test_select_all_call_function_excludes_getitem(self): + sel = select_all_call_function() + for n in self.call_nodes: + if "getitem" in str(n.target): + self.assertFalse(sel(n)) + else: + self.assertTrue(sel(n)) + + def test_select_by_op_type_matches_suffix(self): + sel = select_by_op_type("aten.linear.default", "aten.relu.default") + matched = [n for n in self.call_nodes if sel(n)] + # Two linears + one relu in the model. + self.assertGreaterEqual(len(matched), 2) + for n in matched: + self.assertTrue( + str(n.target).endswith("aten.linear.default") + or str(n.target).endswith("aten.relu.default") + ) + + def test_select_by_op_type_requires_target(self): + with self.assertRaises(ValueError): + select_by_op_type() + + def test_select_by_module_path(self): + sel = select_by_module_path("inner.*") + matched = [n for n in self.call_nodes if sel(n)] + # inner contains a linear and a relu. + self.assertGreater(len(matched), 0) + for n in matched: + stack = n.meta.get("nn_module_stack") or {} + paths = [v[0] if isinstance(v, tuple) else v for v in stack.values()] + self.assertTrue(any(p.startswith("inner") for p in paths)) + + def test_select_by_module_path_multi_pattern(self): + # Multi-pattern call should be equivalent to OR-ing single-pattern selectors. + multi = select_by_module_path("inner.*", "head") + a = select_by_module_path("inner.*") + b = select_by_module_path("head") + for n in self.call_nodes: + self.assertEqual(multi(n), a(n) or b(n)) + + def test_select_by_module_path_requires_arg(self): + with self.assertRaises(ValueError): + select_by_module_path() + + def test_select_by_module_class_matches_inner(self): + # `_Inner` is the only nested module class; we should match every op + # that lives inside an `_Inner` instance (the inner linear + relu). + sel = select_by_module_class("_Inner") + matched = [n for n in self.call_nodes if sel(n)] + self.assertGreater(len(matched), 0) + # And the head linear (which is owned by `_Outer`, not `_Inner`) + # should NOT match. + head_linears = [ + n for n in self.call_nodes + if str(n.target).endswith("aten.linear.default") + and not sel(n) + ] + self.assertGreaterEqual(len(head_linears), 1) + + def test_select_by_module_class_multi(self): + # Either Inner or Outer should pick up every call node that has any + # nn_module_stack entry. + sel = select_by_module_class("_Inner", "_Outer") + for n in self.call_nodes: + if n.meta.get("nn_module_stack"): + self.assertTrue(sel(n)) + + def test_select_by_module_class_output_only(self): + # `_Inner` does `self.linear(x).relu()` — the relu is the value the + # outer module receives, so it's the only terminal of `_Inner`. + sel_all = select_by_module_class("_Inner") + sel_term = select_by_module_class("_Inner", output_only=True) + all_in_inner = [n for n in self.call_nodes if sel_all(n)] + terminals = [n for n in self.call_nodes if sel_term(n)] + self.assertGreater(len(all_in_inner), len(terminals)) + # Exactly one terminal per `_Inner` instance — the relu. + self.assertEqual(len(terminals), 1) + self.assertTrue(str(terminals[0].target).endswith("aten.relu.default")) + + def test_select_by_module_class_requires_arg(self): + with self.assertRaises(ValueError): + select_by_module_class() + + def test_select_by_module_class_no_match(self): + sel = select_by_module_class("NoSuchClass") + for n in self.call_nodes: + self.assertFalse(sel(n)) + + def test_select_by_meta_tag_presence(self): + for n in self.call_nodes[:1]: + n.meta["debug_me"] = "yes" + sel = select_by_meta_tag("debug_me") + self.assertTrue(sel(self.call_nodes[0])) + self.assertFalse(sel(self.call_nodes[1])) + + def test_select_by_meta_tag_value(self): + self.call_nodes[0].meta["color"] = "blue" + self.call_nodes[1].meta["color"] = "red" + sel = select_by_meta_tag("color", "blue") + self.assertTrue(sel(self.call_nodes[0])) + self.assertFalse(sel(self.call_nodes[1])) + + def test_select_combinators(self): + a = select_by_op_type("aten.linear.default") + b = select_by_op_type("aten.relu.default") + any_sel = select_any(a, b) + all_sel = select_all(a, b) + not_sel = select_not(a) + + for n in self.call_nodes: + if a(n) or b(n): + self.assertTrue(any_sel(n)) + self.assertEqual(all_sel(n), a(n) and b(n)) + self.assertEqual(not_sel(n), not a(n)) + + def test_select_any_empty(self): + for n in self.call_nodes: + self.assertFalse(select_any()(n)) + + def test_select_all_empty(self): + for n in self.call_nodes: + self.assertTrue(select_all()(n)) diff --git a/devtools/intermediate_output_tap/tests/test_strip_pass.py b/devtools/intermediate_output_tap/tests/test_strip_pass.py new file mode 100644 index 00000000000..0f8dc5fec10 --- /dev/null +++ b/devtools/intermediate_output_tap/tests/test_strip_pass.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest + +import torch +from executorch.devtools.intermediate_output_tap._reducers import ( + FULL_TENSOR, + STATS, +) +from executorch.devtools.intermediate_output_tap._selectors import select_by_op_type +from executorch.devtools.intermediate_output_tap._strip_pass import strip_taps_ +from executorch.devtools.intermediate_output_tap._tap_pass import ( + find_tap_nodes, + tap_intermediate_outputs, +) +from executorch.exir import to_edge +from torch.export import export + + +class _MLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(8, 8) + self.l2 = torch.nn.Linear(8, 4) + + def forward(self, x): + return self.l2(self.l1(x).relu()) + + +def _tapped_edge(reducer): + ep = export(_MLP(), (torch.randn(2, 8),), strict=True) + ep_t, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=reducer, + ) + return to_edge(ep_t), specs + + +class StripPassTest(unittest.TestCase): + def test_strip_removes_all_tap_nodes_full_tensor(self): + edge, _ = _tapped_edge(FULL_TENSOR) + # Pre-strip: tap nodes present. + for method_name in edge.methods: + ep = edge.exported_program(method_name) + self.assertGreater(len(find_tap_nodes(ep.graph_module)), 0) + + strip_taps_(edge) + + # Post-strip: no tap nodes. + for method_name in edge.methods: + ep = edge.exported_program(method_name) + self.assertEqual(len(find_tap_nodes(ep.graph_module)), 0) + + def test_strip_full_tensor_routes_source_to_output(self): + edge, specs = _tapped_edge(FULL_TENSOR) + strip_taps_(edge) + # Output node should still have all the user outputs + tap outputs. + for method_name in edge.methods: + ep = edge.exported_program(method_name) + outs = list(ep.graph_module.graph.output_node().args[0]) + # Original outputs + 2 linears tapped. + self.assertGreaterEqual(len(outs), len(specs)) + + def test_strip_stats_emits_subgraph(self): + edge, _ = _tapped_edge(STATS) + strip_taps_(edge) + for method_name in edge.methods: + ep = edge.exported_program(method_name) + self.assertEqual(len(find_tap_nodes(ep.graph_module)), 0) + # Some reduction op (amin/amax/mean) should now be in the graph. + # Substring match because EdgeOpOverload's str() looks like + # ": schema = ..." (no clean + # endswith). + targets = {str(n.target) for n in ep.graph_module.graph.nodes} + self.assertTrue( + any( + "aten.amin" in t or "aten.amax" in t or "aten.mean" in t + for t in targets + ), + f"expected reducer ops in graph, got {targets}", + ) + + def test_strip_idempotent(self): + edge, _ = _tapped_edge(FULL_TENSOR) + strip_taps_(edge) + # Second call should be a no-op. + strip_taps_(edge) + for method_name in edge.methods: + ep = edge.exported_program(method_name) + self.assertEqual(len(find_tap_nodes(ep.graph_module)), 0) diff --git a/devtools/intermediate_output_tap/tests/test_tap_pass.py b/devtools/intermediate_output_tap/tests/test_tap_pass.py new file mode 100644 index 00000000000..ddd3204e70c --- /dev/null +++ b/devtools/intermediate_output_tap/tests/test_tap_pass.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import copy +import unittest + +import torch +from executorch.devtools.intermediate_output_tap._reducers import ( + FULL_TENSOR, + STATS, +) +from executorch.devtools.intermediate_output_tap._selectors import select_by_op_type +from executorch.devtools.intermediate_output_tap._tap_pass import ( + is_tap_node, + tap_intermediate_outputs, +) +from torch.export import export +from torch.export.exported_program import OutputKind + + +class _MLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(8, 16) + self.l2 = torch.nn.Linear(16, 8) + self.l3 = torch.nn.Linear(8, 4) + + def forward(self, x): + return self.l3(self.l2(self.l1(x).relu()).relu()) + + +def _export(): + return export(_MLP(), (torch.randn(2, 8),), strict=True) + + +class TapPassTest(unittest.TestCase): + def test_inserts_tap_per_selected_node(self): + ep = _export() + ep_t, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + ) + # MLP has 3 linears. + self.assertEqual(len(specs), 3) + tap_nodes = [n for n in ep_t.graph_module.graph.nodes if is_tap_node(n)] + self.assertEqual(len(tap_nodes), 3) + + def test_appends_user_outputs(self): + ep = _export() + original_user_outs = sum( + 1 + for s in ep.graph_signature.output_specs + if s.kind == OutputKind.USER_OUTPUT + ) + ep_t, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + ) + new_user_outs = sum( + 1 + for s in ep_t.graph_signature.output_specs + if s.kind == OutputKind.USER_OUTPUT + ) + self.assertEqual(new_user_outs, original_user_outs + len(specs)) + + def test_output_indices_contiguous_after_user_outputs(self): + ep = _export() + original_user_outs = sum( + 1 + for s in ep.graph_signature.output_specs + if s.kind == OutputKind.USER_OUTPUT + ) + _, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + ) + for i, spec in enumerate(specs): + self.assertEqual(spec.output_index, original_user_outs + i) + + def test_default_reducer_is_default_stats(self): + ep = _export() + _, specs = tap_intermediate_outputs( + ep, selector=select_by_op_type("aten.linear.default") + ) + for s in specs: + self.assertEqual(s.reducer_name, STATS.name) + self.assertEqual(s.fields, STATS.fields) + + def test_inplace_false_does_not_mutate_original(self): + ep = _export() + before_outs = len(list(ep.graph_module.graph.output_node().args[0])) + before_specs = len(ep.graph_signature.output_specs) + _ = tap_intermediate_outputs( + ep, selector=select_by_op_type("aten.linear.default"), reducer=FULL_TENSOR + ) + after_outs = len(list(ep.graph_module.graph.output_node().args[0])) + after_specs = len(ep.graph_signature.output_specs) + self.assertEqual(before_outs, after_outs) + self.assertEqual(before_specs, after_specs) + + def test_max_taps(self): + ep = _export() + _, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + max_taps=2, + ) + self.assertEqual(len(specs), 2) + + def test_idempotent_does_not_tap_taps(self): + ep = _export() + ep_once, specs1 = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + ) + # Running again should not add NEW taps for our existing tap nodes. + ep_twice, specs2 = tap_intermediate_outputs( + ep_once, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + ) + # Same number of linears matched; tap.Tensor itself is excluded. + self.assertEqual(len(specs2), len(specs1)) + + def test_no_match_returns_empty_specs(self): + ep = _export() + ep_t, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.does.not.exist"), + reducer=FULL_TENSOR, + error_on_empty=False, + ) + self.assertEqual(specs, []) + # Original graph signature is unchanged. + self.assertEqual( + len(ep_t.graph_signature.output_specs), + len(ep.graph_signature.output_specs), + ) + + def test_skip_if_no_debug_handle(self): + ep = _export() + # Strip all debug handles to simulate a graph without them. + ep_clean = copy.deepcopy(ep) + for n in ep_clean.graph_module.graph.nodes: + n.meta.pop("debug_handle", None) + _, specs = tap_intermediate_outputs( + ep_clean, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + skip_if_no_debug_handle=True, + error_on_empty=False, + ) + self.assertEqual(specs, []) diff --git a/devtools/intermediate_output_tap/tests/test_xnnpack_e2e.py b/devtools/intermediate_output_tap/tests/test_xnnpack_e2e.py new file mode 100644 index 00000000000..86c097b6e8e --- /dev/null +++ b/devtools/intermediate_output_tap/tests/test_xnnpack_e2e.py @@ -0,0 +1,335 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Numerical-debugging tutorial for the ExecuTorch XNNPACK backend. + +Showcases three ways to use the intermediate-output tap: + +1. `test_tap_compare_xnnpack` — the one-shot `tap_compare(...)` helper. + Shortest path; recommended for most users. + +2. `test_low_level_pipeline_xnnpack` — the manual pipeline using + `tap_intermediate_outputs` + `strip_taps_` + `compare_aot_runtime_dataframe` + directly. Use this when you need to insert custom edge-program transforms + (e.g., `remove_graph_break_` in the CoreML pipeline) between lowering and + stripping, or when you want delegation summaries / introspection on the + edge program. + +3. `test_tap_compare_static_transformer` — a small static-attention + transformer with RMSNorm. Demonstrates per-layer tap selection: `wo` + (attention output projection) in layers 2/5/8, plus every RMSNorm op in + those same layers, using `select_by_module_class` + multi-pattern + `select_by_module_path`. +""" + +import os +import sys +import tempfile +import unittest + +import pandas as pd +import torch +import torch.utils._pytree as pytree +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.devtools.backend_debug import get_delegation_info +from executorch.devtools.intermediate_output_tap import ( + compare_aot_runtime_dataframe, + FULL_TENSOR, + select_all, + select_by_module_class, + select_by_module_path, + select_by_op_type, + STATS, + strip_taps_, + tap_compare, + tap_intermediate_outputs, +) +from executorch.exir import to_edge_transform_and_lower +from executorch.runtime import Runtime, Verification +from torch.export import export + + +class _MLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(8, 16) + self.l2 = torch.nn.Linear(16, 4) + + def forward(self, x): + return self.l2(self.l1(x).relu()) + + +# ---------------------------------------------------------------------- +# Tiny static-attention transformer used by the third test. +# ---------------------------------------------------------------------- +class _RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(dim)) + self.eps = eps + + def forward(self, x): + norm = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + return norm * self.weight + + +class _Attention(torch.nn.Module): + def __init__(self, dim: int, n_heads: int): + super().__init__() + self.n_heads = n_heads + self.head_dim = dim // n_heads + self.wq = torch.nn.Linear(dim, dim, bias=False) + self.wk = torch.nn.Linear(dim, dim, bias=False) + self.wv = torch.nn.Linear(dim, dim, bias=False) + self.wo = torch.nn.Linear(dim, dim, bias=False) + + def forward(self, x): + b, t, d = x.shape + q = self.wq(x).view(b, t, self.n_heads, self.head_dim).transpose(1, 2) + k = self.wk(x).view(b, t, self.n_heads, self.head_dim).transpose(1, 2) + v = self.wv(x).view(b, t, self.n_heads, self.head_dim).transpose(1, 2) + scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5) + attn = torch.softmax(scores, dim=-1) + out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(b, t, d) + return self.wo(out) + + +class _FeedForward(torch.nn.Module): + def __init__(self, dim: int, hidden: int): + super().__init__() + self.w1 = torch.nn.Linear(dim, hidden, bias=False) + self.w2 = torch.nn.Linear(hidden, dim, bias=False) + + def forward(self, x): + return self.w2(torch.nn.functional.relu(self.w1(x))) + + +class _Block(torch.nn.Module): + def __init__(self, dim: int, n_heads: int, hidden: int): + super().__init__() + self.attn_norm = _RMSNorm(dim) + self.attention = _Attention(dim, n_heads) + self.ffn_norm = _RMSNorm(dim) + self.feed_forward = _FeedForward(dim, hidden) + + def forward(self, x): + x = x + self.attention(self.attn_norm(x)) + x = x + self.feed_forward(self.ffn_norm(x)) + return x + + +class _StaticTransformer(torch.nn.Module): + def __init__( + self, + n_layers: int = 10, + dim: int = 32, + n_heads: int = 4, + hidden: int = 64, + vocab_size: int = 64, + ): + super().__init__() + self.tok_embeddings = torch.nn.Embedding(vocab_size, dim) + self.layers = torch.nn.ModuleList( + [_Block(dim, n_heads, hidden) for _ in range(n_layers)] + ) + self.norm = _RMSNorm(dim) + self.output = torch.nn.Linear(dim, vocab_size, bias=False) + + def forward(self, tokens): + x = self.tok_embeddings(tokens) + for layer in self.layers: + x = layer(x) + return self.output(self.norm(x)) + + +def _build_rules(): + """Two-rule selection: FULL_TENSOR on `l1`, STATS on `l2`.""" + return [ + ( + select_all( + select_by_op_type("aten.linear.default"), + select_by_module_path("l1"), + ), + FULL_TENSOR, + ), + ( + select_all( + select_by_op_type("aten.linear.default"), + select_by_module_path("l2"), + ), + STATS, + ), + ] + + +def _print_df(df: pd.DataFrame, header: str) -> None: + with pd.option_context( + "display.max_columns", None, + "display.width", 280, + "display.max_colwidth", 30, + "display.float_format", "{:.4g}".format, + ): + print(f"\n{header}") + print(df.to_string(index=False)) + + +def _assert_df_shape(df: pd.DataFrame, specs) -> None: + """Common shape assertions for both test paths.""" + assert len(specs) == 2, f"expected 2 taps, got {len(specs)}" + assert {s.reducer_name for s in specs} == {"FULL_TENSOR", "STATS"} + assert "sqnr_db" in df.columns, "expected FULL_TENSOR sqnr_db column" + assert "aot_min" in df.columns, "expected STATS aot_min column" + + +@unittest.skipIf( + sys.platform.startswith("win"), "ExecuTorch runtime not available on Windows" +) +class XnnpackEndToEndTest(unittest.TestCase): + # ------------------------------------------------------------------ + # Path 1: One-shot `tap_compare(...)` + # ------------------------------------------------------------------ + def test_tap_compare_xnnpack(self): + model = _MLP() + example_inputs = (torch.randn(2, 8),) + + df, specs = tap_compare( + model, + example_inputs, + partitioner=[XnnpackPartitioner()], + rules=_build_rules(), + ) + + _assert_df_shape(df, specs) + _print_df(df, f"[tap_compare] {len(specs)} tap(s) — AOT vs XNNPACK runtime:") + + # ------------------------------------------------------------------ + # Path 2: Low-level pipeline (manual steps) + # ------------------------------------------------------------------ + def test_low_level_pipeline_xnnpack(self): + model = _MLP() + example_inputs = (torch.randn(2, 8),) + + # Step 1: Export. + ep = export(model, example_inputs, strict=True) + + # Step 2: Insert taps. + ep_t, specs = tap_intermediate_outputs(ep, rules=_build_rules()) + + # Step 3: Capture AOT-side reference values via the tapped EP. + # `tap.Tensor`'s eager impl applies the reducer, so the flat outputs + # already contain reduced values at the same positions the runtime + # will use. + aot_out = ep_t.module()(*example_inputs) + aot_flat, _ = pytree.tree_flatten(aot_out) + + # Step 4: Lower to XNNPACK and strip the taps. + edge = to_edge_transform_and_lower( + ep_t, partitioner=[XnnpackPartitioner()] + ) + # (At this point you can run any custom edge transforms — e.g., + # `remove_graph_break_(edge)` — before stripping.) + strip_taps_(edge) + + # Bonus: print the delegation summary so you can see what XNNPACK + # took. This is one of the things the low-level path gives you that + # `tap_compare` hides. + delegation_info = get_delegation_info( + edge.exported_program().graph_module + ) + print( + "\n[low-level] === Delegation summary " + f"(num_delegated_subgraphs={delegation_info.num_delegated_subgraphs}) ===" + ) + print(delegation_info.get_summary()) + + # Step 5: Save the .pte and run it through the ExecuTorch runtime. + et_program = edge.to_executorch() + with tempfile.TemporaryDirectory() as temp_dir: + pte_path = os.path.join(temp_dir, "model.pte") + et_program.save(pte_path) + + rt = Runtime.get() + program = rt.load_program(pte_path, verification=Verification.Minimal) + method = program.load_method("forward") + flat_inputs, _ = pytree.tree_flatten(example_inputs) + rt_flat = list(method.execute(flat_inputs)) + + # Step 6: Diff AOT vs runtime. + df = compare_aot_runtime_dataframe(specs, aot_flat, rt_flat) + + _assert_df_shape(df, specs) + _print_df(df, f"[low-level] {len(specs)} tap(s) — AOT vs XNNPACK runtime:") + + # ------------------------------------------------------------------ + # Path 3: Per-layer selection on a small static transformer. + # + # Taps `attention.wo` and every op inside `_RMSNorm` in layers 2, 5, 8. + # Demonstrates `select_by_module_class` and the multi-pattern form of + # `select_by_module_path`. + # ------------------------------------------------------------------ + def test_tap_compare_static_transformer(self): + torch.manual_seed(0) + model = _StaticTransformer( + n_layers=10, dim=32, n_heads=4, hidden=64, vocab_size=64 + ) + example_inputs = (torch.randint(0, 64, (1, 8)),) + + target_layers = (2, 5, 8) + layer_patterns = [f"layers.{i}.*" for i in target_layers] + + rules = [ + # `wo` (attention output projection) in the target layers. + ( + select_all( + select_by_op_type("aten.linear.default"), + select_by_module_path( + *[f"layers.{i}.attention.wo" for i in target_layers] + ), + ), + FULL_TENSOR, + ), + # The terminal output of each `_RMSNorm` instance in the target + # layers (one tap per RMSNorm instance — not every internal op). + ( + select_all( + select_by_module_class("_RMSNorm", output_only=True), + select_by_module_path(*layer_patterns), + ), + STATS, + ), + ] + + df, specs = tap_compare( + model, + example_inputs, + partitioner=[XnnpackPartitioner()], + rules=rules, + ) + + # We expect exactly three `wo` taps (one per target layer). + wo_specs = [s for s in specs if s.reducer_name == "FULL_TENSOR"] + norm_specs = [s for s in specs if s.reducer_name == "STATS"] + self.assertEqual(len(wo_specs), 3, f"got {len(wo_specs)} wo taps: {wo_specs}") + # Every `wo` tap should live in one of the target layers. + for s in wo_specs: + self.assertTrue( + any(f"layers.{i}.attention.wo" in (s.module_path or "") for i in target_layers), + f"unexpected wo module_path: {s.module_path}", + ) + # Two RMSNorms per block (attn_norm + ffn_norm) × 3 target layers + # = 6 RMSNorm output taps. + self.assertEqual( + len(norm_specs), 6, f"got {len(norm_specs)} RMSNorm taps: {norm_specs}" + ) + + _print_df( + df, + f"[transformer] {len(specs)} tap(s) — AOT vs XNNPACK runtime " + f"({len(wo_specs)} wo + {len(norm_specs)} RMSNorm outputs):", + )