Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions src/interpreter/abstract_interpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,28 @@ function Core.Compiler.abstract_call_gf_by_type(
# primitives, and continue that search down the callee tree. That extra work is
# unnecessary for a primitive with a hand-written rule.
#
# Exception: `NativeInterpreter` is overlay-blind, so fall back to
# overlay-aware default dispatch when a `@mooncake_overlay` applies (issue #1169).
#
# `noinline_callmeta` below then blocks inlining/const-folding so the primitive
# call stays in the caller IR and Mooncake can dispatch its `rrule!!` at runtime.
# See PR #1115 for more discussion.
native_interp = CC.NativeInterpreter(interp.world)
ret = CC.abstract_call_gf_by_type(
native_interp, f, arginfo, si, atype, sv, max_methods
)
ret = if any_matches_overlay(applicable)
@invoke CC.abstract_call_gf_by_type(
interp::CC.AbstractInterpreter,
f::Any,
arginfo::CC.ArgInfo,
si::CC.StmtInfo,
atype::Any,
sv::CC.AbsIntState,
max_methods::Int,
)
else
native_interp = CC.NativeInterpreter(interp.world)
CC.abstract_call_gf_by_type(
native_interp, f, arginfo, si, atype, sv, max_methods
)
end
@static if VERSION < v"1.12-"
call = ret::CC.CallMeta
# Keep primitives in caller IR by blocking const-folding and inlining
Expand Down Expand Up @@ -204,12 +219,19 @@ end

function any_matches_primitive(applicable, C, M, world)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that there is already any_matches_primitive defined, which accesses Julia internals in similar ways.

for app in applicable
if VERSION < v"1.12-"
sig = app.spec_types
else
sig = app.match.spec_types
match = VERSION < v"1.12-" ? app : app.match
if is_primitive(C, M, match.spec_types, world)
return true
end
if is_primitive(C, M, sig, world)
end
false
end

function any_matches_overlay(applicable)
for app in applicable
match = VERSION < v"1.12-" ? app : app.match
method = match.method
if isdefined(method, :external_mt) && method.external_mt === mooncake_method_table
return true
end
end
Expand Down
16 changes: 16 additions & 0 deletions test/interpreter/abstract_interpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@ non_primitive(x) = sin(x)

Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(a_primitive),Float64}

# Issue #1169: a `@mooncake_overlay` on a primitive must propagate to the
# inferred return type at the call site, otherwise downstream dispatch
# compiles against the un-overlaid type.
struct OverlayA end
struct OverlayB end
overlay_switch() = OverlayA()
Mooncake.@mooncake_overlay overlay_switch() = OverlayB()
Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(overlay_switch)}
overlay_caller() = overlay_switch()

contains_primitive(x) = @inline a_primitive(x)
contains_non_primitive(x) = @inline non_primitive(x)
contains_primitive_behind_call(x) = @inline contains_primitive(x)
Expand Down Expand Up @@ -123,6 +133,12 @@ end
@test val == 1.0
@test grad[2] == [2.0, 0.0, 1.0]
end

@testset "1169 - overlay propagates to inferred return type" begin
interp = Mooncake.MooncakeInterpreter(DefaultCtx, ReverseMode)
sig = Tuple{typeof(overlay_caller)}
@test Base.code_ircode_by_type(sig; interp)[1][2] == OverlayB
end
end

@testset "Config(empty_cache=true)" begin
Expand Down
Loading