Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 55 additions & 40 deletions ext/SciMLBaseChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,62 @@
module SciMLBaseChainRulesCoreExt

using SciMLBase
using SciMLBase: getobserved
using SciMLBase: getobserved, ODEProblem, _remake_ode_inner
import ChainRulesCore
import ChainRulesCore: NoTangent, @non_differentiable, zero_tangent, rrule_via_ad
import ChainRulesCore: NoTangent, ZeroTangent, AbstractZero, Tangent,
@non_differentiable, zero_tangent, rrule_via_ad, backing
using SymbolicIndexingInterface
using RecursiveArrayTools: AbstractVectorOfArray

@inline function _remake_ode_inner_split_cotangent(Δ, f, u0, tspan, p)
Δ_nt = if Δ isa Tangent
b = backing(Δ)
b isa NamedTuple ? b :
(b isa Tuple && length(b) == 1 && b[1] isa NamedTuple ? b[1] : NamedTuple())
elseif Δ isa NamedTuple
Δ
else
NamedTuple()
end
get_cot(field::Symbol) = (Δ_nt isa NamedTuple && haskey(Δ_nt, field)) ?
Δ_nt[field] : nothing
f_cot = (f === missing) ? NoTangent() : get_cot(:f)
u0_cot = (u0 === missing) ? NoTangent() : get_cot(:u0)
tspan_cot = (tspan === missing) ? NoTangent() : get_cot(:tspan)
p_cot = (p === missing) ? NoTangent() : get_cot(:p)
prob_cot = (
f = (f === missing) ? get_cot(:f) : nothing,
u0 = (u0 === missing) ? get_cot(:u0) : nothing,
tspan = (tspan === missing) ? get_cot(:tspan) : nothing,
p = (p === missing) ? get_cot(:p) : nothing,
kwargs = nothing,
problem_type = nothing,
)
return prob_cot, f_cot, u0_cot, tspan_cot, p_cot
end

function ChainRulesCore.rrule(
::typeof(_remake_ode_inner),
prob::ODEProblem, f, u0, tspan, p, kwargs,
interpret_symbolicmap, build_initializeprob, use_defaults,
lazy_initialization, _kwargs
)
new_prob = _remake_ode_inner(
prob, f, u0, tspan, p, kwargs,
interpret_symbolicmap, build_initializeprob, use_defaults,
lazy_initialization, _kwargs
)
function _remake_ode_inner_pullback(Δ)
prob_cot, f_cot, u0_cot, tspan_cot, p_cot = _remake_ode_inner_split_cotangent(
Δ, f, u0, tspan, p
)
return (NoTangent(), prob_cot, f_cot, u0_cot, tspan_cot, p_cot,
NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(),
NoTangent())
end
return new_prob, _remake_ode_inner_pullback
end

@non_differentiable SciMLBase.checkkwargs(kwargshandle)

# numargs and isinplace use `methods()` for runtime reflection and are not differentiable.
Expand Down Expand Up @@ -163,51 +213,16 @@ function ChainRulesCore.rrule(
RODESolutionAdjoint
end

# EnsembleSolution rrule with full support for various gradient types
# Matches the Zygote extension implementation for consistency
function ChainRulesCore.rrule(
::Type{EnsembleSolution}, sim, time, converged, stats = nothing
)
out = EnsembleSolution(sim, time, converged, stats)
function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N}
arrarr = [
[
p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i]
for j in 1:size(p̄)[end - 1]
] for i in 1:size(p̄)[end]
]
return (
NoTangent(),
EnsembleSolution(arrarr, 0.0, true, stats),
NoTangent(),
NoTangent(),
NoTangent(),
)
end
function EnsembleSolution_adjoint(p̄::AbstractArray{<:AbstractArray, 1})
return (
NoTangent(),
EnsembleSolution(p̄, 0.0, true, stats),
NoTangent(),
NoTangent(),
NoTangent(),
)
end
function EnsembleSolution_adjoint(p̄::AbstractVectorOfArray)
return (
NoTangent(),
EnsembleSolution(p̄, 0.0, true, stats),
NoTangent(),
NoTangent(),
NoTangent(),
)
end
function EnsembleSolution_adjoint(p̄::EnsembleSolution)
return (NoTangent(), p̄, NoTangent(), NoTangent(), NoTangent())
end
function EnsembleSolution_adjoint(p̄::NamedTuple)
return (NoTangent(), p̄.u, NoTangent(), NoTangent(), NoTangent())
end
function EnsembleSolution_adjoint(p̄::Tangent)
return (NoTangent(), backing(p̄).u, NoTangent(), NoTangent(), NoTangent())
end
return out, EnsembleSolution_adjoint
end

Expand Down
36 changes: 0 additions & 36 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,6 @@ end

@adjoint function EnsembleSolution(sim, time, converged, stats)
out = EnsembleSolution(sim, time, converged, stats)
function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N}
arrarr = [
[
p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i]
for j in 1:size(p̄)[end - 1]
] for i in 1:size(p̄)[end]
]
(EnsembleSolution(arrarr, 0.0, true, stats), nothing, nothing, nothing)
end
function EnsembleSolution_adjoint(p̄::AbstractArray{<:AbstractArray, 1})
(EnsembleSolution(p̄, 0.0, true, stats), nothing, nothing, nothing)
end
function EnsembleSolution_adjoint(p̄::RecursiveArrayTools.AbstractVectorOfArray)
(EnsembleSolution(p̄, 0.0, true, stats), nothing, nothing, nothing)
end
function EnsembleSolution_adjoint(p̄::EnsembleSolution)
(p̄, nothing, nothing, nothing)
end
function EnsembleSolution_adjoint(p̄::NamedTuple)
(p̄.u, nothing, nothing, nothing)
end
Expand Down Expand Up @@ -190,24 +172,6 @@ end
VA[sym], NonlinearSolution_getindex_pullback
end

@adjoint function ODESolution{
T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15,
}(
u,
args...
) where {
T1, T2, T3, T4, T5, T6, T7, T8,
T9, T10, T11, T12, T13, T14, T15,
}
function ODESolutionAdjoint(ȳ)
(ȳ, ntuple(_ -> nothing, length(args))...)
end

ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15}(
u, args...
),
ODESolutionAdjoint
end

@adjoint function SDEProblem{uType, tType, isinplace, P, NP, F, G, K, ND}(
u,
Expand Down
76 changes: 52 additions & 24 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,18 @@ function remake(
lazy_initialization = nothing,
_kwargs...
)
return _remake_ode_inner(
prob, f, u0, tspan, p, kwargs,
interpret_symbolicmap, build_initializeprob, use_defaults,
lazy_initialization, values(_kwargs)
)
end

function _remake_ode_inner(
prob::ODEProblem, f, u0, tspan, p, kwargs,
interpret_symbolicmap, build_initializeprob, use_defaults,
lazy_initialization, _kwargs
)
if tspan === missing
tspan = prob.tspan
end
Expand Down Expand Up @@ -382,9 +394,11 @@ function remake(
ODEProblem{iip}(f, newu0, tspan, newp, prob.problem_type; kwargs...)
end

u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
@reset prob.u0 = u0
@reset prob.p = p
if initialization_data !== nothing
u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
@reset prob.u0 = u0
@reset prob.p = p
end

return prob
end
Expand Down Expand Up @@ -600,9 +614,11 @@ function remake(
SDEProblem{iip}(f, newu0, tspan, newp; noise, noise_rate_prototype, seed, kwargs...)
end

u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
@reset prob.u0 = u0
@reset prob.p = p
if initialization_data !== nothing
u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
@reset prob.u0 = u0
@reset prob.p = p
end

return prob
end
Expand Down Expand Up @@ -667,9 +683,11 @@ function remake(
)
end

u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
@reset prob.u0 = u0
@reset prob.p = p
if initialization_data !== nothing
u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
@reset prob.u0 = u0
@reset prob.p = p
end

return prob
end
Expand Down Expand Up @@ -764,9 +782,11 @@ function remake(
)
end

u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
@reset prob.u0 = u0
@reset prob.p = p
if initialization_data !== nothing
u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
@reset prob.u0 = u0
@reset prob.p = p
end

return prob
end
Expand Down Expand Up @@ -826,9 +846,11 @@ function remake(
DAEProblem{iip}(f, du0, newu0, tspan, newp; differential_vars, kwargs...)
end

u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
@reset prob.u0 = u0
@reset prob.p = p
if initialization_data !== nothing
u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
@reset prob.u0 = u0
@reset prob.p = p
end

return prob
end
Expand Down Expand Up @@ -953,9 +975,11 @@ function remake(
)
end

u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
@reset prob.u0 = u0
@reset prob.p = p
if initialization_data !== nothing
u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
@reset prob.u0 = u0
@reset prob.p = p
end

return prob
end
Expand Down Expand Up @@ -1000,9 +1024,11 @@ function remake(
SteadyStateProblem{isinplace(prob)}(f = f, u0 = newu0, p = newp; kwargs...)
end

u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
@reset prob.u0 = u0
@reset prob.p = p
if initialization_data !== nothing
u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
@reset prob.u0 = u0
@reset prob.p = p
end

return prob
end
Expand Down Expand Up @@ -1048,9 +1074,11 @@ function remake(
)
end

u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
@reset prob.u0 = u0
@reset prob.p = p
if initialization_data !== nothing
u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
@reset prob.u0 = u0
@reset prob.p = p
end

return prob
end
Expand Down
85 changes: 85 additions & 0 deletions test/downstream/ensemble_remake_reverse_mode_adjoints.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
using SciMLBase, OrdinaryDiffEq, Test
using SciMLBase: EnsembleProblem, EnsembleSerial, EnsembleSolution
using Zygote, ForwardDiff
import ChainRulesCore

@testset "EnsembleSolution constructor pulls NamedTuple cotangent" begin
f(u, p, t) = -u
prob = ODEProblem(f, [1.0], (0.0, 1.0))
sols = [solve(prob, Tsit5(); saveat = 0.5) for _ in 1:3]
arrarr = [[copy(s.u[j]) for j in eachindex(s.u)] for s in sols]

_, back = Zygote.pullback(EnsembleSolution, sols, 0.0, true, nothing)
sim_cot, t_cot, c_cot, s_cot = back((u = arrarr,))
@test sim_cot == arrarr
@test t_cot === nothing && c_cot === nothing && s_cot === nothing

_, pb = ChainRulesCore.rrule(EnsembleSolution, sols, 0.0, true, nothing)
cot = pb((u = arrarr,))
@test cot[2] == arrarr
@test cot[1] === ChainRulesCore.NoTangent()
@test all(cot[i] === ChainRulesCore.NoTangent() for i in 3:5)

cot_t = pb(ChainRulesCore.Tangent{Any}(; u = arrarr))
@test cot_t[2] == arrarr
end

@testset "remake(::ODEProblem; u0) gradient parity" begin
f(u, p, t) = u
base_prob = ODEProblem(f, [0.0, 0.0], (0.0, 1.0), [1.0])
loss(p) = (q = remake(base_prob, u0 = [p[1] * 2, p[1] + 5]); sum(abs2, q.u0))
p0 = [3.0]
@test Zygote.gradient(loss, p0)[1] ≈ ForwardDiff.gradient(loss, p0) rtol=1e-6
end

@testset "remake(::ODEProblem; p) gradient parity" begin
f(u, p, t) = p[1] * u
base_prob = ODEProblem(f, [1.0], (0.0, 1.0), [0.5])
loss(p) = (q = remake(base_prob, p = [p[1] * 3]); sum(abs2, q.p))
p0 = [2.0]
@test Zygote.gradient(loss, p0)[1] ≈ ForwardDiff.gradient(loss, p0) rtol=1e-6
end

@testset "remake field-pass-through gradient parity" begin
f(u, p, t) = u
base_prob = ODEProblem(f, [1.0], (0.0, 1.0), [1.0])
loss(p) = (q = remake(base_prob, u0 = [p[1]]); sum(abs2, q.u0))
p0 = [2.5]
@test Zygote.gradient(loss, p0)[1] ≈ ForwardDiff.gradient(loss, p0) rtol=1e-6
end

@testset "_remake_ode_inner rrule cotangent distribution" begin
base_prob = ODEProblem((u, p, t) -> u, [1.0, 2.0], (0.0, 1.0), [3.0])
Δ_u0 = [10.0, 20.0]
Δ = ChainRulesCore.Tangent{Any}(;
f = ChainRulesCore.NoTangent(),
u0 = Δ_u0,
tspan = ChainRulesCore.NoTangent(),
p = ChainRulesCore.NoTangent(),
kwargs = ChainRulesCore.NoTangent(),
problem_type = ChainRulesCore.NoTangent(),
)

# u0 supplied → cotangent flows to the u0 positional.
_, pb = ChainRulesCore.rrule(
SciMLBase._remake_ode_inner,
base_prob, missing, [9.9, 8.8], missing, missing, missing,
true, Val{true}, false, nothing, NamedTuple()
)
cot = pb(Δ)
@test length(cot) == 12
@test cot[1] === ChainRulesCore.NoTangent()
@test cot[4] == Δ_u0
@test cot[2].u0 === nothing
@test all(cot[i] === ChainRulesCore.NoTangent() for i in 7:12)

# u0 not supplied → cotangent accumulates onto prob.u0.
_, pb2 = ChainRulesCore.rrule(
SciMLBase._remake_ode_inner,
base_prob, missing, missing, missing, [99.0], missing,
true, Val{true}, false, nothing, NamedTuple()
)
cot2 = pb2(Δ)
@test cot2[4] === ChainRulesCore.NoTangent()
@test cot2[2].u0 == Δ_u0
end
Loading
Loading