Skip to content

max_norm_perturbation kwarg#1173

Open
AstitvaAggarwal wants to merge 8 commits into
mainfrom
Astitva/issue-765-test-rule-max-perturbation
Open

max_norm_perturbation kwarg#1173
AstitvaAggarwal wants to merge 8 commits into
mainfrom
Astitva/issue-765-test-rule-max-perturbation

Conversation

@AstitvaAggarwal
Copy link
Copy Markdown
Member

@AstitvaAggarwal AstitvaAggarwal commented May 13, 2026

closes #765

CI Summary — GitHub Actions

Documentation Preview

Mooncake.jl documentation for PR #1173 is available at:
https://chalk-lab.github.io/Mooncake.jl/previews/PR1173/

Performance

Performance Ratio:
Ratio of time to compute gradient and time to compute function.
Warning: results are very approximate! See here for more context.

┌───────────────────────┬──────────┬──────────┬─────────────┬─────────┬─────────────┬────────┐
│                 Label │   Primal │ Mooncake │ MooncakeFwd │  Zygote │ ReverseDiff │ Enzyme │
│                String │   String │   String │      String │  String │      String │ String │
├───────────────────────┼──────────┼──────────┼─────────────┼─────────┼─────────────┼────────┤
│              sum_1000 │ 298.0 ns │     1.66 │        1.61 │   0.413 │        1.38 │    6.5 │
│             _sum_1000 │   1.2 μs │     10.7 │        1.04 │  2970.0 │        27.0 │   1.43 │
│          sum_sin_1000 │  5.79 μs │     4.25 │        4.62 │    2.37 │        13.0 │    2.3 │
│         _sum_sin_1000 │  9.36 μs │     2.25 │        1.11 │   156.0 │        8.05 │   1.31 │
│              kron_sum │ 375.0 μs │     8.33 │        2.62 │    8.37 │       240.0 │   15.1 │
│         kron_view_sum │ 414.0 μs │     8.27 │        4.05 │    20.7 │       265.0 │   14.6 │
│ naive_map_sin_cos_exp │  2.93 μs │      2.7 │        1.26 │ missing │        5.59 │   1.67 │
│       map_sin_cos_exp │  2.92 μs │      4.5 │        2.78 │    1.41 │        4.68 │    2.1 │
│ broadcast_sin_cos_exp │   3.0 μs │     4.18 │        1.72 │    2.58 │        1.04 │   1.64 │
│            simple_mlp │ 648.0 μs │      3.2 │        1.93 │    1.12 │        6.51 │   2.05 │
│                gp_lml │ 272.0 μs │     8.43 │        2.41 │    7.35 │     missing │   4.29 │
│    large_single_block │ 504.0 ns │     8.05 │        2.26 │  3940.0 │        25.3 │   1.79 │
└───────────────────────┴──────────┴──────────┴─────────────┴─────────┴─────────────┴────────┘

@AstitvaAggarwal AstitvaAggarwal marked this pull request as ready for review May 13, 2026 16:13
@sunxd3
Copy link
Copy Markdown
Collaborator

sunxd3 commented May 14, 2026

thanks, I'll take a look today

@AstitvaAggarwal AstitvaAggarwal requested a review from sunxd3 May 14, 2026 10:53
Copy link
Copy Markdown
Collaborator

@sunxd3 sunxd3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the effort

Comment thread src/test_utils.jl
ε_list = filter(
ε -> isnothing(max_norm_perturbation) || ε <= max_norm_perturbation,
[1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8],
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

worth add a check here for the filtered ε_list: if too few left or empty

Comment thread src/test_utils.jl Outdated
x = map(primal, x_ẋ)
= map(tangent, x_ẋ)
x = map(primal, x_ẋ)
= map(x -> normalize_tangent(tangent(x)), x_ẋ)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a nit on style, maybe do something like normalize_tangent ∘ tangent instead of the anonymous function?

Comment thread src/test_utils.jl
Comment on lines 911 to 925
"""
test_rule(
rng::AbstractRNG,
x...;
interface_only::Bool=false,
is_primitive::Bool=true,
perf_flag::Symbol=:none,
mode::Union{Nothing,Type{ForwardMode},Type{ReverseMode}}=nothing,
debug_mode::Bool=false,
unsafe_perturb::Bool=false,
print_results=true,
output_tangent=nothing,
atol=1e-3,
rtol=1e-3
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update this too? it was outdated before this PR but worth fixing with it

Comment thread test/test_utils.jl Outdated
print_results=false,
max_norm_perturbation=1.0e-3,
)
@test !ts.anynonpass
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will fix this, its actually a field in Test.DefaultTestSet i was using it as @test !ts.anynonpass to assert that all internal tests passed.

Comment thread test/test_utils.jl Outdated
@test TestUtils.count_allocs(Mooncake.fdata_type, Tuple{Vector{Float64}}) == 0
end
@testset "max_norm_perturbation kwarg for testing rules" begin
rng = Xoshiro(1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels inconsistent coding style

StableRNG(123), sqrt, 0.005; is_primitive=true, max_norm_perturbation=1e-3
)
end
end
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this overlap with the one in test_utils.jl?

Comment thread src/test_utils.jl
ε_list = filter(
ε -> isnothing(max_norm_perturbation) || ε <= max_norm_perturbation,
[1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8],
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check here also

@AstitvaAggarwal
Copy link
Copy Markdown
Member Author

thanks for the reviews @sunxd3! I have tried addressing all of them.

@AstitvaAggarwal AstitvaAggarwal requested a review from sunxd3 May 15, 2026 11:34
Comment thread src/test_utils.jl Outdated
Comment on lines +478 to +486
function _check_ε_list(ε_list, max_norm_perturbation)
isempty(ε_list) && throw(
ArgumentError(
"max_norm_perturbation=$max_norm_perturbation filters out all finite-difference " *
"step sizes; the smallest available is 1e-8, below which floating-point rounding " *
"errors dominate the estimate.",
),
)
end
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a style nit: we can just check the value of max_norm_perturbation directly before

ε_list = filter(
        ε -> isnothing(max_norm_perturbation) || ε <= max_norm_perturbation,
        [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8],
    )

and error if too small, try to keep it a oneliner?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

restrict finite differences in test_rule

2 participants