Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING

from transformers import AutoTokenizer
from transformers import AutoTokenizer, PreTrainedTokenizerFast

if TYPE_CHECKING:
from transformers import PreTrainedTokenizerBase
Expand Down Expand Up @@ -76,9 +76,19 @@ def __init__(self, tokenizer_name: str, n_workers: int) -> None:
def _get_thread_tokenizer(self) -> PreTrainedTokenizerBase:
"""Return the tokenizer for the current thread, loading it if needed."""
if getattr(self._thread_local, "tokenizer", None) is None:
self._thread_local.tokenizer = AutoTokenizer.from_pretrained(
self._tokenizer_name
)
try:
self._thread_local.tokenizer = AutoTokenizer.from_pretrained(
self._tokenizer_name
)
except Exception:
# AutoTokenizer loads config.json to detect the model type; for
# models with unknown model_type (e.g. deepseek_v4 in older
# transformers) or missing rope config fields, this fails.
# Fall back to PreTrainedTokenizerFast which reads only
# tokenizer.json / tokenizer_config.json and skips model config.
self._thread_local.tokenizer = PreTrainedTokenizerFast.from_pretrained(
self._tokenizer_name
)
Comment on lines +79 to +91
Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Catching a broad Exception from AutoTokenizer.from_pretrained without logging can mask unrelated configuration issues (e.g., missing files, permission errors, offline mode) and make token-metric failures opaque. Consider catching a narrower set of expected failures (or at least logging the original exception at debug/warning level) before falling back to PreTrainedTokenizerFast.

Copilot uses AI. Check for mistakes.
return self._thread_local.tokenizer

def _token_count_worker(self, text: str) -> int:
Expand Down
71 changes: 49 additions & 22 deletions src/inference_endpoint/commands/benchmark/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@
APIType,
BenchmarkConfig,
DatasetType,
LoadPattern,
LoadPatternType,
StreamingMode,
TestMode,
TestType,
Expand Down Expand Up @@ -302,7 +300,15 @@ def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkCo

# Tokenizer check (light API call, no download)
model_name = config.model_params.name
tokenizer_name = model_name if _check_tokenizer_exists(model_name) else None
tokenizer_override = config.model_params.tokenizer_name
tokenizer_name: str | None
if tokenizer_override:
tokenizer_name = tokenizer_override
logger.info(
f"Tokenizer available for model: {model_name} (override: {tokenizer_override})"
)
Comment on lines +306 to +309
Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The log line says “Tokenizer available for model … (override: …)” but when tokenizer_override is set you’re not actually checking that the override exists/is loadable. This message is misleading and can confuse debugging; consider logging that an override is being used (and optionally validating the override with _check_tokenizer_exists(tokenizer_override) before proceeding).

Suggested change
tokenizer_name = tokenizer_override
logger.info(
f"Tokenizer available for model: {model_name} (override: {tokenizer_override})"
)
if _check_tokenizer_exists(tokenizer_override):
tokenizer_name = tokenizer_override
logger.info(
f"Using tokenizer override for model: {model_name} ({tokenizer_override})"
)
else:
tokenizer_name = None
logger.warning(
f"Tokenizer override not available for model: {model_name} ({tokenizer_override})"
)

Copilot uses AI. Check for mistakes.
else:
Comment on lines +303 to +310
tokenizer_name = model_name if _check_tokenizer_exists(model_name) else None

# Streaming
logger.info(
Expand Down Expand Up @@ -368,7 +374,7 @@ def _build_phases(ctx: BenchmarkContext) -> list[PhaseConfig]:
min_sample_count=acc_ds.num_samples() * acc_ds.repeats,
rng_sched=ctx.rt_settings.rng_sched,
rng_sample_index=ctx.rt_settings.rng_sample_index,
load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT),
load_pattern=ctx.rt_settings.load_pattern,
)
Comment on lines 374 to 378
Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In accuracy phases, load_pattern=ctx.rt_settings.load_pattern means accuracy evaluation will be throttled by the main run’s load pattern (e.g., Poisson target QPS / concurrency), which can make scoring unnecessarily slow. Previously this forced MAX_THROUGHPUT; consider keeping Burst/MAX_THROUGHPUT for accuracy phases (or making it explicitly configurable) to avoid long eval runs.

Copilot uses AI. Check for mistakes.
phases.append(
PhaseConfig(eval_cfg.dataset_name, acc_settings, acc_ds, PhaseType.ACCURACY)
Expand Down Expand Up @@ -649,27 +655,48 @@ def finalize_benchmark(ctx: BenchmarkContext, bench: BenchmarkResult) -> None:
# Write scoring artifacts + copy event log from tmpfs to disk
_write_scoring_artifacts(ctx, result, bench.tmpfs_dir)

# Accuracy scoring
# Accuracy scoring — continue past per-scorer failures so partial results are saved
accuracy_scores: dict[str, Any] = {}
scoring_failed = False
for eval_cfg in ctx.eval_configs:
scorer_instance = eval_cfg.scorer(
eval_cfg.dataset_name,
eval_cfg.dataset,
eval_cfg.report_dir,
extractor=eval_cfg.extractor,
ground_truth_column=eval_cfg.ground_truth_column,
try:
scorer_instance = eval_cfg.scorer(
eval_cfg.dataset_name,
eval_cfg.dataset,
eval_cfg.report_dir,
extractor=eval_cfg.extractor,
ground_truth_column=eval_cfg.ground_truth_column,
)
score, n_repeats = scorer_instance.score()
assert eval_cfg.dataset.data is not None
accuracy_scores[eval_cfg.dataset_name] = {
"dataset_name": eval_cfg.dataset_name,
"num_samples": len(eval_cfg.dataset.data),
"extractor": eval_cfg.extractor.__name__,
"ground_truth_column": eval_cfg.ground_truth_column,
"score": score,
"n_repeats": n_repeats,
}
logger.info(
f"Score for {eval_cfg.dataset_name}: {score} ({n_repeats} repeats)"
)
except Exception as e:
scoring_failed = True
logger.error(f"Scoring failed for {eval_cfg.dataset_name}: {e}")
Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

except Exception as e: logger.error(...) logs only str(e) and drops the traceback, which makes scorer failures hard to diagnose in CI/user runs. Consider including exc_info=True (and/or logging the exception type) so the report includes actionable stack traces while still continuing past failures.

Suggested change
logger.error(f"Scoring failed for {eval_cfg.dataset_name}: {e}")
logger.error(
"Scoring failed for %s: %s: %s",
eval_cfg.dataset_name,
type(e).__name__,
e,
exc_info=True,
)

Copilot uses AI. Check for mistakes.
assert eval_cfg.dataset.data is not None
accuracy_scores[eval_cfg.dataset_name] = {
"dataset_name": eval_cfg.dataset_name,
"num_samples": len(eval_cfg.dataset.data),
"extractor": eval_cfg.extractor.__name__,
"ground_truth_column": eval_cfg.ground_truth_column,
"score": None,
"error": str(e),
}
Comment on lines 661 to +694
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The accuracy scoring loop contains significant code duplication for the result dictionary and uses assert statements for data validation inside an exception handler. If dataset.data is missing, the assert will raise an AssertionError which, while caught by the except Exception, makes the logic brittle. Refactoring this to use a base dictionary and safer access improves maintainability and ensures the loop continues gracefully as intended.

    for eval_cfg in ctx.eval_configs:
        # Prepare base results; ensure data exists to avoid TypeError in len()
        num_samples = len(eval_cfg.dataset.data) if eval_cfg.dataset.data is not None else 0
        res_base = {
            "dataset_name": eval_cfg.dataset_name,
            "num_samples": num_samples,
            "extractor": eval_cfg.extractor.__name__,
            "ground_truth_column": eval_cfg.ground_truth_column,
        }
        try:
            scorer_instance = eval_cfg.scorer(
                eval_cfg.dataset_name,
                eval_cfg.dataset,
                eval_cfg.report_dir,
                extractor=eval_cfg.extractor,
                ground_truth_column=eval_cfg.ground_truth_column,
            )
            score, n_repeats = scorer_instance.score()
            accuracy_scores[eval_cfg.dataset_name] = {
                **res_base,
                "score": score,
                "n_repeats": n_repeats,
            }
            logger.info(
                f"Score for {eval_cfg.dataset_name}: {score} ({n_repeats} repeats)"
            )
        except Exception as e:
            scoring_failed = True
            logger.error(f"Scoring failed for {eval_cfg.dataset_name}: {e}")
            accuracy_scores[eval_cfg.dataset_name] = {
                **res_base,
                "score": None,
                "error": str(e),
            }


if scoring_failed:
logger.warning(
"One or more accuracy scorers failed — partial accuracy results saved"
)
score, n_repeats = scorer_instance.score()
assert eval_cfg.dataset.data is not None
accuracy_scores[eval_cfg.dataset_name] = {
"dataset_name": eval_cfg.dataset_name,
"num_samples": len(eval_cfg.dataset.data),
"extractor": eval_cfg.extractor.__name__,
"ground_truth_column": eval_cfg.ground_truth_column,
"score": score,
"n_repeats": n_repeats,
}
logger.info(f"Score for {eval_cfg.dataset_name}: {score} ({n_repeats} repeats)")

# Report metrics: prefer Report from KVStore, fall back to SessionResult
if report is not None and report.duration_ns is not None:
Expand Down
6 changes: 6 additions & 0 deletions src/inference_endpoint/config/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,12 @@ class ModelParams(BaseModel):
StreamingMode,
cyclopts.Parameter(alias="--streaming", help="Streaming mode: auto/on/off"),
] = StreamingMode.AUTO
tokenizer_name: str | None = Field(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add this to CLI as well? (by using cyclopts.Parameter)

None,
description="Local tokenizer path override. Use when AutoTokenizer.from_pretrained "
"fails for the HF model name (e.g. transformers ≥5.4 rope_theta regression "
"for DeepSeek-V4). Defaults to the model name if unset.",
)


class SubmissionReference(BaseModel):
Expand Down
8 changes: 6 additions & 2 deletions src/inference_endpoint/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,13 @@ def text_after_first_chunk(self) -> str:
"""
parts: list[str] = []
if self.reasoning:
if isinstance(self.reasoning, tuple) and len(self.reasoning) > 1:
if isinstance(self.reasoning, str):
# str reasoning is the fully joined streaming trace — include it
# in the TPOT denominator. Over-counts by one token (the first
# token is not excluded), but the error is negligible in practice.
parts.append(self.reasoning)
elif isinstance(self.reasoning, tuple) and len(self.reasoning) > 1:
Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

text_after_first_chunk() now treats reasoning as a joined str for streaming cases, but the method docstring above still states that for non-streaming (str fields) it returns an empty string. Please update the docstring (and ideally the TextModelOutput field docs) to reflect that str reasoning may also represent a streaming trace, otherwise future callers may make incorrect assumptions.

Copilot uses AI. Check for mistakes.
parts.extend(self.reasoning[1:])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic to include string reasoning in the TPOT denominator introduces a slight inaccuracy in metrics by failing to exclude the first token of the response. If the accumulator preserves reasoning as a tuple of chunks (as suggested in accumulator.py), this special case is unnecessary. Reverting to the previous logic ensures that TPOT is calculated only on tokens generated after the first chunk, maintaining metric precision.

Suggested change
if isinstance(self.reasoning, str):
# str reasoning is the fully joined streaming trace — include it
# in the TPOT denominator. Over-counts by one token (the first
# token is not excluded), but the error is negligible in practice.
parts.append(self.reasoning)
elif isinstance(self.reasoning, tuple) and len(self.reasoning) > 1:
parts.extend(self.reasoning[1:])
if isinstance(self.reasoning, tuple) and len(self.reasoning) > 1:
parts.extend(self.reasoning[1:])

Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

text_after_first_chunk() now treats str reasoning as streaming output and includes the entire reasoning string in the denominator, which contradicts the method’s docstring (it says str fields are non-streaming and should return an empty string) and still can’t actually exclude the first streamed chunk when reasoning is stored as a joined str. Consider updating the docstring to match the new semantics and/or preserving chunk boundaries (tuple) so the “after first chunk” calculation remains exact.

Copilot uses AI. Check for mistakes.
# str reasoning: single chunk, skip entirely (it IS the first chunk)
if self.output:
if isinstance(self.output, str):
# Non-streaming: if reasoning was present and was the first chunk,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""Preset transforms for the AIME25 dataset."""

from inference_endpoint.dataset_manager.transforms import (
AddStaticColumns,
Transform,
UserPromptFormatter,
)
Expand All @@ -27,4 +28,38 @@ def gptoss() -> list[Transform]:
UserPromptFormatter(
user_prompt_format="{question}\nPlease reason step by step, and put your final answer within \\boxed{{}}.",
),
# Enable DeepSeek thinking mode so the model uses chain-of-thought reasoning.
# vLLM's reasoning_parser strips <think>...</think> tokens into reasoning_content;
# the final boxed answer ends up in content where boxed_math_extractor finds it.
AddStaticColumns({"chat_template_kwargs": {"thinking": True}}),
]


def gptoss_budget() -> list[Transform]:
return [
UserPromptFormatter(
user_prompt_format="{question}\nPlease reason step by step, and put your final answer within \\boxed{{}}.",
),
# Same as gptoss but caps thinking at 8192 tokens via budget_tokens so the model
# is forced to emit a final answer rather than consuming all max_new_tokens in
# the reasoning phase (observed issue: 85% of responses had empty answer text).
AddStaticColumns({"chat_template_kwargs": {"thinking": True, "budget_tokens": 8192}}),
]


def gptoss_budget_20k() -> list[Transform]:
return [
UserPromptFormatter(
user_prompt_format="{question}\nPlease reason step by step, and put your final answer within \\boxed{{}}.",
),
AddStaticColumns({"chat_template_kwargs": {"thinking": True, "budget_tokens": 20000}}),
]


def gptoss_budget_20k_pre() -> list[Transform]:
return [
UserPromptFormatter(
user_prompt_format="Please reason step by step, and put your final answer within \\boxed{{}}.\n\n{question}",
),
AddStaticColumns({"chat_template_kwargs": {"thinking": True, "budget_tokens": 20000}}),
]
7 changes: 6 additions & 1 deletion src/inference_endpoint/dataset_manager/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,12 @@ def __init__(self, data: dict[str, Any]):
def __call__(self, df: pd.DataFrame) -> pd.DataFrame:
"""Add the static columns to the row."""
for key, value in self.data.items():
df[key] = value
# Wrap dict/list values in a list so pandas doesn't try to align
# on index keys (e.g. {"thinking": True} would produce NaN otherwise).
if isinstance(value, (dict, list)):
df[key] = [value] * len(df)
else:
df[key] = value
Comment on lines +126 to +131
Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AddStaticColumns now broadcasts dict/list values by doing [value] * len(df), which repeats the same mutable object reference for every row. If anything downstream mutates one row’s dict/list (e.g., chat_template_kwargs), it will silently affect all rows. Consider using per-row copies (e.g., deepcopy) or constructing a fresh object per row to avoid shared mutable state.

Copilot uses AI. Check for mistakes.
Comment on lines +126 to +131
Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change adds special handling for dict/list static values, but the existing unit tests for AddStaticColumns don’t cover dict/list inputs. Please add a regression test ensuring dict/list values are preserved per-row (and not converted to NaN via pandas alignment).

Copilot uses AI. Check for mistakes.
return df


Expand Down
117 changes: 116 additions & 1 deletion src/inference_endpoint/evaluation/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class ABCDExtractor(Extractor, extractor_id="abcd_extractor"):
Returns:
"choice" key (see GQPA dataset columns) or empty string if no answer is found.
Examples:
>>> ABCDExtractor.extract("The answer is B")
>>> ABCDExtractor.extract("Answer: B")
'choice2'
>>> ABCDExtractor.extract("**Answer:** C")
'choice3'
Expand Down Expand Up @@ -220,6 +220,121 @@ def extract(cls, text: str, default: str | None = None) -> str | None:
return default if default is not None else ""


class LetterExtractor(Extractor, extractor_id="letter_extractor"):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: abcd extractor and letter abstractor names are bit confusing

"""Extract MCQ answer letter (A–J) from response text, returning the letter directly.

Like ABCDExtractor but returns the raw letter ("A", "B", … "J") instead of
mapping to "choice1"–"choice4". Supports datasets with up to ten answer
options (e.g. MMLU-Pro A–J) where ground-truth labels are stored as the
letter itself.

Examples:
>>> LetterExtractor.extract("Answer: B")
'B'
>>> LetterExtractor.extract("**Answer:** G")
'G'
>>> LetterExtractor.extract("\\\\boxed{E}")
'E'
"""

LETTERS = frozenset("ABCDEFGHIJ")

PATTERNS = [
# 0) **Answer:** A or *Answers* – B
re.compile(
r"""(?ix)
(?:\*{1,2}|_{1,2})
Answer[s]?
\s*[:\-–]?
(?:\*{1,2}|_{1,2})
\s*
([A-J])\b
""",
re.X,
),
# 0.1) Answer: A (with optional markdown)
re.compile(
r"""(?ix)
^\s*
(?:\*{1,2}|_{1,2})?
Answer:?
(?:\*{1,2}|_{1,2})?
\s*:?\s*
(?:\*{1,2}|_{1,2})?
([A-J])
(?:\*{1,2}|_{1,2})?
\s*
""",
re.MULTILINE,
),
# 1) Answer: (C)
re.compile(r"(?ix)\bAnswer[s]?\b\s*[:\-–]?\s*\(\s*([A-J])\s*\)"),
# 2) Answer: C
re.compile(r"(?ix)\bAnswer[s]?\b\s*[:\-–]?\s*([A-J])\b"),
# 3) Option B or Choice: C
re.compile(r"(?ix)\b(?:Option|Choice)\b\s*[:\-–]?\s*([A-J])\b"),
# 7) \boxed{A}
re.compile(r"(?x)\\boxed\{[^}]*?([A-J])[^}]*\}", re.MULTILINE),
# 7.5) \boxed{\textbf{C}}
re.compile(
r"(?x)\\boxed\{[^}]*?\\textbf\{[^}]*?([A-J])[^}]*\}[^}]*\}", re.MULTILINE
),
# 7.51) \boxed{\text{C}}
re.compile(
r"(?x)\\boxed\{[^}]*?\\text\{[^}]*?([A-J])[^}]*\}[^}]*\}", re.MULTILINE
),
# 4) bare singletons: (A) [B]
re.compile(r"(?x)(?<![A-Za-z0-9])[\(\[]\s*([A-J])\s*[\)\]](?![A-Za-z0-9])"),
# 5) Markdown-wrapped: *A* **B**
re.compile(
r"(?x)(?<![A-Za-z0-9])(?:\*{1,2}|_{1,2})([A-J])(?:\*{1,2}|_{1,2})(?![A-Za-z0-9])"
),
# 6) \textbf{C}
re.compile(r"(?x)\\textbf\{[^}]*?([A-J])[^}]*\}"),
# 8) **D) description**
re.compile(r"""(?x)
(?<![A-Za-z0-9])
(?:\*{1,2}|_{1,2})
\s*([A-J])\)
[^*_\n]+?
(?:\*{1,2}|_{1,2})
(?![A-Za-z0-9])
"""),
# 9) final fallback: line starting with a single letter
re.compile(
r"""(?x)^\s*
(?:\*{1,2}|_{1,2})?
([A-J])
(?:\*{1,2}|_{1,2})?
\s*[\.\)\-–:]?
\s*.*$
""",
re.MULTILINE,
),
]

@classmethod
def extract(cls, text: str, default: str | None = None) -> str | None:
matches = []
for prio, pat in enumerate(cls.PATTERNS):
m = pat.search(text)
if m:
letter = m.group(1).upper()
if letter in cls.LETTERS:
matches.append((prio, m, letter))

matches.sort(key=lambda triple: (triple[0], len(triple[1].group(0))))

for _, _, letter in matches:
return letter

stripped = text.removeprefix("**")
if stripped and stripped[0].upper() in cls.LETTERS:
return stripped[0].upper()

return default if default is not None else ""


class BoxedMathExtractor(Extractor, extractor_id="boxed_math_extractor"):
"""Extract boxed math answer from response text.
Based on OpenAI's extract_boxed_math function from GPT-OSS.
Expand Down
13 changes: 6 additions & 7 deletions src/inference_endpoint/openai/accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,14 @@ def add_chunk(self, delta: OpenAISSEDelta) -> StreamChunk | None:

def get_final_output(self) -> QueryResult:
if self.reasoning_chunks:
# If there are reasoning chunks, then the first chunk received
# is the first reasoning chunk. The rest of the reasoning chunks,
# as well as the output chunks can be joined together.
resp_reasoning: list[str] = [self.reasoning_chunks[0]]
if len(self.reasoning_chunks) > 1:
resp_reasoning.append("".join(self.reasoning_chunks[1:]))
# All reasoning chunks are joined into a single string so the full
# thinking trace is captured as-is in events.jsonl. TPOT still uses
# text_after_first_chunk(), which includes string reasoning in the
# denominator (off by one token vs. the true "after first chunk"
# count, which is negligible).
Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment claims the TPOT denominator error is only “off by one token”, but joining all reasoning chunks into a single string loses the first-chunk boundary and can over-count by all tokens in the first streamed reasoning chunk (SSE deltas often contain multiple tokens). If TPOT accuracy matters, keep reasoning as chunked data (e.g., tuple with the first chunk separated) or record the first-chunk split explicitly so text_after_first_chunk() can exclude it precisely while still writing a joined reasoning string to events.jsonl for readability.

Copilot uses AI. Check for mistakes.
text_output = TextModelOutput(
output="".join(self.output_chunks),
reasoning=resp_reasoning,
reasoning="".join(self.reasoning_chunks),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Joining all reasoning chunks into a single string causes a loss of precision in TPOT (Time Per Output Token) calculations. The text_after_first_chunk() method in TextModelOutput relies on the reasoning being a tuple of chunks to correctly exclude the first chunk from the denominator. By joining them here, the first token is included in the TPOT calculation, leading to inaccurate metrics.

Since TextModelOutput already handles joining chunks in its __str__ method (used for logging and scoring), it is better to preserve the chunks as a tuple.

            # Preserve reasoning chunks as a tuple to allow accurate TPOT 
            # calculation (excluding the first chunk). TextModelOutput 
            # handles joining for display/logging via __str__.
            text_output = TextModelOutput(
                output=tuple(self.output_chunks),
                reasoning=tuple(self.reasoning_chunks),
            )

elif self.output_chunks:
# If there are only output chunks, the first chunk is used for
Expand Down
2 changes: 2 additions & 0 deletions src/inference_endpoint/openai/openai_msgspec_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def dataset_transforms(cls, model_params: ModelParams) -> list[Transform]:
"logit_bias",
"user",
"chat_template",
"chat_template_kwargs",
]
return [
ColumnFilter(
Expand Down Expand Up @@ -164,6 +165,7 @@ def to_endpoint_request(cls, query: Query) -> ChatCompletionRequest:
logit_bias=query.data.get("logit_bias"),
user=query.data.get("user"),
chat_template=query.data.get("chat_template"),
chat_template_kwargs=query.data.get("chat_template_kwargs"),
)

@classmethod
Expand Down
Loading
Loading