Skip to content

Commit a98dbb2

Browse files
cyyevermeta-codesync[bot]
authored andcommitted
Fix OOB read in _get_padding_value_kernel (#5662)
Summary: Pull Request resolved: #5662 X-link: https://github.com/facebookresearch/FBGEMM/pull/2603 Pull Request resolved: #5652 Reviewed By: henrylhtsang Differential Revision: D101380325 Pulled By: q10 fbshipit-source-id: 4e5bab7afd3101c46e3f3f8a9b39042fe1dc5872
1 parent ba5ed4b commit a98dbb2

2 files changed

Lines changed: 92 additions & 23 deletions

File tree

fbgemm_gpu/src/quantize_ops/quantize_padded_fp8_rowwise.cu

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ __global__ inline void _float_to_paddedFP8rowwise_cuda_kernel(
3535
const int output_columns =
3636
ncols_aligned + (ncols + row_dim - 1) / row_dim * 8;
3737

38-
const int64_t row = (int)blockIdx.x * blockDim.x + threadIdx.x;
38+
const int64_t row =
39+
static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
3940
// for 1D case, unsqueezing needed
4041
if (nrows == 1) {
4142
const auto threads = (ncols + row_dim - 1) / row_dim;
@@ -96,10 +97,11 @@ __global__ inline void _get_padding_value_kernel(
9697
const int row_dim,
9798
const std::uint8_t* const __restrict__ input,
9899
int* const __restrict__ offsets) {
99-
const int64_t row = (int)blockIdx.x * blockDim.x + threadIdx.x;
100+
const int64_t row =
101+
static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
100102
const int row_ext = row_dim + 8;
101103
const auto threads = (ncols + row_ext - 1) / row_ext;
102-
if (row > threads)
104+
if (row >= threads)
103105
return;
104106
const std::uint8_t* const input_row = input + row * row_ext;
105107
int pad = *reinterpret_cast<const int*>(input_row + row_dim + 4);
@@ -114,7 +116,7 @@ __global__ inline void _single_thread_sum_padding_kernel(
114116
int* __restrict__ total_pad) {
115117
// this is to count the sum of padding in the first row of 2D input
116118
// in one kernel launch to remove multiple H to D Syncs.
117-
const auto tid = (int)blockIdx.x * blockDim.x + threadIdx.x;
119+
const auto tid = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
118120
if (tid != 0) {
119121
return;
120122
}
@@ -125,7 +127,12 @@ __global__ inline void _single_thread_sum_padding_kernel(
125127
while (offset + 4 <= ncols) {
126128
pad = *reinterpret_cast<const int*>(input + offset);
127129
if (pad < 0) {
128-
offset += -pad * row_ext;
130+
// Widen before negating so pad == INT_MIN doesn't overflow.
131+
const int64_t step = -static_cast<int64_t>(pad) * row_ext;
132+
if (step > ncols - offset) {
133+
break;
134+
}
135+
offset += static_cast<int>(step);
129136
} else {
130137
total_pad[0] += pad;
131138
offset += row_ext;
@@ -153,7 +160,7 @@ __global__ inline void _PaddedFP8rowwise_to_float_1d_cuda_kernel(
153160
reinterpret_cast<const float*>(input_row + row_dim);
154161
const auto scale = input_row_scale[0];
155162
int pad = *reinterpret_cast<const int*>(&input_row_scale[1]);
156-
pad = (pad > 0) ? pad : 0;
163+
pad = ::max(0, ::min(pad, row_dim));
157164
const auto pad_offset = offsets[row];
158165
output_t* output_row = output + row * row_dim - pad_offset;
159166
for (auto col = threadIdx.x; col < row_dim - pad; col += blockDim.x) {
@@ -176,7 +183,8 @@ __global__ inline void _PaddedFP8rowwise_to_float_2d_cuda_kernel(
176183
const int ebit = forward ? 4 : 5;
177184
const int bias = forward ? 15 : 31;
178185

179-
const int64_t row = (int)blockIdx.x * blockDim.x + threadIdx.x;
186+
const int64_t row =
187+
static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
180188
if (row >= nrows) {
181189
return;
182190
}
@@ -189,7 +197,7 @@ __global__ inline void _PaddedFP8rowwise_to_float_2d_cuda_kernel(
189197
int pad = *reinterpret_cast<const int*>(&input_row_scale[1]);
190198
// if pad is negative it's used to indidate indices of the next padded
191199
// bucket
192-
pad = (pad > 0) ? pad : 0;
200+
pad = ::max(0, ::min(pad, row_dim));
193201
for (int bi = 0; bi < row_dim - pad; ++bi) {
194202
const auto output_ =
195203
hfp8_to_float(input_row[col + bi], ebit, bias) / input_row_scale[0];
@@ -208,6 +216,11 @@ Tensor _float_to_paddedFP8rowwise_gpu_t(
208216
const int64_t row_dim) {
209217
TENSOR_CONTIGUOUS_AND_ON_CUDA_GPU(input);
210218
CUDA_DEVICE_GUARD(input);
219+
TORCH_CHECK(
220+
row_dim > 0 && row_dim % 4 == 0,
221+
"row_dim (",
222+
row_dim,
223+
") must be a positive multiple of 4 to keep the appended scale/pad words 4-byte aligned");
211224

212225
const auto input_sizes = input.sizes();
213226
const auto last_dim = input_sizes.size() - 1;
@@ -264,13 +277,26 @@ Tensor _paddedFP8rowwise_to_float_gpu_t(
264277
TENSOR_ON_CUDA_GPU(input);
265278
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
266279
CUDA_DEVICE_GUARD(input);
280+
TORCH_CHECK(
281+
row_dim > 0 && row_dim % 4 == 0,
282+
"row_dim (",
283+
row_dim,
284+
") must be a positive multiple of 4 to keep the appended scale/pad words 4-byte aligned");
267285

268286
const auto input_sizes = input.sizes();
269287
const auto last_dim = input_sizes.size() - 1;
270288
const int nrows = c10::size_to_dim_(last_dim, input_sizes);
271289
const int ncols = input_sizes[last_dim];
272290
const int row_ext = row_dim + 8;
273-
int output_columns = ncols - (ncols + row_ext - 1) / row_ext * 8;
291+
TORCH_CHECK(
292+
ncols % row_ext == 0,
293+
"ncols (",
294+
ncols,
295+
") must be a multiple of row_ext (",
296+
row_ext,
297+
")");
298+
const int num_buckets = ncols / row_ext;
299+
int output_columns = ncols - num_buckets * 8;
274300
// Global memory instructions support reading or writing words of size equal
275301
// to 1, 2, 4, 8, or 16 bytes. Any access (via a variable or a pointer) to
276302
// data residing in global memory compiles to a single global memory
@@ -281,12 +307,9 @@ Tensor _paddedFP8rowwise_to_float_gpu_t(
281307

282308
constexpr int threads_per_block = 256;
283309
const auto num_blocks = cuda_calc_xblock_count(
284-
(nrows == 1) ? (ncols + row_ext - 1) / row_ext + 1 : nrows,
285-
threads_per_block);
310+
(nrows == 1) ? num_buckets : nrows, threads_per_block);
286311
Tensor offsets = at::empty(
287-
(nrows == 1) ? num_blocks * threads_per_block + 1
288-
: 0, // 4 = sizeof(float)
289-
input.options().dtype(at::kInt));
312+
(nrows == 1) ? num_buckets : 0, input.options().dtype(at::kInt));
290313
int total_pad = 0;
291314
if (nrows == 1) {
292315
FBGEMM_LAUNCH_KERNEL(
@@ -316,7 +339,7 @@ Tensor _paddedFP8rowwise_to_float_gpu_t(
316339
total_pad_tensor.data_ptr<int>());
317340
total_pad = total_pad_tensor[0].item<int>();
318341
} else {
319-
total_pad = offsets[((ncols + row_ext - 1) / row_ext)].item<int>();
342+
total_pad = offsets[num_buckets].item<int>();
320343
}
321344
output_columns -= total_pad;
322345
} else {
@@ -340,13 +363,6 @@ Tensor _paddedFP8rowwise_to_float_gpu_t(
340363

341364
if (nrows == 1) {
342365
// Use one thread block to work on 1 row for nrows == 1
343-
TORCH_CHECK(
344-
ncols % row_ext == 0,
345-
"ncols (",
346-
ncols,
347-
") must be multiple of ",
348-
row_ext)
349-
const int num_rows = ncols / row_ext;
350366
const int ebit = forward ? 4 : 5;
351367
const int bias = forward ? 15 : 31;
352368
constexpr int kMaxThreads = 1024;
@@ -356,7 +372,7 @@ Tensor _paddedFP8rowwise_to_float_gpu_t(
356372
output.scalar_type(), "PaddedFP8rowwise_to_float_1d_cuda_kernel", [&] {
357373
FBGEMM_LAUNCH_KERNEL(
358374
(_PaddedFP8rowwise_to_float_1d_cuda_kernel<scalar_t>),
359-
num_rows,
375+
num_buckets,
360376
threads_per_block,
361377
0,
362378
at::cuda::getCurrentCUDAStream(),

fbgemm_gpu/test/quantize/fp8_rowwise_test.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,59 @@ def test_quantize_and_dequantize_op_padded_fp8_rowwise(
227227

228228
torch.testing.assert_allclose(dqcat, qref, rtol=0.1, atol=0.05)
229229

230+
@unittest.skipIf(*gpu_unavailable)
231+
def test_padded_fp8_rowwise_input_validation(self) -> None:
232+
fp32 = SparseType.FP32.as_int()
233+
x = torch.rand(2, 32, device=torch.accelerator.current_accelerator())
234+
# row_dim must be a positive multiple of 4.
235+
for bad in (0, -4, 3, 6):
236+
with self.assertRaises(RuntimeError):
237+
torch.ops.fbgemm.FloatToPaddedFP8RowwiseQuantized(
238+
x, forward=True, row_dim=bad
239+
)
240+
with self.assertRaises(RuntimeError):
241+
torch.ops.fbgemm.PaddedFP8RowwiseQuantizedToFloat(
242+
torch.zeros(
243+
2,
244+
24,
245+
device=torch.accelerator.current_accelerator(),
246+
dtype=torch.uint8,
247+
),
248+
forward=True,
249+
row_dim=bad,
250+
output_dtype=fp32,
251+
)
252+
# Dequant ncols must be a multiple of row_dim + 8.
253+
with self.assertRaises(RuntimeError):
254+
torch.ops.fbgemm.PaddedFP8RowwiseQuantizedToFloat(
255+
torch.zeros(
256+
2,
257+
25,
258+
device=torch.accelerator.current_accelerator(),
259+
dtype=torch.uint8,
260+
),
261+
forward=True,
262+
row_dim=16,
263+
output_dtype=fp32,
264+
)
265+
266+
@unittest.skipIf(*gpu_unavailable)
267+
def test_padded_fp8_rowwise_1d_roundtrip(self) -> None:
268+
# Exercises the nrows == 1 path where _get_padding_value_kernel used
269+
# to read past the offsets buffer at the boundary thread.
270+
fp32 = SparseType.FP32.as_int()
271+
for row_dim, num_buckets in [(4, 1), (16, 7), (256, 33)]:
272+
x = torch.rand(
273+
row_dim * num_buckets, device=torch.accelerator.current_accelerator()
274+
)
275+
q = torch.ops.fbgemm.FloatToPaddedFP8RowwiseQuantized(
276+
x, forward=True, row_dim=row_dim
277+
)
278+
dq = torch.ops.fbgemm.PaddedFP8RowwiseQuantizedToFloat(
279+
q, forward=True, row_dim=row_dim, output_dtype=fp32
280+
)
281+
torch.testing.assert_close(dq.cpu(), x.cpu(), rtol=0.1, atol=0.05)
282+
230283

231284
if __name__ == "__main__":
232285
unittest.main()

0 commit comments

Comments
 (0)