Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 45 additions & 10 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -477,22 +477,31 @@ end

# Assumes that the interface has been tested, and we can simply check for numerical issues.
function test_frule_correctness(
rng::AbstractRNG, x_ẋ...; frule, unsafe_perturb::Bool, rtol=1e-3, atol=1e-3
rng::AbstractRNG,
x_ẋ...;
frule,
unsafe_perturb::Bool,
rtol=1e-3,
atol=1e-3,
max_norm_perturbation::Union{Nothing,Float64}=nothing,
)
@nospecialize rng x_ẋ
@nospecialize rng x_ẋ

x_ẋ = map(_deepcopy, x_ẋ) # defensive copy
x_ẋ = map(_deepcopy, x_ẋ) # defensive copy

# Run original function on deep-copies of inputs.
x = map(primal, x_ẋ)
= map(tangent, x_ẋ)
x = map(primal, x_ẋ)
= map(x -> normalize_tangent(tangent(x)), x_ẋ)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a nit on style, maybe do something like normalize_tangent ∘ tangent instead of the anonymous function?

x_primal = _deepcopy(x)
y_primal = x_primal[1](x_primal[2:end]...)

# Use finite differences to estimate Frechet derivative. Compute the estimate at a range
# of different step sizes. We'll just require that one of them ends up being close to
# what AD gives.
ε_list = [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
ε_list = filter(
ε -> isnothing(max_norm_perturbation) || ε <= max_norm_perturbation,
[1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8],
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check here also

fd_results = Vector{Any}(undef, length(ε_list))
for (n, ε) in enumerate(ε_list)
x′_l = _add_to_primal(x, _scale(ε, ẋ), unsafe_perturb)
Expand Down Expand Up @@ -567,6 +576,7 @@ function test_rrule_correctness(
output_tangent=nothing,
rtol=1e-3,
atol=1e-3,
max_norm_perturbation::Union{Nothing,Float64}=nothing,
)
@nospecialize rng x_x̄

Expand All @@ -586,7 +596,10 @@ function test_rrule_correctness(

# Use finite differences to estimate vjps. Compute the estimate at a range of different
# step sizes. We'll just require that one of them ends up being close to what AD gives.
ε_list = [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
ε_list = filter(
ε -> isnothing(max_norm_perturbation) || ε <= max_norm_perturbation,
[1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8],
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

worth add a check here for the filtered ε_list: if too few left or empty

fd_results = Vector{Any}(undef, length(ε_list))
for (n, ε) in enumerate(ε_list)
x′_l = _add_to_primal(x, _scale(ε, ẋ), unsafe_perturb)
Expand Down Expand Up @@ -959,9 +972,15 @@ signature associated to `x` corresponds to a primitive, a hand-written rule will
Should usually be left `false` -- consult the docstring for `_add_to_primal` for more
info on when you might wish to set it to `true`.
- `output_tangent=nothing`: final output tangent to initialize reverse mode with for testing
the correctnes of reverse rules.
the correctness of reverse rules.
- `atol=1e-3`: absolute tolerance for correctness check of the Frechet derivatives.
- `rtol=1e-3`: relative tolerance for correctness check of the Frechet derivatives.
- `max_norm_perturbation::Union{Nothing,Float64}=nothing`: if provided, only finite-difference
step sizes `ε ≤ max_norm_perturbation` are used. Set this when the function is only
defined on a restricted domain (e.g. `log`, `sqrt`, `cholesky`) and large perturbations
would step outside it. The tangent direction `ẋ` is normalised to unit length before
finite differences are computed, so this bound directly controls the size of the step
in input space.
"""
function test_rule(
rng::AbstractRNG,
Expand All @@ -978,6 +997,7 @@ function test_rule(
rtol=1e-3,
frule=nothing,
rrule=nothing,
max_norm_perturbation::Union{Nothing,Float64}=nothing,
)
# Take a copy of `x` to ensure that we do not mutate the original.
x = deepcopy(x)
Expand Down Expand Up @@ -1037,11 +1057,26 @@ function test_rule(
# Test that answers are numerically correct / consistent.
@testset "Correctness" begin
if test_fwd && !interface_only
test_frule_correctness(rng, x_ẋ...; frule, unsafe_perturb, atol, rtol)
test_frule_correctness(
rng,
x_ẋ...;
frule,
unsafe_perturb,
atol,
rtol,
max_norm_perturbation,
)
end
if test_rvs && !interface_only
test_rrule_correctness(
rng, x_x̄...; rrule, unsafe_perturb, output_tangent, atol, rtol
rng,
x_x̄...;
rrule,
unsafe_perturb,
output_tangent,
atol,
rtol,
max_norm_perturbation,
)
end
end
Expand Down
7 changes: 7 additions & 0 deletions test/rules/low_level_maths.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,11 @@
@test !is_primitive(C, M, Tuple{typeof(/),T,T}, world)
@test !is_primitive(C, M, Tuple{typeof(\),T,T}, world)
end

@testset "near-boundary domain-restricted functions" begin
test_rule(StableRNG(123), log, 0.005; is_primitive=true, max_norm_perturbation=1e-3)
test_rule(
StableRNG(123), sqrt, 0.005; is_primitive=true, max_norm_perturbation=1e-3
)
end
end
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this overlap with the one in test_utils.jl?

13 changes: 13 additions & 0 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,17 @@
@test TestUtils.count_allocs(Mooncake.fdata_type, Tuple{Float64}) == 0
@test TestUtils.count_allocs(Mooncake.fdata_type, Tuple{Vector{Float64}}) == 0
end
@testset "max_norm_perturbation kwarg for testing rules" begin
rng = Xoshiro(1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels inconsistent coding style

# With a bound of 1e-3, all ε values keep x+ε*ẋ > 0 and the rule test passes.
ts = TestUtils.test_rule(
rng,
log,
0.005;
is_primitive=false,
print_results=false,
max_norm_perturbation=1.0e-3,
)
@test !ts.anynonpass
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will fix this, its actually a field in Test.DefaultTestSet i was using it as @test !ts.anynonpass to assert that all internal tests passed.

end
end
Loading