Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions olive/passes/onnx/kquant_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
depending on onnxruntime's quantization modules.
"""

import fnmatch
import logging
from pathlib import Path
from typing import Optional
Expand Down Expand Up @@ -248,7 +249,11 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
"nodes_to_exclude": PassConfigParam(
type_=list,
default_value=None,
description="List of node names to exclude from quantization.",
description=(
"List of node names to exclude from quantization. Entries may be exact node "
"names or Unix shell-style glob patterns (e.g. '*/projector/*' to exclude all "
"projector MatMuls). A node is excluded if its name equals or matches any entry."
),
),
**get_external_data_config(),
}
Expand Down Expand Up @@ -296,7 +301,9 @@ def _quantize_model(
for node in ir_model.graph.all_nodes():
node_name = node.name

if node_name in nodes_to_exclude:
if node_name in nodes_to_exclude or any(
fnmatch.fnmatchcase(node_name or "", pattern) for pattern in nodes_to_exclude
):
Comment on lines +304 to +306
logger.debug("Exclude quantization of %s as specified by nodes_to_exclude.", node_name)
continue

Expand Down
29 changes: 29 additions & 0 deletions test/passes/onnx/test_kquant_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,32 @@ def test_kquant_with_nodes_to_exclude(self, matmul_model_path, tmp_path):

assert len(matmul_nbits_nodes) == 1, "Expected 1 MatMulNBits node (MatMul_2 quantized)"
assert len(matmul_nodes) == 1, "Expected 1 original MatMul node (MatMul_1 excluded)"

def test_kquant_with_nodes_to_exclude_glob(self, matmul_model_path, tmp_path):
"""Test k-quant where nodes_to_exclude uses a glob pattern."""
olive_model = ONNXModelHandler(model_path=str(matmul_model_path))
accelerator_spec = AcceleratorSpec(
accelerator_type="CPU",
execution_provider="CPUExecutionProvider",
)
# "*_1" matches MatMul_1 only; MatMul_2 should still be quantized.
pass_config = {
"bits": 4,
"block_size": 32,
"nodes_to_exclude": ["*_1"],
}
p = create_pass_from_dict(
OnnxKQuantQuantization, pass_config, disable_search=True, accelerator_spec=accelerator_spec
)

output_path = tmp_path / "quantized_glob_model.onnx"
quantized_model = p.run(olive_model, output_path)

assert os.path.exists(quantized_model.model_path)

quantized_onnx = onnx.load(quantized_model.model_path)
matmul_nbits_nodes = [n for n in quantized_onnx.graph.node if n.op_type == str(OpType.MatMulNBits)]
matmul_nodes = [n for n in quantized_onnx.graph.node if n.op_type == "MatMul"]

assert len(matmul_nbits_nodes) == 1, "Expected 1 MatMulNBits node (MatMul_2 quantized)"
assert len(matmul_nodes) == 1, "Expected 1 original MatMul node (MatMul_1 excluded via glob)"
Loading