Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
22017e8
store pre ba result
akshay-krishnan Jan 19, 2026
bd91d3b
track visualization script
akshay-krishnan Jan 20, 2026
ac5c784
trajectory aligner sim3 use
akshay-krishnan Jan 21, 2026
6aa3493
Replacing VGGSfM tracker with VGGT tracker module
Jan 26, 2026
90b252d
Decrease min_score to increase connectivity
nantonzhang Jan 27, 2026
a1c3c34
Use Metis by default
nantonzhang Jan 27, 2026
a7d6c90
exclude all result folders
nantonzhang Jan 27, 2026
562448f
change default to no ba
nantonzhang Jan 27, 2026
b9286e4
Add more metrics to eval code
nantonzhang Jan 27, 2026
883d601
merging updates
akshay-krishnan Jan 27, 2026
806606a
save more metrics for debugging
akshay-krishnan Jan 27, 2026
a5cf200
Merge remote-tracking branch 'upstream/ba-debug' into vggt_tracking_e…
nantonzhang Jan 28, 2026
8ce3153
more detailed eval code
nantonzhang Jan 28, 2026
eceb70b
optimize cacher
nantonzhang Jan 28, 2026
86f568a
skip use_nonlinear_sim3_alignment for now
nantonzhang Jan 28, 2026
01d9c81
resolving comments
Feb 5, 2026
0943229
adding 2 parameters in yaml file
Feb 5, 2026
0e5920b
Merge pull request #1039 from borglab/pr-1035
hkhanuja Feb 5, 2026
de2242d
Fixing dependencies
Feb 5, 2026
195f8b9
Merge pull request #1040 from borglab/pr-1035
hkhanuja Feb 5, 2026
040b108
Resolving comments
Feb 6, 2026
d2018e6
Merge pull request #1041 from borglab/pr-1035
hkhanuja Feb 6, 2026
6f02af8
Minor filtering changes to frontend/vggt.py
Feb 6, 2026
7c18df4
Merge pull request #1042 from borglab/pr-1035
hkhanuja Feb 6, 2026
c3e99b8
Pinning GTSAM version
Feb 7, 2026
5c40d50
Pinning GTSAM version in pyproject
Feb 7, 2026
a921e56
Fixing noise model in tests
Feb 7, 2026
04b5313
Fixing tests
Feb 7, 2026
62bdf26
Fixing tests
Feb 7, 2026
8080920
Merge pull request #1049 from borglab/pr-1035
hkhanuja Feb 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ data/
# Data dumped by GTSFM directory
debug/
plots/
**/*result*/
results/
result_metrics/
*.html
Expand Down
33 changes: 12 additions & 21 deletions gtsfm/cluster_optimizer/cluster_vggt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,9 @@ def _resize_to_square_tensor(image: np.ndarray, target_size: int) -> torch.Tenso
return (tensor.squeeze(0)) / 255.0


def _load_vggt_inputs(loader, indices: list[int], target_size: int):
def _load_vggt_inputs(loader, indices: list[int], mode: str):
"""Load and preprocess a batch of images for VGGT."""

def resize_transform(arr: np.ndarray) -> torch.Tensor:
return _resize_to_square_tensor(arr, target_size)

return loader.load_image_batch_vggt(indices, target_size, resize_transform)
return loader.load_image_batch_vggt_loader(indices, mode=mode)


def _resolve_vggt_model(cache_key: Hashable | None, loader_kwargs: dict[str, Any] | None) -> Any | None:
Expand Down Expand Up @@ -113,15 +109,14 @@ class ClusterVGGT(ClusterOptimizerBase):
def __init__(
self,
weights_path: Optional[str] = None,
image_load_resolution: int = 1024,
inference_resolution: int = 518,
conf_threshold: float = 5.0,
max_num_points: int = 100000,
tracking: bool = False,
tracking_max_query_pts: int = 1000,
tracking_query_frame_num: int = 4,
tracking_fine_tracking: bool = True,
track_vis_thresh: float = 0.2,
tracking_max_query_pts: int = 2048,
tracking_query_frame_num: int = 3,
Comment thread
hkhanuja marked this conversation as resolved.
track_vis_thresh: float = 0.05,
track_conf_thresh: float = 0.2,
keypoint_extractor: str = "aliked+sp+sift",
camera_type: str = "PINHOLE",
seed: int = 42,
scene_dir: Optional[str] = None,
Expand Down Expand Up @@ -154,15 +149,14 @@ def __init__(
run_bundle_adjustment_on_parent=run_bundle_adjustment_on_parent,
)
self._weights_path = Path(weights_path) if weights_path is not None else None
self._image_load_resolution = image_load_resolution
self._inference_resolution = inference_resolution
self._conf_threshold = conf_threshold
self._max_points_for_colmap = max_num_points
self._tracking = tracking
self._tracking_max_query_pts = tracking_max_query_pts
self._tracking_query_frame_num = tracking_query_frame_num
self._tracking_fine_tracking = tracking_fine_tracking
self._track_vis_thresh = track_vis_thresh
self._track_conf_thresh = track_conf_thresh
self._keypoint_extractor = keypoint_extractor
self._camera_type = camera_type
self._max_reproj_error = max_reproj_error
self._seed = seed
Expand Down Expand Up @@ -216,8 +210,6 @@ def _maybe_set_model_kw(key: str, value: Any) -> None:
def __repr__(self) -> str:
components = [
f"weights_path={self._weights_path}",
f"image_load_resolution={self._image_load_resolution}",
f"inference_resolution={self._inference_resolution}",
f"camera_type={self._camera_type}",
f"dtype={self._dtype}",
f"use_sparse_attention={self._use_sparse_attention}",
Expand Down Expand Up @@ -253,15 +245,14 @@ def create_computation_graph(
image_names = tuple(str(image_filenames[idx]) for idx in keys)

config = VggtConfiguration(
vggt_fixed_resolution=self._inference_resolution,
img_load_resolution=self._image_load_resolution,
confidence_threshold=self._conf_threshold,
max_num_points=self._max_points_for_colmap,
tracking=self._tracking,
max_query_pts=self._tracking_max_query_pts,
query_frame_num=self._tracking_query_frame_num,
fine_tracking=self._tracking_fine_tracking,
track_vis_thresh=self._track_vis_thresh,
track_conf_thresh=self._track_conf_thresh,
keypoint_extractor=self._keypoint_extractor,
dtype=self._dtype,
model_ctor_kwargs=self._model_ctor_kwargs.copy(),
use_sparse_attention=self._use_sparse_attention,
Expand All @@ -270,7 +261,7 @@ def create_computation_graph(
)

image_batch_graph, original_coords_graph = delayed(_load_vggt_inputs, nout=2)(
context.loader, global_indices, self._image_load_resolution
context.loader, global_indices, mode="crop" # mode is fixed to "crop"
Comment thread
hkhanuja marked this conversation as resolved.
Outdated
)

result_graph = delayed(_run_vggt_pipeline)(
Expand Down
18 changes: 9 additions & 9 deletions gtsfm/configs/vggt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ loader:
_target_: gtsfm.loader.Olsson
dataset_dir: ??? # Required: set to the dataset root on the command line.
images_dir: null
max_resolution: 760
max_resolution: 518 # VGGT recommended max resolution. Non editable. mode is fixed to "crop"
Comment thread
hkhanuja marked this conversation as resolved.
Outdated
Comment thread
hkhanuja marked this conversation as resolved.
Outdated

image_pairs_generator:
_target_: gtsfm.retriever.image_pairs_generator.ImagePairsGenerator
Expand All @@ -18,32 +18,32 @@ image_pairs_generator:
retriever:
_target_: gtsfm.retriever.Similarity
num_matched: 5
min_score: 0.79
min_score: 0.25
batch_size: 16

graph_partitioner:
_target_: gtsfm.graph_partitioner.Single
_target_: gtsfm.graph_partitioner.Metis

cluster_optimizer:
_target_: gtsfm.cluster_optimizer.Cacher
optimizer:
_target_: gtsfm.cluster_optimizer.cluster_vggt.ClusterVGGT
weights_path: null
image_load_resolution: 1024
inference_resolution: 518
conf_threshold: 5.0
max_num_points: 100000
tracking: true
tracking_max_query_pts: 512
tracking_max_query_pts: 2048
tracking_query_frame_num: 3
tracking_fine_tracking: false
track_vis_thresh: 0.2
keypoint_extractor: aliked+sp+sift
track_vis_thresh: 0.05
track_conf_thresh: 0.2
max_reproj_error: 0 # 0.0 means no filtering based on reproj error
camera_type: PINHOLE
drop_outlier_after_camera_merging: false
drop_child_if_merging_fail: true
drop_camera_with_no_track: true
seed: 42
plot_reprojection_histograms: true
run_bundle_adjustment_on_leaf: false
run_bundle_adjustment_on_parent: true
run_bundle_adjustment_on_parent: false
model_cache_key: null
149 changes: 135 additions & 14 deletions gtsfm/evaluation/compare_colmap_outputs.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
"""Script to compare two reconstructions in Colmap's output format.

Authors: Ayush Baid
Authors: Ayush Baid, Xinan Zhang
"""

import argparse
import csv
import json
import os
from typing import Dict, Tuple
import textwrap
from typing import Dict, List, Optional, Tuple

import numpy as np
import matplotlib.pyplot as plt
import pycolmap
from gtsam import Point3, Pose3, Rot3, Similarity3
from scipy.spatial.transform import Rotation
Expand Down Expand Up @@ -53,7 +57,108 @@ def align_with_colmap(
return aSb, aligned_dict


def compare_poses(baseline_dirpath: str, eval_dirpath: str, output_dirpath: str) -> None:
def plot_camera_centers(
baseline_wTi_list: List[Pose3],
current_wTi_list: List[Pose3],
output_dirpath: str,
title: Optional[str] = None,
) -> None:
"""Save a 3D scatter plot of baseline and current camera centers."""
baseline_centers = np.stack([pose.translation() for pose in baseline_wTi_list])
current_centers_list = [pose.translation() for pose in current_wTi_list]
current_centers = np.stack(current_centers_list) if current_centers_list else np.empty((0, 3))

fig = plt.figure(figsize=(7, 7))
ax = fig.add_subplot(111, projection="3d")
if baseline_centers.size:
center = baseline_centers.mean(axis=0)
mean_radius = np.linalg.norm(baseline_centers - center, axis=1).mean()
arrow_len = max(mean_radius * 0.15, 1e-3)
else:
arrow_len = 1.0

for pose in baseline_wTi_list:
origin = pose.transformFrom(Point3(0.0, 0.0, 0.0))
tip = pose.transformFrom(Point3(0.0, 0.0, arrow_len))
direction = tip - origin
ax.quiver(
origin[0], origin[1], origin[2],
direction[0], direction[1], direction[2],
color="tab:blue", linewidth=0.5, arrow_length_ratio=0.2, alpha=0.6
)
for pose in current_wTi_list:
origin = pose.transformFrom(Point3(0.0, 0.0, 0.0))
tip = pose.transformFrom(Point3(0.0, 0.0, arrow_len))
direction = tip - origin
ax.quiver(
origin[0], origin[1], origin[2],
direction[0], direction[1], direction[2],
color="tab:orange", linewidth=0.5, arrow_length_ratio=0.2, alpha=0.6
)

ax.scatter(
baseline_centers[:, 0],
baseline_centers[:, 1],
baseline_centers[:, 2],
s=10,
c="tab:blue",
label="baseline",
)
if current_centers.size:
ax.scatter(
current_centers[:, 0],
current_centers[:, 1],
current_centers[:, 2],
s=10,
c="tab:orange",
label="current",
)
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")
ax.legend(loc="best")
wrapped = "\n".join(textwrap.wrap(title, width=80)) if title else ""
if wrapped:
fig.suptitle(wrapped, fontsize=9, y=0.98)
fig.tight_layout(rect=[0, 0, 1, 0.92])
fig.savefig(os.path.join(output_dirpath, "camera_centers.png"), dpi=300)
plt.close(fig)


def export_metrics_group_to_csv(metrics_group: GtsfmMetricsGroup, output_path: str) -> None:
"""Export a metrics group to a CSV file."""
rows: List[Dict[str, str]] = []
for metric in metrics_group.metrics:
if metric.dim == 0:
value = "" if metric.data is None else f"{float(metric.data):.6f}"
rows.append({"metric_name": metric.name, "value": value})
else:
summary_json = json.dumps(metric.summary, sort_keys=True)
rows.append({"metric_name": metric.name, "value": summary_json})

with open(output_path, "w", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=["metric_name", "value"])
writer.writeheader()
writer.writerows(rows)


def _format_pose_auc(metrics_group: GtsfmMetricsGroup) -> str:
auc_parts = []
for metric in metrics_group.metrics:
if not metric.name.startswith("pose_auc_@"):
continue
if metric.data is None:
continue
try:
value = float(metric.data)
except (TypeError, ValueError):
continue
suffix = metric.name.replace("pose_auc_", "")
auc_parts.append(f"{suffix}={value:.3f}")
return ", ".join(auc_parts)


def compare_poses(baseline_dirpath: str, eval_dirpath: str, output_dirpath: str) -> GtsfmMetricsGroup:
"""Compare the pose metrics between two reconstructions (Colmap format).

Args:
Expand Down Expand Up @@ -82,25 +187,32 @@ def compare_poses(baseline_dirpath: str, eval_dirpath: str, output_dirpath: str)
len(common_fnames),
)

baseline_wTi_list: List[Pose3] = []
current_wTi_list: List[Optional[Pose3]] = []
for fname, wTi in baseline_wTi_dict.items():
baseline_wTi_list.append(wTi)
current_wTi_list.append(current_wTi_dict.get(fname))

if not args.use_pycolmap_alignment:
aSb = align.sim3_from_Pose3_maps_robust(baseline_wTi_dict, current_wTi_dict)
current_wTi_dict = transform.Pose3_map_with_sim3(aSb, current_wTi_dict)
aSb = align.sim3_from_optional_Pose3s(baseline_wTi_list, current_wTi_list)
current_wTi_list = transform.optional_Pose3s_with_sim3(aSb, current_wTi_list)
Comment thread
hkhanuja marked this conversation as resolved.
current_wTi_dict = {fname: aSb.transformFrom(pose) for fname, pose in current_wTi_dict.items()}

i2Ri1_dict_gt, i2Ui1_dict_gt = metric_utils.get_all_relative_rotations_translations(baseline_wTi_dict)
i2Ri1_dict_gt, i2Ui1_dict_gt = metric_utils.get_all_relative_rotations_translations(baseline_wTi_list)

wRi_aligned, wti_aligned = metric_utils.get_rotations_translations_from_poses(current_wTi_dict)
baseline_wRi, baseline_wti = metric_utils.get_rotations_translations_from_poses(baseline_wTi_dict)
wRi_aligned_list, wti_aligned_list = metric_utils.get_rotations_translations_from_poses(current_wTi_list)
baseline_wRi_list, baseline_wti_list = metric_utils.get_rotations_translations_from_poses(baseline_wTi_list)

metrics = []
metrics.append(metric_utils.compute_rotation_angle_metric(wRi_aligned, baseline_wRi))
metrics.append(metric_utils.compute_translation_distance_metric(wti_aligned, baseline_wti))
metrics.append(metric_utils.compute_translation_angle_metric(baseline_wTi_dict, current_wTi_dict))
metrics.append(metric_utils.compute_rotation_angle_metric(wRi_aligned_list, baseline_wRi_list))
metrics.append(metric_utils.compute_translation_distance_metric(wti_aligned_list, baseline_wti_list))
metrics.append(metric_utils.compute_translation_angle_metric(baseline_wTi_list, current_wTi_list))
relative_rotation_error_metric = metric_utils.compute_relative_rotation_angle_metric(
i2Ri1_dict_gt, current_wTi_dict, store_full_data=True
i2Ri1_dict_gt, current_wTi_list, store_full_data=True
)
metrics.append(relative_rotation_error_metric)
relative_translation_error_metric = metric_utils.compute_relative_translation_angle_metric(
i2Ui1_dict_gt, current_wTi_dict, store_full_data=True
i2Ui1_dict_gt, current_wTi_list, store_full_data=True
)
metrics.append(relative_translation_error_metric)

Expand All @@ -114,7 +226,14 @@ def compare_poses(baseline_dirpath: str, eval_dirpath: str, output_dirpath: str)

ba_pose_metrics = GtsfmMetricsGroup(name="ba_pose_error_metrics", metrics=metrics)

auc_text = _format_pose_auc(ba_pose_metrics)
title = eval_dirpath
if auc_text:
title = f"{title}\nPose AUC: {auc_text}"
plot_camera_centers(baseline_wTi_list, list(current_wTi_dict.values()), output_dirpath, title=title)

save_metrics_reports([ba_pose_metrics], metrics_path=output_dirpath)
return ba_pose_metrics


if __name__ == "__main__":
Expand All @@ -139,4 +258,6 @@ def compare_poses(baseline_dirpath: str, eval_dirpath: str, output_dirpath: str)

os.makedirs(args.output, exist_ok=True)

compare_poses(args.baseline, args.current, args.output)
ba_pose_metrics = compare_poses(args.baseline, args.current, args.output)
export_metrics_group_to_csv(ba_pose_metrics, os.path.join(args.output, f"{ba_pose_metrics.name}.csv"))

Loading