Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ jobs:
fail-fast: false
matrix:
test_group: [
{test_type: 'ext', label: 'differentiation_interface'},
# {test_type: 'ext', label: 'differentiation_interface'},
{test_type: 'ext', label: 'dynamic_expressions'},
{test_type: 'ext', label: 'flux'},
{test_type: 'ext', label: 'function_wrappers'},
Expand All @@ -100,7 +100,7 @@ jobs:
{test_type: 'integration_testing', label: 'array'},
{test_type: 'integration_testing', label: 'bijectors'},
{test_type: 'integration_testing', label: 'diff_tests'},
{test_type: 'integration_testing', label: 'diffeq'},
# {test_type: 'integration_testing', label: 'diffeq'},
{test_type: 'integration_testing', label: 'dispatch_doctor'},
{test_type: 'integration_testing', label: 'distributions'},
{test_type: 'integration_testing', label: 'dynamicppl'},
Expand Down
2 changes: 0 additions & 2 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656"
Expand All @@ -12,7 +11,6 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
[compat]
AllocCheck = "0.2.0"
ChainRulesCore = "1"
DifferentiationInterface = "0.7.16"
Documenter = "1"
JET = "0.9, 0.10, 0.11"
Mooncake = "0.5, 0.6"
7 changes: 4 additions & 3 deletions docs/src/developer_documentation/custom_tangent_type.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ A *recursive type* is a struct that contains itself (directly or indirectly) as

```@setup custom_tangent_type
using Mooncake: Mooncake
using DifferentiationInterface
using JET
using AllocCheck
using Test
Expand Down Expand Up @@ -183,7 +182,8 @@ With these, you can now differentiate simple functions:

```@example custom_tangent_type
a = A(1.0)
val, grad = DifferentiationInterface.value_and_gradient(f1, AutoMooncake(; config=nothing), a)
cache = Mooncake.prepare_gradient_cache(f1, a)
val, (_, grad) = Mooncake.value_and_gradient!!(cache, f1, a)
```

Another example:
Expand All @@ -194,7 +194,8 @@ function prod_x(a::A{T}) where {T}
return a.a === nothing ? a_val : a_val * prod_x(a.a)
end
sum_a = A(1.0, A(2.0, A(3.0)))
val_f5, grad_f5 = DifferentiationInterface.value_and_gradient(prod_x, AutoMooncake(; config=nothing), sum_a)
cache_prod_x = Mooncake.prepare_gradient_cache(prod_x, sum_a)
val_f5, (_, grad_f5) = Mooncake.value_and_gradient!!(cache_prod_x, prod_x, sum_a)
```

Depending on your use case, this may be sufficient.
Expand Down
58 changes: 44 additions & 14 deletions docs/src/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ There are two ways to compute gradients with Mooncake.jl:

We recommend the former to start with, especially if you want to experiment with other automatic differentiation packages.

```@example tuto
```julia
import DifferentiationInterface as DI
import Mooncake
```
Expand All @@ -17,7 +17,7 @@ import Mooncake
DifferentiationInterface.jl (or DI for short) provides a common entry point for every automatic differentiation package in Julia.
To specify that you want to use Mooncake.jl, just create the right "backend" object (with an optional [`Mooncake.Config`](@ref)):

```@example tuto
```julia
backend = DI.AutoMooncake(; config=nothing)
```

Expand All @@ -32,34 +32,34 @@ such that the tangent of a `ComplexF64` is a `ComplexF64`,
or the tangent of a `Symmetric` is a `Symmetric`,
set `friendly_tangents=true` in the config:

```@example tuto
```julia
backend = DI.AutoMooncake(; config=Mooncake.Config(; friendly_tangents=true))
```

### Single argument

Suppose you want to differentiate the following function

```@example tuto
```julia
f(x) = sum(abs2, x)
```

on the following input

```@example tuto
```julia
x = float.(1:3)
```

The naive way is to simply call [`DI.gradient`](@extref DifferentiationInterface.gradient):

```@example tuto
```julia
DI.gradient(f, backend, x) # slow, do not do this
```

This returns the correct gradient, but it is very slow because it includes the time taken by Mooncake.jl to compute a differentiation rule for `f` (see [Mooncake.jl's Rule System](@ref)).
If you anticipate you will need more than one gradient, it is better to call [`DI.prepare_gradient`](@extref DifferentiationInterface.prepare_gradient) on a typical (e.g. random) input first:

```@example tuto
```julia
typical_x = rand(3)
prep = DI.prepare_gradient(f, backend, typical_x)
```
Expand All @@ -68,20 +68,20 @@ The typical input should have the same size and type as the actual inputs we wil
As for the contents of the preparation result, they do not matter.
What matters is that it captures everything you need for `DI.gradient` to be fast:

```@example tuto
```julia
DI.gradient(f, prep, backend, x) # fast
```

For optimal speed, you can provide storage space for the gradient and call [`DI.gradient!`](@extref DifferentiationInterface.gradient!) instead:

```@example tuto
```julia
grad = similar(x)
DI.gradient!(f, grad, prep, backend, x) # very fast
```

If you also need the value of the function, check out [`DI.value_and_gradient`](@extref DifferentiationInterface.value_and_gradient) or [`DI.value_and_gradient!`](@extref DifferentiationInterface.value_and_gradient!):

```@example tuto
```julia
DI.value_and_gradient(f, prep, backend, x)
```

Expand All @@ -91,14 +91,14 @@ What should you do if your function takes more than one input argument?
Well, DI can still handle it, _assuming that you only want the derivative with respect to one of them_ (the first one, by convention).
For instance, consider the function

```@example tuto
```julia
g(x, a, b) = a * f(x) + b
```

You can easily compute the gradient with respect to `x`, while keeping `a` and `b` fixed.
To do that, just wrap these two arguments inside [`DI.Constant`](@extref DifferentiationInterface.Constant), like so:

```@example tuto
```julia
typical_a, typical_b = 1.0, 1.0
prep = DI.prepare_gradient(g, backend, typical_x, DI.Constant(typical_a), DI.Constant(typical_b))

Expand All @@ -113,7 +113,7 @@ If one of your additional arguments behaves like a scratch space in memory (inst
Now what if you care about the derivatives with respect to every argument?
You can always go back to the single-argument case by putting everything inside a tuple:

```@example tuto
```julia
g_tup(xab) = xab[2] * f(xab[1]) + xab[3]
prep = DI.prepare_gradient(g_tup, backend, (typical_x, typical_a, typical_b))
DI.value_and_gradient(g_tup, prep, backend, (x, a, b))
Expand All @@ -126,14 +126,18 @@ You can also use the native API of Mooncake.jl, discussed below.
Going through DI allows you to compute other kinds of derivatives, like (reverse-mode) Jacobian matrices.
The syntax is very similar:

```@example tuto
```julia
h(x) = cos.(x) .* sin.(reverse(x))
prep = DI.prepare_jacobian(h, backend, x)
DI.jacobian(h, prep, backend, x)
```

## Mooncake.jl API

```@example mooncake_api
import Mooncake
```

### Mooncake.jl Functions

Mooncake.jl provides the following core differentiation functions:
Expand Down Expand Up @@ -161,3 +165,29 @@ Mooncake.jl discusses Frechet derivatives and their adjoints, as described in de
!!! info
For a detailed mathematical treatment of these concepts, see [Algorithmic Differentiation](@ref), particularly the sections on [Derivatives](@ref).

### Single Argument

```@example mooncake_api
f(x) = sum(abs2, x)
x = float.(1:3)

# Prepare the differentiation rule once (handles compilation)
cache = Mooncake.prepare_gradient_cache(f, x)

# Compute value and gradient (fast on repeated calls)
val, (_, grad) = Mooncake.value_and_gradient!!(cache, f, x)
(val, grad)
```

### Multiple Arguments

To differentiate with respect to all arguments, pack them into a tuple:

```@example mooncake_api
g(xab) = xab[2] * f(xab[1]) + xab[3]
a, b = 2.0, 3.0
cache_g = Mooncake.prepare_gradient_cache(g, (x, a, b))
val_g, (_, grad_g) = Mooncake.value_and_gradient!!(cache_g, g, (x, a, b))
(val_g, grad_g)
```

4 changes: 3 additions & 1 deletion ext/MooncakeDynamicExpressionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ function set_children!(t::TangentNode{Tv,D}, fdata::Tuple{Vararg{Any,deg}}) wher
end

function DE.extract_gradient(
gradient::Mooncake.Tangent{@NamedTuple{tree::TN,metadata::Mooncake.NoTangent}},
gradient::Mooncake.PossiblyMutableTangent{
@NamedTuple{tree::TN,metadata::Mooncake.NoTangent}
},
tree::Expression{T},
) where {Tv,TN<:TangentNode{Tv},T}
return DE.extract_gradient(gradient.fields.tree, DE.get_tree(tree))
Expand Down
1 change: 0 additions & 1 deletion test/ext/dynamic_expressions/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
35 changes: 0 additions & 35 deletions test/ext/dynamic_expressions/dynamic_expressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ using Mooncake
using Mooncake: Mooncake, prepare_gradient_cache, value_and_gradient!!
using Mooncake.TestUtils
using Mooncake.TestUtils: test_rule, test_data
using Optim: Optim
using DynamicExpressions
using DynamicExpressions: Nullable
using StableRNGs: StableRNG
Expand Down Expand Up @@ -135,40 +134,6 @@ end
end
end

@testset "Use in DynamicExpressions + Optim optimization" begin
let
operators = OperatorEnum(1 => (cos, sin, exp), 2 => (+, -, *, /))
x1 = Expression(Node{Float64}(; feature=1); operators)
x2 = Expression(Node{Float64}(; feature=2); operators)
init = x1 * exp(0.7 + 0.5 * x1) + 0.9 * x2
target = x1 * exp(0.3 + (-0.2) * x1) + 1.5 * x2
X = randn(StableRNG(0), 2, 128)
y = target(X)

f = let X = X, y = y
function (ex)
pred = ex(X)
return sum(i -> abs2(pred[i] - y[i]), axes(X, 2))
end
end

cache = prepare_gradient_cache(f, init)
g! = let cache = cache, f = f
function (G, ex)
y, grad = value_and_gradient!!(cache, f, ex)
d_f, d_ex = grad
G .= extract_gradient(d_ex, ex)
return nothing
end
end
ex0 = copy(init)
result = Optim.optimize(f, g!, ex0, Optim.BFGS())
constants_final = get_scalar_constants(result.minimizer)[1]
constants_target = get_scalar_constants(get_tree(target))[1]
@test isapprox(constants_final, constants_target, atol=1e-5)
end
end

@testset "TestUtils systematic tests - $(T)" for T in [Float32, Float64]
let
operators = OperatorEnum(
Expand Down
Loading