diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index 26b0c635..ceecd5c2 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -388,7 +388,10 @@ def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkCo # Tokenizer check (light API call, no download) model_name = config.model_params.name - tokenizer_name = model_name if _check_tokenizer_exists(model_name) else None + tokenizer_source = config.model_params.tokenizer_name or model_name + tokenizer_name = ( + tokenizer_source if _check_tokenizer_exists(tokenizer_source) else None + ) # Streaming logger.info( diff --git a/src/inference_endpoint/config/schema.py b/src/inference_endpoint/config/schema.py index b5762f64..614d7599 100644 --- a/src/inference_endpoint/config/schema.py +++ b/src/inference_endpoint/config/schema.py @@ -199,6 +199,10 @@ class ModelParams(BaseModel): StreamingMode, cyclopts.Parameter(alias="--streaming", help="Streaming mode: auto/on/off"), ] = StreamingMode.AUTO + tokenizer_name: str | None = Field( + None, + description="HuggingFace tokenizer repo ID. Overrides model name for tokenizer loading.", + ) class SubmissionReference(BaseModel): diff --git a/src/inference_endpoint/config/templates/concurrency_template_full.yaml b/src/inference_endpoint/config/templates/concurrency_template_full.yaml index d9191714..b92495f2 100644 --- a/src/inference_endpoint/config/templates/concurrency_template_full.yaml +++ b/src/inference_endpoint/config/templates/concurrency_template_full.yaml @@ -14,6 +14,7 @@ model_params: max_new_tokens: 1024 # Max output tokens osl_distribution: null # Output sequence length distribution streaming: 'on' # Streaming mode: auto/on/off | options: auto, on, off + tokenizer_name: null # HuggingFace tokenizer repo ID. Overrides model name for tokenizer loading. datasets: # Dataset configs - name: perf type: performance # Dataset purpose: performance or accuracy | options: performance, accuracy diff --git a/src/inference_endpoint/config/templates/offline_template_full.yaml b/src/inference_endpoint/config/templates/offline_template_full.yaml index a40d469c..59da6977 100644 --- a/src/inference_endpoint/config/templates/offline_template_full.yaml +++ b/src/inference_endpoint/config/templates/offline_template_full.yaml @@ -14,6 +14,7 @@ model_params: max_new_tokens: 1024 # Max output tokens osl_distribution: null # Output sequence length distribution streaming: 'off' # Streaming mode: auto/on/off | options: auto, on, off + tokenizer_name: null # HuggingFace tokenizer repo ID. Overrides model name for tokenizer loading. datasets: # Dataset configs - name: perf type: performance # Dataset purpose: performance or accuracy | options: performance, accuracy diff --git a/src/inference_endpoint/config/templates/online_template_full.yaml b/src/inference_endpoint/config/templates/online_template_full.yaml index 978be652..2e54aa8d 100644 --- a/src/inference_endpoint/config/templates/online_template_full.yaml +++ b/src/inference_endpoint/config/templates/online_template_full.yaml @@ -14,6 +14,7 @@ model_params: max_new_tokens: 1024 # Max output tokens osl_distribution: null # Output sequence length distribution streaming: 'on' # Streaming mode: auto/on/off | options: auto, on, off + tokenizer_name: null # HuggingFace tokenizer repo ID. Overrides model name for tokenizer loading. datasets: # Dataset configs - name: perf type: performance # Dataset purpose: performance or accuracy | options: performance, accuracy diff --git a/tests/unit/config/test_schema.py b/tests/unit/config/test_schema.py index e7ea0e51..a6077012 100644 --- a/tests/unit/config/test_schema.py +++ b/tests/unit/config/test_schema.py @@ -68,6 +68,7 @@ def test_defaults(self): params = ModelParams(name="test") assert params.temperature is None assert params.max_new_tokens == 1024 + assert params.tokenizer_name is None @pytest.mark.unit def test_with_osl_distribution(self): @@ -84,6 +85,14 @@ def test_with_osl_distribution(self): assert params.temperature == 0.5 assert params.osl_distribution.type == OSLDistributionType.NORMAL + @pytest.mark.unit + def test_tokenizer_name_override(self): + params = ModelParams( + name="qwen/qwen3.6-35b-a3b", tokenizer_name="Qwen/Qwen3.6-35B-A3B" + ) + assert params.tokenizer_name == "Qwen/Qwen3.6-35B-A3B" + assert params.name == "qwen/qwen3.6-35b-a3b" + class TestAPIType: @pytest.mark.unit