diff --git a/.gitignore b/.gitignore index 29a4be5f9..62fe37926 100644 --- a/.gitignore +++ b/.gitignore @@ -155,6 +155,7 @@ data/ # Data dumped by GTSFM directory debug/ plots/ +**/*result*/ results/ result_metrics/ *.html diff --git a/environment_linux.yml b/environment_linux.yml index 46639b053..479688b81 100644 --- a/environment_linux.yml +++ b/environment_linux.yml @@ -56,7 +56,7 @@ dependencies: - pydegensac - colour - trimesh[easy] - - gtsam==4.3a0 + - gtsam-develop==4.3a1.dev202602040056 - pydot # dust3r/mast3r - roma diff --git a/environment_linux_cpuonly.yml b/environment_linux_cpuonly.yml index 03827ba55..00955e1f3 100644 --- a/environment_linux_cpuonly.yml +++ b/environment_linux_cpuonly.yml @@ -56,7 +56,7 @@ dependencies: - pydegensac - colour - trimesh[easy] - - gtsam==4.3a0 + - gtsam-develop==4.3a1.dev202602040056 - pydot - roma - tqdm diff --git a/environment_mac.yml b/environment_mac.yml index 97303fb95..59eb40997 100644 --- a/environment_mac.yml +++ b/environment_mac.yml @@ -57,7 +57,7 @@ dependencies: - pydegensac - colour - trimesh[easy] - - gtsam==4.3a0 + - gtsam-develop==4.3a1.dev202602040056 - pydot # dust3r/mast3r - roma diff --git a/gtsfm/averaging/translation/averaging_1dsfm.py b/gtsfm/averaging/translation/averaging_1dsfm.py index 41e276e16..2c4463176 100644 --- a/gtsfm/averaging/translation/averaging_1dsfm.py +++ b/gtsfm/averaging/translation/averaging_1dsfm.py @@ -51,7 +51,7 @@ MAX_PROJECTION_DIRECTIONS = 2000 OUTLIER_WEIGHT_THRESHOLD = 0.125 -NOISE_MODEL_DIMENSION = 3 # chordal distances on Unit3 +NOISE_MODEL_DIMENSION = 2 # chordal distances on Unit3 NOISE_MODEL_SIGMA = 0.01 HUBER_LOSS_K = 1.3 # default value from GTSAM @@ -73,7 +73,7 @@ L = symbol_shorthand.B # for track (landmark) translation variables RelativeDirectionsDict = AnnotatedGraph[Unit3] -DUMMY_NOISE_MODEL = gtsam.noiseModel.Isotropic.Sigma(3, 1e-2) # MFAS does not use this. +DUMMY_NOISE_MODEL = gtsam.noiseModel.Isotropic.Sigma(2, 1e-2) # MFAS does not use this. class TranslationAveraging1DSFM(TranslationAveragingBase): diff --git a/gtsfm/bundle/bundle_adjustment.py b/gtsfm/bundle/bundle_adjustment.py index 7bcb3b4a1..d897a66d9 100644 --- a/gtsfm/bundle/bundle_adjustment.py +++ b/gtsfm/bundle/bundle_adjustment.py @@ -12,13 +12,7 @@ import gtsam # type: ignore import numpy as np from dask.delayed import Delayed -from gtsam import ( - BetweenFactorPose3, - NonlinearFactorGraph, - PriorFactorPose3, - PriorFactorPoint3, - Values, -) +from gtsam import BetweenFactorPose3, NonlinearFactorGraph, PriorFactorPoint3, PriorFactorPose3, Values from gtsam.noiseModel import Diagonal, Isotropic, Robust, mEstimator # type: ignore from gtsam.symbol_shorthand import K, P, X # type: ignore diff --git a/gtsfm/cluster_merging.py b/gtsfm/cluster_merging.py index 3a30fa1a7..6ab654c47 100644 --- a/gtsfm/cluster_merging.py +++ b/gtsfm/cluster_merging.py @@ -7,9 +7,10 @@ from pathlib import Path from typing import TYPE_CHECKING, Optional, Tuple +import gtsam import numpy as np from dask.distributed import Client, Future -from gtsam import Similarity3, Pose3 +from gtsam import Similarity3, Pose3, UnaryMeasurementPose3, TrajectoryAlignerSim3 import gtsfm.utils.logger as logger_utils import gtsfm.common.types as gtsfm_types @@ -35,6 +36,56 @@ _SCENE_LABEL_ATTR = "_gtsfm_cluster_label" +def _create_unary_measurements(scene: GtsfmData) -> list[UnaryMeasurementPose3]: + # TODO(akshay-krishnan): investigate using a scene-dependent noise model + # perhaps * np.exp(-len(scene.get_valid_camera_indices()) / 100.0) + noise_model = gtsam.noiseModel.Diagonal.Sigmas(np.array([1e-2, 1e-2, 1e-2, 1e-1, 1e-1, 1e-1])) + unary_measurements = [] + for i, camera in scene.get_camera_poses().items(): + if camera is None: + continue + unary_measurement = UnaryMeasurementPose3(i, camera, noise_model) + unary_measurements.append(unary_measurement) + return unary_measurements + + +def merge_scenes_with_sim3_nonlinear(parent_scene: GtsfmData, children_scenes: list[GtsfmData]) -> GtsfmData: + if len(children_scenes) == 0: + return parent_scene + + aTi_measurements = _create_unary_measurements(parent_scene) + parent_camera_ids = set(parent_scene.get_valid_camera_indices()) + valid_child_scenes = [] + + for i, child_scene in enumerate(children_scenes): + child_camera_ids = set(child_scene.get_valid_camera_indices()) + common_camera_ids = parent_camera_ids & child_camera_ids + if len(common_camera_ids) == 0: + logger.warning("Child scene %d has insufficient overlap with parent, skipping", i) + continue + valid_child_scenes.append(child_scene) + + if len(valid_child_scenes) == 0: + return parent_scene + + aTi_measurements = _create_unary_measurements(parent_scene) + bTi_measurements = [_create_unary_measurements(child_scene) for child_scene in valid_child_scenes] + aligner = TrajectoryAlignerSim3(aTi_measurements, bTi_measurements) + result = aligner.solve() + + opt_aTi = {i: result.atPose3(i) for i in parent_scene.get_valid_camera_indices() if i in result.keys()} + + merged = parent_scene + for i, aTi in opt_aTi.items(): + merged.update_camera_pose(i, aTi) + + for i, child_scene in enumerate(valid_child_scenes): + opt_bSa = result.atSimilarity3(gtsam.Symbol("S", i).key()) + opt_aSb = opt_bSa.inverse() + merged = merged.merged_with(child_scene, opt_aSb) # type: ignore + return merged + + @dataclass(frozen=True) class MergedNodeResult: """Results of merging child scenes with parent scenes in the reconstruction tree. @@ -224,9 +275,7 @@ def _get_pose_metrics( aligned_result_data = result_data.align_via_sim3_and_transform(poses_gt) return metrics_utils.compute_ba_pose_metrics( - gt_wTi=poses_gt, - computed_wTi=aligned_result_data.get_camera_poses(), - save_dir=save_dir, + gt_wTi=poses_gt, computed_wTi=aligned_result_data.get_camera_poses(), save_dir=save_dir, store_full_data=True ) @@ -333,6 +382,8 @@ def _drop_outlier_tracks(scene: GtsfmData) -> GtsfmData: Returns: The scene with outlier tracks dropped. """ + if scene.number_tracks() == 0: + return scene track_errors: list[float] = [] tracks = scene.tracks() cameras = scene.cameras() @@ -406,6 +457,7 @@ def combine_results( drop_camera_with_no_track: bool = True, drop_child_if_merging_fail: bool = True, store_full_data: bool = False, + use_nonlinear_sim3_alignment: bool = False, ) -> MergedNodeResult: """Run the merging and parent BA pipeline using already-transformed children. @@ -469,17 +521,27 @@ def _finalize_result(result_scene: Optional[GtsfmData]) -> MergedNodeResult: merged = current _log_scene_reprojection_stats(merged, "Current node", plot_histograms=plot_reprojection_histograms) - # Merge all children into the merged scene. - for i, child in enumerate(valid_child_scenes): - merged = _align_and_merge_results(merged, child, drop_if_merging_fails=drop_child_if_merging_fail) - _log_scene_reprojection_stats(merged, f"Merged with child #{i+1}", plot_histograms=plot_reprojection_histograms) + if use_nonlinear_sim3_alignment: + merged = merge_scenes_with_sim3_nonlinear(merged, valid_child_scenes) + _log_scene_reprojection_stats( + merged, "Merged with children (nonlinear alignment)", plot_histograms=plot_reprojection_histograms + ) + else: + # Merge all children into the merged scene. + for i, child in enumerate(valid_child_scenes): + merged = _align_and_merge_results(merged, child, drop_if_merging_fails=drop_child_if_merging_fail) + _log_scene_reprojection_stats( + merged, f"Merged with child #{i+1}", plot_histograms=plot_reprojection_histograms + ) _propagate_scene_metadata(merged, metadata_source) - if drop_outlier_after_camera_merging and merged is not None and merged.number_tracks() > 0: - merged = _drop_outlier_tracks(merged) + if merged is None: + return _finalize_result(None) if not run_bundle_adjustment_on_parent: + if drop_outlier_after_camera_merging: + merged = _drop_outlier_tracks(merged) return _finalize_result(merged) # Log cameras that have no supporting track measurements before running BA. @@ -498,6 +560,14 @@ def _finalize_result(result_scene: Optional[GtsfmData]) -> MergedNodeResult: "merged result (with ba)", plot_histograms=plot_reprojection_histograms, ) + if drop_outlier_after_camera_merging: + merged_with_ba = _drop_outlier_tracks(merged_with_ba) + _log_scene_reprojection_stats( + merged_with_ba, + "merged result (with ba + outlier filtering)", + plot_histograms=plot_reprojection_histograms, + ) + # TODO: the order here is different from the merging order above, we should fix this. if merged.has_gaussian_splats(): logger.info("🫱🏻‍🫲🏽 Merging Gaussians") diff --git a/gtsfm/cluster_optimizer/cluster_optimizer_base.py b/gtsfm/cluster_optimizer/cluster_optimizer_base.py index 3a2b3deb4..bf6496d38 100644 --- a/gtsfm/cluster_optimizer/cluster_optimizer_base.py +++ b/gtsfm/cluster_optimizer/cluster_optimizer_base.py @@ -102,6 +102,7 @@ def __init__( drop_outlier_after_camera_merging: bool = True, plot_reprojection_histograms: bool = True, run_bundle_adjustment_on_parent: bool = True, + run_bundle_adjustment_on_leaf: bool = False, output_worker: None | str = None, ) -> None: self.drop_child_if_merging_fail = drop_child_if_merging_fail @@ -109,6 +110,7 @@ def __init__( self.drop_outlier_after_camera_merging = drop_outlier_after_camera_merging self.plot_reprojection_histograms = plot_reprojection_histograms self.run_bundle_adjustment_on_parent = run_bundle_adjustment_on_parent + self.run_bundle_adjustment_on_leaf = run_bundle_adjustment_on_leaf self._pose_angular_error_thresh = pose_angular_error_thresh self._output_worker = output_worker diff --git a/gtsfm/cluster_optimizer/cluster_optimizer_cacher.py b/gtsfm/cluster_optimizer/cluster_optimizer_cacher.py index b59e37e00..e320bdc67 100644 --- a/gtsfm/cluster_optimizer/cluster_optimizer_cacher.py +++ b/gtsfm/cluster_optimizer/cluster_optimizer_cacher.py @@ -3,6 +3,7 @@ from __future__ import annotations import hashlib +import os import typing from pathlib import Path from typing import Optional, TYPE_CHECKING @@ -34,18 +35,26 @@ class ClusterOptimizerCacher(ClusterOptimizerBase): """Caches the delayed bundle result produced by a cluster optimizer.""" - def __init__(self, optimizer: ClusterOptimizerBase) -> None: + def __init__(self, optimizer: ClusterOptimizerBase, cache_subdir: Optional[str] = None) -> None: """Initializes the cacher with the actual cluster optimizer object. Args: optimizer: cluster optimizer to use in case of cache miss. + cache_subdir: Optional subdirectory (relative to cache root) for storing cache entries. """ + run_bundle_adjustment_on_leaf = getattr(optimizer, "run_bundle_adjustment_on_leaf", None) + if run_bundle_adjustment_on_leaf is None: + run_bundle_adjustment_on_leaf = getattr(optimizer, "_run_bundle_adjustment_on_leaf", False) super().__init__( pose_angular_error_thresh=optimizer.pose_angular_error_thresh, + run_bundle_adjustment_on_leaf=run_bundle_adjustment_on_leaf, + run_bundle_adjustment_on_parent=getattr(optimizer, "run_bundle_adjustment_on_parent", True), output_worker=optimizer._output_worker, ) self._optimizer = optimizer self._optimizer_hash = hashlib.sha1(repr(optimizer).encode()).hexdigest() + self._cache_subdir = cache_subdir if cache_subdir is not None else os.getenv("GTSFM_CACHE_SUBDIR") + self._cache_root_path = self._resolve_cache_root(self._cache_subdir) def __repr__(self) -> str: return repr(self._optimizer) @@ -64,20 +73,38 @@ def __getstate__(self) -> dict[str, object]: return { "_optimizer": self._optimizer, "_optimizer_hash": self._optimizer_hash, + "_cache_subdir": self._cache_subdir, } def __setstate__(self, state: dict[str, object]) -> None: """Restore state and keep worker routing consistent.""" self._optimizer = typing.cast(ClusterOptimizerBase, state["_optimizer"]) self._optimizer_hash = typing.cast(str, state["_optimizer_hash"]) + self._cache_subdir = typing.cast(Optional[str], state.get("_cache_subdir")) + self._cache_root_path = self._resolve_cache_root(self._cache_subdir) # Re-initialize the base class to mimic the constructor. + run_bundle_adjustment_on_leaf = getattr(self._optimizer, "run_bundle_adjustment_on_leaf", None) + if run_bundle_adjustment_on_leaf is None: + run_bundle_adjustment_on_leaf = getattr(self._optimizer, "_run_bundle_adjustment_on_leaf", False) super().__init__( pose_angular_error_thresh=self._optimizer.pose_angular_error_thresh, + run_bundle_adjustment_on_leaf=run_bundle_adjustment_on_leaf, + run_bundle_adjustment_on_parent=getattr(self._optimizer, "run_bundle_adjustment_on_parent", True), output_worker=self._optimizer._output_worker, ) + @staticmethod + def _resolve_cache_root(cache_subdir: Optional[str]) -> Path: + """Resolve the cache root path, optionally using a subdirectory or absolute override.""" + if not cache_subdir: + return CACHE_ROOT_PATH + subdir_path = Path(cache_subdir) + if subdir_path.is_absolute(): + return subdir_path + return CACHE_ROOT_PATH / subdir_path + def _get_cache_path(self, cache_key: str) -> Path: - return CACHE_ROOT_PATH / "cluster_optimizer" / f"{cache_key}.pbz2" + return self._cache_root_path / "cluster_optimizer" / f"{cache_key}.pbz2" def _hash_one_view_data(self, one_view_data: Optional["OneViewData"]) -> str: """Compute a stable hash for OneViewData contents.""" @@ -127,11 +154,25 @@ def _save_result_to_cache(self, result: GtsfmData, cache_path: Path) -> GtsfmDat io_utils.write_to_bz2_file(result, cache_path) return result + def _save_cached_result_outputs(self, result: GtsfmData, results_path: Path) -> None: + """Persist cached outputs expected by downstream tooling. + + Currently used to re-export VGGT reconstructions in COLMAP text format. + """ + if "VGGT" not in type(self._optimizer).__name__: + return + target_dir = results_path / "vggt" + target_dir.mkdir(parents=True, exist_ok=True) + result.export_as_colmap_text(target_dir) + def create_computation_graph(self, context: ClusterContext) -> ClusterComputationGraph | None: cached_result = self._load_result_from_cache(context) if cached_result is not None: cached_graph: Delayed = delayed(lambda r: r, pure=False)(cached_result) - return ClusterComputationGraph(io_tasks=tuple(), metric_tasks=tuple(), sfm_result=cached_graph) + io_tasks = ( + delayed(self._save_cached_result_outputs, pure=False)(cached_graph, context.output_paths.results), + ) + return ClusterComputationGraph(io_tasks=io_tasks, metric_tasks=tuple(), sfm_result=cached_graph) computation = self._optimizer.create_computation_graph(context) if computation is None or computation.sfm_result is None: diff --git a/gtsfm/cluster_optimizer/cluster_vggt.py b/gtsfm/cluster_optimizer/cluster_vggt.py index 9c934d50b..aab76bbd9 100644 --- a/gtsfm/cluster_optimizer/cluster_vggt.py +++ b/gtsfm/cluster_optimizer/cluster_vggt.py @@ -9,12 +9,15 @@ import torch import torch.nn.functional as F from dask.delayed import Delayed, delayed +from gtsam import Pose3 +import gtsfm.common.types as gtsfm_types import gtsfm.frontend.vggt as vggt +import gtsfm.utils.metrics as metrics_utils from gtsfm.cluster_optimizer.cluster_optimizer_base import ClusterComputationGraph, ClusterContext, ClusterOptimizerBase from gtsfm.common.gtsfm_data import GtsfmData from gtsfm.evaluation.metrics import GtsfmMetric, GtsfmMetricsGroup -from gtsfm.frontend.vggt import VggtConfiguration +from gtsfm.frontend.vggt import VggtConfiguration, VggtReconstruction from gtsfm.products.visibility_graph import visibility_graph_keys from gtsfm.ui.gtsfm_process import UiMetadata from gtsfm.utils.logger import get_logger @@ -32,13 +35,9 @@ def _resize_to_square_tensor(image: np.ndarray, target_size: int) -> torch.Tenso return (tensor.squeeze(0)) / 255.0 -def _load_vggt_inputs(loader, indices: list[int], target_size: int): +def _load_vggt_inputs(loader, indices: list[int], mode: str): """Load and preprocess a batch of images for VGGT.""" - - def resize_transform(arr: np.ndarray) -> torch.Tensor: - return _resize_to_square_tensor(arr, target_size) - - return loader.load_image_batch_vggt(indices, target_size, resize_transform) + return vggt.load_image_batch_vggt_loader(loader, indices, mode=mode) def _resolve_vggt_model(cache_key: Hashable | None, loader_kwargs: dict[str, Any] | None) -> Any | None: @@ -65,7 +64,7 @@ def _run_vggt_pipeline( model_cache_key: Hashable | None = None, loader_kwargs: dict[str, Any] | None = None, **kwargs, -) -> GtsfmData: +) -> VggtReconstruction: torch.manual_seed(seed) np.random.seed(seed) if torch.cuda.is_available(): @@ -83,45 +82,104 @@ def _run_vggt_pipeline( cached_model = _resolve_vggt_model(model_cache_key, loader_kwargs) if cached_model is not None: kwargs = {**kwargs, "model": cached_model} - return vggt.run_reconstruction_gtsfm_data_only(image_batch, **kwargs) + return vggt.run_reconstruction(image_batch, **kwargs) def _save_reconstruction_as_text( result: GtsfmData, results_path: Path, + *, + subdir: str = "vggt", ) -> None: - target_dir = results_path / "vggt" + target_dir = results_path / subdir target_dir.mkdir(parents=True, exist_ok=True) result.export_as_colmap_text(target_dir) -def _aggregate_vggt_metrics(result: GtsfmData) -> GtsfmMetricsGroup: - num_cameras = len(result.get_valid_camera_indices()) - num_points3d = result.number_tracks() - return GtsfmMetricsGroup( - "vggt_runtime_metrics", - [ - GtsfmMetric("num_cameras", num_cameras), - GtsfmMetric("num_points3d", num_points3d), - ], +def _save_pre_ba_reconstruction_as_text( + pre_ba_result: Optional[GtsfmData], + results_path: Path, +) -> None: + if pre_ba_result is None: + return + _save_reconstruction_as_text(pre_ba_result, results_path, subdir="vggt_pre_ba") + + +def _get_pose_metrics( + result_data: GtsfmData, + cameras_gt: list[Optional[gtsfm_types.CAMERA_TYPE]], + save_dir: Optional[str] = None, +) -> GtsfmMetricsGroup: + """Compute pose metrics for a VGGT result after aligning with ground truth.""" + image_idxs = list(result_data._image_info.keys()) + poses_gt: dict[int, Pose3] = {} + for i in image_idxs: + if i >= len(cameras_gt): + continue + camera = cameras_gt[i] + if camera is not None: + poses_gt[i] = camera.pose() + if len(poses_gt) == 0: + return GtsfmMetricsGroup(name="ba_pose_error_metrics", metrics=[]) + aligned_result_data = result_data.align_via_sim3_and_transform(poses_gt) + computed_wTi: dict[int, Optional[Pose3]] = {i: pose for i, pose in aligned_result_data.get_camera_poses().items()} + return metrics_utils.compute_ba_pose_metrics( + gt_wTi=poses_gt, + computed_wTi=computed_wTi, + save_dir=save_dir, + store_full_data=True, ) +def _aggregate_vggt_metrics( + result: GtsfmData, + cameras_gt: Optional[list[Optional[gtsfm_types.CAMERA_TYPE]]] = None, + pre_ba_result: Optional[GtsfmData] = None, + *, + save_dir: Optional[str] = None, +) -> list[GtsfmMetricsGroup]: + def _build_metrics_group(scene: GtsfmData, name: str) -> GtsfmMetricsGroup: + metrics_group = GtsfmMetricsGroup( + name, + [ + GtsfmMetric("num_cameras", len(scene.get_valid_camera_indices())), + GtsfmMetric("num_points3d", scene.number_tracks()), + ], + ) + if cameras_gt is not None: + metrics_group.extend(_get_pose_metrics(scene, cameras_gt, save_dir=save_dir)) + return metrics_group + + metrics_groups = [_build_metrics_group(result, "cluster_vggt_metrics")] + if pre_ba_result is not None: + metrics_groups.append(_build_metrics_group(pre_ba_result, "cluster_vggt_pre_ba_metrics")) + return metrics_groups + + +def _extract_post_ba_result(result: VggtReconstruction) -> GtsfmData: + """Extract the post-BA reconstruction from the VGGT pipeline output.""" + return result.gtsfm_data + + +def _extract_pre_ba_result(result: VggtReconstruction) -> Optional[GtsfmData]: + """Extract the optional pre-BA reconstruction for debugging.""" + return result.pre_ba_data + + class ClusterVGGT(ClusterOptimizerBase): """Cluster optimizer that runs VGGT to generate COLMAP-style reconstructions.""" def __init__( self, weights_path: Optional[str] = None, - image_load_resolution: int = 1024, - inference_resolution: int = 518, conf_threshold: float = 5.0, max_num_points: int = 100000, tracking: bool = False, - tracking_max_query_pts: int = 1000, - tracking_query_frame_num: int = 4, - tracking_fine_tracking: bool = True, - track_vis_thresh: float = 0.2, + tracking_max_query_pts: int = 2048, + tracking_query_frame_num: int = 3, + track_vis_thresh: float = 0.05, + track_conf_thresh: float = 0.2, + keypoint_extractor: str = "aliked+sp+sift", camera_type: str = "PINHOLE", seed: int = 42, scene_dir: Optional[str] = None, @@ -137,6 +195,7 @@ def __init__( enable_protection: bool = False, extra_model_kwargs: Optional[dict[str, Any]] = None, run_bundle_adjustment_on_leaf: bool = False, + store_pre_ba_result: bool = False, run_bundle_adjustment_on_parent: bool = True, max_reproj_error: float = 8.0, plot_reprojection_histograms: bool = True, @@ -154,15 +213,14 @@ def __init__( run_bundle_adjustment_on_parent=run_bundle_adjustment_on_parent, ) self._weights_path = Path(weights_path) if weights_path is not None else None - self._image_load_resolution = image_load_resolution - self._inference_resolution = inference_resolution self._conf_threshold = conf_threshold self._max_points_for_colmap = max_num_points self._tracking = tracking self._tracking_max_query_pts = tracking_max_query_pts self._tracking_query_frame_num = tracking_query_frame_num - self._tracking_fine_tracking = tracking_fine_tracking self._track_vis_thresh = track_vis_thresh + self._track_conf_thresh = track_conf_thresh + self._keypoint_extractor = keypoint_extractor self._camera_type = camera_type self._max_reproj_error = max_reproj_error self._seed = seed @@ -170,6 +228,7 @@ def __init__( self._use_sparse_attention = use_sparse_attention self._dtype = inference_dtype self._run_bundle_adjustment_on_leaf = run_bundle_adjustment_on_leaf + self._store_pre_ba_result = store_pre_ba_result if fast_dtype is not None: if self._dtype is None: self._dtype = fast_dtype @@ -216,8 +275,6 @@ def _maybe_set_model_kw(key: str, value: Any) -> None: def __repr__(self) -> str: components = [ f"weights_path={self._weights_path}", - f"image_load_resolution={self._image_load_resolution}", - f"inference_resolution={self._inference_resolution}", f"camera_type={self._camera_type}", f"dtype={self._dtype}", f"use_sparse_attention={self._use_sparse_attention}", @@ -253,27 +310,29 @@ def create_computation_graph( image_names = tuple(str(image_filenames[idx]) for idx in keys) config = VggtConfiguration( - vggt_fixed_resolution=self._inference_resolution, - img_load_resolution=self._image_load_resolution, confidence_threshold=self._conf_threshold, max_num_points=self._max_points_for_colmap, tracking=self._tracking, max_query_pts=self._tracking_max_query_pts, query_frame_num=self._tracking_query_frame_num, - fine_tracking=self._tracking_fine_tracking, track_vis_thresh=self._track_vis_thresh, + track_conf_thresh=self._track_conf_thresh, + keypoint_extractor=self._keypoint_extractor, dtype=self._dtype, model_ctor_kwargs=self._model_ctor_kwargs.copy(), use_sparse_attention=self._use_sparse_attention, run_bundle_adjustment_on_leaf=self._run_bundle_adjustment_on_leaf, + store_pre_ba_result=self._store_pre_ba_result, max_reproj_error=self._max_reproj_error, ) + # mode is fixed to "crop", it resizes the width to 518 while maintaining aspect ratio and only if + # height is > 518 then crops image_batch_graph, original_coords_graph = delayed(_load_vggt_inputs, nout=2)( - context.loader, global_indices, self._image_load_resolution + context.loader, global_indices, mode="crop" ) - result_graph = delayed(_run_vggt_pipeline)( + reconstruction_graph = delayed(_run_vggt_pipeline)( image_batch_graph, seed=self._seed, original_coords=original_coords_graph, @@ -285,8 +344,18 @@ def create_computation_graph( loader_kwargs=self._loader_kwargs or None, cluster_label=context.label, ) - - metrics_tasks = [delayed(_aggregate_vggt_metrics)(result_graph)] + result_graph = delayed(_extract_post_ba_result)(reconstruction_graph) + pre_ba_result_graph = delayed(_extract_pre_ba_result)(reconstruction_graph) + + cameras_gt = [context.one_view_data_dict[idx].camera_gt for idx in range(context.num_images)] + metrics_tasks = [ + delayed(_aggregate_vggt_metrics)( + result_graph, + cameras_gt=cameras_gt, + pre_ba_result=pre_ba_result_graph, + save_dir=str(context.output_paths.metrics), + ) + ] io_tasks: list[Delayed] = [] with self._output_annotation(): @@ -296,6 +365,12 @@ def create_computation_graph( context.output_paths.results, ) ) + io_tasks.append( + delayed(_save_pre_ba_reconstruction_as_text)( + pre_ba_result_graph, + context.output_paths.results, + ) + ) return ClusterComputationGraph( io_tasks=tuple(io_tasks), diff --git a/gtsfm/common/gtsfm_data.py b/gtsfm/common/gtsfm_data.py index 0d8b59908..298587ee6 100644 --- a/gtsfm/common/gtsfm_data.py +++ b/gtsfm/common/gtsfm_data.py @@ -540,6 +540,14 @@ def get_tracks(self) -> List[SfmTrack]: """Returns all tracks.""" return self._tracks + def update_camera_pose(self, index: int, pose: Pose3) -> None: + """Updates the pose of a camera at index.""" + if index not in self._cameras: + raise ValueError(f"Camera at index {index} not found") + K = self._cameras[index].calibration() + new_camera = gtsfm_types.get_camera_class_for_calibration(K)(pose, K) + self._cameras[index] = new_camera + def add_camera(self, index: int, camera: gtsfm_types.CAMERA_TYPE) -> None: """Adds camera at index if not already present.""" if camera is None: @@ -846,11 +854,12 @@ def align_via_sim3_and_transform(self, aTi: dict[int, Pose3]) -> "GtsfmData": def get_metrics(self, suffix: str, store_full_data: bool = False) -> List[GtsfmMetric]: """Helper to get bundle adjustment metrics from a GtsfmData object with a suffix for metric names.""" metrics = [] - metrics.append(GtsfmMetric(name="number_cameras", data=len(self.get_valid_camera_indices()))) - metrics.append(GtsfmMetric("number_tracks" + suffix, self.number_tracks())) + metrics.append(GtsfmMetric(name=f"number_images{suffix}", data=self.number_images())) + metrics.append(GtsfmMetric(name=f"number_cameras{suffix}", data=len(self.get_valid_camera_indices()))) + metrics.append(GtsfmMetric(name=f"number_tracks{suffix}", data=self.number_tracks())) metrics.append( GtsfmMetric( - name="3d_track_lengths" + suffix, + name=f"3d_track_lengths{suffix}", data=self.get_track_lengths(), plot_type=GtsfmMetric.PlotType.HISTOGRAM, store_full_data=store_full_data, diff --git a/gtsfm/configs/vggt.yaml b/gtsfm/configs/vggt.yaml index 979cca861..4fc525117 100644 --- a/gtsfm/configs/vggt.yaml +++ b/gtsfm/configs/vggt.yaml @@ -7,7 +7,6 @@ loader: _target_: gtsfm.loader.Olsson dataset_dir: ??? # Required: set to the dataset root on the command line. images_dir: null - max_resolution: 760 image_pairs_generator: _target_: gtsfm.retriever.image_pairs_generator.ImagePairsGenerator @@ -18,26 +17,26 @@ image_pairs_generator: retriever: _target_: gtsfm.retriever.Similarity num_matched: 5 - min_score: 0.79 + min_score: 0.25 batch_size: 16 graph_partitioner: - _target_: gtsfm.graph_partitioner.Single + _target_: gtsfm.graph_partitioner.Metis cluster_optimizer: _target_: gtsfm.cluster_optimizer.Cacher optimizer: _target_: gtsfm.cluster_optimizer.cluster_vggt.ClusterVGGT weights_path: null - image_load_resolution: 1024 - inference_resolution: 518 conf_threshold: 5.0 max_num_points: 100000 tracking: true - tracking_max_query_pts: 512 + tracking_max_query_pts: 2048 tracking_query_frame_num: 3 - tracking_fine_tracking: false - track_vis_thresh: 0.2 + keypoint_extractor: aliked+sp+sift + track_vis_thresh: 0.05 + track_conf_thresh: 0.2 + max_reproj_error: 0 # 0.0 means no filtering based on reproj error camera_type: PINHOLE drop_outlier_after_camera_merging: false drop_child_if_merging_fail: true @@ -45,5 +44,5 @@ cluster_optimizer: seed: 42 plot_reprojection_histograms: true run_bundle_adjustment_on_leaf: false - run_bundle_adjustment_on_parent: true + run_bundle_adjustment_on_parent: false model_cache_key: null diff --git a/gtsfm/configs/vggt_megaloc.yaml b/gtsfm/configs/vggt_megaloc.yaml new file mode 100644 index 000000000..d4cbe08a7 --- /dev/null +++ b/gtsfm/configs/vggt_megaloc.yaml @@ -0,0 +1,54 @@ +# VGGT cluster-only configuration. + +# @package _global_ +_target_: gtsfm.scene_optimizer.SceneOptimizer + +loader: + _target_: gtsfm.loader.Olsson + dataset_dir: ??? # Required: set to the dataset root on the command line. + images_dir: null + max_resolution: 760 + +image_pairs_generator: + _target_: gtsfm.retriever.image_pairs_generator.ImagePairsGenerator + global_descriptor: + _target_: gtsfm.frontend.cacher.global_descriptor_cacher.GlobalDescriptorCacher + global_descriptor_obj: + _target_: gtsfm.frontend.global_descriptor.MegaLoc + retriever: + _target_: gtsfm.retriever.Similarity + num_matched: 15 + min_score: 0.4 + batch_size: 16 + +graph_partitioner: + _target_: gtsfm.graph_partitioner.Metis + +cluster_optimizer: + _target_: gtsfm.cluster_optimizer.Cacher + optimizer: + _target_: gtsfm.cluster_optimizer.cluster_vggt.ClusterVGGT + weights_path: null + image_load_resolution: 1024 + inference_resolution: 518 + conf_threshold: 5.0 + max_num_points: 100000 + tracking: true + tracking_max_query_pts: 512 + tracking_query_frame_num: 3 + tracking_fine_tracking: false + track_vis_thresh: 0.2 + track_conf_thresh: 0.2 + max_reproj_error: 0 # 0.0 means no filtering based on reproj error + camera_type: PINHOLE + drop_outlier_after_camera_merging: false + drop_child_if_merging_fail: true + drop_camera_with_no_track: true + seed: 42 + plot_reprojection_histograms: true + run_bundle_adjustment_on_leaf: false + run_bundle_adjustment_on_parent: true + model_cache_key: null + store_pre_ba_result: true + +use_nonlinear_sim3_alignment: false diff --git a/gtsfm/evaluation/compare_colmap_outputs.py b/gtsfm/evaluation/compare_colmap_outputs.py index 6e163d697..a84fd3145 100644 --- a/gtsfm/evaluation/compare_colmap_outputs.py +++ b/gtsfm/evaluation/compare_colmap_outputs.py @@ -1,13 +1,17 @@ """Script to compare two reconstructions in Colmap's output format. -Authors: Ayush Baid +Authors: Ayush Baid, Xinan Zhang """ import argparse +import csv +import json import os -from typing import Dict, Tuple +import textwrap +from typing import Dict, List, Optional, Tuple import numpy as np +import matplotlib.pyplot as plt import pycolmap from gtsam import Point3, Pose3, Rot3, Similarity3 from scipy.spatial.transform import Rotation @@ -53,7 +57,108 @@ def align_with_colmap( return aSb, aligned_dict -def compare_poses(baseline_dirpath: str, eval_dirpath: str, output_dirpath: str) -> None: +def plot_camera_centers( + baseline_wTi_list: List[Pose3], + current_wTi_list: List[Pose3], + output_dirpath: str, + title: Optional[str] = None, +) -> None: + """Save a 3D scatter plot of baseline and current camera centers.""" + baseline_centers = np.stack([pose.translation() for pose in baseline_wTi_list]) + current_centers_list = [pose.translation() for pose in current_wTi_list] + current_centers = np.stack(current_centers_list) if current_centers_list else np.empty((0, 3)) + + fig = plt.figure(figsize=(7, 7)) + ax = fig.add_subplot(111, projection="3d") + if baseline_centers.size: + center = baseline_centers.mean(axis=0) + mean_radius = np.linalg.norm(baseline_centers - center, axis=1).mean() + arrow_len = max(mean_radius * 0.15, 1e-3) + else: + arrow_len = 1.0 + + for pose in baseline_wTi_list: + origin = pose.transformFrom(Point3(0.0, 0.0, 0.0)) + tip = pose.transformFrom(Point3(0.0, 0.0, arrow_len)) + direction = tip - origin + ax.quiver( + origin[0], origin[1], origin[2], + direction[0], direction[1], direction[2], + color="tab:blue", linewidth=0.5, arrow_length_ratio=0.2, alpha=0.6 + ) + for pose in current_wTi_list: + origin = pose.transformFrom(Point3(0.0, 0.0, 0.0)) + tip = pose.transformFrom(Point3(0.0, 0.0, arrow_len)) + direction = tip - origin + ax.quiver( + origin[0], origin[1], origin[2], + direction[0], direction[1], direction[2], + color="tab:orange", linewidth=0.5, arrow_length_ratio=0.2, alpha=0.6 + ) + + ax.scatter( + baseline_centers[:, 0], + baseline_centers[:, 1], + baseline_centers[:, 2], + s=10, + c="tab:blue", + label="baseline", + ) + if current_centers.size: + ax.scatter( + current_centers[:, 0], + current_centers[:, 1], + current_centers[:, 2], + s=10, + c="tab:orange", + label="current", + ) + ax.set_xlabel("x") + ax.set_ylabel("y") + ax.set_zlabel("z") + ax.legend(loc="best") + wrapped = "\n".join(textwrap.wrap(title, width=80)) if title else "" + if wrapped: + fig.suptitle(wrapped, fontsize=9, y=0.98) + fig.tight_layout(rect=[0, 0, 1, 0.92]) + fig.savefig(os.path.join(output_dirpath, "camera_centers.png"), dpi=300) + plt.close(fig) + + +def export_metrics_group_to_csv(metrics_group: GtsfmMetricsGroup, output_path: str) -> None: + """Export a metrics group to a CSV file.""" + rows: List[Dict[str, str]] = [] + for metric in metrics_group.metrics: + if metric.dim == 0: + value = "" if metric.data is None else f"{float(metric.data):.6f}" + rows.append({"metric_name": metric.name, "value": value}) + else: + summary_json = json.dumps(metric.summary, sort_keys=True) + rows.append({"metric_name": metric.name, "value": summary_json}) + + with open(output_path, "w", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=["metric_name", "value"]) + writer.writeheader() + writer.writerows(rows) + + +def _format_pose_auc(metrics_group: GtsfmMetricsGroup) -> str: + auc_parts = [] + for metric in metrics_group.metrics: + if not metric.name.startswith("pose_auc_@"): + continue + if metric.data is None: + continue + try: + value = float(metric.data) + except (TypeError, ValueError): + continue + suffix = metric.name.replace("pose_auc_", "") + auc_parts.append(f"{suffix}={value:.3f}") + return ", ".join(auc_parts) + + +def compare_poses(baseline_dirpath: str, eval_dirpath: str, output_dirpath: str) -> GtsfmMetricsGroup: """Compare the pose metrics between two reconstructions (Colmap format). Args: @@ -82,18 +187,25 @@ def compare_poses(baseline_dirpath: str, eval_dirpath: str, output_dirpath: str) len(common_fnames), ) + baseline_wTi_list: List[Pose3] = [] + current_wTi_list: List[Optional[Pose3]] = [] + for fname, wTi in baseline_wTi_dict.items(): + baseline_wTi_list.append(wTi) + current_wTi_list.append(current_wTi_dict.get(fname)) + if not args.use_pycolmap_alignment: aSb = align.sim3_from_Pose3_maps_robust(baseline_wTi_dict, current_wTi_dict) - current_wTi_dict = transform.Pose3_map_with_sim3(aSb, current_wTi_dict) + current_wTi_list = transform.optional_Pose3s_with_sim3(aSb, current_wTi_list) + current_wTi_dict = {fname: aSb.transformFrom(pose) for fname, pose in current_wTi_dict.items()} i2Ri1_dict_gt, i2Ui1_dict_gt = metric_utils.get_all_relative_rotations_translations(baseline_wTi_dict) - wRi_aligned, wti_aligned = metric_utils.get_rotations_translations_from_poses(current_wTi_dict) - baseline_wRi, baseline_wti = metric_utils.get_rotations_translations_from_poses(baseline_wTi_dict) + wRi_aligned_dict, wti_aligned_dict = metric_utils.get_rotations_translations_from_poses(current_wTi_dict) + baseline_wRi_dict, baseline_wti_dict = metric_utils.get_rotations_translations_from_poses(baseline_wTi_dict) metrics = [] - metrics.append(metric_utils.compute_rotation_angle_metric(wRi_aligned, baseline_wRi)) - metrics.append(metric_utils.compute_translation_distance_metric(wti_aligned, baseline_wti)) + metrics.append(metric_utils.compute_rotation_angle_metric(wRi_aligned_dict, baseline_wRi_dict)) + metrics.append(metric_utils.compute_translation_distance_metric(wti_aligned_dict, baseline_wti_dict)) metrics.append(metric_utils.compute_translation_angle_metric(baseline_wTi_dict, current_wTi_dict)) relative_rotation_error_metric = metric_utils.compute_relative_rotation_angle_metric( i2Ri1_dict_gt, current_wTi_dict, store_full_data=True @@ -114,7 +226,14 @@ def compare_poses(baseline_dirpath: str, eval_dirpath: str, output_dirpath: str) ba_pose_metrics = GtsfmMetricsGroup(name="ba_pose_error_metrics", metrics=metrics) + auc_text = _format_pose_auc(ba_pose_metrics) + title = eval_dirpath + if auc_text: + title = f"{title}\nPose AUC: {auc_text}" + plot_camera_centers(baseline_wTi_list, list(current_wTi_dict.values()), output_dirpath, title=title) + save_metrics_reports([ba_pose_metrics], metrics_path=output_dirpath) + return ba_pose_metrics if __name__ == "__main__": @@ -139,4 +258,5 @@ def compare_poses(baseline_dirpath: str, eval_dirpath: str, output_dirpath: str) os.makedirs(args.output, exist_ok=True) - compare_poses(args.baseline, args.current, args.output) + ba_pose_metrics = compare_poses(args.baseline, args.current, args.output) + export_metrics_group_to_csv(ba_pose_metrics, os.path.join(args.output, f"{ba_pose_metrics.name}.csv")) diff --git a/gtsfm/evaluation/compare_colmap_outputs_by_cluster.py b/gtsfm/evaluation/compare_colmap_outputs_by_cluster.py new file mode 100644 index 000000000..0ad746992 --- /dev/null +++ b/gtsfm/evaluation/compare_colmap_outputs_by_cluster.py @@ -0,0 +1,913 @@ +"""Compare COLMAP reconstructions using image-name alignment. + +This script walks a results tree, finds cluster reconstructions under a given subfolder +name (default: "vggt"), and evaluates camera pose quality against a COLMAP baseline. +""" + +from __future__ import annotations + +import argparse +import csv +import json +import os +import textwrap +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Tuple + +import matplotlib.pyplot as plt +import numpy as np +from gtsam import Pose3, Rot3 + +import gtsfm.utils.logger as logger_utils +import gtsfm.utils.metrics as metric_utils +from gtsfm.evaluation.metrics import GtsfmMetric, GtsfmMetricsGroup +from gtsfm.utils import align + +logger = logger_utils.get_logger() + + +def _read_images_txt_with_names(images_txt: Path) -> Dict[str, Pose3]: + """Read poses from COLMAP images.txt keyed by image NAME.""" + if not images_txt.exists(): + raise FileNotFoundError(f"{images_txt} does not exist.") + + with images_txt.open("r") as f: + lines = f.readlines() + + poses_by_name: Dict[str, Pose3] = {} + for line in lines: + if not line.strip() or line.startswith("#"): + continue + parts = line.split() + if len(parts) < 10: + continue + _image_id_str, qw, qx, qy, qz, tx, ty, tz, _camera_id = parts[:9] + img_fname = " ".join(parts[9:]) + iRw = Rot3(float(qw), float(qx), float(qy), float(qz)) + wTi = Pose3(iRw, np.array([tx, ty, tz], dtype=np.float64)).inverse() + if not np.isfinite(wTi.translation()).all(): + logger.warning("Skipping non-finite pose for %s in %s", img_fname, images_txt) + continue + poses_by_name[img_fname] = wTi + return poses_by_name + + +def _read_images_txt_with_names_and_cameras(images_txt: Path) -> Tuple[Dict[str, Pose3], Dict[str, int]]: + """Read poses and camera ids from COLMAP images.txt keyed by image NAME.""" + if not images_txt.exists(): + raise FileNotFoundError(f"{images_txt} does not exist.") + + with images_txt.open("r") as f: + lines = f.readlines() + + poses_by_name: Dict[str, Pose3] = {} + camera_by_name: Dict[str, int] = {} + for line in lines: + if not line.strip() or line.startswith("#"): + continue + parts = line.split() + if len(parts) < 10: + continue + _image_id_str, qw, qx, qy, qz, tx, ty, tz, camera_id_str = parts[:9] + img_fname = " ".join(parts[9:]) + iRw = Rot3(float(qw), float(qx), float(qy), float(qz)) + wTi = Pose3(iRw, np.array([tx, ty, tz], dtype=np.float64)).inverse() + if not np.isfinite(wTi.translation()).all(): + logger.warning("Skipping non-finite pose for %s in %s", img_fname, images_txt) + continue + camera_id = int(camera_id_str) + poses_by_name[img_fname] = wTi + camera_by_name[img_fname] = camera_id + return poses_by_name, camera_by_name + + +def _read_cameras_txt_with_ids(cameras_txt: Path) -> Dict[int, Dict[str, float]]: + """Read camera intrinsics from COLMAP cameras.txt keyed by CAMERA_ID.""" + if not cameras_txt.exists(): + raise FileNotFoundError(f"{cameras_txt} does not exist.") + + with cameras_txt.open("r") as f: + lines = f.readlines() + + cameras_by_id: Dict[int, Dict[str, float]] = {} + for line in lines[3:]: + if line.startswith("#") or not line.strip(): + continue + parts = line.split() + camera_id = int(parts[0]) + model = parts[1] + width = int(parts[2]) + height = int(parts[3]) + params = list(map(float, parts[4:])) + if model == "SIMPLE_PINHOLE": + fx, cx, cy = params + fy = fx + k1 = 0.0 + k2 = 0.0 + elif model == "SIMPLE_RADIAL": + fx, cx, cy, k1 = params + k2 = 0.0 + fy = fx + elif model == "RADIAL": + fx, cx, cy, k1, k2 = params + fy = fx + elif model == "PINHOLE": + fx, fy, cx, cy = params + k1 = 0.0 + k2 = 0.0 + elif model == "OPENCV": + fx, fy, cx, cy, k1, k2, _p1, _p2, *_rest = params + elif model == "OPENCV_FISHEYE": + fx, fy, cx, cy, k1, k2, *_rest = params + else: + logger.warning("Unsupported camera model %s; skipping camera_id=%d", model, camera_id) + continue + cameras_by_id[camera_id] = { + "model": model, + "width": float(width), + "height": float(height), + "fx": fx, + "fy": fy, + "cx": cx, + "cy": cy, + "k1": k1, + "k2": k2, + } + return cameras_by_id + + +def _find_cluster_recon_dirs(root: Path, recon_name: str) -> Iterable[Path]: + """Yield directories that match the recon_name and contain images.txt.""" + for dirpath, dirnames, filenames in os.walk(root): + if os.path.basename(dirpath) != recon_name: + continue + if "images.txt" in filenames: + yield Path(dirpath) + + +def _build_pose_lists( + baseline_poses: Dict[str, Pose3], + current_poses: Dict[str, Pose3], + cluster_label: str, +) -> Tuple[List[str], List[Pose3], List[Pose3]]: + """Align poses by image NAME and return matched pose lists.""" + common_names = sorted(set(baseline_poses.keys()) & set(current_poses.keys())) + if not common_names: + missing_in_baseline = sorted(set(current_poses.keys()) - set(baseline_poses.keys())) + missing_in_current = sorted(set(baseline_poses.keys()) - set(current_poses.keys())) + if missing_in_baseline: + logger.warning( + "No common images for %s; missing in baseline (sample): %s", + cluster_label, + ", ".join(missing_in_baseline[:5]), + ) + if missing_in_current: + logger.warning( + "No common images for %s; missing in current (sample): %s", + cluster_label, + ", ".join(missing_in_current[:5]), + ) + else: + logger.info("Common images for %s: %d", cluster_label, len(common_names)) + baseline_list = [baseline_poses[name] for name in common_names] + current_list = [current_poses[name] for name in common_names] + return common_names, baseline_list, current_list + + +def _compute_pose_metrics(baseline_list: List[Pose3], current_aligned_list: List[Pose3]) -> GtsfmMetricsGroup: + """Compute the same pose metrics as compare_colmap_outputs, without plotting.""" + baseline_dict = {i: pose for i, pose in enumerate(baseline_list)} + current_dict = {i: pose for i, pose in enumerate(current_aligned_list)} + i2Ri1_dict_gt, i2Ui1_dict_gt = metric_utils.get_all_relative_rotations_translations(baseline_dict) + wRi_aligned_dict, wti_aligned_dict = metric_utils.get_rotations_translations_from_poses(current_dict) + baseline_wRi_dict, baseline_wti_dict = metric_utils.get_rotations_translations_from_poses(baseline_dict) + + metrics = [] + metrics.append(metric_utils.compute_rotation_angle_metric(wRi_aligned_dict, baseline_wRi_dict)) + metrics.append(metric_utils.compute_translation_distance_metric(wti_aligned_dict, baseline_wti_dict)) + metrics.append(metric_utils.compute_translation_angle_metric(baseline_dict, current_dict)) + relative_rotation_error_metric = metric_utils.compute_relative_rotation_angle_metric( + i2Ri1_dict_gt, current_dict, store_full_data=True + ) + metrics.append(relative_rotation_error_metric) + relative_translation_error_metric = metric_utils.compute_relative_translation_angle_metric( + i2Ui1_dict_gt, current_dict, store_full_data=True + ) + metrics.append(relative_translation_error_metric) + thresholds_deg = (1.0, 2.5, 5.0, 10.0, 20.0) + if relative_rotation_error_metric.data is not None: + rotation_angular_errors = np.asarray(relative_rotation_error_metric.data) + rotation_auc_values = metric_utils.pose_auc(rotation_angular_errors, thresholds_deg) + metrics.extend( + [ + GtsfmMetric(f"rotation_auc_@{threshold}_deg", auc) + for threshold, auc in zip(thresholds_deg, rotation_auc_values) + ] + ) + if relative_translation_error_metric.data is not None: + translation_angular_errors = np.asarray(relative_translation_error_metric.data) + translation_auc_values = metric_utils.pose_auc(translation_angular_errors, thresholds_deg) + metrics.extend( + [ + GtsfmMetric(f"translation_auc_@{threshold}_deg", auc) + for threshold, auc in zip(thresholds_deg, translation_auc_values) + ] + ) + metrics.extend( + metric_utils.compute_pose_auc_metric( + relative_rotation_error_metric.data, relative_translation_error_metric.data, thresholds_deg=thresholds_deg + ) + ) + + return GtsfmMetricsGroup(name="ba_pose_error_metrics", metrics=metrics) + + +def _estimate_sim3_ransac( + baseline_list: List[Pose3], + current_list: List[Pose3], + max_hypotheses: int, + inlier_thresh: float, + rng: np.random.Generator, + cluster_label: str, +) -> align.Similarity3: + """Estimate Sim(3) using simple RANSAC over camera centers with refit on inliers.""" + n_to_align = len(baseline_list) + if n_to_align < 2: + logger.warning("SIM(3) alignment uses at least 2 frames; Skipping for %s", cluster_label) + return align.Similarity3(Rot3(), np.zeros(3), 1.0) + + baseline_centers = np.stack([pose.translation() for pose in baseline_list]) + current_centers = np.stack([pose.translation() for pose in current_list]) + best_inliers: Optional[np.ndarray] = None + best_count = -1 + best_mean_error = float("inf") + best_aSb: Optional[align.Similarity3] = None + + for _ in range(max_hypotheses): + sample_idx = rng.choice(n_to_align, size=2, replace=False) + baseline_sample = {i: baseline_list[idx] for i, idx in enumerate(sample_idx)} + current_sample = {i: current_list[idx] for i, idx in enumerate(sample_idx)} + try: + aSb_candidate = align.sim3_from_Pose3_maps(baseline_sample, current_sample) + except Exception: + continue + transformed = np.stack([aSb_candidate.transformFrom(p) for p in current_centers]) + errors = np.linalg.norm(baseline_centers - transformed, axis=1) + inliers = errors <= inlier_thresh + count = int(np.count_nonzero(inliers)) + mean_error = float(errors[inliers].mean()) if count > 0 else float("inf") + if count > best_count or (count == best_count and mean_error < best_mean_error): + best_count = count + best_mean_error = mean_error + best_inliers = inliers + best_aSb = aSb_candidate + + if best_aSb is None or best_inliers is None: + logger.warning("Robust Sim3 failed; falling back to all-poses alignment for %s", cluster_label) + baseline_dict = {i: pose for i, pose in enumerate(baseline_list)} + current_dict = {i: pose for i, pose in enumerate(current_list)} + return align.sim3_from_Pose3_maps(baseline_dict, current_dict) + + inlier_indices = np.where(best_inliers)[0] + if len(inlier_indices) < 2: + logger.warning( + "Robust Sim3 inliers too few (%d/%d); using best hypothesis for %s", + len(inlier_indices), + n_to_align, + cluster_label, + ) + return best_aSb + + baseline_inliers = {i: baseline_list[idx] for i, idx in enumerate(inlier_indices)} + current_inliers = {i: current_list[idx] for i, idx in enumerate(inlier_indices)} + aSb_refit = align.sim3_from_Pose3_maps(baseline_inliers, current_inliers) + logger.info( + "Robust Sim3 for %s: inliers=%d/%d, thresh=%.3f", + cluster_label, + len(inlier_indices), + n_to_align, + inlier_thresh, + ) + return aSb_refit + + +def _align_poses( + baseline_list: List[Pose3], + current_list: List[Pose3], + use_ransac: bool, + max_hypotheses: int, + inlier_thresh: float, + rng: np.random.Generator, + cluster_label: str, +) -> Tuple[List[Pose3], align.Similarity3]: + """Align current poses to baseline using Sim(3), optionally with RANSAC+refit.""" + baseline_dict = {i: pose for i, pose in enumerate(baseline_list)} + current_dict = {i: pose for i, pose in enumerate(current_list)} + if use_ransac: + aSb = _estimate_sim3_ransac(baseline_list, current_list, max_hypotheses, inlier_thresh, rng, cluster_label) + else: + aSb = align.sim3_from_Pose3_maps(baseline_dict, current_dict) + current_aligned_list = [aSb.transformFrom(pose) for pose in current_list] + return current_aligned_list, aSb + + +def _plot_camera_centers( + baseline_list: List[Pose3], + current_list: List[Pose3], + output_path: Path, + title: str, +) -> None: + """Save a 3D scatter plot of baseline and current camera centers.""" + baseline_centers = np.stack([pose.translation() for pose in baseline_list]) + current_centers_list = [pose.translation() for pose in current_list] + current_centers = np.stack(current_centers_list) if current_centers_list else np.empty((0, 3)) + + fig = plt.figure(figsize=(7, 7)) + ax = fig.add_subplot(111, projection="3d") + if baseline_centers.size: + center = baseline_centers.mean(axis=0) + mean_radius = np.linalg.norm(baseline_centers - center, axis=1).mean() + arrow_len = max(mean_radius * 0.15, 1e-3) + else: + arrow_len = 1.0 + + for pose in baseline_list: + origin = pose.transformFrom(np.array([0.0, 0.0, 0.0])) + tip = pose.transformFrom(np.array([0.0, 0.0, arrow_len])) + direction = tip - origin + ax.quiver( + origin[0], + origin[1], + origin[2], + direction[0], + direction[1], + direction[2], + color="tab:blue", + linewidth=0.5, + arrow_length_ratio=0.2, + alpha=0.6, + ) + for pose in current_list: + origin = pose.transformFrom(np.array([0.0, 0.0, 0.0])) + tip = pose.transformFrom(np.array([0.0, 0.0, arrow_len])) + direction = tip - origin + ax.quiver( + origin[0], + origin[1], + origin[2], + direction[0], + direction[1], + direction[2], + color="tab:orange", + linewidth=0.5, + arrow_length_ratio=0.2, + alpha=0.6, + ) + + ax.scatter( + baseline_centers[:, 0], + baseline_centers[:, 1], + baseline_centers[:, 2], + s=10, + c="tab:blue", + label="baseline", + ) + if current_centers.size: + ax.scatter( + current_centers[:, 0], + current_centers[:, 1], + current_centers[:, 2], + s=10, + c="tab:orange", + label="current", + ) + ax.set_xlabel("x") + ax.set_ylabel("y") + ax.set_zlabel("z") + ax.legend(loc="best") + if title: + wrapped_lines = [] + for line in title.splitlines(): + wrapped_lines.extend(textwrap.wrap(line, width=80) or [""]) + wrapped = "\n".join(wrapped_lines) + else: + wrapped = "" + if wrapped: + fig.suptitle(wrapped, fontsize=9, y=0.98) + fig.tight_layout(rect=[0, 0, 1, 0.92]) + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(str(output_path), dpi=300) + plt.close(fig) + + +def _summarize_pose_errors( + baseline_list: List[Pose3], + current_aligned_list: List[Pose3], + cluster_label: str, +) -> None: + """Log median/mean absolute pose errors after alignment.""" + if not baseline_list or not current_aligned_list: + return + rot_errors_deg: List[float] = [] + trans_errors: List[float] = [] + for baseline_pose, current_pose in zip(baseline_list, current_aligned_list): + rel = baseline_pose.between(current_pose) + rot_vec = Rot3.Logmap(rel.rotation()) + rot_errors_deg.append(float(np.rad2deg(np.linalg.norm(rot_vec)))) + trans_errors.append(float(np.linalg.norm(rel.translation()))) + logger.info( + "Pose errors for %s: rot_deg median=%.3f mean=%.3f; trans median=%.3f mean=%.3f", + cluster_label, + float(np.median(rot_errors_deg)), + float(np.mean(rot_errors_deg)), + float(np.median(trans_errors)), + float(np.mean(trans_errors)), + ) + + +def _plot_pose_auc_boxplot(auc_values_by_label: Dict[str, List[float]], output_path: Path, title: str) -> None: + """Save box plots for AUC metrics across all clusters.""" + preferred_order = ["@1.0_deg", "@2.5_deg", "@5.0_deg", "@10.0_deg", "@20.0_deg"] + labels = [label for label in preferred_order if auc_values_by_label.get(label)] + if not labels: + labels = sorted(auc_values_by_label.keys()) + data = [auc_values_by_label[label] for label in labels if auc_values_by_label.get(label)] + labels = [label for label in labels if auc_values_by_label.get(label)] + if not data: + return + + fig = plt.figure(figsize=(6, 4)) + ax = fig.add_subplot(111) + ax.boxplot(data, vert=True, patch_artist=True) + ax.set_title(title) + ax.set_ylabel("AUC") + ax.set_xticks(range(1, len(labels) + 1)) + ax.set_xticklabels(labels, rotation=30, ha="right") + stats_lines = [] + for label, values in zip(labels, data): + mean_val = float(np.mean(values)) + median_val = float(np.median(values)) + stats_lines.append(f"{label}: mean={mean_val:.3f}, med={median_val:.3f}") + if stats_lines: + ax.text( + 0.02, + 0.98, + "\n".join(stats_lines), + transform=ax.transAxes, + va="top", + ha="left", + fontsize=8, + bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8), + ) + fig.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(str(output_path), dpi=300) + plt.close(fig) + + +def _plot_pose_auc_vs_input_images( + auc_by_label_and_count: Dict[str, List[Tuple[int, float]]], + output_path: Path, +) -> None: + """Plot pose AUC at each threshold vs. input image count across clusters.""" + preferred_order = ["@1.0_deg", "@2.5_deg", "@5.0_deg", "@10.0_deg", "@20.0_deg"] + labels = [label for label in preferred_order if auc_by_label_and_count.get(label)] + if not labels: + labels = sorted(auc_by_label_and_count.keys()) + + fig = plt.figure(figsize=(7, 5)) + ax = fig.add_subplot(111) + for label in labels: + pairs = auc_by_label_and_count.get(label, []) + if not pairs: + continue + pairs_sorted = sorted(pairs, key=lambda pair: pair[0]) + x_vals = [pair[0] for pair in pairs_sorted] + y_vals = [pair[1] for pair in pairs_sorted] + ax.plot(x_vals, y_vals, marker="o", linewidth=1.0, markersize=4, alpha=0.85, label=label) + + ax.set_title("Pose AUC vs input images (all clusters)") + ax.set_xlabel("input images (current count)") + ax.set_ylabel("AUC") + ax.grid(True, linestyle="--", linewidth=0.5, alpha=0.5) + ax.legend(loc="best", fontsize=8) + fig.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(str(output_path), dpi=300) + plt.close(fig) + + +def _plot_intrinsics_deltas_boxplot(deltas: Dict[str, List[float]], output_path: Path, title: str) -> None: + """Save box plots for normalized intrinsics deltas for a cluster.""" + labels = ["delta_fx_norm", "delta_fy_norm", "delta_cx_norm", "delta_cy_norm"] + data = [deltas.get(label, []) for label in labels] + if not any(data): + return + + fig = plt.figure(figsize=(6, 4)) + ax = fig.add_subplot(111) + ax.boxplot(data, vert=True, patch_artist=True) + ax.set_title(title) + ax.set_ylabel("normalized by baseline value") + ax.set_xticks(range(1, len(labels) + 1)) + ax.set_xticklabels(labels, rotation=20, ha="right") + stats_lines = [] + for label, values in zip(labels, data): + if not values: + continue + mean_val = float(np.mean(values)) + median_val = float(np.median(values)) + stats_lines.append(f"{label}: mean={mean_val:.3f}, med={median_val:.3f}") + if stats_lines: + ax.text( + 0.02, + 0.98, + "\n".join(stats_lines), + transform=ax.transAxes, + va="top", + ha="left", + fontsize=8, + bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8), + ) + fig.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(str(output_path), dpi=300) + plt.close(fig) + + +def _plot_fov_deltas_boxplot(deltas: Dict[str, List[float]], output_path: Path, title: str) -> None: + """Save box plots for FOV deltas (degrees) for a cluster.""" + labels = ["delta_fovx_deg", "delta_fovy_deg"] + data = [deltas.get(label, []) for label in labels] + if not any(data): + return + + fig = plt.figure(figsize=(6, 4)) + ax = fig.add_subplot(111) + ax.boxplot(data, vert=True, patch_artist=True) + ax.set_title(title) + ax.set_ylabel("degrees") + ax.set_xticks(range(1, len(labels) + 1)) + ax.set_xticklabels(labels, rotation=20, ha="right") + stats_lines = [] + for label, values in zip(labels, data): + if not values: + continue + mean_val = float(np.mean(values)) + median_val = float(np.median(values)) + stats_lines.append(f"{label}: mean={mean_val:.3f}, med={median_val:.3f}") + if stats_lines: + ax.text( + 0.02, + 0.98, + "\n".join(stats_lines), + transform=ax.transAxes, + va="top", + ha="left", + fontsize=8, + bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8), + ) + fig.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(str(output_path), dpi=300) + plt.close(fig) + + +def _print_metrics(label: str, metrics_group: GtsfmMetricsGroup) -> None: + logger.info("=== %s ===", label) + for metric in metrics_group.metrics: + if metric.dim == 0: + value = "" if metric.data is None else f"{float(metric.data):.6f}" + logger.info("%s: %s", metric.name, value) + else: + logger.info("%s: %s", metric.name, json.dumps(metric.summary, sort_keys=True)) + + +def _format_auc(metrics_group: GtsfmMetricsGroup, prefix: str) -> str: + auc_parts = [] + prefix_token = f"{prefix}_@" + for metric in metrics_group.metrics: + if not metric.name.startswith(prefix_token): + continue + if metric.data is None: + continue + try: + value = float(metric.data) + except (TypeError, ValueError): + continue + suffix = metric.name.replace(f"{prefix}_", "") + auc_parts.append(f"{suffix}={value:.3f}") + return ", ".join(auc_parts) + + +def export_metrics_group_to_csv( + metrics_group: GtsfmMetricsGroup, + cluster_label: str, + baseline_count: int, + current_count: int, + common_count: int, + output_path: Path, + rows: List[Dict[str, str]], +) -> None: + """Append metrics for a cluster into a shared CSV row list.""" + auc_values: List[float] = [] + for metric in metrics_group.metrics: + if metric.dim == 0: + value = "" if metric.data is None else f"{float(metric.data):.6f}" + if metric.name.startswith("pose_auc_@") and metric.data is not None: + try: + auc_values.append(float(metric.data)) + except (TypeError, ValueError): + pass + else: + value = json.dumps(metric.summary, sort_keys=True) + rows.append( + { + "cluster": cluster_label, + "baseline_count": str(baseline_count), + "current_count": str(current_count), + "common_count": str(common_count), + "metric_name": metric.name, + "value": value, + } + ) + if auc_values: + rows.append( + { + "cluster": cluster_label, + "baseline_count": str(baseline_count), + "current_count": str(current_count), + "common_count": str(common_count), + "metric_name": "pose_auc_avg", + "value": f"{float(np.mean(auc_values)):.6f}", + } + ) + + if output_path.exists() and output_path.stat().st_size > 0: + return + + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w", newline="") as csvfile: + writer = csv.DictWriter( + csvfile, + fieldnames=["cluster", "baseline_count", "current_count", "common_count", "metric_name", "value"], + ) + writer.writeheader() + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--baseline", required=True, help="Path to baseline COLMAP directory.") + parser.add_argument("--root", required=True, help="Root directory to traverse for cluster reconstructions.") + parser.add_argument("--recon_name", default="vggt", help="Subdirectory name for reconstructions.") + parser.add_argument( + "--csv_output", + default=None, + help="Optional path to a single CSV file for all cluster metrics.", + ) + parser.add_argument( + "--fig_output_dir", + default=None, + help="Optional directory to save per-cluster camera_centers.png plots.", + ) + parser.add_argument( + "--robust_sim3", + action="store_true", + default=False, + help="Use simple RANSAC+refit for Sim(3) alignment.", + ) + parser.add_argument( + "--robust_sim3_max_hypotheses", + type=int, + default=200, + help="Max RANSAC hypotheses for robust Sim(3) alignment.", + ) + parser.add_argument( + "--robust_sim3_inlier_thresh", + type=float, + default=0.1, + help="Inlier threshold on camera-center error for robust Sim(3) alignment.", + ) + parser.add_argument( + "--robust_sim3_seed", + type=int, + default=0, + help="Random seed for robust Sim(3) alignment.", + ) + args = parser.parse_args() + + baseline_images = Path(args.baseline) / "images.txt" + baseline_cameras_txt = Path(args.baseline) / "cameras.txt" + baseline_poses, baseline_camera_by_name = _read_images_txt_with_names_and_cameras(baseline_images) + baseline_cameras = _read_cameras_txt_with_ids(baseline_cameras_txt) + fig_output_dir = Path(args.fig_output_dir) if args.fig_output_dir else None + if fig_output_dir is None and args.csv_output: + fig_output_dir = Path(args.csv_output).parent / "cluster_camera_centers" + + root = Path(args.root) + recon_dirs = sorted(_find_cluster_recon_dirs(root, args.recon_name)) + if not recon_dirs: + raise FileNotFoundError(f"No reconstructions named '{args.recon_name}' with images.txt under {root}") + + logger.info("Found %d reconstructions under %s", len(recon_dirs), root) + + csv_rows: List[Dict[str, str]] = [] + all_pose_auc_values: Dict[str, List[float]] = {} + all_pose_auc_by_label_and_count: Dict[str, List[Tuple[int, float]]] = {} + all_rotation_auc_values: Dict[str, List[float]] = {} + all_translation_auc_values: Dict[str, List[float]] = {} + all_intrinsics_deltas: Dict[str, List[float]] = { + "delta_fx_norm": [], + "delta_fy_norm": [], + "delta_cx_norm": [], + "delta_cy_norm": [], + } + all_fov_deltas: Dict[str, List[float]] = { + "delta_fovx_deg": [], + "delta_fovy_deg": [], + } + rng = np.random.default_rng(args.robust_sim3_seed) + for recon_dir in recon_dirs: + current_images = recon_dir / "images.txt" + current_cameras_txt = recon_dir / "cameras.txt" + current_poses, current_camera_by_name = _read_images_txt_with_names_and_cameras(current_images) + try: + current_cameras = _read_cameras_txt_with_ids(current_cameras_txt) + except FileNotFoundError: + logger.warning("Missing cameras.txt for %s; skipping intrinsics comparison.", recon_dir) + current_cameras = {} + common_names, baseline_list, current_list = _build_pose_lists( + baseline_poses, current_poses, cluster_label=str(recon_dir) + ) + baseline_count = len(baseline_poses) + current_count = len(current_poses) + common_count = len(common_names) + if len(common_names) < 2: + logger.warning( + "Skipping %s (baseline=%d, current=%d, common=%d)", + recon_dir, + baseline_count, + current_count, + common_count, + ) + continue + current_aligned_list, _aSb = _align_poses( + baseline_list, + current_list, + use_ransac=args.robust_sim3, + max_hypotheses=args.robust_sim3_max_hypotheses, + inlier_thresh=args.robust_sim3_inlier_thresh, + rng=rng, + cluster_label=str(recon_dir), + ) + metrics_group = _compute_pose_metrics(baseline_list, current_aligned_list) + _summarize_pose_errors(baseline_list, current_aligned_list, str(recon_dir)) + intrinsics_deltas: Dict[str, List[float]] = { + "delta_fx_norm": [], + "delta_fy_norm": [], + "delta_cx_norm": [], + "delta_cy_norm": [], + } + fov_deltas: Dict[str, List[float]] = { + "delta_fovx_deg": [], + "delta_fovy_deg": [], + } + for name in common_names: + base_cam_id = baseline_camera_by_name.get(name) + curr_cam_id = current_camera_by_name.get(name) + if base_cam_id is None or curr_cam_id is None: + continue + base = baseline_cameras.get(base_cam_id) + curr = current_cameras.get(curr_cam_id) + if base is None or curr is None: + continue + base_w, base_h = base["width"], base["height"] + curr_w, curr_h = curr["width"], curr["height"] + if curr_w > 0 and curr_h > 0 and (base_w != curr_w or base_h != curr_h): + sx = base_w / curr_w + sy = base_h / curr_h + curr_fx = curr["fx"] * sx + curr_fy = curr["fy"] * sy + curr_cx = curr["cx"] * sx + curr_cy = curr["cy"] * sy + else: + curr_fx = curr["fx"] + curr_fy = curr["fy"] + curr_cx = curr["cx"] + curr_cy = curr["cy"] + if base_w > 0 and base_h > 0: + if base["fx"] != 0: + intrinsics_deltas["delta_fx_norm"].append(abs(base["fx"] - curr_fx) / abs(base["fx"])) + if base["fy"] != 0: + intrinsics_deltas["delta_fy_norm"].append(abs(base["fy"] - curr_fy) / abs(base["fy"])) + if base["cx"] != 0: + intrinsics_deltas["delta_cx_norm"].append(abs(base["cx"] - curr_cx) / abs(base["cx"])) + if base["cy"] != 0: + intrinsics_deltas["delta_cy_norm"].append(abs(base["cy"] - curr_cy) / abs(base["cy"])) + base_fovx = 2.0 * np.degrees(np.arctan(base_w / (2.0 * base["fx"]))) + base_fovy = 2.0 * np.degrees(np.arctan(base_h / (2.0 * base["fy"]))) + curr_fovx = 2.0 * np.degrees(np.arctan(base_w / (2.0 * curr_fx))) + curr_fovy = 2.0 * np.degrees(np.arctan(base_h / (2.0 * curr_fy))) + fov_deltas["delta_fovx_deg"].append(abs(base_fovx - curr_fovx)) + fov_deltas["delta_fovy_deg"].append(abs(base_fovy - curr_fovy)) + for key, values in intrinsics_deltas.items(): + all_intrinsics_deltas[key].extend(values) + for key, values in fov_deltas.items(): + all_fov_deltas[key].extend(values) + if args.csv_output: + export_metrics_group_to_csv( + metrics_group, + cluster_label=str(recon_dir), + baseline_count=baseline_count, + current_count=current_count, + common_count=common_count, + output_path=Path(args.csv_output), + rows=csv_rows, + ) + else: + _print_metrics(str(recon_dir), metrics_group) + if fig_output_dir is not None: + safe_name = str(recon_dir).replace(os.sep, "__") + plot_path = fig_output_dir / f"{safe_name}_camera_centers.png" + pose_auc_text = _format_auc(metrics_group, "pose_auc") + rotation_auc_text = _format_auc(metrics_group, "rotation_auc") + translation_auc_text = _format_auc(metrics_group, "translation_auc") + title_lines = [f"{recon_dir}"] + if pose_auc_text: + title_lines.append(f"Pose AUC: {pose_auc_text}") + if rotation_auc_text: + title_lines.append(f"Rotation AUC: {rotation_auc_text}") + if translation_auc_text: + title_lines.append(f"Translation AUC: {translation_auc_text}") + title = "\n".join(title_lines) + _plot_camera_centers(baseline_list, current_aligned_list, plot_path, title) + # Intrinsics stats are annotated in the plot; no terminal logging. + for metric in metrics_group.metrics: + if metric.name.startswith("pose_auc_@") and metric.data is not None: + try: + value = float(metric.data) + except (TypeError, ValueError): + continue + label = metric.name.replace("pose_auc_", "") + all_pose_auc_values.setdefault(label, []).append(value) + all_pose_auc_by_label_and_count.setdefault(label, []).append((current_count, value)) + elif metric.name.startswith("rotation_auc_@") and metric.data is not None: + try: + value = float(metric.data) + except (TypeError, ValueError): + continue + label = metric.name.replace("rotation_auc_", "") + all_rotation_auc_values.setdefault(label, []).append(value) + elif metric.name.startswith("translation_auc_@") and metric.data is not None: + try: + value = float(metric.data) + except (TypeError, ValueError): + continue + label = metric.name.replace("translation_auc_", "") + all_translation_auc_values.setdefault(label, []).append(value) + + if args.csv_output and csv_rows: + output_path = Path(args.csv_output) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("a", newline="") as csvfile: + writer = csv.DictWriter( + csvfile, + fieldnames=["cluster", "baseline_count", "current_count", "common_count", "metric_name", "value"], + ) + writer.writerows(csv_rows) + if fig_output_dir is not None and all_pose_auc_values: + auc_plot_path = fig_output_dir / "pose_auc_boxplot_all_clusters.png" + _plot_pose_auc_boxplot(all_pose_auc_values, auc_plot_path, "Pose AUC by threshold (all clusters)") + if fig_output_dir is not None and all_pose_auc_by_label_and_count: + auc_vs_images_plot_path = fig_output_dir / "pose_auc_vs_input_images.png" + _plot_pose_auc_vs_input_images(all_pose_auc_by_label_and_count, auc_vs_images_plot_path) + if fig_output_dir is not None and all_rotation_auc_values: + rotation_auc_plot_path = fig_output_dir / "rotation_auc_boxplot_all_clusters.png" + _plot_pose_auc_boxplot( + all_rotation_auc_values, rotation_auc_plot_path, "Rotation AUC by threshold (all clusters)" + ) + if fig_output_dir is not None and all_translation_auc_values: + translation_auc_plot_path = fig_output_dir / "translation_auc_boxplot_all_clusters.png" + _plot_pose_auc_boxplot( + all_translation_auc_values, + translation_auc_plot_path, + "Translation AUC by threshold (all clusters)", + ) + if fig_output_dir is not None and any(all_intrinsics_deltas.values()): + intrinsics_plot_path = fig_output_dir / "intrinsics_deltas_all_clusters.png" + _plot_intrinsics_deltas_boxplot( + all_intrinsics_deltas, + intrinsics_plot_path, + "Intrinsics Δ (normalized, all clusters)", + ) + if fig_output_dir is not None and any(all_fov_deltas.values()): + fov_plot_path = fig_output_dir / "fov_deltas_all_clusters.png" + _plot_fov_deltas_boxplot( + all_fov_deltas, + fov_plot_path, + "FOV Δ (degrees, all clusters)", + ) + + +if __name__ == "__main__": + main() diff --git a/gtsfm/evaluation/metrics.py b/gtsfm/evaluation/metrics.py index 470ad6bb5..23b4b3f9f 100644 --- a/gtsfm/evaluation/metrics.py +++ b/gtsfm/evaluation/metrics.py @@ -208,7 +208,7 @@ def get_metric_as_dict(self) -> Dict[str, Any]: The metric as a dict representation explained above. """ if self._dim == 0: - return {self._name: self._data.tolist()} + return {self._name: round(self._data.tolist(), 4)} metric_dict = {SUMMARY_KEY: self.summary} if self._data is not None: metric_dict[FULL_DATA_KEY] = self._data.tolist() diff --git a/gtsfm/frontend/vggt.py b/gtsfm/frontend/vggt.py index bdb4bab9c..369c74bae 100644 --- a/gtsfm/frontend/vggt.py +++ b/gtsfm/frontend/vggt.py @@ -9,13 +9,14 @@ from importlib.machinery import ModuleSpec from pathlib import Path from types import ModuleType -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np import torch -import torch.nn.functional as F from gtsam import Point2, Point3 +from PIL import Image as PILImage from torch.amp import autocast as amp_autocast # type: ignore +from torchvision import transforms as TF from gtsfm.bundle.bundle_adjustment import BundleAdjustmentOptimizer from gtsfm.common.gtsfm_data import GtsfmData @@ -97,14 +98,11 @@ def _import_from_vanilla_vggt(module_suffix: str) -> ModuleType: logger.info("⚡ FastVGGT enabled via thirdparty/FastVGGT.") else: logger.info("📷 Using vanilla VGGT (FastVGGT submodule not detected).") +from vggt.utils.geometry import unproject_depth_map_to_point_map # type: ignore from vggt.utils.helper import randomly_limit_trues # type: ignore from vggt.utils.load_fn import load_and_preprocess_images_square # type: ignore from vggt.utils.pose_enc import pose_encoding_to_extri_intri # type: ignore -from gtsfm.frontend.anysplat import ( - batchify_unproject_depth_map_to_point_map as _anysplat_batchify_unproject, -) # type: ignore - DEFAULT_FIXED_RESOLUTION = 518 _DTYPE_ALIASES: dict[str, torch.dtype] = { @@ -138,12 +136,144 @@ def _resolve_dtype_argument(arg: Optional[Union[str, torch.dtype]]) -> Optional[ raise TypeError(f"Unsupported dtype specifier of type {type(arg)!r}: {arg!r}") +def load_image_batch_vggt_loader(loader, indices: List[int], mode="crop"): + """ + A quick start function to load and preprocess images for model input. + This assumes the images should have the same shape for easier batching, + but VGGT model can also work well with different shapes. + + Args: + loader: Loader instance providing ``get_image``. + indices: List of image indices to load. + mode (str, optional): Preprocessing mode, either "crop" or "pad". + - "crop" (default): Sets width to 518px and center crops height if needed. + - "pad": Preserves all pixels by making the largest dimension 518px + and padding the smaller dimension to reach a square shape. + + Returns: + torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W) + + Raises: + ValueError: If the input list is empty or if mode is invalid + + Notes: + - Images with different dimensions will be padded with white (value=1.0) + - A warning is printed when images have different shapes + - When mode="crop": The function ensures width=518px while maintaining aspect ratio + and height is center-cropped if larger than 518px + - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio + and the smaller dimension is padded to reach a square shape (518x518) + - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements + """ + # Check for empty list + if len(indices) == 0: + raise ValueError("At least 1 image is required") + + # Validate mode + if mode not in ["crop", "pad"]: + raise ValueError("Mode must be either 'crop' or 'pad'") + + images = [] + shapes = set() + to_tensor = TF.ToTensor() + target_size = 518 + + # First process all images and collect their shapes + for idx in indices: + # Open image + img = loader.get_image(idx).value_array + + img = PILImage.fromarray(img) + + width, height = img.size + + if mode == "pad": + # Make the largest dimension 518px while maintaining aspect ratio + if width >= height: + new_width = target_size + new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14 + else: + new_height = target_size + new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14 + else: # mode == "crop" + # Original behavior: set width to 518px + new_width = target_size + # Calculate height maintaining aspect ratio, divisible by 14 + new_height = round(height * (new_width / width) / 14) * 14 + + # Resize with new dimensions (width, height) + img = img.resize((new_width, new_height), PILImage.Resampling.BICUBIC) + img = to_tensor(img) # Convert to tensor (0, 1) + + # Center crop height if it's larger than 518 (only in crop mode) + if mode == "crop" and new_height > target_size: + start_y = (new_height - target_size) // 2 + img = img[:, start_y : start_y + target_size, :] + + # For pad mode, pad to make a square of target_size x target_size + if mode == "pad": + h_padding = target_size - img.shape[1] + w_padding = target_size - img.shape[2] + + if h_padding > 0 or w_padding > 0: + pad_top = h_padding // 2 + pad_bottom = h_padding - pad_top + pad_left = w_padding // 2 + pad_right = w_padding - pad_left + + # Pad with white (value=1.0) + img = torch.nn.functional.pad( + img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 + ) + + shapes.add((img.shape[1], img.shape[2])) + images.append(img) + + # Check if we have different shapes + # In theory our model can also work well with different shapes + if len(shapes) > 1: + logger.warning("Found images with different shapes: %s", shapes) + # Find maximum dimensions + max_height = max(shape[0] for shape in shapes) + max_width = max(shape[1] for shape in shapes) + + # Pad images if necessary + padded_images = [] + for img in images: + h_padding = max_height - img.shape[1] + w_padding = max_width - img.shape[2] + + if h_padding > 0 or w_padding > 0: + pad_top = h_padding // 2 + pad_bottom = h_padding - pad_top + pad_left = w_padding // 2 + pad_right = w_padding - pad_left + + img = torch.nn.functional.pad( + img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 + ) + padded_images.append(img) + images = padded_images + + images = torch.stack(images) # concatenate images + + # Ensure correct shape when single image + if len(indices) == 1: + # Verify shape is (1, C, H, W) + if images.dim() == 3: + images = images.unsqueeze(0) + + height, width = images.shape[-2], images.shape[-1] + coords = np.tile([0.0, 0.0, float(width), float(height), float(width), float(height)], (len(indices), 1)) + original_coords_tensor = torch.from_numpy(coords).float() + + return images, original_coords_tensor + + @dataclass class VggtConfiguration: """Configuration for the high-level VGGT reconstruction pipeline.""" - img_load_resolution: int = 1024 - vggt_fixed_resolution: int = DEFAULT_FIXED_RESOLUTION seed: int = 42 confidence_threshold: float = 5.0 max_num_points: int = 100000 @@ -151,14 +281,15 @@ class VggtConfiguration: model_ctor_kwargs: dict[str, Any] = field(default_factory=dict) use_sparse_attention: bool = False run_bundle_adjustment_on_leaf: bool = False + store_pre_ba_result: bool = False # Tracking-specific parameters: tracking: bool = True - max_query_pts: int = 1000 - query_frame_num: int = 4 - keypoint_extractor: str = "aliked+sp" - fine_tracking: bool = True - track_vis_thresh: float = 0.2 + max_query_pts: int = 2048 + query_frame_num: int = 3 + track_vis_thresh: float = 0.05 + track_conf_thresh: float = 0.2 + keypoint_extractor: str = "aliked+sp+sift" max_reproj_error: float = 8.0 @@ -168,7 +299,7 @@ class VggtOutput: # TODO(Frank): derive from base class shared with AnySplat (i device: torch.device dtype: torch.dtype - resized_images: torch.Tensor + images: torch.Tensor extrinsic: torch.Tensor intrinsic: torch.Tensor depth_map: torch.Tensor @@ -181,7 +312,8 @@ class VggtReconstruction: """Outputs from the VGGT reconstruction helper. Attributes: - gtsfm_data: Sparse scene estimate including cameras and tracks in original image coordinates. + gtsfm_data: Sparse scene estimate (post-BA if enabled) in original image coordinates. + pre_ba_data: Optional sparse scene estimate before bundle adjustment (debug-only). points_3d: Dense point cloud filtered by VGGT confidence scores. points_rgb: Per-point RGB colors aligned with ``points_3d``. tracking_result: Optional dense tracking payload in the square VGGT coordinate frame. @@ -190,6 +322,7 @@ class VggtReconstruction: gtsfm_data: GtsfmData points_3d: np.ndarray points_rgb: np.ndarray + pre_ba_data: GtsfmData | None = None tracking_result: "VGGTTrackingResult | None" = None def visualize_tracks( @@ -310,79 +443,14 @@ def load_model( return model -def _rescale_intrinsic_for_original_resolution( - intrinsic: np.ndarray, - reconstruction_resolution: int, - image_width: float, - image_height: float, -) -> np.ndarray: - """Adapt intrinsics estimated on a square crop back to the original image size.""" - resized = intrinsic.copy() - # print('image_width, image_height: ', image_width, image_height) - resize_ratio = max(image_width, image_height) / float(reconstruction_resolution) - resized[:2, :] *= resize_ratio - resized[0, 2] = image_width / 2.0 - resized[1, 2] = image_height / 2.0 - return resized - - -def _convert_measurement_to_original_resolution( - uv: Tuple[float, float], - original_coord: np.ndarray, - inference_resolution: int, - img_load_resolution: int, - *, - measurement_in_load_resolution: bool = False, -) -> Tuple[float, float]: - """Convert VGGT coordinates back to the original image coordinate system. - - Args: - uv: Input measurement in either inference or load resolution space. - original_coord: Metadata describing the crop location within the padded square, expressed at load resolution. - inference_resolution: Resolution of VGGT inference grid. - img_load_resolution: Resolution used when images were padded/resized prior to inference. - measurement_in_load_resolution: Set ``True`` if ``uv`` already lives in the load resolution. - """ - - x_infer, y_infer = uv - x1, y1 = original_coord[0], original_coord[1] - width, height = original_coord[4], original_coord[5] - - # VGGT runs on the ``img_load_resolution`` square; vggt_output down-samples that square to the - # (typically smaller) ``inference_resolution``. Undo that downscale so we can use the crop - # metadata stored in ``original_coord``. - if measurement_in_load_resolution: - x_load = x_infer - y_load = y_infer - else: - scale_back_to_load = float(img_load_resolution) / float(inference_resolution) - x_load = x_infer * scale_back_to_load - y_load = y_infer * scale_back_to_load - - # ``original_coord`` encodes the location of the original, possibly rectangular, image within - # the padded square (in *load* resolution). Remove the padding and scale the remaining pixels - # back to the native resolution. - max_side = float(max(width, height)) - resize_ratio = max_side / float(img_load_resolution) - u = (x_load - x1) * resize_ratio - v = (y_load - y1) * resize_ratio - - max_u = max(width - 0.5, 0.0) - max_v = max(height - 0.5, 0.0) - u = float(np.clip(u, 0.0, max_u)) - v = float(np.clip(v, 0.0, max_v)) - return u, v - - def _high_confidence_pointcloud(config: VggtConfiguration, vggt_output: VggtOutput) -> Tuple[np.ndarray, np.ndarray]: """Convert raw VGGT predictions into point attributes.""" points_3d = vggt_output.dense_points.to(torch.float32).cpu().numpy() - points_rgb = (vggt_output.resized_images.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) * 255).astype( - np.uint8 - ) + points_rgb = (vggt_output.images.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) * 255).astype(np.uint8) depth_conf_np = vggt_output.depth_confidence.to(torch.float32).cpu().numpy() - conf_mask = depth_conf_np >= config.confidence_threshold + conf_threshold = min(config.confidence_threshold, depth_conf_np.mean() - depth_conf_np.std()) + conf_mask = depth_conf_np >= conf_threshold conf_mask = randomly_limit_trues(conf_mask, config.max_num_points) # limit number of points if asked return points_3d[conf_mask], points_rgb[conf_mask] @@ -400,9 +468,6 @@ def _is_point_in_front_of_camera(camera, point_xyz: np.ndarray, *, epsilon: floa return float(z_val) > epsilon - - - def _convert_vggt_outputs_to_gtsfm_data( *, vggt_output: VggtOutput, @@ -413,7 +478,7 @@ def _convert_vggt_outputs_to_gtsfm_data( points_3d: np.ndarray, points_rgb: np.ndarray, tracking_result: VGGTTrackingResult | None = None, -) -> GtsfmData: +) -> tuple[GtsfmData, GtsfmData | None]: """Convert raw VGGT predictions into ``GtsfmData``.""" extrinsic_np = vggt_output.extrinsic.to(torch.float32).cpu().numpy() @@ -427,9 +492,8 @@ def _convert_vggt_outputs_to_gtsfm_data( image_width = float(original_coords_np[local_idx, 4]) image_height = float(original_coords_np[local_idx, 5]) - scaled_intrinsic = _rescale_intrinsic_for_original_resolution( - intrinsic_np[local_idx], config.vggt_fixed_resolution, image_width, image_height - ) + scaled_intrinsic = intrinsic_np[local_idx] + camera = torch_utils.camera_from_matrices(extrinsic_np[local_idx], scaled_intrinsic) gtsfm_data.add_camera(global_idx, camera) # type: ignore[arg-type] gtsfm_data.set_image_info( @@ -446,19 +510,27 @@ def _convert_vggt_outputs_to_gtsfm_data( if tracking_result: # track masks according to visibility, reprojection error, etc + max_reproj_error = float(config.max_reproj_error) track_mask = tracking_result.visibilities > config.track_vis_thresh - inlier_num = track_mask.sum(0) - valid_mask = inlier_num >= 2 # a track is invalid if without two inliers + confidence_threshold = config.track_conf_thresh + confidence_threshold = min( + confidence_threshold, np.mean(tracking_result.confidences) + np.std(tracking_result.confidences) + ) if tracking_result.confidences is not None: - valid_mask = np.logical_and(valid_mask, tracking_result.confidences > config.confidence_threshold) - valid_idx = np.nonzero(valid_mask)[0] + track_mask = np.logical_and(track_mask, tracking_result.confidences > confidence_threshold) - max_reproj_error = float(config.max_reproj_error) enforce_reproj_filter = ( tracking_result.points_3d is not None and np.isfinite(max_reproj_error) and max_reproj_error > 0.0 ) + inlier_num = track_mask.sum(0) + min_measurements = 2 + valid_mask = inlier_num >= min_measurements # a track is invalid if without two inliers + valid_idx = np.nonzero(valid_mask)[0] + + logger.info("num points 3d: %d, num valid idx: %d", tracking_result.points_3d.shape[0], len(valid_idx)) + for valid_id in valid_idx: rgb: np.ndarray if tracking_result.colors is not None and valid_id < tracking_result.colors.shape[0]: @@ -478,46 +550,41 @@ def _convert_vggt_outputs_to_gtsfm_data( camera = gtsfm_data.get_camera(global_idx) if not _is_point_in_front_of_camera(camera, point_xyz): continue - rescaled_u, rescaled_v = _convert_measurement_to_original_resolution( - (float(u), float(v)), - original_coords_np[local_id], - config.vggt_fixed_resolution, - config.img_load_resolution, - measurement_in_load_resolution=True, - ) if enforce_reproj_filter: projected = camera.project(gtsam_point) proj_u = float(projected[0]) proj_v = float(projected[1]) - reproj_err = float(np.hypot(rescaled_u - proj_u, rescaled_v - proj_v)) + reproj_err = float(np.hypot(u - proj_u, v - proj_v)) max_error_for_track = max(max_error_for_track, reproj_err) - per_track_measurements.append((global_idx, rescaled_u, rescaled_v)) + per_track_measurements.append((global_idx, u, v)) - if len(per_track_measurements) < 2: + if len(per_track_measurements) < min_measurements: continue if enforce_reproj_filter and max_error_for_track > max_reproj_error: continue - track = torch_utils.colored_track_from_point(point_xyz, rgb) - for global_idx, rescaled_u, rescaled_v in per_track_measurements: - track.addMeasurement(global_idx, Point2(rescaled_u, rescaled_v)) + for global_idx, float_u, float_v in per_track_measurements: + track.addMeasurement(global_idx, Point2(float_u, float_v)) gtsfm_data.add_track(track) + gtsfm_data_pre_ba: GtsfmData | None = None if config.run_bundle_adjustment_on_leaf: + if config.store_pre_ba_result: + gtsfm_data_pre_ba = gtsfm_data if gtsfm_data.number_tracks() == 0: logger.warning("Skipping bundle adjustment because VGGT produced no valid tracks.") else: try: gtsfm_data, should_run_ba = data_utils.remove_cameras_with_no_tracks(gtsfm_data, "node-level BA") if not should_run_ba: - return gtsfm_data - optimizer = BundleAdjustmentOptimizer() + return gtsfm_data, gtsfm_data_pre_ba + optimizer = BundleAdjustmentOptimizer(robust_measurement_noise=False, calibration_prior_noise_sigma=10) gtsfm_data_with_ba, _ = optimizer.run_simple_ba(gtsfm_data, verbose=False) - return gtsfm_data_with_ba + return gtsfm_data_with_ba, gtsfm_data_pre_ba except Exception as exc: logger.warning("⚠️ Failed to run bundle adjustment: %s", exc) - return gtsfm_data + return gtsfm_data, gtsfm_data_pre_ba def _offload_vggt_model(model: Optional[VGGT]) -> None: @@ -538,11 +605,7 @@ def run_VGGT( model: Optional[VGGT] = None, weights_path: PathLike | None = None, ) -> VggtOutput: - """Run VGGT on a batch of images and return raw model predictions. - - Set ``return_dense_points`` to ``True`` to additionally compute the full per-pixel - point cloud using the optional AnySplat acceleration path (when available). - """ + """Run VGGT and unproject depth using the geometry helper.""" if images.ndim != 4 or images.shape[1] != 3: raise ValueError("VGGT expects images shaped as (N, 3, H, W).") @@ -570,22 +633,10 @@ def run_VGGT( assert model is not None images = images.to(resolved_device, dtype=resolved_dtype) - res = cfg.vggt_fixed_resolution if cfg else DEFAULT_FIXED_RESOLUTION - resized_images = F.interpolate(images, size=(res, res), mode="bilinear", align_corners=False, antialias=True) - # print('resized_images: ', resized_images.shape) 518, 518 - - # FastVGGT requires the model to know the actual patch grid dimensions used for token merging. - patch_w = max(1, resized_images.shape[-1] // getattr(model.aggregator, "patch_size", 14)) - patch_h = max(1, resized_images.shape[-2] // getattr(model.aggregator, "patch_size", 14)) - if hasattr(model, "update_patch_dimensions"): - try: - model.update_patch_dimensions(patch_w, patch_h) - except Exception as exc: # pragma: no cover - best effort for FastVGGT compatibility - logger.warning("Failed to update VGGT patch dimensions (%dx%d): %s", patch_w, patch_h, exc) # FastVGGT requires the model to know the actual patch grid dimensions used for token merging. - patch_w = max(1, resized_images.shape[-1] // getattr(model.aggregator, "patch_size", 14)) - patch_h = max(1, resized_images.shape[-2] // getattr(model.aggregator, "patch_size", 14)) + patch_w = max(1, images.shape[-1] // getattr(model.aggregator, "patch_size", 14)) + patch_h = max(1, images.shape[-2] // getattr(model.aggregator, "patch_size", 14)) if hasattr(model, "update_patch_dimensions"): try: model.update_patch_dimensions(patch_w, patch_h) @@ -599,28 +650,32 @@ def run_VGGT( with torch.no_grad(): with autocast_ctx: - batched = resized_images.unsqueeze(0) # make into (training) batch of 1 + batched = images.unsqueeze(0) # make into (training) batch of 1 tokens, ps_idx = model.aggregator(batched) # transformer backbone + with torch.cuda.amp.autocast(dtype=torch.float32): pose_enc = model.camera_head(tokens)[-1] extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, batched.shape[-2:]) depth_map, depth_conf = model.depth_head(tokens, batched, ps_idx) - assert _anysplat_batchify_unproject is not None, "Anysplat dependencies not available" - dense_points = _anysplat_batchify_unproject(depth_map, extrinsic, intrinsic) - depth_confidence = depth_conf.squeeze(0) if depth_confidence.ndim == 4 and depth_confidence.shape[-1] == 1: depth_confidence = depth_confidence.squeeze(-1) + depth_map = depth_map.squeeze(0).to(dtype=torch.float32) + extrinsic = extrinsic.squeeze(0).to(dtype=torch.float32) + intrinsic = intrinsic.squeeze(0).to(dtype=torch.float32) + dense_points_np = unproject_depth_map_to_point_map(depth_map, extrinsic, intrinsic) + dense_points = torch.from_numpy(dense_points_np).to(device=resolved_device, dtype=torch.float32) + return VggtOutput( device=resolved_device, dtype=resolved_dtype, - resized_images=resized_images, - extrinsic=extrinsic.squeeze(0), - intrinsic=intrinsic.squeeze(0), - depth_map=depth_map.squeeze(0), + images=images, + extrinsic=extrinsic, + intrinsic=intrinsic, + depth_map=depth_map, depth_confidence=depth_confidence, - dense_points=dense_points.squeeze(0), + dense_points=dense_points, ) @@ -646,22 +701,17 @@ class VGGTTrackingResult: colors: Optional[np.ndarray] -def _import_predict_tracks(): - """Return the vendored ``predict_tracks`` helper from the VGGT submodule. - - The tracker lives in ``thirdparty/vggt``. We keep this import behind a small helper so that runtime - errors surface with a clear explanation when the submodule is missing. - """ +def _import_vggsfm_utils(): + """Return the vendored vggsfm utilities module from the VGGT submodule.""" try: - from vggt.dependency.track_predict import predict_tracks as _predict_tracks # type: ignore + from vggt.dependency import vggsfm_utils as _vggsfm_utils # type: ignore except ImportError as exc: # pragma: no cover - exercised only when the submodule is absent - # FastVGGT strips the tracker utilities, so fall back to the vanilla VGGT namespace if possible. if _USING_FASTVGGT: try: - tracker_module = _import_from_vanilla_vggt("dependency.track_predict") - logger.info("Using tracker utilities from the vanilla VGGT submodule.") - return tracker_module.predict_tracks # type: ignore[attr-defined] + tracker_module = _import_from_vanilla_vggt("dependency.vggsfm_utils") + logger.info("Using vggsfm utilities from the vanilla VGGT submodule.") + return tracker_module # type: ignore[return-value] except ImportError as fallback_exc: exc = fallback_exc @@ -672,42 +722,19 @@ def _import_predict_tracks(): if _USING_FASTVGGT: hint += " FastVGGT does not bundle the tracker code, so the vanilla VGGT submodule must remain accessible." raise ImportError(hint) from exc - return _predict_tracks + return _vggsfm_utils -def run_vggt_tracking( - images: torch.Tensor, vggt_output: VggtOutput, *, config: Optional[VggtConfiguration] = None +def _run_vggt_head_tracking( + vggt_output: VggtOutput, + *, + model: VGGT, + config: Optional[VggtConfiguration] = None, ) -> VGGTTrackingResult: - """Generate dense feature tracks using the VGGSfM tracker shipped with VGGT. - - Parameters: - images: Tensor shaped ``(num_frames, 3, H, W)`` at the *square* VGGT load resolution. You can reuse - the ``images`` tensor that you passed into :func:`run_reconstruction`; typically this is the output - from ``load_and_preprocess_images_square`` prior to interpolation. - vggt_output: Output from :func:`run_VGGT`. The ``depth_confidence`` and optional ``dense_points`` tensors - are consumed directly, avoiding redundant transfers or recomputation. - config: Optional :class:`VggtConfiguration`. We reuse the existing configuration container because - it already captures the tracker-specific parameters (``max_query_pts``, ``query_frame_num``, etc.). - tracker_kwargs: Optional dictionary to override individual keyword arguments passed to the underlying - :func:`vggt.dependency.track_predict.predict_tracks` function. This is useful if you want to tweak - settings not exposed via :class:`VggtConfiguration`. - - Returns: - :class:`VGGTTrackingResult` aggregating the numpy arrays emitted by the tracker. The visibility scores can - be thresholded manually, e.g. ``mask = result.visibilities > config.vis_thresh``. The tracks are expressed - in the same *square* coordinate frame as ``images``; remember to rescale them back to the original image - crop using :func:`_convert_measurement_to_original_resolution` if you plan to add them to ``GtsfmData``. - - Example: - >>> vggt_output = run_VGGT(image_batch, model=model, dtype=dtype, return_dense_points=True) - >>> cfg = VggtConfiguration() - >>> tracking = run_vggt_tracking(image_batch, vggt_output, config=cfg) - >>> high_quality = tracking.visibilities > cfg.vis_thresh - >>> first_track_pixels = tracking.tracks[:, 0] - """ + """Generate dense feature tracks using the VGGT track head.""" cfg = config or VggtConfiguration() - predict_tracks = _import_predict_tracks() + vggsfm_utils = _import_vggsfm_utils() device = vggt_output.device if device.type != "cuda": @@ -716,47 +743,156 @@ def run_vggt_tracking( "Re-run the pipeline with CUDA available." ) - dtype = torch.float32 # Tracker stack (LightGlue / DINO) expects fp32 inputs. - - if images.device != device or images.dtype != dtype: - logger.info("Moving VGGT tracking inputs to %s (dtype=%s) for DINO attention.", device, dtype) - images = images.to(device=device, dtype=dtype, non_blocking=True) + images = vggt_output.images + if images.device != device or images.dtype != torch.float32: + images = images.to(device=device, dtype=torch.float32, non_blocking=True) + + frame_num = images.shape[0] + query_frame_indexes = vggsfm_utils.generate_rank_by_dino( + images, + query_frame_num=cfg.query_frame_num, + image_size=518, + model_name="dinov2_vitb14_reg", + device=device, + spatial_similarity=False, + ) + if 0 in query_frame_indexes: + query_frame_indexes.remove(0) + query_frame_indexes = [0, *query_frame_indexes] + + extractors = vggsfm_utils.initialize_feature_extractors( + max_query_num=cfg.max_query_pts, + extractor_method=cfg.keypoint_extractor, + device=device, + ) - conf_tensor = vggt_output.depth_confidence.to(device="cpu", dtype=dtype, non_blocking=True) - points_tensor = vggt_output.dense_points.to(device="cpu", dtype=dtype, non_blocking=True) + dense_points = vggt_output.dense_points + depth_confidence = vggt_output.depth_confidence + height, width = images.shape[-2:] + + pred_tracks = [] + pred_vis_scores = [] + pred_conf_scores = [] + pred_world_points = [] + pred_world_points_conf = [] + pred_colors = [] + + for query_index in query_frame_indexes: + query_image = images[query_index] + query_points = vggsfm_utils.extract_keypoints(query_image, extractors, round_keypoints=True) + if query_points is None or query_points.shape[1] == 0: + continue + + query_points = query_points[:, torch.randperm(query_points.shape[1], device=device)] + if query_points.shape[1] > cfg.max_query_pts: + query_points = query_points[:, : cfg.max_query_pts] + + query_points_round = query_points.squeeze(0).round().long() + query_points_round[:, 0] = query_points_round[:, 0].clamp(0, width - 1) + query_points_round[:, 1] = query_points_round[:, 1].clamp(0, height - 1) + + pred_color = ( + images[query_index][:, query_points_round[:, 1], query_points_round[:, 0]].permute(1, 0).cpu().numpy() + * 255.0 + ).astype(np.uint8) + + pred_point_3d = dense_points[query_index][query_points_round[:, 1], query_points_round[:, 0]] + + pred_conf = None + if depth_confidence is not None: + pred_conf = depth_confidence[query_index][query_points_round[:, 1], query_points_round[:, 0]] + + if query_points.shape[1] == 0: + continue + + reorder_index = vggsfm_utils.calculate_index_mappings(query_index, frame_num, device=device) + reorder_images = vggsfm_utils.switch_tensor_order([images], reorder_index, dim=0)[0] + + with torch.no_grad(): + with amp_autocast("cuda", dtype=vggt_output.dtype): + aggregated_tokens_list, ps_idx = model.aggregator(reorder_images[None]) + if aggregated_tokens_list and aggregated_tokens_list[0].dtype != torch.float32: + aggregated_tokens_list = [tokens.float() for tokens in aggregated_tokens_list] + with amp_autocast("cuda", dtype=torch.float32): + track_list, vis_scores, conf_scores = model.track_head( + aggregated_tokens_list, + reorder_images[None], + ps_idx, + query_points=query_points, + ) - with torch.no_grad(): - tracks, vis_scores, confidences, points_3d, colors = predict_tracks( - images, - conf=conf_tensor, - points_3d=points_tensor, - masks=None, # ignored anyway ! - max_query_pts=cfg.max_query_pts, - query_frame_num=cfg.query_frame_num, - keypoint_extractor=cfg.keypoint_extractor, - fine_tracking=cfg.fine_tracking, + pred_track = track_list[-1] + pred_track = pred_track.squeeze(0) + vis_scores = vis_scores.squeeze(0) + conf_scores = conf_scores.squeeze(0) + reordered = vggsfm_utils.switch_tensor_order([pred_track, vis_scores, conf_scores], reorder_index, dim=0) + pred_track, pred_vis, pred_conf_score = reordered + + pred_tracks.append(pred_track) + pred_vis_scores.append(pred_vis) + if pred_conf_score is not None: + pred_conf_scores.append(pred_conf_score) + pred_world_points.append(pred_point_3d) + if pred_conf is not None: + pred_world_points_conf.append(pred_conf) + pred_colors.append(pred_color) + + if not pred_tracks: + empty_tracks = np.zeros((frame_num, 0, 2), dtype=np.float32) + empty_vis = np.zeros((frame_num, 0), dtype=np.float32) + empty_conf = np.zeros((0,), dtype=np.float32) if depth_confidence is not None else None + empty_points = np.zeros((0, 3), dtype=np.float32) + empty_colors = np.zeros((0, 3), dtype=np.uint8) + return VGGTTrackingResult( + tracks=empty_tracks, + visibilities=empty_vis, + confidences=empty_conf, + points_3d=empty_points, + colors=empty_colors, ) - # print("images: ", images.shape) - # print("conf_tensor: ", conf_tensor.shape) - # print("tracks: ", tracks.shape) - # print("vis_scores: ", vis_scores.shape) - # print("confidences: ", confidences.shape) - # print("points_3d: ", points_3d.shape) - # print("colors: ", colors.shape) - # # images: torch.Size([4, 3, 1024, 1024]) - # # conf_tensor: torch.Size([4, 518, 518]) - # # tracks: (4, 2901, 2) - # # vis_scores: (4, 2901) - # # confidences: (2901,) - # # points_3d: (2901, 3) - # # colors: (2901, 3) + tracks = torch.cat(pred_tracks, dim=1) + vis_scores = torch.cat(pred_vis_scores, dim=1) + confidences = torch.cat(pred_conf_scores, dim=1) if pred_conf_scores else None + points_3d = torch.cat(pred_world_points, dim=0) if pred_world_points else None + points_3d_conf = torch.cat(pred_world_points_conf, dim=0) if pred_world_points_conf else None + colors = np.concatenate(pred_colors, axis=0) if pred_colors else None + + if points_3d_conf is not None and points_3d is not None: + filtered_flag = points_3d_conf > 1.5 + if int(filtered_flag.sum().item()) > cfg.max_query_pts // 2: + tracks = tracks[:, filtered_flag] + vis_scores = vis_scores[:, filtered_flag] + if confidences is not None: + confidences = confidences[:, filtered_flag] + points_3d = points_3d[filtered_flag] + points_3d_conf = points_3d_conf[filtered_flag] + if colors is not None: + colors = colors[filtered_flag.cpu().numpy()] return VGGTTrackingResult( - tracks=tracks, visibilities=vis_scores, confidences=confidences, points_3d=points_3d, colors=colors + tracks=tracks.float().cpu().numpy(), + visibilities=vis_scores.float().cpu().numpy(), + confidences=confidences.float().cpu().numpy() if confidences is not None else None, + points_3d=points_3d.float().cpu().numpy() if points_3d is not None else None, + colors=colors, ) +def run_vggt_tracking( + vggt_output: VggtOutput, + *, + config: Optional[VggtConfiguration] = None, + model: Optional[VGGT] = None, +) -> VGGTTrackingResult: + """Generate dense feature tracks using the configured VGGT tracking backend.""" + + cfg = config or VggtConfiguration() + if model is None: + raise ValueError("VGGT tracking_head requires a loaded VGGT model.") + return _run_vggt_head_tracking(vggt_output, model=model, config=cfg) + + # --- VGGT reconstruction ------------------------------------------------- @@ -773,8 +909,7 @@ def run_reconstruction( """Run VGGT on a batch of images and convert outputs to ``GtsfmData``. Args: - images: Tensor shaped ``(num_frames, 3, H, W)`` at the *square* VGGT load resolution. You can - obtain this tensor by calling ``load_and_preprocess_images_square`` prior to interpolation. + images: Tensor shaped ``(num_frames, 3, H, W)`` at the VGGT load resolution. image_indices: Sequence of global image indices corresponding to the provided ``images`` batch. image_names: Optional sequence of image filenames corresponding to the provided ``images`` batch. original_coords: Tensor shaped ``(num_frames, 6)`` giving the original image crop metadata @@ -794,21 +929,25 @@ def run_reconstruction( torch.cuda.manual_seed(cfg.seed) torch.cuda.manual_seed_all(cfg.seed) + model_for_tracking = None + if cfg.tracking and model_for_tracking is None: + model_for_tracking = model + vggt_output = run_VGGT(images, config=cfg, model=model, weights_path=weights_path) + tracking_result = None + if cfg.tracking: + tracking_result = run_vggt_tracking(vggt_output, config=cfg, model=model_for_tracking) + if cfg.tracking and vggt_output.device.type == "cuda": - if model is not None: - _offload_vggt_model(model) + if model_for_tracking is not None: + _offload_vggt_model(model_for_tracking) else: torch.cuda.empty_cache() - tracking_result = None - if cfg.tracking: - tracking_result = run_vggt_tracking(images, vggt_output, config=cfg) - points_3d, points_rgb = _high_confidence_pointcloud(cfg, vggt_output) - gtsfm_data = _convert_vggt_outputs_to_gtsfm_data( + gtsfm_data, gtsfm_data_pre_ba = _convert_vggt_outputs_to_gtsfm_data( config=cfg, vggt_output=vggt_output, original_coords=original_coords, @@ -825,6 +964,7 @@ def run_reconstruction( return VggtReconstruction( gtsfm_data=gtsfm_data, + pre_ba_data=gtsfm_data_pre_ba, points_3d=points_3d, points_rgb=points_rgb, tracking_result=tracking_result, @@ -835,8 +975,7 @@ def run_reconstruction_gtsfm_data_only(images: torch.Tensor, **kwargs) -> GtsfmD """Run VGGT on a batch of images and convert outputs to ``GtsfmData``. Args: - images: Tensor shaped ``(num_frames, 3, H, W)`` at the *square* VGGT load resolution. You can - obtain this tensor by calling ``load_and_preprocess_images_square`` prior to interpolation. + images: Tensor shaped ``(num_frames, 3, H, W)`` at the VGGT load resolution. **kwargs: Additional keyword arguments passed to :func:`run_reconstruction`. Returns: @@ -855,6 +994,7 @@ def run_reconstruction_gtsfm_data_only(images: torch.Tensor, **kwargs) -> GtsfmD "VGGT_SUBMODULE_PATH", "LIGHTGLUE_SUBMODULE_PATH", "default_dtype", + "load_image_batch_vggt_loader", "load_and_preprocess_images_square", "resolve_weights_path", "load_model", diff --git a/gtsfm/runner.py b/gtsfm/runner.py index fda8540b8..83c33ce04 100644 --- a/gtsfm/runner.py +++ b/gtsfm/runner.py @@ -150,7 +150,7 @@ def construct_argparser(self) -> argparse.ArgumentParser: ) parser.add_argument("--threads_per_worker", type=int, default=1, help="Number of threads per each worker.") parser.add_argument( - "--worker_memory_limit", type=str, default="16GB", help="Memory limit per worker, e.g. `16GB`" + "--worker_memory_limit", type=str, default="32GB", help="Memory limit per worker, e.g. `16GB`" ) parser.add_argument("--dashboard_port", type=str, default=":8787", help="dask dashboard port number") parser.add_argument( diff --git a/gtsfm/scene_optimizer.py b/gtsfm/scene_optimizer.py index 917f3c089..4d7356b05 100644 --- a/gtsfm/scene_optimizer.py +++ b/gtsfm/scene_optimizer.py @@ -108,14 +108,13 @@ def __init__( output_root: str = DEFAULT_OUTPUT_ROOT, output_worker: Optional[str] = None, plot_reprojection_histograms: bool = True, + use_nonlinear_sim3_merging: bool = False, ) -> None: self.loader = loader self.image_pairs_generator = image_pairs_generator self.graph_partitioner = graph_partitioner self.cluster_optimizer = cluster_optimizer - self._run_bundle_adjustment_on_parent = getattr( - self.cluster_optimizer, "run_bundle_adjustment_on_parent", True - ) + self._run_bundle_adjustment_on_parent = getattr(self.cluster_optimizer, "run_bundle_adjustment_on_parent", True) self._plot_reprojection_histograms = getattr( self.cluster_optimizer, "plot_reprojection_histograms", plot_reprojection_histograms ) @@ -124,7 +123,7 @@ def __init__( ) self._drop_camera_with_no_track = getattr(self.cluster_optimizer, "drop_camera_with_no_track", True) self._drop_child_if_merging_fail = getattr(self.cluster_optimizer, "drop_child_if_merging_fail", True) - + self._use_nonlinear_sim3_merging = use_nonlinear_sim3_merging self.output_root = Path(output_root) if output_worker is not None: self.cluster_optimizer._output_worker = output_worker @@ -231,7 +230,7 @@ def to_context(path: tuple[int, ...], visibility_graph: VisibilityGraph) -> Clus # Returns handles to various outputs: reconstruction, metrics, io_barrier etc. handles_tree = context_tree.map(self._schedule_single_cluster) - # Get the reconstruction handle and run merging to get a tree of merged result handles. + # Get the reconstruction handle and run merging to get a tree of merged result handles. reconstruction_tree = handles_tree.map(lambda handle: handle.reconstruction) cameras_gt = self.loader.get_gt_cameras() @@ -248,6 +247,7 @@ def merge_fn( drop_camera_with_no_track=self._drop_camera_with_no_track, drop_child_if_merging_fail=self._drop_child_if_merging_fail, store_full_data=False, + use_nonlinear_sim3_alignment=self._use_nonlinear_sim3_merging, ) merged_future_tree = submit_tree_map_with_children(client, reconstruction_tree, merge_fn) @@ -270,7 +270,7 @@ def merge_fn( base_metrics_groups.extend(metrics_groups) base_metrics_groups.append(merged_result.metrics) root_merge_future = merge_future - elif metrics_groups: + else: merged_result = merge_future.result() metrics_groups.append(merged_result.metrics) save_metrics_reports(metrics_groups, str(handle.output_paths.metrics)) diff --git a/gtsfm/utils/metrics.py b/gtsfm/utils/metrics.py index a58294fd4..945a74bfd 100644 --- a/gtsfm/utils/metrics.py +++ b/gtsfm/utils/metrics.py @@ -444,6 +444,7 @@ def compute_ba_pose_metrics( gt_wTi: dict[int, Pose3], computed_wTi: dict[int, Optional[Pose3]], save_dir: Optional[str] = None, + store_full_data: bool = False, ) -> GtsfmMetricsGroup: """Compute pose errors w.r.t. GT for the bundle adjustment result. @@ -453,6 +454,7 @@ def compute_ba_pose_metrics( gt_wTi: Dict of ground truth poses keyed by camera id. computed_wTi: Dict of computed poses keyed by camera id. save_dir: Directory to save the metrics plots. + store_full_data: Whether to store full data. Returns: A group of metrics that describe errors associated with a bundle adjustment result (w.r.t. GT). @@ -477,12 +479,20 @@ def compute_ba_pose_metrics( gt_wTi_opt: dict[int, Optional[Pose3]] = {i: pose for i, pose in gt_wTi.items()} translation_angular_errors = get_relative_translation_angles(i2Ui1_gt_opt, computed_wTi_opt, include_none=True) metrics.append( - GtsfmMetric("relative_translation_angle_error_deg", np.array(translation_angular_errors, dtype=np.float32)) + GtsfmMetric( + "relative_translation_angle_error_deg", + np.array(translation_angular_errors, dtype=np.float32), + store_full_data=store_full_data, + ) ) metrics.append(compute_translation_angle_metric(gt_wTi_opt, computed_wTi_opt)) rotation_angular_errors = get_relative_rotation_angles(i2Ri1_gt_opt, computed_wTi_opt, include_none=True) metrics.append( - GtsfmMetric("relative_rotation_angle_error_deg", np.array(rotation_angular_errors, dtype=np.float32)) + GtsfmMetric( + "relative_rotation_angle_error_deg", + np.array(rotation_angular_errors, dtype=np.float32), + store_full_data=store_full_data, + ) ) metrics.extend(compute_pose_auc_metric(rotation_angular_errors, translation_angular_errors, save_dir=save_dir)) diff --git a/gtsfm/visualization/visualize_tracks.py b/gtsfm/visualization/visualize_tracks.py new file mode 100644 index 000000000..cbff0686d --- /dev/null +++ b/gtsfm/visualization/visualize_tracks.py @@ -0,0 +1,372 @@ +"""Visualize reprojection errors for tracks stored in COLMAP text outputs. + +This script reconstructs a GtsfmData object from COLMAP text files, builds a +single loader from a Hydra config, and overlays reprojection error vectors on +each image. Each measurement draws a line between the reprojected point and the +measured 2D keypoint, with an optional dot for the track. + +The script searches `--result_root` recursively for folders containing COLMAP +`cameras.txt`, `images.txt`, and `points3D.txt`, then writes visualizations to +`/tracks_viz/...` mirroring the COLMAP folder structure. +""" + +from __future__ import annotations + +import argparse +import colorsys +import os +from pathlib import Path +from typing import Iterable, List, Set, Tuple + +import cv2 +import hydra +import numpy as np +from hydra.utils import instantiate +from PIL import Image as PILImage +from PIL.Image import Image as PILImageType + +import gtsfm.utils.logger as logger_utils +from gtsfm.common.gtsfm_data import GtsfmData +from gtsfm.loader.loader_base import LoaderBase + +logger = logger_utils.get_logger() + + +def _build_loader( + loader_config: str, + dataset_dir: str, + images_dir: str | None, + max_resolution: int | None, +) -> LoaderBase: + """Instantiate a loader using a Hydra config.""" + overrides: List[str] = [f"dataset_dir={dataset_dir}"] + if images_dir is not None: + overrides.append(f"images_dir={images_dir}") + if max_resolution is not None: + overrides.append(f"max_resolution={max_resolution}") + + config_dir = Path(__file__).resolve().parents[1] / "configs" / "loader" + with hydra.initialize_config_dir(config_dir=str(config_dir), version_base=None): + cfg = hydra.compose(config_name=loader_config, overrides=overrides) + return instantiate(cfg) + + +def _collect_reprojection_pairs( + gtsfm_data: GtsfmData, + camera_idx: int, + allowed_track_indices: Set[int], +) -> List[Tuple[int, np.ndarray, np.ndarray]]: + """Collect (track_idx, measured, reprojected) for a given camera index.""" + camera = gtsfm_data.get_camera(camera_idx) + if camera is None: + return [] + + pairs: List[Tuple[int, np.ndarray, np.ndarray]] = [] + measurements = gtsfm_data.get_measurements_for_camera(camera_idx) + for track_idx, measurement_idx in measurements: + if track_idx not in allowed_track_indices: + continue + track = gtsfm_data.get_track(track_idx) + image_idx, uv_measured = track.measurement(measurement_idx) + assert image_idx == camera_idx, "Measurement image index does not match camera index" + uv_reproj, success = camera.projectSafe(track.point3()) + if not success: + continue + pairs.append((track_idx, np.array(uv_measured, dtype=float), np.array(uv_reproj, dtype=float))) + return pairs + + +def _track_color(track_idx: int) -> Tuple[int, int, int]: + """Assign a consistent, distinguishable RGB color per track index.""" + hue = (track_idx * 0.61803398875) % 1.0 + r, g, b = colorsys.hsv_to_rgb(hue, 0.7, 0.95) + return int(r * 255), int(g * 255), int(b * 255) + + +def _draw_reprojection_overlay( + image_array: np.ndarray, + pairs: Iterable[Tuple[int, np.ndarray, np.ndarray]], + *, + line_color: Tuple[int, int, int], + dot_radius: int, + line_width: int, + draw_measured: bool, + measured_color: Tuple[int, int, int], + scale_u: float, + scale_v: float, + dot_on_measured: bool, + line_only: bool, +) -> PILImageType: + """Draw reprojection overlays on an image using OpenCV.""" + image_rgb = image_array.astype(np.uint8) + image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) + line_color_bgr = (line_color[2], line_color[1], line_color[0]) + measured_color_bgr = (measured_color[2], measured_color[1], measured_color[0]) + + for track_idx, uv_measured, uv_reproj in pairs: + x_meas = float(uv_measured[0]) * scale_u + y_meas = float(uv_measured[1]) * scale_v + x_rep = float(uv_reproj[0]) * scale_u + y_rep = float(uv_reproj[1]) * scale_v + reproj_color = _track_color(track_idx) + dot_x, dot_y = (x_meas, y_meas) if dot_on_measured else (x_rep, y_rep) + + pt_rep = (int(round(x_rep)), int(round(y_rep))) + pt_meas = (int(round(x_meas)), int(round(y_meas))) + pt_dot = (int(round(dot_x)), int(round(dot_y))) + + cv2.line(image_bgr, pt_rep, pt_meas, line_color_bgr, thickness=line_width, lineType=cv2.LINE_AA) + if not line_only: + reproj_color_bgr = (reproj_color[2], reproj_color[1], reproj_color[0]) + cv2.circle(image_bgr, pt_dot, dot_radius, reproj_color_bgr, thickness=-1, lineType=cv2.LINE_AA) + if draw_measured: + cv2.circle( + image_bgr, + pt_meas, + dot_radius, + measured_color_bgr, + thickness=max(1, line_width), + lineType=cv2.LINE_AA, + ) + + image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) + return PILImage.fromarray(image_rgb) + + +def _resolve_output_name(gtsfm_data: GtsfmData, loader: LoaderBase, image_idx: int) -> str: + """Resolve output filename based on COLMAP image names, with loader fallback.""" + info = gtsfm_data.get_image_info(image_idx) + if info.name: + return Path(info.name).name + filenames = loader.image_filenames() + if 0 <= image_idx < len(filenames): + name = Path(filenames[image_idx]).name + if name: + return name + return f"image_{image_idx:06d}.png" + + +def _build_loader_name_maps(loader: LoaderBase) -> tuple[dict[str, int], dict[str, list[int]]]: + """Build lookup maps from loader filenames to loader indices.""" + filenames = loader.image_filenames() + full_map: dict[str, int] = {} + base_map: dict[str, list[int]] = {} + for idx, name in enumerate(filenames): + full_map[name] = idx + base = Path(name).name + base_map.setdefault(base, []).append(idx) + return full_map, base_map + + +def _resolve_loader_index( + gtsfm_data: GtsfmData, image_idx: int, full_map: dict[str, int], base_map: dict[str, list[int]] +) -> int | None: + """Resolve loader index for a COLMAP image index based on filename.""" + info = gtsfm_data.get_image_info(image_idx) + if info.name: + if info.name in full_map: + return full_map[info.name] + base = Path(info.name).name + if base in base_map: + if len(base_map[base]) > 1: + logger.warning("Multiple loader matches for %s; using first.", base) + return base_map[base][0] + return None + + +def _has_colmap_text_files(directory: str) -> bool: + """Check whether a directory contains COLMAP text outputs.""" + required = {"cameras.txt", "images.txt", "points3D.txt"} + try: + entries = set(os.listdir(directory)) + except FileNotFoundError: + return False + return required.issubset(entries) + + +def _find_colmap_dirs(root_dir: str) -> List[str]: + """Recursively find all subdirectories containing COLMAP text files.""" + matches: List[str] = [] + for dirpath, _, _ in os.walk(root_dir): + if _has_colmap_text_files(dirpath): + matches.append(dirpath) + return matches + + +def _visualize_tracks_for_dir(args: argparse.Namespace, colmap_dir: str, output_dir: str, loader: LoaderBase) -> None: + """Visualize reprojection errors for one COLMAP directory.""" + logger.info("Loading reconstruction from %s", colmap_dir) + try: + gtsfm_data = GtsfmData.read_colmap(colmap_dir) + except Exception as exc: + logger.exception("Skipping %s due to error: %s", colmap_dir, exc) + return + + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + full_map, base_map = _build_loader_name_maps(loader) + + camera_indices = sorted(gtsfm_data.get_valid_camera_indices()) + if args.max_images is not None: + camera_indices = camera_indices[: args.max_images] + + num_tracks = gtsfm_data.number_tracks() + if args.max_pairs is not None and num_tracks > args.max_pairs: + rng = np.random.default_rng(args.random_seed) + sampled = rng.choice(num_tracks, size=args.max_pairs, replace=False) + allowed_track_indices = set(int(idx) for idx in sampled) + else: + allowed_track_indices = set(range(num_tracks)) + + for camera_idx in camera_indices: + loader_idx = _resolve_loader_index(gtsfm_data, camera_idx, full_map, base_map) + if loader_idx is None: + logger.warning("Skipping camera %d with no loader match", camera_idx) + continue + + pairs = _collect_reprojection_pairs(gtsfm_data, camera_idx, allowed_track_indices) + if not pairs: + logger.info("No valid measurements for image %d", camera_idx) + continue + + image = loader.get_image(loader_idx) + resized_h, resized_w = image.height, image.width + info = gtsfm_data.get_image_info(camera_idx) + if info.shape is not None: + orig_h, orig_w = info.shape + else: + orig_h, orig_w = resized_h, resized_w + scale_u = resized_w / orig_w if orig_w > 0 else 1.0 + scale_v = resized_h / orig_h if orig_h > 0 else 1.0 + + overlay = _draw_reprojection_overlay( + image.value_array, + pairs, + line_color=tuple(args.line_color), + dot_radius=args.dot_radius, + line_width=args.line_width, + draw_measured=args.draw_measured, + measured_color=tuple(args.measured_color), + scale_u=scale_u, + scale_v=scale_v, + dot_on_measured=args.dot_on_measured, + line_only=args.line_only, + ) + + output_name = _resolve_output_name(gtsfm_data, loader, camera_idx) + output_file = output_path / output_name + overlay.save(output_file) + logger.info("Saved %s", output_file) + + +def visualize_tracks(args: argparse.Namespace) -> None: + """Visualize reprojection errors across all COLMAP directories under result_root.""" + colmap_dirs = _find_colmap_dirs(args.result_root) + if not colmap_dirs: + logger.warning("No COLMAP text directories found under %s", args.result_root) + return + + logger.info("Instantiating loader config=%s", args.loader_config) + loader = _build_loader( + loader_config=args.loader_config, + dataset_dir=args.dataset_dir, + images_dir=args.images_dir, + max_resolution=args.max_resolution, + ) + + viz_root = Path(args.result_root) / "tracks_viz" + for colmap_dir in colmap_dirs: + rel_path = Path(colmap_dir).relative_to(args.result_root) + output_dir = viz_root / rel_path + _visualize_tracks_for_dir(args, colmap_dir, str(output_dir), loader) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Overlay reprojection error vectors on images for COLMAP reconstructions." + ) + parser.add_argument( + "--result_root", + type=str, + required=True, + help="Root directory to recursively search for COLMAP text outputs.", + ) + parser.add_argument( + "--loader_config", + type=str, + default="colmap", + help="Loader config name from gtsfm/configs/loader (e.g., colmap, tanks_and_temples).", + ) + parser.add_argument( + "--dataset_dir", + type=str, + required=True, + help="Dataset root for the loader (passed as loader.dataset_dir).", + ) + parser.add_argument( + "--images_dir", + type=str, + default=None, + help="Optional images directory (passed as loader.images_dir).", + ) + parser.add_argument( + "--max_resolution", + type=int, + default=None, + help="Optional max resolution override for loader.", + ) + parser.add_argument( + "--max_images", + type=int, + default=None, + help="Limit the number of images to visualize.", + ) + parser.add_argument( + "--max_pairs", + type=int, + default=None, + help="Limit the number of tracks drawn across all images (randomly sampled).", + ) + parser.add_argument("--dot_radius", type=int, default=2, help="Radius for reprojection dot.") + parser.add_argument("--line_width", type=int, default=1, help="Line width for reprojection error.") + parser.add_argument( + "--line_color", + type=int, + nargs=3, + default=(255, 0, 0), + help="RGB color for reprojection error lines.", + ) + parser.add_argument( + "--random_seed", + type=int, + default=0, + help="Random seed for sampling tracks when max_pairs is set.", + ) + parser.add_argument( + "--draw_measured", + action="store_true", + help="Draw an outline circle at the measured 2D point.", + ) + parser.add_argument( + "--dot_on_measured", + action="store_true", + help="Draw the colored dot on the measured point instead of the reprojection.", + ) + parser.add_argument( + "--line_only", + action="store_true", + help="Draw only the line; use a small line-colored dot at the line head.", + ) + parser.add_argument( + "--measured_color", + type=int, + nargs=3, + default=(0, 255, 0), + help="RGB color for measured point outlines.", + ) + return parser.parse_args() + + +if __name__ == "__main__": + visualize_tracks(_parse_args()) diff --git a/pyproject.toml b/pyproject.toml index 227f3ec8e..cb9e386d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,7 @@ dependencies = [ "pydegensac", "colour", "trimesh[easy]", - "gtsam==4.3a0", + "gtsam-develop==4.3a1.dev202602040056", "pydot", # Dust3r/Mast3r diff --git a/tests/averaging/translation/test_averaging_1dsfm.py b/tests/averaging/translation/test_averaging_1dsfm.py index c02413af7..f2670a418 100644 --- a/tests/averaging/translation/test_averaging_1dsfm.py +++ b/tests/averaging/translation/test_averaging_1dsfm.py @@ -48,7 +48,7 @@ def test_binary_measurements_from_dict(self): (0, 2): Unit3(Point3(0, 1, 0)), (1, 2): Unit3(Point3(0, 0, 1)), } - noise_model = gtsam.noiseModel.Isotropic.Sigma(3, 0.1) + noise_model = gtsam.noiseModel.Isotropic.Sigma(2, 0.1) expected_measurement_idxs = set( [(C(i2), C(i1)) for (i1, i2) in w_i2Ui1_dict.keys()] + [(C(i2), L(i1)) for (i1, i2) in w_i2Ui1_dict_tracks.keys()] diff --git a/tests/utils/test_align.py b/tests/utils/test_align.py index dff1fb3fd..f2f3dd1de 100644 --- a/tests/utils/test_align.py +++ b/tests/utils/test_align.py @@ -952,12 +952,12 @@ def test_align_gtsfm_data_via_Sim3_to_poses_skydio32(self) -> None: aligned_metrics = aligned_filtered_data.get_metrics(suffix="_filtered") - assert unaligned_metrics[3].name == "reprojection_errors_filtered_px" - assert aligned_metrics[3].name == "reprojection_errors_filtered_px" + assert unaligned_metrics[4].name == "reprojection_errors_filtered_px" + assert aligned_metrics[4].name == "reprojection_errors_filtered_px" # Reprojection error should be unaffected by Sim(3) alignment. for key in ["min", "max", "median", "mean", "stddev"]: - assert np.isclose(unaligned_metrics[3].summary[key], aligned_metrics[3].summary[key]) + assert np.isclose(unaligned_metrics[4].summary[key], aligned_metrics[4].summary[key]) if __name__ == "__main__": diff --git a/tests/utils/test_metric_utils.py b/tests/utils/test_metric_utils.py index 0c3e09967..aa9c89cab 100644 --- a/tests/utils/test_metric_utils.py +++ b/tests/utils/test_metric_utils.py @@ -177,14 +177,14 @@ def test_get_metrics_for_sfmdata_skydio32(self) -> None: ) metrics = aligned_filtered_data.get_metrics(suffix="_filtered") - assert metrics[0].name == "number_cameras" - assert np.isclose(metrics[0]._data, np.array(5.0, dtype=np.float32)) + assert metrics[1].name == "number_cameras_filtered" + assert np.isclose(metrics[1]._data, np.array(5.0, dtype=np.float32)) - assert metrics[1].name == "number_tracks_filtered" - assert np.isclose(metrics[1]._data, np.array(7.0, dtype=np.float32)) + assert metrics[2].name == "number_tracks_filtered" + assert np.isclose(metrics[2]._data, np.array(7.0, dtype=np.float32)) - assert metrics[2].name == "3d_track_lengths_filtered" - assert metrics[2].summary == { + assert metrics[3].name == "3d_track_lengths_filtered" + assert metrics[3].summary == { "min": 2, "max": 2, "median": 2.0, @@ -195,8 +195,8 @@ def test_get_metrics_for_sfmdata_skydio32(self) -> None: "invalid": 0, } - assert metrics[3].name == "reprojection_errors_filtered_px" - assert metrics[3].summary == {"min": np.nan, "max": np.nan, "median": np.nan, "mean": np.nan, "stddev": np.nan} + assert metrics[4].name == "reprojection_errors_filtered_px" + assert metrics[4].summary == {"min": np.nan, "max": np.nan, "median": np.nan, "mean": np.nan, "stddev": np.nan} def test_compute_percentage_change_improve() -> None: