diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index 01a1a80e6..1e1c84713 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -101,6 +101,7 @@ jobs: - ForwardDiff - GTPSA - Mooncake + - Mooncake-old - PolyesterForwardDiff - ReverseDiff - SparsityDetector diff --git a/DifferentiationInterface/CHANGELOG.md b/DifferentiationInterface/CHANGELOG.md index d60b733ec..2b8ccd7cc 100644 --- a/DifferentiationInterface/CHANGELOG.md +++ b/DifferentiationInterface/CHANGELOG.md @@ -5,7 +5,15 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [Unreleased](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.16...main) +## [Unreleased](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.17...main) + +## [0.7.17](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.16...DifferentiationInterface-v0.7.17) + +### Fixed + +- Make DI compatible with latest Mooncake friendly tangents ([#1001](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/1001)) +- Add docstrings to the result anlysis methods for sparse matrix preparations ([#984](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/984)) +- Make wrong-mode pushforward/pullback return the correct array type ([#974](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/974)) ## [0.7.16](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.15...DifferentiationInterface-v0.7.16) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 3e01e2c03..2fd7884ce 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,6 +1,6 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" -version = "0.7.16" +version = "0.7.17" authors = ["Guillaume Dalle", "Adrian Hill"] [deps] @@ -71,7 +71,7 @@ ForwardDiff = "0.10.36,1" GPUArraysCore = "0.2" GTPSA = "1.4.0" LinearAlgebra = "1" -Mooncake = "0.5.1 - 0.5.24" +Mooncake = "0.5.1" PolyesterForwardDiff = "0.1.2" ReverseDiff = "1.15.1" SparseArrays = "1" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index 3513d548c..db3cdebbc 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -29,14 +29,15 @@ using Mooncake: NoRData, primal, _copy_output, - _copy_to_output!!, - tangent_to_primal!! + _copy_to_output!! const AnyAutoMooncake{C} = Union{AutoMooncake{C}, AutoMooncakeForward{C}} DI.check_available(::AnyAutoMooncake{C}) where {C} = true DI.inner_preparation_behavior(::AutoMooncakeForward) = DI.PrepareInnerSimple() +@inline new_friendly_tangents() = isdefined(Mooncake, :FriendlyTangentCache) + include("utils.jl") include("onearg.jl") include("twoarg.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl index b22d8d49b..b6ec0c452 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl @@ -9,10 +9,19 @@ function call_and_return(f!::F, y, x, contexts...) where {F} return y end +function adaptive_tangent_to_primal!!(primal, tangent) + @static if new_friendly_tangents() + # TODO: optimize performance by allocating cache during prep + return Mooncake.tangent_to_friendly!!(primal, tangent) + else + return Mooncake.tangent_to_primal!!(primal, tangent) + end +end + function zero_tangent_or_primal(x, backend::AnyAutoMooncake) if get_config(backend).friendly_tangents # zero(x) but safer - return tangent_to_primal!!(_copy_output(x), zero_tangent(x)) + return adaptive_tangent_to_primal!!(_copy_output(x), zero_tangent(x)) else return zero_tangent(x) end diff --git a/DifferentiationInterface/test/Back/Mooncake-old/Project.toml b/DifferentiationInterface/test/Back/Mooncake-old/Project.toml new file mode 100644 index 000000000..8ed2b0166 --- /dev/null +++ b/DifferentiationInterface/test/Back/Mooncake-old/Project.toml @@ -0,0 +1,13 @@ +[deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3" +ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" +SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +Mooncake = "<0.5.25" diff --git a/DifferentiationInterface/test/Back/Mooncake-old/test.jl b/DifferentiationInterface/test/Back/Mooncake-old/test.jl new file mode 120000 index 000000000..ec8c9b78c --- /dev/null +++ b/DifferentiationInterface/test/Back/Mooncake-old/test.jl @@ -0,0 +1 @@ +../Mooncake/test.jl \ No newline at end of file diff --git a/DifferentiationInterface/test/Back/Mooncake/Project.toml b/DifferentiationInterface/test/Back/Mooncake/Project.toml index 35595c9d1..8d659dc40 100644 --- a/DifferentiationInterface/test/Back/Mooncake/Project.toml +++ b/DifferentiationInterface/test/Back/Mooncake/Project.toml @@ -11,3 +11,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [sources] DifferentiationInterface = { path = "../../.." } + +[compat] +Mooncake = ">=0.5.25" \ No newline at end of file diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index d531e542a..eb80af551 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -74,9 +74,12 @@ test_differentiation( @test grad.B == ps.A end -test_differentiation( - backends[3:4], - nomatrix(static_scenarios()); - logging = LOGGING, - excluded = SECOND_ORDER -) +# see https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/986 +if pkgversion(Mooncake) < v"0.5.25" + test_differentiation( + backends[3:4], + nomatrix(static_scenarios()); + logging = LOGGING, + excluded = SECOND_ORDER + ) +end