fix(ad): EnsembleSolution cotangent shape + redundant remake reconstructions#1347
Draft
ChrisRackauckas-Claude wants to merge 2 commits into
Conversation
- EnsembleSolution constructor adjoint/rrule: return cotangent for `sim`
as a plain Vector matching the constructor argument's shape, not
re-wrapped in another EnsembleSolution. Add `Tangent` overload for
cotangents arriving from the upstream `Array(::AbstractVectorOfArray)`
adjoint.
- Delete the stale `@adjoint ODESolution{T1...T15}` (15 type params,
ODESolution has 16 since `saved_subsystem` was added; the matching
ChainRulesCore rrule for the 16-param case already covers it).
- `remake(::ODEProblem; ...)` and 7 sibling overloads: guard the
`maybe_eager_initialize_problem` + `@reset prob.u0/p` block on
`initialization_data !== nothing` so the no-init-data path stays at a
single ODEProblem construction.
- Lower `remake(prob::ODEProblem; ...)` into a positional
`_remake_ode_inner(prob, f, u0, tspan, p, kwargs, ...)` helper and
attach the ChainRulesCore.rrule to the helper. Reverse-mode AD now
flows kwarg cotangents naturally without going through Zygote's
kwarg-cotangent plumbing. Cotangent distribution: if a kwarg was
passed (not `missing`), its cotangent goes to that positional;
otherwise it accumulates onto the corresponding `prob` field.
Tests in `test/downstream/ensemble_remake_reverse_mode_adjoints.jl`:
- Direct cotangent-shape assertions on the EnsembleSolution
constructor adjoint (Zygote and ChainRulesCore paths).
- Zygote-vs-ForwardDiff parity for `remake(::ODEProblem; u0)`,
`remake(::ODEProblem; p)`, and field-pass-through.
- Direct shape assertions on the `_remake_ode_inner` rrule's 12-tuple
return covering both kwarg-supplied and prob-fall-through routing.
Co-Authored-By: Chris Rackauckas <[email protected]>
b96c61c to
b1a3308
Compare
Drop the AbstractArray{T,N} reshape, AbstractArray{<:AbstractArray,1},
AbstractVectorOfArray, and EnsembleSolution-self dispatches. The realistic
AD chain feeds a partial NamedTuple cotangent (u = ...) from the upstream
Array(::AbstractVectorOfArray) adjoint; only the NamedTuple (Zygote ext)
and NamedTuple/Tangent (ChainRules ext) branches are reachable from the
realistic chain. The dropped branches were defensive coverage for direct
callers no production path uses.
Test rewritten to feed a NamedTuple cotangent directly, matching what the
upstream adjoint produces.
Co-Authored-By: Chris Rackauckas <[email protected]>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Reverse-mode-AD correctness fixes that surfaced while debugging
Zygote.gradientthroughsolve(::EnsembleProblem, alg, EnsembleGPUArray)(where the gradient was coming back exactly 2x the correct value vs. ForwardDiff and finite differences, when it didn't error out earlier).The full chain now matches ForwardDiff numerically with no Zygote internals patched and no piracy. Verified end-to-end on JLArrays:
1. EnsembleSolution constructor cotangent shape
Both
@adjoint EnsembleSolution(...)andChainRulesCore.rrule(::Type{EnsembleSolution}, ...)returned the cotangent forsimre-wrapped as anotherEnsembleSolution. SinceEnsembleSolution <: AbstractVectorOfArray, itsgetindex(es, i)is flat scalar indexing. Downstream pullbacks that walked the cotangent by row index received scalars where they expected per-row tangents, surfacing asDimensionMismatchdeeper in the trace.Fix: return a plain
Vectorfor thesimcotangent. Also handleTangentinput explicitly (struct-tangent from upstreamArray(::AbstractVectorOfArray)adjoint).2. Stale
@adjoint ODESolution{T1...T15}The signature in
SciMLBaseZygoteExt.jl:199lists 15 type params, but the currentODESolutionconstructor has 16. Dead code. Delete.3.
remake(::ODEProblem; ...)redundant reconstructionsThe remake body's
@reset prob.u0/plines run unconditionally even in the no-init-data path. Each@resetreconstructs the (mutable) ODEProblem; on a mutable struct, Zygote'sJnewre-accumulates gradient via thegrad_mutcache — exact 2x doubling per remake.Fix: guard on
initialization_data !== nothing. Applied to all 8 affected remake overloads.4. Positional
_remake_ode_innerhelper + rrule (the architectural piece)Zygote.chain_rrule_kwdrops kwarg cotangents at(nothing, nothing, dxs...), so an rrule onremake(prob::ODEProblem; kwargs...)cannot propagate gradients to kwargs likeu0=expr(p). Lower the kwargremake(prob::ODEProblem; ...)into a positional_remake_ode_inner(prob, f, u0, tspan, p, kwargs, ...)and define the rrule on the positional helper. If a kwarg was passed (notmissing), its cotangent goes to that positional; otherwise it accumulates ontoprob.field. Control flags getNoTangent. The kwarg gradients now flow naturally because the rrule operates on a non-kwarg signature.This is the piece that closes the full chain. Once it's in, no Zygote internals need patching for
EnsembleGPUArrayreverse-mode AD to match ForwardDiff.Tests
test/downstream/ensemble_remake_reverse_mode_adjoints.jl(no SciMLSensitivity dep — exercises SciMLBase-side cotangent flow directly):remake(::ODEProblem; u0)/; p/ field-pass-through gradient parity vs ForwardDiff (3 tests)_remake_ode_innerrrule cotangent distribution (7 tests)All 16 pass.
Companion PR
Pairs with SciML/RecursiveArrayTools.jl#587 (Array adjoint cotangent shape). Together they make
Zygote.gradientthroughsolve(::EnsembleProblem, alg, EnsembleGPUArray)match ForwardDiff/finite differences exactly.Future work (out of scope)
The positional-helper pattern (#4) should be applied to SDEProblem / DDEProblem / SDDEProblem / DAEProblem / NonlinearProblem / SteadyStateProblem / BVProblem
remakeoverloads too. ODEProblem ships here as the proof of concept; happy to scale up in a follow-up if the architecture is approved.Please ignore until reviewed by @ChrisRackauckas.
Test plan
test/downstream/adjoints.jlcontinues to pass