Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 36 additions & 11 deletions iron/operators/swiglu_decode/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@


def get_params():
params_list = [(2048, 2048)]
# (embedding_dim, hidden_dim)
# Square shape is the historical smoke-test config; the rectangular
# shape reflects real decoder-model FFN dims (e.g. Qwen3.5-0.8B
# embedding=1024, hidden=3584) that downstream runtimes actually hit.
params_list = [
(2048, 2048),
(1024, 3584),
]

params = []
for p in params_list:
Expand Down Expand Up @@ -55,25 +62,43 @@ def test_swiglu_decode(embedding_dim, hidden_dim, aie_context):
print(f"Effective Bandwidth: {bandwidth_gbps:.4f} GB/s")

errors = {}
# Verify intermediate result
# Reshape to (1, hidden_dim) using the unpadded dimension to match the golden reference shape.
# Note: op.hidden_dim_padded may differ if padding was applied; we use hidden_dim here
# because the golden reference was generated with the unpadded hidden_dim.

# Verify intermediate result (left_swished * right) against a chained
# reference built from the observed AIE left_swished and right buffers.
# This isolates eltwise_mul from any sub-tolerance drift accumulated in
# the upstream gemv_1 / silu stages that would otherwise be amplified by
# multiplication against a large-magnitude right operand (e.g. silu
# outputs that land near zero for very-negative inputs, where bf16
# rounding asymmetrically flushes NPU vs fp32-CPU). This mirrors the
# approach used by swiglu_prefill/test.py.
# Reshape to (1, hidden_dim) using the unpadded dimension to match the
# reference shape. Note: op.hidden_dim_padded may differ if padding was
# applied; we use hidden_dim here because the golden reference was
# generated with the unpadded hidden_dim.
left_swished = op_func.left_swished.to_torch().reshape((1, hidden_dim))
right = op_func.right.to_torch().reshape((1, hidden_dim))
ref_intermediate = left_swished * right

intermediate = op_func.intermediate.to_torch().reshape((1, hidden_dim))
errors_intermediate = verify_buffer(
intermediate,
"intermediate",
golden_ref["intermediate"],
rel_tol=0.07,
abs_tol=0.7,
ref_intermediate,
rel_tol=0.04,
abs_tol=0.4,
)
if errors_intermediate:
errors["intermediate"] = errors_intermediate

# Verify output using intermediate result
ref_2 = intermediate @ golden_ref["w_down"]
# Verify output using intermediate result.
# Note: we use the AIE intermediate buffer as reference (rather than
# golden_ref["output"]) because this better matches the bfloat16 precision
# path and isolates errors to gemv_2.
ref_output = intermediate @ golden_ref["w_down"]
output = output_buf.to_torch().reshape((1, embedding_dim))
errors_output = verify_buffer(output, "output", ref_2, rel_tol=0.04, abs_tol=0.4)
errors_output = verify_buffer(
output, "output", ref_output, rel_tol=0.04, abs_tol=0.4
)
if errors_output:
errors["output"] = errors_output

Expand Down
Loading