Honor mooncake_overlay in primitive inference#1170
Conversation
`MooncakeInterpreter` previously routed every primitive call site through `NativeInterpreter` for `CallMeta`, but `NativeInterpreter` is overlay-blind: when a primitive has a `@mooncake_overlay` that changes its return type, the inferred type at the call site was wrong and downstream dispatch compiled against the original type. Detect overlay matches via `Method.external_mt` and route them through `@invoke` instead, keeping the `NativeInterpreter` fast path for non-overlay primitives. Fixes #1169. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
0eddb03 to
848c6a3
Compare
|
On this issue, I slightly lean towards #1168 because indeed it's more general and the complexity diff w.r.t this PR is not too bad. A minor issue with the current code is the use of |
Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Signed-off-by: Hong Ge <3279477+yebai@users.noreply.github.com>
Both `any_matches_primitive` and `any_matches_overlay` need the same pre-1.12 / 1.12+ unwrap of an `applicable` entry to a `MethodMatch`. Use a single inline `match = VERSION < v"1.12-" ? app : app.match` in each so the version skew sits in one place per function. Also tighten the overlay check to compare directly against `mooncake_method_table` rather than `external_mt !== nothing`, so an unrelated downstream overlay table cannot trip the overlay path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
| @@ -204,12 +219,19 @@ end | |||
|
|
|||
| function any_matches_primitive(applicable, C, M, world) | |||
There was a problem hiding this comment.
Note that there is already any_matches_primitive defined, which accesses Julia internals in similar ways.
|
Thanks @yebai for the follow-up. After some thought and investigation, I now lean toward this PR over #1168 for the scope of issue #1169. Use of Before merging, there is one particular case I want to think through. Maybe it is not that important, but I would like to be clear. Comparing on four versions of Mooncake: on using Mooncake
struct A end
struct B end
helper(::A) = A()
Mooncake.@mooncake_overlay helper(::A) = B()
primitive_wrapper(x::A) = helper(x)
Mooncake.@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{
typeof(primitive_wrapper), A
}
caller() = primitive_wrapper(A())
caller_arg(x::A) = primitive_wrapper(x)interp = Mooncake.MooncakeInterpreter(Mooncake.DefaultCtx, Mooncake.ReverseMode)
for sig in (
Tuple{typeof(primitive_wrapper), A},
Tuple{typeof(caller)},
Tuple{typeof(caller_arg), A},
)
ir, rt = Base.code_ircode_by_type(sig; interp)[1]
println(sig, " => ", rt)
endResults:
This result indicates that #1115 introduced a behavior change, and #1168 reverted it. @AstitvaAggarwal and I had an offline chat and we'll dig deeper on this. |
Documents how `@mooncake_overlay` and `@is_primitive` compose, including the supported case (same function/argtypes) and the unsupported case (overlay reachable only from inside a primitive's body). Adds a short section to known_limitations.md and a cross-reference from defining_rules.md. Regression test exercises overlay propagation through downstream dispatch for the same-signature path.
Cuts redundant explanatory layers in the developer doc: - Drops the standalone "Drift between rules and overlays" section (the contract is restated in-place at the two referring points; the typeassert/TypeError mechanism is already covered in the const-folding section). - Removes the nested abstract_call_gf_by_type details block and the concrete-but-not-singleton aside. - Compresses the "why NativeInterpreter at primitive boundaries" rationale from five paragraphs to one. - Hides the five-step const-folding walk in a collapsible details block; the headline summary stays visible. - Folds the redundant "why inference sees A" mechanism into one inline sentence. - Renames "supported corner case" to "Direct overlay on a primitive signature" (neutral framing). - Fixes pullback arity in the worked example (2 slots, not 3). - Softens "primitive call survives into the IR" to "survives inlining" — const-folding is covered separately. - Adds GitHub source links for the .jl files referenced in the details blocks, and an @ref for widen_rettype_callmeta. Page goes from ~210 lines to ~130. Documenter build passes with no unresolved refs.
|
@AstitvaAggarwal I added a developer doc on primitives and overlay, can you give it a read and let me know your thoughts? |
|
looks good, dont think anything is incorrect. |
Tighten the direct-overlay-on-primitive section as not-recommended, split the unsupported case into singleton vs non-singleton failure shapes (clarifying where the typeassert catches and where it can't), sharpen the contrast between the two overlay-plus-primitive sections, and align terminology in the known-limitations summary.
MooncakeInterpreterpreviously routed every primitive call site throughNativeInterpreterforCallMeta, butNativeInterpreteris overlay-blind: when a primitive has a@mooncake_overlaythat changes its return type, the inferred type at the call site was wrong, and downstream dispatch compiled against the original type. Detect overlay matches viaMethod.external_mtand route them through@invokeinstead, keeping theNativeInterpreterfast path for non-overlay primitives.Fixes #1169. Alternative to #1168.
CI Summary — GitHub Actions
Documentation Preview
Mooncake.jl documentation for PR #1170 is available at:
https://chalk-lab.github.io/Mooncake.jl/previews/PR1170/
Performance
Performance Ratio:
Ratio of time to compute gradient and time to compute function.
Warning: results are very approximate! See here for more context.