Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nemo_retriever/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [
"uvicorn[standard]>=0.30.0",
"python-multipart>=0.0.9",
# HTTP clients
"aiohttp>=3.9.0",
"httpx>=0.27.0",
"requests>=2.32.5",
"urllib3==2.6.3",
Expand Down
2 changes: 1 addition & 1 deletion nemo_retriever/src/nemo_retriever/audio/asr_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,4 +401,4 @@ def apply_asr_to_df(
"""
params = ASRParams(**(asr_params or {}))
actor = ASRActor(params=params)
return actor(batch_df)
return actor.run(batch_df)
20 changes: 17 additions & 3 deletions nemo_retriever/src/nemo_retriever/chart/cpu_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from nemo_retriever.graph.cpu_operator import CPUOperator
from nemo_retriever.nim.nim import NIMClient
from nemo_retriever.params import RemoteRetryParams
from nemo_retriever.chart.shared import graphic_elements_ocr_page_elements
from nemo_retriever.chart.shared import agraphic_elements_ocr_page_elements, graphic_elements_ocr_page_elements

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -143,9 +143,23 @@ def process(self, data: Any, **kwargs: Any) -> Any:
def postprocess(self, data: Any, **kwargs: Any) -> Any:
return data

def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any:
async def aprocess(self, data: Any, **kwargs: Any) -> Any:
return await agraphic_elements_ocr_page_elements(
data,
graphic_elements_model=self._graphic_elements_model,
ocr_model=self._ocr_model,
graphic_elements_invoke_url=self._graphic_elements_invoke_url,
ocr_invoke_url=self._ocr_invoke_url,
api_key=self._api_key,
request_timeout_s=self._request_timeout_s,
remote_retry=self._remote_retry,
inference_batch_size=self._inference_batch_size,
**kwargs,
)

async def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any:
try:
return self.run(batch_df, **override_kwargs)
return await self.arun(batch_df, **override_kwargs)
except BaseException as exc:
if isinstance(batch_df, pd.DataFrame):
out = batch_df.copy()
Expand Down
20 changes: 17 additions & 3 deletions nemo_retriever/src/nemo_retriever/chart/gpu_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from nemo_retriever.graph.gpu_operator import GPUOperator
from nemo_retriever.nim.nim import NIMClient
from nemo_retriever.params import RemoteRetryParams
from nemo_retriever.chart.shared import graphic_elements_ocr_page_elements
from nemo_retriever.chart.shared import agraphic_elements_ocr_page_elements, graphic_elements_ocr_page_elements


class GraphicElementsActor(AbstractOperator, GPUOperator):
Expand Down Expand Up @@ -88,9 +88,23 @@ def process(self, data: Any, **kwargs: Any) -> Any:
def postprocess(self, data: Any, **kwargs: Any) -> Any:
return data

def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any:
async def aprocess(self, data: Any, **kwargs: Any) -> Any:
return await agraphic_elements_ocr_page_elements(
data,
graphic_elements_model=self._graphic_elements_model,
ocr_model=self._ocr_model,
graphic_elements_invoke_url=self._graphic_elements_invoke_url,
ocr_invoke_url=self._ocr_invoke_url,
api_key=self._api_key,
request_timeout_s=self._request_timeout_s,
remote_retry=self._remote_retry,
inference_batch_size=self._inference_batch_size,
**kwargs,
)

async def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any:
try:
return self.run(batch_df, **override_kwargs)
return await self.arun(batch_df, **override_kwargs)
except BaseException as exc:
if isinstance(batch_df, pd.DataFrame):
out = batch_df.copy()
Expand Down
231 changes: 231 additions & 0 deletions nemo_retriever/src/nemo_retriever/chart/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@

import base64
import io
import logging
import time
import traceback

logger = logging.getLogger(__name__)

import pandas as pd
from nemo_retriever.nim.nim import NIMClient, invoke_image_inference_batches
from nemo_retriever.params import RemoteRetryParams
Expand Down Expand Up @@ -576,6 +579,234 @@ def graphic_elements_ocr_page_elements(
return out


async def agraphic_elements_ocr_page_elements(
batch_df: Any,
*,
graphic_elements_model: Any = None,
ocr_model: Any = None,
graphic_elements_invoke_url: str = "",
ocr_invoke_url: str = "",
api_key: str = "",
request_timeout_s: float = 120.0,
remote_retry: RemoteRetryParams | None = None,
**kwargs: Any,
) -> Any:
"""Async version of :func:`graphic_elements_ocr_page_elements`."""
import asyncio

from nemo_retriever.nim.nim import ainvoke_image_inference_batches
from nemo_retriever.ocr.ocr import (
_blocks_to_text,
_crop_all_from_page,
_extract_remote_ocr_item,
_np_rgb_to_b64_png,
_parse_ocr_result,
)
from nemo_retriever.utils.table_and_chart import join_graphic_elements_and_ocr_output

retry = remote_retry or RemoteRetryParams(
remote_max_pool_workers=int(kwargs.get("remote_max_pool_workers", 16)),
remote_max_retries=int(kwargs.get("remote_max_retries", 10)),
remote_max_429_retries=int(kwargs.get("remote_max_429_retries", 5)),
)

if not isinstance(batch_df, pd.DataFrame):
raise NotImplementedError("agraphic_elements_ocr_page_elements currently only supports pandas.DataFrame input.")

ge_url = (graphic_elements_invoke_url or kwargs.get("graphic_elements_invoke_url") or "").strip()
ocr_url = (ocr_invoke_url or kwargs.get("ocr_invoke_url") or "").strip()
use_remote_ge = bool(ge_url)
use_remote_ocr = bool(ocr_url)

if not use_remote_ge and graphic_elements_model is None:
raise ValueError("A local `graphic_elements_model` is required when `graphic_elements_invoke_url` is not set.")
if not use_remote_ocr and ocr_model is None:
raise ValueError("A local `ocr_model` is required when `ocr_invoke_url` is not set.")

label_names = _labels_from_model(graphic_elements_model) if graphic_elements_model is not None else []
inference_batch_size = int(kwargs.get("inference_batch_size", 8))

all_chart: List[List[Dict[str, Any]]] = []
all_meta: List[Dict[str, Any]] = []

t0_total = time.perf_counter()

for row in batch_df.itertuples(index=False):
chart_items: List[Dict[str, Any]] = []
row_error: Any = None

try:
pe = getattr(row, "page_elements_v3", None)
dets: List[Dict[str, Any]] = []
if isinstance(pe, dict):
dets = pe.get("detections") or []
if not isinstance(dets, list):
dets = []

page_image = getattr(row, "page_image", None) or {}
page_image_b64 = page_image.get("image_b64") if isinstance(page_image, dict) else None

if not isinstance(page_image_b64, str) or not page_image_b64:
all_chart.append(chart_items)
all_meta.append({"timing": None, "error": None})
continue

crops = _crop_all_from_page(page_image_b64, dets, {"chart"})

if not crops:
all_chart.append(chart_items)
all_meta.append({"timing": None, "error": None})
continue

crop_b64s = (
[_np_rgb_to_b64_png(crop_array) for _, _, crop_array in crops]
if (use_remote_ge or use_remote_ocr)
else []
)

ge_results: List[List[Dict[str, Any]]] = []
ocr_results: List[Any] = []

if use_remote_ge and use_remote_ocr:
ge_task = ainvoke_image_inference_batches(
invoke_url=ge_url,
image_b64_list=crop_b64s,
api_key=api_key or None,
timeout_s=float(request_timeout_s),
max_batch_size=inference_batch_size,
max_concurrency=int(retry.remote_max_pool_workers),
max_retries=int(retry.remote_max_retries),
max_429_retries=int(retry.remote_max_429_retries),
)
ocr_task = ainvoke_image_inference_batches(
invoke_url=ocr_url,
image_b64_list=crop_b64s,
api_key=api_key or None,
timeout_s=float(request_timeout_s),
max_batch_size=inference_batch_size,
max_concurrency=int(retry.remote_max_pool_workers),
max_retries=int(retry.remote_max_retries),
max_429_retries=int(retry.remote_max_429_retries),
)
ge_items, ocr_items = await asyncio.gather(ge_task, ocr_task)

if len(ge_items) != len(crops):
raise RuntimeError(f"Expected {len(crops)} GE responses, got {len(ge_items)}")
for resp in ge_items:
ge_results.append(
[
d
for d in _remote_response_to_ge_detections(resp)
if (d.get("score") or 0.0) >= YOLOX_GRAPHIC_MIN_SCORE
]
)
if len(ocr_items) != len(crops):
raise RuntimeError(f"Expected {len(crops)} OCR responses, got {len(ocr_items)}")
for resp in ocr_items:
ocr_results.append(_extract_remote_ocr_item(resp))
else:
if use_remote_ge:
ge_items = await ainvoke_image_inference_batches(
invoke_url=ge_url,
image_b64_list=crop_b64s,
api_key=api_key or None,
timeout_s=float(request_timeout_s),
max_batch_size=inference_batch_size,
max_concurrency=int(retry.remote_max_pool_workers),
max_retries=int(retry.remote_max_retries),
max_429_retries=int(retry.remote_max_429_retries),
)
if len(ge_items) != len(crops):
raise RuntimeError(f"Expected {len(crops)} GE responses, got {len(ge_items)}")
for resp in ge_items:
ge_results.append(
[
d
for d in _remote_response_to_ge_detections(resp)
if (d.get("score") or 0.0) >= YOLOX_GRAPHIC_MIN_SCORE
]
)
else:

def _run_local_ge():
results = []
for _, _, crop_array in crops:
chw = torch.from_numpy(crop_array).permute(2, 0, 1).contiguous().to(dtype=torch.float32)
h, w = crop_array.shape[:2]
x = chw.unsqueeze(0)
try:
pre = graphic_elements_model.preprocess(x)
except Exception:
pre = x
if isinstance(pre, torch.Tensor) and pre.ndim == 3:
pre = pre.unsqueeze(0)
pred = graphic_elements_model.invoke(pre, (h, w))
ge_dets = _prediction_to_detections(pred, label_names=label_names)
results.append([d for d in ge_dets if (d.get("score") or 0.0) >= YOLOX_GRAPHIC_MIN_SCORE])
return results

ge_results = await asyncio.to_thread(_run_local_ge)

if use_remote_ocr:
ocr_items = await ainvoke_image_inference_batches(
invoke_url=ocr_url,
image_b64_list=crop_b64s,
api_key=api_key or None,
timeout_s=float(request_timeout_s),
max_batch_size=inference_batch_size,
max_concurrency=int(retry.remote_max_pool_workers),
max_retries=int(retry.remote_max_retries),
max_429_retries=int(retry.remote_max_429_retries),
)
if len(ocr_items) != len(crops):
raise RuntimeError(f"Expected {len(crops)} OCR responses, got {len(ocr_items)}")
for resp in ocr_items:
ocr_results.append(_extract_remote_ocr_item(resp))
else:

def _run_local_ocr():
results = []
for _, _, crop_array in crops:
results.append(ocr_model.invoke(crop_array, merge_level="word"))
return results

ocr_results = await asyncio.to_thread(_run_local_ocr)

for crop_i, (label_name, bbox, crop_array) in enumerate(crops):
crop_hw = (int(crop_array.shape[0]), int(crop_array.shape[1]))
ge_dets = ge_results[crop_i]
ocr_preds = ocr_results[crop_i]

text = join_graphic_elements_and_ocr_output(ge_dets, ocr_preds, crop_hw)

if not text:
blocks = _parse_ocr_result(ocr_preds)
text = _blocks_to_text(blocks)

chart_items.append({"bbox_xyxy_norm": bbox, "text": text})

except BaseException as e:
logger.warning("graphic-elements+OCR failed: %s: %s", type(e).__name__, e, exc_info=True)
row_error = {
"stage": "graphic_elements_ocr_page_elements",
"type": e.__class__.__name__,
"message": str(e),
"traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)),
}

all_chart.append(chart_items)
all_meta.append({"timing": None, "error": row_error})

elapsed = time.perf_counter() - t0_total
for meta in all_meta:
meta["timing"] = {"seconds": float(elapsed)}

out = batch_df.copy()
out["chart"] = all_chart
out["graphic_elements_ocr_v1"] = all_meta
return out


# ---------------------------------------------------------------------------
# Combined graphic-elements + OCR Ray Actor
# ---------------------------------------------------------------------------
Loading
Loading