Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 61 additions & 11 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -475,24 +475,44 @@ function _diff(p::P, q::P) where {P}
return increment!!(_scale(-1.0, t2), t1)
end

function _check_ε_list(ε_list, max_norm_perturbation)
isempty(ε_list) && throw(
ArgumentError(
"max_norm_perturbation=$max_norm_perturbation filters out all finite-difference " *
"step sizes; the smallest available is 1e-8, below which floating-point rounding " *
"errors dominate the estimate.",
),
)
end
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a style nit: we can just check the value of max_norm_perturbation directly before

ε_list = filter(
        ε -> isnothing(max_norm_perturbation) || ε <= max_norm_perturbation,
        [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8],
    )

and error if too small, try to keep it a oneliner?


# Assumes that the interface has been tested, and we can simply check for numerical issues.
function test_frule_correctness(
rng::AbstractRNG, x_ẋ...; frule, unsafe_perturb::Bool, rtol=1e-3, atol=1e-3
rng::AbstractRNG,
x_ẋ...;
frule,
unsafe_perturb::Bool,
rtol=1e-3,
atol=1e-3,
max_norm_perturbation::Union{Nothing,Float64}=nothing,
)
@nospecialize rng x_ẋ
@nospecialize rng x_ẋ

x_ẋ = map(_deepcopy, x_ẋ) # defensive copy
x_ẋ = map(_deepcopy, x_ẋ) # defensive copy

# Run original function on deep-copies of inputs.
x = map(primal, x_ẋ)
= map(tangent, x_ẋ)
x = map(primal, x_ẋ)
= map(normalize_tangent ∘ tangent, x_ẋ)
x_primal = _deepcopy(x)
y_primal = x_primal[1](x_primal[2:end]...)

# Use finite differences to estimate Frechet derivative. Compute the estimate at a range
# of different step sizes. We'll just require that one of them ends up being close to
# what AD gives.
ε_list = [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
ε_list = filter(
ε -> isnothing(max_norm_perturbation) || ε <= max_norm_perturbation,
[1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8],
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check here also

_check_ε_list(ε_list, max_norm_perturbation)
fd_results = Vector{Any}(undef, length(ε_list))
for (n, ε) in enumerate(ε_list)
x′_l = _add_to_primal(x, _scale(ε, ẋ), unsafe_perturb)
Expand Down Expand Up @@ -567,6 +587,7 @@ function test_rrule_correctness(
output_tangent=nothing,
rtol=1e-3,
atol=1e-3,
max_norm_perturbation::Union{Nothing,Float64}=nothing,
)
@nospecialize rng x_x̄

Expand All @@ -586,7 +607,11 @@ function test_rrule_correctness(

# Use finite differences to estimate vjps. Compute the estimate at a range of different
# step sizes. We'll just require that one of them ends up being close to what AD gives.
ε_list = [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
ε_list = filter(
ε -> isnothing(max_norm_perturbation) || ε <= max_norm_perturbation,
[1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8],
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

worth add a check here for the filtered ε_list: if too few left or empty

_check_ε_list(ε_list, max_norm_perturbation)
fd_results = Vector{Any}(undef, length(ε_list))
for (n, ε) in enumerate(ε_list)
x′_l = _add_to_primal(x, _scale(ε, ẋ), unsafe_perturb)
Expand Down Expand Up @@ -908,7 +933,10 @@ __get_primals(xs) = map(x -> x isa Union{Dual,CoDual} ? primal(x) : x, xs)
print_results=true,
output_tangent=nothing,
atol=1e-3,
rtol=1e-3
rtol=1e-3,
frule=nothing,
rrule=nothing,
max_norm_perturbation::Union{Nothing,Float64}=nothing,
)
Comment on lines 925 to 942
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update this too? it was outdated before this PR but worth fixing with it


Run standardised tests on the `rule` for `x`.
Expand Down Expand Up @@ -959,9 +987,15 @@ signature associated to `x` corresponds to a primitive, a hand-written rule will
Should usually be left `false` -- consult the docstring for `_add_to_primal` for more
info on when you might wish to set it to `true`.
- `output_tangent=nothing`: final output tangent to initialize reverse mode with for testing
the correctnes of reverse rules.
the correctness of reverse rules.
- `atol=1e-3`: absolute tolerance for correctness check of the Frechet derivatives.
- `rtol=1e-3`: relative tolerance for correctness check of the Frechet derivatives.
- `max_norm_perturbation::Union{Nothing,Float64}=nothing`: if provided, only finite-difference
step sizes `ε ≤ max_norm_perturbation` are used. Set this when the function is only
defined on a restricted domain (e.g. `log`, `sqrt`, `cholesky`) and large perturbations
would step outside it. The tangent direction `ẋ` is normalised to unit length before
finite differences are computed, so this bound directly controls the size of the step
in input space.
"""
function test_rule(
rng::AbstractRNG,
Expand All @@ -978,6 +1012,7 @@ function test_rule(
rtol=1e-3,
frule=nothing,
rrule=nothing,
max_norm_perturbation::Union{Nothing,Float64}=nothing,
)
# Take a copy of `x` to ensure that we do not mutate the original.
x = deepcopy(x)
Expand Down Expand Up @@ -1037,11 +1072,26 @@ function test_rule(
# Test that answers are numerically correct / consistent.
@testset "Correctness" begin
if test_fwd && !interface_only
test_frule_correctness(rng, x_ẋ...; frule, unsafe_perturb, atol, rtol)
test_frule_correctness(
rng,
x_ẋ...;
frule,
unsafe_perturb,
atol,
rtol,
max_norm_perturbation,
)
end
if test_rvs && !interface_only
test_rrule_correctness(
rng, x_x̄...; rrule, unsafe_perturb, output_tangent, atol, rtol
rng,
x_x̄...;
rrule,
unsafe_perturb,
output_tangent,
atol,
rtol,
max_norm_perturbation,
)
end
end
Expand Down
6 changes: 6 additions & 0 deletions test/rules/low_level_maths.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,10 @@
@test !is_primitive(C, M, Tuple{typeof(/),T,T}, world)
@test !is_primitive(C, M, Tuple{typeof(\),T,T}, world)
end

@testset "near-boundary domain-restricted functions" begin
test_rule(
StableRNG(123), sqrt, 0.005; is_primitive=true, max_norm_perturbation=1e-3
)
end
end
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this overlap with the one in test_utils.jl?

10 changes: 10 additions & 0 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,14 @@
@test TestUtils.count_allocs(Mooncake.fdata_type, Tuple{Float64}) == 0
@test TestUtils.count_allocs(Mooncake.fdata_type, Tuple{Vector{Float64}}) == 0
end
@testset "max_norm_perturbation kwarg for testing rules" begin
TestUtils.test_rule(
StableRNG(123),
x -> log(x),
0.005;
is_primitive=false,
print_results=false,
max_norm_perturbation=1e-3,
)
end
end
Loading