Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions aten/src/ATen/native/cuda/CUDALoops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,18 @@ static inline void launch_vectorized_kernel(
#endif
int bws = tws * num_threads();
int64_t grid = (N + bws - 1) / bws;
#if 0
// TODO: The change below needs to work with
// a grid-strided loop in `vectorized_elementwise_kernel`.
// Simiar to: https://github.com/pytorch/pytorch/pull/169474
#ifdef USE_ROCM
// Clamp the grid to ensure total threads (grid * num_threads)
// does not exceed the uint32_t limit of the HSA AQL packet.
// Use 4294967295 (UINT32_MAX) as the ceiling.
int64_t max_safe_grid = 4294967295LL / num_threads();
grid = std::min(grid, max_safe_grid);
#endif
#endif
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you add an #if 0 block? For commentary?

Copy link
Copy Markdown
Author

@glen-amd glen-amd Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. It's not required for the direct fix to the specific error we were working on. But it's identified as a potential (similar) issue while debugging the issue in question.
  2. As the TODO comment indicates, to be fully functional, this change would need some corresponding changes to the other function.
  3. So, I commented out this change (using #if 0) for now.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove from PR.
We can add it back when necessary.
Otherwise. Looks good to me

switch (vec_size) {
#ifdef USE_ROCM
case 16:
Expand Down
169 changes: 136 additions & 33 deletions aten/src/ATen/native/cuda/UpSampleNearest2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,23 @@ typedef int (*nn_compute_source_index_fn_t)(const float, int, int);
// nearest_neighbor_exact_bw_compute_source_index
typedef int (*nn_bw_compute_source_index_fn_t)(const float, int, int);

// The original implementation for ROCm assumed that gridDim.x fully covers width2
// and gridDim.y fully covers height2 (only the Z/nc dimension used a stride
// loop). The original code even had a TODO comment acknowledging this:
// "TODO: kernel implementation could stride on spatial dimension. We probably
// need to overhaul the kernel."
//
// On HIP/ROCm, gridDim.{x,y,z} * blockDim.{x,y,z} must be < 2^32 per
// dimension. For large spatial outputs, this constraint can be violated
// even when individual grid dimensions are within maxGridSize limits.
//
// The fix converts the X (width) and Y (height) dimensions to grid-stride
// loops, so the host-side launch template can clamp grid dimensions to
// safe values without losing coverage of the output tensor.
//
// This is safe on CUDA as well — when the grid already covers the full
// spatial range, each stride-loop body executes exactly once per thread.
//
// see NOTE [ Nearest neighbor upsampling kernel implementation ]
template <typename scalar_t, nn_compute_source_index_fn_t nn_compute_source_index_fn>
C10_LAUNCH_BOUNDS_1(1024)
Expand All @@ -48,34 +65,44 @@ __global__ void upsample_nearest2d_out_frame(
const size_t width2,
float height_scale,
float width_scale) {
size_t nc_iter = threadIdx.z + blockIdx.z * blockDim.z;
int64_t w2 = ((int64_t) threadIdx.x) + blockIdx.x * blockDim.x;
int64_t h2 = threadIdx.y + blockIdx.y * blockDim.y;

if (w2 >= width2 || h2 >= height2) {
return;
}

int64_t nc_stride = ((int64_t) blockDim.z) * gridDim.z;

const size_t h1 = height1 == height2
? h2
: nn_compute_source_index_fn(height_scale, h2, height1);
const size_t w1 = width1 == width2
? w2
: nn_compute_source_index_fn(width_scale, w2, width1);

size_t src_index = (nc_iter * height1 + h1) * width1 + w1;
size_t src_index_stride = nc_stride * width1 * height1;
size_t dst_index = (nc_iter * height2 + h2) * width2 + w2;
size_t dst_index_stride = nc_stride * width2 * height2;

// iterating over
while (nc_iter < nc) {
odata[dst_index] = idata[src_index];
dst_index += dst_index_stride;
src_index += src_index_stride;
nc_iter += nc_stride;
// Grid-stride loop over the width (X) dimension
for (int64_t w2 = static_cast<int64_t>(threadIdx.x) +
static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
w2 < static_cast<int64_t>(width2);
w2 += static_cast<int64_t>(blockDim.x) * static_cast<int64_t>(gridDim.x)) {

// Grid-stride loop over the height (Y) dimension
for (int64_t h2 = static_cast<int64_t>(threadIdx.y) +
static_cast<int64_t>(blockIdx.y) * static_cast<int64_t>(blockDim.y);
h2 < static_cast<int64_t>(height2);
h2 += static_cast<int64_t>(blockDim.y) * static_cast<int64_t>(gridDim.y)) {

const size_t h1 = height1 == height2
? static_cast<size_t>(h2)
: nn_compute_source_index_fn(height_scale, h2, height1);
const size_t w1 = width1 == width2
? static_cast<size_t>(w2)
: nn_compute_source_index_fn(width_scale, w2, width1);

// Grid-stride loop over the batch*channels (Z) dimension
// (This was already a stride loop in the original kernel.)
size_t nc_iter = threadIdx.z + blockIdx.z * blockDim.z;
int64_t nc_stride = static_cast<int64_t>(blockDim.z) * static_cast<int64_t>(gridDim.z);

size_t src_index = (nc_iter * height1 + h1) * width1 + w1;
size_t src_index_stride = nc_stride * width1 * height1;
size_t dst_index = (nc_iter * height2 + static_cast<size_t>(h2)) * width2 + static_cast<size_t>(w2);
size_t dst_index_stride = nc_stride * width2 * height2;

// iterating over batch*channels
while (nc_iter < nc) {
odata[dst_index] = idata[src_index];
dst_index += dst_index_stride;
src_index += src_index_stride;
nc_iter += nc_stride;
}
}
}
}

Expand All @@ -93,17 +120,19 @@ __global__ void upsample_nearest2d_nhwc_out_frame(
float width_scale,
const size_t out_numel) {

const int64_t index = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;
// Grid-stride loop to handle the case where the grid is clamped
// to satisfy HIP's gridDim.x * blockDim.x < 2^32 constraint.
for (int64_t index = static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
static_cast<int64_t>(threadIdx.x);
index < static_cast<int64_t>(out_numel);
index += static_cast<int64_t>(gridDim.x) * static_cast<int64_t>(blockDim.x)) {

if (index < out_numel) {
const auto c = index % channels;
const auto w2 = (index / channels) % width2;
const auto h2 = (index / channels / width2) % height2;
const auto n = index / channels / width2 / height2;

const size_t h1 = height1 == height2 ? h2 : nn_compute_source_index_fn(height_scale, h2, height1);
const size_t w1 = width1 == width2 ? w2 : nn_compute_source_index_fn(width_scale, w2, width1);

odata[index] = idata[idx_cl(n, h1, w1, c, height1, width1, channels)];
}
}
Expand Down Expand Up @@ -202,6 +231,33 @@ __global__ void upsample_nearest2d_backward_nhwc_out_frame(
}
}

// On HIP/ROCm, gridDim.{x,y,z} * blockDim.{x,y,z} must be < 2^32 per
// dimension (the HSA AQL dispatch packet stores global work sizes as
// uint32_t). Violating this returns hipErrorInvalidConfiguration.
//
// Related code snippet in ROCm/ROCR-Runtime and ROCm/clr:
// runtime/hsa-runtime/inc/hsa.h in ROCR-Runtime
// `uint32_t hsa_kernel_dispatch_packet_t.grid_size_x`
// hipamd/src/hip_module.cpp in clr
// `size_t globalWorkSizeX = static_cast<size_t>(gridDimX) * blockDimX;`
// rocclr/device/rocm/rocvirtual.cpp in clr
// `dispatchPacket.grid_size_x = sizes.dimensions() > 0 ? newGlobalSize[0] : 1;`
//
// This function has two kernel launch paths:
//
// (1) NHWC path: 1D grid launching upsample_nearest2d_nhwc_out_frame
// grid = ceil_div(output.numel(), 1024). For output.numel() near 2^32,
// grid * 1024 >= 2^32 → hits HIP limit.
// FIX: Clamp grid; the NHWC kernel must use a grid-stride loop.
//
// (2) Contiguous (NCHW) path: 3D grid launching upsample_nearest2d_out_frame
// grid_x covers output_width, grid_y covers output_height.
// The original code had a TORCH_CHECK for maxGridSize but NOT for the
// HIP-specific product overflow.
// FIX: Clamp grid_x and grid_y on ROCm. The corrected kernel adds
// grid-stride loops on X and Y so the clamped grid still covers the full output.
// The original TORCH_CHECK on maxGridSize is replaced with a
// proper clamp on ROCm; on CUDA, the original check is preserved.
template<nn_compute_source_index_fn_t nn_compute_source_index_fn>
static void upsample_nearest2d_out_cuda_template(
const Tensor& output,
Expand Down Expand Up @@ -234,6 +290,9 @@ static void upsample_nearest2d_out_cuda_template(
return;
}

// ===================================================================
// PATH 1: NHWC (channels-last) layout — 1D grid
// ===================================================================
// heuristic: only use channels_last path when it's faster than the contiguous path
if (memory_format == at::MemoryFormat::ChannelsLast && channels >= 4 && \
output.is_contiguous(memory_format)) {
Expand All @@ -247,11 +306,23 @@ static void upsample_nearest2d_out_cuda_template(
const int64_t num_kernels = output.numel();
const int64_t num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);

int64_t grid = ceil_div(num_kernels, num_threads);
#ifdef USE_ROCM
// HIP does not support gridDim.x * blockDim.x >= 2^32.
// Clamp grid so that grid * num_threads stays below UINT32_MAX.
// The NHWC kernel must use a grid-stride loop to cover all elements.
{
constexpr int64_t kHipMaxGlobalWorkSize = 4294967295LL; // UINT32_MAX
int64_t safe_max_grid = kHipMaxGlobalWorkSize / num_threads;
grid = std::min(grid, safe_max_grid);
}
#endif

AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, input.scalar_type(), "upsample_nearest2d_nhwc_out_frame", [&] {
const scalar_t* idata = input.const_data_ptr<scalar_t>();
scalar_t* odata = output.mutable_data_ptr<scalar_t>();
upsample_nearest2d_nhwc_out_frame<scalar_t, nn_compute_source_index_fn>
<<<ceil_div(num_kernels, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
<<<grid, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
idata,
odata,
channels,
Expand All @@ -266,6 +337,9 @@ static void upsample_nearest2d_out_cuda_template(
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
// ===================================================================
// PATH 2: Contiguous (NCHW) layout — 3D grid
// ===================================================================
else {
// This is needed for non-contiguous tensors.
Tensor output_c = output.is_contiguous() ? output : at::empty(output.sizes(), output.options());
Expand Down Expand Up @@ -293,7 +367,25 @@ static void upsample_nearest2d_out_cuda_template(
int grid_y = ceil_div(output_height, block_y);
int grid_z = std::min<int>(
maxGridSize[2], ceil_div(nc, (int64_t) block_z * 4));
const dim3 grid(grid_x, grid_y, grid_z);

#ifdef USE_ROCM
// HIP does not support gridDim.{x,y,z} * blockDim.{x,y,z} >= 2^32 per
// dimension. Clamp each grid dimension to satisfy this constraint.
// The corrected kernel uses grid-stride loops on X and Y (and already
// had a stride loop on Z), so the clamped grid still covers the full
// output tensor.
{
constexpr unsigned int kHipMaxGlobalWorkSize = 4294967295U; // UINT32_MAX
int safe_max_grid_x = static_cast<int>(kHipMaxGlobalWorkSize / static_cast<unsigned int>(block_x));
int safe_max_grid_y = static_cast<int>(kHipMaxGlobalWorkSize / static_cast<unsigned int>(block_y));
int safe_max_grid_z = static_cast<int>(kHipMaxGlobalWorkSize / static_cast<unsigned int>(block_z));
grid_x = std::min(grid_x, safe_max_grid_x);
grid_y = std::min(grid_y, safe_max_grid_y);
grid_z = std::min(grid_z, safe_max_grid_z);
}
#else
// On CUDA, the original check is sufficient — CUDA's maxGridSize[0] is
// 2^31-1 and there is no per-dimension product-overflow constraint.
// Error out on cases where grid_x & grid_y exceeds limit of launch config, as
// the current kernel implementation doesn't loop over the two dimensions.
// This is unlikely to happen.
Expand All @@ -302,6 +394,9 @@ static void upsample_nearest2d_out_cuda_template(
TORCH_CHECK(
grid_x <= maxGridSize[0] && grid_y <= maxGridSize[1],
"input tensor has spatial dimension larger than the kernel capacity");
#endif

const dim3 grid(grid_x, grid_y, grid_z);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, input.scalar_type(), "upsample_nearest2d_out_frame", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
Expand Down Expand Up @@ -329,6 +424,14 @@ static void upsample_nearest2d_out_cuda_template(
}
}

// TODO: Same pattern as the forward: both the NHWC 1D-grid path and
// the contiguous 1D-grid path need to USE_ROCM clamp grid size
// (guarded by `#if USE_ROCM) to avoid hitting the HIP runtime limits.
// The backward NHWC path is less likely to trigger the issue because
// it already has `TORCH_CHECK(grad_input.numel() < INT_MAX)`, meaning
// `grid * num_threads < INT_MAX < 2^32`. But the contiguous path uses
// `size_t n = grad_input.numel() / nbatch` which could exceed
// `UINT32_MAX / bdim.x`, so the clamp is still necessary for completeness.
template<nn_bw_compute_source_index_fn_t nn_bw_compute_source_index_fn>
static void upsample_nearest2d_backward_out_cuda_template(
const Tensor& grad_input,
Expand Down