From 848c6a3549fa1263ed872d82fe81c00e338487c0 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Thu, 7 May 2026 23:04:45 +0100 Subject: [PATCH 1/6] Honor mooncake_overlay in primitive callmeta inference `MooncakeInterpreter` previously routed every primitive call site through `NativeInterpreter` for `CallMeta`, but `NativeInterpreter` is overlay-blind: when a primitive has a `@mooncake_overlay` that changes its return type, the inferred type at the call site was wrong and downstream dispatch compiled against the original type. Detect overlay matches via `Method.external_mt` and route them through `@invoke` instead, keeping the `NativeInterpreter` fast path for non-overlay primitives. Fixes #1169. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/interpreter/abstract_interpretation.jl | 33 ++++++++++++++++++--- test/interpreter/abstract_interpretation.jl | 16 ++++++++++ 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/src/interpreter/abstract_interpretation.jl b/src/interpreter/abstract_interpretation.jl index 5a482bf09f..33a2626294 100644 --- a/src/interpreter/abstract_interpretation.jl +++ b/src/interpreter/abstract_interpretation.jl @@ -168,13 +168,28 @@ function Core.Compiler.abstract_call_gf_by_type( # primitives, and continue that search down the callee tree. That extra work is # unnecessary for a primitive with a hand-written rule. # + # Exception: `NativeInterpreter` is overlay-blind, so fall back to + # overlay-aware default dispatch when a `@mooncake_overlay` applies (issue #1169). + # # `noinline_callmeta` below then blocks inlining/const-folding so the primitive # call stays in the caller IR and Mooncake can dispatch its `rrule!!` at runtime. # See PR #1115 for more discussion. - native_interp = CC.NativeInterpreter(interp.world) - ret = CC.abstract_call_gf_by_type( - native_interp, f, arginfo, si, atype, sv, max_methods - ) + ret = if any_matches_overlay(applicable) + @invoke CC.abstract_call_gf_by_type( + interp::CC.AbstractInterpreter, + f::Any, + arginfo::CC.ArgInfo, + si::CC.StmtInfo, + atype::Any, + sv::CC.AbsIntState, + max_methods::Int, + ) + else + native_interp = CC.NativeInterpreter(interp.world) + CC.abstract_call_gf_by_type( + native_interp, f, arginfo, si, atype, sv, max_methods + ) + end @static if VERSION < v"1.12-" call = ret::CC.CallMeta # Keep primitives in caller IR by blocking const-folding and inlining @@ -216,6 +231,16 @@ function any_matches_primitive(applicable, C, M, world) false end +function any_matches_overlay(applicable) + for app in applicable + method = VERSION < v"1.12-" ? app.method : app.match.method + if isdefined(method, :external_mt) && method.external_mt !== nothing + return true + end + end + false +end + """ widen_rettype_callmeta(call, argtypes) diff --git a/test/interpreter/abstract_interpretation.jl b/test/interpreter/abstract_interpretation.jl index b600ad0d9d..a53bcc48d5 100644 --- a/test/interpreter/abstract_interpretation.jl +++ b/test/interpreter/abstract_interpretation.jl @@ -3,6 +3,16 @@ non_primitive(x) = sin(x) Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(a_primitive),Float64} +# Issue #1169: a `@mooncake_overlay` on a primitive must propagate to the +# inferred return type at the call site, otherwise downstream dispatch +# compiles against the un-overlaid type. +struct OverlayA end +struct OverlayB end +overlay_switch() = OverlayA() +Mooncake.@mooncake_overlay overlay_switch() = OverlayB() +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(overlay_switch)} +overlay_caller() = overlay_switch() + contains_primitive(x) = @inline a_primitive(x) contains_non_primitive(x) = @inline non_primitive(x) contains_primitive_behind_call(x) = @inline contains_primitive(x) @@ -123,6 +133,12 @@ end @test val == 1.0 @test grad[2] == [2.0, 0.0, 1.0] end + + @testset "1169 - overlay propagates to inferred return type" begin + interp = Mooncake.MooncakeInterpreter(DefaultCtx, ReverseMode) + sig = Tuple{typeof(overlay_caller)} + @test Base.code_ircode_by_type(sig; interp)[1][2] == OverlayB + end end @testset "Config(empty_cache=true)" begin From 5631f8f895f23e739ecf7075e1bde52d419225ab Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Fri, 8 May 2026 14:32:26 +0100 Subject: [PATCH 2/6] Update src/interpreter/abstract_interpretation.jl Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Signed-off-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- src/interpreter/abstract_interpretation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/interpreter/abstract_interpretation.jl b/src/interpreter/abstract_interpretation.jl index 33a2626294..20642eb2aa 100644 --- a/src/interpreter/abstract_interpretation.jl +++ b/src/interpreter/abstract_interpretation.jl @@ -234,7 +234,7 @@ end function any_matches_overlay(applicable) for app in applicable method = VERSION < v"1.12-" ? app.method : app.match.method - if isdefined(method, :external_mt) && method.external_mt !== nothing + if isdefined(method, :external_mt) && method.external_mt === mooncake_method_table return true end end From a0f11465ecde73864a3aaeaf5af9e224ba365759 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Fri, 8 May 2026 14:41:59 +0100 Subject: [PATCH 3/6] Share version-dispatch line between any_matches helpers Both `any_matches_primitive` and `any_matches_overlay` need the same pre-1.12 / 1.12+ unwrap of an `applicable` entry to a `MethodMatch`. Use a single inline `match = VERSION < v"1.12-" ? app : app.match` in each so the version skew sits in one place per function. Also tighten the overlay check to compare directly against `mooncake_method_table` rather than `external_mt !== nothing`, so an unrelated downstream overlay table cannot trip the overlay path. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/interpreter/abstract_interpretation.jl | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/interpreter/abstract_interpretation.jl b/src/interpreter/abstract_interpretation.jl index 20642eb2aa..c5d7adfefd 100644 --- a/src/interpreter/abstract_interpretation.jl +++ b/src/interpreter/abstract_interpretation.jl @@ -219,12 +219,8 @@ end function any_matches_primitive(applicable, C, M, world) for app in applicable - if VERSION < v"1.12-" - sig = app.spec_types - else - sig = app.match.spec_types - end - if is_primitive(C, M, sig, world) + match = VERSION < v"1.12-" ? app : app.match + if is_primitive(C, M, match.spec_types, world) return true end end @@ -233,7 +229,8 @@ end function any_matches_overlay(applicable) for app in applicable - method = VERSION < v"1.12-" ? app.method : app.match.method + match = VERSION < v"1.12-" ? app : app.match + method = match.method if isdefined(method, :external_mt) && method.external_mt === mooncake_method_table return true end From ed166c797d05bf415d6abd954cbc50e10d0919a1 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 14 May 2026 11:17:37 +0100 Subject: [PATCH 4/6] Add primitives/overlays developer doc and #1169 regression test Documents how `@mooncake_overlay` and `@is_primitive` compose, including the supported case (same function/argtypes) and the unsupported case (overlay reachable only from inside a primitive's body). Adds a short section to known_limitations.md and a cross-reference from defining_rules.md. Regression test exercises overlay propagation through downstream dispatch for the same-signature path. --- docs/make.jl | 1 + .../primitives_and_overlays.md | 213 ++++++++++++++++++ docs/src/known_limitations.md | 7 + docs/src/utilities/defining_rules.md | 2 + test/interpreter/abstract_interpretation.jl | 41 ++++ 5 files changed, 264 insertions(+) create mode 100644 docs/src/developer_documentation/primitives_and_overlays.md diff --git a/docs/make.jl b/docs/make.jl index 0700e66885..3742820c8c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -55,6 +55,7 @@ makedocs(; joinpath("developer_documentation", "ir_representation.md"), joinpath("developer_documentation", "forwards_mode_design.md"), joinpath("developer_documentation", "reverse_mode_design.md"), + joinpath("developer_documentation", "primitives_and_overlays.md"), joinpath("developer_documentation", "misc_internals_notes.md"), joinpath("developer_documentation", "advanced_debugging.md"), joinpath("developer_documentation", "internal_docstrings.md"), diff --git a/docs/src/developer_documentation/primitives_and_overlays.md b/docs/src/developer_documentation/primitives_and_overlays.md new file mode 100644 index 0000000000..880ff39261 --- /dev/null +++ b/docs/src/developer_documentation/primitives_and_overlays.md @@ -0,0 +1,213 @@ +# Primitives and Overlays + +[`Mooncake.@is_primitive`](@ref) and [`Mooncake.@mooncake_overlay`](@ref) are the two main ways a user can intervene in how Mooncake differentiates a function. They act at different layers: + +- `@mooncake_overlay` changes **method lookup under Mooncake's interpreter**. It gives Mooncake a different primal body to see. +- `@is_primitive` changes **what Mooncake differentiates**. It declares a boundary where Mooncake stops differentiating the primal body and uses a hand-written rule instead. + +This page explains each macro on its own terms, describes how Mooncake uses them during type inference, and then derives which compositions are supported and which are not. Implementation-level details are tucked into collapsible blocks for readers who want to follow the mechanism. + +## Primitives + +To differentiate a function, Mooncake runs a *compilation step*: it walks the function's inferred source IR statement by statement, rewriting each into the forward- and reverse-pass code that will execute at differentiation time. The output is a derived rule — a callable that runs the program forward while tracking the data needed to compute the gradient on the reverse pass. + +[`Mooncake.@is_primitive`](@ref) declares a function as a stopping point for that rewriting. At a matching call site: + +- The body is not walked: Mooncake leaves the call statement in the transformed IR as-is. +- At runtime, the registered [`Mooncake.rrule!!`](@ref) (or [`Mooncake.frule!!`](@ref)) is dispatched in place of the primal call. + +The rule, not the body, produces the value at this call site; its return type — typically `Tuple{CoDual{B,F}, Pullback}` — is what the surrounding AD code sees. + +!!! details "Mechanism" + Mooncake's `AbstractInterpreter` override of `abstract_call_gf_by_type` (in `src/interpreter/abstract_interpretation.jl`) checks each call site against the primitive table via `any_matches_primitive`. When a match is detected, the resulting `CallMeta` is wrapped in a `NoInlineCallInfo`, which Mooncake's `inlining_policy` / `src_inlining_policy` then refuses to inline. As a result, the primitive call survives into the IR that AD construction sees, and the rule-dispatch code is emitted at that statement instead of inlined primal code. + +## Overlays + +`@mooncake_overlay` registers an additional method for a function in a private method table, `Mooncake.mooncake_method_table`. Only Mooncake's interpreter consults this table; plain Julia dispatch and `Core.Compiler.NativeInterpreter` do not. + +For example: + +```julia +# Imagine `slow_or_unsupported` hits something Mooncake doesn't handle +# (a foreign call, a `try`/`catch`, ...) or handles only inefficiently. +f(x) = slow_or_unsupported(x) + +# An AD-friendly body that returns the same value and the same type. +Mooncake.@mooncake_overlay f(x) = ad_friendly_alternative(x) +``` + +When `MooncakeInterpreter` infers a call to `f`, method lookup goes through `OverlayMethodTable` and resolves to the overlay (`ad_friendly_alternative(x)`). Plain Julia dispatch is unchanged — code calling `f` outside `MooncakeInterpreter` still executes the original (`slow_or_unsupported(x)`). Inside `MooncakeInterpreter`, the *primal* of `f` is the overlay body whenever a matching overlay exists, so the inferred source IR — and any AD rewriting subsequently applied to it — sees the overlay body, not the original. + +The intended use is to substitute a body Mooncake can't differentiate (e.g. a foreign call, or a construct that hits a known limitation) — or one Mooncake can differentiate but only inefficiently — with an equivalent body that AD handles better. Mooncake doesn't verify equivalence; the author is responsible for ensuring the overlay returns the same value and the same type as the original, so that differentiating the overlay yields a derivative of the original semantics. + +!!! details "Mechanism" + `mooncake_method_table` is a `Core.MethodTable` created by [`Base.Experimental.@MethodTable`](https://docs.julialang.org/en/v1/base/base/#Base.Experimental.@MethodTable). `@mooncake_overlay` is essentially a thin wrapper around [`Base.Experimental.@overlay`](https://docs.julialang.org/en/v1/base/base/#Base.Experimental.@overlay): both rewrite the method definition's call head into an `Expr(:overlay, mt, name)`, which the frontend registers into `mt` (visible as the resulting `Method`'s `external_mt` field) rather than the global method table. + + The lookup that makes overlays "win" is `Core.Compiler.OverlayMethodTable` (defined in [`Compiler/src/methodtable.jl`](https://github.com/JuliaLang/julia/blob/master/Compiler/src/methodtable.jl)). `CC.method_table(::MooncakeInterpreter)` returns one constructed over `mooncake_method_table`, and during inference every method lookup goes through it: if `mooncake_method_table` has a matching method that fully covers the signature, it wins; otherwise lookup falls back to the global table. + +## Type inference + +Mooncake's IR transformation is driven by inferred type information. Three places matter, and they fire in this order: + +1. **Source-IR inference.** The function being differentiated is inferred via `MooncakeInterpreter`. This produces the IR that the AD transformation rewrites. +2. **Per-call `CallMeta`.** At each call statement during the source-IR walk, Mooncake needs the return type, effects, and call info. `abstract_call_gf_by_type` produces this `CallMeta`. Primitive call sites are handled specially here — see [Inference at primitive call sites](@ref). + + !!! details "What `abstract_call_gf_by_type` does" + This is Julia's central per-call-site inference entry point, in `Compiler/src/abstractinterpretation.jl`. Given a function value, the call's argument info / `atype`, the current inference state, and a cap on how many methods to consider, it returns a `Future{CallMeta}` with the inferred return type, exception type, effects, and call-site info. + + At a high level it does three things: (a) **method lookup** via `find_method_matches`, finding all method candidates whose signatures intersect the call's `atype`; (b) **per-match abstract interpretation**, calling `abstract_call_method` to recursively infer each candidate's body and optionally running constant propagation; (c) **aggregation**, joining each candidate's return type, exception type, and effects over the IPO lattice, recording inference edges, and producing a final `CallMeta`. The per-match loop is cooperatively pausable — if a sub-inference is in flight, the work is rescheduled — which is how stackless inference on Julia 1.12+ works. + + Mooncake overrides this function for `MooncakeInterpreter` and inserts its primitive / overlay logic *before* the recursive per-match step: if the call site is a primitive, we don't want `abstract_call_method` recursion at all, just a `CallMeta`. See [Inference at primitive call sites](@ref) for why this matters and how Mooncake produces the `CallMeta` without recursion. +3. **Rule-type inference.** Later, during AD IR construction, Mooncake calls `Core.Compiler.return_type` with the default interpreter — for example when emitting a `pullback_type` lookup — to learn the type the rule itself returns. + +The key asymmetry to internalise: **Mooncake's source-function inference is overlay-aware via `OverlayMethodTable`; `NativeInterpreter`, used at primitive boundaries, is not.** + +### Inference at primitive call sites + +At every call site in the source IR, Mooncake needs a return type — downstream code is typed against it. At a primitive call site this is no different: the surrounding code wants the primal's return type, and the rule is an *implementation* keyed to that type, not a *source* for it. So inference asks the primal what it returns; the rule isn't consulted at this stage. + +That leaves the question of *how* to obtain the primal `CallMeta`. Why not just use `MooncakeInterpreter` for this inference too? It is, after all, the AD-aware interpreter we are already running on the source IR — recursing into a primitive's body with it is the natural choice. + +The problem is what that recursion costs. `MooncakeInterpreter`'s `abstract_call_gf_by_type` override re-passes itself into the recursive walk: every call site visited during inference of the body re-enters the override, runs the primitive check, and re-enters Julia's recursive inference under `MooncakeInterpreter` again. `MooncakeInterpreter` also uses its own inference caches separate from Julia's global cache, so the re-walk does not reuse Julia's already-warm results — every function Mooncake differentiates triggers a fresh walk of its transitive call tree. + +For most call trees this is fine. Some real-world ones are not — see [PR #1115](https://github.com/chalk-lab/Mooncake.jl/pull/1115) for a SciML-shaped case where this recursion explodes into a silent compile-time hang. + +The license to do something different comes from observing what the primitive's body is actually used for at this point in AD. The body is *not* going to be rewritten into AD-generated code: at a primitive call site, the registered rule replaces the body at runtime, and the AD transformation emits a rule dispatch rather than walking the body's statements. The body's contents therefore do not need to be inspected by an AD-aware interpreter; only its return type needs to flow into the surrounding primal IR's `CallMeta`. Any interpreter that produces a correct `CallMeta` for the boundary is sufficient. + +Mooncake therefore asks `NativeInterpreter` for the `CallMeta` and stops. Standard inference still walks the body to compute the type, but it does so against Julia's global cache and without re-firing Mooncake's primitive-detection machinery at every nested call site. The recursion that would otherwise cascade through `solve`'s call tree under `MooncakeInterpreter` is bounded at each primitive boundary. + +The wrinkle: `NativeInterpreter` is overlay-blind. Any overlay that would affect the primitive — directly, or indirectly via a call inside its body — is invisible to inference at the primitive boundary. + +## Composition + +The supported and unsupported combinations follow directly from the layers each macro touches. + +### Overlay only, no primitive + +Mooncake's interpreter sees the overlay's body in place of the original. AD differentiates the overlay body. Fully supported; this is the canonical use of `@mooncake_overlay`. + +### Primitive only, no overlay + +The rule replaces the primal at runtime. Inference at the call site asks `NativeInterpreter` for the *original* body's `CallMeta`. Fully supported, on the standard contract that the rule's primal return type matches the original's. + +### Overlay and primitive on the *same* signature — supported corner case + +From the previous discussion: as long as the overlay returns the same type as the original, `NativeInterpreter`'s overlay-blindness doesn't matter — inference and the rule agree on the type at the call site either way. The corner case is when the overlay *changes* the return type: inference would see the original's type, the rule would produce the overlay's, and downstream code would be typed against the wrong one. + +Mooncake detects this configuration and routes inference through the overlay-aware default path, so the inferred return type at the call site matches the overlay's, not the original's. At runtime, the registered `rrule!!` still fires. + +In effect: the rule produces the value and the adjoint; the overlay's only job is to align inference's view of the return type with what the rule actually returns. This matters when the rule returns a value of a different type from the original primal and downstream code dispatches on that type. Most users should not need this pattern — prefer to express the change as either an overlay or a primitive, not both — but it is supported. + +!!! details "Mechanism" + `any_matches_overlay` (in `src/interpreter/abstract_interpretation.jl`) walks the applicable methods and checks `method.external_mt === mooncake_method_table`. When that returns true, `abstract_call_gf_by_type` takes the `@invoke` branch — i.e. it defers to the default `abstract_call_gf_by_type` *with `MooncakeInterpreter` still as the interpreter*, so method lookup inside that call still goes through `OverlayMethodTable` and resolves to the overlay's body. The `NativeInterpreter` fast path is reserved for primitives whose applicable methods have no overlay. + +### Primitive called from inside an overlay's body — supported + +An overlay's body may itself call a registered primitive. This is the ordinary, supported flow: Mooncake walks the overlay body for AD, and any primitive call inside it is handled by the same machinery that handles primitive calls anywhere else (primitive detection, `NativeInterpreter` for the `CallMeta`, rule dispatch at runtime). No special arrangement is needed; this is in fact the most common reason to write an overlay — substituting an AD-unfriendly body with one that bottoms out on a hand-written rule. + +```julia +my_primitive(x::Float64) = 2x +Mooncake.@is_primitive Mooncake.DefaultCtx Tuple{typeof(my_primitive),Float64} +function Mooncake.rrule!!(::CoDual{typeof(my_primitive)}, x::CoDual{Float64}) + pb(dy) = NoRData(), 2dy + return Mooncake.zero_fcodual(2 * Mooncake.primal(x)), pb +end + +# `original_f` has some body Mooncake handles awkwardly. Overlay redirects through +# `my_primitive`, whose rule supplies the derivative. +original_f(x::Float64) = unsupported_or_expensive(x) +Mooncake.@mooncake_overlay original_f(x::Float64) = my_primitive(x) +``` + +Differentiating any caller of `original_f` walks the overlay's body, hits `my_primitive`, dispatches its `rrule!!`, and computes the gradient via the rule's adjoint. The [drift](@ref "Drift between rules and overlays") hazard applies as anywhere else: the primitive's rule must agree with its inferred primal return type. + +### Overlay on a non-primitive called from inside a primitive's body — not supported + +Although the primitive's body is not *differentiated*, it is still *inferred* — `NativeInterpreter` walks it to produce the primitive's `CallMeta`. Because `NativeInterpreter` does not consult `mooncake_method_table`, any overlay on a nested call within the body is invisible to that walk. Inference of the primitive's return type therefore sees the original definitions of its nested calls, not the overlays. + +Example: + +```julia +helper(::A) = A() +Mooncake.@mooncake_overlay helper(::A) = B() + +primitive_wrapper(x::A) = helper(x) +Mooncake.@is_primitive Mooncake.DefaultCtx Tuple{typeof(primitive_wrapper), A} +``` + +Call sites of `primitive_wrapper` will infer the return type as `A`, even though the overlay would return `B` if it were honoured. + +!!! details "Mechanism — why inference sees `A`" + Walking the layers for a call `primitive_wrapper(a::A)`: + + 1. The outer function containing this call is inferred under `MooncakeInterpreter`. At the `primitive_wrapper(a)` statement, `abstract_call_gf_by_type` is called with `MooncakeInterpreter`. + 2. Applicable-method lookup uses `OverlayMethodTable`, but `primitive_wrapper` itself has no overlay. The match resolves to the ordinary `primitive_wrapper(::A)`. + 3. `any_matches_primitive` returns true. `any_matches_overlay` returns false (the method's `external_mt` is unset). + 4. The branch added in #1170 takes the `NativeInterpreter` fast path and asks for the `CallMeta` of `primitive_wrapper(::A)` under `NativeInterpreter`. + 5. `NativeInterpreter` infers `primitive_wrapper`'s body. At the `helper(x)` statement inside that body, its method lookup uses the standard method table — `mooncake_method_table` is invisible to it — so it resolves to `helper(::A) = A()` and infers the return as `A`. + 6. The inferred return type of `primitive_wrapper(::A)` therefore propagates as `A` back to the outer caller, even though under Mooncake the overlay would have made it `B`. + + The break is at step (5): the right layer (Mooncake) knows about the overlay, but it has delegated this lookup to a layer that does not. + +**The wrong-gradient mechanism.** The reported cases of [#1169](https://github.com/chalk-lab/Mooncake.jl/issues/1169) — including the SciMLBase `Originator` shape — involve primitives whose return types are singletons. For these, inference produces not just a type but a `CC.Const(value)`. The consequence is that the rule never gets a chance to fire: Julia constant-folds the call to the literal value before AD construction sees it. + +Walked out: + +1. `NativeInterpreter` (overlay-blind) infers the primitive call as `Const(original_value)` — the singleton instance from the *original* body. +2. `widen_rettype_callmeta` exists to prevent `Const` from causing primitive calls to fold away, but it has a documented carve-out: if every runtime argument at the call site is also `Const`, folding is treated as safe (the `sin(1.0)`-with-a-literal case). A zero-runtime-argument primitive trivially satisfies this; many SciML-style overlays do too. +3. Const propagation in subsequent compiler passes replaces the primitive call with the literal value — the *original*'s singleton, not the overlay's. +4. By the time AD construction processes the IR, there is no primitive call site at this location, only a constant. No `rrule!!` call is emitted; no `Core.typeassert` is emitted; no runtime check fires. +5. Downstream code is compiled against the inferred (wrong) singleton type and picks rules keyed to it. The runtime never has the opportunity to course-correct. + +The result is a silent wrong gradient: not because the rule produced the wrong value, but because the rule was never called. The typeassert that Mooncake emits at primitive call sites (in `src/interpreter/reverse_mode.jl`) is not the safety net here — by the time it would have run, the call has already been replaced by a literal. + +For overlays that don't yield a `Const` (e.g. a primitive whose return is concrete but not a singleton), the failure mode is different: either inference and the rule happen to agree on the type and there is no problem, or they disagree and the typeassert traps with a `TypeError` — loud rather than silent. The dangerous combination is overlay + singleton return, which is exactly the shape both #1169's MWE and the SciMLBase usage take. + +This is by design: Mooncake treats primitives as sealed boundaries and does not walk into a primitive's body to discover what overlays might affect it. The fix in [#1170](https://github.com/chalk-lab/Mooncake.jl/pull/1170) extends overlay-awareness only to the *primitive's own signature* — the boundary inference is already looking at. For overlays reachable from inside a primitive's body, the rule and Mooncake's inferred type may diverge, and keeping them coherent is the rule author's responsibility (see [Drift between rules and overlays](@ref) for the contract). + +An alternative approach in [PR #1168](https://github.com/chalk-lab/Mooncake.jl/pull/1168) instead walks into primitive bodies with overlay-aware method lookup (via a wrapper around `NativeInterpreter` that uses `mooncake_method_table`); it would fix the inside-body case as well. It was not adopted in #1170 — the sealed-boundary policy makes for a smaller and more focused fix. + +Workaround when you do encounter this shape: lift the overlay to the level the user actually calls. Either remove the primitive declaration on the wrapper and let AD differentiate it, or register the desired behaviour as a primitive on the outer function. + +### Drift between rules and overlays + +A rule's return type is hand-written and fixed when `rrule!!` is authored. If an overlay is introduced later — directly on the primitive, or on a function called in its body — that changes the primal type Mooncake's inference produces, the rule and the inference can disagree. + +The invariant the author must maintain is: + +```text +inferred primal return type at the call site + == +primal type inside the CoDual returned by the rule +``` + +When the types are concrete and non-singleton, the typeassert Mooncake emits on the rule's primal output (in `src/interpreter/reverse_mode.jl`) catches violations at runtime. For example: + +```julia +struct DriftOld; v::Float64 end +struct DriftNew; v::Float64 end + +f(x::Float64) = DriftOld(x) + +# Overlay introduced later, changing the return type. +Mooncake.@mooncake_overlay f(x::Float64) = DriftNew(x) + +Mooncake.@is_primitive Mooncake.DefaultCtx Tuple{typeof(f), Float64} + +# Stale rule, authored before the overlay was added: still returns DriftOld(...). +function Mooncake.rrule!!(g::CoDual{typeof(f)}, x::CoDual{Float64}) + pb(dy) = NoRData(), NoRData(), dy + return zero_fcodual(DriftOld(primal(x))), pb +end +``` + +Differentiating any caller of `f` then traps with: + +```text +TypeError: in typeassert, expected CoDual{DriftNew, NoFData}, + got a value of type CoDual{DriftOld, NoFData} +``` + +When the inferred return is a singleton (`CC.Const`), the primitive call is liable to be const-folded to the literal value before the rule fires, so the typeassert is bypassed. The runtime then follows the inferred-type path — which is what the overlay would have produced — and the rule's stale return type ends up irrelevant in practice. Convenient, but it's coincidence: don't rely on it. + +#1170 makes inference at primitive boundaries overlay-aware so the left-hand side reflects what the overlay-modified primal would actually return. Keeping the right-hand side in sync — adjusting the rule when the overlay changes the type — is the author's responsibility. diff --git a/docs/src/known_limitations.md b/docs/src/known_limitations.md index e7429b0da9..f14713576f 100644 --- a/docs/src/known_limitations.md +++ b/docs/src/known_limitations.md @@ -201,3 +201,10 @@ Honestly, your best bet is just to avoid differentiating functions whose argumen ```@meta DocTestSetup = nothing ``` + +## Composition of `@mooncake_overlay` and `@is_primitive` + +`@mooncake_overlay` substitutes the function body that Mooncake differentiates; `@is_primitive` marks a call site as a boundary where Mooncake stops differentiating the body and dispatches a hand-written rule instead. The two operate at different layers of the AD pipeline, and not every combination of them is supported. [Primitives and Overlays](@ref) covers the full picture; the practical summary is: + +- **Overlay on the same function and argument types as a primitive** is supported. The rule still runs at runtime, and the overlay only adjusts what Mooncake infers as the call's return type so it matches what the rule produces. +- **Overlay on a function called from inside a primitive's body** is not supported. Mooncake does not look into a primitive's body for overlays, so the overlay has no effect there; if it would have changed the return type, Mooncake's inferred type and the rule's actual output disagree, which can produce silently wrong gradients. diff --git a/docs/src/utilities/defining_rules.md b/docs/src/utilities/defining_rules.md index a6b87ac985..f24f1917e1 100644 --- a/docs/src/utilities/defining_rules.md +++ b/docs/src/utilities/defining_rules.md @@ -10,6 +10,8 @@ In this section, we detail some useful strategies which can help you avoid havin Mooncake.@mooncake_overlay ``` +See [Primitives and Overlays](@ref) for how overlays interact with primitives, including which compositions are supported and which are not. + ## Functions with Zero Adjoint If the above strategy does not work, but you find yourself in the surprisingly common diff --git a/test/interpreter/abstract_interpretation.jl b/test/interpreter/abstract_interpretation.jl index a53bcc48d5..a654c9f18a 100644 --- a/test/interpreter/abstract_interpretation.jl +++ b/test/interpreter/abstract_interpretation.jl @@ -6,6 +6,12 @@ Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(a_primitive),Float64} # Issue #1169: a `@mooncake_overlay` on a primitive must propagate to the # inferred return type at the call site, otherwise downstream dispatch # compiles against the un-overlaid type. +# +# NOTE: this fixture combines `@mooncake_overlay` with `@is_primitive` on the +# same signature purely to exercise the inference path. Combining them in user +# code is not a supported pattern — at a primitive call site AD dispatches the +# hand-written rule, so the overlay's only effect is to change inference's view +# of the return type. Use one or the other. struct OverlayA end struct OverlayB end overlay_switch() = OverlayA() @@ -13,6 +19,34 @@ Mooncake.@mooncake_overlay overlay_switch() = OverlayB() Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(overlay_switch)} overlay_caller() = overlay_switch() +# rrule returns the overlay-typed value, so an inference path that honours the +# overlay routes the downstream call to `overlay_use(::OverlayB, ...)` (2x); +# a path that ignores the overlay routes it to `overlay_use(::OverlayA, ...)` (1x). +function Mooncake.rrule!!(f::CoDual{typeof(overlay_switch)}) + return Mooncake.zero_fcodual(OverlayB()), Mooncake.NoPullback(f) +end + +overlay_use(::OverlayA, x::Float64) = x +overlay_use(::OverlayB, x::Float64) = 2x + +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(overlay_use),OverlayA,Float64} +function Mooncake.rrule!!( + f::CoDual{typeof(overlay_use)}, ::CoDual{OverlayA}, x::CoDual{Float64} +) + overlay_use_a_pb(dy) = NoRData(), NoRData(), dy + return Mooncake.zero_fcodual(Mooncake.primal(x)), overlay_use_a_pb +end + +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(overlay_use),OverlayB,Float64} +function Mooncake.rrule!!( + f::CoDual{typeof(overlay_use)}, ::CoDual{OverlayB}, x::CoDual{Float64} +) + overlay_use_b_pb(dy) = NoRData(), NoRData(), 2dy + return Mooncake.zero_fcodual(2 * Mooncake.primal(x)), overlay_use_b_pb +end + +overlay_outer(x::Float64) = overlay_use(overlay_switch(), x) + contains_primitive(x) = @inline a_primitive(x) contains_non_primitive(x) = @inline non_primitive(x) contains_primitive_behind_call(x) = @inline contains_primitive(x) @@ -139,6 +173,13 @@ end sig = Tuple{typeof(overlay_caller)} @test Base.code_ircode_by_type(sig; interp)[1][2] == OverlayB end + + @testset "1169 - overlay propagates through downstream dispatch" begin + cache = Mooncake.prepare_gradient_cache(overlay_outer, 1.0) + val, (_, grad) = Mooncake.value_and_gradient!!(cache, overlay_outer, 1.0) + @test val == 2.0 + @test grad == 2.0 + end end @testset "Config(empty_cache=true)" begin From b3571f75079da31c48eb822e31220c26dcd1dd0f Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 14 May 2026 11:46:50 +0100 Subject: [PATCH 5/6] Tighten primitives/overlays doc MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cuts redundant explanatory layers in the developer doc: - Drops the standalone "Drift between rules and overlays" section (the contract is restated in-place at the two referring points; the typeassert/TypeError mechanism is already covered in the const-folding section). - Removes the nested abstract_call_gf_by_type details block and the concrete-but-not-singleton aside. - Compresses the "why NativeInterpreter at primitive boundaries" rationale from five paragraphs to one. - Hides the five-step const-folding walk in a collapsible details block; the headline summary stays visible. - Folds the redundant "why inference sees A" mechanism into one inline sentence. - Renames "supported corner case" to "Direct overlay on a primitive signature" (neutral framing). - Fixes pullback arity in the worked example (2 slots, not 3). - Softens "primitive call survives into the IR" to "survives inlining" — const-folding is covered separately. - Adds GitHub source links for the .jl files referenced in the details blocks, and an @ref for widen_rettype_callmeta. Page goes from ~210 lines to ~130. Documenter build passes with no unresolved refs. --- .../primitives_and_overlays.md | 115 +++--------------- 1 file changed, 17 insertions(+), 98 deletions(-) diff --git a/docs/src/developer_documentation/primitives_and_overlays.md b/docs/src/developer_documentation/primitives_and_overlays.md index 880ff39261..dc06b20e4f 100644 --- a/docs/src/developer_documentation/primitives_and_overlays.md +++ b/docs/src/developer_documentation/primitives_and_overlays.md @@ -19,7 +19,7 @@ To differentiate a function, Mooncake runs a *compilation step*: it walks the fu The rule, not the body, produces the value at this call site; its return type — typically `Tuple{CoDual{B,F}, Pullback}` — is what the surrounding AD code sees. !!! details "Mechanism" - Mooncake's `AbstractInterpreter` override of `abstract_call_gf_by_type` (in `src/interpreter/abstract_interpretation.jl`) checks each call site against the primitive table via `any_matches_primitive`. When a match is detected, the resulting `CallMeta` is wrapped in a `NoInlineCallInfo`, which Mooncake's `inlining_policy` / `src_inlining_policy` then refuses to inline. As a result, the primitive call survives into the IR that AD construction sees, and the rule-dispatch code is emitted at that statement instead of inlined primal code. + Mooncake's `AbstractInterpreter` override of `abstract_call_gf_by_type` (in [`src/interpreter/abstract_interpretation.jl`](https://github.com/chalk-lab/Mooncake.jl/blob/main/src/interpreter/abstract_interpretation.jl)) checks each call site against the primitive table via `any_matches_primitive`. When a match is detected, the resulting `CallMeta` is wrapped in a `NoInlineCallInfo`, which Mooncake's `inlining_policy` / `src_inlining_policy` then refuses to inline. The primitive call therefore survives inlining, and the rule-dispatch code is emitted at that statement instead of inlined primal code. ## Overlays @@ -50,31 +50,16 @@ The intended use is to substitute a body Mooncake can't differentiate (e.g. a fo Mooncake's IR transformation is driven by inferred type information. Three places matter, and they fire in this order: 1. **Source-IR inference.** The function being differentiated is inferred via `MooncakeInterpreter`. This produces the IR that the AD transformation rewrites. -2. **Per-call `CallMeta`.** At each call statement during the source-IR walk, Mooncake needs the return type, effects, and call info. `abstract_call_gf_by_type` produces this `CallMeta`. Primitive call sites are handled specially here — see [Inference at primitive call sites](@ref). - - !!! details "What `abstract_call_gf_by_type` does" - This is Julia's central per-call-site inference entry point, in `Compiler/src/abstractinterpretation.jl`. Given a function value, the call's argument info / `atype`, the current inference state, and a cap on how many methods to consider, it returns a `Future{CallMeta}` with the inferred return type, exception type, effects, and call-site info. - - At a high level it does three things: (a) **method lookup** via `find_method_matches`, finding all method candidates whose signatures intersect the call's `atype`; (b) **per-match abstract interpretation**, calling `abstract_call_method` to recursively infer each candidate's body and optionally running constant propagation; (c) **aggregation**, joining each candidate's return type, exception type, and effects over the IPO lattice, recording inference edges, and producing a final `CallMeta`. The per-match loop is cooperatively pausable — if a sub-inference is in flight, the work is rescheduled — which is how stackless inference on Julia 1.12+ works. - - Mooncake overrides this function for `MooncakeInterpreter` and inserts its primitive / overlay logic *before* the recursive per-match step: if the call site is a primitive, we don't want `abstract_call_method` recursion at all, just a `CallMeta`. See [Inference at primitive call sites](@ref) for why this matters and how Mooncake produces the `CallMeta` without recursion. +2. **Per-call `CallMeta`.** At each call statement during the source-IR walk, Mooncake needs the return type, effects, and call info. Julia's `abstract_call_gf_by_type` (in `Compiler/src/abstractinterpretation.jl`) is the per-call-site inference entry point that produces this; Mooncake overrides it for `MooncakeInterpreter` to insert primitive/overlay handling before the recursive per-match step. See [Inference at primitive call sites](@ref) for the primitive case. 3. **Rule-type inference.** Later, during AD IR construction, Mooncake calls `Core.Compiler.return_type` with the default interpreter — for example when emitting a `pullback_type` lookup — to learn the type the rule itself returns. The key asymmetry to internalise: **Mooncake's source-function inference is overlay-aware via `OverlayMethodTable`; `NativeInterpreter`, used at primitive boundaries, is not.** ### Inference at primitive call sites -At every call site in the source IR, Mooncake needs a return type — downstream code is typed against it. At a primitive call site this is no different: the surrounding code wants the primal's return type, and the rule is an *implementation* keyed to that type, not a *source* for it. So inference asks the primal what it returns; the rule isn't consulted at this stage. - -That leaves the question of *how* to obtain the primal `CallMeta`. Why not just use `MooncakeInterpreter` for this inference too? It is, after all, the AD-aware interpreter we are already running on the source IR — recursing into a primitive's body with it is the natural choice. - -The problem is what that recursion costs. `MooncakeInterpreter`'s `abstract_call_gf_by_type` override re-passes itself into the recursive walk: every call site visited during inference of the body re-enters the override, runs the primitive check, and re-enters Julia's recursive inference under `MooncakeInterpreter` again. `MooncakeInterpreter` also uses its own inference caches separate from Julia's global cache, so the re-walk does not reuse Julia's already-warm results — every function Mooncake differentiates triggers a fresh walk of its transitive call tree. +At every call site in the source IR, Mooncake needs a return type — downstream code is typed against it. At a primitive call site, the surrounding code still wants the primal's return type; the rule is an *implementation* keyed to that type, not a *source* for it. So inference asks the primal what it returns; the rule isn't consulted at this stage. -For most call trees this is fine. Some real-world ones are not — see [PR #1115](https://github.com/chalk-lab/Mooncake.jl/pull/1115) for a SciML-shaped case where this recursion explodes into a silent compile-time hang. - -The license to do something different comes from observing what the primitive's body is actually used for at this point in AD. The body is *not* going to be rewritten into AD-generated code: at a primitive call site, the registered rule replaces the body at runtime, and the AD transformation emits a rule dispatch rather than walking the body's statements. The body's contents therefore do not need to be inspected by an AD-aware interpreter; only its return type needs to flow into the surrounding primal IR's `CallMeta`. Any interpreter that produces a correct `CallMeta` for the boundary is sufficient. - -Mooncake therefore asks `NativeInterpreter` for the `CallMeta` and stops. Standard inference still walks the body to compute the type, but it does so against Julia's global cache and without re-firing Mooncake's primitive-detection machinery at every nested call site. The recursion that would otherwise cascade through `solve`'s call tree under `MooncakeInterpreter` is bounded at each primitive boundary. +The natural choice — recursing into the body with `MooncakeInterpreter` — is expensive and unnecessary. It is expensive because `MooncakeInterpreter` re-fires its primitive/overlay check at every nested call site and uses its own inference cache separate from Julia's global one, so each function Mooncake differentiates triggers a fresh walk of its transitive call tree (see [PR #1115](https://github.com/chalk-lab/Mooncake.jl/pull/1115) for a SciML-shaped case where this explodes into a silent compile-time hang). It is unnecessary because the body isn't being rewritten into AD code, only inferred for its return type — any interpreter that produces a correct `CallMeta` is sufficient. Mooncake therefore delegates to `NativeInterpreter` at primitive boundaries, bounding the recursion at each one. The wrinkle: `NativeInterpreter` is overlay-blind. Any overlay that would affect the primitive — directly, or indirectly via a call inside its body — is invisible to inference at the primitive boundary. @@ -90,16 +75,12 @@ Mooncake's interpreter sees the overlay's body in place of the original. AD diff The rule replaces the primal at runtime. Inference at the call site asks `NativeInterpreter` for the *original* body's `CallMeta`. Fully supported, on the standard contract that the rule's primal return type matches the original's. -### Overlay and primitive on the *same* signature — supported corner case - -From the previous discussion: as long as the overlay returns the same type as the original, `NativeInterpreter`'s overlay-blindness doesn't matter — inference and the rule agree on the type at the call site either way. The corner case is when the overlay *changes* the return type: inference would see the original's type, the rule would produce the overlay's, and downstream code would be typed against the wrong one. - -Mooncake detects this configuration and routes inference through the overlay-aware default path, so the inferred return type at the call site matches the overlay's, not the original's. At runtime, the registered `rrule!!` still fires. +### Direct overlay on a primitive signature -In effect: the rule produces the value and the adjoint; the overlay's only job is to align inference's view of the return type with what the rule actually returns. This matters when the rule returns a value of a different type from the original primal and downstream code dispatches on that type. Most users should not need this pattern — prefer to express the change as either an overlay or a primitive, not both — but it is supported. +When the overlay returns the same type as the original, `NativeInterpreter`'s overlay-blindness is harmless: inference and the rule agree on the type at the call site. When the overlay *changes* the return type, Mooncake detects this configuration and routes inference through the overlay-aware default path, so the inferred return type matches the overlay's, not the original's. At runtime, the registered `rrule!!` still fires; the overlay's only job is to align inference's view of the return type with what the rule actually returns. Most users should not need this pattern — prefer to express the change as either an overlay or a primitive, not both — but it is supported. !!! details "Mechanism" - `any_matches_overlay` (in `src/interpreter/abstract_interpretation.jl`) walks the applicable methods and checks `method.external_mt === mooncake_method_table`. When that returns true, `abstract_call_gf_by_type` takes the `@invoke` branch — i.e. it defers to the default `abstract_call_gf_by_type` *with `MooncakeInterpreter` still as the interpreter*, so method lookup inside that call still goes through `OverlayMethodTable` and resolves to the overlay's body. The `NativeInterpreter` fast path is reserved for primitives whose applicable methods have no overlay. + `any_matches_overlay` (in [`src/interpreter/abstract_interpretation.jl`](https://github.com/chalk-lab/Mooncake.jl/blob/main/src/interpreter/abstract_interpretation.jl)) walks the applicable methods and checks `method.external_mt === mooncake_method_table`. When that returns true, `abstract_call_gf_by_type` takes the `@invoke` branch — i.e. it defers to the default `abstract_call_gf_by_type` *with `MooncakeInterpreter` still as the interpreter*, so method lookup inside that call still goes through `OverlayMethodTable` and resolves to the overlay's body. The `NativeInterpreter` fast path is reserved for primitives whose applicable methods have no overlay. ### Primitive called from inside an overlay's body — supported @@ -119,7 +100,7 @@ original_f(x::Float64) = unsupported_or_expensive(x) Mooncake.@mooncake_overlay original_f(x::Float64) = my_primitive(x) ``` -Differentiating any caller of `original_f` walks the overlay's body, hits `my_primitive`, dispatches its `rrule!!`, and computes the gradient via the rule's adjoint. The [drift](@ref "Drift between rules and overlays") hazard applies as anywhere else: the primitive's rule must agree with its inferred primal return type. +Differentiating any caller of `original_f` walks the overlay's body, hits `my_primitive`, dispatches its `rrule!!`, and computes the gradient via the rule's adjoint. The standard contract applies as anywhere else: the primitive's rule must return a `CoDual` whose primal type matches Mooncake's inferred return type at the call site. ### Overlay on a non-primitive called from inside a primitive's body — not supported @@ -135,79 +116,17 @@ primitive_wrapper(x::A) = helper(x) Mooncake.@is_primitive Mooncake.DefaultCtx Tuple{typeof(primitive_wrapper), A} ``` -Call sites of `primitive_wrapper` will infer the return type as `A`, even though the overlay would return `B` if it were honoured. +From any caller, inference at the `primitive_wrapper` call site reports the return type as `A`, even though the overlay would return `B` if honoured. The nested overlay is invisible only at this boundary — `MooncakeInterpreter` walking the body directly would resolve it. The break is concrete: Mooncake takes the `NativeInterpreter` fast path for the primitive's `CallMeta`, and `NativeInterpreter`'s method lookup uses the standard method table, so the overlay on `helper` registered in `mooncake_method_table` simply isn't seen. -!!! details "Mechanism — why inference sees `A`" - Walking the layers for a call `primitive_wrapper(a::A)`: +**The wrong-gradient mechanism.** The reported cases of [#1169](https://github.com/chalk-lab/Mooncake.jl/issues/1169) — including the SciMLBase `Originator` shape — involve primitives whose return types are singletons. For these, inference produces not just a type but a `CC.Const(value)`, and Julia constant-folds the call to the literal value of the *original* body before AD construction sees it. The result is a silent wrong gradient: not because the rule produced the wrong value, but because the rule was never called. The typeassert that Mooncake emits at primitive call sites (in [`src/interpreter/reverse_mode.jl`](https://github.com/chalk-lab/Mooncake.jl/blob/main/src/interpreter/reverse_mode.jl)) is not the safety net here — by the time it would have run, the call has already been replaced by a literal. - 1. The outer function containing this call is inferred under `MooncakeInterpreter`. At the `primitive_wrapper(a)` statement, `abstract_call_gf_by_type` is called with `MooncakeInterpreter`. - 2. Applicable-method lookup uses `OverlayMethodTable`, but `primitive_wrapper` itself has no overlay. The match resolves to the ordinary `primitive_wrapper(::A)`. - 3. `any_matches_primitive` returns true. `any_matches_overlay` returns false (the method's `external_mt` is unset). - 4. The branch added in #1170 takes the `NativeInterpreter` fast path and asks for the `CallMeta` of `primitive_wrapper(::A)` under `NativeInterpreter`. - 5. `NativeInterpreter` infers `primitive_wrapper`'s body. At the `helper(x)` statement inside that body, its method lookup uses the standard method table — `mooncake_method_table` is invisible to it — so it resolves to `helper(::A) = A()` and infers the return as `A`. - 6. The inferred return type of `primitive_wrapper(::A)` therefore propagates as `A` back to the outer caller, even though under Mooncake the overlay would have made it `B`. +!!! details "How the call gets folded away" + 1. `NativeInterpreter` (overlay-blind) infers the primitive call as `Const(original_value)` — the singleton instance from the *original* body. + 2. [`widen_rettype_callmeta`](@ref Mooncake.widen_rettype_callmeta) exists to prevent `Const` from causing primitive calls to fold away, but it has a documented carve-out: if every runtime argument at the call site is also `Const`, folding is treated as safe (the `sin(1.0)`-with-a-literal case). A zero-runtime-argument primitive trivially satisfies this; many SciML-style overlays do too. + 3. Const propagation in subsequent compiler passes replaces the primitive call with the literal value — the *original*'s singleton, not the overlay's. + 4. By the time AD construction processes the IR, there is no primitive call site at this location, only a constant. No `rrule!!` call is emitted; no `Core.typeassert` is emitted; no runtime check fires. + 5. Downstream code is compiled against the inferred (wrong) singleton type and picks rules keyed to it. The runtime never has the opportunity to course-correct. - The break is at step (5): the right layer (Mooncake) knows about the overlay, but it has delegated this lookup to a layer that does not. - -**The wrong-gradient mechanism.** The reported cases of [#1169](https://github.com/chalk-lab/Mooncake.jl/issues/1169) — including the SciMLBase `Originator` shape — involve primitives whose return types are singletons. For these, inference produces not just a type but a `CC.Const(value)`. The consequence is that the rule never gets a chance to fire: Julia constant-folds the call to the literal value before AD construction sees it. - -Walked out: - -1. `NativeInterpreter` (overlay-blind) infers the primitive call as `Const(original_value)` — the singleton instance from the *original* body. -2. `widen_rettype_callmeta` exists to prevent `Const` from causing primitive calls to fold away, but it has a documented carve-out: if every runtime argument at the call site is also `Const`, folding is treated as safe (the `sin(1.0)`-with-a-literal case). A zero-runtime-argument primitive trivially satisfies this; many SciML-style overlays do too. -3. Const propagation in subsequent compiler passes replaces the primitive call with the literal value — the *original*'s singleton, not the overlay's. -4. By the time AD construction processes the IR, there is no primitive call site at this location, only a constant. No `rrule!!` call is emitted; no `Core.typeassert` is emitted; no runtime check fires. -5. Downstream code is compiled against the inferred (wrong) singleton type and picks rules keyed to it. The runtime never has the opportunity to course-correct. - -The result is a silent wrong gradient: not because the rule produced the wrong value, but because the rule was never called. The typeassert that Mooncake emits at primitive call sites (in `src/interpreter/reverse_mode.jl`) is not the safety net here — by the time it would have run, the call has already been replaced by a literal. - -For overlays that don't yield a `Const` (e.g. a primitive whose return is concrete but not a singleton), the failure mode is different: either inference and the rule happen to agree on the type and there is no problem, or they disagree and the typeassert traps with a `TypeError` — loud rather than silent. The dangerous combination is overlay + singleton return, which is exactly the shape both #1169's MWE and the SciMLBase usage take. - -This is by design: Mooncake treats primitives as sealed boundaries and does not walk into a primitive's body to discover what overlays might affect it. The fix in [#1170](https://github.com/chalk-lab/Mooncake.jl/pull/1170) extends overlay-awareness only to the *primitive's own signature* — the boundary inference is already looking at. For overlays reachable from inside a primitive's body, the rule and Mooncake's inferred type may diverge, and keeping them coherent is the rule author's responsibility (see [Drift between rules and overlays](@ref) for the contract). - -An alternative approach in [PR #1168](https://github.com/chalk-lab/Mooncake.jl/pull/1168) instead walks into primitive bodies with overlay-aware method lookup (via a wrapper around `NativeInterpreter` that uses `mooncake_method_table`); it would fix the inside-body case as well. It was not adopted in #1170 — the sealed-boundary policy makes for a smaller and more focused fix. +This is by design: Mooncake treats primitives as sealed boundaries and does not walk into a primitive's body to discover what overlays might affect it. The fix in [#1170](https://github.com/chalk-lab/Mooncake.jl/pull/1170) extends overlay-awareness only to the *primitive's own signature* — the boundary inference is already looking at. For overlays reachable from inside a primitive's body, the rule and Mooncake's inferred type may diverge, and keeping them coherent is the rule author's responsibility. Workaround when you do encounter this shape: lift the overlay to the level the user actually calls. Either remove the primitive declaration on the wrapper and let AD differentiate it, or register the desired behaviour as a primitive on the outer function. - -### Drift between rules and overlays - -A rule's return type is hand-written and fixed when `rrule!!` is authored. If an overlay is introduced later — directly on the primitive, or on a function called in its body — that changes the primal type Mooncake's inference produces, the rule and the inference can disagree. - -The invariant the author must maintain is: - -```text -inferred primal return type at the call site - == -primal type inside the CoDual returned by the rule -``` - -When the types are concrete and non-singleton, the typeassert Mooncake emits on the rule's primal output (in `src/interpreter/reverse_mode.jl`) catches violations at runtime. For example: - -```julia -struct DriftOld; v::Float64 end -struct DriftNew; v::Float64 end - -f(x::Float64) = DriftOld(x) - -# Overlay introduced later, changing the return type. -Mooncake.@mooncake_overlay f(x::Float64) = DriftNew(x) - -Mooncake.@is_primitive Mooncake.DefaultCtx Tuple{typeof(f), Float64} - -# Stale rule, authored before the overlay was added: still returns DriftOld(...). -function Mooncake.rrule!!(g::CoDual{typeof(f)}, x::CoDual{Float64}) - pb(dy) = NoRData(), NoRData(), dy - return zero_fcodual(DriftOld(primal(x))), pb -end -``` - -Differentiating any caller of `f` then traps with: - -```text -TypeError: in typeassert, expected CoDual{DriftNew, NoFData}, - got a value of type CoDual{DriftOld, NoFData} -``` - -When the inferred return is a singleton (`CC.Const`), the primitive call is liable to be const-folded to the literal value before the rule fires, so the typeassert is bypassed. The runtime then follows the inferred-type path — which is what the overlay would have produced — and the rule's stale return type ends up irrelevant in practice. Convenient, but it's coincidence: don't rely on it. - -#1170 makes inference at primitive boundaries overlay-aware so the left-hand side reflects what the overlay-modified primal would actually return. Keeping the right-hand side in sync — adjusting the rule when the overlay changes the type — is the author's responsibility. From 8732f295609a7331585334a28fd04d13d15b37c2 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 15 May 2026 10:08:01 +0100 Subject: [PATCH 6/6] Refine primitives/overlays doc Tighten the direct-overlay-on-primitive section as not-recommended, split the unsupported case into singleton vs non-singleton failure shapes (clarifying where the typeassert catches and where it can't), sharpen the contrast between the two overlay-plus-primitive sections, and align terminology in the known-limitations summary. --- .../primitives_and_overlays.md | 31 ++++++++++--------- docs/src/known_limitations.md | 4 +-- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/docs/src/developer_documentation/primitives_and_overlays.md b/docs/src/developer_documentation/primitives_and_overlays.md index dc06b20e4f..4846eb327b 100644 --- a/docs/src/developer_documentation/primitives_and_overlays.md +++ b/docs/src/developer_documentation/primitives_and_overlays.md @@ -16,13 +16,15 @@ To differentiate a function, Mooncake runs a *compilation step*: it walks the fu - The body is not walked: Mooncake leaves the call statement in the transformed IR as-is. - At runtime, the registered [`Mooncake.rrule!!`](@ref) (or [`Mooncake.frule!!`](@ref)) is dispatched in place of the primal call. -The rule, not the body, produces the value at this call site; its return type — typically `Tuple{CoDual{B,F}, Pullback}` — is what the surrounding AD code sees. +The rule, not the body, produces the value at this call site; what the surrounding AD code sees is a `CoDual` whose primal type matches the primitive's return type. Examples on this page use reverse mode (`rrule!!`) for concreteness; the same machinery applies to forward mode (`frule!!`). !!! details "Mechanism" - Mooncake's `AbstractInterpreter` override of `abstract_call_gf_by_type` (in [`src/interpreter/abstract_interpretation.jl`](https://github.com/chalk-lab/Mooncake.jl/blob/main/src/interpreter/abstract_interpretation.jl)) checks each call site against the primitive table via `any_matches_primitive`. When a match is detected, the resulting `CallMeta` is wrapped in a `NoInlineCallInfo`, which Mooncake's `inlining_policy` / `src_inlining_policy` then refuses to inline. The primitive call therefore survives inlining, and the rule-dispatch code is emitted at that statement instead of inlined primal code. + Mooncake's `AbstractInterpreter` override of `abstract_call_gf_by_type` (in [`src/interpreter/abstract_interpretation.jl`](https://github.com/chalk-lab/Mooncake.jl/blob/main/src/interpreter/abstract_interpretation.jl)) checks each call site against the primitive table via `any_matches_primitive`. When a match is detected, the resulting `CallMeta` is wrapped in a `NoInlineCallInfo`, which Mooncake's inlining policy (`inlining_policy` pre-1.12, `src_inlining_policy` from 1.12) then refuses to inline. The primitive call therefore survives inlining, and the rule-dispatch code is emitted at that statement instead of inlined primal code. ## Overlays +See [Simplifying Code via Overlays](@ref) in the Defining Rules guide for the `@mooncake_overlay` docstring and a user-facing introduction; this section covers the inference-level picture. + `@mooncake_overlay` registers an additional method for a function in a private method table, `Mooncake.mooncake_method_table`. Only Mooncake's interpreter consults this table; plain Julia dispatch and `Core.Compiler.NativeInterpreter` do not. For example: @@ -47,17 +49,17 @@ The intended use is to substitute a body Mooncake can't differentiate (e.g. a fo ## Type inference -Mooncake's IR transformation is driven by inferred type information. Three places matter, and they fire in this order: +Mooncake's IR transformation is driven by inferred type information. Three places matter, broadly in this order: 1. **Source-IR inference.** The function being differentiated is inferred via `MooncakeInterpreter`. This produces the IR that the AD transformation rewrites. 2. **Per-call `CallMeta`.** At each call statement during the source-IR walk, Mooncake needs the return type, effects, and call info. Julia's `abstract_call_gf_by_type` (in `Compiler/src/abstractinterpretation.jl`) is the per-call-site inference entry point that produces this; Mooncake overrides it for `MooncakeInterpreter` to insert primitive/overlay handling before the recursive per-match step. See [Inference at primitive call sites](@ref) for the primitive case. 3. **Rule-type inference.** Later, during AD IR construction, Mooncake calls `Core.Compiler.return_type` with the default interpreter — for example when emitting a `pullback_type` lookup — to learn the type the rule itself returns. -The key asymmetry to internalise: **Mooncake's source-function inference is overlay-aware via `OverlayMethodTable`; `NativeInterpreter`, used at primitive boundaries, is not.** +The key asymmetry to internalise: **Mooncake's source-function inference (1) is overlay-aware via `OverlayMethodTable`; `NativeInterpreter`, used at primitive boundaries during (2), is not.** ### Inference at primitive call sites -At every call site in the source IR, Mooncake needs a return type — downstream code is typed against it. At a primitive call site, the surrounding code still wants the primal's return type; the rule is an *implementation* keyed to that type, not a *source* for it. So inference asks the primal what it returns; the rule isn't consulted at this stage. +At every call site in the source IR, Mooncake needs a return type — downstream code is typed against it. At a primitive call site, the surrounding code still wants the primal's return type; the rule is an *implementation* keyed to that type, not a *source* for it. So inference asks the primal what it returns. The rule itself is not consulted at this stage. The natural choice — recursing into the body with `MooncakeInterpreter` — is expensive and unnecessary. It is expensive because `MooncakeInterpreter` re-fires its primitive/overlay check at every nested call site and uses its own inference cache separate from Julia's global one, so each function Mooncake differentiates triggers a fresh walk of its transitive call tree (see [PR #1115](https://github.com/chalk-lab/Mooncake.jl/pull/1115) for a SciML-shaped case where this explodes into a silent compile-time hang). It is unnecessary because the body isn't being rewritten into AD code, only inferred for its return type — any interpreter that produces a correct `CallMeta` is sufficient. Mooncake therefore delegates to `NativeInterpreter` at primitive boundaries, bounding the recursion at each one. @@ -77,14 +79,14 @@ The rule replaces the primal at runtime. Inference at the call site asks `Native ### Direct overlay on a primitive signature -When the overlay returns the same type as the original, `NativeInterpreter`'s overlay-blindness is harmless: inference and the rule agree on the type at the call site. When the overlay *changes* the return type, Mooncake detects this configuration and routes inference through the overlay-aware default path, so the inferred return type matches the overlay's, not the original's. At runtime, the registered `rrule!!` still fires; the overlay's only job is to align inference's view of the return type with what the rule actually returns. Most users should not need this pattern — prefer to express the change as either an overlay or a primitive, not both — but it is supported. +This is not a recommended pattern — choose one of `@mooncake_overlay` or `@is_primitive` on a given signature, not both. Mooncake currently supports it as a special case ([#1170](https://github.com/chalk-lab/Mooncake.jl/pull/1170)): when both apply, the rule still fires at runtime, and Mooncake routes call-site inference through the overlay-aware default path so the inferred return type matches what the rule actually returns rather than what the original primal would have returned. When the overlay's return type happens to equal the original's, the routing change is harmless; when it differs, it is what keeps inference and the rule coherent. !!! details "Mechanism" - `any_matches_overlay` (in [`src/interpreter/abstract_interpretation.jl`](https://github.com/chalk-lab/Mooncake.jl/blob/main/src/interpreter/abstract_interpretation.jl)) walks the applicable methods and checks `method.external_mt === mooncake_method_table`. When that returns true, `abstract_call_gf_by_type` takes the `@invoke` branch — i.e. it defers to the default `abstract_call_gf_by_type` *with `MooncakeInterpreter` still as the interpreter*, so method lookup inside that call still goes through `OverlayMethodTable` and resolves to the overlay's body. The `NativeInterpreter` fast path is reserved for primitives whose applicable methods have no overlay. + `any_matches_overlay` (in [`src/interpreter/abstract_interpretation.jl`](https://github.com/chalk-lab/Mooncake.jl/blob/main/src/interpreter/abstract_interpretation.jl)) walks the applicable methods returned by `find_method_matches` and checks `method.external_mt === mooncake_method_table` on each. The check is per-method and signature-aware: an overlay registered for `f(::Float64)` is invisible at a call site that dispatches to a different method (e.g. `f(::Int)`). It also does not try to detect whether the overlay actually *changes* the return type — any applicable overlay triggers the same path, even when the overlay-aware and overlay-blind inference paths would agree. When the check returns true, `abstract_call_gf_by_type` takes the `@invoke` branch — i.e. it defers to the default `abstract_call_gf_by_type` *with `MooncakeInterpreter` still as the interpreter*, so method lookup inside that call still goes through `OverlayMethodTable` and resolves to the overlay's body. The `NativeInterpreter` fast path is reserved for primitives whose applicable methods have no overlay. ### Primitive called from inside an overlay's body — supported -An overlay's body may itself call a registered primitive. This is the ordinary, supported flow: Mooncake walks the overlay body for AD, and any primitive call inside it is handled by the same machinery that handles primitive calls anywhere else (primitive detection, `NativeInterpreter` for the `CallMeta`, rule dispatch at runtime). No special arrangement is needed; this is in fact the most common reason to write an overlay — substituting an AD-unfriendly body with one that bottoms out on a hand-written rule. +Unlike the previous section, the two macros sit on *different* functions here: an overlay replaces one function's body so that it bottoms out on a *separate* function carrying a hand-written rule. This is the ordinary, supported flow — Mooncake walks the overlay body for AD, and the primitive call inside it is handled by the same machinery as any other primitive call (primitive detection, `NativeInterpreter` for the `CallMeta`, rule dispatch at runtime). No special arrangement is needed; this is in fact the most common reason to write an overlay. ```julia my_primitive(x::Float64) = 2x @@ -104,7 +106,12 @@ Differentiating any caller of `original_f` walks the overlay's body, hits `my_pr ### Overlay on a non-primitive called from inside a primitive's body — not supported -Although the primitive's body is not *differentiated*, it is still *inferred* — `NativeInterpreter` walks it to produce the primitive's `CallMeta`. Because `NativeInterpreter` does not consult `mooncake_method_table`, any overlay on a nested call within the body is invisible to that walk. Inference of the primitive's return type therefore sees the original definitions of its nested calls, not the overlays. +If a primitive's return type depends on an overlay applied to a function it calls internally, AD silently uses the un-overlaid return type. The failure surfaces in one of two shapes, depending on whether the affected return value is a singleton: + +- **Singleton return (the SciMLBase `Originator` shape in [#1169](https://github.com/chalk-lab/Mooncake.jl/issues/1169)).** Inference produces a `CC.Const(value)` from the *original* body, and Julia constant-folds the primitive call to that literal *before* AD construction sees it. No `rrule!!` call is emitted, no typeassert fires, and downstream code is compiled against the wrong singleton type — a silent wrong gradient. +- **Non-singleton return.** No const-folding, but inferred return type and rule output still disagree: downstream dispatch is keyed to the inferred (un-overlaid) type while the rule returns the overlaid type. This is the case the runtime `Core.typeassert` Mooncake emits at primitive call sites (in [`src/interpreter/reverse_mode.jl`](https://github.com/chalk-lab/Mooncake.jl/blob/main/src/interpreter/reverse_mode.jl)) normally catches at the rule-output boundary, surfacing as a `TypeError` rather than a silent wrong gradient. In the singleton case the typeassert is *not* a safety net — by the time it would have run, the primitive call has already been replaced by a literal, so no typeassert is emitted. + +The mechanism is the same in both shapes: although the primitive's body is not *differentiated*, it is still *inferred* — `NativeInterpreter` walks it to produce the primitive's `CallMeta`. Because `NativeInterpreter` does not consult `mooncake_method_table`, any overlay on a nested call within the body is invisible to that walk. Inference of the primitive's return type therefore sees the original definitions of its nested calls, not the overlays. Example: @@ -118,9 +125,7 @@ Mooncake.@is_primitive Mooncake.DefaultCtx Tuple{typeof(primitive_wrapper), A} From any caller, inference at the `primitive_wrapper` call site reports the return type as `A`, even though the overlay would return `B` if honoured. The nested overlay is invisible only at this boundary — `MooncakeInterpreter` walking the body directly would resolve it. The break is concrete: Mooncake takes the `NativeInterpreter` fast path for the primitive's `CallMeta`, and `NativeInterpreter`'s method lookup uses the standard method table, so the overlay on `helper` registered in `mooncake_method_table` simply isn't seen. -**The wrong-gradient mechanism.** The reported cases of [#1169](https://github.com/chalk-lab/Mooncake.jl/issues/1169) — including the SciMLBase `Originator` shape — involve primitives whose return types are singletons. For these, inference produces not just a type but a `CC.Const(value)`, and Julia constant-folds the call to the literal value of the *original* body before AD construction sees it. The result is a silent wrong gradient: not because the rule produced the wrong value, but because the rule was never called. The typeassert that Mooncake emits at primitive call sites (in [`src/interpreter/reverse_mode.jl`](https://github.com/chalk-lab/Mooncake.jl/blob/main/src/interpreter/reverse_mode.jl)) is not the safety net here — by the time it would have run, the call has already been replaced by a literal. - -!!! details "How the call gets folded away" +!!! details "How a singleton call gets folded away" 1. `NativeInterpreter` (overlay-blind) infers the primitive call as `Const(original_value)` — the singleton instance from the *original* body. 2. [`widen_rettype_callmeta`](@ref Mooncake.widen_rettype_callmeta) exists to prevent `Const` from causing primitive calls to fold away, but it has a documented carve-out: if every runtime argument at the call site is also `Const`, folding is treated as safe (the `sin(1.0)`-with-a-literal case). A zero-runtime-argument primitive trivially satisfies this; many SciML-style overlays do too. 3. Const propagation in subsequent compiler passes replaces the primitive call with the literal value — the *original*'s singleton, not the overlay's. @@ -128,5 +133,3 @@ From any caller, inference at the `primitive_wrapper` call site reports the retu 5. Downstream code is compiled against the inferred (wrong) singleton type and picks rules keyed to it. The runtime never has the opportunity to course-correct. This is by design: Mooncake treats primitives as sealed boundaries and does not walk into a primitive's body to discover what overlays might affect it. The fix in [#1170](https://github.com/chalk-lab/Mooncake.jl/pull/1170) extends overlay-awareness only to the *primitive's own signature* — the boundary inference is already looking at. For overlays reachable from inside a primitive's body, the rule and Mooncake's inferred type may diverge, and keeping them coherent is the rule author's responsibility. - -Workaround when you do encounter this shape: lift the overlay to the level the user actually calls. Either remove the primitive declaration on the wrapper and let AD differentiate it, or register the desired behaviour as a primitive on the outer function. diff --git a/docs/src/known_limitations.md b/docs/src/known_limitations.md index f14713576f..d4868da8f5 100644 --- a/docs/src/known_limitations.md +++ b/docs/src/known_limitations.md @@ -206,5 +206,5 @@ DocTestSetup = nothing `@mooncake_overlay` substitutes the function body that Mooncake differentiates; `@is_primitive` marks a call site as a boundary where Mooncake stops differentiating the body and dispatches a hand-written rule instead. The two operate at different layers of the AD pipeline, and not every combination of them is supported. [Primitives and Overlays](@ref) covers the full picture; the practical summary is: -- **Overlay on the same function and argument types as a primitive** is supported. The rule still runs at runtime, and the overlay only adjusts what Mooncake infers as the call's return type so it matches what the rule produces. -- **Overlay on a function called from inside a primitive's body** is not supported. Mooncake does not look into a primitive's body for overlays, so the overlay has no effect there; if it would have changed the return type, Mooncake's inferred type and the rule's actual output disagree, which can produce silently wrong gradients. +- **Direct overlay on a primitive signature** is supported but not recommended — choose one of `@mooncake_overlay` or `@is_primitive`, not both. When both apply, the rule still runs at runtime, and the overlay only adjusts what Mooncake infers as the call's return type so it matches what the rule produces. +- **Overlay on a non-primitive called from inside a primitive's body** is not supported. Mooncake does not look into a primitive's body for overlays, so the overlay has no effect there; if it would have changed the return type, Mooncake's inferred type and the rule's actual output disagree, which can produce silently wrong gradients (singleton-returning primitives) or a runtime `TypeError` (otherwise).