From f78fd94091fb79961f9f84afe84b27d35d479e84 Mon Sep 17 00:00:00 2001 From: rafapi Date: Thu, 16 Apr 2026 12:48:14 +0000 Subject: [PATCH 01/18] Core multi-step --- pipelinerl/actor.py | 4 ++-- pipelinerl/async_llm.py | 13 ++++++++++- pipelinerl/domains/miniwob/rollouts.py | 19 ++++----------- pipelinerl/finetune/rl/__init__.py | 31 +++++++++++++++++-------- pipelinerl/rollouts.py | 32 +++++++++++++++++++++++++- 5 files changed, 71 insertions(+), 28 deletions(-) diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index 074d16ab..64c06f57 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -24,7 +24,7 @@ from pipelinerl.finetune_loop import calculate_train_steps from pipelinerl.finetune.logging_ import flatten_dict_config, init_wandb from pipelinerl.llm import TrainableLLM -from pipelinerl.rollouts import BaseMetrics, RolloutResult +from pipelinerl.rollouts import BaseMetrics, RolloutResult, rollout_has_overflow from pipelinerl.shared_memory_array import SharedMemoryQueue from pipelinerl.state import TrainerState from pipelinerl.streams import ( @@ -393,7 +393,7 @@ def init_stats(self): def compute_domain_agnostic_metrics(self, result: RolloutResult) -> Dict[str, float]: metrics = {} - metrics['overflow'] = all([not training_text.finished for training_text in result.training_texts ]) + metrics['overflow'] = rollout_has_overflow(result.training_texts) metrics['num_turns'] = len(result.training_texts) metrics['prompt_tokens'] = [training_text.prompt_tokens for training_text in result.training_texts] metrics['output_tokens'] = [training_text.output_tokens for training_text in result.training_texts] diff --git a/pipelinerl/async_llm.py b/pipelinerl/async_llm.py index c541fd67..35e45ff5 100644 --- a/pipelinerl/async_llm.py +++ b/pipelinerl/async_llm.py @@ -8,7 +8,7 @@ from pipelinerl.llm import LLMCall, LLMOutput, Prompt, TokenLogprob, TrainableLLM from pipelinerl.finetune.data import MASKED_TOKEN_ID -from pipelinerl.rollouts import TrainingText +from pipelinerl.rollouts import TrainingText, apply_rollout_reward from pipelinerl.processor_factory import get_processor from omegaconf import DictConfig, ListConfig, OmegaConf @@ -250,3 +250,14 @@ def make_training_text(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText: output_tokens=output_tokens, visual_features=visual_features, ) + + +def make_training_texts_from_llm_calls( + llm: TrainableLLM, + llm_calls: list[LLMCall], + reward: float | None = None, +) -> list[TrainingText]: + training_texts = [make_training_text(llm, llm_call) for llm_call in llm_calls] + if reward is not None: + training_texts = apply_rollout_reward(training_texts, reward) + return training_texts diff --git a/pipelinerl/domains/miniwob/rollouts.py b/pipelinerl/domains/miniwob/rollouts.py index 8c489291..89e9cc5a 100644 --- a/pipelinerl/domains/miniwob/rollouts.py +++ b/pipelinerl/domains/miniwob/rollouts.py @@ -18,9 +18,9 @@ from tapeagents.remote_environment import AsyncRemoteEnvironment from tapeagents.tools.simple_browser import PageObservation -from pipelinerl.async_llm import make_training_text +from pipelinerl.async_llm import make_training_texts_from_llm_calls from pipelinerl.llm import LLMCall, TrainableLLM -from pipelinerl.rollouts import BaseMetrics, RolloutResult +from pipelinerl.rollouts import BaseMetrics, RolloutResult, summarize_training_texts from pipelinerl.world import Job from .steps import WebTape @@ -271,13 +271,8 @@ async def _execute_rollout_with_timeout( ] # (4) # For each LLM interaction in the tape, make a training example. - all_finished = 1 - prompt_tokens = [llm_call.prompt_length_tokens for llm_call in llm_calls] - output_tokens = [llm_call.output_length_tokens for llm_call in llm_calls] - training_texts = [make_training_text(llm, llm_call) for llm_call in llm_calls] - for text in training_texts: - text.reward = reward - all_finished &= 1 if text.input_ids[-1] == llm.tokenizer.eos_token_id else 0 + training_texts = make_training_texts_from_llm_calls(llm, llm_calls, reward=reward) + training_summary = summarize_training_texts(training_texts) latency = time.time() - start_time agent_time = tape.metadata.result.get("agent_execution_time", -1.0) @@ -289,7 +284,7 @@ async def _execute_rollout_with_timeout( success=reward > 0.5, no_error=no_error, no_answer=reward < 0, - overflow=not all_finished, + overflow=training_summary.overflow, n_llm_calls=n_llm_calls, n_step_errors=n_step_errors, n_page_observations=n_page_observations, @@ -307,8 +302,6 @@ async def _execute_rollout_with_timeout( latency=latency, dataset_name=problem["dataset"], domain="miniwob", - prompt_tokens=prompt_tokens, - output_tokens=output_tokens, ) @@ -340,6 +333,4 @@ def _create_failed_rollout_result(problem: dict, start_time: float, error_type: latency=latency, dataset_name=problem["dataset"], domain="miniwob", - prompt_tokens=[], - output_tokens=[], ) diff --git a/pipelinerl/finetune/rl/__init__.py b/pipelinerl/finetune/rl/__init__.py index 075e1d1e..4ca8ef8c 100644 --- a/pipelinerl/finetune/rl/__init__.py +++ b/pipelinerl/finetune/rl/__init__.py @@ -1,7 +1,7 @@ import logging import os from functools import partial -from typing import Any +from typing import Any, TYPE_CHECKING from pydantic import BaseModel, Field import numpy as np @@ -9,10 +9,14 @@ import torch import torch.nn.functional as F from datasets import Dataset -from transformers import PreTrainedModel from pipelinerl.finetune.types import PipelineBatchEncoding from pipelinerl.finetune.rl.utils import per_segment_sums +if TYPE_CHECKING: + from transformers import PreTrainedModel +else: + PreTrainedModel = Any + from .utils import ( sum_sum, mean_sum, @@ -427,22 +431,29 @@ def populate_rl_data(dataset: list[dict[str, Any]], eos_token_id: int, config: R df_init = pd.DataFrame(dataset) assert isinstance(df_init, pd.DataFrame) - # Step 1: calculate group-level statistics + # Step 1: calculate rollout- and group-level statistics df_stats = df_init[["group_id", "rollout_index", "step_index"]].copy() df_stats["num_tokens"] = df_init["input_ids"].apply(len) - # We assume that rewards for all tokens are the same + # RL preprocessing currently assumes a single reward per rollout. df_stats["rollout_reward"] = df_init["rewards"].apply(lambda x: x[0]) - # Check that the reward is the same for each step in the rollout - assert df_stats.groupby(["group_id", "rollout_index"])["rollout_reward"].nunique().max() == 1 - # Only keep step_index == 0 - df_stats = df_stats[df_stats["step_index"] == 0].drop(columns=["step_index"]) + assert df_stats.groupby(["group_id", "rollout_index"])["rollout_reward"].nunique().max() == 1, ( + "RL preprocessing expects the same reward for every step in a rollout" + ) + df_rollouts = ( + df_stats.groupby(["group_id", "rollout_index"]) + .agg( + rollout_reward=("rollout_reward", "first"), + rollout_tokens=("num_tokens", "sum"), + ) + .reset_index() + ) df_grouped = ( - df_stats.groupby("group_id") + df_rollouts.groupby("group_id") .agg( rollout_reward_sum=("rollout_reward", "sum"), rollout_reward_count=("rollout_reward", "count"), rollout_reward_std=("rollout_reward", "std"), - group_tokens=("num_tokens", "mean"), + group_tokens=("rollout_tokens", "mean"), ) .reset_index() ) diff --git a/pipelinerl/rollouts.py b/pipelinerl/rollouts.py index 1200ba23..4c71dda7 100644 --- a/pipelinerl/rollouts.py +++ b/pipelinerl/rollouts.py @@ -1,5 +1,6 @@ +from dataclasses import dataclass from pydantic import BaseModel, Field -from typing import List, Optional, Dict +from typing import List, Optional, Dict, Sequence import numpy as np class BaseMetrics(BaseModel): @@ -65,3 +66,32 @@ class RolloutResult(BaseModel): dataset_name: str | None = None group_id: str | None = None domain: str | None = None + + +@dataclass(frozen=True) +class TrainingTextSummary: + prompt_tokens: list[int] + output_tokens: list[int] + overflow: bool + num_turns: int + + +def apply_rollout_reward(training_texts: Sequence[TrainingText], reward: float) -> list[TrainingText]: + texts = list(training_texts) + for training_text in texts: + training_text.reward = reward + return texts + + +def rollout_has_overflow(training_texts: Sequence[TrainingText]) -> bool: + return any(not training_text.finished for training_text in training_texts) + + +def summarize_training_texts(training_texts: Sequence[TrainingText]) -> TrainingTextSummary: + texts = list(training_texts) + return TrainingTextSummary( + prompt_tokens=[training_text.prompt_tokens for training_text in texts], + output_tokens=[training_text.output_tokens for training_text in texts], + overflow=rollout_has_overflow(texts), + num_turns=len(texts), + ) From 48f50cba51c73697d905039f45107f9f9c7f17c7 Mon Sep 17 00:00:00 2001 From: rafapi Date: Fri, 20 Mar 2026 11:45:36 +0000 Subject: [PATCH 02/18] Extract get_reward into reusable function --- pipelinerl/domains/math/rollouts.py | 42 ++++++++++++++++------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/pipelinerl/domains/math/rollouts.py b/pipelinerl/domains/math/rollouts.py index 3e6560c0..8b712ff0 100644 --- a/pipelinerl/domains/math/rollouts.py +++ b/pipelinerl/domains/math/rollouts.py @@ -55,6 +55,28 @@ def log_config(self, domain: str = "unknown") -> None: f"buffer_tokens={self.buffer_tokens}" ) +def get_reward(answer_status: str, finished: bool, reward_table: RewardTable) -> float: + match (answer_status, finished): + case ("wrong", False): + return reward_table.wrong_answer_not_finished + case ("wrong", True): + return reward_table.wrong_answer_finished + case ("no_answer", False): + return reward_table.no_answer_not_finished + case ("no_answer", True): + return reward_table.no_answer_finished + case ("unparsable", False): + return reward_table.unparsable_not_finished + case ("unparsable", True): + return reward_table.unparsable_finished + case ("correct", False): + return reward_table.correct_answer_not_finished + case ("correct", True): + return reward_table.correct_answer_finished + case _: + raise ValueError(f"Invalid answer_status/finished combination: {answer_status}/{finished}") + + def length_penalty(max_length: int, sequence_length: int, buffer_tokens: int) -> float: """ Compute the overlong penalty @@ -100,25 +122,7 @@ async def generate_math_rollout( trace = make_training_text(llm, llm_call) # Determine reward based on answer status and finished state - match (answer_status, trace.finished): - case ("wrong", False): - reward = rewards.wrong_answer_not_finished - case ("wrong", True): - reward = rewards.wrong_answer_finished - case ("no_answer", False): - reward = rewards.no_answer_not_finished - case ("no_answer", True): - reward = rewards.no_answer_finished - case ("unparsable", False): - reward = rewards.unparsable_not_finished - case ("unparsable", True): - reward = rewards.unparsable_finished - case ("correct", False): - reward = rewards.correct_answer_not_finished - case ("correct", True): - reward = rewards.correct_answer_finished - case _: - raise ValueError(f"Invalid answer_status/finished combination: {answer_status}/{trace.finished}") + reward = get_reward(answer_status, trace.finished, rewards) # Apply discount factor based on output length reward *= discount_factor**llm_call.output_length_tokens From d69d60adb637da01b99d4df5b09ae263f9a23285 Mon Sep 17 00:00:00 2001 From: rafapi Date: Fri, 20 Mar 2026 11:45:39 +0000 Subject: [PATCH 03/18] Export get_reward and length_penalty --- pipelinerl/domains/math/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipelinerl/domains/math/__init__.py b/pipelinerl/domains/math/__init__.py index 9aee0b8f..7a9809b7 100644 --- a/pipelinerl/domains/math/__init__.py +++ b/pipelinerl/domains/math/__init__.py @@ -1,3 +1,3 @@ from .load_datasets import load_datasets -from .rollouts import generate_math_rollout, RewardTable +from .rollouts import generate_math_rollout, RewardTable, get_reward, length_penalty from .verifier_api import MathEnvironment, verify_answer, verify_answer_rpc \ No newline at end of file From 129033ea0cdba7b7af9ba2ae7487dd174e00185c Mon Sep 17 00:00:00 2001 From: rafapi Date: Fri, 20 Mar 2026 11:45:44 +0000 Subject: [PATCH 04/18] Add tool calling support to llm_async_generate --- pipelinerl/async_llm.py | 91 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 89 insertions(+), 2 deletions(-) diff --git a/pipelinerl/async_llm.py b/pipelinerl/async_llm.py index 35e45ff5..9c0d4bfa 100644 --- a/pipelinerl/async_llm.py +++ b/pipelinerl/async_llm.py @@ -3,6 +3,7 @@ import logging import aiohttp +import litellm import numpy as np from PIL import Image from pipelinerl.llm import LLMCall, LLMOutput, Prompt, TokenLogprob, TrainableLLM @@ -85,6 +86,9 @@ async def llm_async_generate( logger.debug(f"POST request to {llm.base_url}/v1/chat/completions") + if prompt.tools: + data["tools"] = _to_plain_obj(prompt.tools) + # Merge extra_parameters first so that data (model, messages, logprobs settings) takes precedence payload = _to_plain_obj({**extra_parameters, **data}) async with session.post( @@ -101,7 +105,8 @@ async def llm_async_generate( try: content = data["choices"][0]["message"]["content"] - if not content: + raw_tool_calls = data["choices"][0]["message"].get("tool_calls", []) + if not content and not raw_tool_calls: logger.warning(f"Empty completion {data}") parsed_logprobs = [] @@ -128,7 +133,9 @@ async def llm_async_generate( logger.exception(f"Failed to parse llm response: {data}") raise - output = LLMOutput(content=content) + output = LLMOutput(content=content or "") + if raw_tool_calls: + output.tool_calls = [litellm.ChatCompletionMessageToolCall(**tc) for tc in raw_tool_calls] llm_call = llm.log_output(prompt, output, count_tokens=False) llm_call.prompt_length_tokens = data["usage"]["prompt_tokens"] llm_call.output_length_tokens = data["usage"]["completion_tokens"] @@ -261,3 +268,83 @@ def make_training_texts_from_llm_calls( if reward is not None: training_texts = apply_rollout_reward(training_texts, reward) return training_texts + + +def make_training_text_with_tools(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText: + """Build a TrainingText for an assistant turn that may contain tool_calls. + + For turns without tool_calls this delegates to ``make_training_text``. + When tool_calls are present the assistant message dict includes them so + that ``apply_chat_template`` produces the correct token sequence matching + what vLLM actually generated (and for which we have logprobs). + """ + if not llm_call.output.tool_calls: + return make_training_text(llm, llm_call) + + llm.load_tokenizer() + + # Build the assistant message with tool_calls + assistant_msg: dict = {"role": "assistant", "content": llm_call.output.content or ""} + assistant_msg["tool_calls"] = [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + for tc in llm_call.output.tool_calls + ] + + full_messages = llm_call.prompt.messages + [assistant_msg] + + prompt_text = llm.tokenizer.apply_chat_template( + conversation=llm_call.prompt.messages, + tokenize=False, + add_generation_prompt=True, + tools=llm_call.prompt.tools, + ) + text = llm.tokenizer.apply_chat_template( + full_messages, + tokenize=False, + tools=llm_call.prompt.tools, + ) + prompt_token_ids = llm.tokenizer.apply_chat_template( + llm_call.prompt.messages, + add_special_tokens=True, + add_generation_prompt=True, + tools=llm_call.prompt.tools, + ) + + output_text = text[len(prompt_text):] + + tokenizer = llm.tokenizer + if tokenizer.bos_token and text.startswith(tokenizer.bos_token): + text = text[len(tokenizer.bos_token):] + + if not llm_call.logprobs: + raise ValueError("Logprobs are required to make training data for RL") + + labels = [lp.token_id for lp in llm_call.logprobs] + input_ids = prompt_token_ids + labels + labels = [MASKED_TOKEN_ID] * len(prompt_token_ids) + labels + logprobs = [lp.logprob for lp in llm_call.logprobs] + + finish_reason = llm_call.llm_info.get("finish_reason") + if finish_reason is not None: + finished = finish_reason != "length" + else: + eos_token = tokenizer.eos_token or "" + finished = bool(eos_token) and (llm_call.output.content or "").endswith(eos_token) + + return TrainingText( + text=text, + n_predicted=len(output_text), + input_ids=input_ids, + labels=labels, + logprobs=logprobs, + finished=finished, + prompt_tokens=llm_call.prompt_length_tokens, + output_tokens=llm_call.output_length_tokens, + ) From a3057bf7980bb2ecda8286d6bb3c28e2b9996247 Mon Sep 17 00:00:00 2001 From: rafapi Date: Fri, 20 Mar 2026 11:45:45 +0000 Subject: [PATCH 05/18] Port vLLM tool parser from mcp_tir --- pipelinerl/rl_tool_parser_plugin.py | 247 ++++++++++++++++++++++++++++ 1 file changed, 247 insertions(+) create mode 100644 pipelinerl/rl_tool_parser_plugin.py diff --git a/pipelinerl/rl_tool_parser_plugin.py b/pipelinerl/rl_tool_parser_plugin.py new file mode 100644 index 00000000..12e6fc2d --- /dev/null +++ b/pipelinerl/rl_tool_parser_plugin.py @@ -0,0 +1,247 @@ +""" +Tool parser plugin for RL tool calling format. +""" + +import json +import re +from typing import Any # noqa: F401 +import logging + +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser +from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ExtractedToolCallInformation, + ToolCall, + FunctionCall +) + + +@ToolParserManager.register_module("rl_tool") +class HermesRLToolParser(ToolParser): + """ + Tool parser for RL tool calling format using markers. + Supports both standard format and Apriel-style formats: + - [{...}, {...}] (preferred if present) + - [BEGIN FINAL RESPONSE] ... [END FINAL RESPONSE] wrapper + """ + + def __init__(self, tokenizer): + super().__init__(tokenizer) + + # Tool call markers + self.tool_call_start_token = "" + self.tool_call_end_token = "" + + # Regex pattern for parsing tool calls + self.tool_call_regex = re.compile( + r"(.*?)|(.*)", re.DOTALL + ) + + # Apriel-specific patterns + self.apriel_final_response_regex = re.compile( + r"\[BEGIN FINAL RESPONSE\](.*?)\[END FINAL RESPONSE\]", re.DOTALL + ) + # Prefer parsing aggregated tool calls from ... + # Be lenient: case-insensitive; tolerate missing closing tag by capturing to end. + self.apriel_tool_calls_regex = re.compile( + r"\s*(.*?)\s*(?:|$)", re.DOTALL | re.IGNORECASE + ) + + # State for streaming + self.current_tool_name_sent = False + self.prev_tool_call_arr = [] + self.current_tool_id = -1 + self.streamed_args_for_tool = [] + + def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Extract tool calls from the model output. + + Args: + model_output: The raw model output string + request: The request object + + Returns: + ExtractedToolCallInformation with tool calls and metadata + """ + logger = logging.getLogger("pipelinerl.tool_parser") + # Ensure variable exists for any fallback references below + final_response_match = None + + try: + # 1) Apriel aggregated tool calls block has priority + tool_calls_matches = list(self.apriel_tool_calls_regex.finditer(model_output)) + if tool_calls_matches: + # Use the last match (in case of multiple blocks) + last_match = tool_calls_matches[-1] + tool_calls_json = last_match.group(1).strip() + parsed_calls = [] + try: + parsed_calls = json.loads(tool_calls_json) if tool_calls_json else [] + except Exception: + logger.debug("Failed to parse aggregated JSON; falling back", exc_info=True) + parsed_calls = [] + + tool_calls: list[ToolCall] = [] + for i, pc in enumerate(parsed_calls): + try: + name = pc.get("name", "") + args_obj = pc.get("arguments", {}) + if not isinstance(args_obj, (dict, list, str, int, float, bool)): + args_obj = {} + args_str = json.dumps(args_obj, ensure_ascii=False) + call_id = pc.get("id", f"call_{i}") + tool_calls.append( + ToolCall( + id=call_id, + type="function", + function=FunctionCall(name=str(name), arguments=args_str), + ) + ) + except Exception: + logger.debug("Skipping malformed aggregated tool call", exc_info=True) + continue + + # Prefer final response content if present; otherwise empty string + final_response_match = self.apriel_final_response_regex.search(model_output) + content = final_response_match.group(1).strip() if final_response_match else "" + + return ExtractedToolCallInformation( + tools_called=bool(tool_calls), + tool_calls=tool_calls, + content=content, + ) + + # 2) Try bare JSON tool-calls (no tags), but only if tools are declared in the request + # Accept either a list of {name, arguments} or a single dict + try: + tools_declared = bool(getattr(request, "tools", None)) + except Exception: + tools_declared = False + + if tools_declared: + candidate_strings: list[str] = [] + final_response_match = self.apriel_final_response_regex.search(model_output) + if final_response_match: + candidate_strings.append(final_response_match.group(1).strip()) + candidate_strings.append(model_output.strip()) + + for candidate in candidate_strings: + try: + parsed = json.loads(candidate) + except Exception: + continue + parsed_list = [] + if isinstance(parsed, dict) and "name" in parsed and "arguments" in parsed: + parsed_list = [parsed] + elif isinstance(parsed, list) and all(isinstance(it, dict) for it in parsed): + parsed_list = [it for it in parsed if "name" in it and "arguments" in it] + if not parsed_list: + continue + tool_calls: list[ToolCall] = [] + for i, pc in enumerate(parsed_list): + try: + name = pc.get("name", "") + args_obj = pc.get("arguments", {}) + if not isinstance(args_obj, (dict, list, str, int, float, bool)): + args_obj = {} + args_str = json.dumps(args_obj, ensure_ascii=False) + call_id = pc.get("id", f"call_{i}") + tool_calls.append( + ToolCall( + id=call_id, + type="function", + function=FunctionCall(name=str(name), arguments=args_str), + ) + ) + except Exception: + logger.debug("Skipping malformed bare-JSON tool call", exc_info=True) + continue + content = final_response_match.group(1).strip() if final_response_match else "" + return ExtractedToolCallInformation( + tools_called=bool(tool_calls), + tool_calls=tool_calls, + content=content, + ) + + # 3) Fallback: look for single blocks (legacy / other models) + content_to_search = model_output + final_response_match = self.apriel_final_response_regex.search(model_output) + if final_response_match: + final_response_content = final_response_match.group(1).strip() + if self.tool_call_start_token in final_response_content: + content_to_search = final_response_content + elif self.tool_call_start_token not in model_output: + # No tool calls found, return final response as content + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=final_response_content + ) + + # Quick check to avoid unnecessary processing + if self.tool_call_start_token not in content_to_search: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output + ) + + # Find all tool call matches + function_call_tuples = self.tool_call_regex.findall(content_to_search) + + # Parse JSON from matches + tool_calls = [] + for i, match in enumerate(function_call_tuples): + json_str = match[0] if match[0] else match[1] + try: + parsed_call = json.loads(json_str.strip()) + args_obj = parsed_call.get("arguments", {}) + if not isinstance(args_obj, (dict, list, str, int, float, bool)): + args_obj = {} + tool_call = ToolCall( + id=f"call_{i}", + type="function", + function=FunctionCall( + name=str(parsed_call.get("name", "")), + arguments=json.dumps(args_obj, ensure_ascii=False) + ) + ) + tool_calls.append(tool_call) + except Exception: + logger.debug("Skipping malformed JSON", exc_info=True) + continue + + # Determine content based on whether we found tool calls + if tool_calls and final_response_match: + # If we found tool calls in final response, use just the tool calls + content = "" + elif final_response_match: + # If we have final response but no tool calls there, use final response + content = final_response_match.group(1).strip() + else: + # Standard processing + content = model_output + + return ExtractedToolCallInformation( + tools_called=bool(tool_calls), + tool_calls=tool_calls, + content=content + ) + + except Exception: + # Never propagate exceptions to the server; log and return a safe fallback. + logger.exception("Tool parser encountered an exception; returning safe fallback.") + if final_response_match: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=final_response_match.group(1).strip() + ) + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output + ) + \ No newline at end of file From c1e096b8b8b20c3a5353421fde788f2235c09362 Mon Sep 17 00:00:00 2001 From: rafapi Date: Fri, 20 Mar 2026 11:45:46 +0000 Subject: [PATCH 06/18] Create TIR domain package --- pipelinerl/domains/tir/__init__.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 pipelinerl/domains/tir/__init__.py diff --git a/pipelinerl/domains/tir/__init__.py b/pipelinerl/domains/tir/__init__.py new file mode 100644 index 00000000..4a658bd0 --- /dev/null +++ b/pipelinerl/domains/tir/__init__.py @@ -0,0 +1 @@ +from .rollouts import generate_tir_rollout From 23419a9693933b97408259383d54f0fd67bd602c Mon Sep 17 00:00:00 2001 From: rafapi Date: Fri, 20 Mar 2026 11:45:47 +0000 Subject: [PATCH 07/18] Implement multi-turn TIR rollout with SandboxFusion --- pipelinerl/domains/tir/rollouts.py | 340 +++++++++++++++++++++++++++++ 1 file changed, 340 insertions(+) create mode 100644 pipelinerl/domains/tir/rollouts.py diff --git a/pipelinerl/domains/tir/rollouts.py b/pipelinerl/domains/tir/rollouts.py new file mode 100644 index 00000000..bd8bb1db --- /dev/null +++ b/pipelinerl/domains/tir/rollouts.py @@ -0,0 +1,340 @@ +import asyncio +import json +import logging +import random +import re +import time + +import aiohttp +from omegaconf import DictConfig +from pydantic import BaseModel + +from sandbox_fusion import RunCodeRequest, set_sandbox_endpoint, run_code_async + +from pipelinerl.async_llm import llm_async_generate, make_training_text_with_tools +from pipelinerl.domains.math import RewardTable, get_reward, length_penalty, verify_answer_rpc +from pipelinerl.llm import Prompt, TrainableLLM +from pipelinerl.rollouts import BaseMetrics, RolloutResult +from pipelinerl.utils import get_environment_jobs, resolve_environment_key + +logger = logging.getLogger(__name__) + +_SANDBOX_CONFIGURED = False + +# Python safety blocklist: patterns that must not appear in user-submitted code +_BLOCKED_PATTERNS = [ + re.compile(r"\bsys\.exit\b"), + re.compile(r"\bos\._exit\b"), + re.compile(r"\bos\.system\b"), + re.compile(r"\bsubprocess\b"), + re.compile(r"\bos\.popen\b"), + re.compile(r"\bos\.exec\w*\b"), + re.compile(r"\bos\.spawn\w*\b"), + re.compile(r"\bos\.kill\b"), + re.compile(r"\bshutil\.rmtree\b"), + re.compile(r"\bos\.remove\b"), + re.compile(r"\bos\.unlink\b"), +] + + +def build_tool_definitions() -> list[dict]: + return [ + { + "type": "function", + "function": { + "name": "run_python_code", + "description": "Execute Python code. Print only the final result.", + "parameters": { + "type": "object", + "properties": {"code": {"type": "string", "description": "Python code to execute"}}, + "required": ["code"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "MathAnswer", + "description": "Submit the final answer in LaTeX \\boxed{} format.", + "parameters": { + "type": "object", + "properties": {"answer": {"type": "string", "description": "The final answer"}}, + "required": ["answer"], + }, + }, + }, + ] + + +def _check_code_safety(code: str) -> str | None: + """Return a rejection message if the code matches a blocked pattern, else None.""" + for pattern in _BLOCKED_PATTERNS: + if pattern.search(code): + return f"Blocked: code contains forbidden pattern '{pattern.pattern}'" + return None + + +async def execute_python_sandbox(code: str, endpoint: str, timeout: float) -> str: + """Execute Python code via SandboxFusion and return formatted output.""" + global _SANDBOX_CONFIGURED + if not _SANDBOX_CONFIGURED: + set_sandbox_endpoint(endpoint) + _SANDBOX_CONFIGURED = True + logger.info("Configured SandboxFusion endpoint: %s", endpoint) + + rejection = _check_code_safety(code) + if rejection is not None: + return rejection + + try: + request = RunCodeRequest(code=code, language="python", run_timeout=timeout) + response = await run_code_async(request) + + stdout = "" + stderr = "" + if response.run_result: + stdout = response.run_result.stdout or "" + stderr = response.run_result.stderr or "" + + status = response.status.value if hasattr(response.status, "value") else str(response.status) + is_timeout = "timeout" in status.lower() or "timeout" in (response.message or "").lower() + + parts = [] + if stdout: + parts.append(stdout.rstrip()) + if stderr: + parts.append(f"[stderr]\n{stderr.rstrip()}") + if is_timeout: + parts.append("[execution timed out]") + if not parts: + parts.append("[no output]") + return "\n".join(parts) + + except asyncio.TimeoutError: + return "[execution timed out]" + except Exception as exc: + logger.warning("SandboxFusion error: %s", exc) + return f"[execution error: {exc}]" + + +def _serialize_tool_calls(tool_calls) -> list[dict]: + """Serialize litellm tool call objects to dicts for conversation history.""" + return [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + for tc in tool_calls + ] + + +def _compute_shaping( + cfg: DictConfig, + answer_status: str, + num_python_calls: int, + llm_calls: list, + llm: TrainableLLM, +) -> float: + """Compute reward shaping (python tool bonus + length shaping).""" + total = 0.0 + + shaping_cfg = getattr(cfg, "python_tool_shaping", None) + if shaping_cfg is not None: + bonus = float(getattr(shaping_cfg, "bonus_on_correct_with_python", 0.0)) + penalty = float(getattr(shaping_cfg, "penalty_on_incorrect_without_python", 0.0)) + max_abs = float(getattr(shaping_cfg, "max_abs", 0.2)) + + if answer_status == "correct" and num_python_calls >= 1: + total += bonus + if answer_status in ("wrong", "unparsable") and num_python_calls == 0: + total -= penalty + + total = max(-max_abs, min(max_abs, total)) + + length_cfg = getattr(cfg, "length_shaping", None) + if length_cfg is not None: + try: + if hasattr(length_cfg, "target_ratio"): + ratio = float(getattr(length_cfg, "target_ratio")) + max_gen = int(llm.parameters.get("max_tokens", 2048)) + target_tokens = int(max(1, ratio * max_gen)) + min_t = int(getattr(length_cfg, "min_target_tokens", 0)) + max_t = int(getattr(length_cfg, "max_target_tokens", 10**9)) + target_tokens = max(min_t, min(max_t, target_tokens)) + else: + target_tokens = int(getattr(length_cfg, "target_output_tokens", 512)) + slope = float(getattr(length_cfg, "slope", 0.0)) + max_penalty = float(getattr(length_cfg, "max_penalty", 0.0)) + bonus_short_correct = float(getattr(length_cfg, "bonus_on_short_correct", 0.0)) + except Exception: + target_tokens, slope, max_penalty, bonus_short_correct = 512, 0.0, 0.0, 0.0 + + total_output_tokens = sum(getattr(c, "output_length_tokens", 0) for c in llm_calls) + avg_output_tokens = total_output_tokens / max(1, len(llm_calls)) + + if slope > 0.0 and max_penalty > 0.0 and avg_output_tokens > target_tokens: + over_by = float(avg_output_tokens - target_tokens) + total -= min(max_penalty, slope * over_by) + + if bonus_short_correct > 0.0 and answer_status == "correct" and avg_output_tokens <= target_tokens: + total += bonus_short_correct + + return total + + +class Metrics(BaseMetrics): + num_python_calls: int = 0 + num_steps: int = 0 + n_llm_calls: int = 0 + overflow: bool = False + + +async def generate_tir_rollout( + cfg: DictConfig, + llm: TrainableLLM, + problem: dict, + session: aiohttp.ClientSession, +) -> RolloutResult: + start = time.perf_counter() + + # 1. Build initial messages + messages: list[dict] = [] + if cfg.actor.system_prompt: + messages.append({"role": "system", "content": cfg.actor.system_prompt}) + messages.append({"role": "user", "content": cfg.actor.task_template.format(task=problem["task"])}) + + # 2. Tool definitions + tools = build_tool_definitions() + + # 3. Multi-turn loop + llm_calls = [] + final_answer = None + submitted_final_answer = False + num_python_calls = 0 + agent_max_loops = int(getattr(cfg.actor, "agent_max_loops", 3)) + sandbox_endpoint = str(cfg.sandbox_endpoint) + sandbox_timeout = float(cfg.sandbox_timeout) + + for _turn in range(agent_max_loops): + prompt = Prompt(messages=list(messages), tools=tools) + llm_call = await llm_async_generate(llm, prompt, session) + llm_calls.append(llm_call) + + if not llm_call.output.tool_calls: + # Text-only response, no tool call -- end the loop + break + + # Append assistant message with tool_calls to conversation history + assistant_msg: dict = {"role": "assistant", "content": llm_call.output.content or ""} + assistant_msg["tool_calls"] = _serialize_tool_calls(llm_call.output.tool_calls) + messages.append(assistant_msg) + + # Execute each tool call + for tc in llm_call.output.tool_calls: + if tc.function.name == "MathAnswer": + try: + args = json.loads(tc.function.arguments) + except json.JSONDecodeError: + args = {} + final_answer = args.get("answer", "") + submitted_final_answer = True + # Still append tool result for completeness + messages.append({ + "role": "tool", + "tool_call_id": tc.id, + "content": f"Answer submitted: {final_answer}", + }) + break + elif tc.function.name == "run_python_code": + try: + args = json.loads(tc.function.arguments) + except json.JSONDecodeError: + args = {} + code = args.get("code") or args.get("python_code", "") + result = await execute_python_sandbox(code, sandbox_endpoint, sandbox_timeout) + num_python_calls += 1 + messages.append({"role": "tool", "tool_call_id": tc.id, "content": result}) + else: + # Unknown tool, return error + messages.append({ + "role": "tool", + "tool_call_id": tc.id, + "content": f"Unknown tool: {tc.function.name}", + }) + + if submitted_final_answer: + break + + # 4. Determine prediction for grading + if final_answer is not None: + prediction = final_answer + elif llm_calls: + prediction = llm_calls[-1].output.content or "" + else: + prediction = "" + + # 5. Verify answer via math verifier + env_key = resolve_environment_key(cfg, default="math") + env_jobs = get_environment_jobs(cfg, env_key) + if not env_jobs: + raise RuntimeError("No environment servers available for math domain") + env_job = random.choice(env_jobs) + assert env_job.port is not None + answer_status = await verify_answer_rpc( + session=session, + host=env_job.hostname, + port=env_job.port, + prediction=prediction, + gold=problem["answer"], + strict=True, + ) + + # 6. Compute reward + reward_table = RewardTable(**dict(cfg.rewards)) + base_reward = get_reward(answer_status, submitted_final_answer, reward_table) + + discount_factor = float(getattr(cfg.actor, "discount_factor", 1.0)) + if discount_factor != 1.0: + total_generated_tokens = sum(getattr(c, "output_length_tokens", 0) for c in llm_calls) + base_reward *= discount_factor ** total_generated_tokens + + buffer_tokens = getattr(reward_table, "buffer_tokens", 0) + if buffer_tokens: + max_tokens = int(llm.parameters.get("max_tokens", 0)) + total_output_tokens = sum(getattr(c, "output_length_tokens", 0) for c in llm_calls) + if max_tokens > 0: + base_reward += length_penalty(max_tokens, total_output_tokens, buffer_tokens) + + shaping = _compute_shaping(cfg, answer_status, num_python_calls, llm_calls, llm) + reward = base_reward + shaping + + # 7. Build training texts (tool-call aware) + training_texts = [make_training_text_with_tools(llm, call) for call in llm_calls] + for text in training_texts: + text.reward = reward + text.finished = submitted_final_answer + + latency = time.perf_counter() - start + + metrics = Metrics( + reward=reward, + success=answer_status == "correct", + no_error=answer_status != "unparsable", + no_answer=answer_status == "no_answer", + num_python_calls=num_python_calls, + num_steps=len(llm_calls), + n_llm_calls=len(llm_calls), + overflow=not submitted_final_answer, + ) + + return RolloutResult( + training_texts=training_texts, + metrics=metrics, + latency=latency, + dataset_name=problem.get("dataset"), + domain="tir", + ) From f8ba3e0bd897322ed4e64a40d885466f901dcfed Mon Sep 17 00:00:00 2001 From: rafapi Date: Fri, 20 Mar 2026 11:45:50 +0000 Subject: [PATCH 08/18] Add TIR config inheriting from math --- conf/tir.yaml | 54 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 conf/tir.yaml diff --git a/conf/tir.yaml b/conf/tir.yaml new file mode 100644 index 00000000..781a5059 --- /dev/null +++ b/conf/tir.yaml @@ -0,0 +1,54 @@ +defaults: + - math + - _self_ + +actor: + rollout_policy: pipelinerl.domains.tir.generate_tir_rollout + system_prompt: | + You are a math-focused AI Agent. Solve problems by combining clear symbolic reasoning + with short, deterministic Python code. + Keep your replies concise and direct. Prioritize clarity and avoid over-elaboration. + Always present the final answer in LaTeX \boxed{}. + Do not express emotions or opinions about user questions. + + Workflow: + 1. Draft a brief plan in plain text. + 2. Execute one run_python_code call to compute or verify the result. + 3. Finalize by calling MathAnswer with the LaTeX-formatted answer. + + Python execution policy (run_python_code): + - Use Python strictly for pure computation to verify and validate the final answer. + - No network, file system, OS or environment access. + - Keep snippets minimal and self-contained; print only the final result. + + Validation: + - Cross-check results (alternative derivation, invariants, higher precision) before finalizing. + - If execution fails, propose the minimal fix and retry. + Always verify with run_python_code before invoking MathAnswer. + task_template: "{task}" + agent_max_loops: 3 + +# SandboxFusion config +sandbox_endpoint: ${oc.env:SANDBOX_ENDPOINT,http://127.0.0.1:8080} +sandbox_timeout: 30.0 + +# Optional reward shaping +python_tool_shaping: + bonus_on_correct_with_python: 0.1 + penalty_on_incorrect_without_python: 0.1 + max_abs: 0.2 + +# vLLM tool-call parser config +vllm_config: + vllm_kwargs: + enable-auto-tool-choice: "" + tool-call-parser: rl_tool + tool-parser-plugin: ${hydra:runtime.cwd}/pipelinerl/rl_tool_parser_plugin.py + max_model_len: 32000 + +llm: + parameters: + max_tokens: 8192 + +finetune: + seq_length: 32000 From 409f23e6962393d280ee701f95e5b15f3ae60b69 Mon Sep 17 00:00:00 2001 From: rafapi Date: Fri, 20 Mar 2026 11:45:51 +0000 Subject: [PATCH 09/18] Add tir optional dependency group --- pyproject.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4e73d8ba..996b4c6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,9 +72,12 @@ ifeval = [ "langdetect", "absl-py", ] +tir = [ + "sandbox-fusion>=0.3.7", +] # Install all domain dependencies domains = [ - "pipelinerl[coding,fn_calling,logic,ifeval]", + "pipelinerl[coding,fn_calling,logic,ifeval,tir]", ] [tool.setuptools.packages.find] From 8a9a9b7e2b991d7e5ffb65da51162e5e3b5540c9 Mon Sep 17 00:00:00 2001 From: rafapi Date: Fri, 20 Mar 2026 11:47:33 +0000 Subject: [PATCH 10/18] Pass tools to apply_chat_template for correct prompt reconstruction --- pipelinerl/async_llm.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pipelinerl/async_llm.py b/pipelinerl/async_llm.py index 9c0d4bfa..7289c99b 100644 --- a/pipelinerl/async_llm.py +++ b/pipelinerl/async_llm.py @@ -205,19 +205,24 @@ def make_training_text(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText: raise ValueError(f"Failed to process with vision-language processor: {e}") else: # Use tokenizer for text-only models + # Pass tools so the chat template matches what vLLM actually served + tools_kwarg = {"tools": llm_call.prompt.tools} if llm_call.prompt.tools else {} prompt_text = llm.tokenizer.apply_chat_template( conversation=llm_call.prompt.messages, tokenize=False, add_generation_prompt=True, + **tools_kwarg, ) text = llm.tokenizer.apply_chat_template( full_messages, tokenize=False, + **tools_kwarg, ) prompt_token_ids = llm.tokenizer.apply_chat_template( llm_call.prompt.messages, add_special_tokens=True, add_generation_prompt=True, + **tools_kwarg, ) output_text = text[len(prompt_text) :] From 0328249cf720e76be3957eea371a37a6c417ec95 Mon Sep 17 00:00:00 2001 From: rafapi Date: Fri, 20 Mar 2026 11:47:35 +0000 Subject: [PATCH 11/18] Guard against non-dict JSON in tool argument parsing --- pipelinerl/domains/tir/rollouts.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/pipelinerl/domains/tir/rollouts.py b/pipelinerl/domains/tir/rollouts.py index bd8bb1db..1aa812b2 100644 --- a/pipelinerl/domains/tir/rollouts.py +++ b/pipelinerl/domains/tir/rollouts.py @@ -235,14 +235,15 @@ async def generate_tir_rollout( # Execute each tool call for tc in llm_call.output.tool_calls: + try: + parsed = json.loads(tc.function.arguments) + except (json.JSONDecodeError, TypeError): + parsed = None + args = parsed if isinstance(parsed, dict) else {} + if tc.function.name == "MathAnswer": - try: - args = json.loads(tc.function.arguments) - except json.JSONDecodeError: - args = {} final_answer = args.get("answer", "") submitted_final_answer = True - # Still append tool result for completeness messages.append({ "role": "tool", "tool_call_id": tc.id, @@ -250,16 +251,11 @@ async def generate_tir_rollout( }) break elif tc.function.name == "run_python_code": - try: - args = json.loads(tc.function.arguments) - except json.JSONDecodeError: - args = {} code = args.get("code") or args.get("python_code", "") result = await execute_python_sandbox(code, sandbox_endpoint, sandbox_timeout) num_python_calls += 1 messages.append({"role": "tool", "tool_call_id": tc.id, "content": result}) else: - # Unknown tool, return error messages.append({ "role": "tool", "tool_call_id": tc.id, From 2c5ca49e0b333a22e93079e78fa99a5a2a441f82 Mon Sep 17 00:00:00 2001 From: rafapi Date: Fri, 20 Mar 2026 11:52:08 +0000 Subject: [PATCH 12/18] Recover bare-string tool arguments via fallback_key --- pipelinerl/domains/tir/rollouts.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/pipelinerl/domains/tir/rollouts.py b/pipelinerl/domains/tir/rollouts.py index 1aa812b2..a66f5fc9 100644 --- a/pipelinerl/domains/tir/rollouts.py +++ b/pipelinerl/domains/tir/rollouts.py @@ -132,6 +132,23 @@ def _serialize_tool_calls(tool_calls) -> list[dict]: ] +def _parse_tool_arguments(arguments: str, *, fallback_key: str | None = None) -> dict: + """Parse tool-call arguments into an object payload. + + Valid JSON that is not an object should not crash the rollout loop. A bare + string can still be recovered for simple single-field tool schemas. + """ + try: + parsed = json.loads(arguments) + except (json.JSONDecodeError, TypeError): + return {} + if isinstance(parsed, dict): + return parsed + if fallback_key is not None and isinstance(parsed, str): + return {fallback_key: parsed} + return {} + + def _compute_shaping( cfg: DictConfig, answer_status: str, @@ -235,13 +252,8 @@ async def generate_tir_rollout( # Execute each tool call for tc in llm_call.output.tool_calls: - try: - parsed = json.loads(tc.function.arguments) - except (json.JSONDecodeError, TypeError): - parsed = None - args = parsed if isinstance(parsed, dict) else {} - if tc.function.name == "MathAnswer": + args = _parse_tool_arguments(tc.function.arguments, fallback_key="answer") final_answer = args.get("answer", "") submitted_final_answer = True messages.append({ @@ -251,6 +263,7 @@ async def generate_tir_rollout( }) break elif tc.function.name == "run_python_code": + args = _parse_tool_arguments(tc.function.arguments, fallback_key="code") code = args.get("code") or args.get("python_code", "") result = await execute_python_sandbox(code, sandbox_endpoint, sandbox_timeout) num_python_calls += 1 From 5e15dc449e8ea308c4ef03af0e143064b77abefc Mon Sep 17 00:00:00 2001 From: rafapi Date: Fri, 20 Mar 2026 11:52:10 +0000 Subject: [PATCH 13/18] Use tools-aware template for all turns when prompt.tools is set --- pipelinerl/async_llm.py | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/pipelinerl/async_llm.py b/pipelinerl/async_llm.py index 7289c99b..987720e3 100644 --- a/pipelinerl/async_llm.py +++ b/pipelinerl/async_llm.py @@ -276,31 +276,30 @@ def make_training_texts_from_llm_calls( def make_training_text_with_tools(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText: - """Build a TrainingText for an assistant turn that may contain tool_calls. + """Build a TrainingText for a tool-enabled assistant turn. - For turns without tool_calls this delegates to ``make_training_text``. - When tool_calls are present the assistant message dict includes them so - that ``apply_chat_template`` produces the correct token sequence matching - what vLLM actually generated (and for which we have logprobs). + This helper keeps prompts on the same chat-template path used at generation + time whenever ``prompt.tools`` is set, even if the current assistant turn + itself is plain text with no tool_calls. """ - if not llm_call.output.tool_calls: + if not llm_call.prompt.tools and not llm_call.output.tool_calls: return make_training_text(llm, llm_call) llm.load_tokenizer() - # Build the assistant message with tool_calls assistant_msg: dict = {"role": "assistant", "content": llm_call.output.content or ""} - assistant_msg["tool_calls"] = [ - { - "id": tc.id, - "type": "function", - "function": { - "name": tc.function.name, - "arguments": tc.function.arguments, - }, - } - for tc in llm_call.output.tool_calls - ] + if llm_call.output.tool_calls: + assistant_msg["tool_calls"] = [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + for tc in llm_call.output.tool_calls + ] full_messages = llm_call.prompt.messages + [assistant_msg] From ebe16d4ac961ea257d27e642c5f20011dbda5cdf Mon Sep 17 00:00:00 2001 From: rafapi Date: Fri, 20 Mar 2026 12:33:47 +0000 Subject: [PATCH 14/18] Cap max_tokens per turn to fit within max_model_len --- pipelinerl/async_llm.py | 8 +++++++- pipelinerl/domains/tir/rollouts.py | 25 ++++++++++++++++++++++++- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/pipelinerl/async_llm.py b/pipelinerl/async_llm.py index 987720e3..e376f247 100644 --- a/pipelinerl/async_llm.py +++ b/pipelinerl/async_llm.py @@ -55,7 +55,10 @@ def _to_plain_obj(value): async def llm_async_generate( - llm: TrainableLLM, prompt: Prompt, session: aiohttp.ClientSession + llm: TrainableLLM, + prompt: Prompt, + session: aiohttp.ClientSession, + max_tokens_override: int | None = None, ) -> LLMCall: llm.load_tokenizer() headers = {"Content-Type": "application/json"} @@ -89,6 +92,9 @@ async def llm_async_generate( if prompt.tools: data["tools"] = _to_plain_obj(prompt.tools) + if max_tokens_override is not None: + data["max_tokens"] = max_tokens_override + # Merge extra_parameters first so that data (model, messages, logprobs settings) takes precedence payload = _to_plain_obj({**extra_parameters, **data}) async with session.post( diff --git a/pipelinerl/domains/tir/rollouts.py b/pipelinerl/domains/tir/rollouts.py index a66f5fc9..e024fbc8 100644 --- a/pipelinerl/domains/tir/rollouts.py +++ b/pipelinerl/domains/tir/rollouts.py @@ -235,10 +235,33 @@ async def generate_tir_rollout( agent_max_loops = int(getattr(cfg.actor, "agent_max_loops", 3)) sandbox_endpoint = str(cfg.sandbox_endpoint) sandbox_timeout = float(cfg.sandbox_timeout) + configured_max_tokens = int(llm.parameters.get("max_tokens", 16000)) + max_model_len = int(cfg.vllm_config.vllm_kwargs.get("max_model_len", 32000)) + # Reserve a minimum budget so the model can still produce a useful response + min_generation_tokens = 256 for _turn in range(agent_max_loops): prompt = Prompt(messages=list(messages), tools=tools) - llm_call = await llm_async_generate(llm, prompt, session) + + # Estimate prompt length and cap max_tokens to fit within max_model_len + llm.load_tokenizer() + prompt_token_ids = llm.tokenizer.apply_chat_template( + messages, + add_special_tokens=True, + add_generation_prompt=True, + tools=tools, + ) + prompt_len = len(prompt_token_ids) + remaining = max_model_len - prompt_len + if remaining < min_generation_tokens: + logger.warning( + "Prompt length %d leaves only %d tokens for generation (max_model_len=%d), stopping loop", + prompt_len, remaining, max_model_len, + ) + break + max_tokens_this_turn = min(configured_max_tokens, remaining) + + llm_call = await llm_async_generate(llm, prompt, session, max_tokens_override=max_tokens_this_turn) llm_calls.append(llm_call) if not llm_call.output.tool_calls: From 88ea7950203385299869f4b5e0e085262b92321a Mon Sep 17 00:00:00 2001 From: rafapi Date: Fri, 20 Mar 2026 12:36:49 +0000 Subject: [PATCH 15/18] Reduce llm_max_rollouts to 128, warn on max_tokens capping --- conf/tir.yaml | 48 ++++++++++++++++++++++++++++-- pipelinerl/domains/tir/rollouts.py | 5 ++++ 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/conf/tir.yaml b/conf/tir.yaml index 781a5059..606ab2dc 100644 --- a/conf/tir.yaml +++ b/conf/tir.yaml @@ -1,5 +1,6 @@ defaults: - - math + - base + - override rewards: success_and_format - _self_ actor: @@ -27,10 +28,32 @@ actor: Always verify with run_python_code before invoking MathAnswer. task_template: "{task}" agent_max_loops: 3 + llm_max_rollouts: 128 + max_rollout_retries: 20 + rollout_workers: 8 + shared_memory_entry_size: 1000000000 + +rewards: + correct_answer_not_finished: 0.0 + buffer_tokens: 0 + +# Math verifier environment +environments: + - key: math + mode: remote + _target_: pipelinerl.domains.math.MathEnvironment +environment_key: math +dataset_loader: pipelinerl.domains.math.load_datasets + +train_dataset_names: + - open_reasoner_zero_57k + - open_reasoner_zero_extended_72k +test_dataset_names: + - aime_2025 # SandboxFusion config sandbox_endpoint: ${oc.env:SANDBOX_ENDPOINT,http://127.0.0.1:8080} -sandbox_timeout: 30.0 +sandbox_timeout: 10.0 # Optional reward shaping python_tool_shaping: @@ -48,7 +71,26 @@ vllm_config: llm: parameters: - max_tokens: 8192 + max_tokens: 16000 + temperature: 1.0 + +test_llm: + parameters: + max_tokens: 16000 + temperature: 1.0 + top_p: 0.95 + top_k: 50 finetune: seq_length: 32000 + seq_parallel: 8 + gradient_accumulation_passes: 1024 + rl: + policy_loss: gspo + overlong_filtering: true + +preprocess: + input: actor + output: training_data + n_workers: 8 + shared_memory_entry_size: 1000000000 diff --git a/pipelinerl/domains/tir/rollouts.py b/pipelinerl/domains/tir/rollouts.py index e024fbc8..b9291c79 100644 --- a/pipelinerl/domains/tir/rollouts.py +++ b/pipelinerl/domains/tir/rollouts.py @@ -260,6 +260,11 @@ async def generate_tir_rollout( ) break max_tokens_this_turn = min(configured_max_tokens, remaining) + if max_tokens_this_turn < configured_max_tokens: + logger.warning( + "Turn %d: capping max_tokens from %d to %d (prompt_len=%d, max_model_len=%d)", + _turn, configured_max_tokens, max_tokens_this_turn, prompt_len, max_model_len, + ) llm_call = await llm_async_generate(llm, prompt, session, max_tokens_override=max_tokens_this_turn) llm_calls.append(llm_call) From f4f7370329b63586b695f97f6eec6527390cc87a Mon Sep 17 00:00:00 2001 From: rafapi Date: Fri, 17 Apr 2026 10:19:08 +0000 Subject: [PATCH 16/18] Fold tool_calls into make_training_text, drop _with_tools variant TIR now goes through make_training_texts_from_llm_calls like every other multi-turn domain, giving us a single per-turn converter and a single reward-injection point for future per-turn credit assignment. --- pipelinerl/async_llm.py | 97 +++++------------------------- pipelinerl/domains/tir/rollouts.py | 5 +- 2 files changed, 17 insertions(+), 85 deletions(-) diff --git a/pipelinerl/async_llm.py b/pipelinerl/async_llm.py index e376f247..28fae829 100644 --- a/pipelinerl/async_llm.py +++ b/pipelinerl/async_llm.py @@ -157,9 +157,20 @@ def make_training_text(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText: images = [] use_processor = False visual_features = None - full_messages = llm_call.prompt.messages + [ - {"role": "assistant", "content": llm_call.output.content} - ] + assistant_msg: dict = {"role": "assistant", "content": llm_call.output.content or ""} + if llm_call.output.tool_calls: + assistant_msg["tool_calls"] = [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + for tc in llm_call.output.tool_calls + ] + full_messages = llm_call.prompt.messages + [assistant_msg] if hasattr(llm_call.prompt, "messages"): images = extract_images_from_messages(llm_call.prompt.messages) @@ -253,7 +264,7 @@ def make_training_text(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText: finished = finish_reason != "length" else: eos_token = tokenizer.eos_token or "" - finished = bool(eos_token) and llm_call.output.content.endswith(eos_token) + finished = bool(eos_token) and (llm_call.output.content or "").endswith(eos_token) prompt_tokens = llm_call.prompt_length_tokens output_tokens = llm_call.output_length_tokens @@ -280,81 +291,3 @@ def make_training_texts_from_llm_calls( training_texts = apply_rollout_reward(training_texts, reward) return training_texts - -def make_training_text_with_tools(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText: - """Build a TrainingText for a tool-enabled assistant turn. - - This helper keeps prompts on the same chat-template path used at generation - time whenever ``prompt.tools`` is set, even if the current assistant turn - itself is plain text with no tool_calls. - """ - if not llm_call.prompt.tools and not llm_call.output.tool_calls: - return make_training_text(llm, llm_call) - - llm.load_tokenizer() - - assistant_msg: dict = {"role": "assistant", "content": llm_call.output.content or ""} - if llm_call.output.tool_calls: - assistant_msg["tool_calls"] = [ - { - "id": tc.id, - "type": "function", - "function": { - "name": tc.function.name, - "arguments": tc.function.arguments, - }, - } - for tc in llm_call.output.tool_calls - ] - - full_messages = llm_call.prompt.messages + [assistant_msg] - - prompt_text = llm.tokenizer.apply_chat_template( - conversation=llm_call.prompt.messages, - tokenize=False, - add_generation_prompt=True, - tools=llm_call.prompt.tools, - ) - text = llm.tokenizer.apply_chat_template( - full_messages, - tokenize=False, - tools=llm_call.prompt.tools, - ) - prompt_token_ids = llm.tokenizer.apply_chat_template( - llm_call.prompt.messages, - add_special_tokens=True, - add_generation_prompt=True, - tools=llm_call.prompt.tools, - ) - - output_text = text[len(prompt_text):] - - tokenizer = llm.tokenizer - if tokenizer.bos_token and text.startswith(tokenizer.bos_token): - text = text[len(tokenizer.bos_token):] - - if not llm_call.logprobs: - raise ValueError("Logprobs are required to make training data for RL") - - labels = [lp.token_id for lp in llm_call.logprobs] - input_ids = prompt_token_ids + labels - labels = [MASKED_TOKEN_ID] * len(prompt_token_ids) + labels - logprobs = [lp.logprob for lp in llm_call.logprobs] - - finish_reason = llm_call.llm_info.get("finish_reason") - if finish_reason is not None: - finished = finish_reason != "length" - else: - eos_token = tokenizer.eos_token or "" - finished = bool(eos_token) and (llm_call.output.content or "").endswith(eos_token) - - return TrainingText( - text=text, - n_predicted=len(output_text), - input_ids=input_ids, - labels=labels, - logprobs=logprobs, - finished=finished, - prompt_tokens=llm_call.prompt_length_tokens, - output_tokens=llm_call.output_length_tokens, - ) diff --git a/pipelinerl/domains/tir/rollouts.py b/pipelinerl/domains/tir/rollouts.py index b9291c79..83eb5644 100644 --- a/pipelinerl/domains/tir/rollouts.py +++ b/pipelinerl/domains/tir/rollouts.py @@ -11,7 +11,7 @@ from sandbox_fusion import RunCodeRequest, set_sandbox_endpoint, run_code_async -from pipelinerl.async_llm import llm_async_generate, make_training_text_with_tools +from pipelinerl.async_llm import llm_async_generate, make_training_texts_from_llm_calls from pipelinerl.domains.math import RewardTable, get_reward, length_penalty, verify_answer_rpc from pipelinerl.llm import Prompt, TrainableLLM from pipelinerl.rollouts import BaseMetrics, RolloutResult @@ -350,9 +350,8 @@ async def generate_tir_rollout( reward = base_reward + shaping # 7. Build training texts (tool-call aware) - training_texts = [make_training_text_with_tools(llm, call) for call in llm_calls] + training_texts = make_training_texts_from_llm_calls(llm, llm_calls, reward=reward) for text in training_texts: - text.reward = reward text.finished = submitted_final_answer latency = time.perf_counter() - start From cdbcec977b99ecc9d0f0e171647b208027c3df82 Mon Sep 17 00:00:00 2001 From: rafapi Date: Fri, 17 Apr 2026 12:35:45 +0000 Subject: [PATCH 17/18] Remove unnecessary comments --- pipelinerl/async_llm.py | 3 --- pipelinerl/domains/tir/rollouts.py | 16 ---------------- pipelinerl/rl_tool_parser_plugin.py | 3 +-- 3 files changed, 1 insertion(+), 21 deletions(-) diff --git a/pipelinerl/async_llm.py b/pipelinerl/async_llm.py index 28fae829..d93f0bda 100644 --- a/pipelinerl/async_llm.py +++ b/pipelinerl/async_llm.py @@ -221,8 +221,6 @@ def make_training_text(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText: except Exception as e: raise ValueError(f"Failed to process with vision-language processor: {e}") else: - # Use tokenizer for text-only models - # Pass tools so the chat template matches what vLLM actually served tools_kwarg = {"tools": llm_call.prompt.tools} if llm_call.prompt.tools else {} prompt_text = llm.tokenizer.apply_chat_template( conversation=llm_call.prompt.messages, @@ -244,7 +242,6 @@ def make_training_text(llm: TrainableLLM, llm_call: LLMCall) -> TrainingText: output_text = text[len(prompt_text) :] - # Get the appropriate tokenizer (from processor if using vision model) tokenizer = processor.tokenizer if use_processor else llm.tokenizer if tokenizer.bos_token and text.startswith(tokenizer.bos_token): diff --git a/pipelinerl/domains/tir/rollouts.py b/pipelinerl/domains/tir/rollouts.py index 83eb5644..6aaa4bd5 100644 --- a/pipelinerl/domains/tir/rollouts.py +++ b/pipelinerl/domains/tir/rollouts.py @@ -21,7 +21,6 @@ _SANDBOX_CONFIGURED = False -# Python safety blocklist: patterns that must not appear in user-submitted code _BLOCKED_PATTERNS = [ re.compile(r"\bsys\.exit\b"), re.compile(r"\bos\._exit\b"), @@ -67,7 +66,6 @@ def build_tool_definitions() -> list[dict]: def _check_code_safety(code: str) -> str | None: - """Return a rejection message if the code matches a blocked pattern, else None.""" for pattern in _BLOCKED_PATTERNS: if pattern.search(code): return f"Blocked: code contains forbidden pattern '{pattern.pattern}'" @@ -206,7 +204,6 @@ def _compute_shaping( class Metrics(BaseMetrics): num_python_calls: int = 0 num_steps: int = 0 - n_llm_calls: int = 0 overflow: bool = False @@ -218,16 +215,13 @@ async def generate_tir_rollout( ) -> RolloutResult: start = time.perf_counter() - # 1. Build initial messages messages: list[dict] = [] if cfg.actor.system_prompt: messages.append({"role": "system", "content": cfg.actor.system_prompt}) messages.append({"role": "user", "content": cfg.actor.task_template.format(task=problem["task"])}) - # 2. Tool definitions tools = build_tool_definitions() - # 3. Multi-turn loop llm_calls = [] final_answer = None submitted_final_answer = False @@ -237,13 +231,11 @@ async def generate_tir_rollout( sandbox_timeout = float(cfg.sandbox_timeout) configured_max_tokens = int(llm.parameters.get("max_tokens", 16000)) max_model_len = int(cfg.vllm_config.vllm_kwargs.get("max_model_len", 32000)) - # Reserve a minimum budget so the model can still produce a useful response min_generation_tokens = 256 for _turn in range(agent_max_loops): prompt = Prompt(messages=list(messages), tools=tools) - # Estimate prompt length and cap max_tokens to fit within max_model_len llm.load_tokenizer() prompt_token_ids = llm.tokenizer.apply_chat_template( messages, @@ -270,15 +262,12 @@ async def generate_tir_rollout( llm_calls.append(llm_call) if not llm_call.output.tool_calls: - # Text-only response, no tool call -- end the loop break - # Append assistant message with tool_calls to conversation history assistant_msg: dict = {"role": "assistant", "content": llm_call.output.content or ""} assistant_msg["tool_calls"] = _serialize_tool_calls(llm_call.output.tool_calls) messages.append(assistant_msg) - # Execute each tool call for tc in llm_call.output.tool_calls: if tc.function.name == "MathAnswer": args = _parse_tool_arguments(tc.function.arguments, fallback_key="answer") @@ -306,7 +295,6 @@ async def generate_tir_rollout( if submitted_final_answer: break - # 4. Determine prediction for grading if final_answer is not None: prediction = final_answer elif llm_calls: @@ -314,7 +302,6 @@ async def generate_tir_rollout( else: prediction = "" - # 5. Verify answer via math verifier env_key = resolve_environment_key(cfg, default="math") env_jobs = get_environment_jobs(cfg, env_key) if not env_jobs: @@ -330,7 +317,6 @@ async def generate_tir_rollout( strict=True, ) - # 6. Compute reward reward_table = RewardTable(**dict(cfg.rewards)) base_reward = get_reward(answer_status, submitted_final_answer, reward_table) @@ -349,7 +335,6 @@ async def generate_tir_rollout( shaping = _compute_shaping(cfg, answer_status, num_python_calls, llm_calls, llm) reward = base_reward + shaping - # 7. Build training texts (tool-call aware) training_texts = make_training_texts_from_llm_calls(llm, llm_calls, reward=reward) for text in training_texts: text.finished = submitted_final_answer @@ -363,7 +348,6 @@ async def generate_tir_rollout( no_answer=answer_status == "no_answer", num_python_calls=num_python_calls, num_steps=len(llm_calls), - n_llm_calls=len(llm_calls), overflow=not submitted_final_answer, ) diff --git a/pipelinerl/rl_tool_parser_plugin.py b/pipelinerl/rl_tool_parser_plugin.py index 12e6fc2d..29c28d9e 100644 --- a/pipelinerl/rl_tool_parser_plugin.py +++ b/pipelinerl/rl_tool_parser_plugin.py @@ -3,9 +3,8 @@ """ import json -import re -from typing import Any # noqa: F401 import logging +import re from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser from vllm.entrypoints.openai.tool_parsers import ToolParserManager From 24554fbee942da17133abfa5f6d6a486477f5ccc Mon Sep 17 00:00:00 2001 From: rafapi Date: Fri, 17 Apr 2026 12:50:38 +0000 Subject: [PATCH 18/18] Refactor shaping into class, dispatch tool handlers, share parser helper --- pipelinerl/domains/tir/rollouts.py | 187 ++++++++++++++++------------ pipelinerl/rl_tool_parser_plugin.py | 134 +++++++------------- 2 files changed, 151 insertions(+), 170 deletions(-) diff --git a/pipelinerl/domains/tir/rollouts.py b/pipelinerl/domains/tir/rollouts.py index 6aaa4bd5..c25a30b5 100644 --- a/pipelinerl/domains/tir/rollouts.py +++ b/pipelinerl/domains/tir/rollouts.py @@ -4,10 +4,11 @@ import random import re import time +from dataclasses import dataclass +from typing import Awaitable, Callable import aiohttp from omegaconf import DictConfig -from pydantic import BaseModel from sandbox_fusion import RunCodeRequest, set_sandbox_endpoint, run_code_async @@ -147,58 +148,100 @@ def _parse_tool_arguments(arguments: str, *, fallback_key: str | None = None) -> return {} -def _compute_shaping( - cfg: DictConfig, - answer_status: str, - num_python_calls: int, - llm_calls: list, - llm: TrainableLLM, -) -> float: - """Compute reward shaping (python tool bonus + length shaping).""" - total = 0.0 +@dataclass +class _ToolContext: + sandbox_endpoint: str + sandbox_timeout: float + messages: list[dict] + final_answer: str | None = None + submitted_final_answer: bool = False + num_python_calls: int = 0 - shaping_cfg = getattr(cfg, "python_tool_shaping", None) - if shaping_cfg is not None: - bonus = float(getattr(shaping_cfg, "bonus_on_correct_with_python", 0.0)) - penalty = float(getattr(shaping_cfg, "penalty_on_incorrect_without_python", 0.0)) - max_abs = float(getattr(shaping_cfg, "max_abs", 0.2)) - if answer_status == "correct" and num_python_calls >= 1: - total += bonus - if answer_status in ("wrong", "unparsable") and num_python_calls == 0: - total -= penalty +ToolHandler = Callable[[object, _ToolContext], Awaitable[None]] - total = max(-max_abs, min(max_abs, total)) - - length_cfg = getattr(cfg, "length_shaping", None) - if length_cfg is not None: - try: - if hasattr(length_cfg, "target_ratio"): - ratio = float(getattr(length_cfg, "target_ratio")) - max_gen = int(llm.parameters.get("max_tokens", 2048)) - target_tokens = int(max(1, ratio * max_gen)) - min_t = int(getattr(length_cfg, "min_target_tokens", 0)) - max_t = int(getattr(length_cfg, "max_target_tokens", 10**9)) - target_tokens = max(min_t, min(max_t, target_tokens)) - else: - target_tokens = int(getattr(length_cfg, "target_output_tokens", 512)) - slope = float(getattr(length_cfg, "slope", 0.0)) - max_penalty = float(getattr(length_cfg, "max_penalty", 0.0)) - bonus_short_correct = float(getattr(length_cfg, "bonus_on_short_correct", 0.0)) - except Exception: - target_tokens, slope, max_penalty, bonus_short_correct = 512, 0.0, 0.0, 0.0 - total_output_tokens = sum(getattr(c, "output_length_tokens", 0) for c in llm_calls) - avg_output_tokens = total_output_tokens / max(1, len(llm_calls)) +async def _handle_math_answer(tc, ctx: _ToolContext) -> None: + args = _parse_tool_arguments(tc.function.arguments, fallback_key="answer") + ctx.final_answer = args.get("answer", "") + ctx.submitted_final_answer = True + ctx.messages.append({ + "role": "tool", + "tool_call_id": tc.id, + "content": f"Answer submitted: {ctx.final_answer}", + }) - if slope > 0.0 and max_penalty > 0.0 and avg_output_tokens > target_tokens: - over_by = float(avg_output_tokens - target_tokens) - total -= min(max_penalty, slope * over_by) - if bonus_short_correct > 0.0 and answer_status == "correct" and avg_output_tokens <= target_tokens: - total += bonus_short_correct +async def _handle_run_python_code(tc, ctx: _ToolContext) -> None: + args = _parse_tool_arguments(tc.function.arguments, fallback_key="code") + code = args.get("code") or args.get("python_code", "") + result = await execute_python_sandbox(code, ctx.sandbox_endpoint, ctx.sandbox_timeout) + ctx.num_python_calls += 1 + ctx.messages.append({"role": "tool", "tool_call_id": tc.id, "content": result}) + + +async def _handle_unknown_tool(tc, ctx: _ToolContext) -> None: + ctx.messages.append({ + "role": "tool", + "tool_call_id": tc.id, + "content": f"Unknown tool: {tc.function.name}", + }) + - return total +_TOOL_HANDLERS: dict[str, ToolHandler] = { + "MathAnswer": _handle_math_answer, + "run_python_code": _handle_run_python_code, +} + + +class RewardShaper: + def __init__(self, cfg: DictConfig, llm: TrainableLLM): + self._python_cfg = getattr(cfg, "python_tool_shaping", None) + self._length_cfg = getattr(cfg, "length_shaping", None) + self._max_gen_tokens = int(llm.parameters.get("max_tokens", 2048)) + + def compute(self, answer_status: str, num_python_calls: int, llm_calls: list) -> float: + return ( + self._python_tool_bonus(answer_status, num_python_calls) + + self._length_adjustment(answer_status, llm_calls) + ) + + def _python_tool_bonus(self, answer_status: str, num_python_calls: int) -> float: + cfg = self._python_cfg + if cfg is None: + return 0.0 + bonus = float(getattr(cfg, "bonus_on_correct_with_python", 0.0)) + penalty = float(getattr(cfg, "penalty_on_incorrect_without_python", 0.0)) + max_abs = float(getattr(cfg, "max_abs", 0.2)) + total = 0.0 + if answer_status == "correct" and num_python_calls >= 1: + total += bonus + if answer_status in ("wrong", "unparsable") and num_python_calls == 0: + total -= penalty + return max(-max_abs, min(max_abs, total)) + + def _length_adjustment(self, answer_status: str, llm_calls: list) -> float: + cfg = self._length_cfg + if cfg is None or not llm_calls: + return 0.0 + if hasattr(cfg, "target_ratio"): + ratio = float(getattr(cfg, "target_ratio")) + target = int(max(1, ratio * self._max_gen_tokens)) + target = max(int(getattr(cfg, "min_target_tokens", 0)), target) + target = min(int(getattr(cfg, "max_target_tokens", 10**9)), target) + else: + target = int(getattr(cfg, "target_output_tokens", 512)) + slope = float(getattr(cfg, "slope", 0.0)) + max_penalty = float(getattr(cfg, "max_penalty", 0.0)) + bonus_short_correct = float(getattr(cfg, "bonus_on_short_correct", 0.0)) + + avg_out = sum(getattr(c, "output_length_tokens", 0) for c in llm_calls) / len(llm_calls) + total = 0.0 + if slope > 0.0 and max_penalty > 0.0 and avg_out > target: + total -= min(max_penalty, slope * (avg_out - target)) + if bonus_short_correct > 0.0 and answer_status == "correct" and avg_out <= target: + total += bonus_short_correct + return total class Metrics(BaseMetrics): @@ -223,12 +266,12 @@ async def generate_tir_rollout( tools = build_tool_definitions() llm_calls = [] - final_answer = None - submitted_final_answer = False - num_python_calls = 0 + ctx = _ToolContext( + sandbox_endpoint=str(cfg.sandbox_endpoint), + sandbox_timeout=float(cfg.sandbox_timeout), + messages=messages, + ) agent_max_loops = int(getattr(cfg.actor, "agent_max_loops", 3)) - sandbox_endpoint = str(cfg.sandbox_endpoint) - sandbox_timeout = float(cfg.sandbox_timeout) configured_max_tokens = int(llm.parameters.get("max_tokens", 16000)) max_model_len = int(cfg.vllm_config.vllm_kwargs.get("max_model_len", 32000)) min_generation_tokens = 256 @@ -269,34 +312,16 @@ async def generate_tir_rollout( messages.append(assistant_msg) for tc in llm_call.output.tool_calls: - if tc.function.name == "MathAnswer": - args = _parse_tool_arguments(tc.function.arguments, fallback_key="answer") - final_answer = args.get("answer", "") - submitted_final_answer = True - messages.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": f"Answer submitted: {final_answer}", - }) + handler = _TOOL_HANDLERS.get(tc.function.name, _handle_unknown_tool) + await handler(tc, ctx) + if ctx.submitted_final_answer: break - elif tc.function.name == "run_python_code": - args = _parse_tool_arguments(tc.function.arguments, fallback_key="code") - code = args.get("code") or args.get("python_code", "") - result = await execute_python_sandbox(code, sandbox_endpoint, sandbox_timeout) - num_python_calls += 1 - messages.append({"role": "tool", "tool_call_id": tc.id, "content": result}) - else: - messages.append({ - "role": "tool", - "tool_call_id": tc.id, - "content": f"Unknown tool: {tc.function.name}", - }) - - if submitted_final_answer: + + if ctx.submitted_final_answer: break - if final_answer is not None: - prediction = final_answer + if ctx.final_answer is not None: + prediction = ctx.final_answer elif llm_calls: prediction = llm_calls[-1].output.content or "" else: @@ -318,7 +343,7 @@ async def generate_tir_rollout( ) reward_table = RewardTable(**dict(cfg.rewards)) - base_reward = get_reward(answer_status, submitted_final_answer, reward_table) + base_reward = get_reward(answer_status, ctx.submitted_final_answer, reward_table) discount_factor = float(getattr(cfg.actor, "discount_factor", 1.0)) if discount_factor != 1.0: @@ -332,12 +357,12 @@ async def generate_tir_rollout( if max_tokens > 0: base_reward += length_penalty(max_tokens, total_output_tokens, buffer_tokens) - shaping = _compute_shaping(cfg, answer_status, num_python_calls, llm_calls, llm) + shaping = RewardShaper(cfg, llm).compute(answer_status, ctx.num_python_calls, llm_calls) reward = base_reward + shaping training_texts = make_training_texts_from_llm_calls(llm, llm_calls, reward=reward) for text in training_texts: - text.finished = submitted_final_answer + text.finished = ctx.submitted_final_answer latency = time.perf_counter() - start @@ -346,9 +371,9 @@ async def generate_tir_rollout( success=answer_status == "correct", no_error=answer_status != "unparsable", no_answer=answer_status == "no_answer", - num_python_calls=num_python_calls, + num_python_calls=ctx.num_python_calls, num_steps=len(llm_calls), - overflow=not submitted_final_answer, + overflow=not ctx.submitted_final_answer, ) return RolloutResult( diff --git a/pipelinerl/rl_tool_parser_plugin.py b/pipelinerl/rl_tool_parser_plugin.py index 29c28d9e..e48ec22c 100644 --- a/pipelinerl/rl_tool_parser_plugin.py +++ b/pipelinerl/rl_tool_parser_plugin.py @@ -9,12 +9,35 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.openai.protocol import ( - ChatCompletionRequest, + ChatCompletionRequest, ExtractedToolCallInformation, ToolCall, - FunctionCall + FunctionCall, ) +_JSON_SCALAR_TYPES = (dict, list, str, int, float, bool) + + +def _build_tool_call(index: int, parsed: dict, *, force_id: str | None = None) -> ToolCall | None: + try: + args_obj = parsed.get("arguments", {}) + if not isinstance(args_obj, _JSON_SCALAR_TYPES): + args_obj = {} + call_id = force_id if force_id is not None else parsed.get("id", f"call_{index}") + return ToolCall( + id=call_id, + type="function", + function=FunctionCall( + name=str(parsed.get("name", "")), + arguments=json.dumps(args_obj, ensure_ascii=False), + ), + ) + except Exception: + logging.getLogger("pipelinerl.tool_parser").debug( + "Skipping malformed tool call", exc_info=True + ) + return None + @ToolParserManager.register_module("rl_tool") class HermesRLToolParser(ToolParser): @@ -27,52 +50,35 @@ class HermesRLToolParser(ToolParser): def __init__(self, tokenizer): super().__init__(tokenizer) - - # Tool call markers + self.tool_call_start_token = "" self.tool_call_end_token = "" - - # Regex pattern for parsing tool calls + self.tool_call_regex = re.compile( r"(.*?)|(.*)", re.DOTALL ) - - # Apriel-specific patterns + self.apriel_final_response_regex = re.compile( r"\[BEGIN FINAL RESPONSE\](.*?)\[END FINAL RESPONSE\]", re.DOTALL ) - # Prefer parsing aggregated tool calls from ... - # Be lenient: case-insensitive; tolerate missing closing tag by capturing to end. + # Lenient match: case-insensitive and tolerate a missing closing tag. self.apriel_tool_calls_regex = re.compile( r"\s*(.*?)\s*(?:|$)", re.DOTALL | re.IGNORECASE ) - - # State for streaming + + # vLLM streaming hooks expect these attributes on the parser instance. self.current_tool_name_sent = False self.prev_tool_call_arr = [] self.current_tool_id = -1 self.streamed_args_for_tool = [] - + def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation: - """ - Extract tool calls from the model output. - - Args: - model_output: The raw model output string - request: The request object - - Returns: - ExtractedToolCallInformation with tool calls and metadata - """ logger = logging.getLogger("pipelinerl.tool_parser") - # Ensure variable exists for any fallback references below final_response_match = None try: - # 1) Apriel aggregated tool calls block has priority tool_calls_matches = list(self.apriel_tool_calls_regex.finditer(model_output)) if tool_calls_matches: - # Use the last match (in case of multiple blocks) last_match = tool_calls_matches[-1] tool_calls_json = last_match.group(1).strip() parsed_calls = [] @@ -82,27 +88,11 @@ def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) logger.debug("Failed to parse aggregated JSON; falling back", exc_info=True) parsed_calls = [] - tool_calls: list[ToolCall] = [] - for i, pc in enumerate(parsed_calls): - try: - name = pc.get("name", "") - args_obj = pc.get("arguments", {}) - if not isinstance(args_obj, (dict, list, str, int, float, bool)): - args_obj = {} - args_str = json.dumps(args_obj, ensure_ascii=False) - call_id = pc.get("id", f"call_{i}") - tool_calls.append( - ToolCall( - id=call_id, - type="function", - function=FunctionCall(name=str(name), arguments=args_str), - ) - ) - except Exception: - logger.debug("Skipping malformed aggregated tool call", exc_info=True) - continue + tool_calls = [ + tc for tc in (_build_tool_call(i, pc) for i, pc in enumerate(parsed_calls)) + if tc is not None + ] - # Prefer final response content if present; otherwise empty string final_response_match = self.apriel_final_response_regex.search(model_output) content = final_response_match.group(1).strip() if final_response_match else "" @@ -112,8 +102,6 @@ def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) content=content, ) - # 2) Try bare JSON tool-calls (no tags), but only if tools are declared in the request - # Accept either a list of {name, arguments} or a single dict try: tools_declared = bool(getattr(request, "tools", None)) except Exception: @@ -138,25 +126,10 @@ def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) parsed_list = [it for it in parsed if "name" in it and "arguments" in it] if not parsed_list: continue - tool_calls: list[ToolCall] = [] - for i, pc in enumerate(parsed_list): - try: - name = pc.get("name", "") - args_obj = pc.get("arguments", {}) - if not isinstance(args_obj, (dict, list, str, int, float, bool)): - args_obj = {} - args_str = json.dumps(args_obj, ensure_ascii=False) - call_id = pc.get("id", f"call_{i}") - tool_calls.append( - ToolCall( - id=call_id, - type="function", - function=FunctionCall(name=str(name), arguments=args_str), - ) - ) - except Exception: - logger.debug("Skipping malformed bare-JSON tool call", exc_info=True) - continue + tool_calls = [ + tc for tc in (_build_tool_call(i, pc) for i, pc in enumerate(parsed_list)) + if tc is not None + ] content = final_response_match.group(1).strip() if final_response_match else "" return ExtractedToolCallInformation( tools_called=bool(tool_calls), @@ -164,7 +137,7 @@ def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) content=content, ) - # 3) Fallback: look for single blocks (legacy / other models) + # Fallback: legacy blocks. content_to_search = model_output final_response_match = self.apriel_final_response_regex.search(model_output) if final_response_match: @@ -172,14 +145,12 @@ def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) if self.tool_call_start_token in final_response_content: content_to_search = final_response_content elif self.tool_call_start_token not in model_output: - # No tool calls found, return final response as content return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=final_response_content ) - # Quick check to avoid unnecessary processing if self.tool_call_start_token not in content_to_search: return ExtractedToolCallInformation( tools_called=False, @@ -187,40 +158,25 @@ def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) content=model_output ) - # Find all tool call matches function_call_tuples = self.tool_call_regex.findall(content_to_search) - # Parse JSON from matches tool_calls = [] for i, match in enumerate(function_call_tuples): json_str = match[0] if match[0] else match[1] try: parsed_call = json.loads(json_str.strip()) - args_obj = parsed_call.get("arguments", {}) - if not isinstance(args_obj, (dict, list, str, int, float, bool)): - args_obj = {} - tool_call = ToolCall( - id=f"call_{i}", - type="function", - function=FunctionCall( - name=str(parsed_call.get("name", "")), - arguments=json.dumps(args_obj, ensure_ascii=False) - ) - ) - tool_calls.append(tool_call) except Exception: logger.debug("Skipping malformed JSON", exc_info=True) continue + tc = _build_tool_call(i, parsed_call, force_id=f"call_{i}") + if tc is not None: + tool_calls.append(tc) - # Determine content based on whether we found tool calls if tool_calls and final_response_match: - # If we found tool calls in final response, use just the tool calls content = "" elif final_response_match: - # If we have final response but no tool calls there, use final response content = final_response_match.group(1).strip() else: - # Standard processing content = model_output return ExtractedToolCallInformation( @@ -230,7 +186,7 @@ def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) ) except Exception: - # Never propagate exceptions to the server; log and return a safe fallback. + # Never propagate to the vLLM server. logger.exception("Tool parser encountered an exception; returning safe fallback.") if final_response_match: return ExtractedToolCallInformation(