Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gtsfm/configs/megaloc_sift_frontend.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ image_pairs_generator:
# Retriever: KEPT THE SAME - SimilarityRetriever works for any global descriptor!
# It computes cosine similarity between vectors and selects top-k pairs
retriever:
_target_: gtsfm.retriever.Similarity
_target_: gtsfm.retriever.BatchedSimilarity
num_matched: 20 # Find top-20 most similar images per query
min_score: 0.3 # Minimum cosine similarity threshold

Expand Down
117 changes: 117 additions & 0 deletions gtsfm/configs/unified_megaloc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Default front-end configuration.

# @package _global_
_target_: gtsfm.scene_optimizer.SceneOptimizer

loader:
_target_: gtsfm.loader.Olsson

image_pairs_generator:
_target_: gtsfm.retriever.image_pairs_generator.ImagePairsGenerator
global_descriptor:
_target_: gtsfm.frontend.cacher.global_descriptor_cacher.GlobalDescriptorCacher
global_descriptor_obj:
_target_: gtsfm.frontend.global_descriptor.MegaLoc
retriever:
_target_: gtsfm.retriever.BatchedSimilarity
num_matched: 15
min_score: 0.4
batch_size: 16

graph_partitioner:
_target_: gtsfm.graph_partitioner.Single

cluster_optimizer:
_target_: gtsfm.cluster_optimizer.Multiview
correspondence_generator:
_target_: gtsfm.frontend.correspondence_generator.det_desc_correspondence_generator.DetDescCorrespondenceGenerator

detector_descriptor:
_target_: gtsfm.frontend.cacher.detector_descriptor_cacher.DetectorDescriptorCacher
detector_descriptor_obj:
_target_: gtsfm.frontend.detector_descriptor.superpoint.SuperPointDetectorDescriptor
max_keypoints: 5000

matcher:
_target_: gtsfm.frontend.cacher.matcher_cacher.MatcherCacher
matcher_obj:
_target_: gtsfm.frontend.matcher.lightglue_matcher.LightGlueMatcher
features: "superpoint"

two_view_estimator:
_target_: gtsfm.two_view_estimator_cacher.TwoViewEstimatorCacher
two_view_estimator_obj:
_target_: gtsfm.two_view_estimator.TwoViewEstimator
bundle_adjust_2view: True
eval_threshold_px: 4 # in px
ba_reproj_error_thresholds: [0.5]
bundle_adjust_2view_maxiters: 100

verifier:
_target_: gtsfm.frontend.verifier.ransac.Ransac
use_intrinsics_in_verification: True
estimation_threshold_px: 4 # for H/E/F estimators

triangulation_options:
_target_: gtsfm.data_association.point3d_initializer.TriangulationOptions
mode:
_target_: gtsfm.data_association.point3d_initializer.TriangulationSamplingMode
value: NO_RANSAC

inlier_support_processor:
_target_: gtsfm.two_view_estimator.InlierSupportProcessor
min_num_inliers_est_model: 15
min_inlier_ratio_est_model: 0.1
save_gtsfm_data: True
save_3d_viz: False
save_two_view_viz: False
pose_angular_error_thresh: 5 # degrees
multiview_optimizer:
_target_: gtsfm.multi_view_optimizer.MultiViewOptimizer

# comment out to not run
view_graph_estimator:
_target_: gtsfm.view_graph_estimator.cycle_consistent_rotation_estimator.CycleConsistentRotationViewGraphEstimator
edge_error_aggregation_criterion: MIN_EDGE_ERROR

rot_avg_module:
_target_: gtsfm.averaging.rotation.shonan.ShonanRotationAveraging
weight_by_inliers: True

trans_avg_module:
_target_: gtsfm.averaging.translation.averaging_1dsfm.TranslationAveraging1DSFM
robust_measurement_noise: True
projection_sampling_method: SAMPLE_INPUT_MEASUREMENTS
reject_outliers: False
use_all_tracks_for_averaging: True
use_relative_camera_poses: True

data_association_module:
_target_: gtsfm.data_association.data_assoc.DataAssociation
min_track_len: 3
triangulation_options:
_target_: gtsfm.data_association.point3d_initializer.TriangulationOptions
reproj_error_threshold: 10
mode:
_target_: gtsfm.data_association.point3d_initializer.TriangulationSamplingMode
value: RANSAC_SAMPLE_UNIFORM
max_num_hypotheses: 100
save_track_patches_viz: False

bundle_adjustment_module:
_target_: gtsfm.bundle.bundle_adjustment.BundleAdjustmentOptimizer
reproj_error_thresholds: [10, 5, 3] # for (multistage) post-optimization filtering
robust_measurement_noise: True
shared_calib: False
cam_pose3_prior_noise_sigma: 0.1
calibration_prior_noise_sigma: 1e-5
measurement_noise_sigma: 1.0

# comment out to not run
dense_multiview_optimizer:
_target_: gtsfm.densify.mvs_patchmatchnet.MVSPatchmatchNet

gaussian_splatting_optimizer:
_target_: gtsfm.splat.gaussian_splatting.GaussianSplatting
cfg:
_target_: gtsfm.splat.gaussian_splatting.Config
3 changes: 3 additions & 0 deletions gtsfm/retriever/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,18 @@
from .joint_similarity_sequential_retriever import JointSimilaritySequentialRetriever
from .sequential_retriever import SequentialRetriever
from .similarity_retriever import SimilarityRetriever
from .batched_similarity_retriever import BatchedSimilarityRetriever

Exhaustive = ExhaustiveRetriever
JointSimilaritySequential = JointSimilaritySequentialRetriever
Sequential = SequentialRetriever
Similarity = SimilarityRetriever
BatchedSimilarity = BatchedSimilarityRetriever

__all__ = [
"Exhaustive",
"JointSimilaritySequential",
"Sequential",
"Similarity",
"BatchedSimilarity"
]
216 changes: 216 additions & 0 deletions gtsfm/retriever/batched_similarity_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
"""Row-batched similarity retriever optimized for GPU execution.

This retriever computes exact inner-product similarity (equivalent to cosine
similarity for normalized descriptors) using a memory-efficient row-batched
approach. Unlike SimilarityRetriever which materializes the full N×N matrix
in blocks, this implementation processes row batches and immediately reduces
with top-k, achieving O(batch × N) memory instead of O(N²).

Authors: Kathir Gounder, [your collaborators]
"""

import os
import time
from pathlib import Path
from typing import List, Optional, Tuple

import numpy as np
import torch

import gtsfm.utils.logger as logger_utils
from gtsfm.products.visibility_graph import VisibilityGraph
from gtsfm.retriever.retriever_base import RetrieverBase

logger = logger_utils.get_logger()


class BatchedSimilarityRetriever(RetrieverBase):
def __init__(
self,
num_matched: int,
min_score: float = 0.1,
batch_size: int = 1024,
) -> None:
"""
Row-batched similarity retriever for large-scale image matching.

Args:
num_matched: Number of top matches to return per query image.
min_score: Minimum similarity score threshold.
batch_size: Number of query rows to process per batch. Tune based on
GPU memory. Memory per batch ≈ batch_size × N × 4 bytes.
"""
self._num_matched = num_matched
self._min_score = min_score
self._batch_size = batch_size
self._latest_query_results: Optional[List[List[Tuple[int, float]]]] = None

def __repr__(self) -> str:
return f"""
BatchedSimilarityRetriever:
Num. frames matched: {self._num_matched}
Minimum score: {self._min_score}
Batch size: {self._batch_size}
Device: {"cuda" if torch.cuda.is_available() else "cpu"}
"""

def set_num_matched(self, n: int) -> None:
"""Set the number of matched frames."""
self._num_matched = n

def get_image_pairs(
self,
global_descriptors: Optional[List[np.ndarray]],
image_fnames: List[str],
plots_output_dir: Optional[Path] = None,
) -> VisibilityGraph:
"""Compute potential image pairs using batched similarity search.

Args:
global_descriptors: Normalized global descriptors (e.g., NetVLAD, MegaLoc).
image_fnames: File names of the images.
plots_output_dir: Directory to save diagnostic text files.

Returns:
List of (i1, i2) image pairs.
"""
if global_descriptors is None or len(global_descriptors) == 0:
raise ValueError("Global descriptors must be provided and non-empty")

num_images = len(global_descriptors)
descriptors = np.stack(global_descriptors).astype(np.float32)

if not descriptors.flags['C_CONTIGUOUS']:
descriptors = np.ascontiguousarray(descriptors)

start_time = time.time()
scores, indices = self._batched_search(descriptors)
search_time = time.time() - start_time

logger.info(
"Batched similarity search: N=%d, D=%d, time=%.2fs",
num_images, descriptors.shape[1], search_time
)

pairs, per_query_results = self._collect_pairs(scores, indices, num_images)
self._latest_query_results = per_query_results

logger.info("Found %d pairs using BatchedSimilarityRetriever.", len(pairs))

if plots_output_dir:
self.save_diagnostics(image_fnames, pairs, plots_output_dir)

return pairs

def _batched_search(self, descriptors: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Row-batched similarity search with O(batch × N) memory.

For each batch of query rows, computes similarity against all database
images, masks invalid pairs (j <= i), and extracts top-k matches.

Args:
descriptors: (N, D) array of normalized descriptors.

Returns:
scores: (N, k) top-k similarity scores per query.
indices: (N, k) indices of top-k matches per query.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
n = len(descriptors)
k = min(self._num_matched, n - 1)

desc_gpu = torch.from_numpy(descriptors).to(device)

all_scores = torch.zeros(n, k, dtype=torch.float32)
all_indices = torch.zeros(n, k, dtype=torch.int64)

for i_start in range(0, n, self._batch_size):
i_end = min(i_start + self._batch_size, n)
batch_desc = desc_gpu[i_start:i_end]

# (batch, N) similarity matrix for this batch
sim_batch = batch_desc @ desc_gpu.T

# Mask lower triangular + diagonal: for global row i, mask columns [0, i]
# This enforces i < j and removes self-matches
for b in range(sim_batch.shape[0]):
global_i = i_start + b
sim_batch[b, :global_i + 1] = float('-inf')

# Apply score threshold
if self._min_score is not None:
sim_batch = sim_batch.masked_fill(sim_batch < self._min_score, float('-inf'))

scores_batch, indices_batch = torch.topk(sim_batch, k=k, dim=1)

all_scores[i_start:i_end] = scores_batch.cpu()
all_indices[i_start:i_end] = indices_batch.cpu()

return all_scores.numpy(), all_indices.numpy()

def _collect_pairs(
self, scores: np.ndarray, indices: np.ndarray, num_images: int
) -> Tuple[VisibilityGraph, List[List[Tuple[int, float]]]]:
"""Convert top-k results to pair list.

Upper-triangular constraint already enforced during search.
"""
pairs: List[Tuple[int, int]] = []
per_query_results: List[List[Tuple[int, float]]] = []

for i in range(num_images):
query_matches: List[Tuple[int, float]] = []
for k_idx in range(scores.shape[1]):
j = int(indices[i, k_idx])
score = float(scores[i, k_idx])

if not np.isfinite(score) or j < 0:
continue

pairs.append((i, j))
query_matches.append((j, score))
per_query_results.append(query_matches)

return pairs, per_query_results

def save_diagnostics(
self,
image_fnames: List[str],
pairs: VisibilityGraph,
plots_output_dir: Optional[Path],
) -> None:
"""Save retrieval diagnostics to text files.

Note: Unlike SimilarityRetriever, we do NOT save a dense heatmap image
since this retriever is designed for scales where N×N is prohibitive.
"""
if plots_output_dir is None:
return

os.makedirs(plots_output_dir, exist_ok=True)

# Save pair list
pairs_path = plots_output_dir / "retrieved_pairs.txt"
with open(pairs_path, "w") as f:
f.write(f"# BatchedSimilarityRetriever Pairs\n")
f.write(f"# Num Pairs: {len(pairs)}\n")
f.write(f"# Min Score: {self._min_score}\n")
f.write("# Format: Index1 Index2 Name1 Name2\n")
for i, j in pairs:
f.write(f"{i} {j} {image_fnames[i]} {image_fnames[j]}\n")
logger.info("Saved pair list to %s", pairs_path)

# Save ranked scores
if self._latest_query_results is None:
return

ranked_path = plots_output_dir / "similarity_named_pairs.txt"
with open(ranked_path, "w") as f:
f.write("# Format: score name_i name_j\n")
for i, matches in enumerate(self._latest_query_results):
name_i = image_fnames[i]
for j, score in matches:
f.write(f"{score:.4f} {name_i} {image_fnames[j]}\n")
logger.info("Saved ranked pairs to %s", ranked_path)

self._latest_query_results = None