This document captures the full design analysis for implementing a fused kbit dequantization + GEMM kernel in bitsandbytes. It covers the existing kbit quantization implementation, the Marlin kernel architecture (as reference), and the complete design for the new GEMM kernel. A developer reading this should be able to implement the kernel without additional context.
- Existing kbit Implementation
- Marlin Kernel Architecture (Reference)
- GEMM Kernel Design
- Weight Storage Format and Repacking
- Inner Loop: Dequantization + MMA
- Persistent Kernel and Work Distribution
- Pipeline and Shared Memory
- Codebook and Absmax Handling
- Performance Analysis
- Kernel Dispatch and Python Integration
- File Organization and Build
- Error Budget
- Template Instantiations
- Future Considerations
The kbit quantization system lives on the feature/kbit-quantization branch.
It implements K-bit blockwise quantization for K=2,3,4,5 with blocksize=32
(one warp = one quantization block). It uses a codebook-based approach where
each element is mapped to the nearest entry in a 2^K-entry codebook, then
packed into K bit-plane words using warp-level CUDA primitives.
Currently, only standalone quantize and dequantize kernels exist. There is no fused GEMM. The goal of this design is to add a fused dequant+GEMM kernel that achieves high tensor core utilization at larger batch sizes.
The codebook is generated by create_normal_float_codebook(k) in
bitsandbytes/functional.py. It places 2^K reconstruction levels at the
expected values of N(0,1) within 2^K equiprobable bins, then normalizes to
[-1, 1]. The codebook is:
- Sorted ascending
- Roughly symmetric around 0
- Normalized so
abs(max) == 1.0 - Cached per (k, device) pair
For K=4, this is conceptually similar to the existing NF4 datatype, though with minor numerical differences (the existing NF4 has an asymmetric zero trick).
The codebook is always stored as float32 and passed to CUDA kernels as
const float*. For the GEMM kernel, it will be converted to half precision
at kernel startup (see Section 8.1).
Location: csrc/ops.cu, function kQuantizeBlockwise_kbit<T, K> (line 682).
Template parameters:
T: input type (half, __nv_bfloat16, float)
K: bit width (2, 3, 4, 5)
Launch config:
Block size: 256 threads (KBIT_THREADS_PER_BLOCK)
Grid: ceil(num_blocks / 8) where num_blocks = ceil(n / 32)
Each CUDA block has 8 warps, each warp processes one quantization block.
Algorithm per warp:
1. Each lane loads one element from A (lane_id maps 1:1 to element position)
2. Convert to float
3. Warp-reduce absmax via __shfl_down_sync butterfly reduction
4. Lane 0 broadcasts absmax to all lanes via __shfl_sync
5. Lane 0 writes absmax[warp_id]
6. Normalize: val / max(absmax, 1e-8)
7. Load codebook into lane registers: cb = codebook[lane_id] for lane < 2^K
8. Brute-force nearest-neighbor search:
- Loop i = 0..2^K-1
- Broadcast codebook[i] to all lanes via __shfl_sync(cb, i)
- Compare distance, track best index
9. Pack via __ballot_sync: for each bit b in 0..K-1,
packed[b] = __ballot_sync(0xFFFFFFFF, (best_idx >> b) & 1)
This produces K uint32 words where word b contains bit b of all 32 lanes.
10. Lanes 0..K-1 write their respective bit-plane word to
packed_out[warp_id * K + lane_id]
Key observations:
- The output is in "bit-plane" format: K uint32 words per block of 32 elements
__ballot_synccollects one bit from all 32 lanes into a single uint32- The packed data layout in memory is sequential: block 0's K words, then block 1's K words, etc.
- absmax is stored as float32 (later encoded to E4M4 on the Python side)
Location: csrc/ops.cu, function kDequantizeBlockwise_kbit_vec<T, K, BLOCKS_PER_WARP, ABSMAX_T> (line 753).
Template parameters:
T: output type (half, __nv_bfloat16, float)
K: bit width (2, 3, 4, 5)
BLOCKS_PER_WARP: number of quantization blocks processed per warp iteration (4)
ABSMAX_T: absmax storage type (unsigned char for E4M4, half for fp16)
Launch config:
Block size: 256 threads (8 warps)
Grid: ceil(num_warps / 8) where num_warps = ceil(num_blocks / BLOCKS_PER_WARP)
Algorithm per warp:
1. Load codebook into lane registers (once, amortized across BLOCKS_PER_WARP):
float cb = (lane_id < (1 << K)) ? codebook[lane_id] : 0.0f;
2. For each of BLOCKS_PER_WARP=4 blocks:
a. Load absmax via load_absmax<ABSMAX_T>(absmax, block_id)
- For unsigned char: calls decode_e4m4_absmax()
- For half: simple cast to float
b. Load K bit-plane words using shuffle broadcast:
for (bit = 0; bit < K; bit++) {
unsigned int word = (lane_id == bit) ? packed_in[block_id * K + bit] : 0;
packed[bit] = __shfl_sync(0xFFFFFFFF, word, bit);
}
Only lane `bit` reads from global memory; all other lanes receive
the value via shuffle broadcast. This minimizes global memory
transactions (K reads per block instead of K*32).
c. Unpack index: for each bit, extract that bit from the plane word
at the current lane's position, OR them together:
idx = 0;
for (bit = 0; bit < K; bit++)
idx |= ((packed[bit] >> lane_id) & 1) << bit;
d. Codebook lookup via shuffle:
float val = __shfl_sync(0xFFFFFFFF, cb, idx) * amax;
e. Write output: out[block_start + lane_id] = (T)val;
Key observations:
- The shuffle-based bit-plane loading pattern (step 2b) exploits the fact that each lane has a 1:1 correspondence with an element position. Only K lanes do global loads; the rest get data via shuffle. This is specific to the standalone dequant where threads map 1:1 to elements.
- In the GEMM kernel, this pattern CANNOT be used directly because threads are organized around tensor core fragment positions, not element positions. Instead, bit-plane words will be loaded into shared memory by the async pipeline, and each thread reads from shared memory for its specific column. This is discussed in detail in Section 5.
- BLOCKS_PER_WARP=4 amortizes the codebook register load across 4 blocks. In the GEMM kernel, the codebook is loaded once at kernel start and lives in a register for the entire kernel lifetime -- even better amortization.
Location: csrc/ops.cu, function decode_e4m4_absmax (line 722).
Format: 4-bit exponent + 4-bit mantissa with bias=11.
- Normal (e > 0):
2^(e - 11) * (1 + m/16) - Subnormal (e = 0):
2^(1 - 11) * (m/16)=2^(-10) * (m/16) - Zero (e = 0, m = 0): 0.0
Range: approximately [6.1e-5, 31.0] for normal values. Max relative error: 1/16 = 6.25% (from the 4-bit mantissa).
The decode implementation constructs an IEEE 754 float directly via bit manipulation, avoiding any floating-point arithmetic:
__device__ __forceinline__ float decode_e4m4_absmax(unsigned char raw) {
if (raw == 0) return 0.0f;
int e = raw >> 4;
int m = raw & 0xF;
if (e == 0) {
return ldexpf((float)m, 1 - E4M4_BIAS - 4); // subnormal
}
unsigned int ieee = (unsigned int)(e - E4M4_BIAS + 127) << 23
| (unsigned int)m << 19;
return __uint_as_float(ieee);
}Cost: 1 comparison, 2 shifts, 1 OR, 1 add, 1 reinterpret. ~5 integer ALU ops.
The subnormal path uses ldexpf but is rarely taken in practice.
The Python-side encoding is in bitsandbytes/functional.py:
encode_absmax_e4m4() and decode_absmax_e4m4().
Storage savings: 1 byte per block of 32 elements vs 4 bytes for float32. This reduces absmax overhead from 0.125 bytes/element to 0.03125 bytes/element.
// Pack: collect bit `bit` from all 32 lanes into one uint32
template <int K>
__device__ __forceinline__ void pack_kbit_warp(unsigned char qval, unsigned int* packed_words) {
for (int bit = 0; bit < K; bit++)
packed_words[bit] = __ballot_sync(0xFFFFFFFF, (qval >> bit) & 1);
}
// Unpack: reconstruct K-bit index for this lane from K bit-plane words
template <int K>
__device__ __forceinline__ unsigned char unpack_kbit_warp(const unsigned int* packed_words, int lane_id) {
unsigned char val = 0;
for (int bit = 0; bit < K; bit++)
val |= ((packed_words[bit] >> lane_id) & 1) << bit;
return val;
}The pack operation uses __ballot_sync which collects one bit from each of
the 32 lanes in a warp and assembles them into a single uint32 word.
The unpack operation does the reverse: for a given lane position, it extracts one bit from each of K plane words and assembles them into a K-bit index.
Both operations are O(K) in ALU ops. For K=4: 4 ballot_sync ops for packing, 4 shift+mask+OR ops for unpacking.
Quantize: 12 variants (3 input types x 4 K values) Dequantize: 24 variants (3 output types x 2 absmax types x 4 K values)
All instantiated via macros at the bottom of ops.cu (lines 821-869).
Three layers:
bitsandbytes/_ops.py: torch.library op definitions with fake tensor implementations for torch.compile compatibilitybitsandbytes/backends/cuda/ops.py: CUDA kernel dispatch -- maps dtype to C function name suffix, handles fp32->E4M4 absmax encodingcsrc/pythonInterface.cpp: unmangled C++ wrappers calling templates, then extern "C" wrappers calling those
The naming convention for C functions:
- Quantize:
cquantize_kbit_{fp16,bf16,fp32}_k{2,3,4,5} - Dequantize:
cdequantize_kbit_{fp16,bf16,fp32}_{u8abs,fp16abs}_k{2,3,4,5}
The test suite (tests/test_kbit_quantization.py, ~1400 lines) covers:
- Stage 0: Pure Python reference (quantize_kbit_ref, dequantize_kbit_ref)
- Stage 4: CUDA quantize correctness (absmax, all dtypes, various sizes)
- Stage 5: CUDA dequantize correctness (matches ref, all dtypes, various sizes, error bounds)
- Stage 6: Error analysis on 1M+ elements (analytical bounds, MSE scaling, SQNR)
- Stage 7: Cross-validation against existing NF4
- Stage 8: Performance benchmarks (bandwidth utilization, throughput scaling, NF4 comparison)
- Python API tests (round-trip, all dtypes, custom codebook, various sizes)
- Output dtype correctness (bf16/fp32 vs fp16 baseline)
- Asymmetric codebook tests (all-positive, all-negative, skewed, non-uniform)
- E4M4 encode/decode tests (round-trip, subnormals, monotonicity, uniqueness)
The quantize kernel stores packed data in flat sequential order:
packed_out[warp_id * K + bit] = plane_word
For a tensor A of n elements:
num_blocks = ceil(n / 32)
packed_out has num_blocks * K uint32 words
Block i covers elements [32*i, 32*(i+1))
packed_out[i*K + 0] = bit-plane 0 of block i (bit 0 of all 32 elements)
packed_out[i*K + 1] = bit-plane 1 of block i
...
packed_out[i*K + K-1] = bit-plane K-1 of block i
For a weight matrix W[K_dim, N] flattened in row-major order: Element (k, n) is at flat index k * N + n It belongs to block floor((k * N + n) / 32)
This flat layout is NOT suitable for GEMM tiling. The repack kernel (Section 4) transforms it into a tiled layout.
The Marlin kernel in vllm (csrc/quantization/marlin/) is a highly optimized
mixed-precision GEMM for weight-only quantization. We use it as architectural
reference, not as code to copy.
Location: vllm/csrc/quantization/marlin/marlin_template.h
Tiling and SM partitioning (line 271-281): Marlin uses "stripe" partitioning where each threadblock processes a contiguous run of tiles from a linearized 2D work grid. This ensures good SM utilization for all shapes while minimizing cross-threadblock reductions.
4-stage async pipeline (line 916-923):
Uses cp.async to overlap global->shared memory transfers with computation.
The cp_async_wait<stages-2>() pattern ensures double-buffering.
Register double-buffering (line 927-939):
Shared memory reads alternate between two sets of register fragments
(frag_b_quant[k%2]), hiding the shared memory read latency.
On-the-fly dequantization (line 1236-1237):
INT4/INT8/FP4/FP8 values are dequantized in registers using lop3 and
prmt PTX instructions. This is purely arithmetic (no memory access).
For kbit, we replace this with codebook lookup (see Section 5).
Tensor core MMA (line 1278-1281):
Standard m16n8k16 instructions on dequantized fp16 fragments,
accumulating in fp32.
Scale application (line 1244-1270): Group-wise or channel-wise scales applied to dequantized FragB before MMA. Multiple code paths handle different group_blocks configurations. For kbit, this simplifies dramatically because our blocksize=32 aligns with TILE_K boundaries (see Section 8.2).
The stripe system (marlin_template.h:271-281, marlin.cu:362-516) solves the problem of filling all SMs when the 2D tile count is less than the SM count.
Example: 5 SMs, 3x3 tile grid (3 K-tiles x 3 N-columns):
Column: 0 1 2
K-tile 0: [0] [1] [3]
K-tile 1: [0] [2] [3]
K-tile 2: [1] [2] [4]
Numbers = which SM handles that tile.
The linearized tile sequence is distributed as contiguous "stripes" across SMs. Properties:
- Perfect load balance (each SM gets total_tiles/num_SMs +/- 1)
- Minimized reductions (each SM crosses at most one column boundary)
- Adaptive split-K (automatically splits K when N-tiles < num_SMs)
The reduction uses barrier_acquire/barrier_release on a locks array.
We chose NOT to implement Marlin-style stripes. Instead, we use a persistent kernel with explicit work assignment (see Section 6).
Location: marlin.cu:128-313
Two sets of thread configs:
- Small batch (thread_m_blocks=1): {128,128,256}, {64,128,128}, {128,64,128}
- Large batch (thread_m_blocks>1): {64,256,256}, {64,128,128}, {128,64,128} (values are {thread_k, thread_n, num_threads})
The dispatch tries configs in priority order, picks the first valid one (fits in shared memory, divides problem dimensions). If none work, reduces thread_m_blocks and retries.
For large M, Marlin splits M into parallel groups, each processed by a separate set of SMs.
| Aspect | Marlin | kbit GEMM |
|---|---|---|
| Dequant method | lop3 bit manipulation -> fp16 | Bit extraction -> codebook lookup -> scale |
| Codebook | None (linear INT4->FP16) | 4-32 entries via __shfl_sync |
| Scale granularity | Configurable group_blocks | Fixed: 1 E4M4 scale per 32 elements |
| K-tile alignment | Complex group boundary logic | Clean: TILE_K=64 = 2 blocks, no straddling |
| B tile in shmem | Standard INT4 size | Same for K=4, smaller for K=2,3 |
| Bit widths | 4 or 8 | 2, 3, 4, 5 |
| Zero points | Optional, complex logic | None (symmetric codebook) |
| Act-order | Supported (major complexity) | Not needed |
| Work distribution | Stripe partitioning | Persistent kernel + atomicAdd |
Compute C[M, N] = A[M, K_dim] * W_kbit[K_dim, N]^T where:
- A is in fp16 (or bf16)
- W is stored in kbit format (bit-plane packed indices + E4M4 absmax + codebook)
- C is in fp16 (or bf16)
The weight matrix W is quantized offline and stored in a GEMM-optimized tiled layout (produced by the repack kernel). The codebook is shared across all blocks.
TILE_M = variable (16, 32, 48, 64 depending on M; controlled by M_BLOCKS template param)
TILE_N = 128 (or 256 for large batch configs)
TILE_K = 64 (= 2 quantization blocks of 32 elements each)
TILE_K=64 was chosen over TILE_K=32 because:
- Doubles compute per shared memory load of A
- Better compute-to-load ratio in the transition zone (M=32-128)
- Only adds one extra absmax value per column per tile (trivial complexity)
- 2 MMA k-sub-tile pairs instead of 1, better pipeline utilization
With TILE_K=64, each K-tile spans exactly 2 kbit blocks (each 32 elements). Each column has 2 absmax values per K-tile. The absmax boundary falls exactly between k_sub=1 and k_sub=2 of the 4 MMA k-sub-tiles.
256 threads = 8 warps per thread block.
Warp layout (for TILE_M=64, TILE_N=128): 2 warps along M x 4 warps along N Each warp owns a 32x32 sub-tile of C
For the m16n8k16 MMA instruction: Each warp's 32x32 sub-tile = 2 M-blocks x 4 N-blocks = 8 MMA positions With TILE_K=64 (4 k-sub-tiles of 16): 8 * 4 = 32 MMA ops per warp per K-tile
Per thread:
- Codebook: 1 half register (loaded at kernel start, lives for entire kernel)
- FragC accumulators: M_BLOCKS * N_BLOCKS * 2 * Vec<float,4> For M_BLOCKS=4, N_BLOCKS=4: 32 * 4 = 128 floats = 512 bytes Per thread: 512 / 32 = 16 floats
- FragA: M_BLOCKS * Vec<half2,2> per k-sub-tile (double-buffered)
- FragB: Vec<half2,2> per N-block per k-sub-tile (not stored, consumed immediately)
- Bit-plane words: K uint32 temporaries
- Absmax: 2 half values per column group
Total estimated: ~40-50 registers per thread. Well within the 255 limit.
The quantize kernel (kQuantizeBlockwise_kbit) outputs packed data in flat
sequential order:
For a weight matrix W[K_dim, N] flattened to 1D:
Block i: elements [32*i .. 32*(i+1))
packed[i*K + bit] = bit-plane word for bit `bit` of block i
absmax[i] = max absolute value in block i (float32, later E4M4-encoded)
This layout is contiguous in memory but NOT optimized for GEMM tiling. A GEMM kernel loading a TILE_K x TILE_N region would need to gather from many non-contiguous locations.
The repack kernel transforms the flat layout into a tiled layout where each (k_tile, n_tile) region is contiguous in memory:
B_packed[k_tile][n_tile][col_within_tile][k_block_within_tile][bit_plane]
Dimensions:
k_tile: 0 .. K_dim/TILE_K - 1
n_tile: 0 .. N/TILE_N - 1
col_within_tile: 0 .. TILE_N - 1 (128 columns per N-tile)
k_block_within_tile: 0 .. TILE_K/32 - 1 (2 blocks per K-tile with TILE_K=64)
bit_plane: 0 .. K-1
Total words per tile: TILE_N * (TILE_K / 32) * K
For TILE_N=128, TILE_K=64, K=4: 128 * 2 * 4 = 1024 uint32 words = 4 KB
Absmax is stored separately in a matching tiled layout:
B_absmax[k_tile][n_tile][col_within_tile][k_block_within_tile]
Total bytes per tile: TILE_N * (TILE_K / 32) = 128 * 2 = 256 bytes (uint8)
The repack kernel is a simple gather/permutation kernel, run once when the model is loaded (not on the hot path). It maps:
Source: packed_flat[block_id * K + bit]
where block_id = (k * N + n) / 32 (for element (k, n) in row-major W)
Destination: packed_tiled[k_tile][n_tile][col][k_block][bit]
where k_tile = k / TILE_K
n_tile = n / TILE_N
col = n % TILE_N
k_block = (k % TILE_K) / 32
bit = 0..K-1
Similarly for absmax:
Source: absmax_flat[block_id]
Destination: absmax_tiled[k_tile][n_tile][col][k_block]
The repack kernel should also handle E4M4 encoding of absmax if it hasn't been done already.
We keep the bit-plane format for the GEMM kernel rather than converting to contiguous K-bit packing. Reasons:
-
Uniform across all K values: K=2,3,4,5 all work identically. Contiguous packing is awkward for K=3,5 (don't divide 32 evenly, boundary-crossing extraction needed).
-
Same memory footprint: K words per block of 32 regardless of format. Both formats use exactly K * 4 bytes per 32 elements.
-
Extraction cost is hidden: The bit-plane extraction (K shift+mask+OR per element) runs on INT32 ALU, concurrent with tensor core MMA. The cost is effectively free in the steady state.
-
No format conversion needed: The quantize kernel already produces bit-planes. Repacking only changes the tile layout, not the data format.
For the m16n8k16 MMA instruction (fp16 inputs, fp32 accumulation):
The B matrix (weights) in the MMA is k=16 x n=8. Per thread t (lane 0-31):
| Register | Row indices | Column |
|---|---|---|
| b[0] (half2) | k = 2*(t%4), 2*(t%4)+1 | n = t/4 |
| b[1] (half2) | k = 2*(t%4)+8, 2*(t%4)+9 | n = t/4 |
Key property: all 4 elements a thread needs are in the SAME column (n = t/4). The rows are at positions {2i, 2i+1, 2i+8, 2i+9} where i = t%4.
This means threads 4n, 4n+1, 4n+2, 4n+3 all access the same column n. When loading bit-plane words from shared memory, these 4 threads read the same K addresses -> shared memory broadcast (no bank conflict).
In the standalone dequant kernel, bit-plane words are loaded from global
memory using the shuffle-broadcast trick (only lane bit loads, broadcasts
to all). This pattern DOES NOT WORK in the GEMM context because:
- Threads are not mapped 1:1 to elements -- they're mapped to tensor core fragment positions.
- Data is in shared memory (loaded by the async pipeline), not global memory.
- Multiple threads need the same bit-plane words (4 threads per column).
Instead, in the GEMM kernel, each thread reads K words directly from shared memory for its column's block:
// my_col: which N-column this thread handles in the current MMA sub-tile
// This is determined by the tensor core fragment layout: my_col = lane_id / 4
int my_col = (threadIdx.x % 32) / 4; // 0-7 for the 8 columns in m16n8k16
// Load K bit-plane words for this column's block
uint32_t planes[K_BITS];
#pragma unroll
for (int b = 0; b < K_BITS; b++)
planes[b] = sh_b[column_offset + b];Since 4 threads share the same column (same my_col value), they all read
the same K addresses from shared memory. This is a 4-way broadcast, which
shared memory handles natively with no bank conflicts.
With 8 distinct columns per warp and K=4:
- 8 groups of 4 threads, each reading from different addresses
- 8 different banks accessed simultaneously -> zero conflicts
After loading the K bit-plane words into registers, each thread extracts indices for its 4 fragment rows:
int row_base = 2 * (lane_id % 4); // 0, 2, 4, or 6
int rows[4] = {row_base, row_base + 1, row_base + 8, row_base + 9};
half vals[4];
#pragma unroll
for (int r = 0; r < 4; r++) {
int idx = 0;
#pragma unroll
for (int b = 0; b < K_BITS; b++)
idx |= ((planes[b] >> rows[r]) & 1) << b;
// Codebook lookup + scale (see Section 5.4)
half cb_val = __shfl_sync(0xFFFFFFFF, cb_h, idx);
vals[r] = __hmul(cb_val, scale);
}
// Pack into FragB
half2 frag_b[2];
frag_b[0] = __halves2half2(vals[0], vals[1]);
frag_b[1] = __halves2half2(vals[2], vals[3]);ALU cost per FragB (4 values, K=4):
- Index extraction: 4 elements * 4 bits = 16 shift+mask+OR ops (INT32)
- Codebook lookup: 4 __shfl_sync ops (shuffle unit)
- Scale: 4 __hmul ops (FP16 ALU)
- Pack: 2 __halves2half2 ops
All of these run on different functional units from the tensor core MMA, so they overlap with MMA execution.
The codebook is stored as a half-precision value in each lane's register:
// At kernel start (once):
int lane = threadIdx.x % 32;
half cb_h = (lane < (1 << K_BITS))
? __float2half(codebook[lane])
: __float2half(0.0f);Lookup uses __shfl_sync with per-thread independent source lane:
half val = __shfl_sync(0xFFFFFFFF, cb_h, idx);Each thread can request the value from any lane. The shuffle unit handles arbitrary per-thread source selection. Cost: 1 cycle, no memory access.
Why shuffle (not constant memory or shared memory):
- Constant memory: optimized for broadcast (all threads same address). With divergent indices (each thread wants a different codebook entry), it serializes -- up to 2^K sequential reads. Bad.
- Shared memory: works (no bank conflicts for K<=4 since entries fit in distinct banks), but adds shared memory traffic.
- Shuffle: 1 cycle, zero memory, perfect for this use case. Already proven in the existing dequant kernel.
For one K-tile (TILE_K=64, 4 sub-tiles of k=16):
for (int k_sub = 0; k_sub < 4; k_sub++) {
// Which kbit block does this sub-tile fall in?
// k_sub 0,1 -> block 0 (first 32 elements), k_sub 2,3 -> block 1
half scale = (k_sub < 2) ? absmax_h[0] : absmax_h[1];
// Load A fragments via ldmatrix (from shared memory)
FragA frag_a[M_BLOCKS];
for (int m = 0; m < M_BLOCKS; m++)
ldmatrix_a(frag_a[m], sh_a, m, k_sub);
// For each N-block in this warp's sub-tile:
for (int n = 0; n < N_BLOCKS; n++) {
// Load bit-plane words from shared memory
uint32_t planes[K_BITS];
load_b_planes(planes, sh_b, n, k_sub);
// Dequant: extract indices, codebook lookup, scale
half2 frag_b[2];
dequant_kbit_fragb<K_BITS>(planes, scale, cb_h, frag_b);
// MMA: accumulate across all M-blocks (A fragments reused)
for (int m = 0; m < M_BLOCKS; m++) {
mma_m16n8k16(frag_a[m], frag_b, frag_c[m][n]);
}
}
}The key data reuse pattern:
- A fragments: loaded once per M-block, reused across all N-blocks
- B fragments: dequantized once per N-block, reused across all M-blocks
- Codebook register: loaded once at kernel start, reused forever
- Absmax: decoded once per block-of-32 per column, reused across M-blocks
For typical LLM shapes (N=4096-16384, M variable, K=4096-16384), the number of M-tiles * N-tiles is often less than the number of SMs:
| M | N | M/64 x N/128 | H100 SMs | Utilization |
|---|---|---|---|---|
| 128 | 4096 | 2 x 32 = 64 | 132 | 48% |
| 256 | 4096 | 4 x 32 = 128 | 132 | 97% |
| 128 | 8192 | 2 x 64 = 128 | 132 | 97% |
When utilization is below ~80%, we need split-K (multiple blocks share the same output tile, each handling a portion of K). The persistent kernel handles this naturally.
Launch exactly num_SMs blocks. Each block loops over assigned work items.
Work items are linearized as (m_tile, n_tile, k_chunk) triples:
Total work = m_tiles * n_tiles * k_chunks
where k_chunks = ceil(K_dim / TILE_K / tiles_per_chunk)
and tiles_per_chunk >= 8 (minimum for pipeline efficiency)
Work items are ordered so that all k_chunks for a given (m_tile, n_tile)
are contiguous in the linearized sequence.
Each block gets a contiguous range of work items:
int total_work = m_tiles * n_tiles * k_chunks;
int work_per_block = div_ceil(total_work, gridDim.x);
int my_start = blockIdx.x * work_per_block;
int my_end = min(my_start + work_per_block, total_work);When consecutive work items for a block share the same output tile (same m_tile, n_tile), the accumulators persist across k_chunks. The block accumulates without writing to memory.
When the output tile changes (or at the end), the block writes results:
int prev_mn = -1;
FragC frag_c[M_BLOCKS][N_BLOCKS][2];
for (int work_id = my_start; work_id < my_end; work_id++) {
int mn_id = work_id / k_chunks;
int k_chunk_id = work_id % k_chunks;
if (mn_id != prev_mn) {
if (prev_mn >= 0)
write_output(frag_c, prev_mn, ...);
zero_accumulators(frag_c);
prev_mn = mn_id;
}
// Process K-tiles for this chunk
process_k_range(k_chunk_id, frag_c, ...);
}
// Write final tile
if (prev_mn >= 0)
write_output(frag_c, prev_mn, ...);Three cases based on whether the block owns the full K-range for its output tile:
bool i_own_k_start = (my_first_k_chunk == 0);
bool i_own_k_end = (my_last_k_chunk == k_chunks - 1);
if (i_own_k_start && i_own_k_end) {
// Full ownership: write fp16 directly to C
write_frag_fp16(frag_c, C, ...);
}
else if (i_own_k_start) {
// First contributor: overwrite fp32 workspace (acts as zero + write)
write_frag_fp32(frag_c, C_workspace, ...);
}
else {
// Subsequent contributor: atomicAdd fp32
atomic_add_frag_fp32(frag_c, C_workspace, ...);
}No separate memset is needed: the first contributor overwrites the workspace.
When multiple blocks share an output tile, the last block to finish converts fp32 workspace to fp16 output. This is detected via an atomic counter:
// Per-tile done counter (in the workspace/locks array)
if (not_full_ownership) {
int count = atomicAdd(&tile_done_count[mn_id], 1);
if (count == num_contributors - 1) {
// I'm the last one: convert fp32 -> fp16
convert_tile_fp32_to_fp16(C_workspace, C, mn_id, ...);
}
}The tile_done_count array is tiny: m_tiles * n_tiles ints.
When a block switches to a new (m_tile, n_tile) or a new k_chunk, the pipeline must restart (new data in shared memory). This costs ~2 K-tiles of pipeline fill time. Within a block's k_chunk, K-tiles are processed sequentially with continuous pipeline operation.
This is the main performance overhead of split-K: each split incurs a pipeline restart. With >= 8 K-tiles per chunk, the overhead is <= 25%. Typical values (16-32 K-tiles per chunk) give 6-12% overhead.
When m_tiles * n_tiles >= num_SMs, no split-K is needed. Each block owns complete output tiles and writes fp16 directly. No fp32 workspace, no atomics, no reduction. This is the common case for large M.
Per pipeline stage:
+-------------------------------------------+
| A tile: TILE_M * TILE_K * 2 bytes (fp16) |
| For TILE_M=64, TILE_K=64: 8 KB |
+-------------------------------------------+
| B tile (packed bit-planes): |
| TILE_N * (TILE_K/32) * K * 4 bytes |
| For TILE_N=128, K=4: 4 KB |
+-------------------------------------------+
| Absmax (E4M4): |
| TILE_N * (TILE_K/32) * 1 byte |
| = 256 bytes |
+-------------------------------------------+
Total per stage (TILE_M=64, K=4): ~12.3 KB
With 2 stages (double buffer): ~24.6 KB
With 4 stages: ~49.2 KB
GPU shared memory limits:
A100: 164 KB per SM
H100: 228 KB per SM
4090: 100 KB per SM
Even with 4 stages, we have ample room.
The compressed B tiles are 2-8x smaller than fp16 would be, which means:
- More pipeline stages fit in shared memory (better latency hiding)
- Or larger tiles fit (better compute efficiency)
Double-buffered pipeline with cp.async:
// Initial fill
fetch_tile_to_shared(/*stage=*/0, k_tile_start);
fetch_tile_to_shared(/*stage=*/1, k_tile_start + 1);
cp_async_fence();
for (int kt = k_tile_start; kt < k_tile_end; kt++) {
int stage = (kt - k_tile_start) % 2;
cp_async_wait<1>(); // wait for current stage
__syncthreads();
// Prefetch next tile
if (kt + 2 < k_tile_end) {
fetch_tile_to_shared((kt + 2) % 2, kt + 2);
}
cp_async_fence();
// Process: dequant + MMA for current tile
process_k_tile(stage, frag_c, cb_h);
}
cp_async_wait<0>();
__syncthreads();__device__ void fetch_tile_to_shared(int stage, int k_tile) {
int4* sh_a = sh_a_base + stage * a_stage_words;
uint32_t* sh_b = sh_b_base + stage * b_stage_words;
uint8_t* sh_abs = sh_abs_base + stage * abs_stage_bytes;
// Load A tile: TILE_M * TILE_K / 8 int4 loads
// 256 threads, each loads ceil(A_size / 256) int4 words
for (int i = threadIdx.x; i < a_tile_int4s; i += blockDim.x) {
cp_async4(&sh_a[i], &A_global[a_offset + i]);
}
// Load B tile (packed): much smaller than A
for (int i = threadIdx.x; i < b_tile_int4s; i += blockDim.x) {
if (i < actual_b_words)
cp_async4(&sh_b_int4[i], &B_global[b_offset + i]);
}
// Load absmax: very small (256 bytes)
if (threadIdx.x < abs_tile_int4s) {
cp_async4(&sh_abs_int4[threadIdx.x], &absmax_global[abs_offset + threadIdx.x]);
}
}Note the asymmetry: A loading dominates bandwidth, B loading is "free" relative to A. This is a key advantage of compressed weights.
A tile reads (via ldmatrix): Standard ldmatrix access pattern, well-studied, no conflicts with standard swizzled layout.
B tile reads (bit-plane words): As analyzed in Section 5.2, 4 threads per column group read the same addresses (broadcast), 8 column groups read different addresses (different banks). Zero conflicts.
Absmax reads: Each thread reads one uint8 for its column. With 8 columns per warp, these are at different byte addresses. No conflicts.
The existing dequant kernel uses float32 codebook values. For the GEMM kernel, we convert to half at kernel start:
half cb_h = (lane < (1 << K_BITS))
? __float2half(codebook[lane])
: __float2half(0.0f);Rationale:
- Codebook values are in [-1, 1], well within half precision
- The MMA instruction takes fp16 inputs anyway
- Avoids float->half conversion in the inner loop (4 conversions per FragB)
- MMA accumulates in fp32, so precision loss in fp16 fragments is minimal
- The quantization error itself (~6% for K=4) dominates any fp16 rounding
With TILE_K=64, each K-tile spans exactly 2 kbit blocks. Each column has exactly 2 absmax values per K-tile. This is much simpler than Marlin's group boundary logic because there's no straddling -- the boundaries are always at fixed positions.
// Load 2 absmax values from shared memory for this column
uint8_t raw0 = sh_absmax[my_col * 2 + 0]; // block 0 (k=0..31)
uint8_t raw1 = sh_absmax[my_col * 2 + 1]; // block 1 (k=32..63)
// Decode E4M4 -> half (done once per column per K-tile)
half scale0 = __float2half(decode_e4m4_absmax(raw0));
half scale1 = __float2half(decode_e4m4_absmax(raw1));
// In the sub-tile loop:
for (int k_sub = 0; k_sub < 4; k_sub++) {
half scale = (k_sub < 2) ? scale0 : scale1;
// ... dequant uses __hmul(codebook_val, scale) ...
}The decode is ~5 integer ALU ops, done twice per column per K-tile, shared across all M-rows. Negligible cost.
The per-block absmax is functionally identical to Marlin's group scale mechanism. In Marlin terminology:
- group_size = 32 (our blocksize)
- group_blocks = TILE_K / 32 = 2 (number of groups per K-tile)
But our implementation is much simpler because:
- No activation reordering (act-order) to worry about
- Group boundaries always align with K-tile boundaries
- No zero-point subtraction
- Scale format is fixed (E4M4 uint8)
Per thread block per K-tile:
- Compute: 8 warps * 32 MMA ops * 256 FMA ops = 65,536 FMAs = 131,072 FLOPs (with TILE_K=64, this doubles to 262,144 FLOPs)
- Memory:
- A: TILE_M * TILE_K * 2 bytes = 64 * 64 * 2 = 8,192 bytes
- B: TILE_N * (TILE_K/32) * K * 4 = 128 * 2 * 4 * 4 = 4,096 bytes (K=4)
- Absmax: TILE_N * (TILE_K/32) = 128 * 2 = 256 bytes
- Total: 12,544 bytes
Arithmetic intensity: 262,144 / 12,544 = 20.9 FLOP/byte
Compare fp16 GEMM (same tiles, B in fp16):
- B would be: 128 * 64 * 2 = 16,384 bytes
- Total: 24,832 bytes
- Intensity: 262,144 / 24,832 = 10.6 FLOP/byte
The kbit kernel has ~2x higher arithmetic intensity for the same tile size.
On H100 (990 TFLOPS fp16 tensor, 3.35 TB/s HBM): Compute-bound threshold: 990e12 / 3.35e12 = 295 FLOP/byte
For C[M, 4096] = A[M, 4096] * W[4096, 4096] with K=4:
- FLOPs: 2 * M * 4096 * 4096
- Bytes: M * 4096 * 2 (A) + 4096 * 4096 * 0.53 (B, K=4 + E4M4) + M * 4096 * 2 (C)
Solving for compute-bound threshold:
- M=1: intensity ~3, memory-bound
- M=32: intensity ~93, memory-bound
- M=128: intensity ~296, at the boundary
- M=256: intensity ~465, compute-bound
For M >= ~128 on H100, we're compute-bound and tensor core utilization determines performance.
The persistent kernel with explicit work distribution loses ~5-15% vs Marlin-style stripes in unfavorable cases. The overhead comes from:
-
Pipeline startup/drain: 2 K-tiles overhead per k_chunk. With >= 8 tiles per chunk: <= 25% overhead on the chunked portion. Typical: 6-12%.
-
Tail-wave imbalance: last wave of blocks may not fill all SMs. Typically 0-5%.
-
AtomicAdd reduction: < 1% (negligible on Ampere+).
For K=4096 with split_k effective=2-4: expect ~10% overhead. For K=8192+ or when no split-K needed: ~0-3% overhead. This is acceptable given the massive implementation simplicity gain.
K=2: 2/8 + 1/32 = 0.28125 bytes/element (7.1x compression vs fp16)
K=3: 3/8 + 1/32 = 0.40625 bytes/element (4.9x compression)
K=4: 4/8 + 1/32 = 0.53125 bytes/element (3.8x compression)
K=5: 5/8 + 1/32 = 0.65625 bytes/element (3.0x compression)
(The 1/32 term is the E4M4 absmax overhead: 1 byte per 32 elements)
void kbit_gemm(
const half* A, // [M, K_dim] row-major
const uint32_t* B, // tiled kbit packed data
half* C, // [M, N] row-major
float* C_workspace, // [M, N] fp32 workspace (for split-K)
int* tile_counters, // [m_tiles * n_tiles] atomic counters
const uint8_t* absmax, // tiled E4M4 absmax
const float* codebook, // [2^K] float32 codebook
int M, int N, int K_dim, int K_bits,
cudaStream_t stream)
{
int dev;
cudaGetDevice(&dev);
int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
int max_shmem;
cudaDeviceGetAttribute(&max_shmem,
cudaDevAttrMaxSharedMemoryPerBlockOption, dev);
// Choose M-blocking
int m_blocks;
if (M <= 16) m_blocks = 1;
else if (M <= 32) m_blocks = 2;
else if (M <= 48) m_blocks = 3;
else m_blocks = 4;
int tile_m = m_blocks * 16;
// Choose tile config
struct Config { int tile_k, tile_n, threads; };
Config cfg = select_config(m_blocks, M, N, K_dim, K_bits, max_shmem);
// Compute work distribution
int m_tiles = div_ceil(M, tile_m);
int n_tiles = N / cfg.tile_n;
int k_tiles = K_dim / cfg.tile_k;
int min_tiles_per_chunk = 8;
int k_chunks = max(1, div_ceil(k_tiles, max(min_tiles_per_chunk,
div_ceil(k_tiles * m_tiles * n_tiles, sms) /* target full occupancy */)));
// Zero tile counters if split-K
bool needs_split_k = (m_tiles * n_tiles * k_chunks > m_tiles * n_tiles);
if (needs_split_k) {
cudaMemsetAsync(tile_counters, 0, m_tiles * n_tiles * sizeof(int), stream);
}
// Launch persistent kernel
int shmem_size = compute_shmem(cfg, m_blocks, K_bits);
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size);
// Dispatch on K_bits and m_blocks
dispatch_kernel(K_bits, m_blocks, cfg, sms, shmem_size, stream, ...);
}Priority-ordered configs for small and large batch:
// Small batch (m_blocks == 1):
Config small_configs[] = {
{64, 128, 256}, // balanced
{64, 128, 128}, // fewer threads, less shmem
{32, 128, 128}, // shallow K, tight shmem
};
// Large batch (m_blocks > 1):
Config large_configs[] = {
{64, 256, 256}, // wide N, maximum output parallelism
{64, 128, 256}, // balanced
{64, 128, 128}, // fallback
};Validation: config must fit in shared memory and divide problem dimensions.
Following the existing pattern in bitsandbytes/_ops.py:
torch.library.define(
"bitsandbytes::kbit_gemm",
"(Tensor A, Tensor B_packed, Tensor absmax, Tensor codebook, "
"int k, int N, int K_dim) -> Tensor",
)CUDA backend in bitsandbytes/backends/cuda/ops.py:
@register_kernel("bitsandbytes::kbit_gemm", "cuda")
def _(A, B_packed, absmax, codebook, k, N, K_dim):
M = A.shape[0]
C = torch.empty(M, N, dtype=A.dtype, device=A.device)
# ... allocate workspace, call C function ...
return Ctorch.library.define(
"bitsandbytes::kbit_repack_for_gemm",
"(Tensor packed_flat, Tensor absmax_flat, int K_dim, int N, int k, "
"int tile_k, int tile_n) -> (Tensor, Tensor)",
)This would be called once when loading a model, before inference begins.
The GEMM kernel should go in csrc/kernels.cu (the standard location for
CUDA kernels in bitsandbytes), NOT in csrc/ops.cu.
Background: The existing kbit quantize/dequantize kernels were placed in
ops.cu to avoid RDC (relocatable device code) linking issues with template
instantiations. This was a workaround, not a deliberate architectural choice.
The CUDA_RESOLVE_DEVICE_SYMBOLS ON flag was added to CMakeLists.txt as
part of that workaround and should be removed.
For the GEMM kernel: place the kernel definition and launch wrapper in
csrc/kernels.cu with declarations in csrc/kernels.cuh. The extern "C"
wrappers go in csrc/pythonInterface.cpp following the existing pattern.
Remove the CUDA_RESOLVE_DEVICE_SYMBOLS ON flag that was added as a
workaround. The GEMM kernel doesn't need it if templates are properly
instantiated in the same compilation unit as their declarations.
No new .cu files needed. The GEMM kernel fits naturally in the existing file structure:
- Kernel code:
csrc/kernels.cu(append) - Kernel declarations:
csrc/kernels.cuh(append) - Launch wrappers:
csrc/ops.cu(append, for the host-side dispatch) - C interface:
csrc/pythonInterface.cpp(append) - Python ops:
bitsandbytes/_ops.py(append) - CUDA backend:
bitsandbytes/backends/cuda/ops.py(append) - Tests:
tests/test_kbit_gemm.py(new)
The GEMM kernel is templated on:
- K_BITS: 2, 3, 4, 5
- M_BLOCKS: 1, 2, 3, 4
- Tile config (TILE_K, TILE_N): 2-3 configs
Total: 4 * 4 * 3 = 48 kernel variants (worst case). This is manageable. Marlin has hundreds of variants.
Instantiation via macros, similar to existing pattern:
#define INSTANTIATE_KBIT_GEMM(K, M_BLOCKS, TILE_K, TILE_N) \
template __global__ void kbit_gemm_kernel<K, M_BLOCKS, TILE_K, TILE_N>(...);
INSTANTIATE_KBIT_GEMM(2, 1, 64, 128)
INSTANTIATE_KBIT_GEMM(2, 2, 64, 128)
// ... etcThe existing test suite establishes the combined error bound per block:
max_error <= (max_gap/2 + 1/16) * absmax + epsilon
where:
max_gap: maximum gap between adjacent codebook entries
1/16: maximum relative error from E4M4 absmax encoding
absmax: absolute maximum of the block
epsilon: small constant for floating-point rounding (~1e-6)
The GEMM kernel introduces no new error sources beyond the standalone dequant:
- Same bit-plane extraction (exact)
- Same codebook lookup (exact, via shuffle)
- Same absmax multiply (same precision)
- fp16 codebook storage adds at most 1 ULP of fp16 (~0.001 for values near 1.0)
- MMA accumulates in fp32 (no precision loss in accumulation)
From the test suite (1M elements, normal distribution):
- K=2: SQNR > 5 dB
- K=3: SQNR > 10 dB
- K=4: SQNR > 15 dB
- K=5: SQNR > 20 dB
E4M4 absmax degrades SQNR by < 1.5 dB vs fp32 absmax.
The GEMM kernel should match these bounds exactly, since the dequant logic is identical.
template <int K_BITS, int M_BLOCKS, int TILE_K = 64, int TILE_N = 128>
__global__ void kbit_gemm_kernel(
const half* __restrict__ A,
const uint32_t* __restrict__ B_packed,
half* __restrict__ C,
float* __restrict__ C_workspace,
int* __restrict__ tile_counters,
const uint8_t* __restrict__ B_absmax,
const float* __restrict__ codebook,
int M, int N, int K_dim,
int m_tiles, int n_tiles, int k_chunks,
int tiles_per_chunk);template <int K_BITS, int TILE_K = 64, int TILE_N = 128>
__global__ void kbit_repack_kernel(
const uint32_t* __restrict__ packed_flat,
const uint8_t* __restrict__ absmax_flat,
uint32_t* __restrict__ packed_tiled,
uint8_t* __restrict__ absmax_tiled,
int K_dim, int N);On Hopper GPUs, warp specialization can be used: producer warps handle data loading (using TMA for efficient async copies), consumer warps handle compute. The producer warps could handle the bit-plane loading and even partial dequantization, feeding pre-dequantized fp16 tiles to consumer warps. This would further overlap memory and compute.
The current kbit implementation uses blocksize=32 (warp-size). Larger block sizes (64, 128) would reduce the absmax overhead but require different packing primitives (can't use single-warp __ballot_sync for blocks > 32). This would be a separate project.
If activations are also kbit-quantized, the GEMM becomes a fully quantized matmul. This would require a different kernel architecture (integer MMA or custom accumulation).
Common fused patterns for inference:
- kbit GEMM + bias add
- kbit GEMM + ReLU/GELU
- kbit GEMM + residual add
These can be added as epilogue options in the kernel template, similar to Marlin's bias support.
For attention computation, batched GEMM (multiple independent GEMMs) may be needed. The persistent kernel can be extended to handle batches by adding a batch dimension to the work assignment.
Key files in ~/git/vllm/csrc/quantization/marlin/:
marlin_template.h: Main kernel template (~2070 lines)- Line 271-281: Stripe partitioning explanation
- Line 362-401: Work distribution setup
- Line 916-923: Pipeline wait/fence
- Line 927-939: Register fetch from shared memory
- Line 1167-1285: matmul() inner loop with dequant + scale + MMA
- Line 1780-1813: Main K-loop with pipeline interleaving
- Line 1839-2068: Output reduction and slice management
marlin.cu: Host dispatch (~530 lines)- Line 128-141: Thread config tables
- Line 179-249: Config validation
- Line 265-313: Config selection
- Line 315-527: Main dispatch function
marlin_mma.h: MMA instruction wrappersdequant.h: Dequantization functions (lop3-based)marlin.cuh: Constants and helpers
- Block (quantization): A group of 32 consecutive elements sharing one absmax value
- Block (CUDA): A CUDA thread block (256 threads = 8 warps)
- Bit-plane: A uint32 word containing one bit from each of 32 elements
- FragA, FragB, FragC: Register fragments for tensor core MMA
- MMA: Matrix multiply-accumulate (tensor core instruction)
- m16n8k16: MMA instruction computing a 16x8 output from 16x16 and 16x8 inputs
- Split-K: Partitioning the K (reduction) dimension across multiple thread blocks
- Tile: A sub-matrix processed by one thread block or one MMA instruction
- TILE_K, TILE_M, TILE_N: Thread block tile dimensions
- Persistent kernel: A kernel that launches exactly num_SMs blocks, each looping over work
- E4M4: 8-bit float format with 4-bit exponent and 4-bit mantissa
- Codebook: A lookup table of 2^K reconstruction values for quantization
- absmax: Per-block absolute maximum, used as scale factor
- Normal-float: Quantization levels placed at quantiles of N(0,1)