Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion gtsfm/cluster_optimizer/cluster_vggt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from pathlib import Path
from typing import Any, Hashable, Optional, Union
from typing import Any, Hashable, Literal, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -248,6 +248,8 @@ def __init__(
weights_path: Optional[str] = None,
conf_threshold: float = 5.0,
max_num_points: int = 100000,
dense_points_head: Literal["depth", "point"] = "depth",
intrinsics_source: Literal["pose", "reprojection"] = "pose",
tracking: bool = False,
tracking_max_query_pts: int = 2048,
tracking_query_frame_num: int = 3,
Expand Down Expand Up @@ -309,6 +311,8 @@ def __init__(
self._weights_path = Path(weights_path) if weights_path is not None else None
self._conf_threshold = conf_threshold
self._max_points_for_colmap = max_num_points
self._dense_points_head = dense_points_head
self._intrinsics_source = intrinsics_source
self._tracking = tracking
self._tracking_max_query_pts = tracking_max_query_pts
self._tracking_query_frame_num = tracking_query_frame_num
Expand Down Expand Up @@ -417,6 +421,8 @@ def create_computation_graph(
config = VggtConfiguration(
confidence_threshold=self._conf_threshold,
max_num_points=self._max_points_for_colmap,
dense_points_head=self._dense_points_head,
intrinsics_source=self._intrinsics_source,
tracking=self._tracking,
max_query_pts=self._tracking_max_query_pts,
query_frame_num=self._tracking_query_frame_num,
Expand Down
2 changes: 2 additions & 0 deletions gtsfm/configs/vggt_megaloc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ cluster_optimizer:
weights_path: null
conf_threshold: 0.1
max_num_points: 100000
dense_points_head: point
intrinsics_source: reprojection
tracking: true
tracking_max_query_pts: 512
tracking_query_frame_num: 3
Expand Down
2 changes: 2 additions & 0 deletions gtsfm/configs/vggt_megaloc_phototourism.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ cluster_optimizer:
weights_path: null
conf_threshold: 0.1
max_num_points: 100000
dense_points_head: point
intrinsics_source: reprojection
tracking: true
tracking_max_query_pts: 512
tracking_query_frame_num: 3
Expand Down
95 changes: 87 additions & 8 deletions gtsfm/frontend/vggt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from importlib.machinery import ModuleSpec
from pathlib import Path
from types import ModuleType
from typing import Any, List, Optional, Sequence, Tuple, Union
from typing import Any, List, Literal, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -311,6 +311,8 @@ class VggtConfiguration:
seed: int = 42
confidence_threshold: float = 5.0
max_num_points: int = 100000
dense_points_head: Literal["depth", "point"] = "depth"
intrinsics_source: Literal["pose", "reprojection"] = "pose"
dtype: Optional[Union[str, torch.dtype]] = None
model_ctor_kwargs: dict[str, Any] = field(default_factory=dict)
use_sparse_attention: bool = False
Expand Down Expand Up @@ -673,6 +675,8 @@ def run_VGGT(
config: Optional[VggtConfiguration] = None,
model: Optional[VGGT] = None,
weights_path: PathLike | None = None,
dense_points_head: Literal["depth", "point"] = "depth",
intrinsics_source: Literal["pose", "reprojection"] = "pose",
) -> VggtOutput:
"""Run VGGT and unproject depth using the geometry helper."""
if images.ndim != 4 or images.shape[1] != 3:
Expand Down Expand Up @@ -725,16 +729,84 @@ def run_VGGT(
pose_enc = model.camera_head(tokens)[-1]
extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, batched.shape[-2:])
depth_map, depth_conf = model.depth_head(tokens, batched, ps_idx)

depth_confidence = depth_conf.squeeze(0)
if depth_confidence.ndim == 4 and depth_confidence.shape[-1] == 1:
depth_confidence = depth_confidence.squeeze(-1)
point_map = point_conf = None
if dense_points_head == "point":
if model.point_head is None:
raise ValueError("dense_points_head='point' requires model.point_head to be available.")
point_map, point_conf = model.point_head(tokens, batched, ps_idx)

depth_map = depth_map.squeeze(0).to(dtype=torch.float32)
extrinsic = extrinsic.squeeze(0).to(dtype=torch.float32)
intrinsic = intrinsic.squeeze(0).to(dtype=torch.float32)
dense_points_np = unproject_depth_map_to_point_map(depth_map, extrinsic, intrinsic)
dense_points = torch.from_numpy(dense_points_np).to(device=resolved_device, dtype=torch.float32)
if dense_points_head == "point":
assert point_map is not None and point_conf is not None
dense_points = point_map.squeeze(0).to(dtype=torch.float32)
depth_confidence = point_conf.squeeze(0)
elif dense_points_head == "depth":
dense_points_np = unproject_depth_map_to_point_map(depth_map, extrinsic, intrinsic)
dense_points = torch.from_numpy(dense_points_np).to(device=resolved_device, dtype=torch.float32)
depth_confidence = depth_conf.squeeze(0)
else:
raise ValueError(f"Unknown dense_points_head='{dense_points_head}'. Expected 'depth' or 'point'.")
if depth_confidence.ndim == 4 and depth_confidence.shape[-1] == 1:
depth_confidence = depth_confidence.squeeze(-1)

if intrinsics_source == "reprojection":
if extrinsic.shape[-2:] == (4, 4):
r_cw = extrinsic[:, :3, :3]
t_cw = extrinsic[:, :3, 3]
else:
r_cw = extrinsic[:, :3, :3]
t_cw = extrinsic[:, :3, 3]

h, w = dense_points.shape[1], dense_points.shape[2]
cx = w * 0.5
cy = h * 0.5
points_in_cam = torch.einsum("nij,nhwj->nhwi", r_cw, dense_points) + t_cw[:, None, None, :]
x_flat = points_in_cam[..., 0].reshape(points_in_cam.shape[0], -1)
y_flat = points_in_cam[..., 1].reshape(points_in_cam.shape[0], -1)
z_flat = points_in_cam[..., 2].reshape(points_in_cam.shape[0], -1)

u_flat = torch.arange(w, device=dense_points.device, dtype=torch.float32).view(1, w).repeat(h, 1).reshape(-1)
v_flat = torch.arange(h, device=dense_points.device, dtype=torch.float32).view(h, 1).repeat(1, w).reshape(-1)

conf_flat = depth_confidence.to(torch.float32).reshape(depth_confidence.shape[0], -1)
conf_flat = torch.where(torch.isfinite(conf_flat), conf_flat, torch.full_like(conf_flat, float("-inf")))
topk = min(100, conf_flat.shape[1])
topk_idx = torch.topk(conf_flat, k=topk, dim=1, largest=True, sorted=False).indices

x = torch.gather(x_flat, 1, topk_idx)
y = torch.gather(y_flat, 1, topk_idx)
z = torch.gather(z_flat, 1, topk_idx)
u = torch.gather(u_flat.unsqueeze(0).expand(topk_idx.shape[0], -1), 1, topk_idx)
v = torch.gather(v_flat.unsqueeze(0).expand(topk_idx.shape[0], -1), 1, topk_idx)
eps = 1e-6

fx_samples = (u - cx) * z / x
fy_samples = (v - cy) * z / y
fx_mask = torch.isfinite(fx_samples) & (torch.abs(x) > eps) & (z > eps)
fy_mask = torch.isfinite(fy_samples) & (torch.abs(y) > eps) & (z > eps)

fx_sum = (fx_samples * fx_mask).sum(dim=1)
fy_sum = (fy_samples * fy_mask).sum(dim=1)
fx_count = fx_mask.sum(dim=1)
fy_count = fy_mask.sum(dim=1)

fx_est = fx_sum / fx_count.clamp(min=1)
fy_est = fy_sum / fy_count.clamp(min=1)

fx_fallback = intrinsic[:, 0, 0]
fy_fallback = intrinsic[:, 1, 1]
fx = torch.where(fx_count > 0, fx_est, fx_fallback)
fy = torch.where(fy_count > 0, fy_est, fy_fallback)

intrinsic = intrinsic.clone()
intrinsic[:, 0, 0] = fx
intrinsic[:, 1, 1] = fy
intrinsic[:, 0, 2] = cx
intrinsic[:, 1, 2] = cy
elif intrinsics_source != "pose":
raise ValueError(f"Unknown intrinsics_source='{intrinsics_source}'. Expected 'pose' or 'reprojection'.")

return VggtOutput(
device=resolved_device,
Expand Down Expand Up @@ -1079,7 +1151,14 @@ def run_reconstruction(
if cfg.tracking and model_for_tracking is None:
model_for_tracking = model

vggt_output = run_VGGT(images, config=cfg, model=model, weights_path=weights_path)
vggt_output = run_VGGT(
images,
config=cfg,
model=model,
weights_path=weights_path,
dense_points_head=cfg.dense_points_head,
intrinsics_source=cfg.intrinsics_source,
)

tracking_result = None
if cfg.tracking:
Expand Down