@@ -35,7 +35,8 @@ __global__ inline void _float_to_paddedFP8rowwise_cuda_kernel(
3535 const int output_columns =
3636 ncols_aligned + (ncols + row_dim - 1 ) / row_dim * 8 ;
3737
38- const int64_t row = (int )blockIdx .x * blockDim .x + threadIdx .x ;
38+ const int64_t row =
39+ static_cast <int64_t >(blockIdx .x ) * blockDim .x + threadIdx .x ;
3940 // for 1D case, unsqueezing needed
4041 if (nrows == 1 ) {
4142 const auto threads = (ncols + row_dim - 1 ) / row_dim;
@@ -96,10 +97,11 @@ __global__ inline void _get_padding_value_kernel(
9697 const int row_dim,
9798 const std::uint8_t * const __restrict__ input,
9899 int * const __restrict__ offsets) {
99- const int64_t row = (int )blockIdx .x * blockDim .x + threadIdx .x ;
100+ const int64_t row =
101+ static_cast <int64_t >(blockIdx .x ) * blockDim .x + threadIdx .x ;
100102 const int row_ext = row_dim + 8 ;
101103 const auto threads = (ncols + row_ext - 1 ) / row_ext;
102- if (row > threads)
104+ if (row >= threads)
103105 return ;
104106 const std::uint8_t * const input_row = input + row * row_ext;
105107 int pad = *reinterpret_cast <const int *>(input_row + row_dim + 4 );
@@ -114,7 +116,7 @@ __global__ inline void _single_thread_sum_padding_kernel(
114116 int * __restrict__ total_pad) {
115117 // this is to count the sum of padding in the first row of 2D input
116118 // in one kernel launch to remove multiple H to D Syncs.
117- const auto tid = ( int ) blockIdx .x * blockDim .x + threadIdx .x ;
119+ const auto tid = static_cast < int64_t >( blockIdx .x ) * blockDim .x + threadIdx .x ;
118120 if (tid != 0 ) {
119121 return ;
120122 }
@@ -125,7 +127,12 @@ __global__ inline void _single_thread_sum_padding_kernel(
125127 while (offset + 4 <= ncols) {
126128 pad = *reinterpret_cast <const int *>(input + offset);
127129 if (pad < 0 ) {
128- offset += -pad * row_ext;
130+ // Widen before negating so pad == INT_MIN doesn't overflow.
131+ const int64_t step = -static_cast <int64_t >(pad) * row_ext;
132+ if (step > ncols - offset) {
133+ break ;
134+ }
135+ offset += static_cast <int >(step);
129136 } else {
130137 total_pad[0 ] += pad;
131138 offset += row_ext;
@@ -153,7 +160,7 @@ __global__ inline void _PaddedFP8rowwise_to_float_1d_cuda_kernel(
153160 reinterpret_cast <const float *>(input_row + row_dim);
154161 const auto scale = input_row_scale[0 ];
155162 int pad = *reinterpret_cast <const int *>(&input_row_scale[1 ]);
156- pad = (pad > 0 ) ? pad : 0 ;
163+ pad = :: max ( 0 , :: min ( pad, row_dim)) ;
157164 const auto pad_offset = offsets[row];
158165 output_t * output_row = output + row * row_dim - pad_offset;
159166 for (auto col = threadIdx .x ; col < row_dim - pad; col += blockDim .x ) {
@@ -176,7 +183,8 @@ __global__ inline void _PaddedFP8rowwise_to_float_2d_cuda_kernel(
176183 const int ebit = forward ? 4 : 5 ;
177184 const int bias = forward ? 15 : 31 ;
178185
179- const int64_t row = (int )blockIdx .x * blockDim .x + threadIdx .x ;
186+ const int64_t row =
187+ static_cast <int64_t >(blockIdx .x ) * blockDim .x + threadIdx .x ;
180188 if (row >= nrows) {
181189 return ;
182190 }
@@ -189,7 +197,7 @@ __global__ inline void _PaddedFP8rowwise_to_float_2d_cuda_kernel(
189197 int pad = *reinterpret_cast <const int *>(&input_row_scale[1 ]);
190198 // if pad is negative it's used to indidate indices of the next padded
191199 // bucket
192- pad = (pad > 0 ) ? pad : 0 ;
200+ pad = :: max ( 0 , :: min ( pad, row_dim)) ;
193201 for (int bi = 0 ; bi < row_dim - pad; ++bi) {
194202 const auto output_ =
195203 hfp8_to_float (input_row[col + bi], ebit, bias) / input_row_scale[0 ];
@@ -208,6 +216,11 @@ Tensor _float_to_paddedFP8rowwise_gpu_t(
208216 const int64_t row_dim) {
209217 TENSOR_CONTIGUOUS_AND_ON_CUDA_GPU (input);
210218 CUDA_DEVICE_GUARD (input);
219+ TORCH_CHECK (
220+ row_dim > 0 && row_dim % 4 == 0 ,
221+ " row_dim (" ,
222+ row_dim,
223+ " ) must be a positive multiple of 4 to keep the appended scale/pad words 4-byte aligned" );
211224
212225 const auto input_sizes = input.sizes ();
213226 const auto last_dim = input_sizes.size () - 1 ;
@@ -264,13 +277,26 @@ Tensor _paddedFP8rowwise_to_float_gpu_t(
264277 TENSOR_ON_CUDA_GPU (input);
265278 TORCH_CHECK (input.is_contiguous (), " input must be contiguous" );
266279 CUDA_DEVICE_GUARD (input);
280+ TORCH_CHECK (
281+ row_dim > 0 && row_dim % 4 == 0 ,
282+ " row_dim (" ,
283+ row_dim,
284+ " ) must be a positive multiple of 4 to keep the appended scale/pad words 4-byte aligned" );
267285
268286 const auto input_sizes = input.sizes ();
269287 const auto last_dim = input_sizes.size () - 1 ;
270288 const int nrows = c10::size_to_dim_ (last_dim, input_sizes);
271289 const int ncols = input_sizes[last_dim];
272290 const int row_ext = row_dim + 8 ;
273- int output_columns = ncols - (ncols + row_ext - 1 ) / row_ext * 8 ;
291+ TORCH_CHECK (
292+ ncols % row_ext == 0 ,
293+ " ncols (" ,
294+ ncols,
295+ " ) must be a multiple of row_ext (" ,
296+ row_ext,
297+ " )" );
298+ const int num_buckets = ncols / row_ext;
299+ int output_columns = ncols - num_buckets * 8 ;
274300 // Global memory instructions support reading or writing words of size equal
275301 // to 1, 2, 4, 8, or 16 bytes. Any access (via a variable or a pointer) to
276302 // data residing in global memory compiles to a single global memory
@@ -281,12 +307,9 @@ Tensor _paddedFP8rowwise_to_float_gpu_t(
281307
282308 constexpr int threads_per_block = 256 ;
283309 const auto num_blocks = cuda_calc_xblock_count (
284- (nrows == 1 ) ? (ncols + row_ext - 1 ) / row_ext + 1 : nrows,
285- threads_per_block);
310+ (nrows == 1 ) ? num_buckets : nrows, threads_per_block);
286311 Tensor offsets = at::empty (
287- (nrows == 1 ) ? num_blocks * threads_per_block + 1
288- : 0 , // 4 = sizeof(float)
289- input.options ().dtype (at::kInt ));
312+ (nrows == 1 ) ? num_buckets : 0 , input.options ().dtype (at::kInt ));
290313 int total_pad = 0 ;
291314 if (nrows == 1 ) {
292315 FBGEMM_LAUNCH_KERNEL (
@@ -316,7 +339,7 @@ Tensor _paddedFP8rowwise_to_float_gpu_t(
316339 total_pad_tensor.data_ptr <int >());
317340 total_pad = total_pad_tensor[0 ].item <int >();
318341 } else {
319- total_pad = offsets[((ncols + row_ext - 1 ) / row_ext) ].item <int >();
342+ total_pad = offsets[num_buckets ].item <int >();
320343 }
321344 output_columns -= total_pad;
322345 } else {
@@ -340,13 +363,6 @@ Tensor _paddedFP8rowwise_to_float_gpu_t(
340363
341364 if (nrows == 1 ) {
342365 // Use one thread block to work on 1 row for nrows == 1
343- TORCH_CHECK (
344- ncols % row_ext == 0 ,
345- " ncols (" ,
346- ncols,
347- " ) must be multiple of " ,
348- row_ext)
349- const int num_rows = ncols / row_ext;
350366 const int ebit = forward ? 4 : 5 ;
351367 const int bias = forward ? 15 : 31 ;
352368 constexpr int kMaxThreads = 1024 ;
@@ -356,7 +372,7 @@ Tensor _paddedFP8rowwise_to_float_gpu_t(
356372 output.scalar_type (), " PaddedFP8rowwise_to_float_1d_cuda_kernel" , [&] {
357373 FBGEMM_LAUNCH_KERNEL (
358374 (_PaddedFP8rowwise_to_float_1d_cuda_kernel<scalar_t >),
359- num_rows ,
375+ num_buckets ,
360376 threads_per_block,
361377 0 ,
362378 at::cuda::getCurrentCUDAStream (),
0 commit comments