Skip to content
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nemo_retriever/harness/test_configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ presets:

datasets:
bo20:
path: /home/jdyer/datasets/bo20
path: /datasets/nv-ingest/bo20
query_csv: null
input_type: pdf
recall_required: false
Expand Down
226 changes: 137 additions & 89 deletions nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@
from nemo_retriever.params import ExtractParams
from nemo_retriever.params import StoreParams
from nemo_retriever.params import TextChunkParams
from nemo_retriever.params import VdbUploadParams
from nemo_retriever.model import VL_EMBED_MODEL, VL_RERANK_MODEL
from nemo_retriever.params.models import BatchTuningParams
from nemo_retriever.params.models import BatchTuningParams, LanceDbParams
from nemo_retriever.utils.input_files import resolve_input_patterns
from nemo_retriever.utils.remote_auth import resolve_remote_api_key
from nemo_retriever.vector_store.lancedb_store import handle_lancedb

logger = logging.getLogger(__name__)
app = typer.Typer()
Expand Down Expand Up @@ -126,46 +126,77 @@ def _configure_logging(log_file: Optional[Path], *, debug: bool = False) -> tupl
return fh, original_stdout, original_stderr


def _ensure_lancedb_table(uri: str, table_name: str) -> None:
from nemo_retriever.vector_store.lancedb_utils import lancedb_schema
import lancedb
import pyarrow as pa

Path(uri).mkdir(parents=True, exist_ok=True)
db = lancedb.connect(uri)
try:
db.open_table(table_name)
return
except Exception:
pass
schema = lancedb_schema()
empty = pa.table({f.name: [] for f in schema}, schema=schema)
db.create_table(table_name, data=empty, schema=schema, mode="create")


def _write_runtime_summary(
runtime_metrics_dir: Optional[Path],
runtime_metrics_prefix: Optional[str],
payload: dict[str, object],
metrics_output_file: Optional[Path] = None,
) -> None:
if runtime_metrics_dir is None and not runtime_metrics_prefix:
return
if runtime_metrics_dir is not None or runtime_metrics_prefix:
target_dir = Path(runtime_metrics_dir or Path.cwd()).expanduser().resolve()
target_dir.mkdir(parents=True, exist_ok=True)
prefix = (runtime_metrics_prefix or "run").strip() or "run"
target = target_dir / f"{prefix}.runtime.summary.json"
target.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")

if metrics_output_file is not None:
out_path = Path(metrics_output_file).expanduser()
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")


target_dir = Path(runtime_metrics_dir or Path.cwd()).expanduser().resolve()
target_dir.mkdir(parents=True, exist_ok=True)
prefix = (runtime_metrics_prefix or "run").strip() or "run"
target = target_dir / f"{prefix}.runtime.summary.json"
target.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
def _page_keys_from_df(result_df) -> list[str]:
"""Return stable page/input-unit keys from a pipeline result DataFrame."""
if result_df is None or result_df.empty:
return []

source_column = next((c for c in ("source_id", "path", "source_path") if c in result_df.columns), None)
if source_column is not None and "page_number" in result_df.columns:
key_df = result_df[[source_column, "page_number"]].dropna()
return [f"{source}\x1f{page}" for source, page in key_df.itertuples(index=False, name=None)]

def _count_input_units(result_df) -> int:
if "source_id" in result_df.columns:
return int(result_df["source_id"].nunique())
if "source_path" in result_df.columns:
return int(result_df["source_path"].nunique())
if source_column is not None:
return [str(v) for v in result_df[source_column].dropna().tolist()]

if "page_number" in result_df.columns:
return [str(v) for v in result_df["page_number"].dropna().tolist()]

return []


def _count_processed_pages_from_df(result_df) -> int:
keys = _page_keys_from_df(result_df)
if keys:
return int(len(set(keys)))
return int(len(result_df.index))


def _extract_page_key_batch(batch):
import pandas as pd

keys = _page_keys_from_df(batch)
return pd.DataFrame({"_page_key": keys})


def _count_processed_pages_from_dataset(dataset, *, fallback_rows: int) -> int:
try:
columns = set(dataset.columns())
except Exception:
return int(fallback_rows)

if not columns.intersection({"source_id", "path", "source_path", "page_number"}):
return int(fallback_rows)

try:
key_ds = dataset.map_batches(_extract_page_key_batch, batch_format="pandas")
if int(key_ds.count()) == 0:
Comment on lines +183 to +192
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Silent exception swallow in _count_processed_pages_from_dataset

The first except Exception block returns a fallback value with no logging at all. If dataset.columns() raises (e.g., due to a serialisation bug or disconnected Ray cluster), the caller silently gets an incorrect page count and no signal that anything went wrong. The no-bare-except rule requires broad catches to at minimum log with exc_info=True.

Suggested change
columns = set(dataset.columns())
except Exception:
return int(fallback_rows)
if not columns.intersection({"source_id", "path", "source_path", "page_number"}):
return int(fallback_rows)
try:
key_ds = dataset.map_batches(_extract_page_key_batch, batch_format="pandas")
if int(key_ds.count()) == 0:
try:
columns = set(dataset.columns())
except Exception:
logger.warning("Could not read Ray Dataset columns; falling back to output row count.", exc_info=True)
return int(fallback_rows)

Rule Used: Never use bare 'except:' that silently swallows er... (source)

Prompt To Fix With AI
This is a comment left during a code review.
Path: nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py
Line: 183-192

Comment:
**Silent exception swallow in `_count_processed_pages_from_dataset`**

The first `except Exception` block returns a fallback value with no logging at all. If `dataset.columns()` raises (e.g., due to a serialisation bug or disconnected Ray cluster), the caller silently gets an incorrect page count and no signal that anything went wrong. The `no-bare-except` rule requires broad catches to at minimum log with `exc_info=True`.

```suggestion
    try:
        columns = set(dataset.columns())
    except Exception:
        logger.warning("Could not read Ray Dataset columns; falling back to output row count.", exc_info=True)
        return int(fallback_rows)
```

**Rule Used:** Never use bare 'except:' that silently swallows er... ([source](https://app.greptile.com/review/custom-context?memory=no-bare-except))

How can I resolve this? If you propose a fix, please make it concise.

return 0
return int(key_ds.groupby("_page_key").count().count())
except Exception:
logger.warning("Could not estimate processed pages from Ray Dataset; falling back to output row count.", exc_info=True)
return int(fallback_rows)


def _resolve_file_patterns(input_path: Path, input_type: str) -> list[str]:
import glob as _glob

Expand Down Expand Up @@ -309,6 +340,13 @@ def main(
runtime_metrics_dir: Optional[Path] = typer.Option(None, "--runtime-metrics-dir", path_type=Path),
runtime_metrics_prefix: Optional[str] = typer.Option(None, "--runtime-metrics-prefix"),
detection_summary_file: Optional[Path] = typer.Option(None, "--detection-summary-file", path_type=Path),
metrics_output_file: Optional[Path] = typer.Option(
None,
"--metrics-output-file",
path_type=Path,
dir_okay=False,
help="JSON file path to write structured run metrics (used by the harness).",
),
log_file: Optional[Path] = typer.Option(None, "--log-file", path_type=Path, dir_okay=False),
) -> None:
_ = ctx
Expand All @@ -327,7 +365,6 @@ def main(
os.environ["RAY_LOG_TO_DRIVER"] = "1" if ray_log_to_driver else "0"

lancedb_uri = str(Path(lancedb_uri).expanduser().resolve())
_ensure_lancedb_table(lancedb_uri, LANCEDB_TABLE)

remote_api_key = resolve_remote_api_key(api_key)
extract_remote_api_key = remote_api_key
Expand Down Expand Up @@ -528,41 +565,66 @@ def main(

ingestor = ingestor.embed(embed_params)

# VDB upload runs inside the graph — rows stream to the configured
# backend as they are produced, so we never need to collect the entire
# result set on the driver just for the write. Index creation happens
# automatically in GraphIngestor._finalize_vdb() after the pipeline.
ingestor = ingestor.vdb_upload(
VdbUploadParams(
lancedb=LanceDbParams(
lancedb_uri=lancedb_uri,
table_name=LANCEDB_TABLE,
hybrid=hybrid,
overwrite=True,
),
)
)

# ------------------------------------------------------------------
# Execute the graph via the executor
# ------------------------------------------------------------------
logger.info("Starting ingestion of %s ...", input_path)
ingest_start = time.perf_counter()

# GraphIngestor.ingest() builds the Graph, creates the executor,
# and calls executor.ingest(file_patterns) returning:
# calls executor.ingest(file_patterns), and finalizes the VDB index.
# batch mode -> materialized ray.data.Dataset
# inprocess mode -> pandas.DataFrame
result = ingestor.ingest()

ingestion_only_total_time = time.perf_counter() - ingest_start

# ------------------------------------------------------------------
# Collect results
# Collect results only when downstream features need the full DataFrame.
# Page/row metrics stay separate: PPS is pages/sec, while num_rows tracks
# output rows after any content explosion.
# ------------------------------------------------------------------
if run_mode == "batch":
import ray

ray_download_start = time.perf_counter()
ingest_local_results = result.take_all()
ray_download_time = time.perf_counter() - ray_download_start

import pandas as pd

result_df = pd.DataFrame(ingest_local_results)
num_rows = _count_input_units(result_df)
needs_result_df = detection_summary_file is not None or save_intermediate is not None
if needs_result_df:
ray_download_start = time.perf_counter()
ingest_local_results = result.take_all()
ray_download_time = time.perf_counter() - ray_download_start

import pandas as pd

result_df = pd.DataFrame(ingest_local_results)
processed_pages = _count_processed_pages_from_df(result_df)
output_rows = int(len(result_df.index))
else:
ray_download_time = 0.0
result_df = None
output_rows = int(result.count())
processed_pages = _count_processed_pages_from_dataset(result, fallback_rows=output_rows)
else:
import pandas as pd

result_df = result
ingest_local_results = result_df.to_dict("records")
ray_download_time = 0.0
num_rows = _count_input_units(result_df)
processed_pages = _count_processed_pages_from_df(result_df)
output_rows = int(len(result_df.index))

if save_intermediate is not None:
out_dir = Path(save_intermediate).expanduser().resolve()
Expand All @@ -582,44 +644,49 @@ def main(
collect_detection_summary_from_df(result_df),
)

# ------------------------------------------------------------------
# Write to LanceDB
# ------------------------------------------------------------------
lancedb_write_start = time.perf_counter()
handle_lancedb(ingest_local_results, lancedb_uri, LANCEDB_TABLE, hybrid=hybrid, mode="overwrite")
lancedb_write_time = time.perf_counter() - lancedb_write_start

# ------------------------------------------------------------------
# Recall / BEIR evaluation
# ------------------------------------------------------------------
import lancedb as _lancedb_mod

db = _lancedb_mod.connect(lancedb_uri)
table = db.open_table(LANCEDB_TABLE)

if int(table.count_rows()) == 0:
logger.warning("LanceDB table is empty; skipping %s evaluation.", evaluation_mode)
def _empty_summary(reason_label: str) -> None:
_write_runtime_summary(
runtime_metrics_dir,
runtime_metrics_prefix,
{
"run_mode": run_mode,
"input_path": str(Path(input_path).resolve()),
"input_pages": int(num_rows),
"num_pages": int(num_rows),
"num_rows": int(len(result_df.index)),
"input_pages": int(processed_pages),
"num_pages": int(processed_pages),
"num_rows": int(output_rows),
"ingestion_only_secs": float(ingestion_only_total_time),
"ray_download_secs": float(ray_download_time),
"lancedb_write_secs": float(lancedb_write_time),
"lancedb_write_secs": 0.0,
"evaluation_secs": 0.0,
"total_secs": float(time.perf_counter() - ingest_start),
"evaluation_mode": evaluation_mode,
"evaluation_metrics": {},
"recall_details": bool(recall_details),
"lancedb_uri": str(lancedb_uri),
"lancedb_table": str(LANCEDB_TABLE),
"skip_reason": reason_label,
},
metrics_output_file=metrics_output_file,
)

import lancedb as _lancedb_mod

db = _lancedb_mod.connect(lancedb_uri)
try:
table = db.open_table(LANCEDB_TABLE)
except Exception:
logger.warning("LanceDB table %r was not created; skipping %s evaluation.", LANCEDB_TABLE, evaluation_mode)
_empty_summary("lancedb_table_missing")
if run_mode == "batch":
ray.shutdown()
return

if int(table.count_rows()) == 0:
logger.warning("LanceDB table is empty; skipping %s evaluation.", evaluation_mode)
_empty_summary("lancedb_table_empty")
if run_mode == "batch":
ray.shutdown()
return
Expand Down Expand Up @@ -666,27 +733,7 @@ def main(
query_csv_path = Path(query_csv)
if not query_csv_path.exists():
logger.warning("Query CSV not found at %s; skipping recall evaluation.", query_csv_path)
_write_runtime_summary(
runtime_metrics_dir,
runtime_metrics_prefix,
{
"run_mode": run_mode,
"input_path": str(Path(input_path).resolve()),
"input_pages": int(num_rows),
"num_pages": int(num_rows),
"num_rows": int(len(result_df.index)),
"ingestion_only_secs": float(ingestion_only_total_time),
"ray_download_secs": float(ray_download_time),
"lancedb_write_secs": float(lancedb_write_time),
"evaluation_secs": 0.0,
"total_secs": float(time.perf_counter() - ingest_start),
"evaluation_mode": evaluation_mode,
"evaluation_metrics": {},
"recall_details": bool(recall_details),
"lancedb_uri": str(lancedb_uri),
"lancedb_table": str(LANCEDB_TABLE),
},
)
_empty_summary("query_csv_missing")
if run_mode == "batch":
ray.shutdown()
return
Expand Down Expand Up @@ -722,12 +769,12 @@ def main(
{
"run_mode": run_mode,
"input_path": str(Path(input_path).resolve()),
"input_pages": int(num_rows),
"num_pages": int(num_rows),
"num_rows": int(len(result_df.index)),
"input_pages": int(processed_pages),
"num_pages": int(processed_pages),
"num_rows": int(output_rows),
"ingestion_only_secs": float(ingestion_only_total_time),
"ray_download_secs": float(ray_download_time),
"lancedb_write_secs": float(lancedb_write_time),
"lancedb_write_secs": 0.0,
"evaluation_secs": float(evaluation_total_time),
"total_secs": float(total_time),
"evaluation_mode": evaluation_mode,
Expand All @@ -737,21 +784,22 @@ def main(
"lancedb_uri": str(lancedb_uri),
"lancedb_table": str(LANCEDB_TABLE),
},
metrics_output_file=metrics_output_file,
)

if run_mode == "batch":
ray.shutdown()

print_run_summary(
num_rows,
processed_pages,
Path(input_path),
hybrid,
lancedb_uri,
LANCEDB_TABLE,
total_time,
ingestion_only_total_time,
ray_download_time,
lancedb_write_time,
0.0,
evaluation_total_time,
evaluation_metrics,
evaluation_label=evaluation_label,
Expand Down
2 changes: 2 additions & 0 deletions nemo_retriever/src/nemo_retriever/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from nemo_retriever.graph.graph_pipeline_registry import GraphPipelineRegistry, default_registry
from nemo_retriever.graph.pipeline_graph import Graph, Node
from nemo_retriever.graph.store_operator import StoreOperator
from nemo_retriever.graph.vdb_upload_operator import VDBUploadOperator

__all__ = [
"AbstractExecutor",
Expand All @@ -32,6 +33,7 @@
"RayDataExecutor",
"StoreOperator",
"UDFOperator",
"VDBUploadOperator",
"default_registry",
]

Expand Down
Loading