Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ export MetalBackend

include("deprecated.jl")

include("warmup.jl")

include("precompile.jl")

end # module
9 changes: 9 additions & 0 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ else
end
end

# Async warmup system to reduce first-kernel JIT compilation latency
const _warmup_task = Ref{Union{Nothing, Task}}(nothing)
const _warmup_enabled = @load_preference("warmup", true)

function __init__()
precompiling = ccall(:jl_generating_output, Cint, ()) != 0
precompiling && return
Expand Down Expand Up @@ -72,6 +76,11 @@ function __init__()
if isdefined(Base, :active_repl_backend) && !isnothing(Base.active_repl_backend)
push!(Base.active_repl_backend.ast_transforms, synchronize_metal_tasks)
end

# start async warmup to reduce first-kernel JIT compilation latency
if functional() && _warmup_enabled
_warmup_task[] = errormonitor(@async _warmup_compilation())
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_warmup_task[] = errormonitor(@async _warmup_compilation())
_warmup_task[] = errormonitor(Threads.@spawn _warmup_compilation())

@async is pinned to the same thread as parent.

end
end

function synchronize_metal_tasks(ex)
Expand Down
71 changes: 71 additions & 0 deletions src/warmup.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Async warmup to reduce first-kernel JIT compilation latency
#
# The first GPU kernel in a Metal.jl session takes ~1.75s due to one-time JIT
# compilation of GPUCompiler internals. By starting a minimal kernel compilation
# in the background during __init__(), we can reduce this to 0.035-0.20s for the
# user's first actual kernel—a 9-50x improvement.

export warmup

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should fix the benchmark error.

Suggested change
export warmup

# Minimal kernel that triggers the full compilation pipeline
function _warmup_kernel!(a)
i = thread_position_in_grid().x
if i <= length(a)
a[i] = 0.0f0
end
return nothing
end

# Called from __init__() via @async
function _warmup_compilation()
try
# Minimal allocation - just need to trigger compilation
arr = MtlArray{Float32}(undef, 1)
# launch=false compiles but doesn't execute - fastest warmup path
@metal launch=false _warmup_kernel!(arr)
unsafe_free!(arr)
catch
# Silently ignore warmup failures - this is a non-critical optimization
end
return nothing
end

"""
warmup(; blocking::Bool=true)

Ensure the GPU compilation pipeline is warmed up.

The first GPU kernel in a Metal.jl session incurs a one-time JIT compilation overhead
of ~1.7 seconds. Metal.jl automatically starts warming up in the background when the
package is loaded. This function allows you to explicitly wait for warmup to complete.

If `blocking=true` (default), waits for warmup to complete before returning.
If `blocking=false`, returns immediately while warmup continues in background.

# When to use

Call `warmup()` before timing-sensitive code to ensure consistent benchmark results:

```julia
using Metal
Metal.warmup() # wait for warmup to complete
@time @metal kernel!(a) # consistently fast (~0.035s, not ~1.7s)
```

# Note

You never need to call this function for correctness—only for consistent timing.
Most users will never need to call this explicitly, as the background warmup will
complete during normal program setup (loading data, preprocessing, etc.).
"""
function warmup(; blocking::Bool=true)
task = _warmup_task[]
if task === nothing
# Warmup wasn't started (non-functional GPU or disabled)
return nothing
end
if blocking
wait(task)
end
return nothing
end
66 changes: 66 additions & 0 deletions test/warmup.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
@testset "warmup" begin
@testset "warmup task started" begin
# Warmup should have been started during __init__
@test Metal._warmup_task[] !== nothing
@test Metal._warmup_enabled == true
end

@testset "warmup API" begin
# Non-blocking call should return immediately
@test Metal.warmup(blocking=false) === nothing

# Blocking call should wait and return nothing
@test Metal.warmup() === nothing
@test Metal.warmup(blocking=true) === nothing
end

@testset "warmup task completion" begin
# After calling warmup(), task should be done
Metal.warmup()
task = Metal._warmup_task[]
@test istaskdone(task)
@test !istaskfailed(task)
end

@testset "warmup accelerates compilation" begin
# After warmup, kernel compilation should be fast
Metal.warmup()

function test_kernel!(a)
i = thread_position_in_grid().x
if i <= length(a)
a[i] = 1.0f0
end
return nothing
end

a = MtlArray{Float32}(undef, 256)
t = @elapsed @metal launch=false test_kernel!(a)

# After warmup, compilation should be under 0.5s
# (without warmup it would be ~1.7s)
@test t < 0.5
end

@testset "concurrent kernel compilation" begin
# Verify that concurrent compilations don't deadlock
Metal.warmup()

function k1!(a)
a[1] = 1.0f0
return nothing
end
function k2!(a)
a[1] = 2.0f0
return nothing
end

a = MtlArray{Float32}(undef, 1)

t1 = @async @metal launch=false k1!(a)
t2 = @async @metal launch=false k2!(a)

# Should complete without deadlock (with timeout)
@test timedwait(() -> istaskdone(t1) && istaskdone(t2), 10.0) == :ok
end
end