diff --git a/nemo_retriever/src/nemo_retriever/recall/core.py b/nemo_retriever/src/nemo_retriever/recall/core.py index 9076ad33e6..f0e7687bdb 100644 --- a/nemo_retriever/src/nemo_retriever/recall/core.py +++ b/nemo_retriever/src/nemo_retriever/recall/core.py @@ -305,10 +305,13 @@ def _hits_to_keys(raw_hits: List[List[Dict[str, Any]]]) -> List[List[str]]: keys: List[str] = [] for h in hits: page_number = h["page_number"] - source = h["source"] + raw_source = h["source"] + # source may be a bare path string or a JSON object {"source_id": "..."}. + source_map = _parse_mapping(raw_source) + source = source_map.get("source_id", raw_source) if source_map else raw_source # Prefer explicit `pdf_page` column; fall back to derived form. if page_number is not None and source: - filename = Path(source).stem + filename = Path(str(source)).stem keys.append(f"{filename}_{str(page_number)}") else: logger.warning( diff --git a/nemo_retriever/src/nemo_retriever/text_embed/processor.py b/nemo_retriever/src/nemo_retriever/text_embed/processor.py index 81dd4b8a6f..6628291125 100644 --- a/nemo_retriever/src/nemo_retriever/text_embed/processor.py +++ b/nemo_retriever/src/nemo_retriever/text_embed/processor.py @@ -14,7 +14,8 @@ from nv_ingest_api.internal.transform.embed_text import transform_create_text_embeddings_internal from nemo_retriever.io.dataframe import validate_primitives_dataframe -from nemo_retriever.vector_store.lancedb_store import LanceDBConfig, write_embeddings_to_lancedb +from nemo_retriever.params.models import LanceDbParams +from nemo_retriever.vector_store.lancedb_store import write_embeddings_to_lancedb logger = logging.getLogger(__name__) @@ -25,7 +26,7 @@ def embed_text_from_primitives_df( *, transform_config: TextEmbeddingSchema, task_config: Optional[Dict[str, Any]] = None, - lancedb: Optional[LanceDBConfig] = None, + lancedb: Optional[LanceDbParams] = None, trace_info: Optional[Dict[str, Any]] = None, ) -> Tuple[pd.DataFrame, Dict[str, Any]]: """Generate embeddings for supported content types and write to metadata.""" diff --git a/nemo_retriever/src/nemo_retriever/vector_store/__init__.py b/nemo_retriever/src/nemo_retriever/vector_store/__init__.py index 1c05e4e8f2..6c28b5c9ab 100644 --- a/nemo_retriever/src/nemo_retriever/vector_store/__init__.py +++ b/nemo_retriever/src/nemo_retriever/vector_store/__init__.py @@ -3,16 +3,21 @@ # SPDX-License-Identifier: Apache-2.0 from .__main__ import app +from .lancedb_backend import LanceDBBackend from .lancedb_store import ( - LanceDBConfig, create_lancedb_index, write_embeddings_to_lancedb, write_text_embeddings_dir_to_lancedb, ) +from .vdb import VectorStore +from .vdb_records import build_vdb_records, build_vdb_records_from_dicts __all__ = [ "app", - "LanceDBConfig", + "LanceDBBackend", + "VectorStore", + "build_vdb_records", + "build_vdb_records_from_dicts", "create_lancedb_index", "write_embeddings_to_lancedb", "write_text_embeddings_dir_to_lancedb", diff --git a/nemo_retriever/src/nemo_retriever/vector_store/lancedb_backend.py b/nemo_retriever/src/nemo_retriever/vector_store/lancedb_backend.py new file mode 100644 index 0000000000..5bca6b6796 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/vector_store/lancedb_backend.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""LanceDB implementation of the :class:`VectorStore` interface.""" + +from __future__ import annotations + +import logging +from typing import Any, Sequence + +from nemo_retriever.params.models import LanceDbParams +from nemo_retriever.vector_store.lancedb_utils import infer_vector_dim, lancedb_schema +from nemo_retriever.vector_store.vdb import VectorStore + +logger = logging.getLogger(__name__) + + +class LanceDBBackend(VectorStore): + """LanceDB vector store backend. + + Lazily connects and creates the table on the first :meth:`write_rows` + call so that the embedding dimension can be inferred from the data. + """ + + def __init__(self, params: LanceDbParams | None = None) -> None: + self._params = params or LanceDbParams() + self._db: Any = None + self._table: Any = None + + def open_table(self) -> None: + """Open an existing LanceDB table without creating it. + + Used by the driver to run post-pipeline finalization (e.g. index + creation) after distributed workers have written all rows. + """ + import lancedb + + self._db = lancedb.connect(uri=self._params.lancedb_uri) + self._table = self._db.open_table(self._params.table_name) + + def create_table(self, *, dim: int, **kwargs: Any) -> None: + import lancedb + + self._db = lancedb.connect(uri=self._params.lancedb_uri) + schema = lancedb_schema(vector_dim=dim) + mode = "overwrite" if self._params.overwrite else "create" + self._table = self._db.create_table( + self._params.table_name, + schema=schema, + mode=mode, + ) + + def write_rows(self, rows: Sequence[dict[str, Any]], **kwargs: Any) -> None: + if not rows: + return + if self._table is None: + self.create_table(dim=infer_vector_dim(list(rows))) + self._table.add(list(rows)) + + def create_index(self, **kwargs: Any) -> None: + if self._table is None: + return + if not self._params.create_index: + return + + from nemo_retriever.vector_store.lancedb_store import create_lancedb_index + + try: + create_lancedb_index(self._table, cfg=self._params) + except RuntimeError: + logger.warning( + "Index creation failed (likely too few rows for %d partitions); skipping.", + self._params.num_partitions, + exc_info=True, + ) diff --git a/nemo_retriever/src/nemo_retriever/vector_store/lancedb_store.py b/nemo_retriever/src/nemo_retriever/vector_store/lancedb_store.py index 39165a7c57..a8706c58c9 100644 --- a/nemo_retriever/src/nemo_retriever/vector_store/lancedb_store.py +++ b/nemo_retriever/src/nemo_retriever/vector_store/lancedb_store.py @@ -6,43 +6,18 @@ import json import logging -from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple # noqa: F401 +from typing import Any, Dict, List, Optional from datetime import timedelta from nv_ingest_client.util.vdb.lancedb import LanceDB -from nemo_retriever.vector_store.lancedb_utils import lancedb_schema +from nemo_retriever.params.models import LanceDbParams +from nemo_retriever.vector_store.vdb_records import build_vdb_records, build_vdb_records_from_dicts import pandas as pd -import lancedb logger = logging.getLogger(__name__) -@dataclass(frozen=True) -class LanceDBConfig: - """ - Minimal config for writing embeddings into LanceDB. - - This module is intentionally lightweight: it can be used by the text-embedding - stage (`nemo_retriever.text_embed.stage`) and by the vector-store CLI (`nemo_retriever.vector_store.stage`). - """ - - uri: str = "lancedb" - table_name: str = "nv-ingest" - overwrite: bool = True - - # Optional index creation (recommended for recall/search runs). - create_index: bool = True - index_type: str = "IVF_HNSW_SQ" - metric: str = "l2" - num_partitions: int = 16 - num_sub_vectors: int = 256 - - hybrid: bool = False - fts_language: str = "English" - - def _read_text_embeddings_json_df(path: Path) -> pd.DataFrame: """ Read a `*.text_embeddings.json` file emitted by `nemo_retriever.text_embed.stage`. @@ -86,110 +61,7 @@ def _iter_text_embeddings_json_files(input_dir: Path, *, recursive: bool) -> Lis return sorted([p for p in files if p.is_file()]) -def _safe_str(x: Any) -> str: - return "" if x is None else str(x) - - -def _extract_source_path_and_id(meta: Dict[str, Any]) -> Tuple[str, str]: - """ - Extract a stable source path/id from metadata. - - Prefers: - - metadata.source_metadata.source_id - - metadata.source_metadata.source_name - - metadata.custom_content.path - """ - source = meta.get("source_metadata") if isinstance(meta.get("source_metadata"), dict) else {} - source_id = source.get("source_id") or "" - source_name = source.get("source_name") or "" - - custom = meta.get("custom_content") if isinstance(meta.get("custom_content"), dict) else {} - custom_path = custom.get("path") or custom.get("input_pdf") or custom.get("pdf_path") or "" - - path = _safe_str(custom_path or source_id or source_name) - sid = _safe_str(source_id or path or source_name) - return path, sid - - -def _extract_page_number(meta: Dict[str, Any]) -> int: - cm = meta.get("content_metadata") if isinstance(meta.get("content_metadata"), dict) else {} - page = cm.get("hierarchy", {}).get("page", -1) - try: - return int(page) - except Exception: - return -1 - - -def _build_lancedb_rows_from_df(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """ - Transform an embeddings-enriched primitives DataFrame into LanceDB rows. - - Rows include: - - vector (embedding) - - pdf_basename - - page_number - - pdf_page (basename_page) - - source_id - - path - """ - out: List[Dict[str, Any]] = [] - - for row in rows: - meta = row.get("metadata") - if not isinstance(meta, dict): - continue - - embedding = meta.get("embedding") - if embedding is None: - continue - - # Normalize embedding to list[float] - if not isinstance(embedding, list): - try: - embedding = list(embedding) # type: ignore[arg-type] - except Exception: - continue - meta.pop("embedding", None) # Remove embedding from metadata to save space in LanceDB. - # path, source_id = _extract_source_path_and_id(meta) - path = row.get("path", "") - source_id = meta.get("source_path", path) - # page_number = _extract_page_number(meta) - page_number = row.get("page_number", -1) - p = Path(path) if path else None - filename = p.name if p is not None else "" - pdf_basename = p.stem if p is not None else "" - pdf_page = f"{pdf_basename}_{page_number}" if (pdf_basename and page_number >= 0) else "" - - if page_number == -1: - logger.debug("Unable to determine page number for %s", path) - - out.append( - { - "vector": embedding, - "pdf_page": pdf_page, - "filename": filename, - "pdf_basename": pdf_basename, - "page_number": int(page_number), - "source": source_id, - "source_id": source_id, - "path": path, - "text": row.get("text", ""), - "metadata": str(meta), - } - ) - - return out - - -def _infer_vector_dim(rows: Sequence[Dict[str, Any]]) -> int: - for r in rows: - v = r.get("vector") - if isinstance(v, list) and v: - return int(len(v)) - return 0 - - -def create_lancedb_index(table: Any, *, cfg: LanceDBConfig, text_column: str = "text") -> None: +def create_lancedb_index(table: Any, *, cfg: LanceDbParams, text_column: str = "text") -> None: """Create vector (IVF_HNSW_SQ) and optionally FTS indices on a LanceDB table.""" try: table.create_index( @@ -216,48 +88,25 @@ def create_lancedb_index(table: Any, *, cfg: LanceDBConfig, text_column: str = " table.wait_for_index([index_stub.name], timeout=timedelta(seconds=600)) -def _write_rows_to_lancedb(rows: Sequence[Dict[str, Any]], *, cfg: LanceDBConfig) -> None: - if not rows: - logger.warning("No embeddings rows provided; nothing to write to LanceDB.") - return - - dim = _infer_vector_dim(rows) - if dim <= 0: - raise ValueError("Failed to infer embedding dimension from rows.") - - try: - import lancedb # type: ignore - except Exception as e: - raise RuntimeError( - "LanceDB write requested but dependencies are missing. " - "Install `lancedb` and `pyarrow` in this environment." - ) from e - - db = lancedb.connect(uri=cfg.uri) - - schema = lancedb_schema(vector_dim=dim) - - mode = "overwrite" if cfg.overwrite else "append" - table = db.create_table(cfg.table_name, data=list(rows), schema=schema, mode=mode) - - if cfg.create_index: - create_lancedb_index(table, cfg=cfg) - - -def write_embeddings_to_lancedb(df_with_embeddings: pd.DataFrame, *, cfg: LanceDBConfig) -> None: +def write_embeddings_to_lancedb(df_with_embeddings: pd.DataFrame, *, cfg: LanceDbParams) -> None: """ - Write embeddings found in `df_with_embeddings.metadata.embedding` to LanceDB. + Write embeddings found in *df_with_embeddings* to LanceDB. - This is used programmatically by `nemo_retriever.text_embed.stage.embed_text_from_primitives_df(...)`. + This is used programmatically by ``nemo_retriever.text_embed.stage``. """ - rows = _build_lancedb_rows_from_df(df_with_embeddings) - _write_rows_to_lancedb(rows, cfg=cfg) + from nemo_retriever.vector_store.lancedb_backend import LanceDBBackend + + records = build_vdb_records(df_with_embeddings) + backend = LanceDBBackend(cfg) + backend.write_rows(records) + if cfg.create_index: + backend.create_index() def write_text_embeddings_dir_to_lancedb( input_dir: Path, *, - cfg: LanceDBConfig, + cfg: LanceDbParams, recursive: bool = False, limit: Optional[int] = None, ) -> Dict[str, Any]: @@ -273,7 +122,7 @@ def write_text_embeddings_dir_to_lancedb( skipped = 0 failed = 0 - lancedb = LanceDB(uri=cfg.uri, table_name=cfg.table_name, overwrite=cfg.overwrite) + lancedb_client = LanceDB(uri=cfg.lancedb_uri, table_name=cfg.table_name, overwrite=cfg.overwrite) results = [] @@ -290,10 +139,10 @@ def write_text_embeddings_dir_to_lancedb( "processed": 0, "skipped": 0, "failed": 0, - "lancedb": {"uri": cfg.uri, "table_name": cfg.table_name, "overwrite": cfg.overwrite}, + "lancedb": {"uri": cfg.lancedb_uri, "table_name": cfg.table_name, "overwrite": cfg.overwrite}, } - lancedb.run(results) + lancedb_client.run(results) return { "input_dir": str(input_dir), @@ -301,31 +150,36 @@ def write_text_embeddings_dir_to_lancedb( "processed": processed, "skipped": skipped, "failed": failed, - # "rows_written": len(all_rows), - "lancedb": {"uri": cfg.uri, "table_name": cfg.table_name, "overwrite": cfg.overwrite}, + "lancedb": {"uri": cfg.lancedb_uri, "table_name": cfg.table_name, "overwrite": cfg.overwrite}, } def handle_lancedb( - rows: Path, + rows: Any, uri: str, table_name: str, hybrid: bool = False, mode: str = "overwrite", -) -> Dict[str, Any]: +) -> None: + """Write pipeline results to LanceDB. + + Accepts *rows* as a ``pd.DataFrame`` or ``list[dict]`` (e.g. from + ``take_all()``). Converts to canonical VDB records, writes via + :class:`LanceDBBackend`, and creates search indices. """ - Handle LanceDB writing for a batch pipeline run. + from nemo_retriever.vector_store.lancedb_backend import LanceDBBackend - This is used by `nemo_retriever.examples.batch_pipeline.run(...)` after the embedding stage. + if isinstance(rows, pd.DataFrame): + records = build_vdb_records(rows) + else: + records = build_vdb_records_from_dicts(rows) - Reads `*.text_embeddings.json` files from `input_dir`, extracts embeddings, and uploads to LanceDB. + params = LanceDbParams( + lancedb_uri=uri, + table_name=table_name, + hybrid=hybrid, + overwrite=(mode == "overwrite"), ) - """ - lancedb_config = LanceDBConfig( - uri=uri, table_name=table_name, hybrid=hybrid - ) # Use the same LanceDB config for writing and recall. - db = lancedb.connect(uri=lancedb_config.uri) - cleaned_rows = _build_lancedb_rows_from_df(rows) - _write_rows_to_lancedb(cleaned_rows, cfg=lancedb_config) - table = db.open_table(lancedb_config.table_name) # Ensure table is open and metadata is updated before proceeding. - create_lancedb_index(table, cfg=lancedb_config) + backend = LanceDBBackend(params) + backend.write_rows(records) + backend.create_index() diff --git a/nemo_retriever/src/nemo_retriever/vector_store/stage.py b/nemo_retriever/src/nemo_retriever/vector_store/stage.py index ba3f994a17..7c59d718b9 100644 --- a/nemo_retriever/src/nemo_retriever/vector_store/stage.py +++ b/nemo_retriever/src/nemo_retriever/vector_store/stage.py @@ -10,7 +10,8 @@ import typer from rich.console import Console -from nemo_retriever.vector_store.lancedb_store import LanceDBConfig, write_text_embeddings_dir_to_lancedb +from nemo_retriever.params.models import LanceDbParams +from nemo_retriever.vector_store.lancedb_store import write_text_embeddings_dir_to_lancedb console = Console() app = typer.Typer(help="Vector store stage: upload stage5 embeddings to a vector DB (LanceDB).") @@ -54,8 +55,8 @@ def run( - `page_number`: page number from `metadata.content_metadata.page_number` - `path` / `source_id`: source identifiers """ - cfg = LanceDBConfig( - uri=str(lancedb_uri), + cfg = LanceDbParams( + lancedb_uri=str(lancedb_uri), table_name=str(table_name), overwrite=bool(overwrite), create_index=bool(create_index), @@ -73,7 +74,7 @@ def run( ) console.print( f"[green]Done[/green] files={info['n_files']} processed={info['processed']} skipped={info['skipped']} " - f"failed={info['failed']} lancedb_uri={cfg.uri} table={cfg.table_name}" + f"failed={info['failed']} lancedb_uri={cfg.lancedb_uri} table={cfg.table_name}" ) diff --git a/nemo_retriever/src/nemo_retriever/vector_store/vdb.py b/nemo_retriever/src/nemo_retriever/vector_store/vdb.py new file mode 100644 index 0000000000..a704c02245 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/vector_store/vdb.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Abstract base class for vector store backends. + +Backends receive rows in the canonical VDB record format produced by +:func:`nemo_retriever.vector_store.vdb_records.build_vdb_records`. +Write-path only; retrieval support will be added with the second backend. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Sequence + + +class VectorStore(ABC): + """Abstract base for vector store backends.""" + + @abstractmethod + def create_table(self, *, dim: int, **kwargs: Any) -> None: + """Create or reset the storage table / index.""" + + @abstractmethod + def write_rows(self, rows: Sequence[dict[str, Any]], **kwargs: Any) -> None: + """Write a batch of canonical VDB records.""" + + @abstractmethod + def create_index(self, **kwargs: Any) -> None: + """Build search indices after all writes complete.""" diff --git a/nemo_retriever/src/nemo_retriever/vector_store/vdb_records.py b/nemo_retriever/src/nemo_retriever/vector_store/vdb_records.py new file mode 100644 index 0000000000..668e2a89aa --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/vector_store/vdb_records.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Canonical VDB record builder. + +Converts a pandas DataFrame (the graph pipeline's output format) into a list +of backend-neutral VDB record dicts. Every VDB backend in ``nemo_retriever`` +consumes this record format — it is the single source of truth for the +DataFrame → VDB record contract. + +Canonical record schema (matches ``retriever.py`` query expectations):: + + vector : list[float] # embedding + text : str # content + metadata : str # JSON string (round-trips via json.loads) + source : str # JSON string {"source_id": "..."} + page_number : int + pdf_page : str # "basename_pagenum" + pdf_basename : str + filename : str + source_id : str + path : str +""" + +from __future__ import annotations + +from typing import Any, Dict, List + +import pandas as pd + +from nemo_retriever.vector_store.lancedb_utils import build_lancedb_row + + +def build_vdb_records( + df: pd.DataFrame, + *, + embedding_column: str = "text_embeddings_1b_v2", + embedding_key: str = "embedding", + text_column: str = "text", + include_text: bool = True, +) -> List[Dict[str, Any]]: + """Convert a post-embed DataFrame into canonical VDB records. + + Rows without a valid embedding are silently skipped. + """ + rows: List[Dict[str, Any]] = [] + for row in df.itertuples(index=False): + row_out = build_lancedb_row( + row, + embedding_column=embedding_column, + embedding_key=embedding_key, + text_column=text_column, + include_text=include_text, + ) + if row_out is not None: + rows.append(row_out) + return rows + + +def build_vdb_records_from_dicts( + records: List[Dict[str, Any]], + *, + embedding_column: str = "text_embeddings_1b_v2", + embedding_key: str = "embedding", + text_column: str = "text", + include_text: bool = True, +) -> List[Dict[str, Any]]: + """Convert a list of dicts (e.g. from ``take_all()``) into canonical VDB records.""" + if not records: + return [] + df = pd.DataFrame(records) + return build_vdb_records( + df, + embedding_column=embedding_column, + embedding_key=embedding_key, + text_column=text_column, + include_text=include_text, + ) diff --git a/nemo_retriever/tests/test_vdb_record_contract.py b/nemo_retriever/tests/test_vdb_record_contract.py new file mode 100644 index 0000000000..6fc2f5e54e --- /dev/null +++ b/nemo_retriever/tests/test_vdb_record_contract.py @@ -0,0 +1,377 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the canonical VDB record contract, VectorStore ABC, and LanceDB backend. + +These tests validate: + - build_vdb_records produces the correct canonical record format + - build_vdb_records_from_dicts (transitional list[dict] path) matches + - VectorStore ABC enforces the interface contract + - LanceDBBackend round-trips data correctly + - handle_lancedb now uses the canonical builder (regression) +""" + +from __future__ import annotations + +import copy +import json + +import pandas as pd +import pytest + + +# --------------------------------------------------------------------------- +# Shared fixture +# --------------------------------------------------------------------------- + + +def _make_sample_dataframe() -> pd.DataFrame: + """Build a minimal DataFrame matching the graph pipeline's post-embed output.""" + embedding = [0.1, 0.2, 0.3, 0.4] + metadata = { + "embedding": embedding, + "source_path": "/data/test.pdf", + "content_metadata": {"hierarchy": {"page": 0}}, + } + return pd.DataFrame( + [ + { + "metadata": metadata, + "text_embeddings_1b_v2": {"embedding": embedding, "info_msg": None}, + "text": "Hello world", + "path": "/data/test.pdf", + "page_number": 0, + "page_elements_v3_num_detections": 5, + "page_elements_v3_counts_by_label": {"text": 3, "table": 2}, + } + ] + ) + + +# --------------------------------------------------------------------------- +# Canonical record builder tests +# --------------------------------------------------------------------------- + + +class TestBuildVdbRecords: + def test_produces_all_required_fields(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = _make_sample_dataframe() + rows = build_vdb_records(df) + + assert len(rows) == 1 + row = rows[0] + for field in ( + "vector", + "text", + "metadata", + "source", + "page_number", + "pdf_page", + "pdf_basename", + "source_id", + "path", + "filename", + ): + assert field in row, f"Missing required field: {field}" + + def test_metadata_is_valid_json(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = _make_sample_dataframe() + rows = build_vdb_records(df) + + meta = json.loads(rows[0]["metadata"]) + assert isinstance(meta, dict) + assert "page_number" in meta + + def test_metadata_includes_detection_counts(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = _make_sample_dataframe() + rows = build_vdb_records(df) + + meta = json.loads(rows[0]["metadata"]) + assert meta["page_elements_v3_num_detections"] == 5 + assert meta["page_elements_v3_counts_by_label"] == {"text": 3, "table": 2} + + def test_source_is_json_object(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = _make_sample_dataframe() + rows = build_vdb_records(df) + + source = json.loads(rows[0]["source"]) + assert isinstance(source, dict) + assert "source_id" in source + + def test_does_not_mutate_input(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = _make_sample_dataframe() + original_meta = copy.deepcopy(df.iloc[0]["metadata"]) + + build_vdb_records(df) + + assert df.iloc[0]["metadata"] == original_meta + + def test_vector_is_embedding(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = _make_sample_dataframe() + rows = build_vdb_records(df) + + assert rows[0]["vector"] == [0.1, 0.2, 0.3, 0.4] + + def test_skips_rows_without_embedding(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = pd.DataFrame( + [ + { + "metadata": {"source_path": "/data/test.pdf"}, + "text": "No embedding here", + "path": "/data/test.pdf", + "page_number": 0, + } + ] + ) + rows = build_vdb_records(df) + assert len(rows) == 0 + + def test_empty_dataframe(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = pd.DataFrame() + rows = build_vdb_records(df) + assert rows == [] + + def test_text_content(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = _make_sample_dataframe() + rows = build_vdb_records(df) + assert rows[0]["text"] == "Hello world" + + def test_include_text_false(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = _make_sample_dataframe() + rows = build_vdb_records(df, include_text=False) + assert rows[0]["text"] == "" + + +# --------------------------------------------------------------------------- +# Transitional list[dict] builder tests +# --------------------------------------------------------------------------- + + +class TestBuildVdbRecordsFromDicts: + def test_matches_dataframe_path(self): + """list[dict] path should produce identical output to DataFrame path.""" + from nemo_retriever.vector_store.vdb_records import build_vdb_records, build_vdb_records_from_dicts + + df = _make_sample_dataframe() + records = df.to_dict("records") + + from_df = build_vdb_records(df) + from_dicts = build_vdb_records_from_dicts(records) + + assert len(from_df) == len(from_dicts) + assert from_df[0]["vector"] == from_dicts[0]["vector"] + assert from_df[0]["text"] == from_dicts[0]["text"] + assert from_df[0]["path"] == from_dicts[0]["path"] + assert from_df[0]["page_number"] == from_dicts[0]["page_number"] + + def test_empty_list(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records_from_dicts + + assert build_vdb_records_from_dicts([]) == [] + + +# --------------------------------------------------------------------------- +# VectorStore ABC tests +# --------------------------------------------------------------------------- + + +class TestVectorStoreABC: + def test_cannot_instantiate_directly(self): + from nemo_retriever.vector_store.vdb import VectorStore + + with pytest.raises(TypeError): + VectorStore() + + def test_subclass_must_implement_all_methods(self): + from nemo_retriever.vector_store.vdb import VectorStore + + class Incomplete(VectorStore): + def create_table(self, *, dim, **kwargs): + pass + + with pytest.raises(TypeError): + Incomplete() + + def test_lancedb_backend_is_valid_subclass(self): + from nemo_retriever.vector_store.lancedb_backend import LanceDBBackend + from nemo_retriever.vector_store.vdb import VectorStore + + assert issubclass(LanceDBBackend, VectorStore) + + +# --------------------------------------------------------------------------- +# LanceDB backend tests +# --------------------------------------------------------------------------- + + +class TestLanceDBBackend: + def test_write_rows_creates_table_lazily(self, tmp_path): + from nemo_retriever.params.models import LanceDbParams + from nemo_retriever.vector_store.lancedb_backend import LanceDBBackend + import lancedb + + params = LanceDbParams(lancedb_uri=str(tmp_path / "test_db"), table_name="test_table", create_index=False) + backend = LanceDBBackend(params) + + rows = [ + { + "vector": [0.1, 0.2, 0.3], + "text": "hello", + "metadata": "{}", + "source": "{}", + "page_number": 0, + "pdf_page": "", + "pdf_basename": "", + "filename": "", + "source_id": "", + "path": "", + } + ] + backend.write_rows(rows) + + db = lancedb.connect(str(tmp_path / "test_db")) + table = db.open_table("test_table") + assert table.count_rows() == 1 + + def test_multiple_writes_accumulate(self, tmp_path): + from nemo_retriever.params.models import LanceDbParams + from nemo_retriever.vector_store.lancedb_backend import LanceDBBackend + import lancedb + + params = LanceDbParams(lancedb_uri=str(tmp_path / "test_db"), table_name="test_table", create_index=False) + backend = LanceDBBackend(params) + + row_template = { + "vector": [0.1, 0.2, 0.3], + "text": "hello", + "metadata": "{}", + "source": "{}", + "page_number": 0, + "pdf_page": "", + "pdf_basename": "", + "filename": "", + "source_id": "", + "path": "", + } + + backend.write_rows([row_template]) + backend.write_rows([row_template, row_template]) + + db = lancedb.connect(str(tmp_path / "test_db")) + table = db.open_table("test_table") + assert table.count_rows() == 3 + + def test_empty_writes_are_noop(self, tmp_path): + from nemo_retriever.params.models import LanceDbParams + from nemo_retriever.vector_store.lancedb_backend import LanceDBBackend + + params = LanceDbParams(lancedb_uri=str(tmp_path / "test_db"), table_name="test_table", create_index=False) + backend = LanceDBBackend(params) + + backend.write_rows([]) + assert backend._table is None + + def test_create_index_noop_when_no_writes(self, tmp_path): + from nemo_retriever.params.models import LanceDbParams + from nemo_retriever.vector_store.lancedb_backend import LanceDBBackend + + params = LanceDbParams(lancedb_uri=str(tmp_path / "test_db"), table_name="test_table") + backend = LanceDBBackend(params) + backend.create_index() # should not raise + + +# --------------------------------------------------------------------------- +# Regression: handle_lancedb now uses canonical builder +# --------------------------------------------------------------------------- + + +class TestHandleLancedbRegression: + def test_handle_lancedb_writes_valid_json_metadata(self, tmp_path): + """After refactoring, handle_lancedb should produce valid JSON metadata.""" + from nemo_retriever.vector_store.lancedb_store import handle_lancedb + import lancedb + + df = _make_sample_dataframe() + rows = df.to_dict("records") + + uri = str(tmp_path / "test_db") + handle_lancedb(rows, uri, "test_table", mode="overwrite") + + db = lancedb.connect(uri) + table = db.open_table("test_table") + result = table.to_pandas() + + assert len(result) == 1 + meta_str = result.iloc[0]["metadata"] + meta = json.loads(meta_str) + assert isinstance(meta, dict) + assert "page_number" in meta + + def test_handle_lancedb_accepts_dataframe(self, tmp_path): + """handle_lancedb should now accept a DataFrame directly.""" + from nemo_retriever.vector_store.lancedb_store import handle_lancedb + import lancedb + + df = _make_sample_dataframe() + + uri = str(tmp_path / "test_db") + handle_lancedb(df, uri, "test_table", mode="overwrite") + + db = lancedb.connect(uri) + table = db.open_table("test_table") + assert table.count_rows() == 1 + + def test_handle_lancedb_round_trip_preserves_text(self, tmp_path): + """Text content should survive the write→read round-trip.""" + from nemo_retriever.vector_store.lancedb_store import handle_lancedb + import lancedb + + df = _make_sample_dataframe() + rows = df.to_dict("records") + + uri = str(tmp_path / "test_db") + handle_lancedb(rows, uri, "test_table", mode="overwrite") + + db = lancedb.connect(uri) + table = db.open_table("test_table") + result = table.to_pandas() + + assert result.iloc[0]["text"] == "Hello world" + + def test_handle_lancedb_round_trip_preserves_path(self, tmp_path): + from nemo_retriever.vector_store.lancedb_store import handle_lancedb + import lancedb + + df = _make_sample_dataframe() + rows = df.to_dict("records") + + uri = str(tmp_path / "test_db") + handle_lancedb(rows, uri, "test_table", mode="overwrite") + + db = lancedb.connect(uri) + table = db.open_table("test_table") + result = table.to_pandas() + + assert result.iloc[0]["path"] == "/data/test.pdf" + assert result.iloc[0]["page_number"] == 0