Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions olive/common/hf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,20 @@ def get_generation_config(model_name_or_path: str, **kwargs) -> Optional["Genera
return None


def resolve_diffusers_tokenizer_path(model_path: str, load_kwargs: Optional[dict[str, Any]] = None) -> str:
"""Resolve tokenizer path for diffusers pipelines with subfoldered sub-models."""
pipeline_path = Path(model_path)
load_kwargs = load_kwargs or {}
subfolder = load_kwargs.get("subfolder") or load_kwargs.get("extra_args", {}).get("subfolder")
if not subfolder:
return str(pipeline_path)

tokenizer_path = pipeline_path / "tokenizer"
if (tokenizer_path / "tokenizer_config.json").exists():
return str(tokenizer_path)
return str(pipeline_path / subfolder)


def get_tokenizer(model_name_or_path: str, **kwargs) -> Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"]:
"""Get HF model's tokenizer."""
tokenizer = from_pretrained(AutoTokenizer, model_name_or_path, "tokenizer", **kwargs)
Expand Down
16 changes: 16 additions & 0 deletions olive/passes/onnx/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,22 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
required=False,
description="Remove language modeling head from your ONNX model.",
),
"use_cache": PassConfigParam(
type_=bool,
required=False,
description=(
"Include past/present key-value cache inputs and outputs in the ONNX model. "
"Set to false for Stable Diffusion text encoder exports."
),
),
"hidden_states_layers": PassConfigParam(
type_=list[int],
required=False,
description=(
"Hugging Face hidden_states layer indices to concatenate into prompt_embeds "
"(for example [9, 18, 27] for Flux/Qwen3 text encoders)."
),
),
"enable_cuda_graph": PassConfigParam(
type_=bool,
required=False,
Expand Down
3 changes: 2 additions & 1 deletion olive/passes/pytorch/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from transformers import __version__ as transformers_version

from olive.common.config_utils import NestedConfig, validate_config
from olive.common.hf.utils import resolve_diffusers_tokenizer_path
from olive.common.utils import cleanup_memory
from olive.data.config import DataConfig
from olive.data.template import huggingface_data_config_template
Expand Down Expand Up @@ -279,7 +280,7 @@ def get_calibration_dataset(
"""
if not data_config and isinstance(model, HfModelHandler):
data_config = get_calibration_data_config(
model.model_name_or_path,
resolve_diffusers_tokenizer_path(model.model_name_or_path, model.get_load_kwargs()),
trust_remote_code=model.get_load_kwargs().get("trust_remote_code", False),
split=split,
batch_size=batch_size,
Expand Down
5 changes: 4 additions & 1 deletion olive/workflows/run/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from olive.cache import CacheConfig
from olive.common.config_utils import NestedConfig, validate_config
from olive.common.constants import DEFAULT_CACHE_DIR, DEFAULT_HF_TASK, DEFAULT_WORKFLOW_ID
from olive.common.hf.utils import resolve_diffusers_tokenizer_path
from olive.data.config import DataComponentConfig, DataConfig
from olive.data.container.dummy_data_container import TRANSFORMER_DUMMY_DATA_CONTAINER
from olive.data.container.huggingface_container import HuggingfaceContainer
Expand Down Expand Up @@ -224,7 +225,9 @@ def validate_data_configs_with_hf_model(cls, v, info):
model_name = input_model_config["config"]["model_attributes"].get("_name_or_path")
else:
task = input_model_config["config"].get("task", DEFAULT_HF_TASK)
model_name = input_model_config["config"]["model_path"]
model_path = input_model_config["config"]["model_path"]
load_kwargs = input_model_config["config"].get("load_kwargs") or {}
model_name = resolve_diffusers_tokenizer_path(model_path, load_kwargs)

model_info = {
"model_name": model_name,
Expand Down
31 changes: 31 additions & 0 deletions test/passes/onnx/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,37 @@ def fake_create_model(*_, **kwargs):
assert fake_builder.create_model.call_args.kwargs["input_path"] == str(test_model_path)


def test_model_builder_prompt_embeds_options_forwarded(tmp_path, monkeypatch):
input_model = make_local_tiny_llama(tmp_path / "input_model", "hf")
output_folder = tmp_path / "output_model"
captured_kwargs = {}

def fake_create_model(*_, **kwargs):
captured_kwargs.update(kwargs)
output_dir = Path(kwargs["output_dir"])
(output_dir / kwargs["filename"]).write_text("dummy onnx file")
(output_dir / "genai_config.json").write_text("{}")

_mock_genai_builder(monkeypatch, fake_create_model)

p = create_pass_from_dict(
ModelBuilder,
{
"precision": "int4",
"exclude_lm_head": True,
"use_cache": False,
"hidden_states_layers": [9, 18, 27],
},
disable_search=True,
)
output_model = p.run(input_model, output_folder)

assert isinstance(output_model, ONNXModelHandler)
assert captured_kwargs["exclude_lm_head"] is True
assert captured_kwargs["use_cache"] is False
assert captured_kwargs["hidden_states_layers"] == [9, 18, 27]


def test_model_builder_apply_annotations_on_single_file_fallback(tmp_path, monkeypatch):
def fake_create_model(
model_name, input_path, output_dir, precision, execution_provider, cache_dir, filename, **kwargs
Expand Down