From 0232d87824e851f384e01a382e025890c36ef5ef Mon Sep 17 00:00:00 2001 From: wheeheee <104880306+wheeheee@users.noreply.github.com> Date: Sun, 1 Mar 2026 22:07:18 +0800 Subject: [PATCH 1/9] if generated Val(dim) fix check --- src/plan.jl | 57 +++++++++++++++++++++++++++++++++-------------------- 1 file changed, 36 insertions(+), 21 deletions(-) diff --git a/src/plan.jl b/src/plan.jl index 8ccd4b9..d6cdcbb 100644 --- a/src/plan.jl +++ b/src/plan.jl @@ -175,7 +175,7 @@ end throw(DimensionMismatch("input array has axes $(axes(X)), but output array has axes $(axes(out))")) elseif size(p) != size(X) throw(DimensionMismatch("plan has size $(size(p)), but input array has size $(size(X))")) - elseif !(p.region == N || p.region == 1:N) + elseif !(p.region == 1:N || p.region == 1) throw(DimensionMismatch("Plan region is outside array dimensions.")) end @@ -195,10 +195,7 @@ end resize!(ibuf, n) cg = p.callgraph[dim] - Rpre_{dim} = CartesianIndices(sz[1:dim-1]) - Rpost_{dim} = CartesianIndices(sz[dim+1:N]) - - fft_along_dim!(out, ibuf, obuf, cg, dir, Rpre_{dim}, Rpost_{dim}) + fft_along_dim!(out, ibuf, obuf, cg, dir, Val(dim)) end return out @@ -214,11 +211,12 @@ function LinearAlgebra.mul!( Base.require_one_based_indexing(out, X) if size(out) != size(X) throw(DimensionMismatch("input array has axes $(axes(X)), but output array has axes $(axes(out))")) + elseif length(p.region) != M || !issorted(p.region; lt=(<=)) + throw(DimensionMismatch("Region is invalid.")) elseif M > N || first(p.region) < 1 || last(p.region) > N throw(DimensionMismatch("Plan region is outside array dimensions.")) end - sz = size(X) max_sz = maximum(Base.Fix1(size, out), p.region) obuf = Vector{T}(undef, max_sz) ibuf = Vector{T}(undef, max_sz) @@ -228,32 +226,49 @@ function LinearAlgebra.mul!( copyto!(out, X) # operate in-place on output array - # don't use generated functions because this cannot be type-stable anyway - for dim in 1:M - pdim = p.region[dim] - n = size(out, pdim) - resize!(obuf, n) - resize!(ibuf, n) - cg = p.callgraph[dim] - - Rpre = CartesianIndices(sz[1:pdim-1]) - Rpost = CartesianIndices(sz[pdim+1:N]) + if @generated + quote + k = 1 + # region is assumed to be pre-sorted during planning + Base.Cartesian.@nexprs $N dim -> begin + if p.region[k] == dim + n = size(out, dim) + resize!(obuf, n) + resize!(ibuf, n) + cg = p.callgraph[k] + + fft_along_dim!(out, ibuf, obuf, cg, dir, Val(dim)) + + k = min(k + 1, M) + end + end + end + else + for dim in 1:M + pdim = p.region[dim] + n = size(out, pdim) + resize!(obuf, n) + resize!(ibuf, n) + cg = p.callgraph[dim] - fft_along_dim!(out, ibuf, obuf, cg, dir, Rpre, Rpost) + fft_along_dim!(out, ibuf, obuf, cg, dir, Val(pdim)) + end end return out end function fft_along_dim!( - A::AbstractArray, + A::AbstractArray{U,N}, ibuf::Vector{T}, obuf::Vector{T}, cg::CallGraph{T}, d::Direction, - Rpre::CartesianIndices{M}, Rpost::CartesianIndices -) where {T <: Complex{<:AbstractFloat}, M} + ::Val{dim} +) where {T <: Complex{<:AbstractFloat}, U, N, dim} + sz = size(A) + Rpre = CartesianIndices(sz[1:dim-1]) + Rpost = CartesianIndices(sz[dim+1:N]) t = cg[1].type - dim = M + 1 cols = eachindex(axes(A, dim), ibuf, obuf) for Ipost in Rpost, Ipre in Rpre From 200dd831f70c56f1bdcab945ec29da4d0ef4a4c5 Mon Sep 17 00:00:00 2001 From: wheeheee <104880306+wheeheee@users.noreply.github.com> Date: Sun, 1 Mar 2026 22:07:18 +0800 Subject: [PATCH 2/9] another for 1d plan --- src/plan.jl | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/src/plan.jl b/src/plan.jl index d6cdcbb..8023eed 100644 --- a/src/plan.jl +++ b/src/plan.jl @@ -139,24 +139,36 @@ end #### 1D plan ND array function LinearAlgebra.mul!(y::AbstractArray{U,N}, p::FFTAPlan_cx{T,1}, x::AbstractArray{T,N}) where {T,U,N} Base.require_one_based_indexing(x, y) - if axes(x) != axes(y) - throw(DimensionMismatch("input array has axes $(axes(x)), but output array has axes $(axes(y))")) + + ax_x, ax_y = axes(x), axes(y) + if ax_x != ax_y + throw(DimensionMismatch("input array has axes $ax_x, but output array has axes $ax_y")) + end + + R1 = p.region[] + plen, xlen = size(p, 1), size(x, R1) + if plen != xlen + throw(DimensionMismatch("plan has size $plen, but input array has size $xlen along region $R1")) end - if size(p, 1) != size(x, p.region[]) - throw(DimensionMismatch("plan has size $(size(p, 1)), but input array has size $(size(x, p.region[])) along region $(p.region[])")) + + if @generated + quote + Base.Cartesian.@nif $N d -> (d == R1) dim -> (_mul_loop!(y, x, p, Val(dim))) + end + else + _mul_loop!(y, x, p, Val(R1)) end - Rpre = CartesianIndices(size(x)[1:p.region[]-1]) - Rpost = CartesianIndices(size(x)[p.region[]+1:end]) - _mul_loop!(y, x, Rpre, Rpost, p) return y end function _mul_loop!( y::AbstractArray{U,N}, x::AbstractArray{T,N}, - Rpre::CartesianIndices, - Rpost::CartesianIndices, - p::FFTAPlan_cx{T,1}) where {T,U,N} + p::FFTAPlan_cx{T,1}, + ::Val{R} +) where {T,U,N,R} + Rpre = CartesianIndices(size(x)[1:R-1]) + Rpost = CartesianIndices(size(x)[R+1:N]) for Ipost in Rpost, Ipre in Rpre @views fft!(y[Ipre,:,Ipost], x[Ipre,:,Ipost], 1, 1, p.dir, p.callgraph[1][1].type, p.callgraph[1], 1) end From f0841d631597380388db5d9f027745f7b00571d6 Mon Sep 17 00:00:00 2001 From: wheeheee <104880306+wheeheee@users.noreply.github.com> Date: Mon, 2 Mar 2026 22:23:10 +0800 Subject: [PATCH 3/9] don't accidentally mutate when sorting sort copy rename region for type stability --- src/plan.jl | 98 +++++++++++++++++++++++++++++++++++------------------ 1 file changed, 65 insertions(+), 33 deletions(-) diff --git a/src/plan.jl b/src/plan.jl index 8023eed..1389b01 100644 --- a/src/plan.jl +++ b/src/plan.jl @@ -4,7 +4,9 @@ abstract type FFTAPlan{T,N} <: AbstractFFTs.Plan{T} end struct FFTAInvPlan{T,N} <: FFTAPlan{T,N} end -struct FFTAPlan_cx{T,N,R<:Union{Int,AbstractVector{Int}}} <: FFTAPlan{T,N} +const RegionTypes{N} = Union{Int,AbstractVector{Int},NTuple{N,Int}} + +struct FFTAPlan_cx{T,N,R<:RegionTypes{N}} <: FFTAPlan{T,N} callgraph::NTuple{N,CallGraph{T}} region::R dir::Direction @@ -13,11 +15,11 @@ end function FFTAPlan_cx{T,N}( cg::NTuple{N,CallGraph{T}}, r::R, dir::Direction, pinv::FFTAInvPlan{T,N} -) where {T,N,R<:Union{Int,AbstractVector{Int}}} +) where {T,N,R<:RegionTypes{N}} FFTAPlan_cx{T,N,R}(cg, r, dir, pinv) end -struct FFTAPlan_re{T,N,R<:Union{Int,AbstractVector{Int}}} <: FFTAPlan{T,N} +struct FFTAPlan_re{T,N,R<:RegionTypes{N}} <: FFTAPlan{T,N} callgraph::NTuple{N,CallGraph{T}} region::R dir::Direction @@ -27,7 +29,7 @@ end function FFTAPlan_re{T,N}( cg::NTuple{N,CallGraph{T}}, r::R, dir::Direction, pinv::FFTAInvPlan{T,N}, flen::Int -) where {T,N,R<:Union{Int,AbstractVector{Int}}} +) where {T,N,R<:RegionTypes{N}} FFTAPlan_re{T,N,R}(cg, r, dir, pinv, flen) end @@ -46,37 +48,62 @@ Base.size(p::FFTAPlan{<:Any,N}) where N = ntuple(Base.Fix1(size, p), Val{N}()) Base.complex(p::FFTAPlan_re{T,N,R}) where {T,N,R} = FFTAPlan_cx{T,N,R}(p.callgraph, p.region, p.dir, p.pinv) -AbstractFFTs.plan_fft(x::AbstractArray{T,N}, region::R; kwargs...) where {T<:Complex,N,R} = +function _sort(region::T)::T where {N,T<:NTuple{N,Int}} + if N == 2 + minmax(region[1], region[2]) + elseif N == 3 + t1, t2, t3 = region + t1, t2 = minmax(t1, t2) + t2, t3 = minmax(t2, t3) + t1, t2 = minmax(t1, t2) + return (t1, t2, t3) + else + @static VERSION >= v"1.12" ? sort(region) : NTuple{N}(sort!(collect(region))) + end +end + +_sort(region::T) where T<:RegionTypes = issorted(region) ? copy(region) : sort(region) + +AbstractFFTs.plan_fft(x::AbstractArray{T,N}, region; kwargs...) where {T<:Complex,N} = _plan_fft(x, region, FFT_FORWARD; kwargs...) -AbstractFFTs.plan_bfft(x::AbstractArray{T,N}, region::R; kwargs...) where {T<:Complex,N,R} = +AbstractFFTs.plan_bfft(x::AbstractArray{T,N}, region; kwargs...) where {T<:Complex,N} = _plan_fft(x, region, FFT_BACKWARD; kwargs...) -function _plan_fft(x::AbstractArray{T,N}, region::R, dir::Direction; BLUESTEIN_CUTOFF=DEFAULT_BLUESTEIN_CUTOFF, _kwargs...) where {T<:Complex,N,R} - FFTN = length(region) - if FFTN == 1 +function _plan_fft( + x::AbstractArray{T,N}, + region::RegionTypes, + dir::Direction; + BLUESTEIN_CUTOFF=DEFAULT_BLUESTEIN_CUTOFF, _kwargs... +) where {T<:Complex,N} + M = length(region) + if M == 1 R1 = Int(region[]) g = CallGraph{T}(size(x, R1), BLUESTEIN_CUTOFF) pinv = FFTAInvPlan{T,1}() return FFTAPlan_cx{T,1,Int}((g,), R1, dir, pinv) - elseif FFTN == 2 - sort!(region) - g1 = CallGraph{T}(size(x, region[1]), BLUESTEIN_CUTOFF) - g2 = CallGraph{T}(size(x, region[2]), BLUESTEIN_CUTOFF) + elseif M == 2 + R2 = _sort(region) + g1 = CallGraph{T}(size(x, R2[1]), BLUESTEIN_CUTOFF) + g2 = CallGraph{T}(size(x, R2[2]), BLUESTEIN_CUTOFF) pinv = FFTAInvPlan{T,2}() - return FFTAPlan_cx{T,2,R}((g1, g2), region, dir, pinv) + return FFTAPlan_cx{T,2,typeof(R2)}((g1, g2), R2, dir, pinv) else - sort!(region) - return FFTAPlan_cx{T,FFTN,R}( - ntuple(i -> CallGraph{T}(size(x, region[i]), BLUESTEIN_CUTOFF), Val(FFTN)), - region, dir, FFTAInvPlan{T,FFTN}() + RM = _sort(region) + return FFTAPlan_cx{T,M,typeof(RM)}( + ntuple(i -> CallGraph{T}(size(x, RM[i]), BLUESTEIN_CUTOFF), Val(M)), + RM, dir, FFTAInvPlan{T,M}() ) end end -function AbstractFFTs.plan_rfft(x::AbstractArray{T,N}, region::R; BLUESTEIN_CUTOFF=DEFAULT_BLUESTEIN_CUTOFF, _kwargs...) where {T<:Real,N,R} - FFTN = length(region) - if FFTN == 1 +function AbstractFFTs.plan_rfft( + x::AbstractArray{T,N}, + region::RegionTypes; + BLUESTEIN_CUTOFF=DEFAULT_BLUESTEIN_CUTOFF, _kwargs... +) where {T<:Real,N} + M = length(region) + if M == 1 R1 = Int(region[]) n = size(x, R1) # For even length problems, we solve the real problem with @@ -86,20 +113,25 @@ function AbstractFFTs.plan_rfft(x::AbstractArray{T,N}, region::R; BLUESTEIN_CUTO g = CallGraph{Complex{T}}(nn, BLUESTEIN_CUTOFF) pinv = FFTAInvPlan{Complex{T},1}() return FFTAPlan_re{Complex{T},1,Int}((g,), R1, FFT_FORWARD, pinv, n) - elseif FFTN == 2 - sort!(region) - g1 = CallGraph{Complex{T}}(size(x, region[1]), BLUESTEIN_CUTOFF) - g2 = CallGraph{Complex{T}}(size(x, region[2]), BLUESTEIN_CUTOFF) + elseif M == 2 + R2 = _sort(region) + g1 = CallGraph{Complex{T}}(size(x, R2[1]), BLUESTEIN_CUTOFF) + g2 = CallGraph{Complex{T}}(size(x, R2[2]), BLUESTEIN_CUTOFF) pinv = FFTAInvPlan{Complex{T},2}() - return FFTAPlan_re{Complex{T},2,R}((g1, g2), region, FFT_FORWARD, pinv, size(x, region[1])) + return FFTAPlan_re{Complex{T},2,typeof(R2)}((g1, g2), R2, FFT_FORWARD, pinv, size(x, R2[1])) else throw(ArgumentError("only supports 1D and 2D FFTs")) end end -function AbstractFFTs.plan_brfft(x::AbstractArray{T,N}, len, region::R; BLUESTEIN_CUTOFF=DEFAULT_BLUESTEIN_CUTOFF, _kwargs...) where {T,N,R} - FFTN = length(region) - if FFTN == 1 +function AbstractFFTs.plan_brfft( + x::AbstractArray{T,N}, + len::Int, + region::RegionTypes; + BLUESTEIN_CUTOFF=DEFAULT_BLUESTEIN_CUTOFF, _kwargs... +) where {T,N} + M = length(region) + if M == 1 # For even length problems, we solve the real problem with # two n/2 complex FFTs followed by a butterfly. For odd size # problems, we just solve the problem as a single complex @@ -108,12 +140,12 @@ function AbstractFFTs.plan_brfft(x::AbstractArray{T,N}, len, region::R; BLUESTEI g = CallGraph{T}(nn, BLUESTEIN_CUTOFF) pinv = FFTAInvPlan{T,1}() return FFTAPlan_re{T,1,Int}((g,), R1, FFT_BACKWARD, pinv, len) - elseif FFTN == 2 - sort!(region) + elseif M == 2 + R2 = _sort(region) g1 = CallGraph{T}(len, BLUESTEIN_CUTOFF) - g2 = CallGraph{T}(size(x, region[2]), BLUESTEIN_CUTOFF) + g2 = CallGraph{T}(size(x, R2[2]), BLUESTEIN_CUTOFF) pinv = FFTAInvPlan{T,2}() - return FFTAPlan_re{T,2,R}((g1, g2), region, FFT_BACKWARD, pinv, len) + return FFTAPlan_re{T,2,typeof(R2)}((g1, g2), R2, FFT_BACKWARD, pinv, len) else throw(ArgumentError("only supports 1D and 2D FFTs")) end From 531b131e3a396ac979d21d361b18a4356db48139 Mon Sep 17 00:00:00 2001 From: wheeheee <104880306+wheeheee@users.noreply.github.com> Date: Mon, 2 Mar 2026 22:23:10 +0800 Subject: [PATCH 4/9] nospecialize --- src/plan.jl | 78 ++++++++++++++++++++++++++++++++++------------------- 1 file changed, 50 insertions(+), 28 deletions(-) diff --git a/src/plan.jl b/src/plan.jl index 1389b01..d27bf79 100644 --- a/src/plan.jl +++ b/src/plan.jl @@ -207,33 +207,43 @@ function _mul_loop!( end #### ND plan ND array -@generated function LinearAlgebra.mul!( +function LinearAlgebra.mul!( out::AbstractArray{U,N}, p::FFTAPlan_cx{T,N}, X::AbstractArray{T,N} ) where {T,U,N} + Base.require_one_based_indexing(out, X) + if size(out) != size(X) + throw(DimensionMismatch("input array has axes $(axes(X)), but output array has axes $(axes(out))")) + elseif size(p) != size(X) + throw(DimensionMismatch("plan has size $(size(p)), but input array has size $(size(X))")) + elseif !(p.region == 1:N || p.region == 1) + throw(DimensionMismatch("Plan region is outside array dimensions.")) + end - quote - Base.require_one_based_indexing(out, X) - if size(out) != size(X) - throw(DimensionMismatch("input array has axes $(axes(X)), but output array has axes $(axes(out))")) - elseif size(p) != size(X) - throw(DimensionMismatch("plan has size $(size(p)), but input array has size $(size(X))")) - elseif !(p.region == 1:N || p.region == 1) - throw(DimensionMismatch("Plan region is outside array dimensions.")) - end + sz = size(X) + max_sz = maximum(sz) + obuf = Vector{T}(undef, max_sz) + ibuf = Vector{T}(undef, max_sz) + sizehint!(obuf, max_sz) # not guaranteed but hopefully prevents allocations + sizehint!(ibuf, max_sz) + dir = p.dir - sz = size(X) - max_sz = maximum(sz) - obuf = Vector{T}(undef, max_sz) - ibuf = Vector{T}(undef, max_sz) - sizehint!(obuf, max_sz) # not guaranteed but hopefully prevents allocations - sizehint!(ibuf, max_sz) - dir = p.dir + copyto!(out, X) # operate in-place on output array - copyto!(out, X) # operate in-place on output array + if @generated + quote + Base.Cartesian.@nexprs $N dim -> begin + n = size(out, dim) + resize!(obuf, n) + resize!(ibuf, n) + cg = p.callgraph[dim] - Base.Cartesian.@nexprs $N dim -> begin + fft_along_dim!(out, ibuf, obuf, cg, dir, Val(dim)) + end + end + else + for dim in 1:N n = size(out, dim) resize!(obuf, n) resize!(ibuf, n) @@ -241,9 +251,9 @@ end fft_along_dim!(out, ibuf, obuf, cg, dir, Val(dim)) end - - return out end + + return out end #### MD plan ND array (M begin - if p.region[k] == dim + if region[k] == dim n = size(out, dim) resize!(obuf, n) resize!(ibuf, n) - cg = p.callgraph[k] + cg = callgraphs[k] fft_along_dim!(out, ibuf, obuf, cg, dir, Val(dim)) k = min(k + 1, M) end end + return nothing end else for dim in 1:M - pdim = p.region[dim] + pdim = region[dim] n = size(out, pdim) resize!(obuf, n) resize!(ibuf, n) - cg = p.callgraph[dim] + cg = callgraphs[dim] fft_along_dim!(out, ibuf, obuf, cg, dir, Val(pdim)) end end - - return out end function fft_along_dim!( From 1845e835cce5da1a02a220f692f146e264d570ea Mon Sep 17 00:00:00 2001 From: wheeheee <104880306+wheeheee@users.noreply.github.com> Date: Mon, 2 Mar 2026 22:23:10 +0800 Subject: [PATCH 5/9] CartesianIndices ntuple --- src/plan.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/plan.jl b/src/plan.jl index d27bf79..27930ae 100644 --- a/src/plan.jl +++ b/src/plan.jl @@ -199,8 +199,8 @@ function _mul_loop!( p::FFTAPlan_cx{T,1}, ::Val{R} ) where {T,U,N,R} - Rpre = CartesianIndices(size(x)[1:R-1]) - Rpost = CartesianIndices(size(x)[R+1:N]) + Rpre = CartesianIndices(ntuple(Base.Fix1(size, x), Val(R - 1))) + Rpost = CartesianIndices(ntuple(i -> size(x, R + i), Val(N - R))) for Ipost in Rpost, Ipre in Rpre @views fft!(y[Ipre,:,Ipost], x[Ipre,:,Ipost], 1, 1, p.dir, p.callgraph[1][1].type, p.callgraph[1], 1) end @@ -331,9 +331,8 @@ function fft_along_dim!( ::Val{dim} ) where {T <: Complex{<:AbstractFloat}, U, N, dim} - sz = size(A) - Rpre = CartesianIndices(sz[1:dim-1]) - Rpost = CartesianIndices(sz[dim+1:N]) + Rpre = CartesianIndices(ntuple(Base.Fix1(size, A), Val(dim - 1))) + Rpost = CartesianIndices(ntuple(i -> size(A, dim + i), Val(N - dim))) t = cg[1].type cols = eachindex(axes(A, dim), ibuf, obuf) From 209bedfe1a9807d08ad0a36a90b40955610b49a2 Mon Sep 17 00:00:00 2001 From: wheeheee <104880306+wheeheee@users.noreply.github.com> Date: Mon, 2 Mar 2026 22:23:10 +0800 Subject: [PATCH 6/9] add tests for mutated dims --- test/argument_checking.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/argument_checking.jl b/test/argument_checking.jl index c692c00..c368655 100644 --- a/test/argument_checking.jl +++ b/test/argument_checking.jl @@ -126,3 +126,22 @@ end @test_throws DomainError size(p_r, -1) end end + +@testset "Invalid / mutated dims" verbose=true begin + @testset "Extra elements" begin + for n in 3:5 + x = rand(ComplexF64, ntuple(Returns(2), n)) + p1 = plan_fft(x, [1:n-1;]) + push!(p1.region, n) + @test_throws DimensionMismatch("Region is invalid.") p1 * x + end + end + @testset "Unsorted dims" begin + for n in 3:5 + x = rand(ComplexF64, ntuple(Returns(2), n)) + p2 = plan_fft(x, [1:n-1;]) + p2.region[1:2] = [2, 1] + @test_throws DimensionMismatch("Region is invalid.") p2 * x + end + end +end From 09e2f1996444fa96d9c72c15e3e09b6973debe36 Mon Sep 17 00:00:00 2001 From: wheeheee <104880306+wheeheee@users.noreply.github.com> Date: Wed, 4 Mar 2026 18:52:06 +0800 Subject: [PATCH 7/9] fix ci --- src/plan.jl | 63 +++++++++++++++++++++++---------- test/argument_checking.jl | 4 +-- test/ndim/minimal_complex.jl | 4 +-- test/twodim/complex_backward.jl | 16 ++++++--- test/twodim/complex_forward.jl | 16 ++++++--- test/twodim/real_backward.jl | 2 +- test/twodim/real_forward.jl | 2 +- 7 files changed, 74 insertions(+), 33 deletions(-) diff --git a/src/plan.jl b/src/plan.jl index 27930ae..06709a8 100644 --- a/src/plan.jl +++ b/src/plan.jl @@ -49,16 +49,20 @@ Base.size(p::FFTAPlan{<:Any,N}) where N = ntuple(Base.Fix1(size, p), Val{N}()) Base.complex(p::FFTAPlan_re{T,N,R}) where {T,N,R} = FFTAPlan_cx{T,N,R}(p.callgraph, p.region, p.dir, p.pinv) function _sort(region::T)::T where {N,T<:NTuple{N,Int}} - if N == 2 - minmax(region[1], region[2]) - elseif N == 3 - t1, t2, t3 = region - t1, t2 = minmax(t1, t2) - t2, t3 = minmax(t2, t3) - t1, t2 = minmax(t1, t2) - return (t1, t2, t3) + @static if VERSION >= v"1.12" + sort(region) else - @static VERSION >= v"1.12" ? sort(region) : NTuple{N}(sort!(collect(region))) + if N == 2 + minmax(region[1], region[2]) + elseif N == 3 + t1, t2, t3 = region + t1, t2 = minmax(t1, t2) + t2, t3 = minmax(t2, t3) + t1, t2 = minmax(t1, t2) + (t1, t2, t3) + else + NTuple{N}(sort!(collect(region))) + end end end @@ -78,7 +82,7 @@ function _plan_fft( ) where {T<:Complex,N} M = length(region) if M == 1 - R1 = Int(region[]) + R1 = Int(region[1]) g = CallGraph{T}(size(x, R1), BLUESTEIN_CUTOFF) pinv = FFTAInvPlan{T,1}() return FFTAPlan_cx{T,1,Int}((g,), R1, dir, pinv) @@ -87,10 +91,10 @@ function _plan_fft( g1 = CallGraph{T}(size(x, R2[1]), BLUESTEIN_CUTOFF) g2 = CallGraph{T}(size(x, R2[2]), BLUESTEIN_CUTOFF) pinv = FFTAInvPlan{T,2}() - return FFTAPlan_cx{T,2,typeof(R2)}((g1, g2), R2, dir, pinv) + return FFTAPlan_cx{T,2}((g1, g2), R2, dir, pinv) else RM = _sort(region) - return FFTAPlan_cx{T,M,typeof(RM)}( + return FFTAPlan_cx{T,M}( ntuple(i -> CallGraph{T}(size(x, RM[i]), BLUESTEIN_CUTOFF), Val(M)), RM, dir, FFTAInvPlan{T,M}() ) @@ -104,7 +108,7 @@ function AbstractFFTs.plan_rfft( ) where {T<:Real,N} M = length(region) if M == 1 - R1 = Int(region[]) + R1 = Int(region[1]) n = size(x, R1) # For even length problems, we solve the real problem with # two n/2 complex FFTs followed by a butterfly. For odd size @@ -118,7 +122,7 @@ function AbstractFFTs.plan_rfft( g1 = CallGraph{Complex{T}}(size(x, R2[1]), BLUESTEIN_CUTOFF) g2 = CallGraph{Complex{T}}(size(x, R2[2]), BLUESTEIN_CUTOFF) pinv = FFTAInvPlan{Complex{T},2}() - return FFTAPlan_re{Complex{T},2,typeof(R2)}((g1, g2), R2, FFT_FORWARD, pinv, size(x, R2[1])) + return FFTAPlan_re{Complex{T},2}((g1, g2), R2, FFT_FORWARD, pinv, size(x, R2[1])) else throw(ArgumentError("only supports 1D and 2D FFTs")) end @@ -135,7 +139,7 @@ function AbstractFFTs.plan_brfft( # For even length problems, we solve the real problem with # two n/2 complex FFTs followed by a butterfly. For odd size # problems, we just solve the problem as a single complex - R1 = Int(region[]) + R1 = Int(region[1]) nn = iseven(len) ? len >> 1 : len g = CallGraph{T}(nn, BLUESTEIN_CUTOFF) pinv = FFTAInvPlan{T,1}() @@ -145,7 +149,7 @@ function AbstractFFTs.plan_brfft( g1 = CallGraph{T}(len, BLUESTEIN_CUTOFF) g2 = CallGraph{T}(size(x, R2[2]), BLUESTEIN_CUTOFF) pinv = FFTAInvPlan{T,2}() - return FFTAPlan_re{T,2,typeof(R2)}((g1, g2), R2, FFT_BACKWARD, pinv, len) + return FFTAPlan_re{T,2}((g1, g2), R2, FFT_BACKWARD, pinv, len) else throw(ArgumentError("only supports 1D and 2D FFTs")) end @@ -177,7 +181,7 @@ function LinearAlgebra.mul!(y::AbstractArray{U,N}, p::FFTAPlan_cx{T,1}, x::Abstr throw(DimensionMismatch("input array has axes $ax_x, but output array has axes $ax_y")) end - R1 = p.region[] + R1 = only(p.region) plen, xlen = size(p, 1), size(x, R1) if plen != xlen throw(DimensionMismatch("plan has size $plen, but input array has size $xlen along region $R1")) @@ -217,7 +221,7 @@ function LinearAlgebra.mul!( throw(DimensionMismatch("input array has axes $(axes(X)), but output array has axes $(axes(out))")) elseif size(p) != size(X) throw(DimensionMismatch("plan has size $(size(p)), but input array has size $(size(X))")) - elseif !(p.region == 1:N || p.region == 1) + elseif !region_isvalid(p.region, N) throw(DimensionMismatch("Plan region is outside array dimensions.")) end @@ -265,7 +269,7 @@ function LinearAlgebra.mul!( Base.require_one_based_indexing(out, X) if size(out) != size(X) throw(DimensionMismatch("input array has axes $(axes(X)), but output array has axes $(axes(out))")) - elseif length(p.region) != M || !issorted(p.region; lt=(<=)) + elseif !region_isvalid(p.region, M, N) throw(DimensionMismatch("Region is invalid.")) elseif M > N || first(p.region) < 1 || last(p.region) > N throw(DimensionMismatch("Plan region is outside array dimensions.")) @@ -347,6 +351,27 @@ function fft_along_dim!( end end +region_isvalid(r::Int, N::Int, _::Int=0) = r == N == 1 +region_isvalid(r::AbstractVector{Int}, N::Int) = r == 1:N +region_isvalid(r::AbstractRange{Int}, M::Int, _::Int) = issorted(r) && length(r) == M +function region_isvalid(r::NTuple{M,Int}, N::Int) where M + isvalid = M == N + for i in 1:M + isvalid &= (r[i] == i) + end + isvalid +end +function region_isvalid(r::Union{AbstractVector{Int},NTuple{<:Any,Int}}, M::Int, _::Int) + isvalid = length(r) == M + maybe_p = Iterators.peel(r) + isnothing(maybe_p) && return isvalid + p, rest = maybe_p + for n in rest + isvalid = isvalid && (p < n) + p = n + end + isvalid +end ## * ### Complex diff --git a/test/argument_checking.jl b/test/argument_checking.jl index c368655..d8e582e 100644 --- a/test/argument_checking.jl +++ b/test/argument_checking.jl @@ -130,7 +130,7 @@ end @testset "Invalid / mutated dims" verbose=true begin @testset "Extra elements" begin for n in 3:5 - x = rand(ComplexF64, ntuple(Returns(2), n)) + x = rand(ComplexF64, ntuple(i -> 2, n)) p1 = plan_fft(x, [1:n-1;]) push!(p1.region, n) @test_throws DimensionMismatch("Region is invalid.") p1 * x @@ -138,7 +138,7 @@ end end @testset "Unsorted dims" begin for n in 3:5 - x = rand(ComplexF64, ntuple(Returns(2), n)) + x = rand(ComplexF64, ntuple(i -> 2, n)) p2 = plan_fft(x, [1:n-1;]) p2.region[1:2] = [2, 1] @test_throws DimensionMismatch("Region is invalid.") p2 * x diff --git a/test/ndim/minimal_complex.jl b/test/ndim/minimal_complex.jl index 059aea6..21e4c74 100644 --- a/test/ndim/minimal_complex.jl +++ b/test/ndim/minimal_complex.jl @@ -2,8 +2,8 @@ using FFTA, Test @testset "Basic ND checks" begin for sz in ((3, 5, 7), (4, 14, 9), (103, 5, 13), (26, 33, 35, 4), ntuple(i -> 3, 5)) - x = ones(sz) - @test fft(x) ≈ setindex!(zeros(sz), prod(sz), 1) + x = ones(ComplexF64, sz) + @test fft(x, Tuple(1:ndims(x))) ≈ setindex!(zeros(sz), prod(sz), 1) end y = zeros((3, 3, 3)) diff --git a/test/twodim/complex_backward.jl b/test/twodim/complex_backward.jl index c417454..268207c 100644 --- a/test/twodim/complex_backward.jl +++ b/test/twodim/complex_backward.jl @@ -24,11 +24,19 @@ end end end -@testset "2D plan, ND array. Size: $n" for n in 1:64 - x = randn(ComplexF64, n, n + 1, n + 2) +@testset "$(N)D plan, $(N+1)D array" for N in 2:3 + rg = N == 2 ? (1:64) : (1:16) + dims_lst = [[1,2], [1,3], [2,3]] + if N == 3 + foreach(v -> push!(v, 4), dims_lst) + end + @testset "against $(N)D arrays with mapslices, r=$r" for r in dims_lst + for n in rg + x = randn(ComplexF64, ntuple(i -> n + (i - 1), N + 1)) - @testset "against 1D array with mapslices, r=$r" for r in [[1,2], [1,3], [2,3]] - @test bfft(x, r) == mapslices(bfft, x; dims = r) + t = Tuple(r) # test tuple region argument + @test bfft(x, t) == bfft(x, r) == mapslices(bfft, x; dims = r) + end end end diff --git a/test/twodim/complex_forward.jl b/test/twodim/complex_forward.jl index 6375de9..6798d13 100644 --- a/test/twodim/complex_forward.jl +++ b/test/twodim/complex_forward.jl @@ -26,11 +26,19 @@ end end end -@testset "2D plan, ND array. Size: $n" for n in 1:64 - x = randn(ComplexF64, n, n + 1, n + 2) +@testset "$(N)D plan, $(N+1)D array" for N in 2:3 + rg = N == 2 ? (1:64) : (1:16) + dims_lst = [[1,2], [1,3], [2,3]] + if N == 3 + foreach(v -> push!(v, 4), dims_lst) + end + @testset "against $(N)D arrays with mapslices, r=$r" for r in dims_lst + for n in rg + x = randn(ComplexF64, ntuple(i -> n + (i - 1), N + 1)) - @testset "against 1D array with mapslices, r=$r" for r in [[1,2], [1,3], [2,3]] - @test fft(x, r) == mapslices(fft, x; dims = r) + t = Tuple(r) # test tuple region argument + @test fft(x, t) == fft(x, r) == mapslices(fft, x; dims = r) + end end end diff --git a/test/twodim/real_backward.jl b/test/twodim/real_backward.jl index 246ebb6..8913edf 100644 --- a/test/twodim/real_backward.jl +++ b/test/twodim/real_backward.jl @@ -31,7 +31,7 @@ end @test x ≈ irfft(rfft(x,r), size(x,r[1]), r) end - @testset "against 2D array with mapslices, r=$r" for r in [[1,2], [1,3], [2,3]] + @testset "against 2D arrays with mapslices, r=$r" for r in [[1,2], [1,3], [2,3]] y = rfft(x, r) @test brfft(y, size(x, r[1]), r) == mapslices(t -> brfft(t, size(x, r[1])), y; dims = r) end diff --git a/test/twodim/real_forward.jl b/test/twodim/real_forward.jl index 1e91131..aded588 100644 --- a/test/twodim/real_forward.jl +++ b/test/twodim/real_forward.jl @@ -30,7 +30,7 @@ end @testset "2D plan, ND array. Size: $n" for n in 1:64 x = randn(n, n + 1, n + 2) - @testset "against 1D array with mapslices, r=$r" for r in [[1,2], [1,3], [2,3]] + @testset "against 2D arrays with mapslices, r=$r" for r in [[1,2], [1,3], [2,3]] @test rfft(x, r) == mapslices(rfft, x; dims = r) end end From 7dde72a796794bc9491401fcd756e69977efe2e4 Mon Sep 17 00:00:00 2001 From: wheeheee <104880306+wheeheee@users.noreply.github.com> Date: Wed, 4 Mar 2026 18:52:06 +0800 Subject: [PATCH 8/9] hoist --- src/plan.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/plan.jl b/src/plan.jl index 06709a8..ef8e979 100644 --- a/src/plan.jl +++ b/src/plan.jl @@ -205,8 +205,10 @@ function _mul_loop!( ) where {T,U,N,R} Rpre = CartesianIndices(ntuple(Base.Fix1(size, x), Val(R - 1))) Rpost = CartesianIndices(ntuple(i -> size(x, R + i), Val(N - R))) + cg = p.callgraph[1] + t = cg[1].type for Ipost in Rpost, Ipre in Rpre - @views fft!(y[Ipre,:,Ipost], x[Ipre,:,Ipost], 1, 1, p.dir, p.callgraph[1][1].type, p.callgraph[1], 1) + @views fft!(y[Ipre,:,Ipost], x[Ipre,:,Ipost], 1, 1, p.dir, t, cg, 1) end end From 97d018cb76b7813760508536f8c3414ebda5b321 Mon Sep 17 00:00:00 2001 From: wheeheee <104880306+wheeheee@users.noreply.github.com> Date: Thu, 5 Mar 2026 13:13:24 +0800 Subject: [PATCH 9/9] ignore peel public access check --- test/qa/explicit_imports.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/qa/explicit_imports.jl b/test/qa/explicit_imports.jl index 340dcdd..428672f 100644 --- a/test/qa/explicit_imports.jl +++ b/test/qa/explicit_imports.jl @@ -20,7 +20,10 @@ import ExplicitImports # No non-public accesses in FFTA (ie. no `... MyPkg._non_public_internal_func(...)`) # AbstractFFTs requires subtyping of `Plan` but it is not public # This is an upstream bug in AbstractFFTs.jl - @test ExplicitImports.check_all_qualified_accesses_are_public(FFTA; ignore = (:Plan, :require_one_based_indexing, :Fix1, :Cartesian)) === nothing + @test ExplicitImports.check_all_qualified_accesses_are_public( + FFTA; + ignore=(:Plan, :require_one_based_indexing, :Fix1, :Cartesian, :peel) + ) === nothing # No self-qualified accesses in FFTA (ie. no `... FFTA.func(...)`) @test ExplicitImports.check_no_self_qualified_accesses(FFTA) === nothing