Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 62 additions & 9 deletions codegen/tools/gen_max_kernel_num.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand All @@ -10,7 +10,12 @@
it as a C header.

Total = sum of (op, kernel_key) variants across all input YAMLs
+ prim ops always registered by kernels/prim_ops/register_prim_ops.cpp.
+ prim ops registered by kernels/prim_ops/register_prim_ops.cpp.

The prim-ops contribution is counted from the .cpp source by default. When
ET_PRIM_OPS_SELECTIVE_BUILD is active (only some prim ops compile in), pass
--selected-prim-ops-header instead so the count reflects what's actually
linked.

See runtime/kernel/operator_registry.cpp for how the emitted header is
consumed and the full precedence order. Users that register kernels outside
Expand Down Expand Up @@ -53,6 +58,12 @@
)
PRIM_OPS_KERNEL_RE = re.compile(r"\bKernel\s*\(")

# Matches the `#define INCLUDE_<MACRO>` lines emitted by gen_selected_prim_ops.py
# into selected_prim_ops.h. Each define gates one prim op entry in
# register_prim_ops.cpp under ET_PRIM_OPS_SELECTIVE_BUILD, so the count of
# defines equals the number of prim ops actually compiled in.
SELECTED_PRIM_OPS_DEFINE_RE = re.compile(r"^\s*#\s*define\s+INCLUDE_\S+", re.MULTILINE)


def _count_prim_ops(prim_ops_source: Path) -> int:
source = prim_ops_source.read_text()
Expand All @@ -72,6 +83,10 @@
return count


def _count_selected_prim_ops(selected_prim_ops_header: Path) -> int:
return len(SELECTED_PRIM_OPS_DEFINE_RE.findall(selected_prim_ops_header.read_text()))


def _count_yaml_kernels(yaml_path: Path) -> Optional[int]:
"""Returns the kernel count for one YAML, or None if the YAML opts into
include_all_operators / include_all_overloads (callers should skip the
Expand Down Expand Up @@ -116,10 +131,18 @@

def gen_max_kernel_num(
oplist_yamls: List[Path],
prim_ops_source: Path,
output_path: Path,
prim_ops_source: Optional[Path] = None,
selected_prim_ops_header: Optional[Path] = None,
) -> Optional[int]:
total = 0
if (prim_ops_source is None) == (selected_prim_ops_header is None):
raise ValueError(
"Pass exactly one of prim_ops_source / selected_prim_ops_header. "
"Use selected_prim_ops_header when ET_PRIM_OPS_SELECTIVE_BUILD is "
"active so the counter matches what actually compiles in."
)

yaml_kernels = 0
for yaml_path in oplist_yamls:
yaml_count = _count_yaml_kernels(yaml_path)
if yaml_count is None:
Expand All @@ -130,9 +153,26 @@
)
_write_if_different(output_path, OPT_OUT_HEADER)
return None
total += yaml_count
yaml_kernels += yaml_count

# No selective build deps means none of the binary's kernel libraries went
# through this counter — auto-sizing would produce a too-small registry.
# Emit the opt-out header so operator_registry.cpp falls through to the
# compile-time default.
if yaml_kernels == 0:
print(
"gen_max_kernel_num: no kernels enumerated across input YAMLs; "
"emitting opt-out header (registry will use default size).",
file=sys.stderr,
)
_write_if_different(output_path, OPT_OUT_HEADER)
return None

total += _count_prim_ops(prim_ops_source)
if selected_prim_ops_header is not None:
total = yaml_kernels + _count_selected_prim_ops(selected_prim_ops_header)
else:
assert prim_ops_source is not None
total = yaml_kernels + _count_prim_ops(prim_ops_source)

_write_if_different(output_path, HEADER_TEMPLATE.format(count=total))
return total
Expand All @@ -147,11 +187,19 @@
required=True,
help="Path to a selected_operators.yaml. May be repeated.",
)
parser.add_argument(
prim_ops_group = parser.add_mutually_exclusive_group(required=True)
prim_ops_group.add_argument(
"--prim-ops-source",
"--prim_ops_source",
required=True,
help="Path to kernels/prim_ops/register_prim_ops.cpp.",
help="Path to kernels/prim_ops/register_prim_ops.cpp. Use this when "
"ET_PRIM_OPS_SELECTIVE_BUILD is not active (all prim ops compile in).",
)
prim_ops_group.add_argument(
"--selected-prim-ops-header",
"--selected_prim_ops_header",
help="Path to the aggregated selected_prim_ops.h emitted by "
"gen_selected_prim_ops.py. Use this when ET_PRIM_OPS_SELECTIVE_BUILD "
"is active so only enabled prim ops are counted.",
)
parser.add_argument(
"--output-path",
Expand All @@ -163,8 +211,13 @@

count = gen_max_kernel_num(
oplist_yamls=[Path(p) for p in args.oplist_yaml],
prim_ops_source=Path(args.prim_ops_source),
output_path=Path(args.output_path),
prim_ops_source=Path(args.prim_ops_source) if args.prim_ops_source else None,
selected_prim_ops_header=(
Path(args.selected_prim_ops_header)
if args.selected_prim_ops_header
else None
),
)
if count is not None:
print(f"gen_max_kernel_num: wrote {args.output_path} (count={count})")
Expand Down
20 changes: 20 additions & 0 deletions codegen/tools/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,26 @@ def define_common_targets(is_fbcode = False):
_is_external_target = True,
)

runtime.python_library(
name = "gen_max_kernel_num_lib",
srcs = ["gen_max_kernel_num.py"],
base_module = "executorch.codegen.tools",
visibility = ["//executorch/..."],
)

runtime.python_binary(
name = "gen_max_kernel_num",
main_module = "executorch.codegen.tools.gen_max_kernel_num",
package_style = "inplace",
visibility = [
"PUBLIC",
],
deps = [
":gen_max_kernel_num_lib",
],
_is_external_target = True,
)


runtime.cxx_python_extension(
name = "selective_build",
Expand Down
86 changes: 86 additions & 0 deletions codegen/tools/test/test_gen_max_kernel_num.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from executorch.codegen.tools.gen_max_kernel_num import (
_count_prim_ops,
_count_selected_prim_ops,
_count_yaml_kernels,
gen_max_kernel_num,
)
Expand Down Expand Up @@ -64,6 +65,91 @@ def test_counts_prim_ops_errors_when_array_empty(self) -> None:
with self.assertRaises(RuntimeError):
_count_prim_ops(empty_array)

def test_counts_selected_prim_ops_from_header(self) -> None:
header = self.tmp / "selected_prim_ops.h"
header.write_text(
"#pragma once\n"
"#define INCLUDE_ATEN_SYM_SIZE_INT\n"
"#define INCLUDE_EXECUTORCH_PRIM_ADD_INT_INT\n"
"// not a define: INCLUDE_FOO\n"
"#define INCLUDE_EXECUTORCH_PRIM_MUL_INT_INT\n"
)
self.assertEqual(_count_selected_prim_ops(header), 3)

def test_rejects_both_prim_ops_inputs(self) -> None:
yaml_path = self.tmp / "selected_operators.yaml"
_write_yaml(
yaml_path,
{
"operators": {"aten::add.out": {}},
"et_kernel_metadata": {"aten::add.out": ["v1/6"]},
},
)
header = self.tmp / "selected_prim_ops.h"
header.write_text("#define INCLUDE_FOO\n")
with self.assertRaises(ValueError):
gen_max_kernel_num(
oplist_yamls=[yaml_path],
prim_ops_source=self.prim_ops_source,
selected_prim_ops_header=header,
output_path=self.output,
)

def test_rejects_neither_prim_ops_input(self) -> None:
yaml_path = self.tmp / "selected_operators.yaml"
_write_yaml(
yaml_path,
{
"operators": {"aten::add.out": {}},
"et_kernel_metadata": {"aten::add.out": ["v1/6"]},
},
)
with self.assertRaises(ValueError):
gen_max_kernel_num(
oplist_yamls=[yaml_path],
output_path=self.output,
)

def test_empty_yaml_writes_opt_out_header(self) -> None:
yaml_path = self.tmp / "selected_operators.yaml"
_write_yaml(yaml_path, {"operators": {}, "et_kernel_metadata": {}})
total = gen_max_kernel_num(
oplist_yamls=[yaml_path],
prim_ops_source=self.prim_ops_source,
output_path=self.output,
)
self.assertIsNone(total)
self.assertTrue(self.output.exists())
self.assertNotIn(
"#define EXECUTORCH_SELECTED_MAX_KERNEL_NUM",
self.output.read_text(),
)

def test_end_to_end_with_selected_prim_ops_header(self) -> None:
yaml_path = self.tmp / "selected_operators.yaml"
_write_yaml(
yaml_path,
{
"operators": {"aten::add.out": {}},
"et_kernel_metadata": {"aten::add.out": ["v1/6"]},
},
)
header = self.tmp / "selected_prim_ops.h"
header.write_text(
"#define INCLUDE_ATEN_SYM_SIZE_INT\n"
"#define INCLUDE_EXECUTORCH_PRIM_ADD_INT_INT\n"
)
total = gen_max_kernel_num(
oplist_yamls=[yaml_path],
selected_prim_ops_header=header,
output_path=self.output,
)
self.assertEqual(total, 1 + 2)
self.assertIn(
"#define EXECUTORCH_SELECTED_MAX_KERNEL_NUM 3",
self.output.read_text(),
)

def test_counts_single_variant_per_op(self) -> None:
yaml_path = self.tmp / "selected_operators.yaml"
_write_yaml(
Expand Down
121 changes: 121 additions & 0 deletions runtime/kernel/selective_build.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
load(
"@fbsource//xplat/executorch/runtime/kernel:targets.bzl",
"operator_registry_preprocessor_flags",
)

# Layout of the per-binary header tree, matching the angle-bracket includes
# operator_registry.cpp uses (`<executorch/runtime/kernel/...>`).
_HEADER_DIR = "executorch/runtime/kernel"
_MAX_KERNEL_NUM_HEADER = _HEADER_DIR + "/selected_max_kernel_num.h"
_OP_REGISTRY_HEADER = _HEADER_DIR + "/operator_registry.h"

def gen_max_kernel_num_genrule(
name,
oplist_yaml_target,
selected_prim_ops_header_target = None,
platforms = "CXX"):
"""Run gen_max_kernel_num on a selected_operators.yaml and emit a header
that defines EXECUTORCH_SELECTED_MAX_KERNEL_NUM.

When selected_prim_ops_header_target is provided (i.e. ET_PRIM_OPS_SELECTIVE_BUILD
is active for this binary), the prim ops contribution is counted from that
header so it matches what actually compiles in. Otherwise the count comes
from parsing register_prim_ops.cpp directly.
"""

# Write the header flat at the artifact root so consumers can reference
# it as $(location :name)/selected_max_kernel_num.h (mirrors the
# selected_operators.yaml / selected_prim_ops.h conventions used by the
# adjacent genrules).
cmd = (
"$(exe //executorch/codegen/tools:gen_max_kernel_num) " +
"--oplist-yaml=$(location {})/selected_operators.yaml ".format(oplist_yaml_target) +
"--output-path=$OUT/selected_max_kernel_num.h "
)
if selected_prim_ops_header_target:
cmd += "--selected-prim-ops-header=$(location {})/selected_prim_ops.h".format(
selected_prim_ops_header_target,
)
else:
cmd += "--prim-ops-source=$(location //executorch/kernels/prim_ops:prim_ops_sources)/register_prim_ops.cpp"

runtime.genrule(
name = name,
cmd = cmd,
outs = {"selected_max_kernel_num.h": ["selected_max_kernel_num.h"]},
default_outs = ["."],
platforms = platforms,
)

def operator_registry_selective(
name,
selected_max_kernel_num_header_target,
aten_suffix = "",
platforms = "CXX",
**kwargs):
"""Per-binary operator_registry variant whose registry capacity is sized
to the kernels its consumer actually selected.

Stages operator_registry.cpp + operator_registry.h + the generated
selected_max_kernel_num.h in a single artifact tree, then compiles the
.cpp with all three headers visible at the expected
`<executorch/runtime/kernel/...>` paths. operator_registry.cpp's existing
`__has_include` ladder picks up EXECUTORCH_SELECTED_MAX_KERNEL_NUM. A
user-supplied `-c executorch.max_kernel_num=N` still wins via the same
preprocessor flags the shared target uses.

NOTE: the operator registry is intentionally a global; this target
defines the same external-linkage symbols as the shared
`//executorch/runtime/kernel:operator_registry`. Linking both into one
binary produces ODR / duplicate-symbol errors. Consumers must arrange
for only one to be linked transitively, which is why
`executorch_generated_lib(auto_size_kernel_registry = True)` is opt-in.
"""
src_target = "//executorch/runtime/kernel:operator_registry_sources"
hdr_target = "//executorch/runtime/kernel:operator_registry_headers"
source_name = "operator_registry.cpp"
genrule_dep_name = name + "_operator_registry_srcs_copy"

runtime.genrule(
name = genrule_dep_name,
cmd = " && ".join([
"mkdir -p $OUT/{}".format(_HEADER_DIR),
"cp -f $(location {})/{} $OUT/{}".format(src_target, source_name, source_name),
"cp -f $(location {})/operator_registry.h $OUT/{}".format(hdr_target, _OP_REGISTRY_HEADER),
"cp -f $(location {})/selected_max_kernel_num.h $OUT/{}".format(
selected_max_kernel_num_header_target,
_MAX_KERNEL_NUM_HEADER,
),
]),
outs = {
source_name: [source_name],
"operator_registry.h": [_OP_REGISTRY_HEADER],
"selected_max_kernel_num.h": [_MAX_KERNEL_NUM_HEADER],
},
default_outs = ["."],
platforms = platforms,
)

runtime.cxx_library(
name = name,
srcs = [":" + genrule_dep_name + "[" + source_name + "]"],
exported_headers = {
_OP_REGISTRY_HEADER: ":" + genrule_dep_name + "[operator_registry.h]",
_MAX_KERNEL_NUM_HEADER: ":" + genrule_dep_name + "[selected_max_kernel_num.h]",
},
# The dict keys above are already fully-qualified include paths (e.g.
# `executorch/runtime/kernel/operator_registry.h`); env_interface.bzl
# would otherwise prepend `executorch/<consumer-package>/` to them.
header_namespace = "",
visibility = ["PUBLIC"],
# @lint-ignore BUCKLINT link_whole, the registry contains a global table.
link_whole = True,
preprocessor_flags = operator_registry_preprocessor_flags(),
exported_deps = [
"//executorch/runtime/core:core",
"//executorch/runtime/core:evalue" + aten_suffix,
],
platforms = platforms,
**kwargs
)
Loading
Loading