diff --git a/olive/passes/qairt/gen_ai_builder.py b/olive/passes/qairt/gen_ai_builder.py index b0437d715..eeab07ac4 100644 --- a/olive/passes/qairt/gen_ai_builder.py +++ b/olive/passes/qairt/gen_ai_builder.py @@ -94,7 +94,14 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon default_value=False, description="Produces context binaries with additional context length combinations. " "Improves token generation performance for different context lengths but increases preparation time. " - "HTP only.", + "Mutually exclusive with context_lengths. HTP only.", + ), + "context_lengths": PassConfigParam( + type_=list[int], + default_value=None, + description="Explicit list of context lengths (CLs) to compile. " + "Overrides the default CL set produced by multi_graph. " + "Mutually exclusive with multi_graph. HTP only.", ), } @@ -134,6 +141,13 @@ def validate_config( if config.multi_graph: logger.error("multi_graph is unsupported on non-HTP backends") return False + if config.context_lengths: + logger.error("context_lengths is unsupported on non-HTP backends") + return False + + if config.context_lengths and config.multi_graph: + logger.error("context_lengths and multi_graph are mutually exclusive") + return False native_kv_supported_sequence_lengths = [[32, 128]] if config.native_kv and config.sequence_lengths not in native_kv_supported_sequence_lengths: @@ -237,7 +251,12 @@ def _run_for_config( config.num_splits ) - gen_ai_builder.multi_graph = config.multi_graph + if config.context_lengths: + gen_ai_builder._transformation_config.model_transformer_config.arn_cl_options.context_length = ( + config.context_lengths + ) + else: + gen_ai_builder.multi_graph = config.multi_graph gen_ai_container = gen_ai_builder.build() gen_ai_container.save(output_model_path, exist_ok=True) diff --git a/test/passes/qairt/test_gen_ai_builder.py b/test/passes/qairt/test_gen_ai_builder.py index 4d6c51c72..2fa67b9fa 100644 --- a/test/passes/qairt/test_gen_ai_builder.py +++ b/test/passes/qairt/test_gen_ai_builder.py @@ -6,7 +6,7 @@ import builtins from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, PropertyMock, patch import pytest @@ -36,6 +36,8 @@ def test_gen_ai_builder_default_config(mock_accelerator_spec): assert config["num_splits"].default_value == -1 assert "multi_graph" in config assert config["multi_graph"].default_value is False + assert "context_lengths" in config + assert config["context_lengths"].default_value is None def test_gen_ai_builder_cpu_backend_success(tmp_path, mock_hf_model, mock_qairt_modules): @@ -495,3 +497,100 @@ def test_gen_ai_builder_native_kv_validation_valid_sequence_lengths(mock_acceler ).config assert QairtGenAIBuilder.validate_config(config, mock_accelerator_spec) is True + + +def test_gen_ai_builder_context_lengths_configuration(tmp_path, mock_qairt_prepared_model, mock_qairt_modules): + """Test that context_lengths directly sets arn_cl_options.context_length.""" + output_path = tmp_path / "output" + + mock_builder = MagicMock() + mock_container = MagicMock() + mock_builder.build.return_value = mock_container + mock_builder._compilation_config = MagicMock() + mock_builder._compilation_config.graph_custom_configs = [MagicMock()] + mock_builder._compilation_config.device_custom_configs = [MagicMock()] + mock_builder._compilation_config.context_custom_configs = [MagicMock()] + mock_builder._transformation_config = MagicMock() + mock_builder._transformation_config.model_transformer_config = MagicMock() + mock_builder._transformation_config.model_transformer_config.arn_cl_options = MagicMock() + mock_builder._transformation_config.model_transformer_config.split_model = MagicMock() + + mock_qairt_modules["gen_ai_api"].GenAIBuilderFactory.create.return_value = mock_builder + + custom_cls = [1024, 2048, 4096] + gen_ai_pass = create_pass_from_dict( + QairtGenAIBuilder, + {"backend": "HTP", "context_lengths": custom_cls}, + disable_search=True, + ) + + result = gen_ai_pass.run(mock_qairt_prepared_model, str(output_path)) + + assert mock_builder._transformation_config.model_transformer_config.arn_cl_options.context_length == custom_cls + assert isinstance(result, QairtModelHandler) + + +def test_gen_ai_builder_context_lengths_skips_multi_graph_setter( + tmp_path, mock_qairt_prepared_model, mock_qairt_modules +): + """Test that multi_graph setter is not invoked when context_lengths is set.""" + output_path = tmp_path / "output" + + mock_builder_cls = type("mock_builder_cls", (MagicMock,), {}) + mock_builder = mock_builder_cls() + mock_container = MagicMock() + mock_builder.build.return_value = mock_container + mock_builder._compilation_config = MagicMock() + mock_builder._compilation_config.graph_custom_configs = [MagicMock()] + mock_builder._compilation_config.device_custom_configs = [MagicMock()] + mock_builder._compilation_config.context_custom_configs = [MagicMock()] + mock_builder._transformation_config = MagicMock() + mock_builder._transformation_config.model_transformer_config = MagicMock() + mock_builder._transformation_config.model_transformer_config.arn_cl_options = MagicMock() + mock_builder._transformation_config.model_transformer_config.split_model = MagicMock() + + multi_graph_mock = PropertyMock() + type(mock_builder).multi_graph = multi_graph_mock + + mock_qairt_modules["gen_ai_api"].GenAIBuilderFactory.create.return_value = mock_builder + + gen_ai_pass = create_pass_from_dict( + QairtGenAIBuilder, + {"backend": "HTP", "context_lengths": [512, 2048]}, + disable_search=True, + ) + + gen_ai_pass.run(mock_qairt_prepared_model, str(output_path)) + multi_graph_mock.assert_not_called() + + +def test_gen_ai_builder_validate_config_context_lengths_cpu_rejected(mock_accelerator_spec, mock_qairt_modules): + """Test that context_lengths is rejected on non-HTP backends.""" + config = create_pass_from_dict( + QairtGenAIBuilder, + {"backend": "CPU", "context_lengths": [1024, 2048]}, + disable_search=True, + ).config + assert QairtGenAIBuilder.validate_config(config, mock_accelerator_spec) is False + + +def test_gen_ai_builder_validate_config_context_lengths_and_multi_graph_rejected( + mock_accelerator_spec, mock_qairt_modules +): + """Test that context_lengths and multi_graph are mutually exclusive.""" + config = create_pass_from_dict( + QairtGenAIBuilder, + {"backend": "HTP", "context_lengths": [1024, 2048], "multi_graph": True}, + disable_search=True, + ).config + assert QairtGenAIBuilder.validate_config(config, mock_accelerator_spec) is False + + +def test_gen_ai_builder_validate_config_context_lengths_htp_valid(mock_accelerator_spec, mock_qairt_modules): + """Test that context_lengths passes validation on HTP.""" + config = create_pass_from_dict( + QairtGenAIBuilder, + {"backend": "HTP", "context_lengths": [1024, 2048, 4096]}, + disable_search=True, + ).config + assert QairtGenAIBuilder.validate_config(config, mock_accelerator_spec) is True