Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nemo_retriever/harness/test_configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ presets:

datasets:
bo20:
path: /home/jdyer/datasets/bo20
path: /datasets/nv-ingest/bo20
query_csv: null
input_type: pdf
recall_required: false
Expand Down
226 changes: 137 additions & 89 deletions nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@
from nemo_retriever.params import ExtractParams
from nemo_retriever.params import StoreParams
from nemo_retriever.params import TextChunkParams
from nemo_retriever.params import VdbUploadParams
from nemo_retriever.model import VL_EMBED_MODEL, VL_RERANK_MODEL
from nemo_retriever.params.models import BatchTuningParams
from nemo_retriever.params.models import BatchTuningParams, LanceDbParams
from nemo_retriever.utils.input_files import resolve_input_patterns
from nemo_retriever.utils.remote_auth import resolve_remote_api_key
from nemo_retriever.vector_store.lancedb_store import handle_lancedb

logger = logging.getLogger(__name__)
app = typer.Typer()
Expand Down Expand Up @@ -126,46 +126,77 @@ def _configure_logging(log_file: Optional[Path], *, debug: bool = False) -> tupl
return fh, original_stdout, original_stderr


def _ensure_lancedb_table(uri: str, table_name: str) -> None:
from nemo_retriever.vector_store.lancedb_utils import lancedb_schema
import lancedb
import pyarrow as pa

Path(uri).mkdir(parents=True, exist_ok=True)
db = lancedb.connect(uri)
try:
db.open_table(table_name)
return
except Exception:
pass
schema = lancedb_schema()
empty = pa.table({f.name: [] for f in schema}, schema=schema)
db.create_table(table_name, data=empty, schema=schema, mode="create")


def _write_runtime_summary(
runtime_metrics_dir: Optional[Path],
runtime_metrics_prefix: Optional[str],
payload: dict[str, object],
metrics_output_file: Optional[Path] = None,
) -> None:
if runtime_metrics_dir is None and not runtime_metrics_prefix:
return
if runtime_metrics_dir is not None or runtime_metrics_prefix:
target_dir = Path(runtime_metrics_dir or Path.cwd()).expanduser().resolve()
target_dir.mkdir(parents=True, exist_ok=True)
prefix = (runtime_metrics_prefix or "run").strip() or "run"
target = target_dir / f"{prefix}.runtime.summary.json"
target.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")

if metrics_output_file is not None:
out_path = Path(metrics_output_file).expanduser()
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")


target_dir = Path(runtime_metrics_dir or Path.cwd()).expanduser().resolve()
target_dir.mkdir(parents=True, exist_ok=True)
prefix = (runtime_metrics_prefix or "run").strip() or "run"
target = target_dir / f"{prefix}.runtime.summary.json"
target.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
def _page_keys_from_df(result_df) -> list[str]:
"""Return stable page/input-unit keys from a pipeline result DataFrame."""
if result_df is None or result_df.empty:
return []

source_column = next((c for c in ("source_id", "path", "source_path") if c in result_df.columns), None)
if source_column is not None and "page_number" in result_df.columns:
key_df = result_df[[source_column, "page_number"]].dropna()
return [f"{source}\x1f{page}" for source, page in key_df.itertuples(index=False, name=None)]

def _count_input_units(result_df) -> int:
if "source_id" in result_df.columns:
return int(result_df["source_id"].nunique())
if "source_path" in result_df.columns:
return int(result_df["source_path"].nunique())
if source_column is not None:
return [str(v) for v in result_df[source_column].dropna().tolist()]

if "page_number" in result_df.columns:
return [str(v) for v in result_df["page_number"].dropna().tolist()]

return []


def _count_processed_pages_from_df(result_df) -> int:
keys = _page_keys_from_df(result_df)
if keys:
return int(len(set(keys)))
return int(len(result_df.index))


def _extract_page_key_batch(batch):
import pandas as pd

keys = _page_keys_from_df(batch)
return pd.DataFrame({"_page_key": keys})


def _count_processed_pages_from_dataset(dataset, *, fallback_rows: int) -> int:
try:
columns = set(dataset.columns())
except Exception:
return int(fallback_rows)

if not columns.intersection({"source_id", "path", "source_path", "page_number"}):
return int(fallback_rows)

try:
key_ds = dataset.map_batches(_extract_page_key_batch, batch_format="pandas")
if int(key_ds.count()) == 0:
Comment on lines +183 to +192
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Silent exception swallow in _count_processed_pages_from_dataset

The first except Exception block returns a fallback value with no logging at all. If dataset.columns() raises (e.g., due to a serialisation bug or disconnected Ray cluster), the caller silently gets an incorrect page count and no signal that anything went wrong. The no-bare-except rule requires broad catches to at minimum log with exc_info=True.

Suggested change
columns = set(dataset.columns())
except Exception:
return int(fallback_rows)
if not columns.intersection({"source_id", "path", "source_path", "page_number"}):
return int(fallback_rows)
try:
key_ds = dataset.map_batches(_extract_page_key_batch, batch_format="pandas")
if int(key_ds.count()) == 0:
try:
columns = set(dataset.columns())
except Exception:
logger.warning("Could not read Ray Dataset columns; falling back to output row count.", exc_info=True)
return int(fallback_rows)

Rule Used: Never use bare 'except:' that silently swallows er... (source)

Prompt To Fix With AI
This is a comment left during a code review.
Path: nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py
Line: 183-192

Comment:
**Silent exception swallow in `_count_processed_pages_from_dataset`**

The first `except Exception` block returns a fallback value with no logging at all. If `dataset.columns()` raises (e.g., due to a serialisation bug or disconnected Ray cluster), the caller silently gets an incorrect page count and no signal that anything went wrong. The `no-bare-except` rule requires broad catches to at minimum log with `exc_info=True`.

```suggestion
    try:
        columns = set(dataset.columns())
    except Exception:
        logger.warning("Could not read Ray Dataset columns; falling back to output row count.", exc_info=True)
        return int(fallback_rows)
```

**Rule Used:** Never use bare 'except:' that silently swallows er... ([source](https://app.greptile.com/review/custom-context?memory=no-bare-except))

How can I resolve this? If you propose a fix, please make it concise.

return 0
return int(key_ds.groupby("_page_key").count().count())
except Exception:
logger.warning("Could not estimate processed pages from Ray Dataset; falling back to output row count.", exc_info=True)
return int(fallback_rows)


def _resolve_file_patterns(input_path: Path, input_type: str) -> list[str]:
import glob as _glob

Expand Down Expand Up @@ -310,6 +341,13 @@ def main(
runtime_metrics_dir: Optional[Path] = typer.Option(None, "--runtime-metrics-dir", path_type=Path),
runtime_metrics_prefix: Optional[str] = typer.Option(None, "--runtime-metrics-prefix"),
detection_summary_file: Optional[Path] = typer.Option(None, "--detection-summary-file", path_type=Path),
metrics_output_file: Optional[Path] = typer.Option(
None,
"--metrics-output-file",
path_type=Path,
dir_okay=False,
help="JSON file path to write structured run metrics (used by the harness).",
),
log_file: Optional[Path] = typer.Option(None, "--log-file", path_type=Path, dir_okay=False),
) -> None:
_ = ctx
Expand All @@ -328,7 +366,6 @@ def main(
os.environ["RAY_LOG_TO_DRIVER"] = "1" if ray_log_to_driver else "0"

lancedb_uri = str(Path(lancedb_uri).expanduser().resolve())
_ensure_lancedb_table(lancedb_uri, LANCEDB_TABLE)

remote_api_key = resolve_remote_api_key(api_key)
extract_remote_api_key = remote_api_key
Expand Down Expand Up @@ -535,41 +572,66 @@ def main(

ingestor = ingestor.embed(embed_params)

# VDB upload runs inside the graph — rows stream to the configured
# backend as they are produced, so we never need to collect the entire
# result set on the driver just for the write. Index creation happens
# automatically in GraphIngestor._finalize_vdb() after the pipeline.
ingestor = ingestor.vdb_upload(
VdbUploadParams(
lancedb=LanceDbParams(
lancedb_uri=lancedb_uri,
table_name=LANCEDB_TABLE,
hybrid=hybrid,
overwrite=True,
),
)
)

# ------------------------------------------------------------------
# Execute the graph via the executor
# ------------------------------------------------------------------
logger.info("Starting ingestion of %s ...", input_path)
ingest_start = time.perf_counter()

# GraphIngestor.ingest() builds the Graph, creates the executor,
# and calls executor.ingest(file_patterns) returning:
# calls executor.ingest(file_patterns), and finalizes the VDB index.
# batch mode -> materialized ray.data.Dataset
# inprocess mode -> pandas.DataFrame
result = ingestor.ingest()

ingestion_only_total_time = time.perf_counter() - ingest_start

# ------------------------------------------------------------------
# Collect results
# Collect results only when downstream features need the full DataFrame.
# Page/row metrics stay separate: PPS is pages/sec, while num_rows tracks
# output rows after any content explosion.
# ------------------------------------------------------------------
if run_mode == "batch":
import ray

ray_download_start = time.perf_counter()
ingest_local_results = result.take_all()
ray_download_time = time.perf_counter() - ray_download_start

import pandas as pd

result_df = pd.DataFrame(ingest_local_results)
num_rows = _count_input_units(result_df)
needs_result_df = detection_summary_file is not None or save_intermediate is not None
if needs_result_df:
ray_download_start = time.perf_counter()
ingest_local_results = result.take_all()
ray_download_time = time.perf_counter() - ray_download_start

import pandas as pd

result_df = pd.DataFrame(ingest_local_results)
processed_pages = _count_processed_pages_from_df(result_df)
output_rows = int(len(result_df.index))
else:
ray_download_time = 0.0
result_df = None
output_rows = int(result.count())
processed_pages = _count_processed_pages_from_dataset(result, fallback_rows=output_rows)
else:
import pandas as pd

result_df = result
ingest_local_results = result_df.to_dict("records")
ray_download_time = 0.0
num_rows = _count_input_units(result_df)
processed_pages = _count_processed_pages_from_df(result_df)
output_rows = int(len(result_df.index))

if save_intermediate is not None:
out_dir = Path(save_intermediate).expanduser().resolve()
Expand All @@ -589,44 +651,49 @@ def main(
collect_detection_summary_from_df(result_df),
)

# ------------------------------------------------------------------
# Write to LanceDB
# ------------------------------------------------------------------
lancedb_write_start = time.perf_counter()
handle_lancedb(ingest_local_results, lancedb_uri, LANCEDB_TABLE, hybrid=hybrid, mode="overwrite")
lancedb_write_time = time.perf_counter() - lancedb_write_start

# ------------------------------------------------------------------
# Recall / BEIR evaluation
# ------------------------------------------------------------------
import lancedb as _lancedb_mod

db = _lancedb_mod.connect(lancedb_uri)
table = db.open_table(LANCEDB_TABLE)

if int(table.count_rows()) == 0:
logger.warning("LanceDB table is empty; skipping %s evaluation.", evaluation_mode)
def _empty_summary(reason_label: str) -> None:
_write_runtime_summary(
runtime_metrics_dir,
runtime_metrics_prefix,
{
"run_mode": run_mode,
"input_path": str(Path(input_path).resolve()),
"input_pages": int(num_rows),
"num_pages": int(num_rows),
"num_rows": int(len(result_df.index)),
"input_pages": int(processed_pages),
"num_pages": int(processed_pages),
"num_rows": int(output_rows),
"ingestion_only_secs": float(ingestion_only_total_time),
"ray_download_secs": float(ray_download_time),
"lancedb_write_secs": float(lancedb_write_time),
"lancedb_write_secs": 0.0,
"evaluation_secs": 0.0,
"total_secs": float(time.perf_counter() - ingest_start),
"evaluation_mode": evaluation_mode,
"evaluation_metrics": {},
"recall_details": bool(recall_details),
"lancedb_uri": str(lancedb_uri),
"lancedb_table": str(LANCEDB_TABLE),
"skip_reason": reason_label,
},
metrics_output_file=metrics_output_file,
)

import lancedb as _lancedb_mod

db = _lancedb_mod.connect(lancedb_uri)
try:
table = db.open_table(LANCEDB_TABLE)
except Exception:
logger.warning("LanceDB table %r was not created; skipping %s evaluation.", LANCEDB_TABLE, evaluation_mode)
_empty_summary("lancedb_table_missing")
if run_mode == "batch":
ray.shutdown()
return

if int(table.count_rows()) == 0:
logger.warning("LanceDB table is empty; skipping %s evaluation.", evaluation_mode)
_empty_summary("lancedb_table_empty")
if run_mode == "batch":
ray.shutdown()
return
Expand Down Expand Up @@ -675,27 +742,7 @@ def main(
query_csv_path = Path(query_csv)
if not query_csv_path.exists():
logger.warning("Query CSV not found at %s; skipping recall evaluation.", query_csv_path)
_write_runtime_summary(
runtime_metrics_dir,
runtime_metrics_prefix,
{
"run_mode": run_mode,
"input_path": str(Path(input_path).resolve()),
"input_pages": int(num_rows),
"num_pages": int(num_rows),
"num_rows": int(len(result_df.index)),
"ingestion_only_secs": float(ingestion_only_total_time),
"ray_download_secs": float(ray_download_time),
"lancedb_write_secs": float(lancedb_write_time),
"evaluation_secs": 0.0,
"total_secs": float(time.perf_counter() - ingest_start),
"evaluation_mode": evaluation_mode,
"evaluation_metrics": {},
"recall_details": bool(recall_details),
"lancedb_uri": str(lancedb_uri),
"lancedb_table": str(LANCEDB_TABLE),
},
)
_empty_summary("query_csv_missing")
if run_mode == "batch":
ray.shutdown()
return
Expand Down Expand Up @@ -733,12 +780,12 @@ def main(
{
"run_mode": run_mode,
"input_path": str(Path(input_path).resolve()),
"input_pages": int(num_rows),
"num_pages": int(num_rows),
"num_rows": int(len(result_df.index)),
"input_pages": int(processed_pages),
"num_pages": int(processed_pages),
"num_rows": int(output_rows),
"ingestion_only_secs": float(ingestion_only_total_time),
"ray_download_secs": float(ray_download_time),
"lancedb_write_secs": float(lancedb_write_time),
"lancedb_write_secs": 0.0,
"evaluation_secs": float(evaluation_total_time),
"total_secs": float(total_time),
"evaluation_mode": evaluation_mode,
Expand All @@ -748,21 +795,22 @@ def main(
"lancedb_uri": str(lancedb_uri),
"lancedb_table": str(LANCEDB_TABLE),
},
metrics_output_file=metrics_output_file,
)

if run_mode == "batch":
ray.shutdown()

print_run_summary(
num_rows,
processed_pages,
Path(input_path),
hybrid,
lancedb_uri,
LANCEDB_TABLE,
total_time,
ingestion_only_total_time,
ray_download_time,
lancedb_write_time,
0.0,
evaluation_total_time,
evaluation_metrics,
evaluation_label=evaluation_label,
Expand Down
2 changes: 2 additions & 0 deletions nemo_retriever/src/nemo_retriever/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from nemo_retriever.graph.graph_pipeline_registry import GraphPipelineRegistry, default_registry
from nemo_retriever.graph.pipeline_graph import Graph, Node
from nemo_retriever.graph.store_operator import StoreOperator
from nemo_retriever.graph.vdb_upload_operator import VDBUploadOperator
from nemo_retriever.graph.webhook_operator import WebhookNotifyOperator

__all__ = [
Expand All @@ -33,6 +34,7 @@
"RayDataExecutor",
"StoreOperator",
"UDFOperator",
"VDBUploadOperator",
"WebhookNotifyOperator",
"default_registry",
]
Expand Down
Loading
Loading