Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 40 additions & 23 deletions fbgemm_gpu/src/quantize_ops/quantize_padded_fp8_rowwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ __global__ inline void _float_to_paddedFP8rowwise_cuda_kernel(
const int output_columns =
ncols_aligned + (ncols + row_dim - 1) / row_dim * 8;

const int64_t row = (int)blockIdx.x * blockDim.x + threadIdx.x;
const int64_t row =
static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
// for 1D case, unsqueezing needed
if (nrows == 1) {
const auto threads = (ncols + row_dim - 1) / row_dim;
Expand Down Expand Up @@ -96,10 +97,11 @@ __global__ inline void _get_padding_value_kernel(
const int row_dim,
const std::uint8_t* const __restrict__ input,
int* const __restrict__ offsets) {
const int64_t row = (int)blockIdx.x * blockDim.x + threadIdx.x;
const int64_t row =
static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
const int row_ext = row_dim + 8;
const auto threads = (ncols + row_ext - 1) / row_ext;
if (row > threads)
if (row >= threads)
return;
const std::uint8_t* const input_row = input + row * row_ext;
int pad = *reinterpret_cast<const int*>(input_row + row_dim + 4);
Expand All @@ -114,7 +116,8 @@ __global__ inline void _single_thread_sum_padding_kernel(
int* __restrict__ total_pad) {
// this is to count the sum of padding in the first row of 2D input
// in one kernel launch to remove multiple H to D Syncs.
const auto tid = (int)blockIdx.x * blockDim.x + threadIdx.x;
const auto tid =
static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (tid != 0) {
return;
}
Expand All @@ -125,7 +128,12 @@ __global__ inline void _single_thread_sum_padding_kernel(
while (offset + 4 <= ncols) {
pad = *reinterpret_cast<const int*>(input + offset);
if (pad < 0) {
offset += -pad * row_ext;
// Widen before negating so pad == INT_MIN doesn't overflow.
const int64_t step = -static_cast<int64_t>(pad) * row_ext;
if (step > ncols - offset) {
break;
}
offset += static_cast<int>(step);
} else {
total_pad[0] += pad;
offset += row_ext;
Expand Down Expand Up @@ -153,7 +161,7 @@ __global__ inline void _PaddedFP8rowwise_to_float_1d_cuda_kernel(
reinterpret_cast<const float*>(input_row + row_dim);
const auto scale = input_row_scale[0];
int pad = *reinterpret_cast<const int*>(&input_row_scale[1]);
pad = (pad > 0) ? pad : 0;
pad = ::max(0, ::min(pad, row_dim));
const auto pad_offset = offsets[row];
output_t* output_row = output + row * row_dim - pad_offset;
for (auto col = threadIdx.x; col < row_dim - pad; col += blockDim.x) {
Expand All @@ -176,7 +184,8 @@ __global__ inline void _PaddedFP8rowwise_to_float_2d_cuda_kernel(
const int ebit = forward ? 4 : 5;
const int bias = forward ? 15 : 31;

const int64_t row = (int)blockIdx.x * blockDim.x + threadIdx.x;
const int64_t row =
static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (row >= nrows) {
return;
}
Expand All @@ -189,7 +198,7 @@ __global__ inline void _PaddedFP8rowwise_to_float_2d_cuda_kernel(
int pad = *reinterpret_cast<const int*>(&input_row_scale[1]);
// if pad is negative it's used to indidate indices of the next padded
// bucket
pad = (pad > 0) ? pad : 0;
pad = ::max(0, ::min(pad, row_dim));
for (int bi = 0; bi < row_dim - pad; ++bi) {
const auto output_ =
hfp8_to_float(input_row[col + bi], ebit, bias) / input_row_scale[0];
Expand All @@ -208,6 +217,11 @@ Tensor _float_to_paddedFP8rowwise_gpu_t(
const int64_t row_dim) {
TENSOR_CONTIGUOUS_AND_ON_CUDA_GPU(input);
CUDA_DEVICE_GUARD(input);
TORCH_CHECK(
row_dim > 0 && row_dim % 4 == 0,
"row_dim (",
row_dim,
") must be a positive multiple of 4 to keep the appended scale/pad words 4-byte aligned");

const auto input_sizes = input.sizes();
const auto last_dim = input_sizes.size() - 1;
Expand Down Expand Up @@ -264,13 +278,26 @@ Tensor _paddedFP8rowwise_to_float_gpu_t(
TENSOR_ON_CUDA_GPU(input);
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
CUDA_DEVICE_GUARD(input);
TORCH_CHECK(
row_dim > 0 && row_dim % 4 == 0,
"row_dim (",
row_dim,
") must be a positive multiple of 4 to keep the appended scale/pad words 4-byte aligned");

const auto input_sizes = input.sizes();
const auto last_dim = input_sizes.size() - 1;
const int nrows = c10::size_to_dim_(last_dim, input_sizes);
const int ncols = input_sizes[last_dim];
const int row_ext = row_dim + 8;
int output_columns = ncols - (ncols + row_ext - 1) / row_ext * 8;
TORCH_CHECK(
ncols % row_ext == 0,
"ncols (",
ncols,
") must be a multiple of row_ext (",
row_ext,
")");
const int num_buckets = ncols / row_ext;
int output_columns = ncols - num_buckets * 8;
// Global memory instructions support reading or writing words of size equal
// to 1, 2, 4, 8, or 16 bytes. Any access (via a variable or a pointer) to
// data residing in global memory compiles to a single global memory
Expand All @@ -281,12 +308,9 @@ Tensor _paddedFP8rowwise_to_float_gpu_t(

constexpr int threads_per_block = 256;
const auto num_blocks = cuda_calc_xblock_count(
(nrows == 1) ? (ncols + row_ext - 1) / row_ext + 1 : nrows,
threads_per_block);
(nrows == 1) ? num_buckets : nrows, threads_per_block);
Tensor offsets = at::empty(
(nrows == 1) ? num_blocks * threads_per_block + 1
: 0, // 4 = sizeof(float)
input.options().dtype(at::kInt));
(nrows == 1) ? num_buckets : 0, input.options().dtype(at::kInt));
int total_pad = 0;
if (nrows == 1) {
FBGEMM_LAUNCH_KERNEL(
Expand Down Expand Up @@ -316,7 +340,7 @@ Tensor _paddedFP8rowwise_to_float_gpu_t(
total_pad_tensor.data_ptr<int>());
total_pad = total_pad_tensor[0].item<int>();
} else {
total_pad = offsets[((ncols + row_ext - 1) / row_ext)].item<int>();
total_pad = offsets[num_buckets].item<int>();
}
output_columns -= total_pad;
} else {
Expand All @@ -340,13 +364,6 @@ Tensor _paddedFP8rowwise_to_float_gpu_t(

if (nrows == 1) {
// Use one thread block to work on 1 row for nrows == 1
TORCH_CHECK(
ncols % row_ext == 0,
"ncols (",
ncols,
") must be multiple of ",
row_ext)
const int num_rows = ncols / row_ext;
const int ebit = forward ? 4 : 5;
const int bias = forward ? 15 : 31;
constexpr int kMaxThreads = 1024;
Expand All @@ -356,7 +373,7 @@ Tensor _paddedFP8rowwise_to_float_gpu_t(
output.scalar_type(), "PaddedFP8rowwise_to_float_1d_cuda_kernel", [&] {
FBGEMM_LAUNCH_KERNEL(
(_PaddedFP8rowwise_to_float_1d_cuda_kernel<scalar_t>),
num_rows,
num_buckets,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream(),
Expand Down
41 changes: 41 additions & 0 deletions fbgemm_gpu/test/quantize/fp8_rowwise_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,47 @@ def test_quantize_and_dequantize_op_padded_fp8_rowwise(

torch.testing.assert_allclose(dqcat, qref, rtol=0.1, atol=0.05)

@unittest.skipIf(*gpu_unavailable)
def test_padded_fp8_rowwise_input_validation(self) -> None:
fp32 = SparseType.FP32.as_int()
x = torch.rand(2, 32, device="cuda")
# row_dim must be a positive multiple of 4.
for bad in (0, -4, 3, 6):
with self.assertRaises(RuntimeError):
torch.ops.fbgemm.FloatToPaddedFP8RowwiseQuantized(
x, forward=True, row_dim=bad
)
with self.assertRaises(RuntimeError):
torch.ops.fbgemm.PaddedFP8RowwiseQuantizedToFloat(
torch.zeros(2, 24, device="cuda", dtype=torch.uint8),
forward=True,
row_dim=bad,
output_dtype=fp32,
)
# Dequant ncols must be a multiple of row_dim + 8.
with self.assertRaises(RuntimeError):
torch.ops.fbgemm.PaddedFP8RowwiseQuantizedToFloat(
torch.zeros(2, 25, device="cuda", dtype=torch.uint8),
forward=True,
row_dim=16,
output_dtype=fp32,
)

@unittest.skipIf(*gpu_unavailable)
def test_padded_fp8_rowwise_1d_roundtrip(self) -> None:
# Exercises the nrows == 1 path where _get_padding_value_kernel used
# to read past the offsets buffer at the boundary thread.
fp32 = SparseType.FP32.as_int()
for row_dim, num_buckets in [(4, 1), (16, 7), (256, 33)]:
x = torch.rand(row_dim * num_buckets, device="cuda")
q = torch.ops.fbgemm.FloatToPaddedFP8RowwiseQuantized(
x, forward=True, row_dim=row_dim
)
dq = torch.ops.fbgemm.PaddedFP8RowwiseQuantizedToFloat(
q, forward=True, row_dim=row_dim, output_dtype=fp32
)
torch.testing.assert_close(dq.cpu(), x.cpu(), rtol=0.1, atol=0.05)


if __name__ == "__main__":
unittest.main()
Loading