Skip to content

Commit 7a7476b

Browse files
jgmelberclaude
andcommitted
Fuse dual-GEMV + SiLU + Mul into single NPU design for SwiGLU decode
Collapses three separate NPU designs (GEMV W1, GEMV W2, SiLU+Mul) into a single fused operator. Each AIE core loads vector x once, processes both W1 and W2 rows through a shared A FIFO with pre-interleaved weights, then computes silu(left)*right entirely in L1 via kernel-local static buffers. The intermediate vectors never touch DRAM. Reduces SwiGLU decode from 4 to 2 runlist entries and eliminates the left/right buffer allocations. Uses 4 AIE columns (DMA channel limit: 2 input + 1 output per tile). Note: swiglu_prefill unchanged — uses GEMM (not GEMV) so dual-GEMV fusion does not apply. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 329ccf4 commit 7a7476b

File tree

9 files changed

+642
-102
lines changed

9 files changed

+642
-102
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
// Fused dual-GEMV + SiLU + elementwise multiply kernel for AIE2.
5+
// Same structure as AIE2+ variant but uses LUT-based getTanhBf16.
6+
7+
#define NOCPP
8+
9+
#include "../aie_kernel_utils.h"
10+
#include "lut_based_ops.h"
11+
12+
#include <aie_api/aie.hpp>
13+
#include <stdint.h>
14+
#include <type_traits>
15+
16+
static bfloat16 left_buf[1024] __attribute__((aligned(64)));
17+
static bfloat16 right_buf[1024] __attribute__((aligned(64)));
18+
19+
template <uint32_t r>
20+
void matvec_vectorized(uint32_t m,
21+
uint32_t k,
22+
const bfloat16 *__restrict a,
23+
const bfloat16 *__restrict b,
24+
bfloat16 *__restrict c)
25+
{
26+
::aie::set_rounding(aie::rounding_mode::conv_even);
27+
bfloat16 *c_end = c + m;
28+
const bfloat16 *b_end = b + k;
29+
for (; c < c_end; c++) {
30+
aie::accum acc = aie::zeros<accfloat, r>();
31+
AIE_LOOP_MIN_ITERATION_COUNT(2)
32+
for (const bfloat16 *__restrict b_cur = b; b_cur < b_end; b_cur += r, a += r) {
33+
aie::vector<bfloat16, r> a_vec = aie::load_v<r>(a);
34+
aie::vector<bfloat16, r> b_vec = aie::load_v<r>(b_cur);
35+
acc = aie::mac(acc, a_vec, b_vec);
36+
}
37+
*c = static_cast<bfloat16>(aie::reduce_add(acc.template to_vector<float>()));
38+
}
39+
}
40+
41+
extern "C" {
42+
43+
void dual_gemv_matvec_bf16(uint32_t m,
44+
uint32_t k,
45+
uint32_t row_offset,
46+
const bfloat16 *__restrict a_in,
47+
const bfloat16 *__restrict b_in,
48+
uint32_t phase)
49+
{
50+
bfloat16 *dst = (phase == 0) ? left_buf : right_buf;
51+
dst += row_offset;
52+
matvec_vectorized<64>(m, k, a_in, b_in, dst);
53+
}
54+
55+
void dual_gemv_silu_mul_bf16(bfloat16 *__restrict c_out, int32_t m_output)
56+
{
57+
event0();
58+
59+
aie::vector<bfloat16, 16> register_0_5 = aie::broadcast<bfloat16, 16>(0.5f);
60+
aie::vector<bfloat16, 16> register_1 = aie::broadcast<bfloat16, 16>(1.0f);
61+
AIE_PREPARE_FOR_PIPELINING
62+
for (int i = 0; i < m_output; i += 16) {
63+
aie::vector<bfloat16, 16> left_val = aie::load_v<16>(left_buf + i);
64+
aie::vector<bfloat16, 16> right_val = aie::load_v<16>(right_buf + i);
65+
66+
aie::vector<bfloat16, 16> half_x = aie::mul(left_val, register_0_5);
67+
aie::vector<bfloat16, 16> tanh_half_x = getTanhBf16(half_x);
68+
auto tanh_half_x_approx = aie::add(tanh_half_x, register_1);
69+
aie::vector<bfloat16, 16> sigmoid_approx = aie::mul(tanh_half_x_approx, register_0_5);
70+
auto silu_output = aie::mul(left_val, sigmoid_approx);
71+
72+
auto fused_output = aie::mul(silu_output.to_vector<bfloat16>(), right_val);
73+
aie::store_v(c_out + i, fused_output.to_vector<bfloat16>());
74+
}
75+
76+
event1();
77+
}
78+
79+
} // extern "C"
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
// SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
// Fused dual-GEMV + SiLU + elementwise multiply kernel for AIE2+.
5+
//
6+
// Computes: output = silu(W1 @ x) * (W2 @ x)
7+
//
8+
// Two entry points called from the NPU design's core body:
9+
// 1. dual_gemv_matvec_bf16: GEMV writing to FIFO buffer c_out + row_offset
10+
// 2. dual_gemv_silu_mul_bf16: reads from static left_buf/right_buf, writes to FIFO c_out
11+
//
12+
// The static buffers are written via scalar stores (from matvec) and read
13+
// via aie::load_v in the silu_mul phase. Aligned to 64 bytes for safe vector access.
14+
15+
#define NOCPP
16+
17+
#include "../aie_kernel_utils.h"
18+
19+
#include <aie_api/aie.hpp>
20+
#include <stdint.h>
21+
#include <type_traits>
22+
23+
static bfloat16 left_buf[1024] __attribute__((aligned(64)));
24+
static bfloat16 right_buf[1024] __attribute__((aligned(64)));
25+
26+
template <uint32_t r>
27+
void matvec_vectorized(uint32_t m,
28+
uint32_t k,
29+
const bfloat16 *__restrict a,
30+
const bfloat16 *__restrict b,
31+
bfloat16 *__restrict c)
32+
{
33+
::aie::set_rounding(aie::rounding_mode::conv_even);
34+
bfloat16 *c_end = c + m;
35+
const bfloat16 *b_end = b + k;
36+
for (; c < c_end; c++) {
37+
aie::accum acc = aie::zeros<accfloat, r>();
38+
AIE_LOOP_MIN_ITERATION_COUNT(2)
39+
for (const bfloat16 *__restrict b_cur = b; b_cur < b_end; b_cur += r, a += r) {
40+
aie::vector<bfloat16, r> a_vec = aie::load_v<r>(a);
41+
aie::vector<bfloat16, r> b_vec = aie::load_v<r>(b_cur);
42+
acc = aie::mac(acc, a_vec, b_vec);
43+
}
44+
*c = static_cast<bfloat16>(aie::reduce_add(acc.template to_vector<float>()));
45+
}
46+
}
47+
48+
extern "C" {
49+
50+
// Phase 1 & 2: GEMV writing to a static buffer (left_buf or right_buf)
51+
// phase=0 writes to left_buf, phase=1 writes to right_buf
52+
void dual_gemv_matvec_bf16(uint32_t m,
53+
uint32_t k,
54+
uint32_t row_offset,
55+
const bfloat16 *__restrict a_in,
56+
const bfloat16 *__restrict b_in,
57+
uint32_t phase)
58+
{
59+
bfloat16 *dst = (phase == 0) ? left_buf : right_buf;
60+
dst += row_offset;
61+
matvec_vectorized<64>(m, k, a_in, b_in, dst);
62+
}
63+
64+
// Phase 3: silu(left_buf) * right_buf -> c_out (FIFO buffer)
65+
void dual_gemv_silu_mul_bf16(bfloat16 *__restrict c_out, int32_t m_output)
66+
{
67+
event0();
68+
69+
aie::vector<bfloat16, 16> register_0_5 = aie::broadcast<bfloat16, 16>(0.5f);
70+
aie::vector<bfloat16, 16> register_1 = aie::broadcast<bfloat16, 16>(1.0f);
71+
AIE_PREPARE_FOR_PIPELINING
72+
for (int i = 0; i < m_output; i += 16) {
73+
aie::vector<bfloat16, 16> left_val = aie::load_v<16>(left_buf + i);
74+
aie::vector<bfloat16, 16> right_val = aie::load_v<16>(right_buf + i);
75+
76+
// SiLU(x) = x * sigmoid(x) = x * 0.5 * (1 + tanh(x/2))
77+
auto half_x = aie::mul(left_val, register_0_5);
78+
auto tanh_half_x = aie::tanh<bfloat16>(half_x.to_vector<float>());
79+
auto tanh_half_x_approx = aie::add(tanh_half_x, register_1);
80+
aie::vector<bfloat16, 16> sigmoid_approx = aie::mul(tanh_half_x_approx, register_0_5);
81+
auto silu_output = aie::mul(left_val, sigmoid_approx);
82+
83+
auto fused_output = aie::mul(silu_output.to_vector<bfloat16>(), right_val);
84+
aie::store_v(c_out + i, fused_output.to_vector<bfloat16>());
85+
}
86+
87+
event1();
88+
}
89+
90+
} // extern "C"

iron/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from .axpy.op import AIEAXPY
55
from .dequant.op import AIEDequant
6+
from .dual_gemv_silu_mul.op import AIEDualGEMVSiLUMul
67
from .elementwise_add.op import AIEElementwiseAdd
78
from .elementwise_mul.op import AIEElementwiseMul
89
from .gelu.op import AIEGELU
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import numpy as np
5+
from pathlib import Path
6+
from ml_dtypes import bfloat16
7+
import argparse
8+
9+
import aie.dialects.index as index
10+
from aie.dialects.aie import *
11+
from aie.dialects.aiex import *
12+
from aie.helpers.dialects.scf import _for as range_
13+
from aie.iron import Kernel, ObjectFifo, Program, Runtime, Worker
14+
from aie.iron.placers import SequentialPlacer
15+
from aie.iron.device import NPU1, NPU2
16+
17+
"""
18+
Dual matrix-vector + SiLU + elementwise multiply design.
19+
20+
Computes: output = silu(W1 @ x) * (W2 @ x)
21+
22+
W1 and W2 rows are pre-interleaved in DDR by the operator (op.py).
23+
GEMV phases write to kernel-internal static buffers (left_buf, right_buf)
24+
controlled by a phase parameter. The silu_mul phase reads from those
25+
buffers and writes the result to the output C FIFO.
26+
27+
Each AIE core:
28+
1. Acquires vector x (held in L1 for both GEMV passes)
29+
2. Consumes W1 rows from A FIFO, writes dot products to left_buf (phase=0)
30+
3. Consumes W2 rows from A FIFO, writes dot products to right_buf (phase=1)
31+
4. Computes silu(left_buf) * right_buf -> C FIFO output
32+
"""
33+
34+
35+
def my_dual_gemv_silu_mul(dev, cols, M, K, m_input, m_output=None):
36+
if m_output is None:
37+
m_output = m_input
38+
39+
assert m_output % m_input == 0 and m_output >= m_input
40+
assert m_output <= M // cols
41+
assert (M // cols) % m_output == 0
42+
assert m_input <= M // cols
43+
assert (M // cols) % m_input == 0
44+
45+
dtype_in = np.dtype[bfloat16]
46+
dtype_out = np.dtype[bfloat16]
47+
48+
assert M % cols == 0
49+
50+
dev_ty = NPU1() if dev == "npu" else NPU2()
51+
52+
# L1 tile types
53+
L1_A_ty = np.ndarray[(m_input, K), dtype_in]
54+
L1_B_ty = np.ndarray[(K,), dtype_in]
55+
L1_C_ty = np.ndarray[(m_output,), dtype_out]
56+
57+
# L3 (DDR) buffer types
58+
L3_W_ty = np.ndarray[(2 * M, K), dtype_in]
59+
L3_B_ty = np.ndarray[(K,), dtype_in]
60+
L3_C_ty = np.ndarray[(M,), dtype_out]
61+
62+
# GEMV: writes to left_buf (phase=0) or right_buf (phase=1)
63+
matvec = Kernel(
64+
"dual_gemv_matvec_bf16",
65+
"dual_gemv_silu_mul.o",
66+
[np.int32, np.int32, np.int32, L1_A_ty, L1_B_ty, np.int32],
67+
)
68+
69+
# SiLU+Mul: reads from static left_buf/right_buf, writes to C FIFO
70+
silu_mul_fn = Kernel(
71+
"dual_gemv_silu_mul_bf16",
72+
"dual_gemv_silu_mul.o",
73+
[L1_C_ty, np.int32],
74+
)
75+
76+
# ObjectFIFOs: 2 inputs + 1 output = fits AIE DMA channel limits
77+
A_fifos = [ObjectFifo(L1_A_ty, name=f"A_{i}", depth=2) for i in range(cols)]
78+
B_fifos = [ObjectFifo(L1_B_ty, name=f"B_{i}", depth=1) for i in range(cols)]
79+
C_fifos = [ObjectFifo(L1_C_ty, name=f"C_{i}", depth=2) for i in range(cols)]
80+
81+
def core_body(A_fifo, B_fifo, C_fifo, matvec_fn, silu_mul):
82+
for _ in range_(0xFFFFFFFF):
83+
b = B_fifo.acquire(1)
84+
for i_idx in range_(M // m_output // cols):
85+
# Phase 1: W1 rows -> left_buf (phase=0)
86+
for j_idx in range_(m_output // m_input):
87+
j_i32 = index.casts(T.i32(), j_idx)
88+
row_offset = j_i32 * m_input
89+
a = A_fifo.acquire(1)
90+
matvec_fn(m_input, K, row_offset, a, b, 0)
91+
A_fifo.release(1)
92+
# Phase 2: W2 rows -> right_buf (phase=1)
93+
for j_idx in range_(m_output // m_input):
94+
j_i32 = index.casts(T.i32(), j_idx)
95+
row_offset = j_i32 * m_input
96+
a = A_fifo.acquire(1)
97+
matvec_fn(m_input, K, row_offset, a, b, 1)
98+
A_fifo.release(1)
99+
# Phase 3: silu(left_buf) * right_buf -> output
100+
c = C_fifo.acquire(1)
101+
silu_mul(c, m_output)
102+
C_fifo.release(1)
103+
B_fifo.release(1)
104+
105+
workers = [
106+
Worker(
107+
core_body,
108+
[
109+
A_fifos[i].cons(),
110+
B_fifos[i].cons(),
111+
C_fifos[i].prod(),
112+
matvec,
113+
silu_mul_fn,
114+
],
115+
)
116+
for i in range(cols)
117+
]
118+
119+
# Interleaved weight distribution per column
120+
rows_per_col = M // cols
121+
A_taps = [
122+
TensorAccessPattern(
123+
tensor_dims=(2 * M, K),
124+
offset=col * 2 * rows_per_col * K,
125+
sizes=[1, 1, 1, 2 * rows_per_col * K],
126+
strides=[0, 0, 0, 1],
127+
)
128+
for col in range(cols)
129+
]
130+
131+
# Output collection
132+
C_taps = [
133+
TensorAccessPattern(
134+
tensor_dims=(1, M),
135+
offset=col * (M // cols),
136+
sizes=[1, 1, 1, (M // cols)],
137+
strides=[0, 0, 0, 1],
138+
)
139+
for col in range(cols)
140+
]
141+
142+
rt = Runtime()
143+
with rt.sequence(L3_W_ty, L3_B_ty, L3_C_ty) as (W, B, C):
144+
rt.start(*workers)
145+
tg = rt.task_group()
146+
for i in range(cols):
147+
rt.fill(A_fifos[i].prod(), W, A_taps[i], task_group=tg)
148+
rt.fill(B_fifos[i].prod(), B, task_group=tg)
149+
for i in range(cols):
150+
rt.drain(C_fifos[i].cons(), C, C_taps[i], task_group=tg, wait=True)
151+
rt.finish_task_group(tg)
152+
153+
return Program(dev_ty, rt).resolve_program(SequentialPlacer())
154+
155+
156+
if __name__ == "__main__":
157+
argparser = argparse.ArgumentParser(
158+
prog="AIE Dual GEMV + SiLU + Mul Design",
159+
)
160+
argparser.add_argument("--dev", type=str, choices=["npu", "npu2"], default="npu")
161+
argparser.add_argument("-M", type=int, required=True)
162+
argparser.add_argument("-K", type=int, required=True)
163+
argparser.add_argument("-m", type=int, required=True, dest="m_input")
164+
argparser.add_argument("--m-output", type=int, default=None, dest="m_output")
165+
argparser.add_argument("--cols", type=int, required=True)
166+
argparser.add_argument(
167+
"--output-file-path",
168+
"-o",
169+
type=str,
170+
help="Output file path for the generated MLIR module",
171+
)
172+
args = argparser.parse_args()
173+
module = my_dual_gemv_silu_mul(
174+
args.dev, args.cols, args.M, args.K, args.m_input, args.m_output
175+
)
176+
177+
output_file_path = Path(args.output_file_path)
178+
179+
with open(output_file_path, "w") as f:
180+
f.write(str(module))

0 commit comments

Comments
 (0)