Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions aie_kernels/generic/fused_dequant_gemv.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

// Fused INT4 dequantization + GEMV kernel for AIE2+.
//
// Loads INT4-packed weights, dequantizes in-register, and performs
// matrix-vector multiplication in a single pass.
//
// Weight layout per tile (m rows x K cols, group_size G):
// [m * K / 2 bytes of packed uint4 weights]
// [m * (K / G) bf16 scale factors, stored as (m * K / G * 2) bytes]
//
// Dequantization: w_bf16 = scale * unpack_uint4_to_bf16(w_uint4)
//
// The unpack chain matches the existing dequant kernel (expand.cc):
// uint4 -> uint8 (aie::unpack) -> uint16 (aie::unpack) -> bf16 (aie::to_float)
//
// Optimization: double-pump — process 2 groups (64 elements) per iteration
// so the compiler can interleave the two independent unpack chains, hiding
// the dequant latency behind computation.

#define NOCPP

#include "../aie_kernel_utils.h"

#include <aie_api/aie.hpp>
#include <stdint.h>
#include <type_traits>

// block_size: dequant vector width (must be 32 for aie::unpack)
// G: group size (compile-time for pipelining, must be multiple of block_size)
// DK: K dimension (compile-time for loop count optimization)
template <uint32_t block_size, uint32_t G, uint32_t DK>
void fused_dequant_matvec(uint32_t m,
const uint8_t *__restrict a_in,
const bfloat16 *__restrict b_in,
bfloat16 *__restrict c_out)
{
static_assert(block_size == 32, "block_size must be 32 to match dequant vector width");
static_assert(G % block_size == 0, "group_size must be a multiple of block_size");
constexpr uint32_t blocks_per_group = G / block_size;
constexpr uint32_t groups_per_row = DK / G;
// For double-pump: process 2 groups per iteration when possible
constexpr bool can_double_pump = (groups_per_row >= 2) && (groups_per_row % 2 == 0);
constexpr uint32_t pump_groups = can_double_pump ? 2 : 1;
constexpr uint32_t loop_iters = groups_per_row / pump_groups;

::aie::set_rounding(aie::rounding_mode::conv_even);

const uint4 *weights_packed = reinterpret_cast<const uint4 *>(a_in);
const uint8_t *scale_bytes = a_in + m * DK / 2;
const bfloat16 *scales = reinterpret_cast<const bfloat16 *>(scale_bytes);

event0();
for (uint32_t row = 0; row < m; row++) {
const uint4 *row_weights = weights_packed + row * DK / 2;
const bfloat16 *row_scales = scales + row * groups_per_row;
const bfloat16 *b_ptr = b_in;

aie::accum<accfloat, block_size> acc = aie::zeros<accfloat, block_size>();

if constexpr (can_double_pump && blocks_per_group == 1) {
// Optimized path: 2 groups per iteration, 1 block per group
// Two independent unpack chains for the compiler to interleave.
AIE_LOOP_MIN_ITERATION_COUNT(loop_iters)
for (uint32_t g = 0; g < groups_per_row; g += 2)
AIE_PREPARE_FOR_PIPELINING
{
// --- Chain A: group g ---
bfloat16 sf_a = row_scales[g];
aie::vector<bfloat16, block_size> sf_a_bc =
aie::broadcast<bfloat16, block_size>(sf_a);

aie::vector<uint4, block_size> I0_a = aie::load_v<block_size>(row_weights);
row_weights += block_size / 2;

// --- Chain B: group g+1 (interleaved) ---
bfloat16 sf_b = row_scales[g + 1];
aie::vector<bfloat16, block_size> sf_b_bc =
aie::broadcast<bfloat16, block_size>(sf_b);

aie::vector<uint4, block_size> I0_b = aie::load_v<block_size>(row_weights);
row_weights += block_size / 2;

// Unpack chain A
aie::vector<uint8, block_size> a8_a = aie::unpack(I0_a);
aie::vector<uint16, block_size> a16_a = aie::unpack(a8_a);
aie::vector<bfloat16, block_size> abf_a = aie::to_float<bfloat16>(a16_a, 0);
aie::vector<bfloat16, block_size> w_a =
aie::mul(abf_a, sf_a_bc).template to_vector<bfloat16>();

// Unpack chain B
aie::vector<uint8, block_size> a8_b = aie::unpack(I0_b);
aie::vector<uint16, block_size> a16_b = aie::unpack(a8_b);
aie::vector<bfloat16, block_size> abf_b = aie::to_float<bfloat16>(a16_b, 0);
aie::vector<bfloat16, block_size> w_b =
aie::mul(abf_b, sf_b_bc).template to_vector<bfloat16>();

// Load activation vectors and MAC
aie::vector<bfloat16, block_size> b_a = aie::load_v<block_size>(b_ptr);
b_ptr += block_size;
acc = aie::mac(acc, w_a, b_a);

aie::vector<bfloat16, block_size> b_b = aie::load_v<block_size>(b_ptr);
b_ptr += block_size;
acc = aie::mac(acc, w_b, b_b);
}
} else {
// Generic path: 1 group per iteration
AIE_LOOP_MIN_ITERATION_COUNT(loop_iters)
for (uint32_t g = 0; g < groups_per_row; g++)
AIE_PREPARE_FOR_PIPELINING
{
bfloat16 sf = row_scales[g];
aie::vector<bfloat16, block_size> sf_broadcast =
aie::broadcast<bfloat16, block_size>(sf);

AIE_LOOP_MIN_ITERATION_COUNT(blocks_per_group)
for (uint32_t blk = 0; blk < blocks_per_group; blk++) {
aie::vector<uint4, block_size> I0 = aie::load_v<block_size>(row_weights);
row_weights += block_size / 2;

aie::vector<uint8, block_size> as_int8 = aie::unpack(I0);
aie::vector<uint16, block_size> as_int16 = aie::unpack(as_int8);
aie::vector<bfloat16, block_size> as_bf16 =
aie::to_float<bfloat16>(as_int16, 0);
aie::vector<bfloat16, block_size> w_dequant =
aie::mul(as_bf16, sf_broadcast).template to_vector<bfloat16>();

aie::vector<bfloat16, block_size> b_vec = aie::load_v<block_size>(b_ptr);
b_ptr += block_size;

acc = aie::mac(acc, w_dequant, b_vec);
}
}
}

*c_out = static_cast<bfloat16>(aie::reduce_add(acc.template to_vector<float>()));
c_out++;
}
event1();
}

#ifndef GROUP_SIZE
#define GROUP_SIZE 32
#endif

#ifndef DIM_K
#define DIM_K 2048
#endif

extern "C" {

void fused_dequant_matvec_bf16(uint32_t m,
uint32_t row_offset,
const uint8_t *__restrict a_in,
const bfloat16 *__restrict b_in,
bfloat16 *__restrict c_out)
{
c_out += row_offset;
fused_dequant_matvec<32, GROUP_SIZE, DIM_K>(m, a_in, b_in, c_out);
}

} // extern "C"
161 changes: 161 additions & 0 deletions iron/operators/gemv_int4/design.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Fused INT4 dequantization matrix-vector design.

Performs a fused dequantize-GEMV where the weight matrix is stored in packed
INT4 format (two 4-bit values per uint8 byte) with per-group bfloat16 scale
factors. The activation vector and output are bfloat16.

Each AIE column processes a contiguous block of output rows. Within a column,
the worker iterates over tiles of packed weight rows, acquires the full
activation vector once per outer iteration, and calls the fused dequant-matvec
kernel which unpacks, dequantizes, and accumulates in a single pass.

Buffer layout for A (packed weights, uint8):
For each tile of m_input rows: [m_input * K / 2 bytes of packed weights]
[m_input * (K / group_size) * 2 bytes of scales]
"""

import numpy as np
from ml_dtypes import bfloat16

import aie.dialects.index as index
from aie.dialects.aie import T
from aie.helpers.dialects.scf import _for as range_
from aie.helpers.taplib import TensorAccessPattern
from aie.iron import Kernel, ObjectFifo, Program, Runtime, Worker
from aie.iron.placers import SequentialPlacer


def my_fused_dequant_matvec(
dev, cols, M, K, m_input, m_output=None, group_size=32
):
if m_output is None:
m_output = m_input

# --- Assertions ---
assert (
m_output % m_input == 0 and m_output >= m_input
), "m_output must be a multiple of m_input"
assert m_output <= M // cols, "m_output must be less than or equal to M/cols"
assert (M // cols) % m_output == 0, "m_output must evenly divide M/cols"
assert m_input <= M // cols, "m_input must be less than or equal to M/cols"
assert (M // cols) % m_input == 0, "m_input must evenly divide M/cols"
assert K % group_size == 0, "K must be divisible by group_size"
assert group_size % 32 == 0, "group_size must be a multiple of 32"
assert M % cols == 0, "M must be divisible by cols"

# --- Data types ---
dtype_in = np.dtype[np.uint8]
dtype_vec = np.dtype[bfloat16]
dtype_out = np.dtype[bfloat16]

# --- Per-tile sizes (in uint8 bytes) ---
num_groups_per_row = K // group_size
packed_tile_bytes = m_input * K // 2 + m_input * num_groups_per_row * 2
rows_per_col = M // cols
tiles_per_col = rows_per_col // m_input
bytes_per_col = tiles_per_col * packed_tile_bytes
packed_total_bytes = cols * bytes_per_col

# --- L1 (on-chip) tensor types ---
L1_A_ty = np.ndarray[(packed_tile_bytes,), dtype_in]
L1_B_ty = np.ndarray[(K,), dtype_vec]
L1_C_ty = np.ndarray[(m_output,), dtype_out]

# --- L3 (DDR) tensor types ---
L3_A_ty = np.ndarray[(packed_total_bytes,), dtype_in]
L3_B_ty = np.ndarray[(K,), dtype_vec]
L3_C_ty = np.ndarray[(M,), dtype_out]

# --- Kernel declaration ---
# K and group_size are compile-time via -DDIM_K/-DGROUP_SIZE.
fused_matvec = Kernel(
"fused_dequant_matvec_bf16",
f"fused_dequant_gemv_{K}k_g{group_size}.o",
[np.int32, np.int32, L1_A_ty, L1_B_ty, L1_C_ty],
)

# --- ObjectFIFOs ---
A_L3L1_fifos = [
ObjectFifo(L1_A_ty, name=f"A_L3L1_{i}", depth=2) for i in range(cols)
]
B_L3L1_fifos = [
ObjectFifo(L1_B_ty, name=f"B_L3L1_{i}", depth=1) for i in range(cols)
]
C_L1L3_fifos = [
ObjectFifo(L1_C_ty, name=f"C_L1L3_{i}", depth=2) for i in range(cols)
]

# --- Worker core body ---
N_div_n = tiles_per_col // (m_output // m_input)

def core_body(A_L3L1_fifo, B_L3L1_fifo, C_L1L3_fifo, fused_matvec_fn):
for _ in range_(0xFFFFFFFF):
b = B_L3L1_fifo.acquire(1)
for i_idx in range_(N_div_n):
c = C_L1L3_fifo.acquire(1)
for j_idx in range_(m_output // m_input):
j_i32 = index.casts(T.i32(), j_idx)
output_row_offset = j_i32 * m_input
a = A_L3L1_fifo.acquire(1)
fused_matvec_fn(
m_input, output_row_offset, a, b, c
)
A_L3L1_fifo.release(1)
C_L1L3_fifo.release(1)
B_L3L1_fifo.release(1)

workers = [
Worker(
core_body,
[
A_L3L1_fifos[i].cons(),
B_L3L1_fifos[i].cons(),
C_L1L3_fifos[i].prod(),
fused_matvec,
],
)
for i in range(cols)
]

# --- TensorAccessPatterns ---
# A: each column gets a contiguous chunk of bytes_per_col packed bytes
A_taps = [
TensorAccessPattern(
tensor_dims=L3_A_ty.__args__[0],
offset=col * bytes_per_col,
sizes=[1, 1, 1, bytes_per_col],
strides=[0, 0, 0, 1],
)
for col in range(cols)
]

# C: each column writes contiguous rows_per_col bfloat16 values
C_taps = [
TensorAccessPattern(
tensor_dims=L3_C_ty.__args__[0],
offset=col * rows_per_col,
sizes=[1, 1, 1, rows_per_col],
strides=[0, 0, 0, 1],
)
for col in range(cols)
]

# --- Runtime sequence ---
rt = Runtime()
with rt.sequence(L3_A_ty, L3_B_ty, L3_C_ty) as (A, B, C):
rt.start(*workers)
tg = rt.task_group()
for i in range(cols):
rt.fill(A_L3L1_fifos[i].prod(), A, A_taps[i], task_group=tg)
rt.fill(B_L3L1_fifos[i].prod(), B, task_group=tg)
for i in range(cols):
rt.drain(
C_L1L3_fifos[i].cons(), C, C_taps[i], task_group=tg, wait=True
)
rt.finish_task_group(tg)

return Program(dev, rt).resolve_program(SequentialPlacer())
Loading