Add FFT support via AbstractFFTs interface#713
Add FFT support via AbstractFFTs interface#713KaanKesginLW wants to merge 41 commits intoJuliaGPU:mainfrom
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #713 +/- ##
==========================================
- Coverage 83.09% 82.99% -0.10%
==========================================
Files 62 64 +2
Lines 2851 2999 +148
==========================================
+ Hits 2369 2489 +120
- Misses 482 510 +28 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Metal Benchmarks
Details
| Benchmark suite | Current: 9bd038f | Previous: b94fd4b | Ratio |
|---|---|---|---|
array/accumulate/Float32/1d |
1063959 ns |
1098958 ns |
0.97 |
array/accumulate/Float32/dims=1 |
1615375 ns |
1554708 ns |
1.04 |
array/accumulate/Float32/dims=1L |
10382979 ns |
9848583.5 ns |
1.05 |
array/accumulate/Float32/dims=2 |
2040354 ns |
1886771 ns |
1.08 |
array/accumulate/Float32/dims=2L |
7545979.5 ns |
7256459 ns |
1.04 |
array/accumulate/Int64/1d |
1287958 ns |
1261958 ns |
1.02 |
array/accumulate/Int64/dims=1 |
1886709 ns |
1824291.5 ns |
1.03 |
array/accumulate/Int64/dims=1L |
12201292 ns |
11664208.5 ns |
1.05 |
array/accumulate/Int64/dims=2 |
2271083.5 ns |
2170333.5 ns |
1.05 |
array/accumulate/Int64/dims=2L |
9885208 ns |
10120062.5 ns |
0.98 |
array/broadcast |
549000.5 ns |
605916 ns |
0.91 |
array/construct |
6750 ns |
6292 ns |
1.07 |
array/permutedims/2d |
1216916.5 ns |
1168125 ns |
1.04 |
array/permutedims/3d |
1807000 ns |
1673084 ns |
1.08 |
array/permutedims/4d |
2506313 ns |
2365959 ns |
1.06 |
array/private/copy |
821834 ns |
545792 ns |
1.51 |
array/private/copyto!/cpu_to_gpu |
724208 ns |
802916 ns |
0.90 |
array/private/copyto!/gpu_to_cpu |
674583.5 ns |
801917 ns |
0.84 |
array/private/copyto!/gpu_to_gpu |
522249.5 ns |
634458 ns |
0.82 |
array/private/iteration/findall/bool |
1464854 ns |
1402750 ns |
1.04 |
array/private/iteration/findall/int |
1588875 ns |
1564021 ns |
1.02 |
array/private/iteration/findfirst/bool |
2058334 ns |
2055916 ns |
1.00 |
array/private/iteration/findfirst/int |
2117375 ns |
2064479.5 ns |
1.03 |
array/private/iteration/findmin/1d |
2597812.5 ns |
2499959 ns |
1.04 |
array/private/iteration/findmin/2d |
1850812.5 ns |
1790791 ns |
1.03 |
array/private/iteration/logical |
2697166 ns |
2631896 ns |
1.02 |
array/private/iteration/scalar |
3268062 ns |
5047625 ns |
0.65 |
array/random/rand/Float32 |
840396 ns |
582958 ns |
1.44 |
array/random/rand/Int64 |
926000 ns |
775667 ns |
1.19 |
array/random/rand!/Float32 |
543354 ns |
574750 ns |
0.95 |
array/random/rand!/Int64 |
532709 ns |
550792 ns |
0.97 |
array/random/randn/Float32 |
1045959 ns |
1006937.5 ns |
1.04 |
array/random/randn!/Float32 |
721688 ns |
755666 ns |
0.96 |
array/reductions/mapreduce/Float32/1d |
808125 ns |
1029500 ns |
0.78 |
array/reductions/mapreduce/Float32/dims=1 |
815875 ns |
840875 ns |
0.97 |
array/reductions/mapreduce/Float32/dims=1L |
1360958 ns |
1324000 ns |
1.03 |
array/reductions/mapreduce/Float32/dims=2 |
848500 ns |
860875 ns |
0.99 |
array/reductions/mapreduce/Float32/dims=2L |
1820166.5 ns |
1799541 ns |
1.01 |
array/reductions/mapreduce/Int64/1d |
1311062.5 ns |
1374875 ns |
0.95 |
array/reductions/mapreduce/Int64/dims=1 |
1123083 ns |
1097625 ns |
1.02 |
array/reductions/mapreduce/Int64/dims=1L |
2005208 ns |
2002854 ns |
1.00 |
array/reductions/mapreduce/Int64/dims=2 |
1162042 ns |
1145000 ns |
1.01 |
array/reductions/mapreduce/Int64/dims=2L |
3647916 ns |
3614000 ns |
1.01 |
array/reductions/reduce/Float32/1d |
832500 ns |
1028437.5 ns |
0.81 |
array/reductions/reduce/Float32/dims=1 |
808125 ns |
832667 ns |
0.97 |
array/reductions/reduce/Float32/dims=1L |
1360333 ns |
1318416.5 ns |
1.03 |
array/reductions/reduce/Float32/dims=2 |
844270.5 ns |
853041.5 ns |
0.99 |
array/reductions/reduce/Float32/dims=2L |
1816916.5 ns |
1810250 ns |
1.00 |
array/reductions/reduce/Int64/1d |
1317166 ns |
1516958 ns |
0.87 |
array/reductions/reduce/Int64/dims=1 |
1130667 ns |
1095375 ns |
1.03 |
array/reductions/reduce/Int64/dims=1L |
2042083 ns |
2023499.5 ns |
1.01 |
array/reductions/reduce/Int64/dims=2 |
1170625 ns |
1240750 ns |
0.94 |
array/reductions/reduce/Int64/dims=2L |
4132354.5 ns |
4233875 ns |
0.98 |
array/shared/copy |
215000 ns |
252417 ns |
0.85 |
array/shared/copyto!/cpu_to_gpu |
83625 ns |
80750 ns |
1.04 |
array/shared/copyto!/gpu_to_cpu |
84000 ns |
80667 ns |
1.04 |
array/shared/copyto!/gpu_to_gpu |
84667 ns |
83083 ns |
1.02 |
array/shared/iteration/findall/bool |
1478334 ns |
1427208.5 ns |
1.04 |
array/shared/iteration/findall/int |
1586958 ns |
1559875 ns |
1.02 |
array/shared/iteration/findfirst/bool |
1672604 ns |
1649000 ns |
1.01 |
array/shared/iteration/findfirst/int |
1713729 ns |
1672458 ns |
1.02 |
array/shared/iteration/findmin/1d |
2206812.5 ns |
2115583 ns |
1.04 |
array/shared/iteration/findmin/2d |
1871021 ns |
1792625 ns |
1.04 |
array/shared/iteration/logical |
2259416 ns |
2292167 ns |
0.99 |
array/shared/iteration/scalar |
204562.5 ns |
199958 ns |
1.02 |
integration/byval/reference |
1582563 ns |
1544250 ns |
1.02 |
integration/byval/slices=1 |
1595791 ns |
1560229.5 ns |
1.02 |
integration/byval/slices=2 |
2704083 ns |
2598333.5 ns |
1.04 |
integration/byval/slices=3 |
19409291.5 ns |
8092333 ns |
2.40 |
integration/metaldevrt |
859791 ns |
868125 ns |
0.99 |
kernel/indexing |
473937.5 ns |
592667 ns |
0.80 |
kernel/indexing_checked |
500021 ns |
598292 ns |
0.84 |
kernel/launch |
12625 ns |
11791.5 ns |
1.07 |
kernel/rand |
538000 ns |
570709 ns |
0.94 |
latency/import |
1446676250 ns |
1425597062.5 ns |
1.01 |
latency/precompile |
25920507833 ns |
25453724708 ns |
1.02 |
latency/ttfp |
2381656291.5 ns |
2341177208 ns |
1.02 |
metal/synchronization/context |
19750 ns |
19667 ns |
1.00 |
metal/synchronization/stream |
19375 ns |
18459 ns |
1.05 |
This comment was automatically generated by workflow using github-action-benchmark.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as spam.
This comment was marked as spam.
40871e3 to
b88d77f
Compare
This comment was marked as resolved.
This comment was marked as resolved.
130ed6a to
e3aeeea
Compare
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as resolved.
This comment was marked as resolved.
e8d6b2c to
ffdffe8
Compare
ffdffe8 to
6f6aa3c
Compare
3c46dc2 to
60077f5
Compare
60077f5 to
25108fd
Compare
25108fd to
57a51b7
Compare
| end | ||
|
|
||
| @inline function _unsafe_execute!(f, p::MtlFFTPlan{T, S, backward, inplace, N}, x, y) where {T <: FFTNumber, S <: FFTNumber, N, backward, inplace} | ||
| graph = MPSGraph() |
There was a problem hiding this comment.
each execution creates a new MPSGraph? it causes quite some latency, would be great to cache graph somehow
57a51b7 to
e0db641
Compare
e0db641 to
c115f38
Compare
They can be added in a different PR
Claude was instructed to implement this like the MPSGraphs matmul caching
d3d9416 to
d9d399a
Compare
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Adds FFT support for MtlArray by implementing the AbstractFFTs.jl interface on top of MPSGraph FFT operations, along with a comprehensive test suite and required dependency wiring.
Changes:
- Implemented
AbstractFFTsplans and execution for Metal-backed arrays using cached MPSGraph FFT graphs. - Added MPSGraphs bindings for FFT descriptor creation and FFT operations.
- Added FFT-focused tests (including Float16 CPU shims) and updated project dependencies.
Reviewed changes
Copilot reviewed 6 out of 7 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
src/fft.jl |
Implements AbstractFFTs plans/execution for MtlArray, including MPSGraph graph caching. |
lib/mpsgraphs/fft.jl |
Adds Objective-C bindings for MPSGraph FFT descriptor and FFT ops. |
lib/mpsgraphs/MPSGraphs.jl |
Includes the new FFT bindings module. |
src/Metal.jl |
Wires FFT module into the main package and adds Reexport. |
Project.toml |
Adds AbstractFFTs and Reexport deps/compat. |
test/Project.toml |
Adds AbstractFFTs and FFTW for testing. |
test/fft.jl |
Adds a large FFT test suite and Float16 CPU reference shims. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Get or create cached graph | ||
| function _get_cached_graph!(graph_cache_lock, graph_cache, key::FFTGraphKey) | ||
| # Fast path: check cache without lock (safe for reads) | ||
| cached = get(graph_cache, key, nothing) | ||
| if cached !== nothing | ||
| return cached | ||
| end | ||
|
|
||
| # Slow path: acquire lock and build graph | ||
| @lock graph_cache_lock get!(graph_cache, key) do | ||
| CachedFFTGraph(key) | ||
| end | ||
| end |
There was a problem hiding this comment.
Reading from a Dict concurrently with writes is not thread-safe in Julia; the lockless get(graph_cache, ...) can race with the get! mutation and lead to memory corruption/crashes. Protect all Dict accesses (including reads) with the lock, or switch to a thread-safe cache strategy (e.g., always @lock around get! / get, or use a concurrency-safe map implementation if available in this codebase).
|
@maleadt This failure happened because AppleAccelerate.jl was loaded in the flopscomp example test which heppened to be on the same runner as the fft tests. Should we just remove AppleAccelerate? Also, Copilot seems to think that Dict reads while writes are potentially happening isn't thread safe. If that's true we'll need to fix the MPSGraph caching too. |
Yeah I guess. It's unfortunate that AppleAccelerate breaks stuff.
Yes, Dict isn't threadsafe. You'd want to wrap those concurrent accesses in a ReentrantLock. |
Adds FFT support for
MtlArrayvia the AbstractFFTs.jl interface.HEAVILY based on CUDA.jl's AbstractFFTs.jl interface implementation using MPSGraph functionality.
Performance
Benchmarked on Apple M2 Max with 30-core GPU against FFTW.jl on CPU:
Example Usage
Close #270