-
Notifications
You must be signed in to change notification settings - Fork 34
max_norm_perturbation kwarg #1173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
f3edb3f
b5356b6
e1de010
ae6ce5a
99d34a0
09f8888
6973695
4c920e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -475,24 +475,44 @@ function _diff(p::P, q::P) where {P} | |
| return increment!!(_scale(-1.0, t2), t1) | ||
| end | ||
|
|
||
| function _check_ε_list(ε_list, max_norm_perturbation) | ||
| isempty(ε_list) && throw( | ||
| ArgumentError( | ||
| "max_norm_perturbation=$max_norm_perturbation filters out all finite-difference " * | ||
| "step sizes; the smallest available is 1e-8, below which floating-point rounding " * | ||
| "errors dominate the estimate.", | ||
| ), | ||
| ) | ||
| end | ||
|
|
||
| # Assumes that the interface has been tested, and we can simply check for numerical issues. | ||
| function test_frule_correctness( | ||
| rng::AbstractRNG, x_ẋ...; frule, unsafe_perturb::Bool, rtol=1e-3, atol=1e-3 | ||
| rng::AbstractRNG, | ||
| x_ẋ...; | ||
| frule, | ||
| unsafe_perturb::Bool, | ||
| rtol=1e-3, | ||
| atol=1e-3, | ||
| max_norm_perturbation::Union{Nothing,Float64}=nothing, | ||
| ) | ||
| @nospecialize rng x_ẋ | ||
| @nospecialize rng x_ẋ | ||
|
|
||
| x_ẋ = map(_deepcopy, x_ẋ) # defensive copy | ||
| x_ẋ = map(_deepcopy, x_ẋ) # defensive copy | ||
|
|
||
| # Run original function on deep-copies of inputs. | ||
| x = map(primal, x_ẋ) | ||
| ẋ = map(tangent, x_ẋ) | ||
| x = map(primal, x_ẋ) | ||
| ẋ = map(normalize_tangent ∘ tangent, x_ẋ) | ||
| x_primal = _deepcopy(x) | ||
| y_primal = x_primal[1](x_primal[2:end]...) | ||
|
|
||
| # Use finite differences to estimate Frechet derivative. Compute the estimate at a range | ||
| # of different step sizes. We'll just require that one of them ends up being close to | ||
| # what AD gives. | ||
| ε_list = [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8] | ||
| ε_list = filter( | ||
| ε -> isnothing(max_norm_perturbation) || ε <= max_norm_perturbation, | ||
| [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8], | ||
| ) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. check here also |
||
| _check_ε_list(ε_list, max_norm_perturbation) | ||
| fd_results = Vector{Any}(undef, length(ε_list)) | ||
| for (n, ε) in enumerate(ε_list) | ||
| x′_l = _add_to_primal(x, _scale(ε, ẋ), unsafe_perturb) | ||
|
|
@@ -567,6 +587,7 @@ function test_rrule_correctness( | |
| output_tangent=nothing, | ||
| rtol=1e-3, | ||
| atol=1e-3, | ||
| max_norm_perturbation::Union{Nothing,Float64}=nothing, | ||
| ) | ||
| @nospecialize rng x_x̄ | ||
|
|
||
|
|
@@ -586,7 +607,11 @@ function test_rrule_correctness( | |
|
|
||
| # Use finite differences to estimate vjps. Compute the estimate at a range of different | ||
| # step sizes. We'll just require that one of them ends up being close to what AD gives. | ||
| ε_list = [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8] | ||
| ε_list = filter( | ||
| ε -> isnothing(max_norm_perturbation) || ε <= max_norm_perturbation, | ||
| [1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8], | ||
| ) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. worth add a check here for the filtered |
||
| _check_ε_list(ε_list, max_norm_perturbation) | ||
| fd_results = Vector{Any}(undef, length(ε_list)) | ||
| for (n, ε) in enumerate(ε_list) | ||
| x′_l = _add_to_primal(x, _scale(ε, ẋ), unsafe_perturb) | ||
|
|
@@ -908,7 +933,10 @@ __get_primals(xs) = map(x -> x isa Union{Dual,CoDual} ? primal(x) : x, xs) | |
| print_results=true, | ||
| output_tangent=nothing, | ||
| atol=1e-3, | ||
| rtol=1e-3 | ||
| rtol=1e-3, | ||
| frule=nothing, | ||
| rrule=nothing, | ||
| max_norm_perturbation::Union{Nothing,Float64}=nothing, | ||
| ) | ||
|
Comment on lines
925
to
942
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. update this too? it was outdated before this PR but worth fixing with it |
||
|
|
||
| Run standardised tests on the `rule` for `x`. | ||
|
|
@@ -959,9 +987,15 @@ signature associated to `x` corresponds to a primitive, a hand-written rule will | |
| Should usually be left `false` -- consult the docstring for `_add_to_primal` for more | ||
| info on when you might wish to set it to `true`. | ||
| - `output_tangent=nothing`: final output tangent to initialize reverse mode with for testing | ||
| the correctnes of reverse rules. | ||
| the correctness of reverse rules. | ||
| - `atol=1e-3`: absolute tolerance for correctness check of the Frechet derivatives. | ||
| - `rtol=1e-3`: relative tolerance for correctness check of the Frechet derivatives. | ||
| - `max_norm_perturbation::Union{Nothing,Float64}=nothing`: if provided, only finite-difference | ||
| step sizes `ε ≤ max_norm_perturbation` are used. Set this when the function is only | ||
| defined on a restricted domain (e.g. `log`, `sqrt`, `cholesky`) and large perturbations | ||
| would step outside it. The tangent direction `ẋ` is normalised to unit length before | ||
| finite differences are computed, so this bound directly controls the size of the step | ||
| in input space. | ||
| """ | ||
| function test_rule( | ||
| rng::AbstractRNG, | ||
|
|
@@ -978,6 +1012,7 @@ function test_rule( | |
| rtol=1e-3, | ||
| frule=nothing, | ||
| rrule=nothing, | ||
| max_norm_perturbation::Union{Nothing,Float64}=nothing, | ||
| ) | ||
| # Take a copy of `x` to ensure that we do not mutate the original. | ||
| x = deepcopy(x) | ||
|
|
@@ -1037,11 +1072,26 @@ function test_rule( | |
| # Test that answers are numerically correct / consistent. | ||
| @testset "Correctness" begin | ||
| if test_fwd && !interface_only | ||
| test_frule_correctness(rng, x_ẋ...; frule, unsafe_perturb, atol, rtol) | ||
| test_frule_correctness( | ||
| rng, | ||
| x_ẋ...; | ||
| frule, | ||
| unsafe_perturb, | ||
| atol, | ||
| rtol, | ||
| max_norm_perturbation, | ||
| ) | ||
| end | ||
| if test_rvs && !interface_only | ||
| test_rrule_correctness( | ||
| rng, x_x̄...; rrule, unsafe_perturb, output_tangent, atol, rtol | ||
| rng, | ||
| x_x̄...; | ||
| rrule, | ||
| unsafe_perturb, | ||
| output_tangent, | ||
| atol, | ||
| rtol, | ||
| max_norm_perturbation, | ||
| ) | ||
| end | ||
| end | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -133,4 +133,10 @@ | |
| @test !is_primitive(C, M, Tuple{typeof(/),T,T}, world) | ||
| @test !is_primitive(C, M, Tuple{typeof(\),T,T}, world) | ||
| end | ||
|
|
||
| @testset "near-boundary domain-restricted functions" begin | ||
| test_rule( | ||
| StableRNG(123), sqrt, 0.005; is_primitive=true, max_norm_perturbation=1e-3 | ||
| ) | ||
| end | ||
| end | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this overlap with the one in test_utils.jl? |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a style nit: we can just check the value of
max_norm_perturbationdirectly beforeand error if too small, try to keep it a oneliner?