diff --git a/src/compiler/execution.jl b/src/compiler/execution.jl index 663572140..239ebe0bd 100644 --- a/src/compiler/execution.jl +++ b/src/compiler/execution.jl @@ -265,15 +265,22 @@ end @autoreleasepool function (kernel::HostKernel)(args...; groups=1, threads=1, queue=global_queue(device())) - threadgroupsPerGrid = MTLSize(groups) - threadsPerThreadgroup = MTLSize(threads) - (threadgroupsPerGrid.width>0 && threadgroupsPerGrid.height>0 && threadgroupsPerGrid.depth>0) || + gs = MTLSize(groups) + ts = MTLSize(threads) + (gs.width>0 && gs.height>0 && gs.depth>0) || throw(ArgumentError("All group dimensions should be non-zero")) - (threadsPerThreadgroup.width>0 && threadsPerThreadgroup.height>0 && threadsPerThreadgroup.depth>0) || + (ts.width>0 && ts.height>0 && ts.depth>0) || throw(ArgumentError("All thread dimensions should be non-zero")) - (threadsPerThreadgroup.width * threadsPerThreadgroup.height * threadsPerThreadgroup.depth) > kernel.pipeline.maxTotalThreadsPerThreadgroup && - throw(ArgumentError("Number of threads in group ($(threadsPerThreadgroup.width * threadsPerThreadgroup.height * threadsPerThreadgroup.depth)) should not exceed $(kernel.pipeline.maxTotalThreadsPerThreadgroup)")) + (ts.width * ts.height * ts.depth) > kernel.pipeline.maxTotalThreadsPerThreadgroup && + throw(ArgumentError("Number of threads in group ($(ts.width * ts.height * ts.depth)) should not exceed $(kernel.pipeline.maxTotalThreadsPerThreadgroup)")) + + (gs.width * ts.width) > typemax(UInt32) && + throw(ArgumentError("Total threads per grid in a dimension (threads.width($(gs.width)) * groups.width($(ts.width)) = $(gs.width * ts.width)) must not exceed $(typemax(UInt32))")) + (gs.height * ts.height) > typemax(UInt32) && + throw(ArgumentError("Total threads per grid in a dimension (threads.height($(gs.height)) * groups.height($(ts.height)) = $(gs.height * ts.height)) must not exceed $(typemax(UInt32))")) + (gs.depth * ts.depth) > typemax(UInt32) && + throw(ArgumentError("Total threads per grid in a dimension (threads.depth($(gs.depth)) * groups.depth($(ts.depth)) = $(gs.depth * ts.depth)) must not exceed $(typemax(UInt32))")) kernel_state = KernelState(Random.rand(UInt32)) @@ -283,7 +290,7 @@ end argument_buffers = try MTL.set_function!(cce, kernel.pipeline) bufs = encode_arguments!(cce, kernel, kernel_state, kernel.f, args...) - MTL.append_current_function!(cce, threadgroupsPerGrid, threadsPerThreadgroup) + MTL.append_current_function!(cce, gs, ts) bufs finally close(cce) diff --git a/test/execution.jl b/test/execution.jl index 1124f5abc..2749de80c 100644 --- a/test/execution.jl +++ b/test/execution.jl @@ -164,6 +164,10 @@ end @test_throws InexactError @metal groups=(-2) tester(bufferA) @test_throws ArgumentError @metal threads=(1025) tester(bufferA) @test_throws ArgumentError @metal threads=(1000,2) tester(bufferA) + + @test_throws ArgumentError @metal threads=(1024,1,1) groups=(4194304,1,1) tester(bufferA) + @test_throws ArgumentError @metal threads=(1,1024,1) groups=(1,4194304,1) tester(bufferA) + @test_throws ArgumentError @metal threads=(1,1,1024) groups=(1,1,4194304) tester(bufferA) end ############################################################################################