diff --git a/docs/Project.toml b/docs/Project.toml index a63831446..a08ec184b 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -7,6 +7,7 @@ DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503" DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" @@ -15,6 +16,7 @@ IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" @@ -35,13 +37,14 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Calculus = "0.5" -ComponentArrays = "0.15" +ComponentArrays = "0.15.34" DataInterpolations = "3.10, 4, 5, 6, 7, 8" DelayDiffEq = "5" DelimitedFiles = "1" DiffEqCallbacks = "2.24, 3, 4" DiffEqNoiseProcess = "5.14" DifferentialEquations = "7" +DifferentiationInterface = "0.6, 0.7" Documenter = "1" Enzyme = "0.12, 0.13" Flux = "0.14, 0.15, 0.16" @@ -50,6 +53,7 @@ IterTools = "1" Lux = "1" LuxCUDA = "0.3" MLUtils = "0.4" +Mooncake = "0.5" Optimization = "3.9, 4, 5" OptimizationOptimJL = "0.2, 0.3, 0.4" OptimizationOptimisers = "0.2, 0.3" diff --git a/docs/src/Benchmark.md b/docs/src/Benchmark.md index 3ae61d3d6..2387da218 100644 --- a/docs/src/Benchmark.md +++ b/docs/src/Benchmark.md @@ -40,7 +40,8 @@ Quick summary: import OrdinaryDiffEq as ODE import Lux import SciMLSensitivity as SMS -import Zygote +import Mooncake +import DifferentiationInterface as DI import BenchmarkTools import Random import ComponentArrays as CA @@ -80,7 +81,9 @@ for sensealg in (SMS.InterpolatingAdjoint(autojacvec = SMS.ZygoteVJP()), return loss end - t = BenchmarkTools.@belapsed Zygote.gradient($loss_neuralode, $u0, $ps, $st) + backend = DI.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) + loss_ps = p -> loss_neuralode(u0, p, st) + t = BenchmarkTools.@belapsed DI.gradient($loss_ps, $backend, $ps) println("$(sensealg) took $(t)s") end diff --git a/docs/src/examples/dde/delay_diffeq.md b/docs/src/examples/dde/delay_diffeq.md index 47e0d9f4b..950bb1eda 100644 --- a/docs/src/examples/dde/delay_diffeq.md +++ b/docs/src/examples/dde/delay_diffeq.md @@ -10,6 +10,7 @@ import Optimization as OPT import SciMLSensitivity as SMS import OptimizationPolyalgorithms as OPA import DelayDiffEq as DDE +import Mooncake # Define the same LV equation, but including a delay parameter function delay_lotka_volterra!(du, u, h, p, t) @@ -35,7 +36,7 @@ prob_dde = DDE.DDEProblem(delay_lotka_volterra!, u0, h, (0.0, 10.0), function predict_dde(p) return Array(ODE.solve(prob_dde, DDE.MethodOfSteps(ODE.Tsit5()); - u0, p, saveat = 0.1, sensealg = SMS.ReverseDiffAdjoint())) + u0, p, saveat = 0.1)) end loss_dde(p) = sum(abs2, x - 1 for x in predict_dde(p)) @@ -50,14 +51,18 @@ callback = function (state, l; doplot = false) return false end -adtype = OPT.AutoZygote() +adtype = OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) optf = OPT.OptimizationFunction((x, p) -> loss_dde(x), adtype) optprob = OPT.OptimizationProblem(optf, p) result_dde = OPT.solve(optprob, OPA.PolyOpt(); maxiters = 300, callback) ``` -Notice that we chose `sensealg = ReverseDiffAdjoint()` to utilize the ReverseDiff.jl -reverse-mode to handle the delay differential equation. +The `sensealg` is left at its default. For DDEs the automatic choice is +[`ForwardDiffSensitivity`](@ref) (which differentiates through +`MethodOfSteps` via dual numbers) for problems with fewer than 100 +parameters, and [`ReverseDiffAdjoint`](@ref) for larger ones — +[continuous adjoints](@ref sensitivity_diffeq) are not yet defined for +DDEs, so the discretize-then-optimize methods are the only option. We define a callback to display the solution at the current parameters for each step of the training: @@ -76,7 +81,7 @@ end We use `Optimization.solve` to optimize the parameters for our loss function: ```@example dde -adtype = OPT.AutoZygote() +adtype = OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) optf = OPT.OptimizationFunction((x, p) -> loss_dde(x), adtype) optprob = OPT.OptimizationProblem(optf, p) result_dde = OPT.solve(optprob, OPA.PolyOpt(); callback) diff --git a/docs/src/examples/hybrid_jump/bouncing_ball.md b/docs/src/examples/hybrid_jump/bouncing_ball.md index 22a4a50f8..9e70c1f6e 100644 --- a/docs/src/examples/hybrid_jump/bouncing_ball.md +++ b/docs/src/examples/hybrid_jump/bouncing_ball.md @@ -13,6 +13,7 @@ import OptimizationPolyalgorithms as OPA import SciMLSensitivity as SMS import OrdinaryDiffEq as ODE import DiffEqCallbacks as DEC +import Mooncake function f(du, u, p, t) du[1] = u[2] @@ -44,11 +45,15 @@ the value 20: function loss(θ) sol = ODE.solve(prob, ODE.Tsit5(), p = [9.8, θ[1]]; callback) target = 20.0 - abs2(sol[end][1] - target) + # Use `last(sol.u)[1]` instead of `sol[end][1]` — Mooncake's pullback for + # `getindex(::ODESolution, end)` currently has a `BoundsError` bug + # (`SciMLBaseMooncakeExt._scatter_pullback`). Indexing the underlying + # `sol.u::Vector{Vector{Float64}}` directly avoids the bad path. + abs2(last(sol.u)[1] - target) end loss([0.8]) -adtype = OPT.AutoZygote() +adtype = OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) optf = OPT.OptimizationFunction((x, p) -> loss(x), adtype) optprob = OPT.OptimizationProblem(optf, [0.8]) @time res = OPT.solve(optprob, OPA.PolyOpt(), maxiters = 300) diff --git a/docs/src/examples/hybrid_jump/hybrid_diffeq.md b/docs/src/examples/hybrid_jump/hybrid_diffeq.md index 937d734d3..d46de7caf 100644 --- a/docs/src/examples/hybrid_jump/hybrid_diffeq.md +++ b/docs/src/examples/hybrid_jump/hybrid_diffeq.md @@ -12,6 +12,7 @@ import ComponentArrays as CA import Random import SciMLSensitivity as SMS import Lux +import Mooncake import OrdinaryDiffEq as ODE import Plots import Optimization as OPT @@ -50,9 +51,7 @@ cb = DEC.PresetTimeCallback(dosetimes, affect!, save_positions = (false, false)) function predict_n_ode(p) _prob = ODE.remake(prob; p) - Array(ODE.solve(_prob, ODE.Tsit5(); u0 = z0, p, callback = cb, saveat = t, - sensealg = SMS.ReverseDiffAdjoint()))[1:2, :] - #Array(solve(prob,Tsit5();u0=z0,p,saveat=t))[1:2,:] + Array(ODE.solve(_prob, ODE.Tsit5(); u0 = z0, p, callback = cb, saveat = t))[1:2, :] end function loss_n_ode(p, _) @@ -73,7 +72,10 @@ cba = function (state, l; doplot = false) #callback function to observe training end res = OPT.solve( - OPT.OptimizationProblem(OPT.OptimizationFunction(loss_n_ode, OPT.AutoZygote()), + OPT.OptimizationProblem( + OPT.OptimizationFunction( + loss_n_ode, OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) + ), CA.ComponentArray(ps)), OPO.Adam(0.05); callback = cba, maxiters = 1000) ``` diff --git a/docs/src/examples/neural_ode/simplechains.md b/docs/src/examples/neural_ode/simplechains.md index 1bf167771..6fd59a1a9 100644 --- a/docs/src/examples/neural_ode/simplechains.md +++ b/docs/src/examples/neural_ode/simplechains.md @@ -1,5 +1,6 @@ # Faster Neural Ordinary Differential Equations with SimpleChains + [SimpleChains](https://github.com/PumasAI/SimpleChains.jl) has demonstrated performance boosts of ~5x and ~30x when compared to other mainstream deep learning frameworks like Pytorch for the training and evaluation in the specific case of small neural networks. For the nitty-gritty details, as well as, some SciML related videos around the need and applications of such a library, we can refer to this [blogpost](https://julialang.org/blog/2022/04/simple-chains/). As for doing Scientific Machine Learning, how do we even begin with training neural ODEs with any generic deep learning library? ## Training Data diff --git a/docs/src/examples/ode/exogenous_input.md b/docs/src/examples/ode/exogenous_input.md index 5b40bc364..b3d0c89f9 100644 --- a/docs/src/examples/ode/exogenous_input.md +++ b/docs/src/examples/ode/exogenous_input.md @@ -49,6 +49,7 @@ import OptimizationPolyalgorithms as OPA import OptimizationOptimisers as OPO import Plots import Random +import Mooncake rng = Random.default_rng() tspan = (0.1, 10.0) @@ -93,7 +94,7 @@ function loss(p) return sum(abs2.(y[1:N] .- sol')) / N end -adtype = OPT.AutoZygote() +adtype = OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) optf = OPT.OptimizationFunction((x, p) -> loss(x), adtype) optprob = OPT.OptimizationProblem(optf, CA.ComponentArray{Float64}(p_model)) diff --git a/docs/src/examples/ode/second_order_adjoints.md b/docs/src/examples/ode/second_order_adjoints.md index 4ec2d9fe6..00e4c0112 100644 --- a/docs/src/examples/ode/second_order_adjoints.md +++ b/docs/src/examples/ode/second_order_adjoints.md @@ -13,6 +13,7 @@ optimization, while `KrylovTrustRegion` will utilize a Krylov-based method with Hessian-vector products (never forming the Hessian) for large parameter optimizations. + ```@example secondorderadjoints import SciMLSensitivity as SMS import Lux @@ -23,6 +24,7 @@ import OrdinaryDiffEq as ODE import Plots import Random import OptimizationOptimJL as OOJ +import Mooncake u0 = Float32[2.0; 0.0] datasize = 30 @@ -83,13 +85,14 @@ callback = function (state, l; doplot = false) return l < 0.01 end -adtype = OPT.AutoZygote() -optf = OPT.OptimizationFunction((x, p) -> loss_neuralode(x), adtype) - -optprob1 = OPT.OptimizationProblem(optf, ps) +adtype1 = OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) +optf1 = OPT.OptimizationFunction((x, p) -> loss_neuralode(x), adtype1) +optprob1 = OPT.OptimizationProblem(optf1, ps) pstart = OPT.solve(optprob1, OPO.Adam(0.01); callback, maxiters = 100).u -optprob2 = OPT.OptimizationProblem(optf, pstart) +adtype2 = OPT.AutoZygote() +optf2 = OPT.OptimizationFunction((x, p) -> loss_neuralode(x), adtype2) +optprob2 = OPT.OptimizationProblem(optf2, pstart) pmin = OPT.solve(optprob2, OOJ.NewtonTrustRegion(); callback, maxiters = 200) ``` diff --git a/docs/src/examples/ode/second_order_neural.md b/docs/src/examples/ode/second_order_neural.md index 41eb6c128..5d12266df 100644 --- a/docs/src/examples/ode/second_order_neural.md +++ b/docs/src/examples/ode/second_order_neural.md @@ -29,6 +29,7 @@ import OptimizationOptimisers as OPO import RecursiveArrayTools import Random import ComponentArrays as CA +import Mooncake u0 = Float32[0.0; 2.0] du0 = Float32[0.0; 0.0] @@ -61,7 +62,7 @@ callback = function (state, l) l < 0.01 end -adtype = OPT.AutoZygote() +adtype = OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) optf = OPT.OptimizationFunction((x, p) -> loss_n_ode(x), adtype) optprob = OPT.OptimizationProblem(optf, ps) diff --git a/docs/src/examples/optimal_control/feedback_control.md b/docs/src/examples/optimal_control/feedback_control.md index 228b548d9..6aa414b6e 100644 --- a/docs/src/examples/optimal_control/feedback_control.md +++ b/docs/src/examples/optimal_control/feedback_control.md @@ -4,6 +4,7 @@ You can also mix a known differential equation and a neural differential equation, so that the parameters and the neural network are estimated simultaneously! + We will assume that we know the dynamics of the second equation (linear dynamics), and our goal is to find a neural network that is dependent on the current state of the dynamical system that will control the second diff --git a/docs/src/examples/optimal_control/optimal_control.md b/docs/src/examples/optimal_control/optimal_control.md index 53bca29a7..1995ff3ee 100644 --- a/docs/src/examples/optimal_control/optimal_control.md +++ b/docs/src/examples/optimal_control/optimal_control.md @@ -43,7 +43,6 @@ import Optimization as OPT import OptimizationOptimJL as OOJ import OptimizationOptimisers as OPO import SciMLSensitivity as SMS -import Zygote import Plots import Statistics import Random diff --git a/docs/src/examples/pde/brusselator.md b/docs/src/examples/pde/brusselator.md index d8be744da..6ab28840c 100644 --- a/docs/src/examples/pde/brusselator.md +++ b/docs/src/examples/pde/brusselator.md @@ -156,7 +156,7 @@ First, we have to define and configure the neural network that has to be used fo ```@example bruss import Lux, Random, Optimization as OPT, OptimizationOptimJL as OOJ, - SciMLSensitivity as SMS, Zygote + SciMLSensitivity as SMS, Mooncake model = Lux.Chain(Lux.Dense(2 => 16, tanh), Lux.Dense(16 => 1)) rng = Random.default_rng() @@ -223,12 +223,13 @@ function loss_fn(ps, _) end ``` -Once the loss function is defined, we use the ADAM optimizer to train the neural network. The optimization problem is defined using SciML's `Optimization.jl` tools, and gradients are computed via automatic differentiation using `AutoZygote()` from `SciMLSensitivity`: +Once the loss function is defined, we use the ADAM optimizer to train the neural network. The optimization problem is defined using SciML's `Optimization.jl` tools, and gradients are computed via automatic differentiation using Mooncake through the `SciMLSensitivity` adjoint chain: ```@example bruss println("[Training] Starting optimization...") import OptimizationOptimisers as OPO -optf = OPT.OptimizationFunction(loss_fn, SMS.AutoZygote()) +adtype = OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) +optf = OPT.OptimizationFunction(loss_fn, adtype) optprob = OPT.OptimizationProblem(optf, ps_init) loss_history = Float32[] diff --git a/docs/src/examples/pde/pde_constrained.md b/docs/src/examples/pde/pde_constrained.md index d26090fcf..f3bef4cfc 100644 --- a/docs/src/examples/pde/pde_constrained.md +++ b/docs/src/examples/pde/pde_constrained.md @@ -6,7 +6,7 @@ This example uses a prediction model to optimize the one-dimensional Heat Equati ```@example pde import SciMLSensitivity as SMS import DelimitedFiles, Plots -import OrdinaryDiffEq as ODE, Optimization as OPT, OptimizationPolyalgorithms as OPA, Zygote +import OrdinaryDiffEq as ODE, Optimization as OPT, OptimizationPolyalgorithms as OPA, Mooncake # Problem setup parameters: Lx = 10.0 @@ -92,7 +92,7 @@ cb((; u = ps), loss(ps)) # Testing callback function Plots.scatter(sol[:, end], label = "Truth", size = (800, 500)) Plots.plot!(PRED[end][:, end], lw = 2, label = "Prediction") -adtype = OPT.AutoZygote() +adtype = OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) optf = OPT.OptimizationFunction((x, p) -> loss(x), adtype) optprob = OPT.OptimizationProblem(optf, ps) @@ -107,7 +107,7 @@ res = OPT.solve(optprob, OPA.PolyOpt(), callback = cb) ```@example pde2 import SciMLSensitivity as SMS import DelimitedFiles, Plots -import OrdinaryDiffEq as ODE, Optimization as OPT, OptimizationPolyalgorithms as OPA, Zygote +import OrdinaryDiffEq as ODE, Optimization as OPT, OptimizationPolyalgorithms as OPA, Mooncake ``` ### Parameters @@ -283,7 +283,7 @@ The resulting best parameters are stored in `res` and `res.u` returns the parameters that minimize the cost function. ```@example pde2 -adtype = OPT.AutoZygote() +adtype = OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) optf = OPT.OptimizationFunction((x, p) -> loss(x), adtype) optprob = OPT.OptimizationProblem(optf, ps) diff --git a/docs/src/examples/sde/SDE_control.md b/docs/src/examples/sde/SDE_control.md index c44cc836b..cc2a0de49 100644 --- a/docs/src/examples/sde/SDE_control.md +++ b/docs/src/examples/sde/SDE_control.md @@ -18,6 +18,7 @@ to ultimately prepare and stabilize the qubit in the excited state. Before getting to the explanation, here's some code to start with. We will follow a full explanation of the definition and training process: + ```@example # load packages import SciMLSensitivity as SMS, Optimization as OPT, OptimizationOptimisers as OPO diff --git a/docs/src/examples/sde/optimization_sde.md b/docs/src/examples/sde/optimization_sde.md index 7f040c8f0..a9bc10cb5 100644 --- a/docs/src/examples/sde/optimization_sde.md +++ b/docs/src/examples/sde/optimization_sde.md @@ -94,7 +94,8 @@ function cb2(st, l) end ``` -We can then use `Optimization.solve` to fit the SDE: +We can then use `Optimization.solve` to fit the SDE. + ```@example sde import Optimization as OPT, Zygote, OptimizationOptimisers as OPO @@ -178,7 +179,8 @@ end Let's optimize ```@example sde -adtype = OPT.AutoZygote() +import Mooncake +adtype = OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) optf = OPT.OptimizationFunction((x, p) -> loss_sde(x), adtype) optprob = OPT.OptimizationProblem(optf, p) diff --git a/docs/src/faq.md b/docs/src/faq.md index e84b0bbfa..cc72817a1 100644 --- a/docs/src/faq.md +++ b/docs/src/faq.md @@ -48,17 +48,20 @@ Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(_f, _tmp6), This is exactly the inner core Enzyme call and if this fails, that is the issue that needs to be fixed. -And similarly, for out-of-place functions the Zygote isolation is as follows: +And similarly, for out-of-place functions the Mooncake isolation is as follows: ```julia +import Mooncake p = prob.p y = prob.u0 f = prob.f λ = zero(prob.u0) -_dy, back = Zygote.pullback(y, p) do u, p - vec(f(u, p, t)) -end -tmp1, tmp2 = back(λ) +# Build the Mooncake pullback for the inner-rhs evaluation `f(u, p)` and +# apply the cotangent `λ` to recover the seed gradients `tmp1` (wrt `y`) +# and `tmp2` (wrt `p`). +g = (u, p) -> vec(f(u, p, t)) +cache = Mooncake.prepare_pullback_cache(g, y, p) +_dy, (_, tmp1, tmp2) = Mooncake.value_and_pullback!!(cache, λ, g, y, p) ``` ## How do I use custom parameter types with adjoint sensitivity analysis? diff --git a/docs/src/getting_started.md b/docs/src/getting_started.md index 37e8466a8..79baacb07 100644 --- a/docs/src/getting_started.md +++ b/docs/src/getting_started.md @@ -82,23 +82,25 @@ solution at time `t=1` with respect to `p[1]`. ## Reverse-Mode Automatic Differentiation -[The `solve` function is automatically compatible with AD systems like Zygote.jl](https://docs.sciml.ai/SciMLSensitivity/stable/) +[The `solve` function is automatically compatible with reverse-mode AD systems like Mooncake.jl](https://docs.sciml.ai/SciMLSensitivity/stable/) and thus there is no machinery that is necessary to use other than to put `solve` inside -a function that is differentiated by Zygote. For example, the following computes the solution +a function that is differentiated by Mooncake. For example, the following computes the solution to an ODE and computes the gradient of a loss function (the sum of the ODE's output at each timepoint with dt=0.1) via the adjoint method: ```@example diffode -import Zygote +import Mooncake +import DifferentiationInterface as DI -function sum_of_solution(u0, p) - _prob = ODE.remake(prob; u0, p) +function sum_of_solution(p) + _prob = ODE.remake(prob; p) sum(ODE.solve(_prob, ODE.Tsit5(), reltol = 1e-6, abstol = 1e-6, saveat = 0.1)) end -du01, dp1 = Zygote.gradient(sum_of_solution, u0, p) +backend = DI.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) +dp1 = DI.gradient(sum_of_solution, backend, p) ``` -Zygote.jl's automatic differentiation system is overloaded to allow SciMLSensitivity.jl +Mooncake.jl's automatic differentiation system is overloaded to allow SciMLSensitivity.jl to redefine the way the derivatives are computed, allowing trade-offs between numerical stability, memory, and compute performance, similar to how ODE solver algorithms are chosen. @@ -107,16 +109,16 @@ chosen. The algorithms for differentiation calculation are called `AbstractSensitivityAlgorithms`, or `sensealg`s for short. These are chosen by passing the `sensealg` keyword argument into solve. -Let's demonstrate this by choosing the `QuadratureAdjoint` `sensealg` for the differentiation of +Let's demonstrate this by choosing the `GaussAdjoint` `sensealg` for the differentiation of this system: ```@example diffode -function sum_of_solution(u0, p) - _prob = ODE.remake(prob; u0, p) +function sum_of_solution(p) + _prob = ODE.remake(prob; p) sum(ODE.solve(_prob, ODE.Tsit5(), reltol = 1e-6, abstol = 1e-6, saveat = 0.1, sensealg = SMS.GaussAdjoint())) end -du01, dp1 = Zygote.gradient(sum_of_solution, u0, p) +dp1 = DI.gradient(sum_of_solution, backend, p) ``` Here this computes the derivative of the output with respect to the initial diff --git a/docs/src/index.md b/docs/src/index.md index f2396ee30..6f7443c8c 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -49,6 +49,8 @@ solve(prob, args...; sensealg = SMS.InterpolatingAdjoint(), checkpoints = sol.t, `solve` is fully compatible with automatic differentiation libraries like: + - [Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl) (recommended for reverse-mode) + - [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) - [Zygote.jl](https://fluxml.ai/Zygote.jl/stable/) - [ReverseDiff.jl](https://juliadiff.org/ReverseDiff.jl/) - [Tracker.jl](https://github.com/FluxML/Tracker.jl) diff --git a/docs/src/manual/differential_equation_sensitivities.md b/docs/src/manual/differential_equation_sensitivities.md index 6a376eabb..035bb6f65 100644 --- a/docs/src/manual/differential_equation_sensitivities.md +++ b/docs/src/manual/differential_equation_sensitivities.md @@ -24,7 +24,8 @@ Take for example this simple differential equation solve on Lotka-Volterra: ```julia import SciMLSensitivity as SMS import OrdinaryDiffEq as ODE -import Zygote +import Mooncake +import DifferentiationInterface as DI function fiip(du, u, p, t) du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2] @@ -34,31 +35,32 @@ p = [1.5, 1.0, 3.0, 1.0]; u0 = [1.0; 1.0]; prob = ODE.ODEProblem(fiip, u0, (0.0, 10.0), p) sol = ODE.solve(prob, ODE.Tsit5()) -loss(u0, p) = sum(ODE.solve(prob, ODE.Tsit5(); u0, p, saveat = 0.1)) -du0, dp = Zygote.gradient(loss, u0, p) +loss(p) = sum(ODE.solve(prob, ODE.Tsit5(); p, saveat = 0.1)) +backend = DI.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) +dp = DI.gradient(loss, backend, p) ``` This will compute the gradient of the loss function "sum of the values of the -solution to the ODE at timepoints dt=0.1" using an adjoint method, where `du0` -is the derivative of the loss function with respect to the initial condition -and `dp` is the derivative of the loss function with respect to the parameters. +solution to the ODE at timepoints dt=0.1" using an adjoint method, where `dp` +is the derivative of the loss function with respect to the parameters. -Because the gradient is calculated by `Zygote.gradient` and Zygote.jl is one of -the compatible AD libraries, this derivative calculation will be captured -by the `sensealg` system, and one of SciMLSensitivity.jl's adjoint overloads -will be used to compute the derivative. By default, if the `sensealg` keyword -argument is not defined, then a smart polyalgorithm is used to automatically -determine the most appropriate method for a given equation. +Because the gradient is calculated through DifferentiationInterface.jl with +`AutoMooncake` and Mooncake.jl is one of the compatible AD libraries, this +derivative calculation will be captured by the `sensealg` system, and one of +SciMLSensitivity.jl's adjoint overloads will be used to compute the +derivative. By default, if the `sensealg` keyword argument is not defined, +then a smart polyalgorithm is used to automatically determine the most +appropriate method for a given equation. Likewise, the `sensealg` argument can be given to directly control the method by which the derivative is computed. For example: ```julia -function loss(u0, p) - sum(ODE.solve(prob, ODE.Tsit5(); u0, p, saveat = 0.1, +function loss(p) + sum(ODE.solve(prob, ODE.Tsit5(); p, saveat = 0.1, sensealg = SMS.ForwardSensitivity())) end -du0, dp = Zygote.gradient(loss, u0, p) +dp = DI.gradient(loss, backend, p) ``` would do reverse-mode automatic differentiation of the loss function, but when reversing diff --git a/docs/src/tutorials/chaotic_ode.md b/docs/src/tutorials/chaotic_ode.md index f7da53ff2..efd48dc04 100644 --- a/docs/src/tutorials/chaotic_ode.md +++ b/docs/src/tutorials/chaotic_ode.md @@ -23,7 +23,8 @@ can be seen, for instance, when solving the [Lorenz system](https://en.wikipedia ```@example chaosode import OrdinaryDiffEq as ODE import SciMLSensitivity as SMS -import Zygote +import Mooncake +import DifferentiationInterface as DI import Plots import Random @@ -132,7 +133,7 @@ function G(p) sensealg = SMS.ForwardLSS(; g)) sum(getindex.(_sol.u, 3)) end -dp1 = Zygote.gradient(p -> G(p), p) +dp1 = DI.gradient(p -> G(p), DI.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)), p) ``` Alternatively, we can define the `ForwardLSSProblem` and solve it diff --git a/docs/src/tutorials/data_parallel.md b/docs/src/tutorials/data_parallel.md index 8ffb8e31b..cdeddc1dc 100644 --- a/docs/src/tutorials/data_parallel.md +++ b/docs/src/tutorials/data_parallel.md @@ -89,11 +89,13 @@ interface. The following is a full copy-paste example for the multithreading. Distributed and GPU minibatching are described below. + ```@example dataparallel import OrdinaryDiffEq as ODE import Optimization as OPT import OptimizationOptimisers as OPO import SciMLSensitivity as SMS +import Zygote pa = [1.0] u0 = [3.0] θ = [u0; pa] diff --git a/docs/src/tutorials/parameter_estimation_ode.md b/docs/src/tutorials/parameter_estimation_ode.md index 259b3af1e..f3ea1a0ea 100644 --- a/docs/src/tutorials/parameter_estimation_ode.md +++ b/docs/src/tutorials/parameter_estimation_ode.md @@ -10,7 +10,7 @@ import OrdinaryDiffEq as ODE import Optimization as OPT import OptimizationPolyalgorithms as OPA import SciMLSensitivity as SMS -import Zygote +import Mooncake import Plots function lotka_volterra!(du, u, p, t) @@ -54,7 +54,7 @@ callback = function (state, l) return false end -adtype = OPT.AutoZygote() +adtype = OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) optf = OPT.OptimizationFunction((x, p) -> loss(x), adtype) optprob = OPT.OptimizationProblem(optf, p) @@ -78,7 +78,7 @@ import OrdinaryDiffEq as ODE import Optimization as OPT import OptimizationPolyalgorithms as OPA import SciMLSensitivity as SMS -import Zygote +import Mooncake import Plots function lotka_volterra!(du, u, p, t) @@ -148,7 +148,7 @@ end Let's optimize the model. ```@example optode -adtype = OPT.AutoZygote() +adtype = OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) optf = OPT.OptimizationFunction((x, p) -> loss(x), adtype) optprob = OPT.OptimizationProblem(optf, p) diff --git a/docs/src/tutorials/training_tips/divergence.md b/docs/src/tutorials/training_tips/divergence.md index e72abee53..ef46a37e0 100644 --- a/docs/src/tutorials/training_tips/divergence.md +++ b/docs/src/tutorials/training_tips/divergence.md @@ -30,6 +30,7 @@ A full example making use of this trick is: ```@example divergence import OrdinaryDiffEq as ODE, SciMLSensitivity as SMS, SciMLBase, Optimization as OPT, OptimizationOptimisers as OPO, Plots +import Mooncake function lotka_volterra!(du, u, p, t) rab, wol = u @@ -65,7 +66,7 @@ function loss(p) end pinit = [1.2, 0.8, 2.5, 0.8] -adtype = OPT.AutoZygote() +adtype = OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) optf = OPT.OptimizationFunction((x, p) -> loss(x), adtype) optprob = OPT.OptimizationProblem(optf, pinit) @@ -74,9 +75,9 @@ res = OPT.solve(optprob, OPO.Adam(), maxiters = 1000) # res = OPT.solve(optprob,NLopt.LD_LBFGS(), maxiters = 1000) ### errors! ``` -You might notice that `AutoZygote` (default) fails for the above `OPT.solve` call -with Optim's optimizers, which happens because of Zygote's behavior for zero gradients, in -which case it returns `nothing`. To avoid such issues, you can just use a different version +You might notice that some AD backends fail for the above `OPT.solve` call +with Optim's optimizers, which can happen due to how reverse-mode AD handles zero gradients +(e.g., returning `nothing`). To avoid such issues, you can just use a different version of the same check which compares the size of the obtained solution and the data we have, shown below, which is easier to AD. diff --git a/docs/src/tutorials/training_tips/local_minima.md b/docs/src/tutorials/training_tips/local_minima.md index d3cda6c6a..8ff982486 100644 --- a/docs/src/tutorials/training_tips/local_minima.md +++ b/docs/src/tutorials/training_tips/local_minima.md @@ -19,7 +19,7 @@ on `(0,5.0)`. Naively, we use the same training strategy as before: import SciMLSensitivity as SMS import OrdinaryDiffEq as ODE, ComponentArrays as CA, Optimization as OPT, OptimizationOptimisers as OPO -import Lux, Plots, Random, Zygote +import Lux, Plots, Random, Mooncake rng = Random.default_rng() u0 = Float32[2.0; 0.0] @@ -75,7 +75,7 @@ function callback(state, l; doplot = false) return false end -adtype = OPT.AutoZygote() +adtype = OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) optf = OPT.OptimizationFunction((x, p) -> loss_neuralode(x), adtype) optprob = OPT.OptimizationProblem(optf, pinit) @@ -102,7 +102,7 @@ function predict_neuralode(p) Array(sol) end -adtype = OPT.AutoZygote() +adtype = OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) optf = OPT.OptimizationFunction((x, p) -> loss_neuralode(x), adtype) optprob = OPT.OptimizationProblem(optf, pinit) @@ -163,7 +163,7 @@ one could use a mix of (3) and (4), or breaking up the trajectory into chunks an import SciMLSensitivity as SMS import OrdinaryDiffEq as ODE, ComponentArrays as CA, Optimization as OPT, OptimizationOptimisers as OPO -import Lux, Plots, Random, Zygote +import Lux, Plots, Random, Mooncake #Starting example with tspan (0, 5) u0 = Float32[2.0; 0.0] @@ -217,7 +217,7 @@ predict_n_ode(p_init) loss_n_ode(p_init, nothing) res = OPT.solve( - OPT.OptimizationProblem(OPT.OptimizationFunction(loss_n_ode, OPT.AutoZygote()), p_init), + OPT.OptimizationProblem(OPT.OptimizationFunction(loss_n_ode, OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))), p_init), OPO.Adam(0.05); callback, maxiters = 1000) function predict_n_ode2(p) @@ -246,7 +246,7 @@ end #Here we reset the IC back to the original and train only the NODE parameters u0 = Float32[2.0; 0.0] res = OPT.solve( - OPT.OptimizationProblem(OPT.OptimizationFunction(loss_n_ode2, OPT.AutoZygote()), p_init.p), + OPT.OptimizationProblem(OPT.OptimizationFunction(loss_n_ode2, OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))), p_init.p), OPO.Adam(0.05); callback = callback2, maxiters = 1000) #Now use the same technique for a longer tspan (0, 10) @@ -266,11 +266,11 @@ prob = ODE.ODEProblem(dudt, u0, tspan) p_init = CA.ComponentArray(; u0, p) res = OPT.solve( - OPT.OptimizationProblem(OPT.OptimizationFunction(loss_n_ode, OPT.AutoZygote()), p_init), + OPT.OptimizationProblem(OPT.OptimizationFunction(loss_n_ode, OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))), p_init), OPO.Adam(0.05); callback, maxiters = 1000) res = OPT.solve( - OPT.OptimizationProblem(OPT.OptimizationFunction(loss_n_ode2, OPT.AutoZygote()), p_init.p), + OPT.OptimizationProblem(OPT.OptimizationFunction(loss_n_ode2, OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))), p_init.p), OPO.Adam(0.05); callback = callback2, maxiters = 1000) ``` diff --git a/docs/src/tutorials/training_tips/multiple_nn.md b/docs/src/tutorials/training_tips/multiple_nn.md index 98032aa89..56274a87e 100644 --- a/docs/src/tutorials/training_tips/multiple_nn.md +++ b/docs/src/tutorials/training_tips/multiple_nn.md @@ -10,6 +10,7 @@ The following is a fully working demo on the Fitzhugh-Nagumo ODE: import SciMLSensitivity as SMS import Lux, ComponentArrays as CA, Optimization as OPT, OptimizationOptimJL as OOJ, OptimizationOptimisers as OPO, OrdinaryDiffEq as ODE, Random +import Mooncake rng = Random.default_rng() Random.seed!(rng, 1) @@ -77,7 +78,7 @@ callback(θ, l) = begin end false end -adtype = OPT.AutoZygote() +adtype = OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) optf = OPT.OptimizationFunction((x, p) -> loss(x), adtype) optprob = OPT.OptimizationProblem(optf, p)