diff --git a/nemo_retriever/src/nemo_retriever/graph/react_agent_operator.py b/nemo_retriever/src/nemo_retriever/graph/react_agent_operator.py new file mode 100644 index 0000000000..94777db5d9 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/graph/react_agent_operator.py @@ -0,0 +1,754 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Operator that runs a ReAct agentic retrieval loop per query.""" + +from __future__ import annotations + +import json +import logging +import os + +import requests +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any, Callable, Dict, List, Literal, Optional + +import pandas as pd + +from nemo_retriever.graph.abstract_operator import AbstractOperator +from nemo_retriever.graph.cpu_operator import CPUOperator +from nemo_retriever.nim.chat_completions import invoke_chat_completion_step + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Prompt rendering (verbatim content of 02_v1.j2, rendered via Python) +# --------------------------------------------------------------------------- + +_GOAL = """\ +You are a retrieval agent that finds all documents related to a given query. + + +You are given a search query and a list of documents retrieved for that query. Your task is to write new \ +queries and use the given search tool to find *ALL* the related and somewhat related documents to the given \ +query (i.e., maximize recall). +If the user's query is a question, you should not answer the question yourself. Instead, you should find \ +the related documents for the given query. +""" + +_RELEVANCE_DEFINITION = """ + + +- You should be careful, in the context of this task, what it means to be a "query", "document", and \ +"relevant" can sometimes be very complex and might not follow the traditional definition of these terms \ +in standard information retrieval. +- In standard retrieval, a query is usually a user question (like a web search query), the document is \ +some sort of content that provides information (e.g., a web page), and these two are considered relevant if \ +the document provides information that answers the user's query. +- However, in our setting, this could be different. Here are some examples: + * the query is a programming problem and documents are programming language syntax references. A document \ +is relevant if it contains the reference for the programming syntax used for solving the problem. + * both query and documents are descriptions programming problems and a query and document are relevant if \ +the same approach is used to solve them. + * the query is a math problem and documents are theorems. Relevant documents (theorems) are the ones \ +that are useful for solving the math problem. + * the query and document are both math problems. A query and a document are relevant if the same theorem \ +is used for solving them. + * the query is a task description (e.g., for an API programmer) and documents are descriptions of \ +available APIs. Relevant documents (e.g., APIs) are the ones needed for completing the task. +- This is not an exhaustive list. These are just some examples to show you the complexity of queries, \ +documents, and the concept of relevance in this task. +- Note that even here, the relevant documents are still the ones that are useful for a user who is \ +searching for the given query. But the relation is more nuanced. +- You should analyze the query and some of the available documents. And then reason about what could be a \ +meaningful definition of relevance in this case, and what the user could be looking for. +- Moreover, sometimes, the query could be even a prompt that is given to a Large Language Model (LLM) and \ +the user wants to find the useful documents for the LLM that help answering/solving this prompt. +""" + +_WORKFLOW_TEMPLATE = """ + +- You are given a retrieval tool, powered by a dense embedding model, that takes a text query and returns \ +the most similar documents. +{extended_relevance_line}\ +- You can call the search tool multiple times. +- Search for related documents to the user's query from different angles. +- If needed, revise your search queries based on the documents you find in previous steps. +- Once you are confident that you have found all the related and somewhat related documents and there are \ +no more related documents in the corpus, call the "final_results" tool to finish the task. +{enforce_top_k_line}\ +- When calling the "final_results" tool, the list of documents must be sorted in the decreasing level of \ +relevance to the query. I.e., the first document is the most relevant to the query, the second document is \ +the second most relevant to the query, and so on. +""" + +_BEST_PRACTICES_TEMPLATE = """ + + +- You should be thorough and find all related and somewhat related documents. +- The goal is to increase the **Recall** of your search attempt. So, if multiple documents are relevant \ +to the given query, you should find and report all of them even if only a subset of them is enough \ +for answering the query. +{with_init_docs_line}\ +""" + + +def _render_react_agent_prompt( + top_k: int, + *, + with_init_docs: bool = True, + enforce_top_k: bool = True, + extended_relevance: bool = False, +) -> str: + """Render the ReAct agent system prompt (verbatim 02_v1.j2 logic).""" + parts = [_GOAL] + if extended_relevance: + parts.append(_RELEVANCE_DEFINITION) + + ext_line = ( + "- As explained above, reason and figure out what the meaning of relevance is in this case, " + "and what could be relevant and useful information for the given query.\n" + if extended_relevance + else "" + ) + enforce_line = ( + f'- When calling "final_results", you must select exactly the {top_k} most relevant documents ' + "among all documents you have retrieved.\n" + if enforce_top_k + else "" + ) + parts.append( + _WORKFLOW_TEMPLATE.format( + extended_relevance_line=ext_line, + enforce_top_k_line=enforce_line, + ) + ) + + init_docs_line = ( + "- **TIP**: you can look at the list of documents retrieved using the original query and think " + "what other queries you can use to find the potentially related documents that are missing in these results.\n" + if with_init_docs + else "" + ) + parts.append(_BEST_PRACTICES_TEMPLATE.format(with_init_docs_line=init_docs_line)) + return "".join(parts) + + +# --------------------------------------------------------------------------- +# Tool specs (verbatim from retrieval_bench/nemo_agentic/tool_helpers.py) +# --------------------------------------------------------------------------- + + +def _make_think_tool_spec(extended_relevance: bool = False) -> Dict[str, Any]: + ext = "" + if extended_relevance: + ext = ( + "- When it is difficult to understand what is the intent of the user and what they are trying " + "to find with this query, use this tool to think about potential definitions of relevance that " + "could be meaningful/useful to the user for this task.\n" + "- If the intention of the user is vague especially given the available documents, use this tool " + "to think how you should decide what documents are relevant and what the metric of relevance is.\n" + ) + description = ( + "Use the tool to think about something. It will not obtain new information or make any changes, " + "but just log the thought. Use it when complex reasoning or brainstorming is needed.\n\n" + "Common use cases:\n" + f"{ext}" + "- When processing a complex query, use this tool to organize your thoughts and think about " + "the sub queries that you need to search for to find the relevant information\n" + "- If a query is vague is very difficult to find information for it, you can use this tool to think " + "about clues in the query that you can use to narrow down the search and spot relevant pieces of information.\n" + "- When finding related documents that help you create better search queries in the next step, use this " + "tool to think about what pieces of information from these documents are helpful to search for.\n" + "- When you fail to find any related information to the query, use this tool to think about other " + "search strategies that you can take to retrieve the related documents\n\n" + "The tool simply logs your thought process for better transparency and does not make any changes." + ) + return { + "type": "function", + "function": { + "name": "think", + "description": description, + "parameters": { + "type": "object", + "properties": {"thought": {"type": "string", "description": "The thought to log."}}, + "required": ["thought"], + }, + }, + } + + +def _make_retrieve_tool_spec(top_k: int) -> Dict[str, Any]: + return { + "type": "function", + "function": { + "name": "retrieve", + "description": ( + "Search for documents relevant to the given query using a dense embedding retrieval system. " + "Returns the most semantically similar documents from the corpus." + ), + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query to retrieve documents for.", + }, + }, + "required": ["query"], + }, + }, + } + + +def _make_final_results_tool_spec(top_k: Optional[int]) -> Dict[str, Any]: + tk_ins = "" + if top_k is not None: + tk_ins = f"- You must choose exactly {top_k} document IDs when calling this function.\n" + + description = ( + "Signals the completion of the search process for the current query.\n\n" + "Use this tool when:\n" + "- You have found all the relevant documents to the query.\n" + "- Despite several attempts, you cannot find good documents for the given query.\n\n" + "The message should include:\n" + "- A brief summary of your exploration and the results\n" + "- Explanation if the search was unsuccessful\n\n" + "When reporting the selected document IDs, make sure:\n" + "- the list of document IDs is sorted in the decreasing level of relevance to the query. " + "I.e., the first document in the list is the most relevant to the query, the second is the " + "second most relevant to the query, and so on.\n" + f"{tk_ins}" + "\nThe successful_search field should be set to true if you believed you have found the most " + "relevant documents to the user's query, and false otherwise. And partial if it is in between." + ) + return { + "type": "function", + "function": { + "name": "final_results", + "description": description, + "parameters": { + "type": "object", + "required": ["doc_ids", "message", "search_successful"], + "properties": { + "message": { + "type": "string", + "description": ( + "A message for the user to explain why you think you found all the related " + "documents and there is no related document is missing. Also, include a short " + "description of your exploration process. If your attempts to find related " + "documents were unsuccessful, explain why." + ), + }, + "doc_ids": { + "type": "array", + "items": {"type": "string"}, + "description": ( + "List of document IDs that are relevant to the user's query sorted descending " + "by their level of relevance to the user's query. I.e., the first document is " + "the most relevant to the query, the second is the second most relevant to the " + "query, and so on." + ), + }, + "search_successful": { + "type": "string", + "enum": ["true", "false", "partial"], + "description": "Whether you managed to find all the related documents to the query.", + }, + }, + }, + }, + } + + +# --------------------------------------------------------------------------- +# Operator +# --------------------------------------------------------------------------- + +#: Message sent when the LLM produces a stop without calling any tool. +_AUTO_USER_MSG = ( + "continue with the task. Do not re-read the query. Do not summarize your progress. " + "If you believe you have done all the required steps, call the `final_results` tool" +) + + +class ReActAgentOperator(AbstractOperator, CPUOperator): + """Run an iterative ReAct retrieval loop per query and emit the full retrieval log. + + Each query row is processed independently by an LLM-driven ReAct loop + (Reason + Act) that has access to three tools: ``think``, ``retrieve``, + and ``final_results``. The operator emits one output row per retrieved + document per retrieval step, enabling downstream + :class:`RRFAggregatorOperator` to fuse the ranked lists with Reciprocal + Rank Fusion. + + The system prompt is a verbatim Python rendering of the retrieval-bench + ``02_v1.j2`` template, including optional ``extended_relevance`` and + ``enforce_top_k`` blocks. + + Input DataFrame schema + ---------------------- + query_id : str — unique query identifier + query_text : str — the search query text + (additional columns are ignored) + + Output DataFrame schema + ----------------------- + query_id : str — same ``query_id`` as the input + query_text : str — same ``query_text`` (passed through for downstream) + step_idx : int — 0 = initial seed retrieval; 1 … N = per-loop retrieve calls + doc_id : str — retrieved document identifier + text : str — document text + rank : int — 1-indexed rank within this step (1 = most relevant) + + Parameters + ---------- + invoke_url : str + Full ``/v1/chat/completions`` endpoint URL. + llm_model : str + Model identifier forwarded to the endpoint. + retriever_fn : Callable[[str, int], list[dict]] + ``(query_text, top_k) → [{doc_id: str, text: str, ...}]``. + The callable is invoked for every retrieve tool call the agent makes. + Each returned dict must contain ``doc_id`` and ``text`` keys. + retriever_top_k : int + Number of documents fetched per retrieve call. Defaults to ``500``. + target_top_k : int + Number of final documents to select, communicated to the LLM via the + system prompt and ``final_results`` tool spec. Defaults to ``10``. + enforce_top_k : bool + When ``True``, the system prompt instructs the LLM to select exactly + ``target_top_k`` documents in its ``final_results`` call. + Defaults to ``True``. + user_msg_type : {"with_results", "simple"} + ``"with_results"`` (default): make one upfront retrieval call with the + original query and include those documents in the first user message, + mirroring the retrieval-bench ``with_results`` mode. + ``"simple"``: start the loop with just the query text. + extended_relevance : bool + Include the ```` block in the system prompt for + tasks with non-standard relevance definitions. Defaults to ``False``. + max_steps : int + Maximum ReAct loop iterations per query before forced exit. + Defaults to ``10``. + num_concurrent : int + Number of queries processed concurrently via ``ThreadPoolExecutor``. + Defaults to ``8``. + api_key : str, optional + Literal API key **or** an ``"os.environ/VAR_NAME"`` reference. + max_tokens : int, optional + Upper bound on tokens in each LLM response. + + Notes + ----- + ``retriever_fn`` must be serialisable when used with ``RayDataExecutor``. + Prefer module-level functions or picklable callable objects over lambdas. + + Examples + -------- + :: + + from nemo_retriever.graph.react_agent_operator import ReActAgentOperator + from nemo_retriever.graph.rrf_aggregator_operator import RRFAggregatorOperator + from nemo_retriever.graph.selection_agent_operator import SelectionAgentOperator + from nemo_retriever.graph.executor import InprocessExecutor + + def my_retriever(query_text: str, top_k: int) -> list[dict]: + # Returns [{doc_id, text, score?}, ...] + ... + + pipeline = ( + ReActAgentOperator( + invoke_url="https://integrate.api.nvidia.com/v1/chat/completions", + llm_model="nvidia/llama-3.3-nemotron-super-49b-v1", + retriever_fn=my_retriever, + retriever_top_k=500, + target_top_k=10, + ) + >> RRFAggregatorOperator(k=60) + >> SelectionAgentOperator( + invoke_url="https://integrate.api.nvidia.com/v1/chat/completions", + llm_model="nvidia/llama-3.3-nemotron-super-49b-v1", + top_k=10, + ) + ) + + result_df = InprocessExecutor(pipeline).ingest(query_df) + """ + + _NVIDIA_BUILD_ENDPOINT = "https://integrate.api.nvidia.com/v1/chat/completions" + + def __init__( + self, + *, + invoke_url: Optional[str] = None, + llm_model: str, + retriever_fn: Callable[[str, int], List[Dict[str, Any]]], + retriever_top_k: int = 500, + target_top_k: int = 10, + enforce_top_k: bool = True, + user_msg_type: Literal["with_results", "simple"] = "with_results", + extended_relevance: bool = False, + max_steps: int = 10, + num_concurrent: int = 8, + api_key: Optional[str] = None, + max_tokens: Optional[int] = None, + parallel_tool_calls: bool = True, + ) -> None: + super().__init__() + self._invoke_url = invoke_url or self._NVIDIA_BUILD_ENDPOINT + self._llm_model = llm_model + self._retriever_fn = retriever_fn + self._retriever_top_k = retriever_top_k + self._target_top_k = target_top_k + self._enforce_top_k = enforce_top_k + self._user_msg_type = user_msg_type + self._extended_relevance = extended_relevance + self._max_steps = max_steps + self._num_concurrent = num_concurrent + self._api_key = api_key + self._max_tokens = max_tokens + self._parallel_tool_calls = parallel_tool_calls + + # ------------------------------------------------------------------ + # AbstractOperator interface + # ------------------------------------------------------------------ + + def preprocess(self, data: Any, **kwargs: Any) -> pd.DataFrame: + if not isinstance(data, pd.DataFrame): + raise TypeError(f"ReActAgentOperator expects a pd.DataFrame, got {type(data).__name__!r}.") + required = {"query_id", "query_text"} + missing = required - set(data.columns) + if missing: + raise ValueError( + f"Input DataFrame is missing required column(s): {sorted(missing)}. " f"Expected: {sorted(required)}." + ) + return data[["query_id", "query_text"]].copy() + + def process(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: + """Run the ReAct loop for each query, concurrently up to num_concurrent.""" + api_key = self._resolve_api_key() + rows: List[Dict[str, Any]] = [] + + query_rows = [(str(r["query_id"]), str(r["query_text"])) for _, r in data.iterrows()] + + if len(query_rows) == 1: + # Fast path: single query, no threading overhead + qid, qtxt = query_rows[0] + rows.extend(self._run_single_query(qid, qtxt, api_key)) + else: + with ThreadPoolExecutor(max_workers=min(self._num_concurrent, len(query_rows))) as executor: + futures = { + executor.submit(self._run_single_query, qid, qtxt, api_key): (qid, qtxt) for qid, qtxt in query_rows + } + for future in as_completed(futures): + try: + rows.extend(future.result()) + except TimeoutError as exc: + qid, qtxt = futures[future] + logger.warning("ReActAgentOperator: query %r timed out: %s", qid, exc, exc_info=True) + except RuntimeError as exc: + qid, qtxt = futures[future] + logger.warning("ReActAgentOperator: query %r retries exhausted: %s", qid, exc, exc_info=True) + except requests.RequestException as exc: + qid, qtxt = futures[future] + logger.warning("ReActAgentOperator: query %r HTTP error: %s", qid, exc, exc_info=True) + except (json.JSONDecodeError, ValueError) as exc: + qid, qtxt = futures[future] + logger.warning("ReActAgentOperator: query %r data error: %s", qid, exc, exc_info=True) + except Exception as exc: # catches unexpected worker errors not covered above + qid, qtxt = futures[future] + logger.warning("ReActAgentOperator: query %r failed: %s", qid, exc, exc_info=True) + + if not rows: + return pd.DataFrame(columns=["query_id", "query_text", "step_idx", "doc_id", "text", "rank"]) + + return pd.DataFrame(rows) + + def postprocess(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: + return data + + # ------------------------------------------------------------------ + # Internal: single query ReAct loop + # ------------------------------------------------------------------ + + def _run_single_query( + self, + query_id: str, + query_text: str, + api_key: Optional[str], + ) -> List[Dict[str, Any]]: + """Run the full ReAct loop for one query; return a list of row dicts.""" + with_init_docs = self._user_msg_type == "with_results" + + system_prompt = _render_react_agent_prompt( + self._target_top_k, + with_init_docs=with_init_docs, + enforce_top_k=self._enforce_top_k, + extended_relevance=self._extended_relevance, + ) + tools = [ + _make_think_tool_spec(self._extended_relevance), + _make_retrieve_tool_spec(self._retriever_top_k), + _make_final_results_tool_spec(self._target_top_k if self._enforce_top_k else None), + ] + + messages: List[Dict[str, Any]] = [{"role": "system", "content": system_prompt}] + + # Retrieval log: one list per step, each item is {doc_id, text, score?} + retrieval_log: List[List[Dict[str, Any]]] = [] + seen_doc_ids: set[str] = set() + step_counter = 0 + + # ------ optional initial retrieval (with_results mode) ------ + if with_init_docs: + init_docs = self._call_retriever(query_text, seen_doc_ids, api_key) + retrieval_log.append(init_docs) + step_counter += 1 + for d in init_docs: + seen_doc_ids.add(d["doc_id"]) + + doc_content = _docs_to_message_content(init_docs) + user_msg_content: List[Dict[str, Any]] = [ + {"type": "text", "text": f"Query:\n{query_text}\n\nRetrieved Documents:"} + ] + doc_content + messages.append({"role": "user", "content": user_msg_content}) + else: + messages.append({"role": "user", "content": f"Query:\n{query_text}"}) + + final_doc_ids: Optional[List[str]] = None + + # ------ main ReAct loop ------ + for _step in range(self._max_steps): + logger.debug("query=%r loop_step=%d seen_docs=%d", query_id, _step, len(seen_doc_ids)) + try: + response = invoke_chat_completion_step( + invoke_url=self._invoke_url, + messages=messages, + model=self._llm_model, + api_key=api_key, + tools=tools, + tool_choice="auto", + max_tokens=self._max_tokens, + extra_body={"parallel_tool_calls": False} if not self._parallel_tool_calls else None, + ) + except TimeoutError as exc: + logger.warning( + "ReActAgentOperator: LLM call timed out on step %d for query %r: %s", + _step, + query_id, + exc, + exc_info=True, + ) + break + except RuntimeError as exc: + logger.warning( + "ReActAgentOperator: LLM retries exhausted on step %d for query %r: %s", + _step, + query_id, + exc, + exc_info=True, + ) + break + except requests.RequestException as exc: + logger.warning( + "ReActAgentOperator: LLM HTTP error on step %d for query %r: %s", + _step, + query_id, + exc, + exc_info=True, + ) + break + except json.JSONDecodeError as exc: + logger.warning( + "ReActAgentOperator: LLM returned invalid JSON on step %d for query %r: %s", + _step, + query_id, + exc, + exc_info=True, + ) + break + + if not response.get("choices"): + logger.warning( + "ReActAgentOperator: empty choices in API response on step %d for query %r", _step, query_id + ) + break + choice = response["choices"][0] + msg = choice["message"] + finish_reason = choice.get("finish_reason") + tool_calls = msg.get("tool_calls") or [] + + # Append assistant turn + assistant_turn: Dict[str, Any] = {"role": "assistant"} + if msg.get("content"): + assistant_turn["content"] = msg["content"] + if tool_calls: + assistant_turn["tool_calls"] = tool_calls + messages.append(assistant_turn) + + if finish_reason == "stop" or not tool_calls: + messages.append({"role": "user", "content": _AUTO_USER_MSG}) + continue + + tool_messages: List[Dict[str, Any]] = [] + loop_done = False + + for tc in tool_calls: + tc_id = tc.get("id", "") + fn = tc.get("function", {}) + fn_name = fn.get("name", "") + try: + fn_args = json.loads(fn.get("arguments", "{}")) + except json.JSONDecodeError: + tool_messages.append( + {"role": "tool", "tool_call_id": tc_id, "content": "Error: could not parse tool arguments."} + ) + continue + + if fn_name == "think": + logger.debug("query=%r step=%d [think] %s", query_id, _step, str(fn_args.get("thought", ""))[:120]) + tool_messages.append( + {"role": "tool", "tool_call_id": tc_id, "content": "Your thought has been logged."} + ) + + elif fn_name == "retrieve": + subquery = str(fn_args.get("query", query_text)) + logger.debug("query=%r step=%d [retrieve] subquery=%r", query_id, _step, subquery) + retrieved = self._call_retriever(subquery, seen_doc_ids, api_key) + logger.debug("query=%r step=%d [retrieve] got %d new docs", query_id, _step, len(retrieved)) + retrieval_log.append(retrieved) + step_counter += 1 + for d in retrieved: + seen_doc_ids.add(d["doc_id"]) + doc_content = _docs_to_message_content(retrieved) + tool_content: List[Dict[str, Any]] = [ + {"type": "text", "text": f"Retrieved {len(retrieved)} documents:"} + ] + doc_content + tool_messages.append({"role": "tool", "tool_call_id": tc_id, "content": tool_content}) + + elif fn_name == "final_results": + raw_ids: List[str] = fn_args.get("doc_ids", []) + logger.debug("query=%r step=%d [final_results] doc_ids=%s", query_id, _step, raw_ids) + if isinstance(raw_ids, list) and raw_ids: + final_doc_ids = [str(d) for d in raw_ids] + tool_messages.append( + { + "role": "tool", + "tool_call_id": tc_id, + "content": "The results have been successfully logged and the interaction ended.", + } + ) + loop_done = True + + else: + tool_messages.append( + {"role": "tool", "tool_call_id": tc_id, "content": f"Error: unknown tool '{fn_name}'."} + ) + + messages.extend(tool_messages) + if loop_done: + break + + return _build_output_rows(query_id, query_text, retrieval_log, final_doc_ids) + + def _call_retriever( + self, + query_text: str, + seen_doc_ids: set[str], + api_key: Optional[str], + ) -> List[Dict[str, Any]]: + """Call retriever_fn, over-fetching to ensure new results after dedup.""" + fetch_k = self._retriever_top_k + len(seen_doc_ids) + try: + raw = self._retriever_fn(query_text, fetch_k) + except TimeoutError as exc: + logger.warning( + "ReActAgentOperator: retriever_fn timed out for query %r: %s", query_text, exc, exc_info=True + ) + return [] + except (TypeError, ValueError) as exc: + logger.warning( + "ReActAgentOperator: retriever_fn bad call/return for query %r: %s", query_text, exc, exc_info=True + ) + return [] + except Exception as exc: # retriever_fn is user-supplied; catches remaining unexpected errors. + logger.warning("ReActAgentOperator: retriever_fn failed for query %r: %s", query_text, exc, exc_info=True) + return [] + + # Filter already-seen and normalise keys + results: List[Dict[str, Any]] = [] + for item in raw: + doc_id = str(item.get("doc_id", item.get("id", ""))) + text = str(item.get("text", "")) + score = float(item.get("score", 0.0)) + if doc_id and doc_id not in seen_doc_ids: + results.append({"doc_id": doc_id, "text": text, "score": score}) + if len(results) >= self._retriever_top_k: + break + + return results + + def _resolve_api_key(self) -> Optional[str]: + api_key = self._api_key + if api_key is not None and api_key.strip().startswith("os.environ/"): + var = api_key.strip().removeprefix("os.environ/") + value = os.environ.get(var) + if value is None: + raise ValueError( + f"Environment variable '{var}' is not set. " f"Set it with: export {var}=" + ) + return value + return api_key + + +# --------------------------------------------------------------------------- +# Module-level helpers +# --------------------------------------------------------------------------- + + +def _docs_to_message_content(docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Convert a list of doc dicts to LLM message content blocks.""" + content: List[Dict[str, Any]] = [] + for doc in docs: + doc_id = doc.get("doc_id", "") + text = doc.get("text", "").strip() + entry: Dict[str, Any] = {"id": doc_id} + if text: + entry["text"] = text + score = doc.get("score") + if score is not None: + entry["score"] = score + content.append({"type": "text", "text": json.dumps(entry)}) + return content + + +def _build_output_rows( + query_id: str, + query_text: str, + retrieval_log: List[List[Dict[str, Any]]], + final_doc_ids: Optional[List[str]], +) -> List[Dict[str, Any]]: + """Convert the retrieval log to one row per (step_idx, rank, doc_id).""" + rows: List[Dict[str, Any]] = [] + for step_idx, step_docs in enumerate(retrieval_log): + for rank, doc in enumerate(step_docs, 1): + rows.append( + { + "query_id": query_id, + "query_text": query_text, + "step_idx": step_idx, + "doc_id": doc.get("doc_id", ""), + "text": doc.get("text", ""), + "rank": rank, + } + ) + + # If final_results was called, also emit those as a synthetic final step + # (step_idx = len(retrieval_log)) so RRF can optionally weight it. + # These are already covered by the existing steps, so we skip deduplication + # here — RRF will naturally up-weight docs that appeared in final_results + # because they were retrieved in earlier steps. + return rows diff --git a/nemo_retriever/src/nemo_retriever/graph/rrf_aggregator_operator.py b/nemo_retriever/src/nemo_retriever/graph/rrf_aggregator_operator.py new file mode 100644 index 0000000000..39053bc3da --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/graph/rrf_aggregator_operator.py @@ -0,0 +1,133 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Operator that fuses per-step retrieval results using Reciprocal Rank Fusion.""" + +from __future__ import annotations + +from collections import defaultdict +from typing import Any, Dict, List + +import pandas as pd + +from nemo_retriever.graph.abstract_operator import AbstractOperator +from nemo_retriever.graph.cpu_operator import CPUOperator + + +class RRFAggregatorOperator(AbstractOperator, CPUOperator): + """Fuse multiple per-step ranked lists into a single ranking per query using RRF. + + Implements the Reciprocal Rank Fusion formula + ``score(d) = sum(1 / (rank_i + k))`` across all retrieval steps where + document *d* appears. This is the same formula used in + ``retrieval_bench/nemo_agentic/utils.py:rrf_from_subquery_results``. + + Designed to consume the output of :class:`ReActAgentOperator` (or any + operator that emits one row per ``(query_id, step_idx, doc_id)`` triple) + and produce a single fused ranking per ``query_id`` suitable as input to + :class:`SelectionAgentOperator`. + + Input DataFrame schema + ---------------------- + query_id : str — unique query identifier + query_text : str — original query text (carried through) + step_idx : int — which retrieval step produced this row (0, 1, 2 …) + doc_id : str — retrieved document identifier + text : str — document text content + rank : int — 1-indexed rank within its step (1 = most relevant) + (additional columns are ignored) + + Output DataFrame schema + ----------------------- + query_id : str — same ``query_id`` as the input + query_text: str — original query text (first occurrence per query) + doc_id : str — document identifier + rrf_score : float — fused RRF score (higher = more relevant) + text : str — document text (first occurrence per ``doc_id``) + Rows are sorted by ``rrf_score`` descending within each ``query_id``. + + Parameters + ---------- + k : int + RRF damping factor. The standard value is ``60`` (default). + Larger values reduce the influence of top-ranked documents. + + Examples + -------- + :: + + import pandas as pd + from nemo_retriever.graph.rrf_aggregator_operator import RRFAggregatorOperator + + op = RRFAggregatorOperator(k=60) + df = pd.DataFrame({ + "query_id": ["q1", "q1", "q1", "q1"], + "query_text": ["inflation causes"] * 4, + "step_idx": [0, 0, 1, 1 ], + "doc_id": ["d1", "d2", "d1", "d3"], + "text": ["t1", "t2", "t1", "t3"], + "rank": [1, 2, 1, 2 ], + }) + result = op.run(df) + # d1 appears in both steps at rank 1 → highest RRF score + """ + + def __init__(self, *, k: int = 60) -> None: + super().__init__() + self._k = k + + # ------------------------------------------------------------------ + # AbstractOperator interface + # ------------------------------------------------------------------ + + def preprocess(self, data: Any, **kwargs: Any) -> pd.DataFrame: + if not isinstance(data, pd.DataFrame): + raise TypeError(f"RRFAggregatorOperator expects a pd.DataFrame, got {type(data).__name__!r}.") + required = {"query_id", "query_text", "step_idx", "doc_id", "text", "rank"} + missing = required - set(data.columns) + if missing: + raise ValueError( + f"Input DataFrame is missing required column(s): {sorted(missing)}. " f"Expected: {sorted(required)}." + ) + return data.copy() + + def process(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: + """Compute RRF scores, group by query_id, sort by score descending.""" + k = self._k + rows: List[Dict[str, Any]] = [] + + for query_id, qgroup in data.groupby("query_id", sort=False): + query_text = str(qgroup["query_text"].iloc[0]) + + rrf_scores: Dict[str, float] = defaultdict(float) + first_text: Dict[str, str] = {} + + # Process each step's ranked list + for _step_idx, sgroup in qgroup.groupby("step_idx", sort=True): + # Sort by rank ascending so rank=1 is processed first + for _, row in sgroup.sort_values("rank").iterrows(): + doc_id = str(row["doc_id"]) + rank = int(row["rank"]) + rrf_scores[doc_id] += 1.0 / (rank + k) + if doc_id not in first_text: + first_text[doc_id] = str(row["text"]) + + for doc_id, score in sorted(rrf_scores.items(), key=lambda kv: kv[1], reverse=True): + rows.append( + { + "query_id": query_id, + "query_text": query_text, + "doc_id": doc_id, + "rrf_score": score, + "text": first_text.get(doc_id, ""), + } + ) + + if not rows: + return pd.DataFrame(columns=["query_id", "query_text", "doc_id", "rrf_score", "text"]) + + return pd.DataFrame(rows) + + def postprocess(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: + return data diff --git a/nemo_retriever/src/nemo_retriever/graph/selection_agent_operator.py b/nemo_retriever/src/nemo_retriever/graph/selection_agent_operator.py new file mode 100644 index 0000000000..14371969a8 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/graph/selection_agent_operator.py @@ -0,0 +1,512 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Operator that re-ranks retrieved documents using an LLM-based selection agent.""" + +from __future__ import annotations + +import json +import logging +import os + +import requests +from typing import Any, Dict, List, Optional + +import pandas as pd + +from nemo_retriever.graph.abstract_operator import AbstractOperator +from nemo_retriever.graph.cpu_operator import CPUOperator +from nemo_retriever.nim.chat_completions import invoke_chat_completion_step + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Prompt rendering (verbatim content of 01_v0.j2, rendered via Python) +# --------------------------------------------------------------------------- + +_ROLE = """\ +You are a document re-ranker agent, which is the final stage in an information retrieval pipeline. + + +You are given a search query and a list of retrieved candidate documents that are potentially relevant to \ +the given query. Your goal is to help the users identify the most relevant documents to the given query \ +from the list of candidate documents. +""" + +_RELEVANCE_DEFINITION = """\ + + +- You should be careful, in the context of this task, what it means to be a "query", "document", and \ +"relevant" can sometimes be very complex and might not follow the traditional definition of these terms \ +in standard re-ranking and retrieval. +- In standard re-ranking/retrieval, a query is usually a user question (like a web search query), the \ +document is some sort of content that provides information (e.g., a web page), and these two are considered \ +relevant if the document provides information that answers the user's query. +- However, in our setting, this could be different. Here are some examples: + * the query is a programming problem and documents are programming language syntax references. A document \ +is relevant if it contains the reference for the programming syntax used for solving the problem. + * both query and documents are descriptions programming problems and a query and document are relevant if \ +the same approach is used to solve them. + * the query is a math problem and documents are theorems. Relevant documents (theorems) are the ones \ +that are useful for solving the math problem. + * the query and document are both math problems. A query and a document are relevant if the same theorem \ +is used for solving them. + * the query is a task description (e.g., for an API programmer) and documents are descriptions of \ +available APIs. Relevant documents (e.g., APIs) are the ones needed for completing the task. +- This is not an exhaustive list. These are just some examples to show you the complexity of queries, \ +documents, and the concept of relevance in this task. +- Note that even here, the relevant documents are still the ones that are useful for a user who is \ +searching for the given query. But the relation is more nuanced. +- You should analyze the query and the available documents. And then reason about what could be a meaningful \ +definition of relevance in this case, and what the user could be looking for. +- Moreover, sometimes, the query could be even a prompt that is given to a Large Language Model (LLM) and \ +the user wants to find the useful documents for the LLM that help answering/solving this prompt. +""" + +_WORKFLOW_TEMPLATE = """\ + + +* You are given a search query and a list of candidate documents. You have access to the ID and content of \ +each candidate document. +* You should read the query carefully and understand it. +{extended_relevance_line}\ +* Then you should compare the query with each one of the candidate documents. In this comparison, you want \ +to identify if the document is relevant/useful for the given query and to what extent. +* Select the {top_k} most relevant candidate documents for the given query. +* Note that just selecting the most relevant documents is not enough. You should identify the relative level \ +of relevance between the query and selected documents. This helps you sort the selected documents later \ +based on how relevant they are to the query. +* Once you have this information, you should call the "log_selected_documents" function to report the final \ +results and signal the completion of the task. +* Note that the selected document IDs must be reported in the decreasing level of relevance. I.e., The \ +first document in the list is the most relevant, the second is the second most relevant, and so on. This \ +is similar to what a search engine (e.g., Google Search) does (it shows you the relevant results in a \ +sorted order, where the most relevant results appear on top of the list). +""" + +_THINKING_TIPS = """ + + +* you have access to a "think" tool that you can use for complex thinking and analysis. Here are examples \ +of cases where the think tool might be useful: + - complex analysis and thinking to understand the meaning and intent of the query. E.g., what is the \ +user trying to find with this query? what kind of information is helpful for the user? + - extended thinking to analyze how each candidate document could or could not be relevant to the given query. + - reasoning to identify the relative level of relevance between the query and selected documents. It \ +helps you sort the documents correctly when reporting the final answer. +""" + + +def _render_selection_prompt(top_k: int, *, extended_relevance: bool = False) -> str: + """Render the selection agent system prompt (verbatim 01_v0.j2 logic).""" + parts = [_ROLE] + if extended_relevance: + parts.append(_RELEVANCE_DEFINITION) + ext_line = ( + "* As explained above, reason and figure out what the meaning of relevance is in this case, " + "and what could be relevant and useful information for the given query.\n" + if extended_relevance + else "" + ) + parts.append(_WORKFLOW_TEMPLATE.format(top_k=top_k, extended_relevance_line=ext_line)) + parts.append(_THINKING_TIPS) + return "".join(parts) + + +# --------------------------------------------------------------------------- +# Operator +# --------------------------------------------------------------------------- + + +class SelectionAgentOperator(AbstractOperator, CPUOperator): + """Re-rank a set of retrieved documents using an LLM-based selection agent. + + For each ``query_id`` group in the input DataFrame, the operator runs an + agentic LLM loop that reads the query and all candidate documents, then + calls a ``log_selected_documents`` tool to report the final ranked list. + The loop also has access to a ``think`` tool for extended reasoning. + + The system prompt matches the retrieval-bench ``01_v0.j2`` template verbatim, + with an optional ``extended_relevance`` mode for complex retrieval tasks. + + Input DataFrame schema + ---------------------- + query_id : str — unique query identifier + query_text : str — original query text shown to the LLM + doc_id : str — unique document identifier + text : str — document text content shown to the LLM + (any additional columns are ignored) + + Output DataFrame schema + ----------------------- + query_id : str — same ``query_id`` as the input + doc_id : str — selected document ID + rank : int — 1-indexed rank (1 = most relevant) + message : str — LLM explanation of the selection + + Parameters + ---------- + llm_model : str + Model identifier forwarded to the endpoint. + invoke_url : str + Full ``/v1/chat/completions`` endpoint URL. + top_k : int + Number of documents to select per query. Defaults to ``5``. + api_key : str, optional + Literal API key **or** an ``"os.environ/VAR_NAME"`` reference. + max_tokens : int, optional + Upper bound on tokens in each LLM response. + max_steps : int + Maximum agentic loop iterations per query. Defaults to ``10``. + extended_relevance : bool + When ``True``, include the ```` block in the + system prompt for tasks with non-standard relevance definitions. + Defaults to ``False``. + system_prompt_override : str, optional + Fully custom system prompt. Use ``{top_k}`` as a placeholder. + text_truncation : int + Maximum characters of each document's text shown to the LLM. + Defaults to ``2000``. + base_url : str, optional + Deprecated alias for ``invoke_url``. Prefer ``invoke_url``. + """ + + _NVIDIA_BUILD_ENDPOINT = "https://integrate.api.nvidia.com/v1/chat/completions" + + def __init__( + self, + *, + llm_model: str, + invoke_url: Optional[str] = None, + top_k: int = 5, + api_key: Optional[str] = None, + max_tokens: Optional[int] = None, + max_steps: int = 10, + extended_relevance: bool = False, + system_prompt_override: Optional[str] = None, + text_truncation: int = 2000, + parallel_tool_calls: bool = True, + base_url: Optional[str] = None, + ) -> None: + super().__init__() + self._llm_model = llm_model + self._top_k = top_k + self._api_key = api_key + self._max_tokens = max_tokens + self._max_steps = max_steps + self._extended_relevance = extended_relevance + self._system_prompt_override = system_prompt_override + self._text_truncation = text_truncation + self._parallel_tool_calls = parallel_tool_calls + + if invoke_url is not None: + self._invoke_url = invoke_url + elif base_url is not None: + import warnings + + warnings.warn( + "SelectionAgentOperator: 'base_url' is deprecated, use 'invoke_url' instead.", + DeprecationWarning, + stacklevel=2, + ) + self._invoke_url = base_url.rstrip("/") + "/v1/chat/completions" + else: + self._invoke_url = self._NVIDIA_BUILD_ENDPOINT + + # ------------------------------------------------------------------ + # AbstractOperator interface + # ------------------------------------------------------------------ + + def preprocess(self, data: Any, **kwargs: Any) -> pd.DataFrame: + """Validate that *data* is a DataFrame with the required columns.""" + if not isinstance(data, pd.DataFrame): + raise TypeError(f"SelectionAgentOperator expects a pd.DataFrame, got {type(data).__name__!r}.") + required = {"query_id", "query_text", "doc_id", "text"} + missing = required - set(data.columns) + if missing: + raise ValueError( + f"Input DataFrame is missing required column(s): {sorted(missing)}. " f"Expected: {sorted(required)}." + ) + return data.copy() + + def process(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: + """Run the selection agent loop for each query group.""" + rows: List[Dict[str, Any]] = [] + + for query_id, group in data.groupby("query_id", sort=False): + query_text = str(group["query_text"].iloc[0]) + docs = [{"id": str(row["doc_id"]), "text": str(row["text"])} for _, row in group.iterrows()] + result = self._select_documents(query_text, docs) + message = result.get("message", "") + for rank, doc_id in enumerate(result.get("doc_ids", []), 1): + rows.append( + { + "query_id": query_id, + "doc_id": doc_id, + "rank": rank, + "message": message, + } + ) + + if not rows: + return pd.DataFrame(columns=["query_id", "doc_id", "rank", "message"]) + + return pd.DataFrame(rows) + + def postprocess(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: + return data + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _resolve_api_key(self) -> Optional[str]: + api_key = self._api_key + if api_key is not None and api_key.strip().startswith("os.environ/"): + var = api_key.strip().removeprefix("os.environ/") + value = os.environ.get(var) + if value is None: + raise ValueError( + f"Environment variable '{var}' is not set. " f"Set it with: export {var}=" + ) + return value + return api_key + + def _build_system_prompt(self, top_k: int) -> str: + if self._system_prompt_override: + return self._system_prompt_override.format(top_k=top_k) + return _render_selection_prompt(top_k, extended_relevance=self._extended_relevance) + + def _build_tools(self, top_k: int, valid_doc_ids: List[str]) -> List[Dict[str, Any]]: + """Return the two tool specs for the selection agent loop.""" + return [ + { + "type": "function", + "function": { + "name": "think", + "description": ( + "Use this tool to think through complex analysis before making a decision. " + "It logs your reasoning without making any changes. Use it to compare " + "documents against the query or to reason about relevance." + ), + "parameters": { + "type": "object", + "properties": { + "thought": { + "type": "string", + "description": "Your reasoning or analysis.", + } + }, + "required": ["thought"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "log_selected_documents", + "description": ( + f"Records the {top_k} most relevant documents and ends the task. " + f"Call this when you have finished evaluating all candidate documents. " + f"The doc_ids list must be sorted from most to least relevant. " + f"Valid document IDs are: {valid_doc_ids}." + ), + "parameters": { + "type": "object", + "required": ["doc_ids", "message"], + "properties": { + "doc_ids": { + "type": "array", + "items": {"type": "string"}, + "description": ( + f"The IDs of the {top_k} most relevant documents, sorted from " + "most to least relevant. Must be valid document IDs from the candidates." + ), + }, + "message": { + "type": "string", + "description": "A brief explanation of your selection and the relevance ordering.", + }, + }, + }, + }, + }, + ] + + def _build_user_message(self, query_text: str, docs: List[Dict[str, Any]]) -> Dict[str, Any]: + """Format query + candidate documents as a multi-part user message.""" + content: List[Dict[str, Any]] = [ + {"type": "text", "text": f"Query:\n{query_text}"}, + {"type": "text", "text": "Candidate Documents:"}, + ] + seen: set[str] = set() + for doc in docs: + doc_id = doc["id"] + if doc_id in seen: + continue + seen.add(doc_id) + content.append({"type": "text", "text": f"Doc ID: {doc_id}"}) + text = doc.get("text", "").strip() + if text: + truncated = text[: self._text_truncation] + if len(text) > self._text_truncation: + truncated += "..." + content.append({"type": "text", "text": f"Doc Text: {truncated}"}) + return {"role": "user", "content": content} + + def _select_documents( + self, + query_text: str, + docs: List[Dict[str, Any]], + ) -> Dict[str, Any]: + """Run the agentic selection loop for a single query.""" + valid_ids = list(dict.fromkeys(d["id"] for d in docs)) + feasible_k = min(self._top_k, len(valid_ids)) + + system_prompt = self._build_system_prompt(feasible_k) + tools = self._build_tools(feasible_k, valid_ids) + valid_id_set = set(valid_ids) + api_key = self._resolve_api_key() + + messages: List[Dict[str, Any]] = [ + {"role": "system", "content": system_prompt}, + self._build_user_message(query_text, docs), + ] + + extra_body: Dict[str, Any] = {} + if not self._parallel_tool_calls: + extra_body["parallel_tool_calls"] = False + + for _step in range(self._max_steps): + try: + response = invoke_chat_completion_step( + invoke_url=self._invoke_url, + messages=messages, + model=self._llm_model, + api_key=api_key, + tools=tools, + tool_choice="auto", + max_tokens=self._max_tokens, + extra_body=extra_body or None, + ) + except TimeoutError as exc: + logger.warning( + "SelectionAgentOperator: LLM call timed out on step %d for query %r: %s", + _step, + query_text, + exc, + exc_info=True, + ) + break + except RuntimeError as exc: + logger.warning( + "SelectionAgentOperator: LLM retries exhausted on step %d for query %r: %s", + _step, + query_text, + exc, + exc_info=True, + ) + break + except requests.RequestException as exc: + logger.warning( + "SelectionAgentOperator: LLM HTTP error on step %d for query %r: %s", + _step, + query_text, + exc, + exc_info=True, + ) + break + except json.JSONDecodeError as exc: + logger.warning( + "SelectionAgentOperator: LLM returned invalid JSON on step %d for query %r: %s", + _step, + query_text, + exc, + exc_info=True, + ) + break + + if not response.get("choices"): + logger.warning("SelectionAgentOperator: empty choices in API response on step %d", _step) + break + choice = response["choices"][0] + msg = choice["message"] + finish_reason = choice.get("finish_reason") + + # Append the assistant turn to history + assistant_turn: Dict[str, Any] = {"role": "assistant"} + if msg.get("content"): + assistant_turn["content"] = msg["content"] + tool_calls = msg.get("tool_calls") or [] + if tool_calls: + assistant_turn["tool_calls"] = tool_calls + messages.append(assistant_turn) + + if finish_reason == "stop" or not tool_calls: + messages.append( + { + "role": "user", + "content": "Please call log_selected_documents to report your final selection.", + } + ) + continue + + tool_messages: List[Dict[str, Any]] = [] + should_end = False + end_kwargs: Dict[str, Any] = {} + + for tc in tool_calls: + tc_id = tc.get("id", "") + fn = tc.get("function", {}) + try: + fn_args = json.loads(fn.get("arguments", "{}")) + except json.JSONDecodeError: + tool_messages.append( + {"role": "tool", "tool_call_id": tc_id, "content": "Error: could not parse tool arguments."} + ) + continue + + if fn.get("name") == "think": + tool_messages.append( + {"role": "tool", "tool_call_id": tc_id, "content": "Your thought has been logged."} + ) + + elif fn.get("name") == "log_selected_documents": + raw_doc_ids = fn_args.get("doc_ids", []) + if isinstance(raw_doc_ids, str): + try: + raw_doc_ids = json.loads(raw_doc_ids) + except json.JSONDecodeError: + raw_doc_ids = [] + doc_ids = [d for d in raw_doc_ids if d in valid_id_set][:feasible_k] + if not doc_ids and raw_doc_ids: + logger.warning( + "SelectionAgentOperator: LLM returned %d doc_id(s) for query %r " + "but none matched the candidate set — possible hallucination. " + "Returned IDs: %s", + len(raw_doc_ids), + query_text, + raw_doc_ids[:10], + ) + end_kwargs = {"doc_ids": doc_ids, "message": fn_args.get("message", "")} + should_end = True + + else: + tool_messages.append( + { + "role": "tool", + "tool_call_id": tc_id, + "content": f"Error: unknown tool '{fn.get('name')}'.", + } + ) + + if should_end: + return end_kwargs + + messages.extend(tool_messages) + + return { + "doc_ids": [], + "message": "Selection agent reached max steps without completing.", + } diff --git a/nemo_retriever/src/nemo_retriever/graph/subquery_operator.py b/nemo_retriever/src/nemo_retriever/graph/subquery_operator.py new file mode 100644 index 0000000000..d5f0e6a3c1 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/graph/subquery_operator.py @@ -0,0 +1,364 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Operator that expands a query DataFrame into sub-queries via an LLM.""" + +from __future__ import annotations + +import json +import logging +import os + +import requests +from typing import Any, List, Literal, Optional + +import pandas as pd + +from nemo_retriever.graph.abstract_operator import AbstractOperator +from nemo_retriever.graph.cpu_operator import CPUOperator +from nemo_retriever.nim.chat_completions import invoke_chat_completions + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Built-in strategy prompts +# --------------------------------------------------------------------------- + +_PROMPTS: dict[str, str] = { + "decompose": """\ +You are a query decomposition assistant for a retrieval system. + +Given a search query, break it down into up to {max_subqueries} distinct sub-queries that \ +together cover all aspects and angles of the original query. Generate as many sub-queries as \ +are genuinely useful — do not pad with redundant ones just to hit the maximum. Each sub-query \ +should target a specific facet, making it easier for a dense retrieval system to find all \ +relevant documents. + +Rules: +- Each sub-query must be self-contained and meaningful on its own. +- Sub-queries should be diverse and complementary, not redundant. +- Use clear, precise language suited for dense embedding retrieval. +- Output a JSON array of strings only — no explanation, no markdown fences.""", + "hyde": """\ +You are a Hypothetical Document Embedding (HyDE) assistant for a retrieval system. + +Given a search query, generate up to {max_subqueries} short hypothetical document passages \ +(2–4 sentences each) that would directly answer or address the query. Generate as many as are \ +genuinely useful — fewer is fine if the query is simple. These passages will be used as queries \ +to a dense retrieval system to find real, similar documents. + +Rules: +- Each passage should read like a real document excerpt that answers the query. +- Vary the style and perspective across passages (e.g., academic, technical, narrative). +- Be factually plausible; focus on covering the query intent. +- Output a JSON array of strings only — no explanation, no markdown fences.""", + "multi_perspective": """\ +You are a multi-perspective query expansion assistant for a retrieval system. + +Given a search query, generate up to {max_subqueries} reformulations from different angles, \ +perspectives, or levels of specificity to maximise recall in a dense retrieval system. Only \ +generate reformulations that add genuine coverage — do not pad. + +Rules: +- Vary terminology: use synonyms, technical vs. casual language, acronyms vs. full names. +- Vary scope: broad overview queries alongside narrow, specific ones. +- Vary form: declarative statements, questions, and entity-focused queries. +- Each reformulation must have a meaningfully different surface form from the others. +- Output a JSON array of strings only — no explanation, no markdown fences.""", +} + + +# --------------------------------------------------------------------------- +# Operator +# --------------------------------------------------------------------------- + + +class SubQueryGeneratorOperator(AbstractOperator, CPUOperator): + """Expand each query row into sub-query rows using an LLM. + + The operator calls an LLM once per input query and explodes the result into + one output row per generated sub-query. The LLM decides how many + sub-queries to generate up to ``max_subqueries``. This + makes it a natural upstream stage for a retrieval operator: the downstream + operator can retrieve documents independently for every sub-query row, and + a subsequent aggregation step (e.g. RRF) can merge the per-sub-query + ranked lists back into a single ranking per ``query_id``. + + Input DataFrame schema + ---------------------- + query_id : str — unique identifier for the query + query_text : str — the search query text + (any additional columns are passed through unchanged) + + Output DataFrame schema + ----------------------- + query_id : str — same ``query_id`` as the input row + query_text : str — original query text (preserved for context) + subquery_idx : int — 0-based position within the generated sub-query group + subquery_text : str — the generated sub-query text + (additional input columns are forwarded to every expanded row) + + Parameters + ---------- + llm_model : str + OpenAI model identifier, e.g. ``"gpt-4o"``. + max_subqueries : int + Maximum number of sub-queries the LLM may generate per query. + The LLM will generate fewer if the query does not warrant the maximum. + Defaults to ``4``. + strategy : {"decompose", "hyde", "multi_perspective"} + Built-in sub-query generation strategy. + + ``"decompose"`` + Break the query into complementary sub-aspects (default). + ``"hyde"`` + Generate hypothetical answer passages (HyDE). + ``"multi_perspective"`` + Rewrite the query from diverse angles to maximise recall. + api_key : str, optional + Literal API key **or** an ``"os.environ/VAR_NAME"`` reference that is + resolved at call time. + invoke_url : str, optional + Full ``/v1/chat/completions`` endpoint URL. Defaults to the NVIDIA + build endpoint when omitted (requires ``api_key`` / ``NVIDIA_API_KEY``). + base_url : str, optional + Deprecated alias for ``invoke_url``. Prefer ``invoke_url``. + max_tokens : int, optional + Upper bound on tokens in the LLM response. + system_prompt_override : str, optional + Fully custom system prompt. Use ``{max_subqueries}`` as a placeholder. + When provided, ``strategy`` is ignored. + + Examples + -------- + Standalone use:: + + import pandas as pd + from nemo_retriever.graph.subquery_operator import SubQueryGeneratorOperator + + op = SubQueryGeneratorOperator( + llm_model="nvidia/llama-3.3-nemotron-super-49b-v1", + invoke_url="https://integrate.api.nvidia.com/v1/chat/completions", + max_subqueries=5, + ) + df = pd.DataFrame({ + "query_id": ["q1", "q2"], + "query_text": ["What causes inflation?", "How do vaccines work?"], + }) + result = op.run(df) + # result has up to 10 rows: ≤5 sub-queries × 2 original queries + + Composing into a graph:: + + from nemo_retriever.graph import InprocessExecutor + from nemo_retriever.graph.subquery_operator import SubQueryGeneratorOperator + + graph = ( + SubQueryGeneratorOperator(llm_model="gpt-4o", max_subqueries=4) + >> RetrievalOperator(retriever=my_retriever) + >> RRFAggregatorOperator() + ) + executor = InprocessExecutor(graph) + result_df = executor.ingest(query_df) + """ + + _NVIDIA_BUILD_ENDPOINT = "https://integrate.api.nvidia.com/v1/chat/completions" + + def __init__( + self, + *, + llm_model: str, + max_subqueries: int = 4, + strategy: Literal["decompose", "hyde", "multi_perspective"] = "decompose", + api_key: Optional[str] = None, + invoke_url: Optional[str] = None, + base_url: Optional[str] = None, + max_tokens: Optional[int] = None, + system_prompt_override: Optional[str] = None, + ) -> None: + super().__init__() + self._llm_model = llm_model + self._max_subqueries = max_subqueries + self._strategy = strategy + self._api_key = api_key + self._max_tokens = max_tokens + self._system_prompt_override = system_prompt_override + + # Resolve invoke_url: explicit > deprecated base_url > default NVIDIA endpoint + if invoke_url is not None: + self._invoke_url = invoke_url + elif base_url is not None: + import warnings + + warnings.warn( + "SubQueryGeneratorOperator: 'base_url' is deprecated, use 'invoke_url' instead.", + DeprecationWarning, + stacklevel=2, + ) + self._invoke_url = base_url.rstrip("/") + "/v1/chat/completions" + else: + self._invoke_url = self._NVIDIA_BUILD_ENDPOINT + + # ------------------------------------------------------------------ + # AbstractOperator interface + # ------------------------------------------------------------------ + + def preprocess(self, data: Any, **kwargs: Any) -> pd.DataFrame: + """Normalise *data* to a DataFrame with ``query_id`` / ``query_text`` columns. + + Accepted input types + -------------------- + ``pd.DataFrame`` + Must contain at least ``query_id`` and ``query_text`` columns. + ``list[str]`` + Plain query strings; ``query_id`` values are auto-assigned as + ``"q0"``, ``"q1"``, … + ``list[tuple[str, str]]`` or ``list[list[str, str]]`` + ``(query_id, query_text)`` pairs. + """ + if isinstance(data, pd.DataFrame): + missing = {"query_id", "query_text"} - set(data.columns) + if missing: + raise ValueError( + f"Input DataFrame is missing required column(s): {sorted(missing)}. " + "Expected at minimum: 'query_id' and 'query_text'." + ) + return data.copy() + + if isinstance(data, list) and data: + first = data[0] + if isinstance(first, str): + return pd.DataFrame( + { + "query_id": [f"q{i}" for i in range(len(data))], + "query_text": list(data), + } + ) + if isinstance(first, (tuple, list)) and len(first) == 2: + return pd.DataFrame(data, columns=["query_id", "query_text"]) + + raise TypeError( + f"Unsupported input type {type(data).__name__!r}. " + "Pass a pd.DataFrame with 'query_id' and 'query_text' columns, " + "a list[str], or a list[tuple[str, str]]." + ) + + def process(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: + """Generate sub-queries for every row and explode to one row per sub-query.""" + system_prompt = self._build_system_prompt() + + passthrough_cols = [c for c in data.columns if c not in ("query_id", "query_text")] + rows: List[dict[str, Any]] = [] + + for _, row in data.iterrows(): + subqueries = self._generate_one(row["query_text"], system_prompt) + for idx, sq in enumerate(subqueries): + new_row: dict[str, Any] = { + "query_id": row["query_id"], + "query_text": row["query_text"], + "subquery_idx": idx, + "subquery_text": sq, + } + for col in passthrough_cols: + new_row[col] = row[col] + rows.append(new_row) + + if not rows: + return pd.DataFrame(columns=["query_id", "query_text", "subquery_idx", "subquery_text"]) + + return pd.DataFrame(rows) + + def postprocess(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: + return data + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _resolve_api_key(self) -> Optional[str]: + api_key = self._api_key + if api_key is not None and api_key.strip().startswith("os.environ/"): + var = api_key.strip().removeprefix("os.environ/") + value = os.environ.get(var) + if value is None: + raise ValueError( + f"Environment variable '{var}' is not set. " f"Set it with: export {var}=" + ) + return value + return api_key + + def _build_system_prompt(self) -> str: + template = self._system_prompt_override or _PROMPTS[self._strategy] + return template.format(max_subqueries=self._max_subqueries) + + def _generate_one(self, query: str, system_prompt: str) -> List[str]: + """Call the LLM and return a list of sub-query strings for *query*.""" + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"Query: {query}"}, + ] + extra_body: dict[str, Any] = {} + if self._max_tokens is not None: + extra_body["max_tokens"] = self._max_tokens + + api_key = self._resolve_api_key() # raises ValueError immediately on config error + + try: + results = invoke_chat_completions( + invoke_url=self._invoke_url, + messages_list=[messages], + model=self._llm_model, + api_key=api_key, + extra_body=extra_body or None, + ) + except TimeoutError as exc: + logger.warning("SubQueryGeneratorOperator: LLM call timed out for query %r: %s", query, exc, exc_info=True) + return [query] + except RuntimeError as exc: + logger.warning( + "SubQueryGeneratorOperator: LLM retries exhausted for query %r: %s", query, exc, exc_info=True + ) + return [query] + except requests.RequestException as exc: + logger.warning("SubQueryGeneratorOperator: LLM HTTP error for query %r: %s", query, exc, exc_info=True) + return [query] + except json.JSONDecodeError as exc: + logger.warning( + "SubQueryGeneratorOperator: LLM returned invalid JSON for query %r: %s", query, exc, exc_info=True + ) + return [query] + raw = results[0].strip() + return _parse_json_list(raw, fallback=query) + + +# --------------------------------------------------------------------------- +# Module-level helpers (no instance state — easier to test in isolation) +# --------------------------------------------------------------------------- + + +def _parse_json_list(raw: str, *, fallback: str) -> List[str]: + """Parse a JSON array from *raw*, stripping markdown fences if present. + + Returns *[fallback]* when parsing fails so downstream stages always + receive at least one sub-query row. + """ + text = raw + found_fence = False + for fence in ("```json", "```"): + if text.startswith(fence): + text = text[len(fence) :] + found_fence = True + break + if found_fence and text.endswith("```"): + text = text[:-3] + text = text.strip() + + try: + parsed = json.loads(text) + if isinstance(parsed, list) and parsed and all(isinstance(s, str) for s in parsed): + return parsed + except json.JSONDecodeError: + pass + + return [fallback] diff --git a/nemo_retriever/src/nemo_retriever/nim/chat_completions.py b/nemo_retriever/src/nemo_retriever/nim/chat_completions.py index 98849c7dc4..2fdebe4576 100644 --- a/nemo_retriever/src/nemo_retriever/nim/chat_completions.py +++ b/nemo_retriever/src/nemo_retriever/nim/chat_completions.py @@ -17,6 +17,8 @@ from typing import Any, Dict, List, Optional, Sequence +from nemo_retriever.nim.nim import _parse_invoke_urls, _post_with_retries + def extract_chat_completion_text(response_json: Any) -> str: """Extract generated text from an OpenAI-compatible chat completions response.""" @@ -86,6 +88,68 @@ def invoke_chat_completions( client.shutdown() +def invoke_chat_completion_step( + *, + invoke_url: str, + messages: List[Dict[str, Any]], + model: Optional[str] = None, + api_key: Optional[str] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: str = "auto", + timeout_s: float = 120.0, + temperature: float = 0.0, + max_tokens: Optional[int] = None, + extra_body: Optional[Dict[str, Any]] = None, + max_retries: int = 10, + max_429_retries: int = 5, +) -> Dict[str, Any]: + """Single synchronous tool-aware chat completion call. + + Parameters + ---------- + invoke_url + Full ``/v1/chat/completions`` endpoint URL. + messages + OpenAI-format message list for this single request. + tools + List of OpenAI tool-spec dicts. When provided, ``tool_choice`` is also + forwarded so the model knows which tools it may call. + tool_choice + ``"auto"`` (default) lets the model decide; ``"none"`` suppresses tool + use; or a specific tool name dict. + """ + token = (api_key or "").strip() + headers: Dict[str, str] = {"Accept": "application/json", "Content-Type": "application/json"} + if token: + headers["Authorization"] = f"Bearer {token}" + + invoke_urls = _parse_invoke_urls(invoke_url) + endpoint_url = invoke_urls[0] + + payload: Dict[str, Any] = { + "messages": messages, + "temperature": temperature, + } + if model: + payload["model"] = model + if max_tokens is not None: + payload["max_tokens"] = max_tokens + if tools: + payload["tools"] = tools + payload["tool_choice"] = tool_choice + if extra_body: + payload.update(extra_body) + + return _post_with_retries( + invoke_url=endpoint_url, + payload=payload, + headers=headers, + timeout_s=float(timeout_s), + max_retries=int(max_retries), + max_429_retries=int(max_429_retries), + ) + + def invoke_chat_completions_images( *, invoke_url: str, diff --git a/nemo_retriever/tests/test_agentic_operators.py b/nemo_retriever/tests/test_agentic_operators.py new file mode 100644 index 0000000000..78ae3dd1e2 --- /dev/null +++ b/nemo_retriever/tests/test_agentic_operators.py @@ -0,0 +1,596 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Smoke tests for the agentic retrieval operators. + +Run with: + cd nemo_retriever && uv run pytest tests/test_agentic_operators.py -v +""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + +# --------------------------------------------------------------------------- +# RRFAggregatorOperator — pure pandas, no mocking needed +# --------------------------------------------------------------------------- + + +class TestRRFAggregatorOperator: + def _make_input(self): + """Two queries; q1 has doc d1 in both steps, q2 has one step.""" + return pd.DataFrame( + { + "query_id": ["q1", "q1", "q1", "q1", "q2", "q2"], + "query_text": ["inflation"] * 4 + ["vaccines"] * 2, + "step_idx": [0, 0, 1, 1, 0, 0], + "doc_id": ["d1", "d2", "d1", "d3", "d4", "d5"], + "text": ["t1", "t2", "t1", "t3", "t4", "t5"], + "rank": [1, 2, 1, 2, 1, 2], + } + ) + + def test_rrf_scores_correct(self): + from nemo_retriever.graph.rrf_aggregator_operator import RRFAggregatorOperator + + op = RRFAggregatorOperator(k=60) + result = op.run(self._make_input()) + + q1 = result[result["query_id"] == "q1"].set_index("doc_id") + k = 60 + # d1 appears in step 0 rank 1 and step 1 rank 1 + expected_d1 = 1 / (1 + k) + 1 / (1 + k) + # d2 appears only in step 0 rank 2 + expected_d2 = 1 / (2 + k) + assert abs(q1.loc["d1", "rrf_score"] - expected_d1) < 1e-10 + assert abs(q1.loc["d2", "rrf_score"] - expected_d2) < 1e-10 + + def test_sorted_descending_per_query(self): + from nemo_retriever.graph.rrf_aggregator_operator import RRFAggregatorOperator + + op = RRFAggregatorOperator(k=60) + result = op.run(self._make_input()) + + for _, grp in result.groupby("query_id"): + scores = grp["rrf_score"].tolist() + assert scores == sorted(scores, reverse=True), "Scores not sorted descending" + + def test_text_carried_through(self): + from nemo_retriever.graph.rrf_aggregator_operator import RRFAggregatorOperator + + op = RRFAggregatorOperator(k=60) + result = op.run(self._make_input()) + q1 = result[result["query_id"] == "q1"].set_index("doc_id") + assert q1.loc["d1", "text"] == "t1" + + def test_output_schema(self): + from nemo_retriever.graph.rrf_aggregator_operator import RRFAggregatorOperator + + op = RRFAggregatorOperator(k=60) + result = op.run(self._make_input()) + assert set(result.columns) >= {"query_id", "query_text", "doc_id", "rrf_score", "text"} + + def test_missing_column_raises(self): + from nemo_retriever.graph.rrf_aggregator_operator import RRFAggregatorOperator + + op = RRFAggregatorOperator(k=60) + bad_df = pd.DataFrame({"query_id": ["q1"], "query_text": ["x"]}) + with pytest.raises(ValueError, match="missing required column"): + op.run(bad_df) + + +# --------------------------------------------------------------------------- +# Prompt rendering — pure Python, no mocking needed +# --------------------------------------------------------------------------- + + +class TestPromptRendering: + def test_react_prompt_no_extended_relevance(self): + from nemo_retriever.graph.react_agent_operator import _render_react_agent_prompt + + prompt = _render_react_agent_prompt(10, with_init_docs=True, enforce_top_k=True, extended_relevance=False) + assert "" in prompt + assert "" in prompt + assert "" in prompt + assert "RELEVANCE_DEFINITION" not in prompt + assert "exactly the 10 most relevant" in prompt + assert "TIP" in prompt + + def test_react_prompt_with_extended_relevance(self): + from nemo_retriever.graph.react_agent_operator import _render_react_agent_prompt + + prompt = _render_react_agent_prompt(5, with_init_docs=False, enforce_top_k=False, extended_relevance=True) + assert "RELEVANCE_DEFINITION" in prompt + assert "exactly the 5" not in prompt + assert "TIP" not in prompt + + def test_selection_prompt_no_extended_relevance(self): + from nemo_retriever.graph.selection_agent_operator import _render_selection_prompt + + prompt = _render_selection_prompt(5, extended_relevance=False) + assert "" in prompt + assert "" in prompt + assert "THINKING TIPS" in prompt + assert "RELEVANCE_DEFINITION" not in prompt + assert "5 most relevant" in prompt + + def test_selection_prompt_with_extended_relevance(self): + from nemo_retriever.graph.selection_agent_operator import _render_selection_prompt + + prompt = _render_selection_prompt(5, extended_relevance=True) + assert "RELEVANCE_DEFINITION" in prompt + assert "As explained above" in prompt + + +# --------------------------------------------------------------------------- +# SelectionAgentOperator — mock invoke_chat_completion_step +# --------------------------------------------------------------------------- + + +def _make_tool_call_response(fn_name: str, fn_args: dict, tc_id: str = "call_1") -> dict: + """Build a canned /v1/chat/completions response with one tool call.""" + return { + "choices": [ + { + "message": { + "content": None, + "tool_calls": [ + { + "id": tc_id, + "type": "function", + "function": {"name": fn_name, "arguments": json.dumps(fn_args)}, + } + ], + }, + "finish_reason": "tool_calls", + } + ] + } + + +class TestSelectionAgentOperator: + def _make_input(self): + return pd.DataFrame( + { + "query_id": ["q1", "q1", "q1"], + "query_text": ["What causes inflation?"] * 3, + "doc_id": ["d1", "d2", "d3"], + "text": ["monetary policy doc", "supply chain doc", "unrelated doc"], + } + ) + + @patch("nemo_retriever.graph.selection_agent_operator.invoke_chat_completion_step") + def test_happy_path_selects_docs(self, mock_step): + from nemo_retriever.graph.selection_agent_operator import SelectionAgentOperator + + # LLM immediately calls log_selected_documents + mock_step.return_value = _make_tool_call_response( + "log_selected_documents", + {"doc_ids": ["d1", "d2"], "message": "d1 most relevant"}, + ) + + op = SelectionAgentOperator( + llm_model="test-model", + invoke_url="http://localhost/v1/chat/completions", + top_k=2, + ) + result = op.run(self._make_input()) + + assert list(result.columns) == ["query_id", "doc_id", "rank", "message"] + assert result["query_id"].tolist() == ["q1", "q1"] + assert result["doc_id"].tolist() == ["d1", "d2"] + assert result["rank"].tolist() == [1, 2] + + @patch("nemo_retriever.graph.selection_agent_operator.invoke_chat_completion_step") + def test_think_then_select(self, mock_step): + from nemo_retriever.graph.selection_agent_operator import SelectionAgentOperator + + # First call: think; second call: log_selected_documents + mock_step.side_effect = [ + _make_tool_call_response("think", {"thought": "let me reason..."}), + _make_tool_call_response("log_selected_documents", {"doc_ids": ["d3"], "message": "only d3"}), + ] + + op = SelectionAgentOperator( + llm_model="test-model", + invoke_url="http://localhost/v1/chat/completions", + top_k=1, + ) + result = op.run(self._make_input()) + + assert result["doc_id"].tolist() == ["d3"] + assert mock_step.call_count == 2 + + @patch("nemo_retriever.graph.selection_agent_operator.invoke_chat_completion_step") + def test_extended_relevance_in_prompt(self, mock_step): + from nemo_retriever.graph.selection_agent_operator import SelectionAgentOperator + + captured_prompts = [] + + def capture_and_respond(**kwargs): + captured_prompts.append(kwargs["messages"][0]["content"]) + return _make_tool_call_response("log_selected_documents", {"doc_ids": ["d1"], "message": "ok"}) + + mock_step.side_effect = capture_and_respond + + op = SelectionAgentOperator( + llm_model="test-model", + invoke_url="http://localhost/v1/chat/completions", + top_k=1, + extended_relevance=True, + ) + op.run(self._make_input()) + + assert "RELEVANCE_DEFINITION" in captured_prompts[0] + + +# --------------------------------------------------------------------------- +# ReActAgentOperator — mock retriever_fn + invoke_chat_completion_step +# --------------------------------------------------------------------------- + + +class TestReActAgentOperator: + def _make_input(self): + return pd.DataFrame( + { + "query_id": ["q1"], + "query_text": ["What causes inflation?"], + } + ) + + def _make_retriever(self, docs=None): + """Return a mock retriever_fn that returns canned docs.""" + if docs is None: + docs = [{"doc_id": "d1", "text": "monetary policy"}, {"doc_id": "d2", "text": "supply chains"}] + + def retriever_fn(query_text: str, top_k: int): + return docs[:top_k] + + return retriever_fn + + @patch("nemo_retriever.graph.react_agent_operator.invoke_chat_completion_step") + def test_simple_mode_retrieve_then_final(self, mock_step): + from nemo_retriever.graph.react_agent_operator import ReActAgentOperator + + # 1) agent calls retrieve("subquery"), 2) agent calls final_results + mock_step.side_effect = [ + _make_tool_call_response("retrieve", {"query": "inflation monetary policy"}), + _make_tool_call_response( + "final_results", + {"doc_ids": ["d1", "d2"], "message": "found them", "search_successful": "true"}, + ), + ] + + op = ReActAgentOperator( + invoke_url="http://localhost/v1/chat/completions", + llm_model="test-model", + retriever_fn=self._make_retriever(), + user_msg_type="simple", + target_top_k=2, + ) + result = op.run(self._make_input()) + + assert set(result.columns) >= {"query_id", "query_text", "step_idx", "doc_id", "text", "rank"} + assert result["query_id"].unique().tolist() == ["q1"] + # step 0 is the retrieve tool call result + assert 0 in result["step_idx"].values + assert "d1" in result["doc_id"].values + + @patch("nemo_retriever.graph.react_agent_operator.invoke_chat_completion_step") + def test_with_results_mode_initial_retrieval(self, mock_step): + from nemo_retriever.graph.react_agent_operator import ReActAgentOperator + + # with_results=True → retriever_fn called once before LLM, then LLM immediately calls final_results + mock_step.return_value = _make_tool_call_response( + "final_results", + {"doc_ids": ["d1"], "message": "ok", "search_successful": "true"}, + ) + retriever = MagicMock(return_value=[{"doc_id": "d1", "text": "monetary policy"}]) + + op = ReActAgentOperator( + invoke_url="http://localhost/v1/chat/completions", + llm_model="test-model", + retriever_fn=retriever, + user_msg_type="with_results", + target_top_k=1, + ) + result = op.run(self._make_input()) + + # retriever was called upfront (step_idx=0) before any LLM step + assert retriever.call_count >= 1 + assert 0 in result["step_idx"].values + + @patch("nemo_retriever.graph.react_agent_operator.invoke_chat_completion_step") + def test_output_row_structure(self, mock_step): + from nemo_retriever.graph.react_agent_operator import ReActAgentOperator + + mock_step.side_effect = [ + _make_tool_call_response("retrieve", {"query": "q"}), + _make_tool_call_response( + "final_results", {"doc_ids": ["d1"], "message": "ok", "search_successful": "true"} + ), + ] + + op = ReActAgentOperator( + invoke_url="http://localhost/v1/chat/completions", + llm_model="test-model", + retriever_fn=self._make_retriever(), + user_msg_type="simple", + ) + result = op.run(self._make_input()) + + assert (result["rank"] >= 1).all() + assert result["step_idx"].dtype in (int, "int64") + assert result["doc_id"].notna().all() + + @patch("nemo_retriever.graph.selection_agent_operator.invoke_chat_completion_step") + @patch("nemo_retriever.graph.react_agent_operator.invoke_chat_completion_step") + def test_pipeline_end_to_end_with_mocks(self, mock_react_step, mock_selection_step): + """Wire ReAct → RRF → Selection with mocks; verify final output shape. + + Each operator imports invoke_chat_completion_step into its own module + namespace, so both must be patched independently. + """ + from nemo_retriever.graph.executor import InprocessExecutor + from nemo_retriever.graph.react_agent_operator import ReActAgentOperator + from nemo_retriever.graph.rrf_aggregator_operator import RRFAggregatorOperator + from nemo_retriever.graph.selection_agent_operator import SelectionAgentOperator + + # ReAct: retrieve once, then final_results + mock_react_step.side_effect = [ + _make_tool_call_response("retrieve", {"query": "inflation"}), + _make_tool_call_response( + "final_results", {"doc_ids": ["d1", "d2"], "message": "ok", "search_successful": "true"} + ), + ] + # Selection: immediately log_selected_documents + mock_selection_step.return_value = _make_tool_call_response( + "log_selected_documents", {"doc_ids": ["d1"], "message": "d1 best"} + ) + + def retriever_fn(query_text, top_k): + return [{"doc_id": "d1", "text": "monetary policy"}, {"doc_id": "d2", "text": "supply chains"}] + + pipeline = ( + ReActAgentOperator( + invoke_url="http://localhost/v1/chat/completions", + llm_model="test-model", + retriever_fn=retriever_fn, + user_msg_type="simple", + target_top_k=1, + ) + >> RRFAggregatorOperator(k=60) + >> SelectionAgentOperator( + invoke_url="http://localhost/v1/chat/completions", + llm_model="test-model", + top_k=1, + ) + ) + + query_df = pd.DataFrame({"query_id": ["q1"], "query_text": ["What causes inflation?"]}) + result = InprocessExecutor(pipeline).ingest(query_df) + + assert list(result.columns) == ["query_id", "doc_id", "rank", "message"] + assert result["query_id"].tolist() == ["q1"] + assert result["rank"].tolist() == [1] + + +# --------------------------------------------------------------------------- +# _parse_json_list — pure Python, no mocking needed +# --------------------------------------------------------------------------- + + +class TestParseJsonList: + def _parse(self, raw, fallback="orig"): + from nemo_retriever.graph.subquery_operator import _parse_json_list + + return _parse_json_list(raw, fallback=fallback) + + def test_plain_json_array(self): + assert self._parse('["a", "b", "c"]') == ["a", "b", "c"] + + def test_json_fence(self): + assert self._parse('```json\n["x", "y"]\n```') == ["x", "y"] + + def test_plain_fence(self): + assert self._parse('```\n["x"]\n```') == ["x"] + + def test_trailing_fence_without_leading_not_stripped(self): + # A JSON string that happens to end with ``` but has no leading fence. + # It should still parse because the trailing strip must NOT fire. + raw = '["valid"]```' + result = self._parse(raw) + assert result == ["orig"] # malformed JSON → fallback + + def test_malformed_json_returns_fallback(self): + assert self._parse("not json at all", fallback="q") == ["q"] + + def test_empty_list_returns_fallback(self): + assert self._parse("[]", fallback="q") == ["q"] + + def test_non_string_items_returns_fallback(self): + assert self._parse("[1, 2, 3]", fallback="q") == ["q"] + + def test_mixed_types_returns_fallback(self): + assert self._parse('["a", 1]', fallback="q") == ["q"] + + +# --------------------------------------------------------------------------- +# SubQueryGeneratorOperator.preprocess — no LLM calls needed +# --------------------------------------------------------------------------- + + +class TestSubQueryPreprocess: + def _op(self): + from nemo_retriever.graph.subquery_operator import SubQueryGeneratorOperator + + return SubQueryGeneratorOperator(llm_model="test-model") + + def test_dataframe_accepted(self): + op = self._op() + df = pd.DataFrame({"query_id": ["q1"], "query_text": ["hello"]}) + result = op.preprocess(df) + assert list(result["query_id"]) == ["q1"] + + def test_dataframe_missing_query_id_raises(self): + op = self._op() + bad = pd.DataFrame({"query_text": ["hello"]}) + with pytest.raises(ValueError, match="query_id"): + op.preprocess(bad) + + def test_dataframe_missing_query_text_raises(self): + op = self._op() + bad = pd.DataFrame({"query_id": ["q1"]}) + with pytest.raises(ValueError, match="query_text"): + op.preprocess(bad) + + def test_list_of_strings_auto_ids(self): + op = self._op() + result = op.preprocess(["alpha", "beta"]) + assert result["query_id"].tolist() == ["q0", "q1"] + assert result["query_text"].tolist() == ["alpha", "beta"] + + def test_list_of_tuples(self): + op = self._op() + result = op.preprocess([("id1", "alpha"), ("id2", "beta")]) + assert result["query_id"].tolist() == ["id1", "id2"] + + def test_unsupported_type_raises(self): + op = self._op() + with pytest.raises(TypeError): + op.preprocess({"query_id": "q1", "query_text": "hello"}) + + +class TestSubQueryGeneratorOperator: + """Tests for _build_system_prompt and _generate_one.""" + + def _op(self, **kwargs): + from nemo_retriever.graph.subquery_operator import SubQueryGeneratorOperator + + return SubQueryGeneratorOperator(llm_model="test-model", **kwargs) + + # -- _build_system_prompt ------------------------------------------------- + + def test_decompose_prompt_contains_max_subqueries(self): + op = self._op(strategy="decompose", max_subqueries=6) + prompt = op._build_system_prompt() + assert "6" in prompt + assert "decompos" in prompt.lower() + + def test_hyde_prompt_contains_max_subqueries(self): + op = self._op(strategy="hyde", max_subqueries=3) + prompt = op._build_system_prompt() + assert "3" in prompt + assert "hypothetical" in prompt.lower() + + def test_multi_perspective_prompt_contains_max_subqueries(self): + op = self._op(strategy="multi_perspective", max_subqueries=5) + prompt = op._build_system_prompt() + assert "5" in prompt + assert "perspective" in prompt.lower() + + def test_system_prompt_override_used_instead_of_strategy(self): + op = self._op(system_prompt_override="Custom prompt max={max_subqueries}", max_subqueries=2) + assert op._build_system_prompt() == "Custom prompt max=2" + + # -- _generate_one -------------------------------------------------------- + + @patch("nemo_retriever.graph.subquery_operator.invoke_chat_completions") + def test_generate_one_happy_path(self, mock_invoke): + mock_invoke.return_value = ['["sub1", "sub2", "sub3"]'] + op = self._op(max_subqueries=4) + result = op._generate_one("What causes inflation?", "system prompt") + assert result == ["sub1", "sub2", "sub3"] + mock_invoke.assert_called_once() + + @patch("nemo_retriever.graph.subquery_operator.invoke_chat_completions") + def test_generate_one_fenced_json(self, mock_invoke): + mock_invoke.return_value = ['```json\n["a", "b"]\n```'] + op = self._op() + assert op._generate_one("q", "sys") == ["a", "b"] + + @patch("nemo_retriever.graph.subquery_operator.invoke_chat_completions") + def test_generate_one_malformed_json_falls_back(self, mock_invoke): + mock_invoke.return_value = ["not valid json"] + op = self._op() + assert op._generate_one("original query", "sys") == ["original query"] + + @patch("nemo_retriever.graph.subquery_operator.invoke_chat_completions") + def test_generate_one_llm_error_falls_back(self, mock_invoke): + mock_invoke.side_effect = RuntimeError("connection timeout") + op = self._op() + assert op._generate_one("original query", "sys") == ["original query"] + + +# --------------------------------------------------------------------------- +# SelectionAgentOperator.preprocess — no LLM calls needed +# --------------------------------------------------------------------------- + + +class TestSelectionAgentPreprocess: + def _op(self): + from nemo_retriever.graph.selection_agent_operator import SelectionAgentOperator + + return SelectionAgentOperator(llm_model="test-model", invoke_url="http://localhost/v1/chat/completions") + + def test_valid_dataframe_accepted(self): + op = self._op() + df = pd.DataFrame({"query_id": ["q1"], "query_text": ["q"], "doc_id": ["d1"], "text": ["t"]}) + result = op.preprocess(df) + assert len(result) == 1 + + def test_missing_doc_id_raises(self): + op = self._op() + bad = pd.DataFrame({"query_id": ["q1"], "query_text": ["q"], "text": ["t"]}) + with pytest.raises(ValueError, match="doc_id"): + op.preprocess(bad) + + def test_missing_text_raises(self): + op = self._op() + bad = pd.DataFrame({"query_id": ["q1"], "query_text": ["q"], "doc_id": ["d1"]}) + with pytest.raises(ValueError, match="text"): + op.preprocess(bad) + + def test_non_dataframe_raises(self): + op = self._op() + with pytest.raises(TypeError): + op.preprocess([{"query_id": "q1"}]) + + +# --------------------------------------------------------------------------- +# SelectionAgentOperator — max_steps exhausted fallback +# --------------------------------------------------------------------------- + + +class TestSelectionAgentMaxSteps: + @patch("nemo_retriever.graph.selection_agent_operator.invoke_chat_completion_step") + def test_max_steps_exhausted_returns_empty(self, mock_step): + from nemo_retriever.graph.selection_agent_operator import SelectionAgentOperator + + # LLM only ever calls think — never log_selected_documents + mock_step.return_value = _make_tool_call_response("think", {"thought": "still thinking..."}) + + op = SelectionAgentOperator( + llm_model="test-model", + invoke_url="http://localhost/v1/chat/completions", + top_k=2, + max_steps=3, + ) + df = pd.DataFrame( + { + "query_id": ["q1", "q1"], + "query_text": ["What causes inflation?"] * 2, + "doc_id": ["d1", "d2"], + "text": ["doc one", "doc two"], + } + ) + result = op.run(df) + + assert len(result) == 0 + assert list(result.columns) == ["query_id", "doc_id", "rank", "message"] + assert mock_step.call_count == 3