diff --git a/src/test_utils.jl b/src/test_utils.jl index a24478da13..50f860d286 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -477,22 +477,38 @@ end # Assumes that the interface has been tested, and we can simply check for numerical issues. function test_frule_correctness( - rng::AbstractRNG, x_ẋ...; frule, unsafe_perturb::Bool, rtol=1e-3, atol=1e-3 + rng::AbstractRNG, + x_ẋ...; + frule, + unsafe_perturb::Bool, + rtol=1e-3, + atol=1e-3, + max_norm_perturbation::Union{Nothing,Float64}=nothing, ) - @nospecialize rng x_ẋ + @nospecialize rng x_ẋ - x_ẋ = map(_deepcopy, x_ẋ) # defensive copy + x_ẋ = map(_deepcopy, x_ẋ) # defensive copy # Run original function on deep-copies of inputs. - x = map(primal, x_ẋ) - ẋ = map(tangent, x_ẋ) + x = map(primal, x_ẋ) + ẋ = map(normalize_tangent ∘ tangent, x_ẋ) x_primal = _deepcopy(x) y_primal = x_primal[1](x_primal[2:end]...) # Use finite differences to estimate Frechet derivative. Compute the estimate at a range # of different step sizes. We'll just require that one of them ends up being close to # what AD gives. - ε_list = [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8] + !isnothing(max_norm_perturbation) && + max_norm_perturbation < 1e-8 && + throw( + ArgumentError( + "max_norm_perturbation=$max_norm_perturbation < 1e-8; the smallest available step size is 1e-8.", + ), + ) + ε_list = filter( + ε -> isnothing(max_norm_perturbation) || ε <= max_norm_perturbation, + [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8], + ) fd_results = Vector{Any}(undef, length(ε_list)) for (n, ε) in enumerate(ε_list) x′_l = _add_to_primal(x, _scale(ε, ẋ), unsafe_perturb) @@ -567,6 +583,7 @@ function test_rrule_correctness( output_tangent=nothing, rtol=1e-3, atol=1e-3, + max_norm_perturbation::Union{Nothing,Float64}=nothing, ) @nospecialize rng x_x̄ @@ -586,7 +603,17 @@ function test_rrule_correctness( # Use finite differences to estimate vjps. Compute the estimate at a range of different # step sizes. We'll just require that one of them ends up being close to what AD gives. - ε_list = [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8] + !isnothing(max_norm_perturbation) && + max_norm_perturbation < 1e-8 && + throw( + ArgumentError( + "max_norm_perturbation=$max_norm_perturbation < 1e-8; the smallest available step size is 1e-8.", + ), + ) + ε_list = filter( + ε -> isnothing(max_norm_perturbation) || ε <= max_norm_perturbation, + [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8], + ) fd_results = Vector{Any}(undef, length(ε_list)) for (n, ε) in enumerate(ε_list) x′_l = _add_to_primal(x, _scale(ε, ẋ), unsafe_perturb) @@ -908,7 +935,10 @@ __get_primals(xs) = map(x -> x isa Union{Dual,CoDual} ? primal(x) : x, xs) print_results=true, output_tangent=nothing, atol=1e-3, - rtol=1e-3 + rtol=1e-3, + frule=nothing, + rrule=nothing, + max_norm_perturbation::Union{Nothing,Float64}=nothing, ) Run standardised tests on the `rule` for `x`. @@ -959,9 +989,15 @@ signature associated to `x` corresponds to a primitive, a hand-written rule will Should usually be left `false` -- consult the docstring for `_add_to_primal` for more info on when you might wish to set it to `true`. - `output_tangent=nothing`: final output tangent to initialize reverse mode with for testing - the correctnes of reverse rules. + the correctness of reverse rules. - `atol=1e-3`: absolute tolerance for correctness check of the Frechet derivatives. - `rtol=1e-3`: relative tolerance for correctness check of the Frechet derivatives. +- `max_norm_perturbation::Union{Nothing,Float64}=nothing`: if provided, only finite-difference + step sizes `ε ≤ max_norm_perturbation` are used. Set this when the function is only + defined on a restricted domain (e.g. `log`, `sqrt`, `cholesky`) and large perturbations + would step outside it. The tangent direction `ẋ` is normalised to unit length before + finite differences are computed, so this bound directly controls the size of the step + in input space. """ function test_rule( rng::AbstractRNG, @@ -978,6 +1014,7 @@ function test_rule( rtol=1e-3, frule=nothing, rrule=nothing, + max_norm_perturbation::Union{Nothing,Float64}=nothing, ) # Take a copy of `x` to ensure that we do not mutate the original. x = deepcopy(x) @@ -1037,11 +1074,26 @@ function test_rule( # Test that answers are numerically correct / consistent. @testset "Correctness" begin if test_fwd && !interface_only - test_frule_correctness(rng, x_ẋ...; frule, unsafe_perturb, atol, rtol) + test_frule_correctness( + rng, + x_ẋ...; + frule, + unsafe_perturb, + atol, + rtol, + max_norm_perturbation, + ) end if test_rvs && !interface_only test_rrule_correctness( - rng, x_x̄...; rrule, unsafe_perturb, output_tangent, atol, rtol + rng, + x_x̄...; + rrule, + unsafe_perturb, + output_tangent, + atol, + rtol, + max_norm_perturbation, ) end end diff --git a/test/rules/low_level_maths.jl b/test/rules/low_level_maths.jl index be2f8425fe..b7b50c86c7 100644 --- a/test/rules/low_level_maths.jl +++ b/test/rules/low_level_maths.jl @@ -133,4 +133,10 @@ @test !is_primitive(C, M, Tuple{typeof(/),T,T}, world) @test !is_primitive(C, M, Tuple{typeof(\),T,T}, world) end + + @testset "near-boundary domain-restricted functions" begin + test_rule( + StableRNG(123), sqrt, 0.005; is_primitive=true, max_norm_perturbation=1e-3 + ) + end end diff --git a/test/test_utils.jl b/test/test_utils.jl index d779c331f0..448c6506a5 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -147,4 +147,14 @@ @test TestUtils.count_allocs(Mooncake.fdata_type, Tuple{Float64}) == 0 @test TestUtils.count_allocs(Mooncake.fdata_type, Tuple{Vector{Float64}}) == 0 end + @testset "max_norm_perturbation kwarg for testing rules" begin + TestUtils.test_rule( + StableRNG(123), + x -> log(x), + 0.005; + is_primitive=false, + print_results=false, + max_norm_perturbation=1e-3, + ) + end end