Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ export MetalBackend

include("deprecated.jl")

include("warmup.jl")

include("precompile.jl")

end # module
19 changes: 15 additions & 4 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
try
dev = device()
return supports_family(dev, MTL.MTLGPUFamilyApple7) &&
supports_family(dev, MTL.MTLGPUFamilyMetal3)
supports_family(dev, MTL.MTLGPUFamilyMetal3)
catch
return false
end
end
else
# Becomes `nothing` once it has been determined that the device is on macOS
const _functional = Ref{Union{Nothing,Bool}}(false)
const _functional = Ref{Union{Nothing, Bool}}(false)

function functional()
if isnothing(_functional[])
Expand All @@ -24,6 +24,10 @@ else
end
end

# Async warmup system to reduce first-kernel JIT compilation latency
const _warmup_task = Ref{Union{Nothing, Task}}(nothing)
const _warmup_enabled = @load_preference("warmup", true)

function __init__()
precompiling = ccall(:jl_generating_output, Cint, ()) != 0
precompiling && return
Expand Down Expand Up @@ -63,7 +67,7 @@ function __init__()
_functional[] = nothing # VERSION <= v"1.12.0-DEV.1421"
end
catch err
@error "Failed to load Metal" exception=(err,catch_backtrace())
@error "Failed to load Metal" exception = (err, catch_backtrace())
return
end

Expand All @@ -72,10 +76,17 @@ function __init__()
if isdefined(Base, :active_repl_backend) && !isnothing(Base.active_repl_backend)
push!(Base.active_repl_backend.ast_transforms, synchronize_metal_tasks)
end

# Start async warmup to reduce first-kernel JIT compilation latency.
# Only run with multiple threads - with a single thread, the async task would
# block the main thread due to Julia's cooperative task runtime.
return if functional() && _warmup_enabled && Threads.nthreads() > 1
_warmup_task[] = errormonitor(@async _warmup_compilation())
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_warmup_task[] = errormonitor(@async _warmup_compilation())
_warmup_task[] = errormonitor(Threads.@spawn _warmup_compilation())

@async is pinned to the same thread as parent.

end
end

function synchronize_metal_tasks(ex)
quote
return quote
try
$(ex)
finally
Expand Down
76 changes: 76 additions & 0 deletions src/warmup.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Async warmup to reduce first-kernel JIT compilation latency
#
# The first GPU kernel in a Metal.jl session takes ~1.75s due to one-time JIT
# compilation of GPUCompiler internals. By starting a minimal kernel compilation
# in the background during __init__(), we can reduce this to 0.035-0.20s for the
# user's first actual kernel—a 9-50x improvement.
#
# NOTE: Warmup only runs when multiple threads are available (Threads.nthreads() > 1).
# With a single thread, async warmup would block the main thread due to Julia's
# cooperative task runtime, potentially hurting perceived latency.

# Minimal kernel that triggers the full compilation pipeline
function _warmup_kernel!(a)
i = thread_position_in_grid().x
if i <= length(a)
a[i] = 0.0f0
end
return nothing
end

# Called from __init__() via @async
function _warmup_compilation()
try
# Minimal allocation - just need to trigger compilation
arr = MtlArray{Float32}(undef, 1)
# launch=false compiles but doesn't execute - fastest warmup path
@metal launch = false _warmup_kernel!(arr)
unsafe_free!(arr)
catch
# Silently ignore warmup failures - this is a non-critical optimization
end
return nothing
end

"""
Metal.warmup(; blocking::Bool=true)

Ensure the GPU compilation pipeline is warmed up.

The first GPU kernel in a Metal.jl session incurs a one-time JIT compilation overhead
of ~1.7 seconds. When running with multiple threads (`julia -t auto`), Metal.jl
automatically starts warming up in the background when the package is loaded.
This function allows you to explicitly wait for warmup to complete.

If `blocking=true` (default), waits for warmup to complete before returning.
If `blocking=false`, returns immediately while warmup continues in background.

# When to use

Call `Metal.warmup()` before timing-sensitive code to ensure consistent benchmark results:

```julia
using Metal
Metal.warmup() # wait for warmup to complete
@time @metal kernel!(a) # consistently fast (~0.035s, not ~1.7s)
```

# Note

- Background warmup only runs with multiple threads. With a single thread, async
warmup would block the main thread due to Julia's cooperative task runtime.
- You never need to call this function for correctness—only for consistent timing.
- Most users will never need to call this explicitly, as the background warmup will
complete during normal program setup (loading data, preprocessing, etc.).
"""
function warmup(; blocking::Bool = true)
task = _warmup_task[]
if task === nothing
# Warmup wasn't started (non-functional GPU or disabled)
return nothing
end
if blocking
wait(task)
end
return nothing
end
68 changes: 68 additions & 0 deletions test/warmup.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
@testset "warmup" begin
@testset "warmup API" begin
# warmup() should always return nothing, regardless of thread configuration
@test Metal.warmup() === nothing
@test Metal.warmup(blocking = false) === nothing
@test Metal.warmup(blocking = true) === nothing

# Multiple calls should be safe
@test Metal.warmup() === nothing
@test Metal.warmup() === nothing
end

@testset "kernel compilation after warmup" begin
Metal.warmup()

# Define and compile a test kernel
function test_kernel!(a)
i = thread_position_in_grid().x
if i <= length(a)
a[i] = Float32(i)
end
return nothing
end

a = MtlArray{Float32}(undef, 256)
@metal threads = 256 test_kernel!(a)
synchronize()

# Verify the kernel executed correctly
result = Array(a)
@test result[1] == 1.0f0
@test result[128] == 128.0f0
@test result[256] == 256.0f0
end

@testset "concurrent kernel compilation" begin
Metal.warmup()

# Define two distinct kernels
function kernel_add!(a)
i = thread_position_in_grid().x
if i <= length(a)
a[i] += 1.0f0
end
return nothing
end

function kernel_mul!(a)
i = thread_position_in_grid().x
if i <= length(a)
a[i] *= 2.0f0
end
return nothing
end

a = MtlArray(ones(Float32, 64))
b = MtlArray(ones(Float32, 64))

# Compile and run both kernels
@metal threads = 64 kernel_add!(a)
@metal threads = 64 kernel_mul!(b)
synchronize()

# Verify both executed correctly
@test Array(a)[1] == 2.0f0
@test Array(b)[1] == 2.0f0
end
end