Skip to content

fix(ad): EnsembleSolution cotangent shape + redundant remake reconstructions#1347

Draft
ChrisRackauckas-Claude wants to merge 2 commits into
SciML:masterfrom
ChrisRackauckas-Claude:fix-ensemble-rrule-and-remake-double-construction
Draft

fix(ad): EnsembleSolution cotangent shape + redundant remake reconstructions#1347
ChrisRackauckas-Claude wants to merge 2 commits into
SciML:masterfrom
ChrisRackauckas-Claude:fix-ensemble-rrule-and-remake-double-construction

Conversation

@ChrisRackauckas-Claude
Copy link
Copy Markdown
Contributor

@ChrisRackauckas-Claude ChrisRackauckas-Claude commented May 8, 2026

Summary

Reverse-mode-AD correctness fixes that surfaced while debugging Zygote.gradient through solve(::EnsembleProblem, alg, EnsembleGPUArray) (where the gradient was coming back exactly 2x the correct value vs. ForwardDiff and finite differences, when it didn't error out earlier).

The full chain now matches ForwardDiff numerically with no Zygote internals patched and no piracy. Verified end-to-end on JLArrays:

Forward: 20.2634
Zygote:      Float32[37.526844, 150.10738]
ForwardDiff: Float32[37.52689, 150.10756]
Ratio Zg/FD: Float32[0.9999988, 0.9999988]   # 1.0 within Float32 eps

1. EnsembleSolution constructor cotangent shape

Both @adjoint EnsembleSolution(...) and ChainRulesCore.rrule(::Type{EnsembleSolution}, ...) returned the cotangent for sim re-wrapped as another EnsembleSolution. Since EnsembleSolution <: AbstractVectorOfArray, its getindex(es, i) is flat scalar indexing. Downstream pullbacks that walked the cotangent by row index received scalars where they expected per-row tangents, surfacing as DimensionMismatch deeper in the trace.

Fix: return a plain Vector for the sim cotangent. Also handle Tangent input explicitly (struct-tangent from upstream Array(::AbstractVectorOfArray) adjoint).

2. Stale @adjoint ODESolution{T1...T15}

The signature in SciMLBaseZygoteExt.jl:199 lists 15 type params, but the current ODESolution constructor has 16. Dead code. Delete.

3. remake(::ODEProblem; ...) redundant reconstructions

The remake body's @reset prob.u0/p lines run unconditionally even in the no-init-data path. Each @reset reconstructs the (mutable) ODEProblem; on a mutable struct, Zygote's Jnew re-accumulates gradient via the grad_mut cache — exact 2x doubling per remake.

Fix: guard on initialization_data !== nothing. Applied to all 8 affected remake overloads.

4. Positional _remake_ode_inner helper + rrule (the architectural piece)

Zygote.chain_rrule_kw drops kwarg cotangents at (nothing, nothing, dxs...), so an rrule on remake(prob::ODEProblem; kwargs...) cannot propagate gradients to kwargs like u0=expr(p). Lower the kwarg remake(prob::ODEProblem; ...) into a positional _remake_ode_inner(prob, f, u0, tspan, p, kwargs, ...) and define the rrule on the positional helper. If a kwarg was passed (not missing), its cotangent goes to that positional; otherwise it accumulates onto prob.field. Control flags get NoTangent. The kwarg gradients now flow naturally because the rrule operates on a non-kwarg signature.

This is the piece that closes the full chain. Once it's in, no Zygote internals need patching for EnsembleGPUArray reverse-mode AD to match ForwardDiff.

Tests

test/downstream/ensemble_remake_reverse_mode_adjoints.jl (no SciMLSensitivity dep — exercises SciMLBase-side cotangent flow directly):

  • EnsembleSolution constructor cotangent shape (4 tests)
  • EnsembleSolution rrule cotangent shape (2 tests)
  • remake(::ODEProblem; u0) / ; p / field-pass-through gradient parity vs ForwardDiff (3 tests)
  • _remake_ode_inner rrule cotangent distribution (7 tests)

All 16 pass.

Companion PR

Pairs with SciML/RecursiveArrayTools.jl#587 (Array adjoint cotangent shape). Together they make Zygote.gradient through solve(::EnsembleProblem, alg, EnsembleGPUArray) match ForwardDiff/finite differences exactly.

Future work (out of scope)

The positional-helper pattern (#4) should be applied to SDEProblem / DDEProblem / SDDEProblem / DAEProblem / NonlinearProblem / SteadyStateProblem / BVProblem remake overloads too. ODEProblem ships here as the proof of concept; happy to scale up in a follow-up if the architecture is approved.

Please ignore until reviewed by @ChrisRackauckas.

Test plan

  • Existing test/downstream/adjoints.jl continues to pass
  • New testsets pass (16 tests)
  • No regressions in remake-related suites

- EnsembleSolution constructor adjoint/rrule: return cotangent for `sim`
  as a plain Vector matching the constructor argument's shape, not
  re-wrapped in another EnsembleSolution. Add `Tangent` overload for
  cotangents arriving from the upstream `Array(::AbstractVectorOfArray)`
  adjoint.

- Delete the stale `@adjoint ODESolution{T1...T15}` (15 type params,
  ODESolution has 16 since `saved_subsystem` was added; the matching
  ChainRulesCore rrule for the 16-param case already covers it).

- `remake(::ODEProblem; ...)` and 7 sibling overloads: guard the
  `maybe_eager_initialize_problem` + `@reset prob.u0/p` block on
  `initialization_data !== nothing` so the no-init-data path stays at a
  single ODEProblem construction.

- Lower `remake(prob::ODEProblem; ...)` into a positional
  `_remake_ode_inner(prob, f, u0, tspan, p, kwargs, ...)` helper and
  attach the ChainRulesCore.rrule to the helper. Reverse-mode AD now
  flows kwarg cotangents naturally without going through Zygote's
  kwarg-cotangent plumbing. Cotangent distribution: if a kwarg was
  passed (not `missing`), its cotangent goes to that positional;
  otherwise it accumulates onto the corresponding `prob` field.

Tests in `test/downstream/ensemble_remake_reverse_mode_adjoints.jl`:
  - Direct cotangent-shape assertions on the EnsembleSolution
    constructor adjoint (Zygote and ChainRulesCore paths).
  - Zygote-vs-ForwardDiff parity for `remake(::ODEProblem; u0)`,
    `remake(::ODEProblem; p)`, and field-pass-through.
  - Direct shape assertions on the `_remake_ode_inner` rrule's 12-tuple
    return covering both kwarg-supplied and prob-fall-through routing.

Co-Authored-By: Chris Rackauckas <[email protected]>
@ChrisRackauckas-Claude ChrisRackauckas-Claude force-pushed the fix-ensemble-rrule-and-remake-double-construction branch from b96c61c to b1a3308 Compare May 8, 2026 18:24
Drop the AbstractArray{T,N} reshape, AbstractArray{<:AbstractArray,1},
AbstractVectorOfArray, and EnsembleSolution-self dispatches. The realistic
AD chain feeds a partial NamedTuple cotangent (u = ...) from the upstream
Array(::AbstractVectorOfArray) adjoint; only the NamedTuple (Zygote ext)
and NamedTuple/Tangent (ChainRules ext) branches are reachable from the
realistic chain. The dropped branches were defensive coverage for direct
callers no production path uses.

Test rewritten to feed a NamedTuple cotangent directly, matching what the
upstream adjoint produces.

Co-Authored-By: Chris Rackauckas <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants