diff --git a/lib/DiffEqBase/Project.toml b/lib/DiffEqBase/Project.toml index fbdcce670f0..c78a7791f91 100644 --- a/lib/DiffEqBase/Project.toml +++ b/lib/DiffEqBase/Project.toml @@ -108,6 +108,7 @@ Random = "1.9" RecursiveArrayTools = "4" Reexport = "1.0" ReverseDiff = "1" +RecursiveArrayToolsRaggedArrays = "1" SafeTestsets = "0.1" SciMLBase = "3" SciMLLogging = "1.9.1" @@ -143,6 +144,7 @@ Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +RecursiveArrayToolsRaggedArrays = "c384ba91-639a-44ca-823a-e1d3691ab84a" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -153,4 +155,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["DiffEqCallbacks", "Distributed", "Measurements", "Unitful", "FlexUnits", "LabelledArrays", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Pkg", "Random", "ReverseDiff", "StaticArrays", "SafeTestsets", "Test", "Distributions", "DynamicQuantities", "Aqua"] +test = ["DiffEqCallbacks", "Distributed", "Measurements", "Unitful", "FlexUnits", "LabelledArrays", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Pkg", "Random", "RecursiveArrayToolsRaggedArrays","ReverseDiff", "StaticArrays", "SafeTestsets", "Test", "Distributions", "DynamicQuantities", "Aqua"] diff --git a/lib/DiffEqBase/src/common_defaults.jl b/lib/DiffEqBase/src/common_defaults.jl index 85d44e4c529..813d24fe366 100644 --- a/lib/DiffEqBase/src/common_defaults.jl +++ b/lib/DiffEqBase/src/common_defaults.jl @@ -18,6 +18,10 @@ function UNITLESS_ABS2(x::RecursiveArrayTools.ArrayPartition) return mapreduce(UNITLESS_ABS2, abs2_and_sum, x.x, init = zero(real(value(eltype(x))))) end +function UNITLESS_ABS2(x::RecursiveArrayTools.AbstractRaggedVectorOfArray) + return mapreduce(UNITLESS_ABS2, +, x.u; init = zero(real(eltype(x)))) +end + UNITLESS_ABS2(f::F, x::Number) where {F} = abs2(f(x)) function UNITLESS_ABS2(f::F, x::AbstractArray) where {F} return mapreduce( @@ -37,6 +41,7 @@ recursive_length(u::Number) = length(u) recursive_length(u::AbstractArray{<:AbstractArray}) = sum(recursive_length, u) recursive_length(u::RecursiveArrayTools.ArrayPartition) = sum(recursive_length, u.x) recursive_length(u::RecursiveArrayTools.VectorOfArray) = sum(recursive_length, u.u) +recursive_length(u::RecursiveArrayTools.AbstractRaggedVectorOfArray) = sum(recursive_length, u.u; init = 0) function recursive_length( u::AbstractArray{ <:StaticArraysCore.StaticArray{S, <:Number}, @@ -102,6 +107,7 @@ function ODE_DEFAULT_NORM( u::Union{ AbstractArray, RecursiveArrayTools.AbstractVectorOfArray, + RecursiveArrayTools.AbstractRaggedVectorOfArray, }, t ) diff --git a/lib/DiffEqBase/test/ode_default_norm.jl b/lib/DiffEqBase/test/ode_default_norm.jl index 0d447802dcf..4118c3cc281 100644 --- a/lib/DiffEqBase/test/ode_default_norm.jl +++ b/lib/DiffEqBase/test/ode_default_norm.jl @@ -1,4 +1,4 @@ -using Test, RecursiveArrayTools, StaticArrays, ForwardDiff +using Test, RecursiveArrayTools, RecursiveArrayToolsRaggedArrays, StaticArrays, ForwardDiff using DiffEqBase: UNITLESS_ABS2, recursive_length, ODE_DEFAULT_NORM @@ -54,3 +54,11 @@ u8_ref = 1.2909944487358056 @test ODE_DEFAULT_NORM(u8, 4.0) ≈ u8_ref @test ODE_DEFAULT_NORM(u8, ForwardDiff.Dual{:b}(4.0, true)) isa Float64 @test ODE_DEFAULT_NORM(u8, ForwardDiff.Dual{:b}(4.0, true)) ≈ u8_ref + +r = RaggedVectorOfArray([ones(3), ones(3)]) # 6 ones +@test UNITLESS_ABS2(r) ≈ 6.0 +@test recursive_length(r) == 6 +# RMS norm of 6 ones = sqrt(6/6) = 1 +@test ODE_DEFAULT_NORM(r, 0.0) ≈ 1.0 +# Unnormalised Euclidean norm would be sqrt(6) ≈ 2.449 — make sure we don't get that +@test ODE_DEFAULT_NORM(r, 0.0) < 2.0