Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 65 additions & 34 deletions src/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,12 @@ The output of this function is automatically cached, i.e. you can simply call `m
in a hot path without degrading performance. New code will be generated automatically when
the function changes, or when different types or keyword arguments are provided.
"""
function mtlfunction(f::F, tt::TT=Tuple{}; name=nothing, kwargs...) where {F,TT}
function mtlfunction(@nospecialize(f), @nospecialize(tt)=Tuple{}; name=nothing, kwargs...)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This despecialization breaks this inference test. But that seems to be on purpose so maybe remove?

@testset "inference" begin
foo() = @metal dummy()
@inferred foo()
# with arguments, we call mtlconvert
kernel(a) = return
bar(a) = @metal kernel(a)
@inferred bar(MtlArray([1]))
end

Only shows up when commenting out the device_synchronize test.

dev = device()
Base.@lock mtlfunction_lock begin
# compile the function
cache = compiler_cache(dev)
F = Core.Typeof(f)
source = methodinstance(F, tt)
config = compiler_config(dev; name, kwargs...)::MetalCompilerConfig
pipeline = GPUCompiler.cached_compilation(cache, source, config, compile, link)
Expand All @@ -197,7 +198,7 @@ function mtlfunction(f::F, tt::TT=Tuple{}; name=nothing, kwargs...) where {F,TT}
kernel = HostKernel{F,tt}(f, pipeline)
_kernel_instances[h] = kernel
end
return kernel::HostKernel{F,tt}
return kernel
end
end

Expand Down Expand Up @@ -246,7 +247,7 @@ const _kernel_instances = Dict{UInt, Any}()
ex
end

@inline function encode_argument!(kernel, arg)
@inline function encode_argument!(@nospecialize(kernel), arg)
argtyp = typeof(arg)

# replace non-isbits arguments (they should be unused, or compilation
Expand All @@ -263,8 +264,20 @@ end
return argument_buffer
end

# Thin outer callable — specializes on HostKernel{F,TT} and args, but the body
# is trivial (one function call) so it's fast to compile for new kernel types.
@autoreleasepool function (kernel::HostKernel)(args...; groups=1, threads=1,
queue=global_queue(device()))
_metal_launch(kernel, args, groups, threads, queue)
end

# Heavy body — @nospecialize, compiled once during precompilation and reused for
# all kernel types. All generic MTL operations (command buffer, validation, etc.) live here.
function _metal_launch(@nospecialize(kernel::HostKernel), @nospecialize(args::Tuple),
groups, threads, queue)
pipeline = kernel.pipeline::MTLComputePipelineState
f = kernel.f

gs = MTLSize(groups)
ts = MTLSize(threads)
(gs.width>0 && gs.height>0 && gs.depth>0) ||
Expand All @@ -284,42 +297,60 @@ end

kernel_state = KernelState(Random.rand(UInt32))

cmdbuf = MTLCommandBuffer(queue)
cmdbuf.label = "MTLCommandBuffer($(nameof(kernel.f)))"
cce = MTLComputeCommandEncoder(cmdbuf)
argument_buffers = try
MTL.set_function!(cce, kernel.pipeline)
bufs = encode_arguments!(cce, kernel, kernel_state, kernel.f, args...)
MTL.append_current_function!(cce, gs, ts)
bufs
finally
close(cce)
end
@autoreleasepool begin
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this @autoreleasepool do anything? The parent function (@autoreleasepool function (kernel::HostKernel)(args...) is already annotated with one.

cmdbuf = MTLCommandBuffer(queue)
cmdbuf.label = "MTLCommandBuffer($(nameof(f)))"
cce = MTLComputeCommandEncoder(cmdbuf)
argument_buffers = try
MTL.set_function!(cce, pipeline)
bufs = _metal_encode(cce, kernel, kernel_state, f, args)
MTL.append_current_function!(cce, gs, ts)
bufs
finally
close(cce)
end

# the command buffer retains resources that are explicitly encoded (i.e. direct buffer
# arguments, or the buffers allocated for each other argument), but that doesn't keep
# other resources alive for which we've encoded the GPU address ourselves. since it's
# possible for buffers to go out of scope while the kernel is still running, which
# triggers validation failures, keep track of things we need to keep alive until the
# kernel has actually completed.
#
# TODO: is there a way to bind additional resources to the command buffer?
roots = [kernel.f, args]
MTL.on_completed(cmdbuf) do buf
empty!(roots)
foreach(free, argument_buffers)

# Check for errors
# XXX: we cannot do this nicely, e.g. throwing an `error` or reporting with `@error`
# because we're not allowed to switch tasks from this contexts.
if buf.status == MTL.MTLCommandBufferStatusError
Core.println("ERROR: Failed to submit command buffer: $(buf.error.localizedDescription)")
# during precompilation, skip GPU submission (which hangs) but keep the
# encoding path above to cache compilation of the argument encoding pipeline
if ccall(:jl_generating_output, Cint, ()) != 0
foreach(free, argument_buffers)
return
end
end

commit!(cmdbuf)
# the command buffer retains resources that are explicitly encoded (i.e. direct buffer
# arguments, or the buffers allocated for each other argument), but that doesn't keep
# other resources alive for which we've encoded the GPU address ourselves. since it's
# possible for buffers to go out of scope while the kernel is still running, which
# triggers validation failures, keep track of things we need to keep alive until the
# kernel has actually completed.
#
# TODO: is there a way to bind additional resources to the command buffer?
#
# collect to Vector so the callback closure type doesn't
# depend on the specific kernel's argument buffer tuple type
argument_buffer_vec = collect(argument_buffers)
roots = Any[f, args]
MTL.on_completed(cmdbuf) do buf
empty!(roots)
foreach(free, argument_buffer_vec)

# Check for errors
# XXX: we cannot do this nicely, e.g. throwing an `error` or reporting with `@error`
# because we're not allowed to switch tasks from this contexts.
if buf.status == MTL.MTLCommandBufferStatusError
Core.println("ERROR: Failed to submit command buffer: $(buf.error.localizedDescription)")
end
end

commit!(cmdbuf)
end
end

# Bridge: specializes on f and args types (encode_arguments! needs concrete types
# for ghost-type detection), but NOT on the full HostKernel type.
@inline _metal_encode(cce, @nospecialize(kernel), kernel_state, f, args::Tuple) =
encode_arguments!(cce, kernel, kernel_state, f, args...)

## Intra-warp Helpers

"""
Expand Down
49 changes: 37 additions & 12 deletions src/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,43 @@ using PrecompileTools: @setup_workload, @compile_workload
@setup_workload begin
metallib_file = joinpath(dirname(@__DIR__), "test", "dummy.metallib")

# parsing and writing metal libraries
metallib = parse(MetalLib, metallib_file)
sprint(write, metallib)
end
@compile_workload begin
# parsing and writing metal libraries
metallib = parse(MetalLib, metallib_file)
sprint(write, metallib)

# launch a trivial kernel to precompile the full pipeline:
# mtlfunction → GPUCompiler → LLVM IR → AIR → metallib → link → launch
# (GPU submission is skipped during precompilation, but the entire
# encoding path — encode_arguments!, command buffer setup, etc. — runs)
kernel() = return
@metal kernel()

precompile(compile, (CompilerJob,))
precompile(Tuple{typeof(GPUCompiler.finish_ir!), GPUCompiler.CompilerJob{GPUCompiler.MetalCompilerTarget, Metal.MetalCompilerParams}, LLVM.Module, LLVM.Function})
precompile(Tuple{typeof(GPUCompiler.finish_module!), GPUCompiler.CompilerJob{GPUCompiler.MetalCompilerTarget, Metal.MetalCompilerParams}, LLVM.Module, LLVM.Function})
precompile(Tuple{typeof(GPUCompiler.check_ir), GPUCompiler.CompilerJob{GPUCompiler.MetalCompilerTarget, Metal.MetalCompilerParams}, LLVM.Module})
precompile(Tuple{typeof(GPUCompiler.actual_compilation), Base.Dict{Any, Any}, Core.MethodInstance, UInt64, GPUCompiler.CompilerConfig{GPUCompiler.MetalCompilerTarget, Metal.MetalCompilerParams}, typeof(Metal.compile), typeof(Metal.link)})
# launch a realistic kernel with array arguments
a = MtlArray(Float32[1])
b = MtlArray(Float32[1])
c = MtlArray(Float32[0])
function precompile_vadd(a, b, c)
i = thread_position_in_grid().x
c[i] = a[i] + b[i]
return
end
@metal precompile_vadd(a, b, c)

# Worth the hassle
if isdefined(Base, :Compiler) && isdefined(Base.Compiler, :typeinf_local)
precompile(Tuple{typeof(Base.Compiler.typeinf_local), GPUCompiler.GPUInterpreter{Base.Compiler.CachedMethodTable{Base.Compiler.OverlayMethodTable}}, Base.Compiler.InferenceState, Base.Compiler.CurrentState})
# also exercise 2D arrays (common in real workloads)
a2 = MtlArray(Float32[1 1])
b2 = MtlArray(Float32[1 1])
c2 = MtlArray(Float32[0 0])
@metal precompile_vadd(a2, b2, c2)

# precompile MtlArray → Array copy-back
Array(c)
Array(c2)
end
end

# GPUCompiler macro-expansion utilities (compiled at every @metal callsite).
# split_kwargs is called with 3 Vector{Symbol} groups (MACRO/COMPILER/LAUNCH_KWARGS).
precompile(Tuple{typeof(GPUCompiler.split_kwargs), Tuple{Expr}, Vector{Symbol}, Vector{Symbol}, Vector{Symbol}})
precompile(Tuple{typeof(GPUCompiler.split_kwargs), Tuple{}, Vector{Symbol}, Vector{Symbol}, Vector{Symbol}})
precompile(Tuple{typeof(GPUCompiler.assign_args!), Expr, Vector{Any}})