-
Notifications
You must be signed in to change notification settings - Fork 61
Improve time-to-first-kernel #747
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
maleadt
wants to merge
2
commits into
main
Choose a base branch
from
tb/ttfx
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -179,11 +179,12 @@ The output of this function is automatically cached, i.e. you can simply call `m | |
| in a hot path without degrading performance. New code will be generated automatically when | ||
| the function changes, or when different types or keyword arguments are provided. | ||
| """ | ||
| function mtlfunction(f::F, tt::TT=Tuple{}; name=nothing, kwargs...) where {F,TT} | ||
| function mtlfunction(@nospecialize(f), @nospecialize(tt)=Tuple{}; name=nothing, kwargs...) | ||
| dev = device() | ||
| Base.@lock mtlfunction_lock begin | ||
| # compile the function | ||
| cache = compiler_cache(dev) | ||
| F = Core.Typeof(f) | ||
| source = methodinstance(F, tt) | ||
| config = compiler_config(dev; name, kwargs...)::MetalCompilerConfig | ||
| pipeline = GPUCompiler.cached_compilation(cache, source, config, compile, link) | ||
|
|
@@ -197,7 +198,7 @@ function mtlfunction(f::F, tt::TT=Tuple{}; name=nothing, kwargs...) where {F,TT} | |
| kernel = HostKernel{F,tt}(f, pipeline) | ||
| _kernel_instances[h] = kernel | ||
| end | ||
| return kernel::HostKernel{F,tt} | ||
| return kernel | ||
| end | ||
| end | ||
|
|
||
|
|
@@ -246,7 +247,7 @@ const _kernel_instances = Dict{UInt, Any}() | |
| ex | ||
| end | ||
|
|
||
| @inline function encode_argument!(kernel, arg) | ||
| @inline function encode_argument!(@nospecialize(kernel), arg) | ||
| argtyp = typeof(arg) | ||
|
|
||
| # replace non-isbits arguments (they should be unused, or compilation | ||
|
|
@@ -263,8 +264,20 @@ end | |
| return argument_buffer | ||
| end | ||
|
|
||
| # Thin outer callable — specializes on HostKernel{F,TT} and args, but the body | ||
| # is trivial (one function call) so it's fast to compile for new kernel types. | ||
| @autoreleasepool function (kernel::HostKernel)(args...; groups=1, threads=1, | ||
| queue=global_queue(device())) | ||
| _metal_launch(kernel, args, groups, threads, queue) | ||
| end | ||
|
|
||
| # Heavy body — @nospecialize, compiled once during precompilation and reused for | ||
| # all kernel types. All generic MTL operations (command buffer, validation, etc.) live here. | ||
| function _metal_launch(@nospecialize(kernel::HostKernel), @nospecialize(args::Tuple), | ||
| groups, threads, queue) | ||
| pipeline = kernel.pipeline::MTLComputePipelineState | ||
| f = kernel.f | ||
|
|
||
| gs = MTLSize(groups) | ||
| ts = MTLSize(threads) | ||
| (gs.width>0 && gs.height>0 && gs.depth>0) || | ||
|
|
@@ -284,42 +297,60 @@ end | |
|
|
||
| kernel_state = KernelState(Random.rand(UInt32)) | ||
|
|
||
| cmdbuf = MTLCommandBuffer(queue) | ||
| cmdbuf.label = "MTLCommandBuffer($(nameof(kernel.f)))" | ||
| cce = MTLComputeCommandEncoder(cmdbuf) | ||
| argument_buffers = try | ||
| MTL.set_function!(cce, kernel.pipeline) | ||
| bufs = encode_arguments!(cce, kernel, kernel_state, kernel.f, args...) | ||
| MTL.append_current_function!(cce, gs, ts) | ||
| bufs | ||
| finally | ||
| close(cce) | ||
| end | ||
| @autoreleasepool begin | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this |
||
| cmdbuf = MTLCommandBuffer(queue) | ||
| cmdbuf.label = "MTLCommandBuffer($(nameof(f)))" | ||
| cce = MTLComputeCommandEncoder(cmdbuf) | ||
| argument_buffers = try | ||
| MTL.set_function!(cce, pipeline) | ||
| bufs = _metal_encode(cce, kernel, kernel_state, f, args) | ||
| MTL.append_current_function!(cce, gs, ts) | ||
| bufs | ||
| finally | ||
| close(cce) | ||
| end | ||
|
|
||
| # the command buffer retains resources that are explicitly encoded (i.e. direct buffer | ||
| # arguments, or the buffers allocated for each other argument), but that doesn't keep | ||
| # other resources alive for which we've encoded the GPU address ourselves. since it's | ||
| # possible for buffers to go out of scope while the kernel is still running, which | ||
| # triggers validation failures, keep track of things we need to keep alive until the | ||
| # kernel has actually completed. | ||
| # | ||
| # TODO: is there a way to bind additional resources to the command buffer? | ||
| roots = [kernel.f, args] | ||
| MTL.on_completed(cmdbuf) do buf | ||
| empty!(roots) | ||
| foreach(free, argument_buffers) | ||
|
|
||
| # Check for errors | ||
| # XXX: we cannot do this nicely, e.g. throwing an `error` or reporting with `@error` | ||
| # because we're not allowed to switch tasks from this contexts. | ||
| if buf.status == MTL.MTLCommandBufferStatusError | ||
| Core.println("ERROR: Failed to submit command buffer: $(buf.error.localizedDescription)") | ||
| # during precompilation, skip GPU submission (which hangs) but keep the | ||
| # encoding path above to cache compilation of the argument encoding pipeline | ||
| if ccall(:jl_generating_output, Cint, ()) != 0 | ||
| foreach(free, argument_buffers) | ||
| return | ||
| end | ||
| end | ||
|
|
||
| commit!(cmdbuf) | ||
| # the command buffer retains resources that are explicitly encoded (i.e. direct buffer | ||
| # arguments, or the buffers allocated for each other argument), but that doesn't keep | ||
| # other resources alive for which we've encoded the GPU address ourselves. since it's | ||
| # possible for buffers to go out of scope while the kernel is still running, which | ||
| # triggers validation failures, keep track of things we need to keep alive until the | ||
| # kernel has actually completed. | ||
| # | ||
| # TODO: is there a way to bind additional resources to the command buffer? | ||
| # | ||
| # collect to Vector so the callback closure type doesn't | ||
| # depend on the specific kernel's argument buffer tuple type | ||
| argument_buffer_vec = collect(argument_buffers) | ||
| roots = Any[f, args] | ||
| MTL.on_completed(cmdbuf) do buf | ||
| empty!(roots) | ||
| foreach(free, argument_buffer_vec) | ||
|
|
||
| # Check for errors | ||
| # XXX: we cannot do this nicely, e.g. throwing an `error` or reporting with `@error` | ||
| # because we're not allowed to switch tasks from this contexts. | ||
| if buf.status == MTL.MTLCommandBufferStatusError | ||
| Core.println("ERROR: Failed to submit command buffer: $(buf.error.localizedDescription)") | ||
| end | ||
| end | ||
|
|
||
| commit!(cmdbuf) | ||
| end | ||
| end | ||
|
|
||
| # Bridge: specializes on f and args types (encode_arguments! needs concrete types | ||
| # for ghost-type detection), but NOT on the full HostKernel type. | ||
| @inline _metal_encode(cce, @nospecialize(kernel), kernel_state, f, args::Tuple) = | ||
| encode_arguments!(cce, kernel, kernel_state, f, args...) | ||
|
|
||
| ## Intra-warp Helpers | ||
|
|
||
| """ | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This despecialization breaks this inference test. But that seems to be on purpose so maybe remove?
Metal.jl/test/execution.jl
Lines 45 to 53 in 1d2f000
Only shows up when commenting out the
device_synchronizetest.