diff --git a/olive/passes/onnx/kquant_quantization.py b/olive/passes/onnx/kquant_quantization.py index 4263406af..1e6b4b66c 100644 --- a/olive/passes/onnx/kquant_quantization.py +++ b/olive/passes/onnx/kquant_quantization.py @@ -12,6 +12,7 @@ depending on onnxruntime's quantization modules. """ +import fnmatch import logging from pathlib import Path from typing import Optional @@ -248,7 +249,11 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon "nodes_to_exclude": PassConfigParam( type_=list, default_value=None, - description="List of node names to exclude from quantization.", + description=( + "List of node names to exclude from quantization. Entries may be exact node " + "names or Unix shell-style glob patterns (e.g. '*/projector/*' to exclude all " + "projector MatMuls). A node is excluded if its name equals or matches any entry." + ), ), **get_external_data_config(), } @@ -296,7 +301,9 @@ def _quantize_model( for node in ir_model.graph.all_nodes(): node_name = node.name - if node_name in nodes_to_exclude: + if node_name in nodes_to_exclude or any( + fnmatch.fnmatchcase(node_name or "", pattern) for pattern in nodes_to_exclude + ): logger.debug("Exclude quantization of %s as specified by nodes_to_exclude.", node_name) continue diff --git a/test/passes/onnx/test_kquant_quantization.py b/test/passes/onnx/test_kquant_quantization.py index 4440c4337..3bb88b631 100644 --- a/test/passes/onnx/test_kquant_quantization.py +++ b/test/passes/onnx/test_kquant_quantization.py @@ -135,3 +135,32 @@ def test_kquant_with_nodes_to_exclude(self, matmul_model_path, tmp_path): assert len(matmul_nbits_nodes) == 1, "Expected 1 MatMulNBits node (MatMul_2 quantized)" assert len(matmul_nodes) == 1, "Expected 1 original MatMul node (MatMul_1 excluded)" + + def test_kquant_with_nodes_to_exclude_glob(self, matmul_model_path, tmp_path): + """Test k-quant where nodes_to_exclude uses a glob pattern.""" + olive_model = ONNXModelHandler(model_path=str(matmul_model_path)) + accelerator_spec = AcceleratorSpec( + accelerator_type="CPU", + execution_provider="CPUExecutionProvider", + ) + # "*_1" matches MatMul_1 only; MatMul_2 should still be quantized. + pass_config = { + "bits": 4, + "block_size": 32, + "nodes_to_exclude": ["*_1"], + } + p = create_pass_from_dict( + OnnxKQuantQuantization, pass_config, disable_search=True, accelerator_spec=accelerator_spec + ) + + output_path = tmp_path / "quantized_glob_model.onnx" + quantized_model = p.run(olive_model, output_path) + + assert os.path.exists(quantized_model.model_path) + + quantized_onnx = onnx.load(quantized_model.model_path) + matmul_nbits_nodes = [n for n in quantized_onnx.graph.node if n.op_type == str(OpType.MatMulNBits)] + matmul_nodes = [n for n in quantized_onnx.graph.node if n.op_type == "MatMul"] + + assert len(matmul_nbits_nodes) == 1, "Expected 1 MatMulNBits node (MatMul_2 quantized)" + assert len(matmul_nodes) == 1, "Expected 1 original MatMul node (MatMul_1 excluded via glob)"