diff --git a/lib/OrdinaryDiffEqSDIRK/src/generic_imex_perform_step.jl b/lib/OrdinaryDiffEqSDIRK/src/generic_imex_perform_step.jl index 3388294fbc..51126d35bf 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/generic_imex_perform_step.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/generic_imex_perform_step.jl @@ -445,13 +445,13 @@ end return quote $setup if tab.explicit_first_stage - if integrator.f isa SplitFunction && tab.fsal && !repeat_step && !integrator.last_stepfail + if integrator.f isa SplitFunction && issplit(alg) && tab.fsal && !repeat_step && !integrator.last_stepfail f_impl(zs[1], integrator.uprev, p, integrator.t) zs[1] .*= dt else @.. broadcast = false zs[1] = dt * integrator.fsalfirst end - if integrator.f isa SplitFunction + if integrator.f isa SplitFunction && issplit(alg) @.. broadcast = false ks[1] = dt * integrator.fsalfirst - zs[1] end $stages_efs_true @@ -482,7 +482,7 @@ end $adaptive - if integrator.f isa SplitFunction + if integrator.f isa SplitFunction && issplit(alg) integrator.f(integrator.fsallast, u, p, t + dt) elseif tab.explicit_fsallast integrator.f(integrator.fsallast, u, p, t + tab.fsallast_c * dt) @@ -690,12 +690,12 @@ end return quote $setup if tab.explicit_first_stage - if integrator.f isa SplitFunction + if integrator.f isa SplitFunction && issplit(alg) $z1 = dt * f_impl(uprev, p, t) else $z1 = dt * integrator.fsalfirst end - if integrator.f isa SplitFunction + if integrator.f isa SplitFunction && issplit(alg) $k1 = dt * integrator.fsalfirst - $z1 end $stages_efs_true @@ -726,7 +726,7 @@ end $adaptive - if integrator.f isa SplitFunction + if integrator.f isa SplitFunction && issplit(alg) integrator.k[1] = integrator.fsalfirst integrator.fsallast = integrator.f(u, p, t + dt) integrator.k[2] = integrator.fsallast diff --git a/lib/OrdinaryDiffEqSDIRK/test/sdirk_convergence_tests.jl b/lib/OrdinaryDiffEqSDIRK/test/sdirk_convergence_tests.jl index 40b4398898..46959e5b1a 100644 --- a/lib/OrdinaryDiffEqSDIRK/test/sdirk_convergence_tests.jl +++ b/lib/OrdinaryDiffEqSDIRK/test/sdirk_convergence_tests.jl @@ -183,3 +183,27 @@ end sim_iip = test_convergence(dts, prob_iip, ARS343()) @test sim_iip.𝒪est[:l∞] ≈ 3 atol = testTol end + +# Regression test: Kvaerno3/4/5 with SplitODEProblem must integrate the full RHS (f1+f2), +# not just f1. These are non-IMEX (issplit=false) methods, so f.f2 must flow through +# fsalfirst rather than being split off into the explicit ks arrays (which have Ae=be=0). +# f = f1 + f2 = -u + 2u = u => exact solution u(t) = exp(t) * u0 +@testset "Kvaerno SplitODEProblem" begin + dts = 1 .// 2 .^ (8:-1:4) + + f1_oop = (u, p, t) -> -u + f2_oop = (u, p, t) -> 2u + ff_oop = SplitFunction(f1_oop, f2_oop; analytic = (u0, p, t) -> exp(t) * u0) + prob_oop = SplitODEProblem(ff_oop, 1.0, (0.0, 1.0)) + + f1_iip! = (du, u, p, t) -> (du .= -u) + f2_iip! = (du, u, p, t) -> (du .= 2u) + ff_iip = SplitFunction(f1_iip!, f2_iip!; analytic = (u0, p, t) -> exp(t) .* u0) + prob_iip = SplitODEProblem(ff_iip, [1.0, 0.5], (0.0, 1.0)) + + sim_oop = test_convergence(dts, prob_oop, Kvaerno4()) + @test sim_oop.𝒪est[:l∞] ≈ 4 atol = testTol + + sim_iip = test_convergence(dts, prob_iip, Kvaerno4()) + @test sim_iip.𝒪est[:l∞] ≈ 4 atol = testTol +end