Skip to content
65 changes: 56 additions & 9 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -477,22 +477,37 @@ end

# Assumes that the interface has been tested, and we can simply check for numerical issues.
function test_frule_correctness(
rng::AbstractRNG, x_ẋ...; frule, unsafe_perturb::Bool, rtol=1e-3, atol=1e-3
rng::AbstractRNG,
x_ẋ...;
frule,
unsafe_perturb::Bool,
rtol=1e-3,
atol=1e-3,
max_fd_step::Union{Nothing,Real}=nothing,
)
@nospecialize rng x_ẋ
@nospecialize rng x_ẋ

x_ẋ = map(_deepcopy, x_ẋ) # defensive copy
x_ẋ = map(_deepcopy, x_ẋ) # defensive copy

# Run original function on deep-copies of inputs.
x = map(primal, x_ẋ)
= map(tangent, x_ẋ)
x = map(primal, x_ẋ)
= map(normalize_tangent ∘ tangent, x_ẋ)
x_primal = _deepcopy(x)
y_primal = x_primal[1](x_primal[2:end]...)

# Use finite differences to estimate Frechet derivative. Compute the estimate at a range
# of different step sizes. We'll just require that one of them ends up being close to
# what AD gives.
ε_list = [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
if max_fd_step !== nothing
ε_list = filter(≤(max_fd_step), ε_list)
length(ε_list) ≥ 2 || throw(
ArgumentError(
"max_fd_step=$max_fd_step leaves fewer than two FD steps; the fixed " *
"grid ends at 1e-7, so the smallest usable cap is 1e-6.",
),
)
end
fd_results = Vector{Any}(undef, length(ε_list))
for (n, ε) in enumerate(ε_list)
x′_l = _add_to_primal(x, _scale(ε, ẋ), unsafe_perturb)
Expand Down Expand Up @@ -567,6 +582,7 @@ function test_rrule_correctness(
output_tangent=nothing,
rtol=1e-3,
atol=1e-3,
max_fd_step::Union{Nothing,Real}=nothing,
)
@nospecialize rng x_x̄

Expand All @@ -587,6 +603,15 @@ function test_rrule_correctness(
# Use finite differences to estimate vjps. Compute the estimate at a range of different
# step sizes. We'll just require that one of them ends up being close to what AD gives.
ε_list = [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
if max_fd_step !== nothing
ε_list = filter(≤(max_fd_step), ε_list)
length(ε_list) ≥ 2 || throw(
ArgumentError(
"max_fd_step=$max_fd_step leaves fewer than two FD steps; the fixed " *
"grid ends at 1e-7, so the smallest usable cap is 1e-6.",
),
)
end
fd_results = Vector{Any}(undef, length(ε_list))
for (n, ε) in enumerate(ε_list)
x′_l = _add_to_primal(x, _scale(ε, ẋ), unsafe_perturb)
Expand Down Expand Up @@ -908,7 +933,10 @@ __get_primals(xs) = map(x -> x isa Union{Dual,CoDual} ? primal(x) : x, xs)
print_results=true,
output_tangent=nothing,
atol=1e-3,
rtol=1e-3
rtol=1e-3,
frule=nothing,
rrule=nothing,
Comment thread
sunxd3 marked this conversation as resolved.
max_fd_step::Union{Nothing,Real}=nothing,
)

Run standardised tests on the `rule` for `x`.
Expand Down Expand Up @@ -959,9 +987,18 @@ signature associated to `x` corresponds to a primitive, a hand-written rule will
Should usually be left `false` -- consult the docstring for `_add_to_primal` for more
info on when you might wish to set it to `true`.
- `output_tangent=nothing`: final output tangent to initialize reverse mode with for testing
the correctnes of reverse rules.
the correctness of reverse rules.
- `atol=1e-3`: absolute tolerance for correctness check of the Frechet derivatives.
- `rtol=1e-3`: relative tolerance for correctness check of the Frechet derivatives.
- `frule=nothing`: if provided, use this callable as the forward rule instead of building one
from the interpreter. Useful for testing a hand-written `frule!!` directly.
- `rrule=nothing`: if provided, use this callable as the reverse rule instead of building one
from the interpreter. Useful for testing a hand-written `rrule!!` directly.
- `max_fd_step::Union{Nothing,Real}=nothing`: cap on finite-difference step sizes; only
`ε ≤ max_fd_step` are used. Each argument's tangent is unit-normalised independently,
so each argument is perturbed by at most `max_fd_step` in L2 norm. Set this for
domain-restricted functions (`log`, `sqrt`, `cholesky`) to keep perturbations inside
the domain. The FD grid ends at `1e-7`; the smallest usable cap is `1e-6`.
"""
function test_rule(
rng::AbstractRNG,
Expand All @@ -978,6 +1015,7 @@ function test_rule(
rtol=1e-3,
frule=nothing,
rrule=nothing,
max_fd_step::Union{Nothing,Real}=nothing,
)
# Take a copy of `x` to ensure that we do not mutate the original.
x = deepcopy(x)
Expand Down Expand Up @@ -1037,11 +1075,20 @@ function test_rule(
# Test that answers are numerically correct / consistent.
@testset "Correctness" begin
if test_fwd && !interface_only
test_frule_correctness(rng, x_ẋ...; frule, unsafe_perturb, atol, rtol)
test_frule_correctness(
rng, x_ẋ...; frule, unsafe_perturb, atol, rtol, max_fd_step
)
end
if test_rvs && !interface_only
test_rrule_correctness(
rng, x_x̄...; rrule, unsafe_perturb, output_tangent, atol, rtol
rng,
x_x̄...;
rrule,
unsafe_perturb,
output_tangent,
atol,
rtol,
max_fd_step,
)
end
end
Expand Down
4 changes: 4 additions & 0 deletions test/rules/low_level_maths.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,8 @@
@test !is_primitive(C, M, Tuple{typeof(/),T,T}, world)
@test !is_primitive(C, M, Tuple{typeof(\),T,T}, world)
end

@testset "near-boundary domain-restricted functions" begin
test_rule(StableRNG(123), sqrt, 0.005; is_primitive=true, max_fd_step=1e-3)
end
end
Comment thread
sunxd3 marked this conversation as resolved.
13 changes: 13 additions & 0 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,17 @@
@test TestUtils.count_allocs(Mooncake.fdata_type, Tuple{Float64}) == 0
@test TestUtils.count_allocs(Mooncake.fdata_type, Tuple{Vector{Float64}}) == 0
end
@testset "max_fd_step kwarg for testing rules" begin
# log is a primitive in Mooncake, so we wrap it in a lambda to test the derived-rule
# path. The input 0.005 is close to the boundary of log's domain (x > 0), so we cap
# the FD step at 1e-3 to avoid perturbing into x ≤ 0.
TestUtils.test_rule(
StableRNG(123),
x -> log(x),
Comment thread
sunxd3 marked this conversation as resolved.
0.005;
is_primitive=false,
print_results=false,
max_fd_step=1e-3,
)
end
end
Loading