From f3edb3f8129529ddb00eae3818a429c9d638bd3b Mon Sep 17 00:00:00 2001 From: AstitvaAggarwal Date: Wed, 13 May 2026 16:38:20 +0100 Subject: [PATCH 1/6] max_norm_perturbation kwarg --- src/test_utils.jl | 55 +++++++++++++++++++++++++++++++++++++--------- test/test_utils.jl | 17 ++++++++++++++ 2 files changed, 62 insertions(+), 10 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index a24478da13..742cfab1e0 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -477,22 +477,31 @@ end # Assumes that the interface has been tested, and we can simply check for numerical issues. function test_frule_correctness( - rng::AbstractRNG, x_ẋ...; frule, unsafe_perturb::Bool, rtol=1e-3, atol=1e-3 + rng::AbstractRNG, + x_ẋ...; + frule, + unsafe_perturb::Bool, + rtol=1e-3, + atol=1e-3, + max_norm_perturbation::Union{Nothing,Float64}=nothing, ) - @nospecialize rng x_ẋ + @nospecialize rng x_ẋ - x_ẋ = map(_deepcopy, x_ẋ) # defensive copy + x_ẋ = map(_deepcopy, x_ẋ) # defensive copy # Run original function on deep-copies of inputs. - x = map(primal, x_ẋ) - ẋ = map(tangent, x_ẋ) + x = map(primal, x_ẋ) + ẋ = map(normalize_tangent ∘ tangent, x_ẋ) x_primal = _deepcopy(x) y_primal = x_primal[1](x_primal[2:end]...) # Use finite differences to estimate Frechet derivative. Compute the estimate at a range # of different step sizes. We'll just require that one of them ends up being close to # what AD gives. - ε_list = [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8] + ε_list = filter( + ε -> isnothing(max_norm_perturbation) || ε <= max_norm_perturbation, + [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8], + ) fd_results = Vector{Any}(undef, length(ε_list)) for (n, ε) in enumerate(ε_list) x′_l = _add_to_primal(x, _scale(ε, ẋ), unsafe_perturb) @@ -567,6 +576,7 @@ function test_rrule_correctness( output_tangent=nothing, rtol=1e-3, atol=1e-3, + max_norm_perturbation::Union{Nothing,Float64}=nothing, ) @nospecialize rng x_x̄ @@ -586,7 +596,10 @@ function test_rrule_correctness( # Use finite differences to estimate vjps. Compute the estimate at a range of different # step sizes. We'll just require that one of them ends up being close to what AD gives. - ε_list = [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8] + ε_list = filter( + ε -> isnothing(max_norm_perturbation) || ε <= max_norm_perturbation, + [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8], + ) fd_results = Vector{Any}(undef, length(ε_list)) for (n, ε) in enumerate(ε_list) x′_l = _add_to_primal(x, _scale(ε, ẋ), unsafe_perturb) @@ -959,9 +972,15 @@ signature associated to `x` corresponds to a primitive, a hand-written rule will Should usually be left `false` -- consult the docstring for `_add_to_primal` for more info on when you might wish to set it to `true`. - `output_tangent=nothing`: final output tangent to initialize reverse mode with for testing - the correctnes of reverse rules. + the correctness of reverse rules. - `atol=1e-3`: absolute tolerance for correctness check of the Frechet derivatives. - `rtol=1e-3`: relative tolerance for correctness check of the Frechet derivatives. +- `max_norm_perturbation::Union{Nothing,Float64}=nothing`: if provided, only finite-difference + step sizes `ε ≤ max_norm_perturbation` are used. Set this when the function is only + defined on a restricted domain (e.g. `log`, `sqrt`, `cholesky`) and large perturbations + would step outside it. The tangent direction `ẋ` is normalised to unit length before + finite differences are computed, so this bound directly controls the size of the step + in input space. """ function test_rule( rng::AbstractRNG, @@ -978,6 +997,7 @@ function test_rule( rtol=1e-3, frule=nothing, rrule=nothing, + max_norm_perturbation::Union{Nothing,Float64}=nothing, ) # Take a copy of `x` to ensure that we do not mutate the original. x = deepcopy(x) @@ -1037,11 +1057,26 @@ function test_rule( # Test that answers are numerically correct / consistent. @testset "Correctness" begin if test_fwd && !interface_only - test_frule_correctness(rng, x_ẋ...; frule, unsafe_perturb, atol, rtol) + test_frule_correctness( + rng, + x_ẋ...; + frule, + unsafe_perturb, + atol, + rtol, + max_norm_perturbation, + ) end if test_rvs && !interface_only test_rrule_correctness( - rng, x_x̄...; rrule, unsafe_perturb, output_tangent, atol, rtol + rng, + x_x̄...; + rrule, + unsafe_perturb, + output_tangent, + atol, + rtol, + max_norm_perturbation, ) end end diff --git a/test/test_utils.jl b/test/test_utils.jl index d779c331f0..4f2233446a 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -147,4 +147,21 @@ @test TestUtils.count_allocs(Mooncake.fdata_type, Tuple{Float64}) == 0 @test TestUtils.count_allocs(Mooncake.fdata_type, Tuple{Vector{Float64}}) == 0 end + @testset "max_norm_perturbation kwarg for testing rules" begin + rng = Xoshiro(1) + # log is undefined for x ≤ 0: without a bound, the ε=1e-2 step evaluates + # log(0.005 - 0.01) = log(-0.005) and throws a DomainError. + @test_throws DomainError log(0.005 - 1.0e-2) + + # With a bound of 1e-3, all ε values keep x+ε*ẋ > 0 and the rule test passes. + ts = TestUtils.test_rule( + rng, + log, + 0.005; + is_primitive=false, + print_results=false, + max_norm_perturbation=1.0e-3, + ) + @test !ts.anynonpass + end end From b5356b6c7cae5efa0f86c53e4c77add14fe0517c Mon Sep 17 00:00:00 2001 From: AstitvaAggarwal Date: Wed, 13 May 2026 17:04:15 +0100 Subject: [PATCH 2/6] domain restricted tests --- test/rules/low_level_maths.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/test/rules/low_level_maths.jl b/test/rules/low_level_maths.jl index be2f8425fe..cb23add8af 100644 --- a/test/rules/low_level_maths.jl +++ b/test/rules/low_level_maths.jl @@ -133,4 +133,19 @@ @test !is_primitive(C, M, Tuple{typeof(/),T,T}, world) @test !is_primitive(C, M, Tuple{typeof(\),T,T}, world) end + + @testset "near-boundary domain-restricted functions" begin + for T in [Float32, Float64] + test_rule( + StableRNG(123), log, T(0.005); is_primitive=true, max_norm_perturbation=1e-3 + ) + test_rule( + StableRNG(123), + sqrt, + T(0.005); + is_primitive=true, + max_norm_perturbation=1e-3, + ) + end + end end From e1de0109b7ee85b451a6defe8a1564cf26693d26 Mon Sep 17 00:00:00 2001 From: AstitvaAggarwal Date: Wed, 13 May 2026 17:07:17 +0100 Subject: [PATCH 3/6] explicit lambda --- src/test_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 742cfab1e0..52cc21d616 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -491,7 +491,7 @@ function test_frule_correctness( # Run original function on deep-copies of inputs. x = map(primal, x_ẋ) - ẋ = map(normalize_tangent ∘ tangent, x_ẋ) + ẋ = map(x -> normalize_tangent(tangent(x)), x_ẋ) x_primal = _deepcopy(x) y_primal = x_primal[1](x_primal[2:end]...) From 99d34a0000f9331ab59fcbc467cd40a2b8f613b7 Mon Sep 17 00:00:00 2001 From: AstitvaAggarwal Date: Thu, 14 May 2026 11:48:40 +0100 Subject: [PATCH 4/6] reduce tests --- test/rules/low_level_maths.jl | 16 ++++------------ test/test_utils.jl | 4 ---- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/test/rules/low_level_maths.jl b/test/rules/low_level_maths.jl index cb23add8af..03f367db42 100644 --- a/test/rules/low_level_maths.jl +++ b/test/rules/low_level_maths.jl @@ -135,17 +135,9 @@ end @testset "near-boundary domain-restricted functions" begin - for T in [Float32, Float64] - test_rule( - StableRNG(123), log, T(0.005); is_primitive=true, max_norm_perturbation=1e-3 - ) - test_rule( - StableRNG(123), - sqrt, - T(0.005); - is_primitive=true, - max_norm_perturbation=1e-3, - ) - end + test_rule(StableRNG(123), log, 0.005; is_primitive=true, max_norm_perturbation=1e-3) + test_rule( + StableRNG(123), sqrt, 0.005; is_primitive=true, max_norm_perturbation=1e-3 + ) end end diff --git a/test/test_utils.jl b/test/test_utils.jl index 4f2233446a..bbdffe00fc 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -149,10 +149,6 @@ end @testset "max_norm_perturbation kwarg for testing rules" begin rng = Xoshiro(1) - # log is undefined for x ≤ 0: without a bound, the ε=1e-2 step evaluates - # log(0.005 - 0.01) = log(-0.005) and throws a DomainError. - @test_throws DomainError log(0.005 - 1.0e-2) - # With a bound of 1e-3, all ε values keep x+ε*ẋ > 0 and the rule test passes. ts = TestUtils.test_rule( rng, From 09f8888c5f03e5af50fc663d5b07dc8c00b219bf Mon Sep 17 00:00:00 2001 From: AstitvaAggarwal Date: Fri, 15 May 2026 10:26:59 +0100 Subject: [PATCH 5/6] changes fmor reviews --- src/test_utils.jl | 19 +++++++++++++++++-- test/rules/low_level_maths.jl | 1 - test/test_utils.jl | 11 ++++------- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 52cc21d616..0bab60e6c5 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -475,6 +475,16 @@ function _diff(p::P, q::P) where {P} return increment!!(_scale(-1.0, t2), t1) end +function _check_ε_list(ε_list, max_norm_perturbation) + isempty(ε_list) && throw( + ArgumentError( + "max_norm_perturbation=$max_norm_perturbation filters out all finite-difference " * + "step sizes; the smallest available is 1e-8, below which floating-point rounding " * + "errors dominate the estimate.", + ), + ) +end + # Assumes that the interface has been tested, and we can simply check for numerical issues. function test_frule_correctness( rng::AbstractRNG, @@ -491,7 +501,7 @@ function test_frule_correctness( # Run original function on deep-copies of inputs. x = map(primal, x_ẋ) - ẋ = map(x -> normalize_tangent(tangent(x)), x_ẋ) + ẋ = map(normalize_tangent ∘ tangent, x_ẋ) x_primal = _deepcopy(x) y_primal = x_primal[1](x_primal[2:end]...) @@ -502,6 +512,7 @@ function test_frule_correctness( ε -> isnothing(max_norm_perturbation) || ε <= max_norm_perturbation, [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8], ) + _check_ε_list(ε_list, max_norm_perturbation) fd_results = Vector{Any}(undef, length(ε_list)) for (n, ε) in enumerate(ε_list) x′_l = _add_to_primal(x, _scale(ε, ẋ), unsafe_perturb) @@ -600,6 +611,7 @@ function test_rrule_correctness( ε -> isnothing(max_norm_perturbation) || ε <= max_norm_perturbation, [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8], ) + _check_ε_list(ε_list, max_norm_perturbation) fd_results = Vector{Any}(undef, length(ε_list)) for (n, ε) in enumerate(ε_list) x′_l = _add_to_primal(x, _scale(ε, ẋ), unsafe_perturb) @@ -921,7 +933,10 @@ __get_primals(xs) = map(x -> x isa Union{Dual,CoDual} ? primal(x) : x, xs) print_results=true, output_tangent=nothing, atol=1e-3, - rtol=1e-3 + rtol=1e-3, + frule=nothing, + rrule=nothing, + max_norm_perturbation::Union{Nothing,Float64}=nothing, ) Run standardised tests on the `rule` for `x`. diff --git a/test/rules/low_level_maths.jl b/test/rules/low_level_maths.jl index 03f367db42..b7b50c86c7 100644 --- a/test/rules/low_level_maths.jl +++ b/test/rules/low_level_maths.jl @@ -135,7 +135,6 @@ end @testset "near-boundary domain-restricted functions" begin - test_rule(StableRNG(123), log, 0.005; is_primitive=true, max_norm_perturbation=1e-3) test_rule( StableRNG(123), sqrt, 0.005; is_primitive=true, max_norm_perturbation=1e-3 ) diff --git a/test/test_utils.jl b/test/test_utils.jl index bbdffe00fc..448c6506a5 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -148,16 +148,13 @@ @test TestUtils.count_allocs(Mooncake.fdata_type, Tuple{Vector{Float64}}) == 0 end @testset "max_norm_perturbation kwarg for testing rules" begin - rng = Xoshiro(1) - # With a bound of 1e-3, all ε values keep x+ε*ẋ > 0 and the rule test passes. - ts = TestUtils.test_rule( - rng, - log, + TestUtils.test_rule( + StableRNG(123), + x -> log(x), 0.005; is_primitive=false, print_results=false, - max_norm_perturbation=1.0e-3, + max_norm_perturbation=1e-3, ) - @test !ts.anynonpass end end From 697369532381a06373a48ec6f6817b0966a19410 Mon Sep 17 00:00:00 2001 From: AstitvaAggarwal Date: Fri, 15 May 2026 16:19:24 +0100 Subject: [PATCH 6/6] remove error function --- src/test_utils.jl | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 0bab60e6c5..50f860d286 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -475,16 +475,6 @@ function _diff(p::P, q::P) where {P} return increment!!(_scale(-1.0, t2), t1) end -function _check_ε_list(ε_list, max_norm_perturbation) - isempty(ε_list) && throw( - ArgumentError( - "max_norm_perturbation=$max_norm_perturbation filters out all finite-difference " * - "step sizes; the smallest available is 1e-8, below which floating-point rounding " * - "errors dominate the estimate.", - ), - ) -end - # Assumes that the interface has been tested, and we can simply check for numerical issues. function test_frule_correctness( rng::AbstractRNG, @@ -508,11 +498,17 @@ function test_frule_correctness( # Use finite differences to estimate Frechet derivative. Compute the estimate at a range # of different step sizes. We'll just require that one of them ends up being close to # what AD gives. + !isnothing(max_norm_perturbation) && + max_norm_perturbation < 1e-8 && + throw( + ArgumentError( + "max_norm_perturbation=$max_norm_perturbation < 1e-8; the smallest available step size is 1e-8.", + ), + ) ε_list = filter( ε -> isnothing(max_norm_perturbation) || ε <= max_norm_perturbation, [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8], ) - _check_ε_list(ε_list, max_norm_perturbation) fd_results = Vector{Any}(undef, length(ε_list)) for (n, ε) in enumerate(ε_list) x′_l = _add_to_primal(x, _scale(ε, ẋ), unsafe_perturb) @@ -607,11 +603,17 @@ function test_rrule_correctness( # Use finite differences to estimate vjps. Compute the estimate at a range of different # step sizes. We'll just require that one of them ends up being close to what AD gives. + !isnothing(max_norm_perturbation) && + max_norm_perturbation < 1e-8 && + throw( + ArgumentError( + "max_norm_perturbation=$max_norm_perturbation < 1e-8; the smallest available step size is 1e-8.", + ), + ) ε_list = filter( ε -> isnothing(max_norm_perturbation) || ε <= max_norm_perturbation, [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8], ) - _check_ε_list(ε_list, max_norm_perturbation) fd_results = Vector{Any}(undef, length(ε_list)) for (n, ε) in enumerate(ε_list) x′_l = _add_to_primal(x, _scale(ε, ẋ), unsafe_perturb)