diff --git a/aie_kernels/generic/fused_dequant_gemv.cc b/aie_kernels/generic/fused_dequant_gemv.cc new file mode 100644 index 00000000..fc911801 --- /dev/null +++ b/aie_kernels/generic/fused_dequant_gemv.cc @@ -0,0 +1,164 @@ +// SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Fused INT4 dequantization + GEMV kernel for AIE2+. +// +// Loads INT4-packed weights, dequantizes in-register, and performs +// matrix-vector multiplication in a single pass. +// +// Weight layout per tile (m rows x K cols, group_size G): +// [m * K / 2 bytes of packed uint4 weights] +// [m * (K / G) bf16 scale factors, stored as (m * K / G * 2) bytes] +// +// Dequantization: w_bf16 = scale * unpack_uint4_to_bf16(w_uint4) +// +// The unpack chain matches the existing dequant kernel (expand.cc): +// uint4 -> uint8 (aie::unpack) -> uint16 (aie::unpack) -> bf16 (aie::to_float) +// +// Optimization: double-pump — process 2 groups (64 elements) per iteration +// so the compiler can interleave the two independent unpack chains, hiding +// the dequant latency behind computation. + +#define NOCPP + +#include "../aie_kernel_utils.h" + +#include +#include +#include + +// block_size: dequant vector width (must be 32 for aie::unpack) +// G: group size (compile-time for pipelining, must be multiple of block_size) +// DK: K dimension (compile-time for loop count optimization) +template +void fused_dequant_matvec(uint32_t m, + const uint8_t *__restrict a_in, + const bfloat16 *__restrict b_in, + bfloat16 *__restrict c_out) +{ + static_assert(block_size == 32, "block_size must be 32 to match dequant vector width"); + static_assert(G % block_size == 0, "group_size must be a multiple of block_size"); + constexpr uint32_t blocks_per_group = G / block_size; + constexpr uint32_t groups_per_row = DK / G; + // For double-pump: process 2 groups per iteration when possible + constexpr bool can_double_pump = (groups_per_row >= 2) && (groups_per_row % 2 == 0); + constexpr uint32_t pump_groups = can_double_pump ? 2 : 1; + constexpr uint32_t loop_iters = groups_per_row / pump_groups; + + ::aie::set_rounding(aie::rounding_mode::conv_even); + + const uint4 *weights_packed = reinterpret_cast(a_in); + const uint8_t *scale_bytes = a_in + m * DK / 2; + const bfloat16 *scales = reinterpret_cast(scale_bytes); + + event0(); + for (uint32_t row = 0; row < m; row++) { + const uint4 *row_weights = weights_packed + row * DK / 2; + const bfloat16 *row_scales = scales + row * groups_per_row; + const bfloat16 *b_ptr = b_in; + + aie::accum acc = aie::zeros(); + + if constexpr (can_double_pump && blocks_per_group == 1) { + // Optimized path: 2 groups per iteration, 1 block per group + // Two independent unpack chains for the compiler to interleave. + AIE_LOOP_MIN_ITERATION_COUNT(loop_iters) + for (uint32_t g = 0; g < groups_per_row; g += 2) + AIE_PREPARE_FOR_PIPELINING + { + // --- Chain A: group g --- + bfloat16 sf_a = row_scales[g]; + aie::vector sf_a_bc = + aie::broadcast(sf_a); + + aie::vector I0_a = aie::load_v(row_weights); + row_weights += block_size / 2; + + // --- Chain B: group g+1 (interleaved) --- + bfloat16 sf_b = row_scales[g + 1]; + aie::vector sf_b_bc = + aie::broadcast(sf_b); + + aie::vector I0_b = aie::load_v(row_weights); + row_weights += block_size / 2; + + // Unpack chain A + aie::vector a8_a = aie::unpack(I0_a); + aie::vector a16_a = aie::unpack(a8_a); + aie::vector abf_a = aie::to_float(a16_a, 0); + aie::vector w_a = + aie::mul(abf_a, sf_a_bc).template to_vector(); + + // Unpack chain B + aie::vector a8_b = aie::unpack(I0_b); + aie::vector a16_b = aie::unpack(a8_b); + aie::vector abf_b = aie::to_float(a16_b, 0); + aie::vector w_b = + aie::mul(abf_b, sf_b_bc).template to_vector(); + + // Load activation vectors and MAC + aie::vector b_a = aie::load_v(b_ptr); + b_ptr += block_size; + acc = aie::mac(acc, w_a, b_a); + + aie::vector b_b = aie::load_v(b_ptr); + b_ptr += block_size; + acc = aie::mac(acc, w_b, b_b); + } + } else { + // Generic path: 1 group per iteration + AIE_LOOP_MIN_ITERATION_COUNT(loop_iters) + for (uint32_t g = 0; g < groups_per_row; g++) + AIE_PREPARE_FOR_PIPELINING + { + bfloat16 sf = row_scales[g]; + aie::vector sf_broadcast = + aie::broadcast(sf); + + AIE_LOOP_MIN_ITERATION_COUNT(blocks_per_group) + for (uint32_t blk = 0; blk < blocks_per_group; blk++) { + aie::vector I0 = aie::load_v(row_weights); + row_weights += block_size / 2; + + aie::vector as_int8 = aie::unpack(I0); + aie::vector as_int16 = aie::unpack(as_int8); + aie::vector as_bf16 = + aie::to_float(as_int16, 0); + aie::vector w_dequant = + aie::mul(as_bf16, sf_broadcast).template to_vector(); + + aie::vector b_vec = aie::load_v(b_ptr); + b_ptr += block_size; + + acc = aie::mac(acc, w_dequant, b_vec); + } + } + } + + *c_out = static_cast(aie::reduce_add(acc.template to_vector())); + c_out++; + } + event1(); +} + +#ifndef GROUP_SIZE +#define GROUP_SIZE 32 +#endif + +#ifndef DIM_K +#define DIM_K 2048 +#endif + +extern "C" { + +void fused_dequant_matvec_bf16(uint32_t m, + uint32_t row_offset, + const uint8_t *__restrict a_in, + const bfloat16 *__restrict b_in, + bfloat16 *__restrict c_out) +{ + c_out += row_offset; + fused_dequant_matvec<32, GROUP_SIZE, DIM_K>(m, a_in, b_in, c_out); +} + +} // extern "C" diff --git a/iron/operators/gemv_int4/design.py b/iron/operators/gemv_int4/design.py new file mode 100644 index 00000000..0fae355d --- /dev/null +++ b/iron/operators/gemv_int4/design.py @@ -0,0 +1,161 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Fused INT4 dequantization matrix-vector design. + +Performs a fused dequantize-GEMV where the weight matrix is stored in packed +INT4 format (two 4-bit values per uint8 byte) with per-group bfloat16 scale +factors. The activation vector and output are bfloat16. + +Each AIE column processes a contiguous block of output rows. Within a column, +the worker iterates over tiles of packed weight rows, acquires the full +activation vector once per outer iteration, and calls the fused dequant-matvec +kernel which unpacks, dequantizes, and accumulates in a single pass. + +Buffer layout for A (packed weights, uint8): + For each tile of m_input rows: [m_input * K / 2 bytes of packed weights] + [m_input * (K / group_size) * 2 bytes of scales] +""" + +import numpy as np +from ml_dtypes import bfloat16 + +import aie.dialects.index as index +from aie.dialects.aie import T +from aie.helpers.dialects.scf import _for as range_ +from aie.helpers.taplib import TensorAccessPattern +from aie.iron import Kernel, ObjectFifo, Program, Runtime, Worker +from aie.iron.placers import SequentialPlacer + + +def my_fused_dequant_matvec( + dev, cols, M, K, m_input, m_output=None, group_size=32 +): + if m_output is None: + m_output = m_input + + # --- Assertions --- + assert ( + m_output % m_input == 0 and m_output >= m_input + ), "m_output must be a multiple of m_input" + assert m_output <= M // cols, "m_output must be less than or equal to M/cols" + assert (M // cols) % m_output == 0, "m_output must evenly divide M/cols" + assert m_input <= M // cols, "m_input must be less than or equal to M/cols" + assert (M // cols) % m_input == 0, "m_input must evenly divide M/cols" + assert K % group_size == 0, "K must be divisible by group_size" + assert group_size % 32 == 0, "group_size must be a multiple of 32" + assert M % cols == 0, "M must be divisible by cols" + + # --- Data types --- + dtype_in = np.dtype[np.uint8] + dtype_vec = np.dtype[bfloat16] + dtype_out = np.dtype[bfloat16] + + # --- Per-tile sizes (in uint8 bytes) --- + num_groups_per_row = K // group_size + packed_tile_bytes = m_input * K // 2 + m_input * num_groups_per_row * 2 + rows_per_col = M // cols + tiles_per_col = rows_per_col // m_input + bytes_per_col = tiles_per_col * packed_tile_bytes + packed_total_bytes = cols * bytes_per_col + + # --- L1 (on-chip) tensor types --- + L1_A_ty = np.ndarray[(packed_tile_bytes,), dtype_in] + L1_B_ty = np.ndarray[(K,), dtype_vec] + L1_C_ty = np.ndarray[(m_output,), dtype_out] + + # --- L3 (DDR) tensor types --- + L3_A_ty = np.ndarray[(packed_total_bytes,), dtype_in] + L3_B_ty = np.ndarray[(K,), dtype_vec] + L3_C_ty = np.ndarray[(M,), dtype_out] + + # --- Kernel declaration --- + # K and group_size are compile-time via -DDIM_K/-DGROUP_SIZE. + fused_matvec = Kernel( + "fused_dequant_matvec_bf16", + f"fused_dequant_gemv_{K}k_g{group_size}.o", + [np.int32, np.int32, L1_A_ty, L1_B_ty, L1_C_ty], + ) + + # --- ObjectFIFOs --- + A_L3L1_fifos = [ + ObjectFifo(L1_A_ty, name=f"A_L3L1_{i}", depth=2) for i in range(cols) + ] + B_L3L1_fifos = [ + ObjectFifo(L1_B_ty, name=f"B_L3L1_{i}", depth=1) for i in range(cols) + ] + C_L1L3_fifos = [ + ObjectFifo(L1_C_ty, name=f"C_L1L3_{i}", depth=2) for i in range(cols) + ] + + # --- Worker core body --- + N_div_n = tiles_per_col // (m_output // m_input) + + def core_body(A_L3L1_fifo, B_L3L1_fifo, C_L1L3_fifo, fused_matvec_fn): + for _ in range_(0xFFFFFFFF): + b = B_L3L1_fifo.acquire(1) + for i_idx in range_(N_div_n): + c = C_L1L3_fifo.acquire(1) + for j_idx in range_(m_output // m_input): + j_i32 = index.casts(T.i32(), j_idx) + output_row_offset = j_i32 * m_input + a = A_L3L1_fifo.acquire(1) + fused_matvec_fn( + m_input, output_row_offset, a, b, c + ) + A_L3L1_fifo.release(1) + C_L1L3_fifo.release(1) + B_L3L1_fifo.release(1) + + workers = [ + Worker( + core_body, + [ + A_L3L1_fifos[i].cons(), + B_L3L1_fifos[i].cons(), + C_L1L3_fifos[i].prod(), + fused_matvec, + ], + ) + for i in range(cols) + ] + + # --- TensorAccessPatterns --- + # A: each column gets a contiguous chunk of bytes_per_col packed bytes + A_taps = [ + TensorAccessPattern( + tensor_dims=L3_A_ty.__args__[0], + offset=col * bytes_per_col, + sizes=[1, 1, 1, bytes_per_col], + strides=[0, 0, 0, 1], + ) + for col in range(cols) + ] + + # C: each column writes contiguous rows_per_col bfloat16 values + C_taps = [ + TensorAccessPattern( + tensor_dims=L3_C_ty.__args__[0], + offset=col * rows_per_col, + sizes=[1, 1, 1, rows_per_col], + strides=[0, 0, 0, 1], + ) + for col in range(cols) + ] + + # --- Runtime sequence --- + rt = Runtime() + with rt.sequence(L3_A_ty, L3_B_ty, L3_C_ty) as (A, B, C): + rt.start(*workers) + tg = rt.task_group() + for i in range(cols): + rt.fill(A_L3L1_fifos[i].prod(), A, A_taps[i], task_group=tg) + rt.fill(B_L3L1_fifos[i].prod(), B, task_group=tg) + for i in range(cols): + rt.drain( + C_L1L3_fifos[i].cons(), C, C_taps[i], task_group=tg, wait=True + ) + rt.finish_task_group(tg) + + return Program(dev, rt).resolve_program(SequentialPlacer()) diff --git a/iron/operators/gemv_int4/op.py b/iron/operators/gemv_int4/op.py new file mode 100644 index 00000000..89973427 --- /dev/null +++ b/iron/operators/gemv_int4/op.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass, field +from typing import ClassVar, Dict + +import numpy as np +from ml_dtypes import bfloat16 + +from iron.common import ( + MLIROperator, + AIERuntimeArgSpec, + KernelObjectArtifact, + SourceArtifact, + PythonGeneratedMLIRArtifact, + DesignGenerator, +) +import aie.utils as aie_utils + + +@dataclass +class GEMVInt4(MLIROperator): + """AIE-accelerated fused INT4 dequantization and GEMV operator""" + + M: int + K: int + num_aie_columns: int = 4 + tile_size_input: int = 1 + tile_size_output: int | None = None + group_size: int = 32 + context: object = field(default=None, repr=False) + + _name_aliases: ClassVar[Dict[str, str]] = { + **MLIROperator._name_aliases, + "num_aie_columns": "col", + "tile_size_input": "tsi", + "tile_size_output": "tso", + "group_size": "g", + } + + def __post_init__(self): + if self.tile_size_output is None: + self.tile_size_output = self.M // self.num_aie_columns + + if not ( + self.tile_size_output % self.tile_size_input == 0 + and self.tile_size_output >= self.tile_size_input + ): + raise ValueError("tile_size_output must be a multiple of tile_size_input") + if not (self.K % self.group_size == 0): + raise ValueError("K must be a multiple of group_size") + if not (self.group_size % 32 == 0): + raise ValueError("group_size must be a multiple of 32") + if not (self.M % self.num_aie_columns == 0): + raise ValueError("M must be a multiple of num_aie_columns") + + MLIROperator.__init__(self, context=self.context) + + @property + def _packed_buffer_size(self): + num_groups_per_row = self.K // self.group_size + packed_tile_bytes = ( + self.tile_size_input * self.K // 2 + + self.tile_size_input * num_groups_per_row * 2 + ) + rows_per_col = self.M // self.num_aie_columns + tiles_per_col = rows_per_col // self.tile_size_input + return self.num_aie_columns * tiles_per_col * packed_tile_bytes + + def get_mlir_artifact(self): + return PythonGeneratedMLIRArtifact( + f"{self.name}.mlir", + DesignGenerator( + self.operator_dir / "design.py", + "my_fused_dequant_matvec", + ( + aie_utils.get_current_device(), + self.num_aie_columns, + self.M, + self.K, + self.tile_size_input, + self.tile_size_output, + self.group_size, + ), + ), + ) + + def get_kernel_artifacts(self): + return [ + KernelObjectArtifact( + f"fused_dequant_gemv_{self.K}k_g{self.group_size}.o", + dependencies=[ + SourceArtifact( + self.context.base_dir + / "aie_kernels" + / "generic" + / "fused_dequant_gemv.cc" + ) + ], + extra_flags=[ + f"-DDIM_K={self.K}", + f"-DGROUP_SIZE={self.group_size}", + ], + ), + ] + + def get_arg_spec(self): + return [ + AIERuntimeArgSpec( + "in", (self._packed_buffer_size,), dtype=np.uint8 + ), # packed INT4 weights + AIERuntimeArgSpec("in", (self.K,)), # bf16 activation vector + AIERuntimeArgSpec("out", (self.M,)), # bf16 output vector + ] diff --git a/iron/operators/gemv_int4/reference.py b/iron/operators/gemv_int4/reference.py new file mode 100644 index 00000000..e9e27506 --- /dev/null +++ b/iron/operators/gemv_int4/reference.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import numpy as np +from ml_dtypes import bfloat16 + + +def quantize_and_pack(M, K, group_size=32, m_input=1, cols=4): + """Generate quantized INT4 weights and pack for the fused dequant-GEMV kernel. + + Uses the same quantization scheme as the existing dequant operator + (iron/operators/dequant/reference.py): unsigned INT4 values with per-group + bf16 scale factors, zero-point fixed at 0. + + The DDR buffer is laid out per-tile, where each tile corresponds to + ``m_input`` matrix rows. Tiles for column 0 come first, then column 1, + etc. Within each tile the layout is: + + [m_input * K / 2 bytes] packed uint4 weights (2 values per byte, + low nibble first in little-endian order) + [m_input * (K / group_size) * 2 bytes] bf16 scale factors + + Args: + M: Number of rows in the weight matrix. + K: Number of columns in the weight matrix. + group_size: Number of elements per quantization group (default 32). + m_input: Number of rows per kernel tile invocation. + cols: Number of AIE columns the work is split across. + + Returns: + packed: numpy uint8 array with the complete packed DDR buffer. + W_dequant: torch.bfloat16 (M, K) tensor of dequantized weights. + """ + assert K % group_size == 0, "K must be a multiple of group_size" + assert M % cols == 0, "M must be a multiple of cols" + rows_per_col = M // cols + assert rows_per_col % m_input == 0, "rows_per_col must be a multiple of m_input" + + num_groups_per_row = K // group_size + val_range = 3.75 + r1, r2 = 1 / val_range, 1.0 + + # Generate per-group scale factors in [r1, r2) + total_groups = M * num_groups_per_row + scales_flat = r1 + (r2 - r1) * torch.rand(total_groups, dtype=torch.bfloat16) + zero_points = torch.zeros(total_groups, dtype=torch.bfloat16) + + # Generate random data in [0, val_range) shaped for per-group quantization + W_grouped = torch.rand(total_groups, group_size, dtype=torch.bfloat16) * val_range + + # Quantize with PyTorch per-channel (per-group) quantization + A_quant = torch.quantize_per_channel( + W_grouped.to(torch.float32), + scales=scales_flat.to(torch.float32), + zero_points=zero_points.to(torch.float32), + axis=0, + dtype=torch.quint8, + ) + W_dequant = torch.dequantize(A_quant).to(torch.bfloat16).reshape(M, K) + A_int = A_quant.int_repr() # (total_groups, group_size) with values in [0,15] + + # Now pack into the tile-based DDR layout. + # Tile order: column 0 tiles first, then column 1, etc. + packed_bytes_per_tile = m_input * K // 2 + m_input * num_groups_per_row * 2 + tiles_per_col = rows_per_col // m_input + total_tiles = cols * tiles_per_col + total_bytes = total_tiles * packed_bytes_per_tile + + packed = np.zeros(total_bytes, dtype=np.uint8) + + for col in range(cols): + for tile_idx in range(tiles_per_col): + # Global row range for this tile + row_start = col * rows_per_col + tile_idx * m_input + # Offset into the packed buffer + flat_tile = col * tiles_per_col + tile_idx + tile_offset = flat_tile * packed_bytes_per_tile + + # 1) Pack uint4 weights for m_input rows + for r in range(m_input): + global_row = row_start + r + for grp in range(num_groups_per_row): + flat_grp = global_row * num_groups_per_row + grp + for k in range(group_size // 2): + val_lo = int(A_int[flat_grp, 2 * k].item()) & 0x0F + val_hi = int(A_int[flat_grp, 2 * k + 1].item()) & 0x0F + byte_idx = ( + tile_offset + r * (K // 2) + grp * (group_size // 2) + k + ) + packed[byte_idx] = val_lo | (val_hi << 4) + + # 2) Pack bf16 scale factors for m_input rows + scale_region_start = tile_offset + m_input * K // 2 + for r in range(m_input): + global_row = row_start + r + for grp in range(num_groups_per_row): + flat_grp = global_row * num_groups_per_row + grp + sf_val = scales_flat[flat_grp] + sf_uint16 = sf_val.view(torch.uint16).item() + sf_offset = scale_region_start + (r * num_groups_per_row + grp) * 2 + packed[sf_offset] = sf_uint16 & 0xFF + packed[sf_offset + 1] = (sf_uint16 >> 8) & 0xFF + + return packed, W_dequant + + +def generate_golden_reference( + M=2048, K=2048, group_size=32, m_input=1, cols=4, seed=42 +): + """Generate golden reference for fused dequant-GEMV. + + Args: + M: Number of rows in the weight matrix. + K: Number of columns (== input vector length). + group_size: Quantization group size. + m_input: Number of rows per kernel tile invocation. + cols: Number of AIE columns. + seed: Random seed for reproducibility. + + Returns: + dict with packed_weights, x, output, W_dequant. + """ + torch.manual_seed(seed) + + # Generate random input vector + val_range = 4 + x = torch.randn(K, dtype=torch.bfloat16) * val_range + + # Generate quantized + packed weights + packed_weights, W_dequant = quantize_and_pack(M, K, group_size, m_input, cols) + + # Reference output: dequantized_weights @ x + output = W_dequant @ x + + return { + "packed_weights": packed_weights, + "x": x, + "output": output, + "W_dequant": W_dequant, + } diff --git a/iron/operators/gemv_int4/test.py b/iron/operators/gemv_int4/test.py new file mode 100644 index 00000000..6c8b32b2 --- /dev/null +++ b/iron/operators/gemv_int4/test.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import aie.utils as aie_utils + +from iron.operators.gemv_int4.op import GEMVInt4 +from iron.operators.gemv_int4.reference import generate_golden_reference +from iron.common.test_utils import run_test + + +def get_params(): + max_aie_columns = aie_utils.get_current_device().cols + + params_list = [ + # (M, K, num_aie_columns, tile_size_input, tile_size_output, group_size) + (2048, 2048, 4, 1, 512, 32), # Basic, 4 cols + (8192, 2048, 4, 1, 2048, 32), # Llama down_proj, 4 cols + (2048, 8192, 4, 1, 512, 32), # Llama up_proj, 4 cols + (2048, 8192, 8, 1, 256, 32), # Llama up_proj, 8 cols + (8192, 2048, 8, 1, 1024, 32), # Llama down_proj, 8 cols + (2048, 8192, 4, 4, 512, 32), # tsi=4 for better amortization + (8192, 2048, 4, 4, 2048, 32), # tsi=4 + ] + + params = [] + for p in params_list: + M, K, num_aie_columns, tile_size_input, tile_size_output, group_size = p + # Skip tests that require more columns than available on the device + if num_aie_columns > max_aie_columns: + continue + params.append( + pytest.param( + *p, + id=f"gemv_int4_{M}x{K}_{tile_size_input}tsi_{tile_size_output}tso_{num_aie_columns}col_g{group_size}", + ) + ) + return params + + +@pytest.mark.metrics( + Latency=r"Latency \(us\): (?P[\d\.]+)", + Bandwidth=r"Effective Bandwidth: (?P[\d\.e\+-]+) GB/s", + Throughput=r"Throughput: (?P[\d\.e\+-]+) GFLOP/s", +) +@pytest.mark.parametrize( + "M,K,num_aie_columns,tile_size_input,tile_size_output,group_size", get_params() +) +def test_gemv_int4( + M, K, num_aie_columns, tile_size_input, tile_size_output, group_size, aie_context +): + golden_ref = generate_golden_reference( + M=M, + K=K, + group_size=group_size, + m_input=tile_size_input, + cols=num_aie_columns, + ) + + operator = GEMVInt4( + M=M, + K=K, + num_aie_columns=num_aie_columns, + tile_size_input=tile_size_input, + tile_size_output=tile_size_output, + group_size=group_size, + context=aie_context, + ) + + input_buffers = { + "packed_weights": torch.from_numpy(golden_ref["packed_weights"]), + "vector": golden_ref["x"], + } + output_buffers = {"output": golden_ref["output"]} + + # Tolerances are looser than bf16 GEMV (rel_tol=0.04, abs_tol=1e-3) because + # INT4 quantization introduces significant per-group rounding error. + errors, latency_us, bandwidth_gbps = run_test( + operator, input_buffers, output_buffers, rel_tol=0.07, abs_tol=0.7 + ) + + print(f"\nLatency (us): {latency_us:.1f}") + + gflops = (2.0 * M * K) / (latency_us * 1e-6) / 1e9 + print(f"Throughput: {gflops:.2e} GFLOP/s") + + # INT4 weights: M*K/2 bytes + scales (bf16): M*(K//group_size)*2 bytes + weight_bytes = M * K / 2 + M * (K // group_size) * 2 + vector_bytes = K * 2 # bf16 + output_bytes = M * 2 # bf16 + total_bytes = weight_bytes + vector_bytes + output_bytes + bandwidth = total_bytes / (latency_us * 1e-6) / 1e9 + print(f"Effective Bandwidth: {bandwidth:.2e} GB/s\n") + + assert not errors, f"Test failed with errors: {errors}"