diff --git a/Project.toml b/Project.toml index 1d92d411d..9baa163ce 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "1.9.3" projects = ["test", "docs", "perf", "examples", "res/wrap"] [deps] +AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -25,6 +26,7 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Random123 = "74087812-796a-5b5d-8853-05524746bad3" RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce" ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -37,6 +39,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" SpecialFunctionsExt = "SpecialFunctions" [compat] +AbstractFFTs = "1" Adapt = "4.5" BFloat16s = "0.5, 0.6" CEnum = "0.4, 0.5" @@ -56,6 +59,7 @@ Printf = "1" Random = "1" Random123 = "1.7.1" RandomNumbers = "1.6.0" +Reexport = "1.2.2" SHA = "0.7, 1" ScopedValues = "1.3.0" SpecialFunctions = "2" diff --git a/lib/mpsgraphs/MPSGraphs.jl b/lib/mpsgraphs/MPSGraphs.jl index 86a5d44d9..19f7cfce9 100644 --- a/lib/mpsgraphs/MPSGraphs.jl +++ b/lib/mpsgraphs/MPSGraphs.jl @@ -48,5 +48,6 @@ include("operations.jl") include("random.jl") include("matmul.jl") +include("fft.jl") end diff --git a/lib/mpsgraphs/fft.jl b/lib/mpsgraphs/fft.jl new file mode 100644 index 000000000..a1e4bdb39 --- /dev/null +++ b/lib/mpsgraphs/fft.jl @@ -0,0 +1,40 @@ + +## FFT Descriptor Creation + +""" + MPSGraphFFTDescriptor(; inverse=false, scalingMode=MPSGraphFFTScalingModeNone) + +Create an MPSGraphFFTDescriptor with the specified parameters. +""" +function MPSGraphFFTDescriptor(; inverse::Bool = false, scalingMode::MPSGraphFFTScalingMode = MPSGraphFFTScalingModeNone) + obj = @objc [MPSGraphFFTDescriptor alloc]::id{MPSGraphFFTDescriptor} + desc = MPSGraphFFTDescriptor(obj) + desc.inverse = inverse + desc.scalingMode = scalingMode + return desc +end + +## MPSGraph FFT operations +function fastFourierTransformWithTensor(graph::MPSGraph, tensor::MPSGraphTensor, axes::NSArray, descriptor::MPSGraphFFTDescriptor, name = "fft") + obj = @objc [graph::id{MPSGraph} fastFourierTransformWithTensor:tensor::id{MPSGraphTensor} + axes:axes::id{NSArray} + descriptor:descriptor::id{MPSGraphFFTDescriptor} + name:name::id{NSString}]::id{MPSGraphTensor} + MPSGraphTensor(obj) +end + +function realToHermiteanFFTWithTensor(graph::MPSGraph, tensor::MPSGraphTensor, axes::NSArray, descriptor::MPSGraphFFTDescriptor, name = "rfft") + obj = @objc [graph::id{MPSGraph} realToHermiteanFFTWithTensor:tensor::id{MPSGraphTensor} + axes:axes::id{NSArray} + descriptor:descriptor::id{MPSGraphFFTDescriptor} + name:name::id{NSString}]::id{MPSGraphTensor} + MPSGraphTensor(obj) +end + +function HermiteanToRealFFTWithTensor(graph::MPSGraph, tensor::MPSGraphTensor, axes::NSArray, descriptor::MPSGraphFFTDescriptor, name = "irfft") + obj = @objc [graph::id{MPSGraph} HermiteanToRealFFTWithTensor:tensor::id{MPSGraphTensor} + axes:axes::id{NSArray} + descriptor:descriptor::id{MPSGraphFFTDescriptor} + name:name::id{NSString}]::id{MPSGraphTensor} + MPSGraphTensor(obj) +end diff --git a/src/Metal.jl b/src/Metal.jl index 9fc025053..8e62bff8d 100644 --- a/src/Metal.jl +++ b/src/Metal.jl @@ -14,6 +14,8 @@ import ObjectiveC: is_macos import KernelAbstractions using ScopedValues +using Reexport: @reexport + include("version.jl") include("storage_type.jl") @@ -68,6 +70,7 @@ include("mapreduce.jl") include("accumulate.jl") include("indexing.jl") include("random.jl") +include("fft.jl") # KernelAbstractions include("MetalKernels.jl") diff --git a/src/fft.jl b/src/fft.jl new file mode 100644 index 000000000..939412d53 --- /dev/null +++ b/src/fft.jl @@ -0,0 +1,321 @@ +# FFT operations using MPSGraph +# Implements AbstractFFTs.jl interface for MtlArray + +using .MPSGraphs: MPSGraph, MPSGraphFFTDescriptor, HermiteanToRealFFTWithTensor, realToHermiteanFFTWithTensor, + fastFourierTransformWithTensor, placeholderTensor, MPSGraphTensorData, MPSGraphTensor + +@reexport using AbstractFFTs +import AbstractFFTs: plan_fft, plan_fft!, plan_bfft, plan_bfft!, plan_ifft, + plan_rfft, plan_brfft, plan_inv, normalization, fft, bfft, ifft, rfft, irfft, + Plan, ScaledPlan + +# supported types for FFT using MPSGraphs +const FFTComplex = Union{ComplexF32, ComplexF16} +const FFTReal = Union{Float32, Float16} +const FFTNumber = Union{FFTReal, FFTComplex} + +# mtlfloat is like Base.float but converts Integers to Float32 instead +# of to Float64 which is unsupported on all Apple Silicon GPUs +mtlfloat(x) = float(x) +mtlfloat(x::Integer) = Float32(x) +mtlfloat(x::Complex{<:Integer}) = ComplexF32(x) +mtlfloat(::Type{<:Integer}) = Float32 +mtlfloat(::Type{<:Complex{<:Integer}}) = ComplexF32 + +mtlfftfloat(x) = _mtlfftfloat(mtlfloat(x)) +_mtlfftfloat(::Type{T}) where {T<:FFTNumber} = T +_mtlfftfloat(::Type{T}) where {T} = error("type $T not supported") +_mtlfftfloat(x::T) where {T} = _mtlfftfloat(T)(x) + +realfloat(x::MtlArray{<:FFTReal}) = x +realfloat(x::MtlArray{T}) where {T<:Real} = copy1(mtlfftfloat(T), x) +realfloat(::MtlArray{T}) where {T} = error("type $T not supported") + +complexfloat(x::MtlArray{<:FFTComplex}) = x +complexfloat(x::MtlArray{T}) where {T<:Complex} = copy1(mtlfftfloat(T), x) +complexfloat(x::MtlArray{T}) where {T<:Real} = copy1(mtlfftfloat(complex(T)), x) +complexfloat(::MtlArray{T}) where {T} = error("type $T not supported") + +function copy1(::Type{T}, x::MtlArray{<:Any, N, S}) where {T, N, S} + y = MtlArray{T, N, S}(undef, map(length, axes(x))) + y .= broadcast(xi -> convert(T, xi), x) +end + +## plan structure + +""" + MtlFFTPlan{T, S, backward, inplace, N, R} <: AbstractFFTs.Plan{S} + +`T` is the output type +`S` is the input ("source") type + +`backward` is a boolean flag +`inplace` is a boolean flag + +`N` is the number of dimensions + +GPU FFT plan for Metal using MPSGraph's fastFourierTransformWithTensor. +""" +mutable struct MtlFFTPlan{T <: FFTNumber, S <: FFTNumber, backward, inplace, N, R} <: Plan{S} + input_size::NTuple{N, Int} + output_size::NTuple{N, Int} + region::NTuple{R, Int} + pinv::ScaledPlan{T} + + function MtlFFTPlan{T, S, backward, inplace, N, R}(input_size::NTuple{N, Int}, output_size::NTuple{N, Int}, region::NTuple{R, Int}) where {T <: FFTNumber, S <: FFTNumber, backward, inplace, N, R} + # Validate region + if any(i -> region[i] >= region[i+1], 1:R-1) + throw(ArgumentError("region must be an increasing sequence")) + end + if any(region .< 1 .|| region .> N) + throw(ArgumentError("region can only refer to valid dimensions")) + end + backward isa Bool || throw(ArgumentError("FFT backward argument must be a Bool")) + inplace isa Bool || throw(ArgumentError("FFT inplace argument must be a Bool")) + + return new{T, S, backward, inplace, N, R}(input_size, output_size, region) + end +end + +function showfftdims(io, sz, T) + if isempty(sz) + print(io,"0-dimensional") + elseif length(sz) == 1 + print(io, sz[1], "-element") + else + print(io, join(sz, "×")) + end + print(io, " MtlArray of ", T) +end + +function Base.show(io::IO, p::MtlFFTPlan{T, S, backward, inplace}) where {T, S, backward, inplace} + print(io, "MPSGraph FFT ", + inplace ? "in-place " : "", + S == T ? "$T " : "$(S)-to-$(T) ", + backward ? "backward " : "forward ", + "plan for ") + showfftdims(io, p.input_size, S) +end + +# plan properties +Base.size(p::MtlFFTPlan) = p.input_size +AbstractFFTs.fftdims(p::MtlFFTPlan) = p.region + +## AbstractFFTs interface implementation + +# promote to a complex floating-point type (out-of-place only), +# so implementations only need Complex{Float} methods +for f in (:fft, :bfft, :ifft) + pf = Symbol("plan_", f) + @eval begin + $f(x::MtlArray{<:Real}, region=1:ndims(x)) = $f(complexfloat(x), region) + $pf(x::MtlArray{<:Real}, region) = $pf(complexfloat(x), region) + $f(x::MtlArray{<:Complex{<:Union{Integer,Rational}}}, region=1:ndims(x)) = $f(complexfloat(x), region) + $pf(x::MtlArray{<:Complex{<:Union{Integer,Rational}}}, region) = $pf(complexfloat(x), region) + end +end +rfft(x::MtlArray{<:Union{Integer,Rational}}, region=1:ndims(x)) = rfft(realfloat(x), region) +plan_rfft(x::MtlArray{<:Real}, region) = plan_rfft(realfloat(x), region) + +function irfft(x::MtlArray{<:Union{Real,Integer,Rational}}, d::Integer, region=1:ndims(x)) + irfft(complexfloat(x), d, region) +end + + +# forward plans are `plan_fft`, and backward (unnormalized) plans are `plan_bfft` +# inplace functions have a "!", inverse (normalized) plans are handled via plan_inv +for inplace in (true, false), backward in (true, false) + dir_str = backward ? "b" : "" + inplace_str = inplace ? "!" : "" + f = Symbol(:plan_, dir_str, :fft, inplace_str) + + @eval begin + # untyped `region` argument + Base.@constprop :aggressive function $f(x::MtlArray{T, N}, region) where {T <: FFTComplex, N} + R = length(region) + region = NTuple{R,Int}(region) + $f(x, region) + end + + # actually create the MtlFFTPlan + $f(x::MtlArray{T, N}, region::NTuple{R, Int}) where {T <: FFTComplex, N, R} = MtlFFTPlan{T, T, $backward, $inplace, N, R}(size(x), size(x), region) + end +end + +# out-of-place real-to-complex +Base.@constprop :aggressive function plan_rfft(x::MtlArray{T, N}, region) where {T <: FFTReal, N} + R = length(region) + region = NTuple{R,Int}(region) + + plan_rfft(x, region) +end + +function plan_rfft(x::MtlArray{T, N}, region::NTuple{R, Int}) where {T <: FFTReal, N, R} + backward = false + inplace = false + + xdims = size(x) + ydims = Base.setindex(xdims, div(xdims[region[1]], 2) + 1, region[1]) + MtlFFTPlan{complex(T), T, backward, inplace, N, R}(size(x), (ydims...,), region) +end + +# out-of-place complex-to-real +Base.@constprop :aggressive function plan_brfft(x::MtlArray{T, N}, d::Integer, region) where {T <: FFTComplex, N} + R = length(region) + region = NTuple{R,Int}(region) + + plan_brfft(x, d, region) +end + +function plan_brfft(x::MtlArray{T, N}, d::Integer, region::NTuple{R, Int}) where {T <: FFTComplex, N, R} + backward = true + inplace = false + + xdims = size(x) + ydims = Base.setindex(xdims, d, region[1]) + + MtlFFTPlan{real(T), T, backward, inplace, N, R}(size(x), ydims, region) +end + +function plan_inv(p::MtlFFTPlan{T, S, true, inplace, N, R}) where {T <: FFTNumber, S <: FFTNumber, inplace, N, R} + ScaledPlan(MtlFFTPlan{S, T, false, inplace, N, R}(p.output_size, p.input_size, p.region), + normalization(real(T), p.output_size, p.region)) +end + +function plan_inv(p::MtlFFTPlan{T, S, false, inplace, N, R}) where {T <: FFTNumber, S <: FFTNumber, inplace, N, R} + ScaledPlan(MtlFFTPlan{S, T, true, inplace, N, R}(p.output_size, p.input_size, p.region), + normalization(real(S), p.input_size, p.region)) +end + +## plan execution + +function assert_applicable(p::MtlFFTPlan{T, S}, X::MtlArray{S}) where {T, S} + (size(X) == p.input_size) || + throw(ArgumentError("MtlFFT plan applied to wrong-size input")) +end + +function assert_applicable(p::MtlFFTPlan{T, S, backward, inplace}, X::MtlArray{S}, + Y::MtlArray{T}) where {T, S, backward, inplace} + assert_applicable(p, X) + if size(Y) != p.output_size + throw(ArgumentError("MtlFFT plan applied to wrong-size output")) + elseif inplace != (pointer(X) == pointer(Y)) + throw(ArgumentError(string("MtlFFT ", + inplace ? "in-place" : "out-of-place", + " plan applied to ", + inplace ? "out-of-place" : "in-place", + " data"))) + end +end + +# Cache key for FFT graphs - includes all structural parameters +struct FFTGraphKey + input_size::Tuple{Vararg{Int}} + output_size::Tuple{Vararg{Int}} + eltype_input::DataType + eltype_output::DataType + ndims::Int + region::Tuple{Vararg{Int}} + backward::Bool +end +# Build graph key from FFT plan parameters +function FFTGraphKey(p::MtlFFTPlan{T, S, backward, inplace, N, R}) where {T, S, backward, inplace, N, R} + FFTGraphKey( + p.input_size, p.output_size, + S, T, + N, p.region, + backward + ) +end + +# Cached graph with all tensors needed for execution +struct CachedFFTGraph + graph::MPSGraph + placeholder::MPSGraphTensor + result::MPSGraphTensor +end +function CachedFFTGraph(key::FFTGraphKey) + graph = MPSGraph() + + # Create symbolic placeholder with the input shape and type + placeholder = placeholderTensor(graph, key.input_size, key.eltype_input) + + # Create FFT descriptor - don't use MPSGraph scaling, AbstractFFTs handles it for us + fft_desc = MPSGraphFFTDescriptor(; inverse = key.backward) + + # Convert Julia 1-indexed axis to Metal 0-indexed axis + # Due to shape reversal in placeholderTensor, we need to compute the correct axis + # Julia axis i -> Metal axis (N - i) for N-dimensional array + axes = NSArray([NSNumber(Int(key.ndims - ax)) for ax in key.region]) + + # Select the MPSGraph FFT operation based on input/output element types + fft_fn = if key.eltype_input <: Complex && key.eltype_output <: Complex + fastFourierTransformWithTensor + elseif key.eltype_input <: Real && key.eltype_output <: Complex + realToHermiteanFFTWithTensor + else # complex input, real output + HermiteanToRealFFTWithTensor + end + + # Create FFT operation + result = fft_fn(graph, placeholder, axes, fft_desc) + + CachedFFTGraph(graph, placeholder, result) +end + +# Thread-safe graph cache with lock +const _fft_graph_cache = Dict{FFTGraphKey, CachedFFTGraph}() +const _fft_graph_cache_lock = ReentrantLock() + +@autoreleasepool function _fft!(p::MtlFFTPlan{T, S, backward, inplace, N}, x, y) where {T <: FFTNumber, S <: FFTNumber, N, backward, inplace} + # Get or create cached graph + key = FFTGraphKey(p) + cached = @lock _fft_graph_cache_lock get!(_fft_graph_cache, key) do + CachedFFTGraph(key) + end + + # Build feed and result dictionaries with current data + feeds = Dict{MPSGraphTensor, MPSGraphTensorData}( + cached.placeholder => MPSGraphTensorData(x) + ) + + resultdict = Dict{MPSGraphTensor, MPSGraphTensorData}( + cached.result => MPSGraphTensorData(y) + ) + + cmdbuf = MPS.MPSCommandBuffer(global_queue(device())) + MPS.encode!(cmdbuf, cached.graph, NSDictionary(feeds), NSDictionary(resultdict), nil, MPSGraphs.default_exec_desc()) + commit!(cmdbuf) + wait_completed(cmdbuf) + + return y +end + +## high-level integrations + +function LinearAlgebra.mul!(y::MtlArray{T, N}, p::MtlFFTPlan{T, S, backward, inplace, N}, x::MtlArray{S, N}) where {T, S, backward, inplace, N} + assert_applicable(p, x, y) + + _fft!(p, x, y) + return y +end + +function Base.:(*)(p::MtlFFTPlan{T, S, backward, true}, x::MtlArray{S}) where {T, S, backward} + assert_applicable(p, x) + + _fft!(p, x, x) + return x +end +function Base.:(*)(p::MtlFFTPlan{T, S, backward, false}, x::MtlArray{S1, M}) where {T, S, backward, S1, M} + z = if S1 != S + # Convert to the expected input type. + copy1(S, x) + else + x + end + assert_applicable(p, z) + + y = MtlArray{T, M}(undef, p.output_size) + _fft!(p, z, y) + return y +end diff --git a/test/Project.toml b/test/Project.toml index aef2cc151..322e6edac 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,7 @@ [deps] +AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" diff --git a/test/fft.jl b/test/fft.jl new file mode 100644 index 000000000..9651f8e45 --- /dev/null +++ b/test/fft.jl @@ -0,0 +1,467 @@ +using FFTW +using AbstractFFTs +using LinearAlgebra + +# FFTW does not support Float16, so we provide shims for CPU reference + +function AbstractFFTs.fft(x::Array{ComplexF16}, dims...) + return Array{ComplexF16}(fft(Array{ComplexF32}(x), dims...)) +end +function AbstractFFTs.ifft(x::Array{ComplexF16}, dims...) + return Array{ComplexF16}(ifft(Array{ComplexF32}(x), dims...)) +end +function AbstractFFTs.bfft(x::Array{ComplexF16}, dims...) + return Array{ComplexF16}(bfft(Array{ComplexF32}(x), dims...)) +end +function AbstractFFTs.rfft(x::Array{Float16}, dims...) + return Array{ComplexF16}(rfft(Array{Float32}(x), dims...)) +end +function AbstractFFTs.irfft(x::Array{ComplexF16}, d::Integer, dims...) + return Array{Float16}(irfft(Array{ComplexF32}(x), d, dims...)) +end +function AbstractFFTs.brfft(x::Array{ComplexF16}, d::Integer, dims...) + return Array{Float16}(brfft(Array{ComplexF32}(x), d, dims...)) +end +struct WrappedFloat16Operator + op +end +Base.:*(A::WrappedFloat16Operator, b::Array{Float16}) = Array{Float16}(A.op * Array{Float32}(b)) +Base.:*(A::WrappedFloat16Operator, b::Array{Complex{Float16}}) = Array{Complex{Float16}}(A.op * Array{Complex{Float32}}(b)) +function LinearAlgebra.mul!(C::Array{Float16}, A::WrappedFloat16Operator, B::Array{Float16}, α, β) + C32 = Array{Float32}(C) + B32 = Array{Float32}(B) + mul!(C32, A.op, B32, α, β) + C .= C32 +end +function LinearAlgebra.mul!(C::Array{Complex{Float16}}, A::WrappedFloat16Operator, B::Array{Complex{Float16}}, α, β) + C32 = Array{Complex{Float32}}(C) + B32 = Array{Complex{Float32}}(B) + mul!(C32, A.op, B32, α, β) + C .= C32 +end + +function AbstractFFTs.plan_fft!(x::Array{Complex{Float16}}, dims...) + y = similar(x, Complex{Float32}) + WrappedFloat16Operator(plan_fft!(y, dims...)) +end +function AbstractFFTs.plan_bfft!(x::Array{Complex{Float16}}, dims...) + y = similar(x, Complex{Float32}) + WrappedFloat16Operator(plan_bfft!(y, dims...)) +end +function AbstractFFTs.plan_ifft!(x::Array{Complex{Float16}}, dims...) + y = similar(x, Complex{Float32}) + WrappedFloat16Operator(plan_ifft!(y, dims...)) +end + +function AbstractFFTs.plan_fft(x::Array{Complex{Float16}}, dims...) + y = similar(x, Complex{Float32}) + WrappedFloat16Operator(plan_fft(y, dims...)) +end +function AbstractFFTs.plan_bfft(x::Array{Complex{Float16}}, dims...) + y = similar(x, Complex{Float32}) + WrappedFloat16Operator(plan_bfft(y, dims...)) +end +function AbstractFFTs.plan_ifft(x::Array{Complex{Float16}}, dims...) + y = similar(x, Complex{Float32}) + WrappedFloat16Operator(plan_ifft(y, dims...)) +end +function AbstractFFTs.plan_rfft(x::Array{Float16}, dims...) + y = similar(x, Float32) + WrappedFloat16Operator(plan_rfft(y, dims...)) +end +function AbstractFFTs.plan_irfft(x::Array{Complex{Float16}}, dims...) + y = similar(x, Complex{Float32}) + WrappedFloat16Operator(plan_irfft(y, dims...)) +end +function AbstractFFTs.plan_brfft(x::Array{Complex{Float16}}, dims...) + y = similar(x, Complex{Float32}) + WrappedFloat16Operator(plan_brfft(y, dims...)) +end + +# Tolerance functions based on type precision +rtol(::Type{Float16}) = 1.0e-2 +rtol(::Type{Float32}) = 1.0e-5 +rtol(::Type{Float64}) = 1.0e-12 +rtol(::Type{I}) where {I<:Integer} = rtol(Metal.mtlfloat(I)) +rtol(::Type{Complex{T}}) where {T} = rtol(T) +# Test dimensions +N1 = 8 +N2 = 32 +N3 = 64 +N4 = 8 + +if MPS.is_supported(device()) + + ## complex FFT tests + function complex_out_of_place(X::AbstractArray{T, N}) where {T <: Complex, N} + fftw_X = fft(X) + d_X = MtlArray(X) + + # Forward FFT with @inferred + p = @inferred plan_fft(d_X) + d_Y = p * d_X + Y = Array(d_Y) + @test isapprox(Y, fftw_X, rtol = rtol(T)) + + # Inverse FFT + pinv = plan_ifft(d_Y) + d_Z = pinv * d_Y + Z = Array(d_Z) + @test isapprox(Z, X, rtol = rtol(T)) + + pinv2 = inv(p) + d_Z = pinv2 * d_Y + Z = Array(d_Z) + @test isapprox(Z, X, rtol = rtol(T)) + + # Backward FFT (unnormalized inverse) + pinvb = @inferred plan_bfft(d_Y) + d_Z = pinvb * d_Y + Z = Array(d_Z) ./ length(d_Z) + @test isapprox(Z, X, rtol = rtol(T)) + end + + function complex_in_place(X::AbstractArray{T, N}) where {T <: Complex, N} + fftw_X = fft(X) + d_X = MtlArray(copy(X)) + + # In-place forward FFT + p = @inferred plan_fft!(d_X) + p * d_X + Y = Array(d_X) + @test isapprox(Y, fftw_X, rtol = rtol(T)) + + # In-place inverse FFT + pinv = plan_ifft!(d_X) + pinv * d_X + Z = Array(d_X) + @test isapprox(Z, X, rtol = rtol(T)) + + # Reset and test bfft! + p * d_X + pinvb = @inferred plan_bfft!(d_X) + pinvb * d_X + Z = Array(d_X) ./ length(X) + @test isapprox(Z, X, rtol = rtol(T)) + end + + function complex_batched(X::AbstractArray{T, N}, region) where {T <: Complex, N} + fftw_X = fft(X, region) + d_X = MtlArray(X) + + p = plan_fft(d_X, region) + d_Y = p * d_X + d_X2 = reshape(d_X, (size(d_X)..., 1)) + @test_throws ArgumentError p * d_X2 + + Y = Array(d_Y) + @test isapprox(Y, fftw_X, rtol = rtol(T)) + + pinv = plan_ifft(d_Y, region) + d_Z = pinv * d_Y + Z = Array(d_Z) + @test isapprox(Z, X, rtol = rtol(T)) + + ldiv!(d_Z, p, d_Y) + Z = collect(d_Z) + @test isapprox(Z, X, rtol = rtol(T)) + end + + @testset "Complex FFT" begin + @testset for T in [ComplexF16, ComplexF32] + @testset "simple" begin + @testset "$(n)D" for n = 1:3 + sz = 40 + dims = ntuple(i -> sz, n) + @test testf(fft!, rand(T, dims)) + @test testf(ifft!, rand(T, dims)) + @test testf(bfft!, rand(T, dims)) + + @test testf(fft, rand(T, dims)) + @test testf(ifft, rand(T, dims)) + @test testf(bfft, rand(T, dims)) + end + end + + @testset "1D" begin + X = rand(T, N1) + complex_out_of_place(X) + end + + @testset "1D in-place" begin + X = rand(T, N1) + complex_in_place(X) + end + + @testset "2D" begin + X = rand(T, N1, N2) + complex_out_of_place(X) + end + + @testset "2D in-place" begin + X = rand(T, N1, N2) + complex_in_place(X) + end + + @testset "3D" begin + X = rand(T, N1, N2, N3) + complex_out_of_place(X) + end + + @testset "3D in-place" begin + X = rand(T, N1, N2, N3) + complex_in_place(X) + end + + @testset "Batch 1D" begin + dims = (N1, N2) + X = rand(T, dims) + complex_batched(X, 1) + + X = rand(T, dims) + complex_batched(X, 2) + + X = rand(T, dims) + complex_batched(X, (1, 2)) + end + + @testset "Batch 2D (in 3D)" begin + dims = (N1, N2, N3) + for region in [(1, 2), (2, 3), (1, 3)] + X = rand(T, dims) + complex_batched(X, region) + end + + X = rand(T, dims) + @test_throws ArgumentError complex_batched(X, (3, 1)) + end + @testset "Batch 2D (in 4D)" begin + dims = (N1, N2, N3, N4) + for region in [(1, 2), (1, 4), (3, 4), (1, 3), (2, 3), (2,), (3,)] + X = rand(T, dims) + complex_batched(X, region) + end + X = rand(T, dims) + complex_batched(X, (2, 4)) + end + end + end + + ## real FFT tests + function real_out_of_place(X::AbstractArray{T, N}) where {T <: Real, N} + fftw_X = rfft(X) + d_X = MtlArray(X) + + # Forward rfft with @inferred + p = @inferred plan_rfft(d_X) + d_Y = p * d_X + Y = Array(d_Y) + @test isapprox(Y, fftw_X, rtol = rtol(T)) + + # Inverse rfft + pinv = plan_irfft(d_Y, size(X, 1)) + d_Z = pinv * d_Y + Z = Array(d_Z) + @test isapprox(Z, X, rtol = rtol(T)) + + pinv2 = inv(p) + d_Z = pinv2 * d_Y + Z = Array(d_Z) + @test isapprox(Z, X, rtol = rtol(T)) + + pinv3 = inv(pinv) + d_W = pinv3 * d_X + W = Array(d_W) + @test isapprox(W, Y, rtol = rtol(T)) + + # Backward rfft (unnormalized) + pinvb = @inferred plan_brfft(d_Y, size(X, 1)) + d_Z = pinvb * d_Y + Z = Array(d_Z) ./ length(X) + @test isapprox(Z, X, rtol = rtol(T)) + end + + function real_batched(X::AbstractArray{T, N}, region) where {T <: Real, N} + fftw_X = rfft(X, region) + d_X = MtlArray(X) + + p = plan_rfft(d_X, region) + d_Y = p * d_X + Y = Array(d_Y) + @test isapprox(Y, fftw_X, rtol = rtol(T)) + + pinv = plan_irfft(d_Y, size(X, region[1]), region) + d_Z = pinv * d_Y + Z = Array(d_Z) + @test isapprox(Z, X, rtol = rtol(T)) + end + + @testset "Real FFT" begin + @testset for T in [Float16, Float32] + @testset "1D" begin + X = rand(T, N1) + real_out_of_place(X) + end + + @testset "Batch 1D" begin + dims = (N1, N2) + X = rand(T, dims) + real_batched(X, 1) + + X = rand(T, dims) + real_batched(X, 2) + + X = rand(T, dims) + real_batched(X, (1, 2)) + end + + @testset "2D" begin + X = rand(T, N1, N2) + real_out_of_place(X) + end + + @testset "Batch 2D (in 3D)" begin + dims = (N1, N2, N3) + for region in [(1, 2), (2, 3), (1, 3)] + X = rand(T, dims) + real_batched(X, region) + end + + X = rand(T, dims) + @test_throws ArgumentError real_batched(X, (3, 1)) + end + + @testset "Batch 2D (in 4D)" begin + dims = (N1,N2,N3,N4) + for region in [(1,2),(1,4),(3,4),(1,3),(2,3)] + X = rand(T, dims) + real_batched(X, region) + end + X = rand(T, dims) + real_batched(X, (2, 4)) + end + + @testset "3D" begin + X = rand(T, N1, N2, N3) + real_out_of_place(X) + end + end + end + + ## complex integer + function out_of_place(X::AbstractArray{T,N}) where {T <: Complex{<:Integer},N} + fftw_X = fft(ComplexF32.(X)) + d_X = MtlArray(X) + p = plan_fft(d_X) + d_Y = p * d_X + Y = collect(d_Y) + @test isapprox(Y, fftw_X, rtol = rtol(T)) + + d_Y = fft(d_X) + Y = collect(d_Y) + @test isapprox(Y, fftw_X, rtol = rtol(T)) + end + + @testset "1D $T" for T in [Complex{Int32}, Complex{Int64}] + dims = (N1,) + X = rand(T, dims) + out_of_place(X) + end + + + ## real integer + function out_of_place(X::AbstractArray{T,N}) where {T <: Integer,N} + fftw_X = rfft(Float32.(X)) + d_X = MtlArray(X) + p = plan_rfft(d_X) + d_Y = p * d_X + Y = collect(d_Y) + @test isapprox(Y, fftw_X, rtol = rtol(T)) + + d_Y = rfft(d_X) + Y = collect(d_Y) + @test isapprox(Y, fftw_X, rtol = rtol(T)) + end + + @testset "1D $T" for T in [Int32, Int64] + X = rand(T, N1) + out_of_place(X) + end + + + ## Additional Tests + @testset "Plan Properties" begin + x = MtlArray(randn(ComplexF32, 64, 64)) + p = plan_fft(x) + @test size(p) == (64, 64) + @test fftdims(p) == (1, 2) + + p2 = plan_fft(x, 1) + @test fftdims(p2) == (1,) + end + + @testset "mul! Interface" begin + x = MtlArray(randn(ComplexF32, 32, 32)) + y = similar(x) + p = plan_fft(x) + mul!(y, p, x) + @test isapprox(Array(y), fft(Array(x)), rtol = 1.0e-4) + + # Real FFT mul! + xr = MtlArray(randn(Float32, 32, 32)) + yr = MtlArray{ComplexF32}(undef, 17, 32) + pr = plan_rfft(xr) + mul!(yr, pr, xr) + @test isapprox(Array(yr), rfft(Array(xr)), rtol = 1.0e-4) + end + + @testset "Plan Reuse" begin + x1 = MtlArray(randn(ComplexF32, 64, 64)) + x2 = MtlArray(randn(ComplexF32, 64, 64)) + + p = plan_fft(x1) + y1 = p * x1 + y2 = p * x2 + + @test isapprox(Array(y1), fft(Array(x1)), rtol = 1.0e-4) + @test isapprox(Array(y2), fft(Array(x2)), rtol = 1.0e-4) + end + + @testset "Type Restrictions" begin + # ComplexF32 should work + x32 = MtlArray(randn(ComplexF32, 32, 32)) + @test plan_fft(x32) isa Metal.MtlFFTPlan + + # ComplexF16 should work + x16 = MtlArray(ComplexF16.(randn(ComplexF32, 32, 32))) + @test plan_fft(x16) isa Metal.MtlFFTPlan + + # Float32 rfft should work + xr32 = MtlArray(randn(Float32, 32, 32)) + @test plan_rfft(xr32) isa Metal.MtlFFTPlan + + # Float16 rfft should work + xr16 = MtlArray(Float16.(randn(Float32, 32, 32))) + @test plan_rfft(xr16) isa Metal.MtlFFTPlan + end + + @testset "Invalid Dimensions" begin + x = MtlArray(randn(ComplexF32, 32, 32)) + @test_throws ArgumentError plan_fft(x, 3) # Only 2 dimensions + @test_throws ArgumentError plan_fft(x, 0) # Invalid dimension + end + + @testset "Odd Sizes" begin + # Odd-sized irfft + x_cpu = randn(Float32, 63, 64) + y_cpu = rfft(x_cpu, 1) + y_gpu = MtlArray(y_cpu) + + z_cpu = irfft(y_cpu, 63, 1) + z_gpu = Array(irfft(y_gpu, 63, 1)) + + @test size(z_gpu) == (63, 64) + @test isapprox(z_cpu, z_gpu, rtol = 1.0e-4) + end + +end # MPS.is_supported(device())