Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions olive/passes/qairt/gen_ai_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,14 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
default_value=False,
description="Produces context binaries with additional context length combinations. "
"Improves token generation performance for different context lengths but increases preparation time. "
"HTP only.",
"Mutually exclusive with context_lengths. HTP only.",
),
"context_lengths": PassConfigParam(
type_=list[int],
default_value=None,
description="Explicit list of context lengths (CLs) to compile. "
"Overrides the default CL set produced by multi_graph. "
"Mutually exclusive with multi_graph. HTP only.",
),
Comment thread
qti-kromero marked this conversation as resolved.
}

Expand Down Expand Up @@ -134,6 +141,13 @@ def validate_config(
if config.multi_graph:
logger.error("multi_graph is unsupported on non-HTP backends")
return False
if config.context_lengths:
logger.error("context_lengths is unsupported on non-HTP backends")
return False

if config.context_lengths and config.multi_graph:
logger.error("context_lengths and multi_graph are mutually exclusive")
return False
Comment thread
qti-kromero marked this conversation as resolved.

native_kv_supported_sequence_lengths = [[32, 128]]
if config.native_kv and config.sequence_lengths not in native_kv_supported_sequence_lengths:
Expand Down Expand Up @@ -237,7 +251,12 @@ def _run_for_config(
config.num_splits
)

gen_ai_builder.multi_graph = config.multi_graph
if config.context_lengths:
gen_ai_builder._transformation_config.model_transformer_config.arn_cl_options.context_length = (
config.context_lengths
)
else:
gen_ai_builder.multi_graph = config.multi_graph
Comment thread
qti-kromero marked this conversation as resolved.

gen_ai_container = gen_ai_builder.build()
gen_ai_container.save(output_model_path, exist_ok=True)
Expand Down
98 changes: 98 additions & 0 deletions test/passes/qairt/test_gen_ai_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def test_gen_ai_builder_default_config(mock_accelerator_spec):
assert config["num_splits"].default_value == -1
assert "multi_graph" in config
assert config["multi_graph"].default_value is False
assert "context_lengths" in config
assert config["context_lengths"].default_value is None


def test_gen_ai_builder_cpu_backend_success(tmp_path, mock_hf_model, mock_qairt_modules):
Expand Down Expand Up @@ -495,3 +497,99 @@ def test_gen_ai_builder_native_kv_validation_valid_sequence_lengths(mock_acceler
).config

assert QairtGenAIBuilder.validate_config(config, mock_accelerator_spec) is True


def test_gen_ai_builder_context_lengths_configuration(tmp_path, mock_qairt_prepared_model, mock_qairt_modules):
"""Test that context_lengths directly sets arn_cl_options.context_length."""
output_path = tmp_path / "output"

mock_builder = MagicMock()
mock_container = MagicMock()
mock_builder.build.return_value = mock_container
mock_builder._compilation_config = MagicMock()
mock_builder._compilation_config.graph_custom_configs = [MagicMock()]
mock_builder._compilation_config.device_custom_configs = [MagicMock()]
mock_builder._compilation_config.context_custom_configs = [MagicMock()]
mock_builder._transformation_config = MagicMock()
mock_builder._transformation_config.model_transformer_config = MagicMock()
mock_builder._transformation_config.model_transformer_config.arn_cl_options = MagicMock()
mock_builder._transformation_config.model_transformer_config.split_model = MagicMock()

mock_qairt_modules["gen_ai_api"].GenAIBuilderFactory.create.return_value = mock_builder

custom_cls = [1024, 2048, 4096]
gen_ai_pass = create_pass_from_dict(
QairtGenAIBuilder,
{"backend": "HTP", "context_lengths": custom_cls},
disable_search=True,
)

result = gen_ai_pass.run(mock_qairt_prepared_model, str(output_path))

assert mock_builder._transformation_config.model_transformer_config.arn_cl_options.context_length == custom_cls
assert isinstance(result, QairtModelHandler)


def test_gen_ai_builder_context_lengths_skips_multi_graph_setter(
tmp_path, mock_qairt_prepared_model, mock_qairt_modules
):
"""Test that multi_graph setter is not invoked when context_lengths is set."""
output_path = tmp_path / "output"

mock_builder = MagicMock()
mock_container = MagicMock()
mock_builder.build.return_value = mock_container
mock_builder._compilation_config = MagicMock()
mock_builder._compilation_config.graph_custom_configs = [MagicMock()]
mock_builder._compilation_config.device_custom_configs = [MagicMock()]
mock_builder._compilation_config.context_custom_configs = [MagicMock()]
mock_builder._transformation_config = MagicMock()
mock_builder._transformation_config.model_transformer_config = MagicMock()
mock_builder._transformation_config.model_transformer_config.arn_cl_options = MagicMock()
mock_builder._transformation_config.model_transformer_config.split_model = MagicMock()

mock_qairt_modules["gen_ai_api"].GenAIBuilderFactory.create.return_value = mock_builder

gen_ai_pass = create_pass_from_dict(
QairtGenAIBuilder,
{"backend": "HTP", "context_lengths": [512, 2048]},
disable_search=True,
)

gen_ai_pass.run(mock_qairt_prepared_model, str(output_path))

# Verify multi_graph property setter was never invoked
calls = [str(c) for c in mock_builder.mock_calls]
assert not any("multi_graph" in c for c in calls)
Comment thread
qti-kromero marked this conversation as resolved.
Outdated


def test_gen_ai_builder_validate_config_context_lengths_cpu_rejected(mock_accelerator_spec, mock_qairt_modules):
"""Test that context_lengths is rejected on non-HTP backends."""
config = create_pass_from_dict(
QairtGenAIBuilder,
{"backend": "CPU", "context_lengths": [1024, 2048]},
disable_search=True,
).config
assert QairtGenAIBuilder.validate_config(config, mock_accelerator_spec) is False


def test_gen_ai_builder_validate_config_context_lengths_and_multi_graph_rejected(
mock_accelerator_spec, mock_qairt_modules
):
"""Test that context_lengths and multi_graph are mutually exclusive."""
config = create_pass_from_dict(
QairtGenAIBuilder,
{"backend": "HTP", "context_lengths": [1024, 2048], "multi_graph": True},
disable_search=True,
).config
assert QairtGenAIBuilder.validate_config(config, mock_accelerator_spec) is False


def test_gen_ai_builder_validate_config_context_lengths_htp_valid(mock_accelerator_spec, mock_qairt_modules):
"""Test that context_lengths passes validation on HTP."""
config = create_pass_from_dict(
QairtGenAIBuilder,
{"backend": "HTP", "context_lengths": [1024, 2048, 4096]},
disable_search=True,
).config
assert QairtGenAIBuilder.validate_config(config, mock_accelerator_spec) is True
Loading