diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 8c13ab6a77..4b12b01889 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -89,7 +89,7 @@ jobs: fail-fast: false matrix: test_group: [ - {test_type: 'ext', label: 'differentiation_interface'}, + # {test_type: 'ext', label: 'differentiation_interface'}, {test_type: 'ext', label: 'dynamic_expressions'}, {test_type: 'ext', label: 'flux'}, {test_type: 'ext', label: 'function_wrappers'}, @@ -100,7 +100,7 @@ jobs: {test_type: 'integration_testing', label: 'array'}, {test_type: 'integration_testing', label: 'bijectors'}, {test_type: 'integration_testing', label: 'diff_tests'}, - {test_type: 'integration_testing', label: 'diffeq'}, + # {test_type: 'integration_testing', label: 'diffeq'}, {test_type: 'integration_testing', label: 'dispatch_doctor'}, {test_type: 'integration_testing', label: 'distributions'}, {test_type: 'integration_testing', label: 'dynamicppl'}, diff --git a/docs/Project.toml b/docs/Project.toml index 4c4c1c40da..ca50a2962a 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -2,7 +2,6 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656" @@ -12,7 +11,6 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" [compat] AllocCheck = "0.2.0" ChainRulesCore = "1" -DifferentiationInterface = "0.7.16" Documenter = "1" JET = "0.9, 0.10, 0.11" Mooncake = "0.5, 0.6" diff --git a/docs/src/developer_documentation/custom_tangent_type.md b/docs/src/developer_documentation/custom_tangent_type.md index a8a066401f..ad0ff284a5 100644 --- a/docs/src/developer_documentation/custom_tangent_type.md +++ b/docs/src/developer_documentation/custom_tangent_type.md @@ -34,7 +34,6 @@ A *recursive type* is a struct that contains itself (directly or indirectly) as ```@setup custom_tangent_type using Mooncake: Mooncake -using DifferentiationInterface using JET using AllocCheck using Test @@ -183,7 +182,8 @@ With these, you can now differentiate simple functions: ```@example custom_tangent_type a = A(1.0) -val, grad = DifferentiationInterface.value_and_gradient(f1, AutoMooncake(; config=nothing), a) +cache = Mooncake.prepare_gradient_cache(f1, a) +val, (_, grad) = Mooncake.value_and_gradient!!(cache, f1, a) ``` Another example: @@ -194,7 +194,8 @@ function prod_x(a::A{T}) where {T} return a.a === nothing ? a_val : a_val * prod_x(a.a) end sum_a = A(1.0, A(2.0, A(3.0))) -val_f5, grad_f5 = DifferentiationInterface.value_and_gradient(prod_x, AutoMooncake(; config=nothing), sum_a) +cache_prod_x = Mooncake.prepare_gradient_cache(prod_x, sum_a) +val_f5, (_, grad_f5) = Mooncake.value_and_gradient!!(cache_prod_x, prod_x, sum_a) ``` Depending on your use case, this may be sufficient. diff --git a/docs/src/tutorial.md b/docs/src/tutorial.md index 4ed3b26908..e5be8ba24f 100644 --- a/docs/src/tutorial.md +++ b/docs/src/tutorial.md @@ -7,7 +7,7 @@ There are two ways to compute gradients with Mooncake.jl: We recommend the former to start with, especially if you want to experiment with other automatic differentiation packages. -```@example tuto +```julia import DifferentiationInterface as DI import Mooncake ``` @@ -17,7 +17,7 @@ import Mooncake DifferentiationInterface.jl (or DI for short) provides a common entry point for every automatic differentiation package in Julia. To specify that you want to use Mooncake.jl, just create the right "backend" object (with an optional [`Mooncake.Config`](@ref)): -```@example tuto +```julia backend = DI.AutoMooncake(; config=nothing) ``` @@ -32,7 +32,7 @@ such that the tangent of a `ComplexF64` is a `ComplexF64`, or the tangent of a `Symmetric` is a `Symmetric`, set `friendly_tangents=true` in the config: -```@example tuto +```julia backend = DI.AutoMooncake(; config=Mooncake.Config(; friendly_tangents=true)) ``` @@ -40,26 +40,26 @@ backend = DI.AutoMooncake(; config=Mooncake.Config(; friendly_tangents=true)) Suppose you want to differentiate the following function -```@example tuto +```julia f(x) = sum(abs2, x) ``` on the following input -```@example tuto +```julia x = float.(1:3) ``` The naive way is to simply call [`DI.gradient`](@extref DifferentiationInterface.gradient): -```@example tuto +```julia DI.gradient(f, backend, x) # slow, do not do this ``` This returns the correct gradient, but it is very slow because it includes the time taken by Mooncake.jl to compute a differentiation rule for `f` (see [Mooncake.jl's Rule System](@ref)). If you anticipate you will need more than one gradient, it is better to call [`DI.prepare_gradient`](@extref DifferentiationInterface.prepare_gradient) on a typical (e.g. random) input first: -```@example tuto +```julia typical_x = rand(3) prep = DI.prepare_gradient(f, backend, typical_x) ``` @@ -68,20 +68,20 @@ The typical input should have the same size and type as the actual inputs we wil As for the contents of the preparation result, they do not matter. What matters is that it captures everything you need for `DI.gradient` to be fast: -```@example tuto +```julia DI.gradient(f, prep, backend, x) # fast ``` For optimal speed, you can provide storage space for the gradient and call [`DI.gradient!`](@extref DifferentiationInterface.gradient!) instead: -```@example tuto +```julia grad = similar(x) DI.gradient!(f, grad, prep, backend, x) # very fast ``` If you also need the value of the function, check out [`DI.value_and_gradient`](@extref DifferentiationInterface.value_and_gradient) or [`DI.value_and_gradient!`](@extref DifferentiationInterface.value_and_gradient!): -```@example tuto +```julia DI.value_and_gradient(f, prep, backend, x) ``` @@ -91,14 +91,14 @@ What should you do if your function takes more than one input argument? Well, DI can still handle it, _assuming that you only want the derivative with respect to one of them_ (the first one, by convention). For instance, consider the function -```@example tuto +```julia g(x, a, b) = a * f(x) + b ``` You can easily compute the gradient with respect to `x`, while keeping `a` and `b` fixed. To do that, just wrap these two arguments inside [`DI.Constant`](@extref DifferentiationInterface.Constant), like so: -```@example tuto +```julia typical_a, typical_b = 1.0, 1.0 prep = DI.prepare_gradient(g, backend, typical_x, DI.Constant(typical_a), DI.Constant(typical_b)) @@ -113,7 +113,7 @@ If one of your additional arguments behaves like a scratch space in memory (inst Now what if you care about the derivatives with respect to every argument? You can always go back to the single-argument case by putting everything inside a tuple: -```@example tuto +```julia g_tup(xab) = xab[2] * f(xab[1]) + xab[3] prep = DI.prepare_gradient(g_tup, backend, (typical_x, typical_a, typical_b)) DI.value_and_gradient(g_tup, prep, backend, (x, a, b)) @@ -126,7 +126,7 @@ You can also use the native API of Mooncake.jl, discussed below. Going through DI allows you to compute other kinds of derivatives, like (reverse-mode) Jacobian matrices. The syntax is very similar: -```@example tuto +```julia h(x) = cos.(x) .* sin.(reverse(x)) prep = DI.prepare_jacobian(h, backend, x) DI.jacobian(h, prep, backend, x) @@ -134,6 +134,10 @@ DI.jacobian(h, prep, backend, x) ## Mooncake.jl API +```@example mooncake_api +import Mooncake +``` + ### Mooncake.jl Functions Mooncake.jl provides the following core differentiation functions: @@ -161,3 +165,29 @@ Mooncake.jl discusses Frechet derivatives and their adjoints, as described in de !!! info For a detailed mathematical treatment of these concepts, see [Algorithmic Differentiation](@ref), particularly the sections on [Derivatives](@ref). +### Single Argument + +```@example mooncake_api +f(x) = sum(abs2, x) +x = float.(1:3) + +# Prepare the differentiation rule once (handles compilation) +cache = Mooncake.prepare_gradient_cache(f, x) + +# Compute value and gradient (fast on repeated calls) +val, (_, grad) = Mooncake.value_and_gradient!!(cache, f, x) +(val, grad) +``` + +### Multiple Arguments + +To differentiate with respect to all arguments, pack them into a tuple: + +```@example mooncake_api +g(xab) = xab[2] * f(xab[1]) + xab[3] +a, b = 2.0, 3.0 +cache_g = Mooncake.prepare_gradient_cache(g, (x, a, b)) +val_g, (_, grad_g) = Mooncake.value_and_gradient!!(cache_g, g, (x, a, b)) +(val_g, grad_g) +``` + diff --git a/ext/MooncakeDynamicExpressionsExt.jl b/ext/MooncakeDynamicExpressionsExt.jl index 2d9e683b38..a7b80a81de 100644 --- a/ext/MooncakeDynamicExpressionsExt.jl +++ b/ext/MooncakeDynamicExpressionsExt.jl @@ -84,7 +84,9 @@ function set_children!(t::TangentNode{Tv,D}, fdata::Tuple{Vararg{Any,deg}}) wher end function DE.extract_gradient( - gradient::Mooncake.Tangent{@NamedTuple{tree::TN,metadata::Mooncake.NoTangent}}, + gradient::Mooncake.PossiblyMutableTangent{ + @NamedTuple{tree::TN,metadata::Mooncake.NoTangent} + }, tree::Expression{T}, ) where {Tv,TN<:TangentNode{Tv},T} return DE.extract_gradient(gradient.fields.tree, DE.get_tree(tree)) diff --git a/test/ext/dynamic_expressions/Project.toml b/test/ext/dynamic_expressions/Project.toml index f8fe0995aa..d80d7c3afb 100644 --- a/test/ext/dynamic_expressions/Project.toml +++ b/test/ext/dynamic_expressions/Project.toml @@ -3,7 +3,6 @@ AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" -Optim = "429524aa-4258-5aef-a3af-852621145aeb" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/ext/dynamic_expressions/dynamic_expressions.jl b/test/ext/dynamic_expressions/dynamic_expressions.jl index 974094bad2..1670613c25 100644 --- a/test/ext/dynamic_expressions/dynamic_expressions.jl +++ b/test/ext/dynamic_expressions/dynamic_expressions.jl @@ -6,7 +6,6 @@ using Mooncake using Mooncake: Mooncake, prepare_gradient_cache, value_and_gradient!! using Mooncake.TestUtils using Mooncake.TestUtils: test_rule, test_data -using Optim: Optim using DynamicExpressions using DynamicExpressions: Nullable using StableRNGs: StableRNG @@ -135,40 +134,6 @@ end end end -@testset "Use in DynamicExpressions + Optim optimization" begin - let - operators = OperatorEnum(1 => (cos, sin, exp), 2 => (+, -, *, /)) - x1 = Expression(Node{Float64}(; feature=1); operators) - x2 = Expression(Node{Float64}(; feature=2); operators) - init = x1 * exp(0.7 + 0.5 * x1) + 0.9 * x2 - target = x1 * exp(0.3 + (-0.2) * x1) + 1.5 * x2 - X = randn(StableRNG(0), 2, 128) - y = target(X) - - f = let X = X, y = y - function (ex) - pred = ex(X) - return sum(i -> abs2(pred[i] - y[i]), axes(X, 2)) - end - end - - cache = prepare_gradient_cache(f, init) - g! = let cache = cache, f = f - function (G, ex) - y, grad = value_and_gradient!!(cache, f, ex) - d_f, d_ex = grad - G .= extract_gradient(d_ex, ex) - return nothing - end - end - ex0 = copy(init) - result = Optim.optimize(f, g!, ex0, Optim.BFGS()) - constants_final = get_scalar_constants(result.minimizer)[1] - constants_target = get_scalar_constants(get_tree(target))[1] - @test isapprox(constants_final, constants_target, atol=1e-5) - end -end - @testset "TestUtils systematic tests - $(T)" for T in [Float32, Float64] let operators = OperatorEnum(