diff --git a/aten/src/ATen/native/cuda/UpSampleNearest2d.cu b/aten/src/ATen/native/cuda/UpSampleNearest2d.cu index e9470b3a6b473..e5bd406ab5a94 100644 --- a/aten/src/ATen/native/cuda/UpSampleNearest2d.cu +++ b/aten/src/ATen/native/cuda/UpSampleNearest2d.cu @@ -35,6 +35,23 @@ typedef int (*nn_compute_source_index_fn_t)(const float, int, int); // nearest_neighbor_exact_bw_compute_source_index typedef int (*nn_bw_compute_source_index_fn_t)(const float, int, int); +// The original implementation for ROCm assumed that gridDim.x fully covers width2 +// and gridDim.y fully covers height2 (only the Z/nc dimension used a stride +// loop). The original code even had a TODO comment acknowledging this: +// "TODO: kernel implementation could stride on spatial dimension. We probably +// need to overhaul the kernel." +// +// On HIP/ROCm, gridDim.{x,y,z} * blockDim.{x,y,z} must be < 2^32 per +// dimension. For large spatial outputs, this constraint can be violated +// even when individual grid dimensions are within maxGridSize limits. +// +// The fix converts the X (width) and Y (height) dimensions to grid-stride +// loops, so the host-side launch template can clamp grid dimensions to +// safe values without losing coverage of the output tensor. +// +// This is safe on CUDA as well — when the grid already covers the full +// spatial range, each stride-loop body executes exactly once per thread. +// // see NOTE [ Nearest neighbor upsampling kernel implementation ] template C10_LAUNCH_BOUNDS_1(1024) @@ -48,34 +65,44 @@ __global__ void upsample_nearest2d_out_frame( const size_t width2, float height_scale, float width_scale) { - size_t nc_iter = threadIdx.z + blockIdx.z * blockDim.z; - int64_t w2 = ((int64_t) threadIdx.x) + blockIdx.x * blockDim.x; - int64_t h2 = threadIdx.y + blockIdx.y * blockDim.y; - if (w2 >= width2 || h2 >= height2) { - return; - } - - int64_t nc_stride = ((int64_t) blockDim.z) * gridDim.z; - - const size_t h1 = height1 == height2 - ? h2 - : nn_compute_source_index_fn(height_scale, h2, height1); - const size_t w1 = width1 == width2 - ? w2 - : nn_compute_source_index_fn(width_scale, w2, width1); - - size_t src_index = (nc_iter * height1 + h1) * width1 + w1; - size_t src_index_stride = nc_stride * width1 * height1; - size_t dst_index = (nc_iter * height2 + h2) * width2 + w2; - size_t dst_index_stride = nc_stride * width2 * height2; - - // iterating over - while (nc_iter < nc) { - odata[dst_index] = idata[src_index]; - dst_index += dst_index_stride; - src_index += src_index_stride; - nc_iter += nc_stride; + // Grid-stride loop over the width (X) dimension + for (int64_t w2 = static_cast(threadIdx.x) + + static_cast(blockIdx.x) * static_cast(blockDim.x); + w2 < static_cast(width2); + w2 += static_cast(blockDim.x) * static_cast(gridDim.x)) { + + // Grid-stride loop over the height (Y) dimension + for (int64_t h2 = static_cast(threadIdx.y) + + static_cast(blockIdx.y) * static_cast(blockDim.y); + h2 < static_cast(height2); + h2 += static_cast(blockDim.y) * static_cast(gridDim.y)) { + + const size_t h1 = height1 == height2 + ? static_cast(h2) + : nn_compute_source_index_fn(height_scale, h2, height1); + const size_t w1 = width1 == width2 + ? static_cast(w2) + : nn_compute_source_index_fn(width_scale, w2, width1); + + // Grid-stride loop over the batch*channels (Z) dimension + // (This was already a stride loop in the original kernel.) + size_t nc_iter = threadIdx.z + blockIdx.z * blockDim.z; + int64_t nc_stride = static_cast(blockDim.z) * static_cast(gridDim.z); + + size_t src_index = (nc_iter * height1 + h1) * width1 + w1; + size_t src_index_stride = nc_stride * width1 * height1; + size_t dst_index = (nc_iter * height2 + static_cast(h2)) * width2 + static_cast(w2); + size_t dst_index_stride = nc_stride * width2 * height2; + + // iterating over batch*channels + while (nc_iter < nc) { + odata[dst_index] = idata[src_index]; + dst_index += dst_index_stride; + src_index += src_index_stride; + nc_iter += nc_stride; + } + } } } @@ -93,17 +120,19 @@ __global__ void upsample_nearest2d_nhwc_out_frame( float width_scale, const size_t out_numel) { - const int64_t index = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x; + // Grid-stride loop to handle the case where the grid is clamped + // to satisfy HIP's gridDim.x * blockDim.x < 2^32 constraint. + for (int64_t index = static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); + index < static_cast(out_numel); + index += static_cast(gridDim.x) * static_cast(blockDim.x)) { - if (index < out_numel) { const auto c = index % channels; const auto w2 = (index / channels) % width2; const auto h2 = (index / channels / width2) % height2; const auto n = index / channels / width2 / height2; - const size_t h1 = height1 == height2 ? h2 : nn_compute_source_index_fn(height_scale, h2, height1); const size_t w1 = width1 == width2 ? w2 : nn_compute_source_index_fn(width_scale, w2, width1); - odata[index] = idata[idx_cl(n, h1, w1, c, height1, width1, channels)]; } } @@ -202,6 +231,33 @@ __global__ void upsample_nearest2d_backward_nhwc_out_frame( } } +// On HIP/ROCm, gridDim.{x,y,z} * blockDim.{x,y,z} must be < 2^32 per +// dimension (the HSA AQL dispatch packet stores global work sizes as +// uint32_t). Violating this returns hipErrorInvalidConfiguration. +// +// Related code snippet in ROCm/ROCR-Runtime and ROCm/clr: +// runtime/hsa-runtime/inc/hsa.h in ROCR-Runtime +// `uint32_t hsa_kernel_dispatch_packet_t.grid_size_x` +// hipamd/src/hip_module.cpp in clr +// `size_t globalWorkSizeX = static_cast(gridDimX) * blockDimX;` +// rocclr/device/rocm/rocvirtual.cpp in clr +// `dispatchPacket.grid_size_x = sizes.dimensions() > 0 ? newGlobalSize[0] : 1;` +// +// This function has two kernel launch paths: +// +// (1) NHWC path: 1D grid launching upsample_nearest2d_nhwc_out_frame +// grid = ceil_div(output.numel(), 1024). For output.numel() near 2^32, +// grid * 1024 >= 2^32 → hits HIP limit. +// FIX: Clamp grid; the NHWC kernel must use a grid-stride loop. +// +// (2) Contiguous (NCHW) path: 3D grid launching upsample_nearest2d_out_frame +// grid_x covers output_width, grid_y covers output_height. +// The original code had a TORCH_CHECK for maxGridSize but NOT for the +// HIP-specific product overflow. +// FIX: Clamp grid_x and grid_y on ROCm. The corrected kernel adds +// grid-stride loops on X and Y so the clamped grid still covers the full output. +// The original TORCH_CHECK on maxGridSize is replaced with a +// proper clamp on ROCm; on CUDA, the original check is preserved. template static void upsample_nearest2d_out_cuda_template( const Tensor& output, @@ -234,6 +290,9 @@ static void upsample_nearest2d_out_cuda_template( return; } + // =================================================================== + // PATH 1: NHWC (channels-last) layout — 1D grid + // =================================================================== // heuristic: only use channels_last path when it's faster than the contiguous path if (memory_format == at::MemoryFormat::ChannelsLast && channels >= 4 && \ output.is_contiguous(memory_format)) { @@ -247,11 +306,23 @@ static void upsample_nearest2d_out_cuda_template( const int64_t num_kernels = output.numel(); const int64_t num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024); + int64_t grid = ceil_div(num_kernels, num_threads); +#ifdef USE_ROCM + // HIP does not support gridDim.x * blockDim.x >= 2^32. + // Clamp grid so that grid * num_threads stays below UINT32_MAX. + // The NHWC kernel must use a grid-stride loop to cover all elements. + { + constexpr int64_t kHipMaxGlobalWorkSize = 4294967295LL; // UINT32_MAX + int64_t safe_max_grid = kHipMaxGlobalWorkSize / num_threads; + grid = std::min(grid, safe_max_grid); + } +#endif + AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, input.scalar_type(), "upsample_nearest2d_nhwc_out_frame", [&] { const scalar_t* idata = input.const_data_ptr(); scalar_t* odata = output.mutable_data_ptr(); upsample_nearest2d_nhwc_out_frame - <<>>( + <<>>( idata, odata, channels, @@ -266,6 +337,9 @@ static void upsample_nearest2d_out_cuda_template( C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } + // =================================================================== + // PATH 2: Contiguous (NCHW) layout — 3D grid + // =================================================================== else { // This is needed for non-contiguous tensors. Tensor output_c = output.is_contiguous() ? output : at::empty(output.sizes(), output.options()); @@ -293,7 +367,25 @@ static void upsample_nearest2d_out_cuda_template( int grid_y = ceil_div(output_height, block_y); int grid_z = std::min( maxGridSize[2], ceil_div(nc, (int64_t) block_z * 4)); - const dim3 grid(grid_x, grid_y, grid_z); + +#ifdef USE_ROCM + // HIP does not support gridDim.{x,y,z} * blockDim.{x,y,z} >= 2^32 per + // dimension. Clamp each grid dimension to satisfy this constraint. + // The corrected kernel uses grid-stride loops on X and Y (and already + // had a stride loop on Z), so the clamped grid still covers the full + // output tensor. + { + constexpr unsigned int kHipMaxGlobalWorkSize = 4294967295U; // UINT32_MAX + int safe_max_grid_x = static_cast(kHipMaxGlobalWorkSize / static_cast(block_x)); + int safe_max_grid_y = static_cast(kHipMaxGlobalWorkSize / static_cast(block_y)); + int safe_max_grid_z = static_cast(kHipMaxGlobalWorkSize / static_cast(block_z)); + grid_x = std::min(grid_x, safe_max_grid_x); + grid_y = std::min(grid_y, safe_max_grid_y); + grid_z = std::min(grid_z, safe_max_grid_z); + } +#else + // On CUDA, the original check is sufficient — CUDA's maxGridSize[0] is + // 2^31-1 and there is no per-dimension product-overflow constraint. // Error out on cases where grid_x & grid_y exceeds limit of launch config, as // the current kernel implementation doesn't loop over the two dimensions. // This is unlikely to happen. @@ -302,6 +394,9 @@ static void upsample_nearest2d_out_cuda_template( TORCH_CHECK( grid_x <= maxGridSize[0] && grid_y <= maxGridSize[1], "input tensor has spatial dimension larger than the kernel capacity"); +#endif + + const dim3 grid(grid_x, grid_y, grid_z); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, input.scalar_type(), "upsample_nearest2d_out_frame", [&] { using accscalar_t = at::acc_type; @@ -329,6 +424,14 @@ static void upsample_nearest2d_out_cuda_template( } } +// TODO: Same pattern as the forward: both the NHWC 1D-grid path and +// the contiguous 1D-grid path need to USE_ROCM clamp grid size +// (guarded by `#if USE_ROCM) to avoid hitting the HIP runtime limits. +// The backward NHWC path is less likely to trigger the issue because +// it already has `TORCH_CHECK(grad_input.numel() < INT_MAX)`, meaning +// `grid * num_threads < INT_MAX < 2^32`. But the contiguous path uses +// `size_t n = grad_input.numel() / nbatch` which could exceed +// `UINT32_MAX / bdim.x`, so the clamp is still necessary for completeness. template static void upsample_nearest2d_backward_out_cuda_template( const Tensor& grad_input,