Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
477 changes: 469 additions & 8 deletions docs/tutorials/tlm_structured_outputs/index.ipynb

Large diffs are not rendered by default.

27 changes: 27 additions & 0 deletions tlm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tlm.config.presets import WorkflowType
from tlm.inference import InferenceResult, tlm_inference
from tlm.types import Eval
from tlm.utils.structured_output_utils import _get_untrustworthy_fields


def is_notebook() -> bool:
Expand Down Expand Up @@ -174,3 +175,29 @@ async def _async_inference(
context=context,
config=config,
)

def get_untrustworthy_fields(
self,
*,
tlm_result: InferenceResult,
threshold: float = 0.8,
display_details: bool = True,
) -> list[str]:
"""Gets the fields of a structured output response that are considered untrustworthy by TLM.
Only works for responses that are valid JSON objects (uses `response_format` to specify the output format).
Prints detailed information about the untrustworthy fields if `display_details` is True.

Args:
response: The OpenAI ChatCompletion response object to evaluate
tlm_result: The result object from a previous TLM call
threshold: The threshold for considering a field untrustworthy
display_details (bool): Whether to display detailed information about the untrustworthy fields

Returns:
list[str]: The fields of the response that are considered untrustworthy by TLM
"""
return _get_untrustworthy_fields(
tlm_result=tlm_result,
threshold=threshold,
display_details=display_details,
)
2 changes: 1 addition & 1 deletion tlm/config/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def from_inference_params(
score: bool,
constrain_outputs: list[str] | None = None,
) -> "WorkflowType":
if openai_args.get("response_format") is not None and score:
if openai_args.get("response_format") is not None:
return cls.STRUCTURED_OUTPUT_SCORING

if rag:
Expand Down
17 changes: 13 additions & 4 deletions tlm/templates/reflection_completion_templates.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from typing import Callable, ClassVar, Literal
import json
import ast
from pydantic import BaseModel, Field

from tlm.types.base import SOReflectionScoreConfigType
Expand Down Expand Up @@ -808,7 +809,12 @@ def create(cls, reasoning_effort: ReasoningEffort, **kwargs) -> ReflectionComple

@classmethod
def construct_response_format(cls, response_json: str) -> type[BaseModel] | None:
response_fields = json.loads(response_json).keys()
try:
response_dict = json.loads(response_json)
except Exception:
response_dict = ast.literal_eval(response_json)

response_fields = response_dict.keys()
ResponseFields = Literal[tuple(response_fields)] # type: ignore

class IncorrectField(BaseModel):
Expand Down Expand Up @@ -876,7 +882,12 @@ def create(cls, reasoning_effort: ReasoningEffort, **kwargs) -> ReflectionComple

@classmethod
def construct_response_format(cls, response_json: str) -> type[BaseModel] | None:
response_fields = json.loads(response_json).keys()
try:
response_dict = json.loads(response_json)
except Exception:
response_dict = ast.literal_eval(response_json)

response_fields = response_dict.keys()
ResponseFields = Literal[tuple(response_fields)] # type: ignore

class IncorrectField(BaseModel):
Expand Down Expand Up @@ -910,8 +921,6 @@ class RatingModel(IncorrectFieldEvaluationBase):
ReflectionRAGIssuesTemplate,
],
WorkflowType.STRUCTURED_OUTPUT_SCORING: [
# ReflectionCertaintyTemplate,
# ReflectionKnowledgeGapTemplate,
SelfReflectionSOFieldAccuracyConfig,
SelfReflectionSOFieldKnowledgeGapConfig,
ReflectionArgumentTemplate,
Expand Down
2 changes: 1 addition & 1 deletion tlm/utils/completion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def _parse_completion(completion: Completion, reference_answer: str | None = Non

completion.add_response_field(
ExtractedResponseField.MAPPED_SCORE,
score_mapper(unmapped_overall_score),
score_mapper(str(unmapped_overall_score)),
)


Expand Down
8 changes: 4 additions & 4 deletions tlm/utils/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ def extract_message_content(completion: Dict[str, Any]) -> str:
def extract_structured_output_field(message_content: str, field: str) -> str | None:
try:
return str(ast.literal_eval(message_content)[field])
except Exception as e:
logger.warning(f"ast.literal_eval failed for message_content: {message_content}\nError: {e}")
except Exception:
pass

try:
return str(json.loads(message_content)[field])
except Exception as e:
logger.warning(f"json.loads failed for message_content: {message_content}\nError: {e}")
except Exception:
pass

return None
8 changes: 6 additions & 2 deletions tlm/utils/response_format_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pydantic import BaseModel, Field, create_model
import json
import copy

import ast
from tlm.types import CompletionParams
from tlm.config.defaults import get_settings

Expand Down Expand Up @@ -55,6 +55,10 @@ def add_explanation_field(schema: Dict[str, Any]) -> Dict[str, Any]:
def construct_per_field_response_format_model(
reference_answer: str, per_field_score_response_format: type[BaseModel]
) -> type[BaseModel]:
answer_keys = json.loads(reference_answer).keys()
try:
answer_keys = json.loads(reference_answer).keys()
except Exception:
answer_keys = ast.literal_eval(reference_answer).keys()

fields = {key: (per_field_score_response_format, Field(...)) for key in answer_keys}
return create_model(per_field_score_response_format.__name__, **fields) # type:ignore
8 changes: 6 additions & 2 deletions tlm/utils/scoring/per_field_scoring_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import numpy as np
from typing import Callable

import ast
from tlm.types import FieldMetadata, Completion, SOReflectionScoreConfigType
from tlm.utils.math_utils import make_score_asymptotic
from tlm.config.presets import (
Expand Down Expand Up @@ -41,7 +41,11 @@ def extract_incorrect_fields_reflection_metadata(
incorrect_fields_list = answer_json["incorrect_fields"]
incorrect_field_names_and_explanations = {item["field_name"]: item["explanation"] for item in incorrect_fields_list}

field_names = json.loads(reference_answer).keys()
try:
field_names = json.loads(reference_answer).keys()
except Exception:
field_names = ast.literal_eval(reference_answer).keys()

per_field_metadata = {}

# construct scores and mapped scores for each field for downstream use of per-field score details
Expand Down
68 changes: 68 additions & 0 deletions tlm/utils/structured_output_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import ast
import json

from tlm.inference import InferenceResult


def _get_untrustworthy_fields(
tlm_result: InferenceResult,
threshold: float = 0.8,
display_details: bool = True,
) -> list[str]:
tlm_metadata = tlm_result["metadata"]
response_text = tlm_result["response"].choices[0].message.content # type: ignore

if tlm_metadata is None or "per_field_score" not in tlm_metadata:
raise ValueError(
"`per_field_score` is not present in the metadata.\n"
"`get_untrustworthy_fields()` can only be called scoring structured outputs responses."
)

try:
so_response = json.loads(response_text)
except Exception:
pass
try:
so_response = ast.literal_eval(response_text)
except Exception:
raise ValueError(
"The LLM response must be a valid JSON output (use `response_format` to specify the output format)"
)

per_field_score = tlm_metadata["per_field_score"]
per_score_details = []

# handle cases where error log is returned
if len(per_field_score) == 1 and isinstance(per_field_score.get("error"), str):
print("Per-field score returned an error:")
print(per_field_score.get("error"))
return []

for key, value in per_field_score.items():
score = value["score"]
if float(score) < threshold:
key_details = {
"response": so_response[key],
"score": score,
"explanation": value["explanation"],
}
per_score_details.append({key: key_details})

per_score_details.sort(key=lambda x: next(iter(x.values()))["score"])
untrustworthy_fields = [next(iter(item.keys())) for item in per_score_details]

if display_details:
if len(untrustworthy_fields) == 0:
print("No untrustworthy fields found")

else:
print(f"Untrustworthy fields: {untrustworthy_fields}\n")
for item in per_score_details:
print(f"Field: {next(iter(item.keys()))}")
details = next(iter(item.values()))
print(f"Response: {details['response']}")
print(f"Score: {details['score']}")
print(f"Explanation: {details['explanation']}")
print()

return untrustworthy_fields