-
Notifications
You must be signed in to change notification settings - Fork 34
max_norm_perturbation kwarg #1173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
f3edb3f
b5356b6
e1de010
ae6ce5a
99d34a0
09f8888
6973695
4c920e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -477,22 +477,31 @@ end | |
|
|
||
| # Assumes that the interface has been tested, and we can simply check for numerical issues. | ||
| function test_frule_correctness( | ||
| rng::AbstractRNG, x_ẋ...; frule, unsafe_perturb::Bool, rtol=1e-3, atol=1e-3 | ||
| rng::AbstractRNG, | ||
| x_ẋ...; | ||
| frule, | ||
| unsafe_perturb::Bool, | ||
| rtol=1e-3, | ||
| atol=1e-3, | ||
| max_norm_perturbation::Union{Nothing,Float64}=nothing, | ||
| ) | ||
| @nospecialize rng x_ẋ | ||
| @nospecialize rng x_ẋ | ||
|
|
||
| x_ẋ = map(_deepcopy, x_ẋ) # defensive copy | ||
| x_ẋ = map(_deepcopy, x_ẋ) # defensive copy | ||
|
|
||
| # Run original function on deep-copies of inputs. | ||
| x = map(primal, x_ẋ) | ||
| ẋ = map(tangent, x_ẋ) | ||
| x = map(primal, x_ẋ) | ||
| ẋ = map(x -> normalize_tangent(tangent(x)), x_ẋ) | ||
| x_primal = _deepcopy(x) | ||
| y_primal = x_primal[1](x_primal[2:end]...) | ||
|
|
||
| # Use finite differences to estimate Frechet derivative. Compute the estimate at a range | ||
| # of different step sizes. We'll just require that one of them ends up being close to | ||
| # what AD gives. | ||
| ε_list = [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8] | ||
| ε_list = filter( | ||
| ε -> isnothing(max_norm_perturbation) || ε <= max_norm_perturbation, | ||
| [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8], | ||
| ) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. check here also |
||
| fd_results = Vector{Any}(undef, length(ε_list)) | ||
| for (n, ε) in enumerate(ε_list) | ||
| x′_l = _add_to_primal(x, _scale(ε, ẋ), unsafe_perturb) | ||
|
|
@@ -567,6 +576,7 @@ function test_rrule_correctness( | |
| output_tangent=nothing, | ||
| rtol=1e-3, | ||
| atol=1e-3, | ||
| max_norm_perturbation::Union{Nothing,Float64}=nothing, | ||
| ) | ||
| @nospecialize rng x_x̄ | ||
|
|
||
|
|
@@ -586,7 +596,10 @@ function test_rrule_correctness( | |
|
|
||
| # Use finite differences to estimate vjps. Compute the estimate at a range of different | ||
| # step sizes. We'll just require that one of them ends up being close to what AD gives. | ||
| ε_list = [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8] | ||
| ε_list = filter( | ||
| ε -> isnothing(max_norm_perturbation) || ε <= max_norm_perturbation, | ||
| [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8], | ||
| ) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. worth add a check here for the filtered |
||
| fd_results = Vector{Any}(undef, length(ε_list)) | ||
| for (n, ε) in enumerate(ε_list) | ||
| x′_l = _add_to_primal(x, _scale(ε, ẋ), unsafe_perturb) | ||
|
|
@@ -959,9 +972,15 @@ signature associated to `x` corresponds to a primitive, a hand-written rule will | |
| Should usually be left `false` -- consult the docstring for `_add_to_primal` for more | ||
| info on when you might wish to set it to `true`. | ||
| - `output_tangent=nothing`: final output tangent to initialize reverse mode with for testing | ||
| the correctnes of reverse rules. | ||
| the correctness of reverse rules. | ||
| - `atol=1e-3`: absolute tolerance for correctness check of the Frechet derivatives. | ||
| - `rtol=1e-3`: relative tolerance for correctness check of the Frechet derivatives. | ||
| - `max_norm_perturbation::Union{Nothing,Float64}=nothing`: if provided, only finite-difference | ||
| step sizes `ε ≤ max_norm_perturbation` are used. Set this when the function is only | ||
| defined on a restricted domain (e.g. `log`, `sqrt`, `cholesky`) and large perturbations | ||
| would step outside it. The tangent direction `ẋ` is normalised to unit length before | ||
| finite differences are computed, so this bound directly controls the size of the step | ||
| in input space. | ||
| """ | ||
| function test_rule( | ||
| rng::AbstractRNG, | ||
|
|
@@ -978,6 +997,7 @@ function test_rule( | |
| rtol=1e-3, | ||
| frule=nothing, | ||
| rrule=nothing, | ||
| max_norm_perturbation::Union{Nothing,Float64}=nothing, | ||
| ) | ||
| # Take a copy of `x` to ensure that we do not mutate the original. | ||
| x = deepcopy(x) | ||
|
|
@@ -1037,11 +1057,26 @@ function test_rule( | |
| # Test that answers are numerically correct / consistent. | ||
| @testset "Correctness" begin | ||
| if test_fwd && !interface_only | ||
| test_frule_correctness(rng, x_ẋ...; frule, unsafe_perturb, atol, rtol) | ||
| test_frule_correctness( | ||
| rng, | ||
| x_ẋ...; | ||
| frule, | ||
| unsafe_perturb, | ||
| atol, | ||
| rtol, | ||
| max_norm_perturbation, | ||
| ) | ||
| end | ||
| if test_rvs && !interface_only | ||
| test_rrule_correctness( | ||
| rng, x_x̄...; rrule, unsafe_perturb, output_tangent, atol, rtol | ||
| rng, | ||
| x_x̄...; | ||
| rrule, | ||
| unsafe_perturb, | ||
| output_tangent, | ||
| atol, | ||
| rtol, | ||
| max_norm_perturbation, | ||
| ) | ||
| end | ||
| end | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -133,4 +133,11 @@ | |
| @test !is_primitive(C, M, Tuple{typeof(/),T,T}, world) | ||
| @test !is_primitive(C, M, Tuple{typeof(\),T,T}, world) | ||
| end | ||
|
|
||
| @testset "near-boundary domain-restricted functions" begin | ||
| test_rule(StableRNG(123), log, 0.005; is_primitive=true, max_norm_perturbation=1e-3) | ||
| test_rule( | ||
| StableRNG(123), sqrt, 0.005; is_primitive=true, max_norm_perturbation=1e-3 | ||
| ) | ||
| end | ||
| end | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this overlap with the one in test_utils.jl? |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -147,4 +147,17 @@ | |
| @test TestUtils.count_allocs(Mooncake.fdata_type, Tuple{Float64}) == 0 | ||
| @test TestUtils.count_allocs(Mooncake.fdata_type, Tuple{Vector{Float64}}) == 0 | ||
| end | ||
| @testset "max_norm_perturbation kwarg for testing rules" begin | ||
| rng = Xoshiro(1) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this feels inconsistent coding style |
||
| # With a bound of 1e-3, all ε values keep x+ε*ẋ > 0 and the rule test passes. | ||
| ts = TestUtils.test_rule( | ||
| rng, | ||
| log, | ||
| 0.005; | ||
| is_primitive=false, | ||
| print_results=false, | ||
| max_norm_perturbation=1.0e-3, | ||
| ) | ||
| @test !ts.anynonpass | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need this?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will fix this, its actually a field in |
||
| end | ||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a nit on style, maybe do something like
normalize_tangent ∘ tangentinstead of the anonymous function?