diff --git a/nemo_retriever/src/nemo_retriever/graph/react_agent_operator.py b/nemo_retriever/src/nemo_retriever/graph/react_agent_operator.py
new file mode 100644
index 0000000000..94777db5d9
--- /dev/null
+++ b/nemo_retriever/src/nemo_retriever/graph/react_agent_operator.py
@@ -0,0 +1,754 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Operator that runs a ReAct agentic retrieval loop per query."""
+
+from __future__ import annotations
+
+import json
+import logging
+import os
+
+import requests
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from typing import Any, Callable, Dict, List, Literal, Optional
+
+import pandas as pd
+
+from nemo_retriever.graph.abstract_operator import AbstractOperator
+from nemo_retriever.graph.cpu_operator import CPUOperator
+from nemo_retriever.nim.chat_completions import invoke_chat_completion_step
+
+logger = logging.getLogger(__name__)
+
+
+# ---------------------------------------------------------------------------
+# Prompt rendering (verbatim content of 02_v1.j2, rendered via Python)
+# ---------------------------------------------------------------------------
+
+_GOAL = """\
+You are a retrieval agent that finds all documents related to a given query.
+
+
+You are given a search query and a list of documents retrieved for that query. Your task is to write new \
+queries and use the given search tool to find *ALL* the related and somewhat related documents to the given \
+query (i.e., maximize recall).
+If the user's query is a question, you should not answer the question yourself. Instead, you should find \
+the related documents for the given query.
+"""
+
+_RELEVANCE_DEFINITION = """
+
+
+- You should be careful, in the context of this task, what it means to be a "query", "document", and \
+"relevant" can sometimes be very complex and might not follow the traditional definition of these terms \
+in standard information retrieval.
+- In standard retrieval, a query is usually a user question (like a web search query), the document is \
+some sort of content that provides information (e.g., a web page), and these two are considered relevant if \
+the document provides information that answers the user's query.
+- However, in our setting, this could be different. Here are some examples:
+ * the query is a programming problem and documents are programming language syntax references. A document \
+is relevant if it contains the reference for the programming syntax used for solving the problem.
+ * both query and documents are descriptions programming problems and a query and document are relevant if \
+the same approach is used to solve them.
+ * the query is a math problem and documents are theorems. Relevant documents (theorems) are the ones \
+that are useful for solving the math problem.
+ * the query and document are both math problems. A query and a document are relevant if the same theorem \
+is used for solving them.
+ * the query is a task description (e.g., for an API programmer) and documents are descriptions of \
+available APIs. Relevant documents (e.g., APIs) are the ones needed for completing the task.
+- This is not an exhaustive list. These are just some examples to show you the complexity of queries, \
+documents, and the concept of relevance in this task.
+- Note that even here, the relevant documents are still the ones that are useful for a user who is \
+searching for the given query. But the relation is more nuanced.
+- You should analyze the query and some of the available documents. And then reason about what could be a \
+meaningful definition of relevance in this case, and what the user could be looking for.
+- Moreover, sometimes, the query could be even a prompt that is given to a Large Language Model (LLM) and \
+the user wants to find the useful documents for the LLM that help answering/solving this prompt.
+"""
+
+_WORKFLOW_TEMPLATE = """
+
+- You are given a retrieval tool, powered by a dense embedding model, that takes a text query and returns \
+the most similar documents.
+{extended_relevance_line}\
+- You can call the search tool multiple times.
+- Search for related documents to the user's query from different angles.
+- If needed, revise your search queries based on the documents you find in previous steps.
+- Once you are confident that you have found all the related and somewhat related documents and there are \
+no more related documents in the corpus, call the "final_results" tool to finish the task.
+{enforce_top_k_line}\
+- When calling the "final_results" tool, the list of documents must be sorted in the decreasing level of \
+relevance to the query. I.e., the first document is the most relevant to the query, the second document is \
+the second most relevant to the query, and so on.
+"""
+
+_BEST_PRACTICES_TEMPLATE = """
+
+
+- You should be thorough and find all related and somewhat related documents.
+- The goal is to increase the **Recall** of your search attempt. So, if multiple documents are relevant \
+to the given query, you should find and report all of them even if only a subset of them is enough \
+for answering the query.
+{with_init_docs_line}\
+"""
+
+
+def _render_react_agent_prompt(
+ top_k: int,
+ *,
+ with_init_docs: bool = True,
+ enforce_top_k: bool = True,
+ extended_relevance: bool = False,
+) -> str:
+ """Render the ReAct agent system prompt (verbatim 02_v1.j2 logic)."""
+ parts = [_GOAL]
+ if extended_relevance:
+ parts.append(_RELEVANCE_DEFINITION)
+
+ ext_line = (
+ "- As explained above, reason and figure out what the meaning of relevance is in this case, "
+ "and what could be relevant and useful information for the given query.\n"
+ if extended_relevance
+ else ""
+ )
+ enforce_line = (
+ f'- When calling "final_results", you must select exactly the {top_k} most relevant documents '
+ "among all documents you have retrieved.\n"
+ if enforce_top_k
+ else ""
+ )
+ parts.append(
+ _WORKFLOW_TEMPLATE.format(
+ extended_relevance_line=ext_line,
+ enforce_top_k_line=enforce_line,
+ )
+ )
+
+ init_docs_line = (
+ "- **TIP**: you can look at the list of documents retrieved using the original query and think "
+ "what other queries you can use to find the potentially related documents that are missing in these results.\n"
+ if with_init_docs
+ else ""
+ )
+ parts.append(_BEST_PRACTICES_TEMPLATE.format(with_init_docs_line=init_docs_line))
+ return "".join(parts)
+
+
+# ---------------------------------------------------------------------------
+# Tool specs (verbatim from retrieval_bench/nemo_agentic/tool_helpers.py)
+# ---------------------------------------------------------------------------
+
+
+def _make_think_tool_spec(extended_relevance: bool = False) -> Dict[str, Any]:
+ ext = ""
+ if extended_relevance:
+ ext = (
+ "- When it is difficult to understand what is the intent of the user and what they are trying "
+ "to find with this query, use this tool to think about potential definitions of relevance that "
+ "could be meaningful/useful to the user for this task.\n"
+ "- If the intention of the user is vague especially given the available documents, use this tool "
+ "to think how you should decide what documents are relevant and what the metric of relevance is.\n"
+ )
+ description = (
+ "Use the tool to think about something. It will not obtain new information or make any changes, "
+ "but just log the thought. Use it when complex reasoning or brainstorming is needed.\n\n"
+ "Common use cases:\n"
+ f"{ext}"
+ "- When processing a complex query, use this tool to organize your thoughts and think about "
+ "the sub queries that you need to search for to find the relevant information\n"
+ "- If a query is vague is very difficult to find information for it, you can use this tool to think "
+ "about clues in the query that you can use to narrow down the search and spot relevant pieces of information.\n"
+ "- When finding related documents that help you create better search queries in the next step, use this "
+ "tool to think about what pieces of information from these documents are helpful to search for.\n"
+ "- When you fail to find any related information to the query, use this tool to think about other "
+ "search strategies that you can take to retrieve the related documents\n\n"
+ "The tool simply logs your thought process for better transparency and does not make any changes."
+ )
+ return {
+ "type": "function",
+ "function": {
+ "name": "think",
+ "description": description,
+ "parameters": {
+ "type": "object",
+ "properties": {"thought": {"type": "string", "description": "The thought to log."}},
+ "required": ["thought"],
+ },
+ },
+ }
+
+
+def _make_retrieve_tool_spec(top_k: int) -> Dict[str, Any]:
+ return {
+ "type": "function",
+ "function": {
+ "name": "retrieve",
+ "description": (
+ "Search for documents relevant to the given query using a dense embedding retrieval system. "
+ "Returns the most semantically similar documents from the corpus."
+ ),
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "query": {
+ "type": "string",
+ "description": "The search query to retrieve documents for.",
+ },
+ },
+ "required": ["query"],
+ },
+ },
+ }
+
+
+def _make_final_results_tool_spec(top_k: Optional[int]) -> Dict[str, Any]:
+ tk_ins = ""
+ if top_k is not None:
+ tk_ins = f"- You must choose exactly {top_k} document IDs when calling this function.\n"
+
+ description = (
+ "Signals the completion of the search process for the current query.\n\n"
+ "Use this tool when:\n"
+ "- You have found all the relevant documents to the query.\n"
+ "- Despite several attempts, you cannot find good documents for the given query.\n\n"
+ "The message should include:\n"
+ "- A brief summary of your exploration and the results\n"
+ "- Explanation if the search was unsuccessful\n\n"
+ "When reporting the selected document IDs, make sure:\n"
+ "- the list of document IDs is sorted in the decreasing level of relevance to the query. "
+ "I.e., the first document in the list is the most relevant to the query, the second is the "
+ "second most relevant to the query, and so on.\n"
+ f"{tk_ins}"
+ "\nThe successful_search field should be set to true if you believed you have found the most "
+ "relevant documents to the user's query, and false otherwise. And partial if it is in between."
+ )
+ return {
+ "type": "function",
+ "function": {
+ "name": "final_results",
+ "description": description,
+ "parameters": {
+ "type": "object",
+ "required": ["doc_ids", "message", "search_successful"],
+ "properties": {
+ "message": {
+ "type": "string",
+ "description": (
+ "A message for the user to explain why you think you found all the related "
+ "documents and there is no related document is missing. Also, include a short "
+ "description of your exploration process. If your attempts to find related "
+ "documents were unsuccessful, explain why."
+ ),
+ },
+ "doc_ids": {
+ "type": "array",
+ "items": {"type": "string"},
+ "description": (
+ "List of document IDs that are relevant to the user's query sorted descending "
+ "by their level of relevance to the user's query. I.e., the first document is "
+ "the most relevant to the query, the second is the second most relevant to the "
+ "query, and so on."
+ ),
+ },
+ "search_successful": {
+ "type": "string",
+ "enum": ["true", "false", "partial"],
+ "description": "Whether you managed to find all the related documents to the query.",
+ },
+ },
+ },
+ },
+ }
+
+
+# ---------------------------------------------------------------------------
+# Operator
+# ---------------------------------------------------------------------------
+
+#: Message sent when the LLM produces a stop without calling any tool.
+_AUTO_USER_MSG = (
+ "continue with the task. Do not re-read the query. Do not summarize your progress. "
+ "If you believe you have done all the required steps, call the `final_results` tool"
+)
+
+
+class ReActAgentOperator(AbstractOperator, CPUOperator):
+ """Run an iterative ReAct retrieval loop per query and emit the full retrieval log.
+
+ Each query row is processed independently by an LLM-driven ReAct loop
+ (Reason + Act) that has access to three tools: ``think``, ``retrieve``,
+ and ``final_results``. The operator emits one output row per retrieved
+ document per retrieval step, enabling downstream
+ :class:`RRFAggregatorOperator` to fuse the ranked lists with Reciprocal
+ Rank Fusion.
+
+ The system prompt is a verbatim Python rendering of the retrieval-bench
+ ``02_v1.j2`` template, including optional ``extended_relevance`` and
+ ``enforce_top_k`` blocks.
+
+ Input DataFrame schema
+ ----------------------
+ query_id : str — unique query identifier
+ query_text : str — the search query text
+ (additional columns are ignored)
+
+ Output DataFrame schema
+ -----------------------
+ query_id : str — same ``query_id`` as the input
+ query_text : str — same ``query_text`` (passed through for downstream)
+ step_idx : int — 0 = initial seed retrieval; 1 … N = per-loop retrieve calls
+ doc_id : str — retrieved document identifier
+ text : str — document text
+ rank : int — 1-indexed rank within this step (1 = most relevant)
+
+ Parameters
+ ----------
+ invoke_url : str
+ Full ``/v1/chat/completions`` endpoint URL.
+ llm_model : str
+ Model identifier forwarded to the endpoint.
+ retriever_fn : Callable[[str, int], list[dict]]
+ ``(query_text, top_k) → [{doc_id: str, text: str, ...}]``.
+ The callable is invoked for every retrieve tool call the agent makes.
+ Each returned dict must contain ``doc_id`` and ``text`` keys.
+ retriever_top_k : int
+ Number of documents fetched per retrieve call. Defaults to ``500``.
+ target_top_k : int
+ Number of final documents to select, communicated to the LLM via the
+ system prompt and ``final_results`` tool spec. Defaults to ``10``.
+ enforce_top_k : bool
+ When ``True``, the system prompt instructs the LLM to select exactly
+ ``target_top_k`` documents in its ``final_results`` call.
+ Defaults to ``True``.
+ user_msg_type : {"with_results", "simple"}
+ ``"with_results"`` (default): make one upfront retrieval call with the
+ original query and include those documents in the first user message,
+ mirroring the retrieval-bench ``with_results`` mode.
+ ``"simple"``: start the loop with just the query text.
+ extended_relevance : bool
+ Include the ```` block in the system prompt for
+ tasks with non-standard relevance definitions. Defaults to ``False``.
+ max_steps : int
+ Maximum ReAct loop iterations per query before forced exit.
+ Defaults to ``10``.
+ num_concurrent : int
+ Number of queries processed concurrently via ``ThreadPoolExecutor``.
+ Defaults to ``8``.
+ api_key : str, optional
+ Literal API key **or** an ``"os.environ/VAR_NAME"`` reference.
+ max_tokens : int, optional
+ Upper bound on tokens in each LLM response.
+
+ Notes
+ -----
+ ``retriever_fn`` must be serialisable when used with ``RayDataExecutor``.
+ Prefer module-level functions or picklable callable objects over lambdas.
+
+ Examples
+ --------
+ ::
+
+ from nemo_retriever.graph.react_agent_operator import ReActAgentOperator
+ from nemo_retriever.graph.rrf_aggregator_operator import RRFAggregatorOperator
+ from nemo_retriever.graph.selection_agent_operator import SelectionAgentOperator
+ from nemo_retriever.graph.executor import InprocessExecutor
+
+ def my_retriever(query_text: str, top_k: int) -> list[dict]:
+ # Returns [{doc_id, text, score?}, ...]
+ ...
+
+ pipeline = (
+ ReActAgentOperator(
+ invoke_url="https://integrate.api.nvidia.com/v1/chat/completions",
+ llm_model="nvidia/llama-3.3-nemotron-super-49b-v1",
+ retriever_fn=my_retriever,
+ retriever_top_k=500,
+ target_top_k=10,
+ )
+ >> RRFAggregatorOperator(k=60)
+ >> SelectionAgentOperator(
+ invoke_url="https://integrate.api.nvidia.com/v1/chat/completions",
+ llm_model="nvidia/llama-3.3-nemotron-super-49b-v1",
+ top_k=10,
+ )
+ )
+
+ result_df = InprocessExecutor(pipeline).ingest(query_df)
+ """
+
+ _NVIDIA_BUILD_ENDPOINT = "https://integrate.api.nvidia.com/v1/chat/completions"
+
+ def __init__(
+ self,
+ *,
+ invoke_url: Optional[str] = None,
+ llm_model: str,
+ retriever_fn: Callable[[str, int], List[Dict[str, Any]]],
+ retriever_top_k: int = 500,
+ target_top_k: int = 10,
+ enforce_top_k: bool = True,
+ user_msg_type: Literal["with_results", "simple"] = "with_results",
+ extended_relevance: bool = False,
+ max_steps: int = 10,
+ num_concurrent: int = 8,
+ api_key: Optional[str] = None,
+ max_tokens: Optional[int] = None,
+ parallel_tool_calls: bool = True,
+ ) -> None:
+ super().__init__()
+ self._invoke_url = invoke_url or self._NVIDIA_BUILD_ENDPOINT
+ self._llm_model = llm_model
+ self._retriever_fn = retriever_fn
+ self._retriever_top_k = retriever_top_k
+ self._target_top_k = target_top_k
+ self._enforce_top_k = enforce_top_k
+ self._user_msg_type = user_msg_type
+ self._extended_relevance = extended_relevance
+ self._max_steps = max_steps
+ self._num_concurrent = num_concurrent
+ self._api_key = api_key
+ self._max_tokens = max_tokens
+ self._parallel_tool_calls = parallel_tool_calls
+
+ # ------------------------------------------------------------------
+ # AbstractOperator interface
+ # ------------------------------------------------------------------
+
+ def preprocess(self, data: Any, **kwargs: Any) -> pd.DataFrame:
+ if not isinstance(data, pd.DataFrame):
+ raise TypeError(f"ReActAgentOperator expects a pd.DataFrame, got {type(data).__name__!r}.")
+ required = {"query_id", "query_text"}
+ missing = required - set(data.columns)
+ if missing:
+ raise ValueError(
+ f"Input DataFrame is missing required column(s): {sorted(missing)}. " f"Expected: {sorted(required)}."
+ )
+ return data[["query_id", "query_text"]].copy()
+
+ def process(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
+ """Run the ReAct loop for each query, concurrently up to num_concurrent."""
+ api_key = self._resolve_api_key()
+ rows: List[Dict[str, Any]] = []
+
+ query_rows = [(str(r["query_id"]), str(r["query_text"])) for _, r in data.iterrows()]
+
+ if len(query_rows) == 1:
+ # Fast path: single query, no threading overhead
+ qid, qtxt = query_rows[0]
+ rows.extend(self._run_single_query(qid, qtxt, api_key))
+ else:
+ with ThreadPoolExecutor(max_workers=min(self._num_concurrent, len(query_rows))) as executor:
+ futures = {
+ executor.submit(self._run_single_query, qid, qtxt, api_key): (qid, qtxt) for qid, qtxt in query_rows
+ }
+ for future in as_completed(futures):
+ try:
+ rows.extend(future.result())
+ except TimeoutError as exc:
+ qid, qtxt = futures[future]
+ logger.warning("ReActAgentOperator: query %r timed out: %s", qid, exc, exc_info=True)
+ except RuntimeError as exc:
+ qid, qtxt = futures[future]
+ logger.warning("ReActAgentOperator: query %r retries exhausted: %s", qid, exc, exc_info=True)
+ except requests.RequestException as exc:
+ qid, qtxt = futures[future]
+ logger.warning("ReActAgentOperator: query %r HTTP error: %s", qid, exc, exc_info=True)
+ except (json.JSONDecodeError, ValueError) as exc:
+ qid, qtxt = futures[future]
+ logger.warning("ReActAgentOperator: query %r data error: %s", qid, exc, exc_info=True)
+ except Exception as exc: # catches unexpected worker errors not covered above
+ qid, qtxt = futures[future]
+ logger.warning("ReActAgentOperator: query %r failed: %s", qid, exc, exc_info=True)
+
+ if not rows:
+ return pd.DataFrame(columns=["query_id", "query_text", "step_idx", "doc_id", "text", "rank"])
+
+ return pd.DataFrame(rows)
+
+ def postprocess(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
+ return data
+
+ # ------------------------------------------------------------------
+ # Internal: single query ReAct loop
+ # ------------------------------------------------------------------
+
+ def _run_single_query(
+ self,
+ query_id: str,
+ query_text: str,
+ api_key: Optional[str],
+ ) -> List[Dict[str, Any]]:
+ """Run the full ReAct loop for one query; return a list of row dicts."""
+ with_init_docs = self._user_msg_type == "with_results"
+
+ system_prompt = _render_react_agent_prompt(
+ self._target_top_k,
+ with_init_docs=with_init_docs,
+ enforce_top_k=self._enforce_top_k,
+ extended_relevance=self._extended_relevance,
+ )
+ tools = [
+ _make_think_tool_spec(self._extended_relevance),
+ _make_retrieve_tool_spec(self._retriever_top_k),
+ _make_final_results_tool_spec(self._target_top_k if self._enforce_top_k else None),
+ ]
+
+ messages: List[Dict[str, Any]] = [{"role": "system", "content": system_prompt}]
+
+ # Retrieval log: one list per step, each item is {doc_id, text, score?}
+ retrieval_log: List[List[Dict[str, Any]]] = []
+ seen_doc_ids: set[str] = set()
+ step_counter = 0
+
+ # ------ optional initial retrieval (with_results mode) ------
+ if with_init_docs:
+ init_docs = self._call_retriever(query_text, seen_doc_ids, api_key)
+ retrieval_log.append(init_docs)
+ step_counter += 1
+ for d in init_docs:
+ seen_doc_ids.add(d["doc_id"])
+
+ doc_content = _docs_to_message_content(init_docs)
+ user_msg_content: List[Dict[str, Any]] = [
+ {"type": "text", "text": f"Query:\n{query_text}\n\nRetrieved Documents:"}
+ ] + doc_content
+ messages.append({"role": "user", "content": user_msg_content})
+ else:
+ messages.append({"role": "user", "content": f"Query:\n{query_text}"})
+
+ final_doc_ids: Optional[List[str]] = None
+
+ # ------ main ReAct loop ------
+ for _step in range(self._max_steps):
+ logger.debug("query=%r loop_step=%d seen_docs=%d", query_id, _step, len(seen_doc_ids))
+ try:
+ response = invoke_chat_completion_step(
+ invoke_url=self._invoke_url,
+ messages=messages,
+ model=self._llm_model,
+ api_key=api_key,
+ tools=tools,
+ tool_choice="auto",
+ max_tokens=self._max_tokens,
+ extra_body={"parallel_tool_calls": False} if not self._parallel_tool_calls else None,
+ )
+ except TimeoutError as exc:
+ logger.warning(
+ "ReActAgentOperator: LLM call timed out on step %d for query %r: %s",
+ _step,
+ query_id,
+ exc,
+ exc_info=True,
+ )
+ break
+ except RuntimeError as exc:
+ logger.warning(
+ "ReActAgentOperator: LLM retries exhausted on step %d for query %r: %s",
+ _step,
+ query_id,
+ exc,
+ exc_info=True,
+ )
+ break
+ except requests.RequestException as exc:
+ logger.warning(
+ "ReActAgentOperator: LLM HTTP error on step %d for query %r: %s",
+ _step,
+ query_id,
+ exc,
+ exc_info=True,
+ )
+ break
+ except json.JSONDecodeError as exc:
+ logger.warning(
+ "ReActAgentOperator: LLM returned invalid JSON on step %d for query %r: %s",
+ _step,
+ query_id,
+ exc,
+ exc_info=True,
+ )
+ break
+
+ if not response.get("choices"):
+ logger.warning(
+ "ReActAgentOperator: empty choices in API response on step %d for query %r", _step, query_id
+ )
+ break
+ choice = response["choices"][0]
+ msg = choice["message"]
+ finish_reason = choice.get("finish_reason")
+ tool_calls = msg.get("tool_calls") or []
+
+ # Append assistant turn
+ assistant_turn: Dict[str, Any] = {"role": "assistant"}
+ if msg.get("content"):
+ assistant_turn["content"] = msg["content"]
+ if tool_calls:
+ assistant_turn["tool_calls"] = tool_calls
+ messages.append(assistant_turn)
+
+ if finish_reason == "stop" or not tool_calls:
+ messages.append({"role": "user", "content": _AUTO_USER_MSG})
+ continue
+
+ tool_messages: List[Dict[str, Any]] = []
+ loop_done = False
+
+ for tc in tool_calls:
+ tc_id = tc.get("id", "")
+ fn = tc.get("function", {})
+ fn_name = fn.get("name", "")
+ try:
+ fn_args = json.loads(fn.get("arguments", "{}"))
+ except json.JSONDecodeError:
+ tool_messages.append(
+ {"role": "tool", "tool_call_id": tc_id, "content": "Error: could not parse tool arguments."}
+ )
+ continue
+
+ if fn_name == "think":
+ logger.debug("query=%r step=%d [think] %s", query_id, _step, str(fn_args.get("thought", ""))[:120])
+ tool_messages.append(
+ {"role": "tool", "tool_call_id": tc_id, "content": "Your thought has been logged."}
+ )
+
+ elif fn_name == "retrieve":
+ subquery = str(fn_args.get("query", query_text))
+ logger.debug("query=%r step=%d [retrieve] subquery=%r", query_id, _step, subquery)
+ retrieved = self._call_retriever(subquery, seen_doc_ids, api_key)
+ logger.debug("query=%r step=%d [retrieve] got %d new docs", query_id, _step, len(retrieved))
+ retrieval_log.append(retrieved)
+ step_counter += 1
+ for d in retrieved:
+ seen_doc_ids.add(d["doc_id"])
+ doc_content = _docs_to_message_content(retrieved)
+ tool_content: List[Dict[str, Any]] = [
+ {"type": "text", "text": f"Retrieved {len(retrieved)} documents:"}
+ ] + doc_content
+ tool_messages.append({"role": "tool", "tool_call_id": tc_id, "content": tool_content})
+
+ elif fn_name == "final_results":
+ raw_ids: List[str] = fn_args.get("doc_ids", [])
+ logger.debug("query=%r step=%d [final_results] doc_ids=%s", query_id, _step, raw_ids)
+ if isinstance(raw_ids, list) and raw_ids:
+ final_doc_ids = [str(d) for d in raw_ids]
+ tool_messages.append(
+ {
+ "role": "tool",
+ "tool_call_id": tc_id,
+ "content": "The results have been successfully logged and the interaction ended.",
+ }
+ )
+ loop_done = True
+
+ else:
+ tool_messages.append(
+ {"role": "tool", "tool_call_id": tc_id, "content": f"Error: unknown tool '{fn_name}'."}
+ )
+
+ messages.extend(tool_messages)
+ if loop_done:
+ break
+
+ return _build_output_rows(query_id, query_text, retrieval_log, final_doc_ids)
+
+ def _call_retriever(
+ self,
+ query_text: str,
+ seen_doc_ids: set[str],
+ api_key: Optional[str],
+ ) -> List[Dict[str, Any]]:
+ """Call retriever_fn, over-fetching to ensure new results after dedup."""
+ fetch_k = self._retriever_top_k + len(seen_doc_ids)
+ try:
+ raw = self._retriever_fn(query_text, fetch_k)
+ except TimeoutError as exc:
+ logger.warning(
+ "ReActAgentOperator: retriever_fn timed out for query %r: %s", query_text, exc, exc_info=True
+ )
+ return []
+ except (TypeError, ValueError) as exc:
+ logger.warning(
+ "ReActAgentOperator: retriever_fn bad call/return for query %r: %s", query_text, exc, exc_info=True
+ )
+ return []
+ except Exception as exc: # retriever_fn is user-supplied; catches remaining unexpected errors.
+ logger.warning("ReActAgentOperator: retriever_fn failed for query %r: %s", query_text, exc, exc_info=True)
+ return []
+
+ # Filter already-seen and normalise keys
+ results: List[Dict[str, Any]] = []
+ for item in raw:
+ doc_id = str(item.get("doc_id", item.get("id", "")))
+ text = str(item.get("text", ""))
+ score = float(item.get("score", 0.0))
+ if doc_id and doc_id not in seen_doc_ids:
+ results.append({"doc_id": doc_id, "text": text, "score": score})
+ if len(results) >= self._retriever_top_k:
+ break
+
+ return results
+
+ def _resolve_api_key(self) -> Optional[str]:
+ api_key = self._api_key
+ if api_key is not None and api_key.strip().startswith("os.environ/"):
+ var = api_key.strip().removeprefix("os.environ/")
+ value = os.environ.get(var)
+ if value is None:
+ raise ValueError(
+ f"Environment variable '{var}' is not set. " f"Set it with: export {var}="
+ )
+ return value
+ return api_key
+
+
+# ---------------------------------------------------------------------------
+# Module-level helpers
+# ---------------------------------------------------------------------------
+
+
+def _docs_to_message_content(docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """Convert a list of doc dicts to LLM message content blocks."""
+ content: List[Dict[str, Any]] = []
+ for doc in docs:
+ doc_id = doc.get("doc_id", "")
+ text = doc.get("text", "").strip()
+ entry: Dict[str, Any] = {"id": doc_id}
+ if text:
+ entry["text"] = text
+ score = doc.get("score")
+ if score is not None:
+ entry["score"] = score
+ content.append({"type": "text", "text": json.dumps(entry)})
+ return content
+
+
+def _build_output_rows(
+ query_id: str,
+ query_text: str,
+ retrieval_log: List[List[Dict[str, Any]]],
+ final_doc_ids: Optional[List[str]],
+) -> List[Dict[str, Any]]:
+ """Convert the retrieval log to one row per (step_idx, rank, doc_id)."""
+ rows: List[Dict[str, Any]] = []
+ for step_idx, step_docs in enumerate(retrieval_log):
+ for rank, doc in enumerate(step_docs, 1):
+ rows.append(
+ {
+ "query_id": query_id,
+ "query_text": query_text,
+ "step_idx": step_idx,
+ "doc_id": doc.get("doc_id", ""),
+ "text": doc.get("text", ""),
+ "rank": rank,
+ }
+ )
+
+ # If final_results was called, also emit those as a synthetic final step
+ # (step_idx = len(retrieval_log)) so RRF can optionally weight it.
+ # These are already covered by the existing steps, so we skip deduplication
+ # here — RRF will naturally up-weight docs that appeared in final_results
+ # because they were retrieved in earlier steps.
+ return rows
diff --git a/nemo_retriever/src/nemo_retriever/graph/rrf_aggregator_operator.py b/nemo_retriever/src/nemo_retriever/graph/rrf_aggregator_operator.py
new file mode 100644
index 0000000000..39053bc3da
--- /dev/null
+++ b/nemo_retriever/src/nemo_retriever/graph/rrf_aggregator_operator.py
@@ -0,0 +1,133 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Operator that fuses per-step retrieval results using Reciprocal Rank Fusion."""
+
+from __future__ import annotations
+
+from collections import defaultdict
+from typing import Any, Dict, List
+
+import pandas as pd
+
+from nemo_retriever.graph.abstract_operator import AbstractOperator
+from nemo_retriever.graph.cpu_operator import CPUOperator
+
+
+class RRFAggregatorOperator(AbstractOperator, CPUOperator):
+ """Fuse multiple per-step ranked lists into a single ranking per query using RRF.
+
+ Implements the Reciprocal Rank Fusion formula
+ ``score(d) = sum(1 / (rank_i + k))`` across all retrieval steps where
+ document *d* appears. This is the same formula used in
+ ``retrieval_bench/nemo_agentic/utils.py:rrf_from_subquery_results``.
+
+ Designed to consume the output of :class:`ReActAgentOperator` (or any
+ operator that emits one row per ``(query_id, step_idx, doc_id)`` triple)
+ and produce a single fused ranking per ``query_id`` suitable as input to
+ :class:`SelectionAgentOperator`.
+
+ Input DataFrame schema
+ ----------------------
+ query_id : str — unique query identifier
+ query_text : str — original query text (carried through)
+ step_idx : int — which retrieval step produced this row (0, 1, 2 …)
+ doc_id : str — retrieved document identifier
+ text : str — document text content
+ rank : int — 1-indexed rank within its step (1 = most relevant)
+ (additional columns are ignored)
+
+ Output DataFrame schema
+ -----------------------
+ query_id : str — same ``query_id`` as the input
+ query_text: str — original query text (first occurrence per query)
+ doc_id : str — document identifier
+ rrf_score : float — fused RRF score (higher = more relevant)
+ text : str — document text (first occurrence per ``doc_id``)
+ Rows are sorted by ``rrf_score`` descending within each ``query_id``.
+
+ Parameters
+ ----------
+ k : int
+ RRF damping factor. The standard value is ``60`` (default).
+ Larger values reduce the influence of top-ranked documents.
+
+ Examples
+ --------
+ ::
+
+ import pandas as pd
+ from nemo_retriever.graph.rrf_aggregator_operator import RRFAggregatorOperator
+
+ op = RRFAggregatorOperator(k=60)
+ df = pd.DataFrame({
+ "query_id": ["q1", "q1", "q1", "q1"],
+ "query_text": ["inflation causes"] * 4,
+ "step_idx": [0, 0, 1, 1 ],
+ "doc_id": ["d1", "d2", "d1", "d3"],
+ "text": ["t1", "t2", "t1", "t3"],
+ "rank": [1, 2, 1, 2 ],
+ })
+ result = op.run(df)
+ # d1 appears in both steps at rank 1 → highest RRF score
+ """
+
+ def __init__(self, *, k: int = 60) -> None:
+ super().__init__()
+ self._k = k
+
+ # ------------------------------------------------------------------
+ # AbstractOperator interface
+ # ------------------------------------------------------------------
+
+ def preprocess(self, data: Any, **kwargs: Any) -> pd.DataFrame:
+ if not isinstance(data, pd.DataFrame):
+ raise TypeError(f"RRFAggregatorOperator expects a pd.DataFrame, got {type(data).__name__!r}.")
+ required = {"query_id", "query_text", "step_idx", "doc_id", "text", "rank"}
+ missing = required - set(data.columns)
+ if missing:
+ raise ValueError(
+ f"Input DataFrame is missing required column(s): {sorted(missing)}. " f"Expected: {sorted(required)}."
+ )
+ return data.copy()
+
+ def process(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
+ """Compute RRF scores, group by query_id, sort by score descending."""
+ k = self._k
+ rows: List[Dict[str, Any]] = []
+
+ for query_id, qgroup in data.groupby("query_id", sort=False):
+ query_text = str(qgroup["query_text"].iloc[0])
+
+ rrf_scores: Dict[str, float] = defaultdict(float)
+ first_text: Dict[str, str] = {}
+
+ # Process each step's ranked list
+ for _step_idx, sgroup in qgroup.groupby("step_idx", sort=True):
+ # Sort by rank ascending so rank=1 is processed first
+ for _, row in sgroup.sort_values("rank").iterrows():
+ doc_id = str(row["doc_id"])
+ rank = int(row["rank"])
+ rrf_scores[doc_id] += 1.0 / (rank + k)
+ if doc_id not in first_text:
+ first_text[doc_id] = str(row["text"])
+
+ for doc_id, score in sorted(rrf_scores.items(), key=lambda kv: kv[1], reverse=True):
+ rows.append(
+ {
+ "query_id": query_id,
+ "query_text": query_text,
+ "doc_id": doc_id,
+ "rrf_score": score,
+ "text": first_text.get(doc_id, ""),
+ }
+ )
+
+ if not rows:
+ return pd.DataFrame(columns=["query_id", "query_text", "doc_id", "rrf_score", "text"])
+
+ return pd.DataFrame(rows)
+
+ def postprocess(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
+ return data
diff --git a/nemo_retriever/src/nemo_retriever/graph/selection_agent_operator.py b/nemo_retriever/src/nemo_retriever/graph/selection_agent_operator.py
new file mode 100644
index 0000000000..14371969a8
--- /dev/null
+++ b/nemo_retriever/src/nemo_retriever/graph/selection_agent_operator.py
@@ -0,0 +1,512 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Operator that re-ranks retrieved documents using an LLM-based selection agent."""
+
+from __future__ import annotations
+
+import json
+import logging
+import os
+
+import requests
+from typing import Any, Dict, List, Optional
+
+import pandas as pd
+
+from nemo_retriever.graph.abstract_operator import AbstractOperator
+from nemo_retriever.graph.cpu_operator import CPUOperator
+from nemo_retriever.nim.chat_completions import invoke_chat_completion_step
+
+logger = logging.getLogger(__name__)
+
+# ---------------------------------------------------------------------------
+# Prompt rendering (verbatim content of 01_v0.j2, rendered via Python)
+# ---------------------------------------------------------------------------
+
+_ROLE = """\
+You are a document re-ranker agent, which is the final stage in an information retrieval pipeline.
+
+
+You are given a search query and a list of retrieved candidate documents that are potentially relevant to \
+the given query. Your goal is to help the users identify the most relevant documents to the given query \
+from the list of candidate documents.
+"""
+
+_RELEVANCE_DEFINITION = """\
+
+
+- You should be careful, in the context of this task, what it means to be a "query", "document", and \
+"relevant" can sometimes be very complex and might not follow the traditional definition of these terms \
+in standard re-ranking and retrieval.
+- In standard re-ranking/retrieval, a query is usually a user question (like a web search query), the \
+document is some sort of content that provides information (e.g., a web page), and these two are considered \
+relevant if the document provides information that answers the user's query.
+- However, in our setting, this could be different. Here are some examples:
+ * the query is a programming problem and documents are programming language syntax references. A document \
+is relevant if it contains the reference for the programming syntax used for solving the problem.
+ * both query and documents are descriptions programming problems and a query and document are relevant if \
+the same approach is used to solve them.
+ * the query is a math problem and documents are theorems. Relevant documents (theorems) are the ones \
+that are useful for solving the math problem.
+ * the query and document are both math problems. A query and a document are relevant if the same theorem \
+is used for solving them.
+ * the query is a task description (e.g., for an API programmer) and documents are descriptions of \
+available APIs. Relevant documents (e.g., APIs) are the ones needed for completing the task.
+- This is not an exhaustive list. These are just some examples to show you the complexity of queries, \
+documents, and the concept of relevance in this task.
+- Note that even here, the relevant documents are still the ones that are useful for a user who is \
+searching for the given query. But the relation is more nuanced.
+- You should analyze the query and the available documents. And then reason about what could be a meaningful \
+definition of relevance in this case, and what the user could be looking for.
+- Moreover, sometimes, the query could be even a prompt that is given to a Large Language Model (LLM) and \
+the user wants to find the useful documents for the LLM that help answering/solving this prompt.
+"""
+
+_WORKFLOW_TEMPLATE = """\
+
+
+* You are given a search query and a list of candidate documents. You have access to the ID and content of \
+each candidate document.
+* You should read the query carefully and understand it.
+{extended_relevance_line}\
+* Then you should compare the query with each one of the candidate documents. In this comparison, you want \
+to identify if the document is relevant/useful for the given query and to what extent.
+* Select the {top_k} most relevant candidate documents for the given query.
+* Note that just selecting the most relevant documents is not enough. You should identify the relative level \
+of relevance between the query and selected documents. This helps you sort the selected documents later \
+based on how relevant they are to the query.
+* Once you have this information, you should call the "log_selected_documents" function to report the final \
+results and signal the completion of the task.
+* Note that the selected document IDs must be reported in the decreasing level of relevance. I.e., The \
+first document in the list is the most relevant, the second is the second most relevant, and so on. This \
+is similar to what a search engine (e.g., Google Search) does (it shows you the relevant results in a \
+sorted order, where the most relevant results appear on top of the list).
+"""
+
+_THINKING_TIPS = """
+
+
+* you have access to a "think" tool that you can use for complex thinking and analysis. Here are examples \
+of cases where the think tool might be useful:
+ - complex analysis and thinking to understand the meaning and intent of the query. E.g., what is the \
+user trying to find with this query? what kind of information is helpful for the user?
+ - extended thinking to analyze how each candidate document could or could not be relevant to the given query.
+ - reasoning to identify the relative level of relevance between the query and selected documents. It \
+helps you sort the documents correctly when reporting the final answer.
+"""
+
+
+def _render_selection_prompt(top_k: int, *, extended_relevance: bool = False) -> str:
+ """Render the selection agent system prompt (verbatim 01_v0.j2 logic)."""
+ parts = [_ROLE]
+ if extended_relevance:
+ parts.append(_RELEVANCE_DEFINITION)
+ ext_line = (
+ "* As explained above, reason and figure out what the meaning of relevance is in this case, "
+ "and what could be relevant and useful information for the given query.\n"
+ if extended_relevance
+ else ""
+ )
+ parts.append(_WORKFLOW_TEMPLATE.format(top_k=top_k, extended_relevance_line=ext_line))
+ parts.append(_THINKING_TIPS)
+ return "".join(parts)
+
+
+# ---------------------------------------------------------------------------
+# Operator
+# ---------------------------------------------------------------------------
+
+
+class SelectionAgentOperator(AbstractOperator, CPUOperator):
+ """Re-rank a set of retrieved documents using an LLM-based selection agent.
+
+ For each ``query_id`` group in the input DataFrame, the operator runs an
+ agentic LLM loop that reads the query and all candidate documents, then
+ calls a ``log_selected_documents`` tool to report the final ranked list.
+ The loop also has access to a ``think`` tool for extended reasoning.
+
+ The system prompt matches the retrieval-bench ``01_v0.j2`` template verbatim,
+ with an optional ``extended_relevance`` mode for complex retrieval tasks.
+
+ Input DataFrame schema
+ ----------------------
+ query_id : str — unique query identifier
+ query_text : str — original query text shown to the LLM
+ doc_id : str — unique document identifier
+ text : str — document text content shown to the LLM
+ (any additional columns are ignored)
+
+ Output DataFrame schema
+ -----------------------
+ query_id : str — same ``query_id`` as the input
+ doc_id : str — selected document ID
+ rank : int — 1-indexed rank (1 = most relevant)
+ message : str — LLM explanation of the selection
+
+ Parameters
+ ----------
+ llm_model : str
+ Model identifier forwarded to the endpoint.
+ invoke_url : str
+ Full ``/v1/chat/completions`` endpoint URL.
+ top_k : int
+ Number of documents to select per query. Defaults to ``5``.
+ api_key : str, optional
+ Literal API key **or** an ``"os.environ/VAR_NAME"`` reference.
+ max_tokens : int, optional
+ Upper bound on tokens in each LLM response.
+ max_steps : int
+ Maximum agentic loop iterations per query. Defaults to ``10``.
+ extended_relevance : bool
+ When ``True``, include the ```` block in the
+ system prompt for tasks with non-standard relevance definitions.
+ Defaults to ``False``.
+ system_prompt_override : str, optional
+ Fully custom system prompt. Use ``{top_k}`` as a placeholder.
+ text_truncation : int
+ Maximum characters of each document's text shown to the LLM.
+ Defaults to ``2000``.
+ base_url : str, optional
+ Deprecated alias for ``invoke_url``. Prefer ``invoke_url``.
+ """
+
+ _NVIDIA_BUILD_ENDPOINT = "https://integrate.api.nvidia.com/v1/chat/completions"
+
+ def __init__(
+ self,
+ *,
+ llm_model: str,
+ invoke_url: Optional[str] = None,
+ top_k: int = 5,
+ api_key: Optional[str] = None,
+ max_tokens: Optional[int] = None,
+ max_steps: int = 10,
+ extended_relevance: bool = False,
+ system_prompt_override: Optional[str] = None,
+ text_truncation: int = 2000,
+ parallel_tool_calls: bool = True,
+ base_url: Optional[str] = None,
+ ) -> None:
+ super().__init__()
+ self._llm_model = llm_model
+ self._top_k = top_k
+ self._api_key = api_key
+ self._max_tokens = max_tokens
+ self._max_steps = max_steps
+ self._extended_relevance = extended_relevance
+ self._system_prompt_override = system_prompt_override
+ self._text_truncation = text_truncation
+ self._parallel_tool_calls = parallel_tool_calls
+
+ if invoke_url is not None:
+ self._invoke_url = invoke_url
+ elif base_url is not None:
+ import warnings
+
+ warnings.warn(
+ "SelectionAgentOperator: 'base_url' is deprecated, use 'invoke_url' instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ self._invoke_url = base_url.rstrip("/") + "/v1/chat/completions"
+ else:
+ self._invoke_url = self._NVIDIA_BUILD_ENDPOINT
+
+ # ------------------------------------------------------------------
+ # AbstractOperator interface
+ # ------------------------------------------------------------------
+
+ def preprocess(self, data: Any, **kwargs: Any) -> pd.DataFrame:
+ """Validate that *data* is a DataFrame with the required columns."""
+ if not isinstance(data, pd.DataFrame):
+ raise TypeError(f"SelectionAgentOperator expects a pd.DataFrame, got {type(data).__name__!r}.")
+ required = {"query_id", "query_text", "doc_id", "text"}
+ missing = required - set(data.columns)
+ if missing:
+ raise ValueError(
+ f"Input DataFrame is missing required column(s): {sorted(missing)}. " f"Expected: {sorted(required)}."
+ )
+ return data.copy()
+
+ def process(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
+ """Run the selection agent loop for each query group."""
+ rows: List[Dict[str, Any]] = []
+
+ for query_id, group in data.groupby("query_id", sort=False):
+ query_text = str(group["query_text"].iloc[0])
+ docs = [{"id": str(row["doc_id"]), "text": str(row["text"])} for _, row in group.iterrows()]
+ result = self._select_documents(query_text, docs)
+ message = result.get("message", "")
+ for rank, doc_id in enumerate(result.get("doc_ids", []), 1):
+ rows.append(
+ {
+ "query_id": query_id,
+ "doc_id": doc_id,
+ "rank": rank,
+ "message": message,
+ }
+ )
+
+ if not rows:
+ return pd.DataFrame(columns=["query_id", "doc_id", "rank", "message"])
+
+ return pd.DataFrame(rows)
+
+ def postprocess(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
+ return data
+
+ # ------------------------------------------------------------------
+ # Internal helpers
+ # ------------------------------------------------------------------
+
+ def _resolve_api_key(self) -> Optional[str]:
+ api_key = self._api_key
+ if api_key is not None and api_key.strip().startswith("os.environ/"):
+ var = api_key.strip().removeprefix("os.environ/")
+ value = os.environ.get(var)
+ if value is None:
+ raise ValueError(
+ f"Environment variable '{var}' is not set. " f"Set it with: export {var}="
+ )
+ return value
+ return api_key
+
+ def _build_system_prompt(self, top_k: int) -> str:
+ if self._system_prompt_override:
+ return self._system_prompt_override.format(top_k=top_k)
+ return _render_selection_prompt(top_k, extended_relevance=self._extended_relevance)
+
+ def _build_tools(self, top_k: int, valid_doc_ids: List[str]) -> List[Dict[str, Any]]:
+ """Return the two tool specs for the selection agent loop."""
+ return [
+ {
+ "type": "function",
+ "function": {
+ "name": "think",
+ "description": (
+ "Use this tool to think through complex analysis before making a decision. "
+ "It logs your reasoning without making any changes. Use it to compare "
+ "documents against the query or to reason about relevance."
+ ),
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "thought": {
+ "type": "string",
+ "description": "Your reasoning or analysis.",
+ }
+ },
+ "required": ["thought"],
+ },
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "log_selected_documents",
+ "description": (
+ f"Records the {top_k} most relevant documents and ends the task. "
+ f"Call this when you have finished evaluating all candidate documents. "
+ f"The doc_ids list must be sorted from most to least relevant. "
+ f"Valid document IDs are: {valid_doc_ids}."
+ ),
+ "parameters": {
+ "type": "object",
+ "required": ["doc_ids", "message"],
+ "properties": {
+ "doc_ids": {
+ "type": "array",
+ "items": {"type": "string"},
+ "description": (
+ f"The IDs of the {top_k} most relevant documents, sorted from "
+ "most to least relevant. Must be valid document IDs from the candidates."
+ ),
+ },
+ "message": {
+ "type": "string",
+ "description": "A brief explanation of your selection and the relevance ordering.",
+ },
+ },
+ },
+ },
+ },
+ ]
+
+ def _build_user_message(self, query_text: str, docs: List[Dict[str, Any]]) -> Dict[str, Any]:
+ """Format query + candidate documents as a multi-part user message."""
+ content: List[Dict[str, Any]] = [
+ {"type": "text", "text": f"Query:\n{query_text}"},
+ {"type": "text", "text": "Candidate Documents:"},
+ ]
+ seen: set[str] = set()
+ for doc in docs:
+ doc_id = doc["id"]
+ if doc_id in seen:
+ continue
+ seen.add(doc_id)
+ content.append({"type": "text", "text": f"Doc ID: {doc_id}"})
+ text = doc.get("text", "").strip()
+ if text:
+ truncated = text[: self._text_truncation]
+ if len(text) > self._text_truncation:
+ truncated += "..."
+ content.append({"type": "text", "text": f"Doc Text: {truncated}"})
+ return {"role": "user", "content": content}
+
+ def _select_documents(
+ self,
+ query_text: str,
+ docs: List[Dict[str, Any]],
+ ) -> Dict[str, Any]:
+ """Run the agentic selection loop for a single query."""
+ valid_ids = list(dict.fromkeys(d["id"] for d in docs))
+ feasible_k = min(self._top_k, len(valid_ids))
+
+ system_prompt = self._build_system_prompt(feasible_k)
+ tools = self._build_tools(feasible_k, valid_ids)
+ valid_id_set = set(valid_ids)
+ api_key = self._resolve_api_key()
+
+ messages: List[Dict[str, Any]] = [
+ {"role": "system", "content": system_prompt},
+ self._build_user_message(query_text, docs),
+ ]
+
+ extra_body: Dict[str, Any] = {}
+ if not self._parallel_tool_calls:
+ extra_body["parallel_tool_calls"] = False
+
+ for _step in range(self._max_steps):
+ try:
+ response = invoke_chat_completion_step(
+ invoke_url=self._invoke_url,
+ messages=messages,
+ model=self._llm_model,
+ api_key=api_key,
+ tools=tools,
+ tool_choice="auto",
+ max_tokens=self._max_tokens,
+ extra_body=extra_body or None,
+ )
+ except TimeoutError as exc:
+ logger.warning(
+ "SelectionAgentOperator: LLM call timed out on step %d for query %r: %s",
+ _step,
+ query_text,
+ exc,
+ exc_info=True,
+ )
+ break
+ except RuntimeError as exc:
+ logger.warning(
+ "SelectionAgentOperator: LLM retries exhausted on step %d for query %r: %s",
+ _step,
+ query_text,
+ exc,
+ exc_info=True,
+ )
+ break
+ except requests.RequestException as exc:
+ logger.warning(
+ "SelectionAgentOperator: LLM HTTP error on step %d for query %r: %s",
+ _step,
+ query_text,
+ exc,
+ exc_info=True,
+ )
+ break
+ except json.JSONDecodeError as exc:
+ logger.warning(
+ "SelectionAgentOperator: LLM returned invalid JSON on step %d for query %r: %s",
+ _step,
+ query_text,
+ exc,
+ exc_info=True,
+ )
+ break
+
+ if not response.get("choices"):
+ logger.warning("SelectionAgentOperator: empty choices in API response on step %d", _step)
+ break
+ choice = response["choices"][0]
+ msg = choice["message"]
+ finish_reason = choice.get("finish_reason")
+
+ # Append the assistant turn to history
+ assistant_turn: Dict[str, Any] = {"role": "assistant"}
+ if msg.get("content"):
+ assistant_turn["content"] = msg["content"]
+ tool_calls = msg.get("tool_calls") or []
+ if tool_calls:
+ assistant_turn["tool_calls"] = tool_calls
+ messages.append(assistant_turn)
+
+ if finish_reason == "stop" or not tool_calls:
+ messages.append(
+ {
+ "role": "user",
+ "content": "Please call log_selected_documents to report your final selection.",
+ }
+ )
+ continue
+
+ tool_messages: List[Dict[str, Any]] = []
+ should_end = False
+ end_kwargs: Dict[str, Any] = {}
+
+ for tc in tool_calls:
+ tc_id = tc.get("id", "")
+ fn = tc.get("function", {})
+ try:
+ fn_args = json.loads(fn.get("arguments", "{}"))
+ except json.JSONDecodeError:
+ tool_messages.append(
+ {"role": "tool", "tool_call_id": tc_id, "content": "Error: could not parse tool arguments."}
+ )
+ continue
+
+ if fn.get("name") == "think":
+ tool_messages.append(
+ {"role": "tool", "tool_call_id": tc_id, "content": "Your thought has been logged."}
+ )
+
+ elif fn.get("name") == "log_selected_documents":
+ raw_doc_ids = fn_args.get("doc_ids", [])
+ if isinstance(raw_doc_ids, str):
+ try:
+ raw_doc_ids = json.loads(raw_doc_ids)
+ except json.JSONDecodeError:
+ raw_doc_ids = []
+ doc_ids = [d for d in raw_doc_ids if d in valid_id_set][:feasible_k]
+ if not doc_ids and raw_doc_ids:
+ logger.warning(
+ "SelectionAgentOperator: LLM returned %d doc_id(s) for query %r "
+ "but none matched the candidate set — possible hallucination. "
+ "Returned IDs: %s",
+ len(raw_doc_ids),
+ query_text,
+ raw_doc_ids[:10],
+ )
+ end_kwargs = {"doc_ids": doc_ids, "message": fn_args.get("message", "")}
+ should_end = True
+
+ else:
+ tool_messages.append(
+ {
+ "role": "tool",
+ "tool_call_id": tc_id,
+ "content": f"Error: unknown tool '{fn.get('name')}'.",
+ }
+ )
+
+ if should_end:
+ return end_kwargs
+
+ messages.extend(tool_messages)
+
+ return {
+ "doc_ids": [],
+ "message": "Selection agent reached max steps without completing.",
+ }
diff --git a/nemo_retriever/src/nemo_retriever/graph/subquery_operator.py b/nemo_retriever/src/nemo_retriever/graph/subquery_operator.py
new file mode 100644
index 0000000000..d5f0e6a3c1
--- /dev/null
+++ b/nemo_retriever/src/nemo_retriever/graph/subquery_operator.py
@@ -0,0 +1,364 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Operator that expands a query DataFrame into sub-queries via an LLM."""
+
+from __future__ import annotations
+
+import json
+import logging
+import os
+
+import requests
+from typing import Any, List, Literal, Optional
+
+import pandas as pd
+
+from nemo_retriever.graph.abstract_operator import AbstractOperator
+from nemo_retriever.graph.cpu_operator import CPUOperator
+from nemo_retriever.nim.chat_completions import invoke_chat_completions
+
+logger = logging.getLogger(__name__)
+
+# ---------------------------------------------------------------------------
+# Built-in strategy prompts
+# ---------------------------------------------------------------------------
+
+_PROMPTS: dict[str, str] = {
+ "decompose": """\
+You are a query decomposition assistant for a retrieval system.
+
+Given a search query, break it down into up to {max_subqueries} distinct sub-queries that \
+together cover all aspects and angles of the original query. Generate as many sub-queries as \
+are genuinely useful — do not pad with redundant ones just to hit the maximum. Each sub-query \
+should target a specific facet, making it easier for a dense retrieval system to find all \
+relevant documents.
+
+Rules:
+- Each sub-query must be self-contained and meaningful on its own.
+- Sub-queries should be diverse and complementary, not redundant.
+- Use clear, precise language suited for dense embedding retrieval.
+- Output a JSON array of strings only — no explanation, no markdown fences.""",
+ "hyde": """\
+You are a Hypothetical Document Embedding (HyDE) assistant for a retrieval system.
+
+Given a search query, generate up to {max_subqueries} short hypothetical document passages \
+(2–4 sentences each) that would directly answer or address the query. Generate as many as are \
+genuinely useful — fewer is fine if the query is simple. These passages will be used as queries \
+to a dense retrieval system to find real, similar documents.
+
+Rules:
+- Each passage should read like a real document excerpt that answers the query.
+- Vary the style and perspective across passages (e.g., academic, technical, narrative).
+- Be factually plausible; focus on covering the query intent.
+- Output a JSON array of strings only — no explanation, no markdown fences.""",
+ "multi_perspective": """\
+You are a multi-perspective query expansion assistant for a retrieval system.
+
+Given a search query, generate up to {max_subqueries} reformulations from different angles, \
+perspectives, or levels of specificity to maximise recall in a dense retrieval system. Only \
+generate reformulations that add genuine coverage — do not pad.
+
+Rules:
+- Vary terminology: use synonyms, technical vs. casual language, acronyms vs. full names.
+- Vary scope: broad overview queries alongside narrow, specific ones.
+- Vary form: declarative statements, questions, and entity-focused queries.
+- Each reformulation must have a meaningfully different surface form from the others.
+- Output a JSON array of strings only — no explanation, no markdown fences.""",
+}
+
+
+# ---------------------------------------------------------------------------
+# Operator
+# ---------------------------------------------------------------------------
+
+
+class SubQueryGeneratorOperator(AbstractOperator, CPUOperator):
+ """Expand each query row into sub-query rows using an LLM.
+
+ The operator calls an LLM once per input query and explodes the result into
+ one output row per generated sub-query. The LLM decides how many
+ sub-queries to generate up to ``max_subqueries``. This
+ makes it a natural upstream stage for a retrieval operator: the downstream
+ operator can retrieve documents independently for every sub-query row, and
+ a subsequent aggregation step (e.g. RRF) can merge the per-sub-query
+ ranked lists back into a single ranking per ``query_id``.
+
+ Input DataFrame schema
+ ----------------------
+ query_id : str — unique identifier for the query
+ query_text : str — the search query text
+ (any additional columns are passed through unchanged)
+
+ Output DataFrame schema
+ -----------------------
+ query_id : str — same ``query_id`` as the input row
+ query_text : str — original query text (preserved for context)
+ subquery_idx : int — 0-based position within the generated sub-query group
+ subquery_text : str — the generated sub-query text
+ (additional input columns are forwarded to every expanded row)
+
+ Parameters
+ ----------
+ llm_model : str
+ OpenAI model identifier, e.g. ``"gpt-4o"``.
+ max_subqueries : int
+ Maximum number of sub-queries the LLM may generate per query.
+ The LLM will generate fewer if the query does not warrant the maximum.
+ Defaults to ``4``.
+ strategy : {"decompose", "hyde", "multi_perspective"}
+ Built-in sub-query generation strategy.
+
+ ``"decompose"``
+ Break the query into complementary sub-aspects (default).
+ ``"hyde"``
+ Generate hypothetical answer passages (HyDE).
+ ``"multi_perspective"``
+ Rewrite the query from diverse angles to maximise recall.
+ api_key : str, optional
+ Literal API key **or** an ``"os.environ/VAR_NAME"`` reference that is
+ resolved at call time.
+ invoke_url : str, optional
+ Full ``/v1/chat/completions`` endpoint URL. Defaults to the NVIDIA
+ build endpoint when omitted (requires ``api_key`` / ``NVIDIA_API_KEY``).
+ base_url : str, optional
+ Deprecated alias for ``invoke_url``. Prefer ``invoke_url``.
+ max_tokens : int, optional
+ Upper bound on tokens in the LLM response.
+ system_prompt_override : str, optional
+ Fully custom system prompt. Use ``{max_subqueries}`` as a placeholder.
+ When provided, ``strategy`` is ignored.
+
+ Examples
+ --------
+ Standalone use::
+
+ import pandas as pd
+ from nemo_retriever.graph.subquery_operator import SubQueryGeneratorOperator
+
+ op = SubQueryGeneratorOperator(
+ llm_model="nvidia/llama-3.3-nemotron-super-49b-v1",
+ invoke_url="https://integrate.api.nvidia.com/v1/chat/completions",
+ max_subqueries=5,
+ )
+ df = pd.DataFrame({
+ "query_id": ["q1", "q2"],
+ "query_text": ["What causes inflation?", "How do vaccines work?"],
+ })
+ result = op.run(df)
+ # result has up to 10 rows: ≤5 sub-queries × 2 original queries
+
+ Composing into a graph::
+
+ from nemo_retriever.graph import InprocessExecutor
+ from nemo_retriever.graph.subquery_operator import SubQueryGeneratorOperator
+
+ graph = (
+ SubQueryGeneratorOperator(llm_model="gpt-4o", max_subqueries=4)
+ >> RetrievalOperator(retriever=my_retriever)
+ >> RRFAggregatorOperator()
+ )
+ executor = InprocessExecutor(graph)
+ result_df = executor.ingest(query_df)
+ """
+
+ _NVIDIA_BUILD_ENDPOINT = "https://integrate.api.nvidia.com/v1/chat/completions"
+
+ def __init__(
+ self,
+ *,
+ llm_model: str,
+ max_subqueries: int = 4,
+ strategy: Literal["decompose", "hyde", "multi_perspective"] = "decompose",
+ api_key: Optional[str] = None,
+ invoke_url: Optional[str] = None,
+ base_url: Optional[str] = None,
+ max_tokens: Optional[int] = None,
+ system_prompt_override: Optional[str] = None,
+ ) -> None:
+ super().__init__()
+ self._llm_model = llm_model
+ self._max_subqueries = max_subqueries
+ self._strategy = strategy
+ self._api_key = api_key
+ self._max_tokens = max_tokens
+ self._system_prompt_override = system_prompt_override
+
+ # Resolve invoke_url: explicit > deprecated base_url > default NVIDIA endpoint
+ if invoke_url is not None:
+ self._invoke_url = invoke_url
+ elif base_url is not None:
+ import warnings
+
+ warnings.warn(
+ "SubQueryGeneratorOperator: 'base_url' is deprecated, use 'invoke_url' instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ self._invoke_url = base_url.rstrip("/") + "/v1/chat/completions"
+ else:
+ self._invoke_url = self._NVIDIA_BUILD_ENDPOINT
+
+ # ------------------------------------------------------------------
+ # AbstractOperator interface
+ # ------------------------------------------------------------------
+
+ def preprocess(self, data: Any, **kwargs: Any) -> pd.DataFrame:
+ """Normalise *data* to a DataFrame with ``query_id`` / ``query_text`` columns.
+
+ Accepted input types
+ --------------------
+ ``pd.DataFrame``
+ Must contain at least ``query_id`` and ``query_text`` columns.
+ ``list[str]``
+ Plain query strings; ``query_id`` values are auto-assigned as
+ ``"q0"``, ``"q1"``, …
+ ``list[tuple[str, str]]`` or ``list[list[str, str]]``
+ ``(query_id, query_text)`` pairs.
+ """
+ if isinstance(data, pd.DataFrame):
+ missing = {"query_id", "query_text"} - set(data.columns)
+ if missing:
+ raise ValueError(
+ f"Input DataFrame is missing required column(s): {sorted(missing)}. "
+ "Expected at minimum: 'query_id' and 'query_text'."
+ )
+ return data.copy()
+
+ if isinstance(data, list) and data:
+ first = data[0]
+ if isinstance(first, str):
+ return pd.DataFrame(
+ {
+ "query_id": [f"q{i}" for i in range(len(data))],
+ "query_text": list(data),
+ }
+ )
+ if isinstance(first, (tuple, list)) and len(first) == 2:
+ return pd.DataFrame(data, columns=["query_id", "query_text"])
+
+ raise TypeError(
+ f"Unsupported input type {type(data).__name__!r}. "
+ "Pass a pd.DataFrame with 'query_id' and 'query_text' columns, "
+ "a list[str], or a list[tuple[str, str]]."
+ )
+
+ def process(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
+ """Generate sub-queries for every row and explode to one row per sub-query."""
+ system_prompt = self._build_system_prompt()
+
+ passthrough_cols = [c for c in data.columns if c not in ("query_id", "query_text")]
+ rows: List[dict[str, Any]] = []
+
+ for _, row in data.iterrows():
+ subqueries = self._generate_one(row["query_text"], system_prompt)
+ for idx, sq in enumerate(subqueries):
+ new_row: dict[str, Any] = {
+ "query_id": row["query_id"],
+ "query_text": row["query_text"],
+ "subquery_idx": idx,
+ "subquery_text": sq,
+ }
+ for col in passthrough_cols:
+ new_row[col] = row[col]
+ rows.append(new_row)
+
+ if not rows:
+ return pd.DataFrame(columns=["query_id", "query_text", "subquery_idx", "subquery_text"])
+
+ return pd.DataFrame(rows)
+
+ def postprocess(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
+ return data
+
+ # ------------------------------------------------------------------
+ # Internal helpers
+ # ------------------------------------------------------------------
+
+ def _resolve_api_key(self) -> Optional[str]:
+ api_key = self._api_key
+ if api_key is not None and api_key.strip().startswith("os.environ/"):
+ var = api_key.strip().removeprefix("os.environ/")
+ value = os.environ.get(var)
+ if value is None:
+ raise ValueError(
+ f"Environment variable '{var}' is not set. " f"Set it with: export {var}="
+ )
+ return value
+ return api_key
+
+ def _build_system_prompt(self) -> str:
+ template = self._system_prompt_override or _PROMPTS[self._strategy]
+ return template.format(max_subqueries=self._max_subqueries)
+
+ def _generate_one(self, query: str, system_prompt: str) -> List[str]:
+ """Call the LLM and return a list of sub-query strings for *query*."""
+ messages = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": f"Query: {query}"},
+ ]
+ extra_body: dict[str, Any] = {}
+ if self._max_tokens is not None:
+ extra_body["max_tokens"] = self._max_tokens
+
+ api_key = self._resolve_api_key() # raises ValueError immediately on config error
+
+ try:
+ results = invoke_chat_completions(
+ invoke_url=self._invoke_url,
+ messages_list=[messages],
+ model=self._llm_model,
+ api_key=api_key,
+ extra_body=extra_body or None,
+ )
+ except TimeoutError as exc:
+ logger.warning("SubQueryGeneratorOperator: LLM call timed out for query %r: %s", query, exc, exc_info=True)
+ return [query]
+ except RuntimeError as exc:
+ logger.warning(
+ "SubQueryGeneratorOperator: LLM retries exhausted for query %r: %s", query, exc, exc_info=True
+ )
+ return [query]
+ except requests.RequestException as exc:
+ logger.warning("SubQueryGeneratorOperator: LLM HTTP error for query %r: %s", query, exc, exc_info=True)
+ return [query]
+ except json.JSONDecodeError as exc:
+ logger.warning(
+ "SubQueryGeneratorOperator: LLM returned invalid JSON for query %r: %s", query, exc, exc_info=True
+ )
+ return [query]
+ raw = results[0].strip()
+ return _parse_json_list(raw, fallback=query)
+
+
+# ---------------------------------------------------------------------------
+# Module-level helpers (no instance state — easier to test in isolation)
+# ---------------------------------------------------------------------------
+
+
+def _parse_json_list(raw: str, *, fallback: str) -> List[str]:
+ """Parse a JSON array from *raw*, stripping markdown fences if present.
+
+ Returns *[fallback]* when parsing fails so downstream stages always
+ receive at least one sub-query row.
+ """
+ text = raw
+ found_fence = False
+ for fence in ("```json", "```"):
+ if text.startswith(fence):
+ text = text[len(fence) :]
+ found_fence = True
+ break
+ if found_fence and text.endswith("```"):
+ text = text[:-3]
+ text = text.strip()
+
+ try:
+ parsed = json.loads(text)
+ if isinstance(parsed, list) and parsed and all(isinstance(s, str) for s in parsed):
+ return parsed
+ except json.JSONDecodeError:
+ pass
+
+ return [fallback]
diff --git a/nemo_retriever/src/nemo_retriever/nim/chat_completions.py b/nemo_retriever/src/nemo_retriever/nim/chat_completions.py
index 98849c7dc4..2fdebe4576 100644
--- a/nemo_retriever/src/nemo_retriever/nim/chat_completions.py
+++ b/nemo_retriever/src/nemo_retriever/nim/chat_completions.py
@@ -17,6 +17,8 @@
from typing import Any, Dict, List, Optional, Sequence
+from nemo_retriever.nim.nim import _parse_invoke_urls, _post_with_retries
+
def extract_chat_completion_text(response_json: Any) -> str:
"""Extract generated text from an OpenAI-compatible chat completions response."""
@@ -86,6 +88,68 @@ def invoke_chat_completions(
client.shutdown()
+def invoke_chat_completion_step(
+ *,
+ invoke_url: str,
+ messages: List[Dict[str, Any]],
+ model: Optional[str] = None,
+ api_key: Optional[str] = None,
+ tools: Optional[List[Dict[str, Any]]] = None,
+ tool_choice: str = "auto",
+ timeout_s: float = 120.0,
+ temperature: float = 0.0,
+ max_tokens: Optional[int] = None,
+ extra_body: Optional[Dict[str, Any]] = None,
+ max_retries: int = 10,
+ max_429_retries: int = 5,
+) -> Dict[str, Any]:
+ """Single synchronous tool-aware chat completion call.
+
+ Parameters
+ ----------
+ invoke_url
+ Full ``/v1/chat/completions`` endpoint URL.
+ messages
+ OpenAI-format message list for this single request.
+ tools
+ List of OpenAI tool-spec dicts. When provided, ``tool_choice`` is also
+ forwarded so the model knows which tools it may call.
+ tool_choice
+ ``"auto"`` (default) lets the model decide; ``"none"`` suppresses tool
+ use; or a specific tool name dict.
+ """
+ token = (api_key or "").strip()
+ headers: Dict[str, str] = {"Accept": "application/json", "Content-Type": "application/json"}
+ if token:
+ headers["Authorization"] = f"Bearer {token}"
+
+ invoke_urls = _parse_invoke_urls(invoke_url)
+ endpoint_url = invoke_urls[0]
+
+ payload: Dict[str, Any] = {
+ "messages": messages,
+ "temperature": temperature,
+ }
+ if model:
+ payload["model"] = model
+ if max_tokens is not None:
+ payload["max_tokens"] = max_tokens
+ if tools:
+ payload["tools"] = tools
+ payload["tool_choice"] = tool_choice
+ if extra_body:
+ payload.update(extra_body)
+
+ return _post_with_retries(
+ invoke_url=endpoint_url,
+ payload=payload,
+ headers=headers,
+ timeout_s=float(timeout_s),
+ max_retries=int(max_retries),
+ max_429_retries=int(max_429_retries),
+ )
+
+
def invoke_chat_completions_images(
*,
invoke_url: str,
diff --git a/nemo_retriever/tests/test_agentic_operators.py b/nemo_retriever/tests/test_agentic_operators.py
new file mode 100644
index 0000000000..78ae3dd1e2
--- /dev/null
+++ b/nemo_retriever/tests/test_agentic_operators.py
@@ -0,0 +1,596 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Smoke tests for the agentic retrieval operators.
+
+Run with:
+ cd nemo_retriever && uv run pytest tests/test_agentic_operators.py -v
+"""
+
+from __future__ import annotations
+
+import json
+from unittest.mock import MagicMock, patch
+
+import pandas as pd
+import pytest
+
+# ---------------------------------------------------------------------------
+# RRFAggregatorOperator — pure pandas, no mocking needed
+# ---------------------------------------------------------------------------
+
+
+class TestRRFAggregatorOperator:
+ def _make_input(self):
+ """Two queries; q1 has doc d1 in both steps, q2 has one step."""
+ return pd.DataFrame(
+ {
+ "query_id": ["q1", "q1", "q1", "q1", "q2", "q2"],
+ "query_text": ["inflation"] * 4 + ["vaccines"] * 2,
+ "step_idx": [0, 0, 1, 1, 0, 0],
+ "doc_id": ["d1", "d2", "d1", "d3", "d4", "d5"],
+ "text": ["t1", "t2", "t1", "t3", "t4", "t5"],
+ "rank": [1, 2, 1, 2, 1, 2],
+ }
+ )
+
+ def test_rrf_scores_correct(self):
+ from nemo_retriever.graph.rrf_aggregator_operator import RRFAggregatorOperator
+
+ op = RRFAggregatorOperator(k=60)
+ result = op.run(self._make_input())
+
+ q1 = result[result["query_id"] == "q1"].set_index("doc_id")
+ k = 60
+ # d1 appears in step 0 rank 1 and step 1 rank 1
+ expected_d1 = 1 / (1 + k) + 1 / (1 + k)
+ # d2 appears only in step 0 rank 2
+ expected_d2 = 1 / (2 + k)
+ assert abs(q1.loc["d1", "rrf_score"] - expected_d1) < 1e-10
+ assert abs(q1.loc["d2", "rrf_score"] - expected_d2) < 1e-10
+
+ def test_sorted_descending_per_query(self):
+ from nemo_retriever.graph.rrf_aggregator_operator import RRFAggregatorOperator
+
+ op = RRFAggregatorOperator(k=60)
+ result = op.run(self._make_input())
+
+ for _, grp in result.groupby("query_id"):
+ scores = grp["rrf_score"].tolist()
+ assert scores == sorted(scores, reverse=True), "Scores not sorted descending"
+
+ def test_text_carried_through(self):
+ from nemo_retriever.graph.rrf_aggregator_operator import RRFAggregatorOperator
+
+ op = RRFAggregatorOperator(k=60)
+ result = op.run(self._make_input())
+ q1 = result[result["query_id"] == "q1"].set_index("doc_id")
+ assert q1.loc["d1", "text"] == "t1"
+
+ def test_output_schema(self):
+ from nemo_retriever.graph.rrf_aggregator_operator import RRFAggregatorOperator
+
+ op = RRFAggregatorOperator(k=60)
+ result = op.run(self._make_input())
+ assert set(result.columns) >= {"query_id", "query_text", "doc_id", "rrf_score", "text"}
+
+ def test_missing_column_raises(self):
+ from nemo_retriever.graph.rrf_aggregator_operator import RRFAggregatorOperator
+
+ op = RRFAggregatorOperator(k=60)
+ bad_df = pd.DataFrame({"query_id": ["q1"], "query_text": ["x"]})
+ with pytest.raises(ValueError, match="missing required column"):
+ op.run(bad_df)
+
+
+# ---------------------------------------------------------------------------
+# Prompt rendering — pure Python, no mocking needed
+# ---------------------------------------------------------------------------
+
+
+class TestPromptRendering:
+ def test_react_prompt_no_extended_relevance(self):
+ from nemo_retriever.graph.react_agent_operator import _render_react_agent_prompt
+
+ prompt = _render_react_agent_prompt(10, with_init_docs=True, enforce_top_k=True, extended_relevance=False)
+ assert "" in prompt
+ assert "" in prompt
+ assert "" in prompt
+ assert "RELEVANCE_DEFINITION" not in prompt
+ assert "exactly the 10 most relevant" in prompt
+ assert "TIP" in prompt
+
+ def test_react_prompt_with_extended_relevance(self):
+ from nemo_retriever.graph.react_agent_operator import _render_react_agent_prompt
+
+ prompt = _render_react_agent_prompt(5, with_init_docs=False, enforce_top_k=False, extended_relevance=True)
+ assert "RELEVANCE_DEFINITION" in prompt
+ assert "exactly the 5" not in prompt
+ assert "TIP" not in prompt
+
+ def test_selection_prompt_no_extended_relevance(self):
+ from nemo_retriever.graph.selection_agent_operator import _render_selection_prompt
+
+ prompt = _render_selection_prompt(5, extended_relevance=False)
+ assert "" in prompt
+ assert "" in prompt
+ assert "THINKING TIPS" in prompt
+ assert "RELEVANCE_DEFINITION" not in prompt
+ assert "5 most relevant" in prompt
+
+ def test_selection_prompt_with_extended_relevance(self):
+ from nemo_retriever.graph.selection_agent_operator import _render_selection_prompt
+
+ prompt = _render_selection_prompt(5, extended_relevance=True)
+ assert "RELEVANCE_DEFINITION" in prompt
+ assert "As explained above" in prompt
+
+
+# ---------------------------------------------------------------------------
+# SelectionAgentOperator — mock invoke_chat_completion_step
+# ---------------------------------------------------------------------------
+
+
+def _make_tool_call_response(fn_name: str, fn_args: dict, tc_id: str = "call_1") -> dict:
+ """Build a canned /v1/chat/completions response with one tool call."""
+ return {
+ "choices": [
+ {
+ "message": {
+ "content": None,
+ "tool_calls": [
+ {
+ "id": tc_id,
+ "type": "function",
+ "function": {"name": fn_name, "arguments": json.dumps(fn_args)},
+ }
+ ],
+ },
+ "finish_reason": "tool_calls",
+ }
+ ]
+ }
+
+
+class TestSelectionAgentOperator:
+ def _make_input(self):
+ return pd.DataFrame(
+ {
+ "query_id": ["q1", "q1", "q1"],
+ "query_text": ["What causes inflation?"] * 3,
+ "doc_id": ["d1", "d2", "d3"],
+ "text": ["monetary policy doc", "supply chain doc", "unrelated doc"],
+ }
+ )
+
+ @patch("nemo_retriever.graph.selection_agent_operator.invoke_chat_completion_step")
+ def test_happy_path_selects_docs(self, mock_step):
+ from nemo_retriever.graph.selection_agent_operator import SelectionAgentOperator
+
+ # LLM immediately calls log_selected_documents
+ mock_step.return_value = _make_tool_call_response(
+ "log_selected_documents",
+ {"doc_ids": ["d1", "d2"], "message": "d1 most relevant"},
+ )
+
+ op = SelectionAgentOperator(
+ llm_model="test-model",
+ invoke_url="http://localhost/v1/chat/completions",
+ top_k=2,
+ )
+ result = op.run(self._make_input())
+
+ assert list(result.columns) == ["query_id", "doc_id", "rank", "message"]
+ assert result["query_id"].tolist() == ["q1", "q1"]
+ assert result["doc_id"].tolist() == ["d1", "d2"]
+ assert result["rank"].tolist() == [1, 2]
+
+ @patch("nemo_retriever.graph.selection_agent_operator.invoke_chat_completion_step")
+ def test_think_then_select(self, mock_step):
+ from nemo_retriever.graph.selection_agent_operator import SelectionAgentOperator
+
+ # First call: think; second call: log_selected_documents
+ mock_step.side_effect = [
+ _make_tool_call_response("think", {"thought": "let me reason..."}),
+ _make_tool_call_response("log_selected_documents", {"doc_ids": ["d3"], "message": "only d3"}),
+ ]
+
+ op = SelectionAgentOperator(
+ llm_model="test-model",
+ invoke_url="http://localhost/v1/chat/completions",
+ top_k=1,
+ )
+ result = op.run(self._make_input())
+
+ assert result["doc_id"].tolist() == ["d3"]
+ assert mock_step.call_count == 2
+
+ @patch("nemo_retriever.graph.selection_agent_operator.invoke_chat_completion_step")
+ def test_extended_relevance_in_prompt(self, mock_step):
+ from nemo_retriever.graph.selection_agent_operator import SelectionAgentOperator
+
+ captured_prompts = []
+
+ def capture_and_respond(**kwargs):
+ captured_prompts.append(kwargs["messages"][0]["content"])
+ return _make_tool_call_response("log_selected_documents", {"doc_ids": ["d1"], "message": "ok"})
+
+ mock_step.side_effect = capture_and_respond
+
+ op = SelectionAgentOperator(
+ llm_model="test-model",
+ invoke_url="http://localhost/v1/chat/completions",
+ top_k=1,
+ extended_relevance=True,
+ )
+ op.run(self._make_input())
+
+ assert "RELEVANCE_DEFINITION" in captured_prompts[0]
+
+
+# ---------------------------------------------------------------------------
+# ReActAgentOperator — mock retriever_fn + invoke_chat_completion_step
+# ---------------------------------------------------------------------------
+
+
+class TestReActAgentOperator:
+ def _make_input(self):
+ return pd.DataFrame(
+ {
+ "query_id": ["q1"],
+ "query_text": ["What causes inflation?"],
+ }
+ )
+
+ def _make_retriever(self, docs=None):
+ """Return a mock retriever_fn that returns canned docs."""
+ if docs is None:
+ docs = [{"doc_id": "d1", "text": "monetary policy"}, {"doc_id": "d2", "text": "supply chains"}]
+
+ def retriever_fn(query_text: str, top_k: int):
+ return docs[:top_k]
+
+ return retriever_fn
+
+ @patch("nemo_retriever.graph.react_agent_operator.invoke_chat_completion_step")
+ def test_simple_mode_retrieve_then_final(self, mock_step):
+ from nemo_retriever.graph.react_agent_operator import ReActAgentOperator
+
+ # 1) agent calls retrieve("subquery"), 2) agent calls final_results
+ mock_step.side_effect = [
+ _make_tool_call_response("retrieve", {"query": "inflation monetary policy"}),
+ _make_tool_call_response(
+ "final_results",
+ {"doc_ids": ["d1", "d2"], "message": "found them", "search_successful": "true"},
+ ),
+ ]
+
+ op = ReActAgentOperator(
+ invoke_url="http://localhost/v1/chat/completions",
+ llm_model="test-model",
+ retriever_fn=self._make_retriever(),
+ user_msg_type="simple",
+ target_top_k=2,
+ )
+ result = op.run(self._make_input())
+
+ assert set(result.columns) >= {"query_id", "query_text", "step_idx", "doc_id", "text", "rank"}
+ assert result["query_id"].unique().tolist() == ["q1"]
+ # step 0 is the retrieve tool call result
+ assert 0 in result["step_idx"].values
+ assert "d1" in result["doc_id"].values
+
+ @patch("nemo_retriever.graph.react_agent_operator.invoke_chat_completion_step")
+ def test_with_results_mode_initial_retrieval(self, mock_step):
+ from nemo_retriever.graph.react_agent_operator import ReActAgentOperator
+
+ # with_results=True → retriever_fn called once before LLM, then LLM immediately calls final_results
+ mock_step.return_value = _make_tool_call_response(
+ "final_results",
+ {"doc_ids": ["d1"], "message": "ok", "search_successful": "true"},
+ )
+ retriever = MagicMock(return_value=[{"doc_id": "d1", "text": "monetary policy"}])
+
+ op = ReActAgentOperator(
+ invoke_url="http://localhost/v1/chat/completions",
+ llm_model="test-model",
+ retriever_fn=retriever,
+ user_msg_type="with_results",
+ target_top_k=1,
+ )
+ result = op.run(self._make_input())
+
+ # retriever was called upfront (step_idx=0) before any LLM step
+ assert retriever.call_count >= 1
+ assert 0 in result["step_idx"].values
+
+ @patch("nemo_retriever.graph.react_agent_operator.invoke_chat_completion_step")
+ def test_output_row_structure(self, mock_step):
+ from nemo_retriever.graph.react_agent_operator import ReActAgentOperator
+
+ mock_step.side_effect = [
+ _make_tool_call_response("retrieve", {"query": "q"}),
+ _make_tool_call_response(
+ "final_results", {"doc_ids": ["d1"], "message": "ok", "search_successful": "true"}
+ ),
+ ]
+
+ op = ReActAgentOperator(
+ invoke_url="http://localhost/v1/chat/completions",
+ llm_model="test-model",
+ retriever_fn=self._make_retriever(),
+ user_msg_type="simple",
+ )
+ result = op.run(self._make_input())
+
+ assert (result["rank"] >= 1).all()
+ assert result["step_idx"].dtype in (int, "int64")
+ assert result["doc_id"].notna().all()
+
+ @patch("nemo_retriever.graph.selection_agent_operator.invoke_chat_completion_step")
+ @patch("nemo_retriever.graph.react_agent_operator.invoke_chat_completion_step")
+ def test_pipeline_end_to_end_with_mocks(self, mock_react_step, mock_selection_step):
+ """Wire ReAct → RRF → Selection with mocks; verify final output shape.
+
+ Each operator imports invoke_chat_completion_step into its own module
+ namespace, so both must be patched independently.
+ """
+ from nemo_retriever.graph.executor import InprocessExecutor
+ from nemo_retriever.graph.react_agent_operator import ReActAgentOperator
+ from nemo_retriever.graph.rrf_aggregator_operator import RRFAggregatorOperator
+ from nemo_retriever.graph.selection_agent_operator import SelectionAgentOperator
+
+ # ReAct: retrieve once, then final_results
+ mock_react_step.side_effect = [
+ _make_tool_call_response("retrieve", {"query": "inflation"}),
+ _make_tool_call_response(
+ "final_results", {"doc_ids": ["d1", "d2"], "message": "ok", "search_successful": "true"}
+ ),
+ ]
+ # Selection: immediately log_selected_documents
+ mock_selection_step.return_value = _make_tool_call_response(
+ "log_selected_documents", {"doc_ids": ["d1"], "message": "d1 best"}
+ )
+
+ def retriever_fn(query_text, top_k):
+ return [{"doc_id": "d1", "text": "monetary policy"}, {"doc_id": "d2", "text": "supply chains"}]
+
+ pipeline = (
+ ReActAgentOperator(
+ invoke_url="http://localhost/v1/chat/completions",
+ llm_model="test-model",
+ retriever_fn=retriever_fn,
+ user_msg_type="simple",
+ target_top_k=1,
+ )
+ >> RRFAggregatorOperator(k=60)
+ >> SelectionAgentOperator(
+ invoke_url="http://localhost/v1/chat/completions",
+ llm_model="test-model",
+ top_k=1,
+ )
+ )
+
+ query_df = pd.DataFrame({"query_id": ["q1"], "query_text": ["What causes inflation?"]})
+ result = InprocessExecutor(pipeline).ingest(query_df)
+
+ assert list(result.columns) == ["query_id", "doc_id", "rank", "message"]
+ assert result["query_id"].tolist() == ["q1"]
+ assert result["rank"].tolist() == [1]
+
+
+# ---------------------------------------------------------------------------
+# _parse_json_list — pure Python, no mocking needed
+# ---------------------------------------------------------------------------
+
+
+class TestParseJsonList:
+ def _parse(self, raw, fallback="orig"):
+ from nemo_retriever.graph.subquery_operator import _parse_json_list
+
+ return _parse_json_list(raw, fallback=fallback)
+
+ def test_plain_json_array(self):
+ assert self._parse('["a", "b", "c"]') == ["a", "b", "c"]
+
+ def test_json_fence(self):
+ assert self._parse('```json\n["x", "y"]\n```') == ["x", "y"]
+
+ def test_plain_fence(self):
+ assert self._parse('```\n["x"]\n```') == ["x"]
+
+ def test_trailing_fence_without_leading_not_stripped(self):
+ # A JSON string that happens to end with ``` but has no leading fence.
+ # It should still parse because the trailing strip must NOT fire.
+ raw = '["valid"]```'
+ result = self._parse(raw)
+ assert result == ["orig"] # malformed JSON → fallback
+
+ def test_malformed_json_returns_fallback(self):
+ assert self._parse("not json at all", fallback="q") == ["q"]
+
+ def test_empty_list_returns_fallback(self):
+ assert self._parse("[]", fallback="q") == ["q"]
+
+ def test_non_string_items_returns_fallback(self):
+ assert self._parse("[1, 2, 3]", fallback="q") == ["q"]
+
+ def test_mixed_types_returns_fallback(self):
+ assert self._parse('["a", 1]', fallback="q") == ["q"]
+
+
+# ---------------------------------------------------------------------------
+# SubQueryGeneratorOperator.preprocess — no LLM calls needed
+# ---------------------------------------------------------------------------
+
+
+class TestSubQueryPreprocess:
+ def _op(self):
+ from nemo_retriever.graph.subquery_operator import SubQueryGeneratorOperator
+
+ return SubQueryGeneratorOperator(llm_model="test-model")
+
+ def test_dataframe_accepted(self):
+ op = self._op()
+ df = pd.DataFrame({"query_id": ["q1"], "query_text": ["hello"]})
+ result = op.preprocess(df)
+ assert list(result["query_id"]) == ["q1"]
+
+ def test_dataframe_missing_query_id_raises(self):
+ op = self._op()
+ bad = pd.DataFrame({"query_text": ["hello"]})
+ with pytest.raises(ValueError, match="query_id"):
+ op.preprocess(bad)
+
+ def test_dataframe_missing_query_text_raises(self):
+ op = self._op()
+ bad = pd.DataFrame({"query_id": ["q1"]})
+ with pytest.raises(ValueError, match="query_text"):
+ op.preprocess(bad)
+
+ def test_list_of_strings_auto_ids(self):
+ op = self._op()
+ result = op.preprocess(["alpha", "beta"])
+ assert result["query_id"].tolist() == ["q0", "q1"]
+ assert result["query_text"].tolist() == ["alpha", "beta"]
+
+ def test_list_of_tuples(self):
+ op = self._op()
+ result = op.preprocess([("id1", "alpha"), ("id2", "beta")])
+ assert result["query_id"].tolist() == ["id1", "id2"]
+
+ def test_unsupported_type_raises(self):
+ op = self._op()
+ with pytest.raises(TypeError):
+ op.preprocess({"query_id": "q1", "query_text": "hello"})
+
+
+class TestSubQueryGeneratorOperator:
+ """Tests for _build_system_prompt and _generate_one."""
+
+ def _op(self, **kwargs):
+ from nemo_retriever.graph.subquery_operator import SubQueryGeneratorOperator
+
+ return SubQueryGeneratorOperator(llm_model="test-model", **kwargs)
+
+ # -- _build_system_prompt -------------------------------------------------
+
+ def test_decompose_prompt_contains_max_subqueries(self):
+ op = self._op(strategy="decompose", max_subqueries=6)
+ prompt = op._build_system_prompt()
+ assert "6" in prompt
+ assert "decompos" in prompt.lower()
+
+ def test_hyde_prompt_contains_max_subqueries(self):
+ op = self._op(strategy="hyde", max_subqueries=3)
+ prompt = op._build_system_prompt()
+ assert "3" in prompt
+ assert "hypothetical" in prompt.lower()
+
+ def test_multi_perspective_prompt_contains_max_subqueries(self):
+ op = self._op(strategy="multi_perspective", max_subqueries=5)
+ prompt = op._build_system_prompt()
+ assert "5" in prompt
+ assert "perspective" in prompt.lower()
+
+ def test_system_prompt_override_used_instead_of_strategy(self):
+ op = self._op(system_prompt_override="Custom prompt max={max_subqueries}", max_subqueries=2)
+ assert op._build_system_prompt() == "Custom prompt max=2"
+
+ # -- _generate_one --------------------------------------------------------
+
+ @patch("nemo_retriever.graph.subquery_operator.invoke_chat_completions")
+ def test_generate_one_happy_path(self, mock_invoke):
+ mock_invoke.return_value = ['["sub1", "sub2", "sub3"]']
+ op = self._op(max_subqueries=4)
+ result = op._generate_one("What causes inflation?", "system prompt")
+ assert result == ["sub1", "sub2", "sub3"]
+ mock_invoke.assert_called_once()
+
+ @patch("nemo_retriever.graph.subquery_operator.invoke_chat_completions")
+ def test_generate_one_fenced_json(self, mock_invoke):
+ mock_invoke.return_value = ['```json\n["a", "b"]\n```']
+ op = self._op()
+ assert op._generate_one("q", "sys") == ["a", "b"]
+
+ @patch("nemo_retriever.graph.subquery_operator.invoke_chat_completions")
+ def test_generate_one_malformed_json_falls_back(self, mock_invoke):
+ mock_invoke.return_value = ["not valid json"]
+ op = self._op()
+ assert op._generate_one("original query", "sys") == ["original query"]
+
+ @patch("nemo_retriever.graph.subquery_operator.invoke_chat_completions")
+ def test_generate_one_llm_error_falls_back(self, mock_invoke):
+ mock_invoke.side_effect = RuntimeError("connection timeout")
+ op = self._op()
+ assert op._generate_one("original query", "sys") == ["original query"]
+
+
+# ---------------------------------------------------------------------------
+# SelectionAgentOperator.preprocess — no LLM calls needed
+# ---------------------------------------------------------------------------
+
+
+class TestSelectionAgentPreprocess:
+ def _op(self):
+ from nemo_retriever.graph.selection_agent_operator import SelectionAgentOperator
+
+ return SelectionAgentOperator(llm_model="test-model", invoke_url="http://localhost/v1/chat/completions")
+
+ def test_valid_dataframe_accepted(self):
+ op = self._op()
+ df = pd.DataFrame({"query_id": ["q1"], "query_text": ["q"], "doc_id": ["d1"], "text": ["t"]})
+ result = op.preprocess(df)
+ assert len(result) == 1
+
+ def test_missing_doc_id_raises(self):
+ op = self._op()
+ bad = pd.DataFrame({"query_id": ["q1"], "query_text": ["q"], "text": ["t"]})
+ with pytest.raises(ValueError, match="doc_id"):
+ op.preprocess(bad)
+
+ def test_missing_text_raises(self):
+ op = self._op()
+ bad = pd.DataFrame({"query_id": ["q1"], "query_text": ["q"], "doc_id": ["d1"]})
+ with pytest.raises(ValueError, match="text"):
+ op.preprocess(bad)
+
+ def test_non_dataframe_raises(self):
+ op = self._op()
+ with pytest.raises(TypeError):
+ op.preprocess([{"query_id": "q1"}])
+
+
+# ---------------------------------------------------------------------------
+# SelectionAgentOperator — max_steps exhausted fallback
+# ---------------------------------------------------------------------------
+
+
+class TestSelectionAgentMaxSteps:
+ @patch("nemo_retriever.graph.selection_agent_operator.invoke_chat_completion_step")
+ def test_max_steps_exhausted_returns_empty(self, mock_step):
+ from nemo_retriever.graph.selection_agent_operator import SelectionAgentOperator
+
+ # LLM only ever calls think — never log_selected_documents
+ mock_step.return_value = _make_tool_call_response("think", {"thought": "still thinking..."})
+
+ op = SelectionAgentOperator(
+ llm_model="test-model",
+ invoke_url="http://localhost/v1/chat/completions",
+ top_k=2,
+ max_steps=3,
+ )
+ df = pd.DataFrame(
+ {
+ "query_id": ["q1", "q1"],
+ "query_text": ["What causes inflation?"] * 2,
+ "doc_id": ["d1", "d2"],
+ "text": ["doc one", "doc two"],
+ }
+ )
+ result = op.run(df)
+
+ assert len(result) == 0
+ assert list(result.columns) == ["query_id", "doc_id", "rank", "message"]
+ assert mock_step.call_count == 3