Skip to content
Open
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
9089568
added sycl code
asrathore-ai Aug 13, 2025
dedfc2d
Resolve Windows compiler / linker errors.
ssheorey Sep 13, 2025
6f44999
Fix Windows linker errors
ssheorey Sep 13, 2025
a9eaf99
Loading kernels now works in Windows.
ssheorey Sep 15, 2025
2f28ba3
COrrected libtorch path issues
asrathore-ai Sep 23, 2025
3fff3de
Updated correct shape calculation
asrathore-ai Sep 24, 2025
10c0efd
Updated proj changes
asrathore-ai Sep 24, 2025
7b3bfb8
corrected isect code
asrathore-ai Sep 24, 2025
0d25c9e
Update fully fused projection kernels
asrathore-ai Sep 25, 2025
02a596b
wip packed kernel
asrathore-ai Sep 25, 2025
fc801e8
added fully fused projection packed
asrathore-ai Sep 25, 2025
f773859
Update kernel
asrathore-ai Sep 25, 2025
4dfea00
updated training code
asrathore-ai Sep 25, 2025
6624ec0
update steps for memory error
asrathore-ai Sep 25, 2025
6597d44
torch_acc uniform API for both cuda and xpu. Maybe replaced by torch.…
ssheorey Oct 7, 2025
0a472cf
Added forward pass for 2fgs fully fused projection
asrathore-ai Oct 17, 2025
417942e
added backward pass
asrathore-ai Oct 17, 2025
6763ada
update tests
asrathore-ai Oct 17, 2025
f4d0123
Working tests
asrathore-ai Oct 17, 2025
167d0b6
Added rasterize forward kernel
asrathore-ai Oct 17, 2025
a4f90ac
added backward kernel
asrathore-ai Oct 17, 2025
aa6ff08
Update rasterize_to_pixels_2dgs_bwd.cpp
asrathore-ai Nov 11, 2025
6542d4c
Update RasterizeToPixels2DGSBwdKernel.hpp
asrathore-ai Nov 11, 2025
cb5c72a
Update RasterizeToPixels2DGSBwdKernel.hpp
asrathore-ai Nov 11, 2025
4a5bde3
Update RasterizeToPixels2DGSFwdKernel.hpp
asrathore-ai Nov 11, 2025
11604f5
Update rasterize_to_pixels_2dgs_fwd.cpp for correct block size comput…
asrathore-ai Nov 15, 2025
b5c46a9
Update rasterize_to_pixels_2dgs_bwd.cpp for correct block size comput…
asrathore-ai Nov 15, 2025
eb2174c
relocation kernel for MCMC strategy
ssheorey Nov 15, 2025
390ffc2
Merge branch '2d_gaussian_splatting' of github.com:isl-org/gsplat int…
ssheorey Nov 15, 2025
6524660
CI tests and docs should pass now.
ssheorey Nov 16, 2025
e261c04
Use pytorch 2.6 for CI (docs and core_tests)
ssheorey Nov 16, 2025
d03c5f3
black v22 and clang-format
ssheorey Nov 16, 2025
ed7352a
Fix test_rasterization for packed, fix has_xpu check from previous co…
ssheorey Nov 17, 2025
b5bc4f0
#pragma once to follow gsplat convention
ssheorey Nov 17, 2025
b785b20
Merge pull request #2 from isl-org/2d_gaussian_splatting
ssheorey Nov 17, 2025
72fcd3a
Common _wrapper.py for cuda and sycl
ssheorey Nov 18, 2025
570de47
Merge branch 'main' into sycl_code_integration
ssheorey Nov 21, 2025
cabd342
contiguous() tensors to fused_ssim
ssheorey Nov 24, 2025
31c484e
Updated fused_ssim to latest including sycl bugfix.
ssheorey Nov 24, 2025
a70b2f5
empty xpu docs
ssheorey Nov 25, 2025
bc60098
Evaluation results on B580
ssheorey Jan 26, 2026
4f022e4
Windows MSVC fixes
ssheorey Jan 27, 2026
a9f1536
Do not install fused-bilagrid
ssheorey Feb 4, 2026
b315442
Update instructions.
ssheorey Feb 4, 2026
fa4cdd9
Simplify build process.
ssheorey Feb 7, 2026
3cb8262
Initial plan
Copilot Feb 23, 2026
02d094c
Add SYCL/XPU wheel build and publish GitHub Actions workflows
Copilot Feb 23, 2026
0305370
Address PR review comments on SYCL/XPU wheel workflows
Copilot Feb 23, 2026
a5c4e47
Add SYCL/XPU wheel build and publish GitHub Actions workflows
ssheorey Feb 23, 2026
2406dfb
Initial plan
Copilot Feb 23, 2026
a0025c6
Fix XPU workflow: truncate oneAPI version X.Y.Z to X.Y for apt packag…
Copilot Feb 23, 2026
1256a06
Windows wheel workflow
ssheorey Feb 27, 2026
1baa172
Fix
ssheorey Feb 27, 2026
3328824
Fix setup.py for source only wheel
ssheorey Feb 27, 2026
b3a8f5d
installer fix
ssheorey Feb 28, 2026
02b390b
fix
ssheorey Feb 28, 2026
54de18a
Separate Windows and Linux build steps.
ssheorey Feb 28, 2026
1442ff5
fix build type for windows
ssheorey Feb 28, 2026
21f1f85
Update docs
ssheorey Mar 12, 2026
711958b
SYCL CI - wheels and PyPI repo hosting on github pages.
ssheorey Mar 12, 2026
760d963
Add info about pre-built wheels.
ssheorey Mar 12, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/core_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ jobs:
with:
submodules: 'recursive'

- name: Set up Python 3.8.12
- name: Set up Python 3.9
uses: actions/setup-python@v5
with:
python-version: "3.8.12"
python-version: "3.9"
- name: Install dependencies
run: |
pip install black[jupyter]==22.3.0 pytest
pip install torch==2.0.0 --index-url https://download.pytorch.org/whl/cpu
BUILD_NO_CUDA=1 pip install .
pip install torch==2.6.0 --index-url https://download.pytorch.org/whl/cpu
BUILD_NO_CUDA=1 pip install --no-build-isolation .
- name: Run Black Format Check
run: black . gsplat/ tests/ examples/ profiling/ --check
- name: Run Tests.
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/doc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ jobs:
- name: Install dependencies
run: |
pip install -r docs/requirements.txt
pip install torch==2.0.0 --index-url https://download.pytorch.org/whl/cpu
BUILD_NO_CUDA=1 pip install .
pip install torch==2.6.0 --index-url https://download.pytorch.org/whl/cpu
BUILD_NO_CUDA=1 pip install --no-build-isolation .

# Get version.
- name: Get version + subdirectory
Expand Down
4 changes: 0 additions & 4 deletions docs/source/apis/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ Below are the basic functions that supports the rasterization.

.. autofunction:: rasterize_to_indices_in_range

.. autofunction:: accumulate

.. autofunction:: rasterization_inria_wrapper

2DGS
Expand All @@ -41,6 +39,4 @@ Below are the basic functions that supports the rasterization.

.. autofunction:: rasterize_to_indices_in_range_2dgs

.. autofunction:: accumulate_2dgs

.. autofunction:: rasterization_2dgs_inria_wrapper
8 changes: 4 additions & 4 deletions examples/image_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from PIL import Image
from torch import Tensor, optim

from gsplat import rasterization, rasterization_2dgs
from gsplat import torch_acc, rasterization, rasterization_2dgs


class SimpleTrainer:
Expand All @@ -21,7 +21,7 @@ def __init__(
gt_image: Tensor,
num_points: int = 2000,
):
self.device = torch.device("cuda:0")
self.device = torch_acc._get_device(0)
self.gt_image = gt_image.to(device=self.device)
self.num_points = num_points

Expand Down Expand Up @@ -117,13 +117,13 @@ def train(
packed=False,
)[0]
out_img = renders[0]
torch.cuda.synchronize()
torch_acc.synchronize()
times[0] += time.time() - start
loss = mse_loss(out_img, self.gt_image)
optimizer.zero_grad()
start = time.time()
loss.backward()
torch.cuda.synchronize()
torch_acc.synchronize()
times[1] += time.time() - start
optimizer.step()
print(f"Iteration {iter + 1}/{iterations}, Loss: {loss.item()}")
Expand Down
4 changes: 2 additions & 2 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ tensorboard
tensorly
pyyaml
matplotlib
git+https://github.com/rahul-goel/fused-ssim@328dc9836f513d00c4b5bc38fe30478b4435cbb5
git+https://github.com/harry7557558/fused-bilagrid@90f9788e57d3545e3a033c1038bb9986549632fe
#git+https://github.com/rahul-goel/fused-ssim@328dc9836f513d00c4b5bc38fe30478b4435cbb5
#git+https://github.com/harry7557558/fused-bilagrid@90f9788e57d3545e3a033c1038bb9986549632fe
splines
55 changes: 43 additions & 12 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
generate_interpolated_path,
generate_spiral_path,
)
from fused_ssim import fused_ssim
from fusedssim_sycl import fusedssim
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
Expand All @@ -30,7 +30,7 @@
from typing_extensions import Literal, assert_never
from utils import AppearanceOptModule, CameraOptModule, knn, rgb_to_sh, set_random_seed

from gsplat import export_splats
from gsplat import export_splats, torch_acc
from gsplat.compression import PngCompression
from gsplat.distributed import cli
from gsplat.optimizers import SelectiveAdam
Expand Down Expand Up @@ -79,13 +79,13 @@ class Config:
# Number of training steps
max_steps: int = 30_000
# Steps to evaluate the model
eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
eval_steps: List[int] = field(default_factory=lambda: [2_000, 7_000, 30_000])
# Steps to save the model
save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
save_steps: List[int] = field(default_factory=lambda: [2_000, 7_000, 30_000])
# Whether to save ply file (storage size can be large)
save_ply: bool = False
# Steps to save the model as ply
ply_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
ply_steps: List[int] = field(default_factory=lambda: [2_000, 7_000, 30_000])
# Whether to disable video generation during training and evaluation
disable_video: bool = False

Expand Down Expand Up @@ -227,7 +227,7 @@ def create_splats_with_optimizers(
visible_adam: bool = False,
batch_size: int = 1,
feature_dim: Optional[int] = None,
device: str = "cuda",
device: str = torch_acc._device(0).type,
world_rank: int = 0,
world_size: int = 1,
) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]:
Expand Down Expand Up @@ -312,7 +312,7 @@ def __init__(
self.world_rank = world_rank
self.local_rank = local_rank
self.world_size = world_size
self.device = f"cuda:{local_rank}"
self.device = str(torch_acc._device(local_rank))

# Where to dump results.
os.makedirs(cfg.result_dir, exist_ok=True)
Expand Down Expand Up @@ -681,7 +681,7 @@ def train(self):

# loss
l1loss = F.l1_loss(colors, pixels)
ssimloss = 1.0 - fused_ssim(
ssimloss = 1.0 - fusedssim(
colors.permute(0, 3, 1, 2), pixels.permute(0, 3, 1, 2), padding="valid"
)
loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda
Expand Down Expand Up @@ -735,7 +735,7 @@ def train(self):
# )

if world_rank == 0 and cfg.tb_every > 0 and step % cfg.tb_every == 0:
mem = torch.cuda.max_memory_allocated() / 1024**3
mem = torch_acc.max_memory_allocated() / 1024**3
self.writer.add_scalar("train/loss", loss.item(), step)
self.writer.add_scalar("train/l1loss", l1loss.item(), step)
self.writer.add_scalar("train/ssimloss", ssimloss.item(), step)
Expand All @@ -753,7 +753,7 @@ def train(self):

# save checkpoint before updating the model
if step in [i - 1 for i in cfg.save_steps] or step == max_steps - 1:
mem = torch.cuda.max_memory_allocated() / 1024**3
mem = torch_acc.max_memory_allocated() / 1024**3
stats = {
"mem": mem,
"ellipse_time": time.time() - global_tic,
Expand Down Expand Up @@ -923,7 +923,7 @@ def eval(self, step: int, stage: str = "val"):
masks = data["mask"].to(device) if "mask" in data else None
height, width = pixels.shape[1:3]

torch.cuda.synchronize()
torch_acc.synchronize()
tic = time.time()
colors, _, _ = self.rasterize_splats(
camtoworlds=camtoworlds,
Expand All @@ -935,7 +935,7 @@ def eval(self, step: int, stage: str = "val"):
far_plane=cfg.far_plane,
masks=masks,
) # [1, H, W, 3]
torch.cuda.synchronize()
torch_acc.synchronize()
ellipse_time += max(time.time() - tic, 1e-10)

colors = torch.clamp(colors, 0.0, 1.0)
Expand Down Expand Up @@ -1177,6 +1177,37 @@ def main(local_rank: int, world_rank, world_size: int, cfg: Config):
step = ckpts[0]["step"]
runner.eval(step=step)
runner.render_traj(step=step)
if cfg.save_ply:
if runner.cfg.app_opt:
# eval at origin to bake the appeareance into the colors
rgb = runner.app_module(
features=runner.splats["features"],
embed_ids=None,
dirs=torch.zeros_like(runner.splats["means"][None, :, :]),
sh_degree=runner.cfg.sh_degree,
)
rgb = rgb + runner.splats["colors"]
rgb = torch.sigmoid(rgb).squeeze(0).unsqueeze(1)
sh0 = rgb_to_sh(rgb)
shN = torch.empty([sh0.shape[0], 0, 3], device=sh0.device)
else:
sh0 = runner.splats["sh0"]
shN = runner.splats["shN"]

means = runner.splats["means"]
scales = runner.splats["scales"]
quats = runner.splats["quats"]
opacities = runner.splats["opacities"]
export_splats(
means=means,
scales=scales,
quats=quats,
opacities=opacities,
sh0=sh0,
shN=shN,
format="ply",
save_to=f"{cfg.result_dir}/point_cloud_{step}.ply",
)
if cfg.compression is not None:
runner.run_compression(step=step)
else:
Expand Down
13 changes: 7 additions & 6 deletions examples/simple_trainer_2dgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
rgb_to_sh,
set_random_seed,
)
from gsplat import torch_acc
from gsplat_viewer_2dgs import GsplatViewer, GsplatRenderTabState
from gsplat.rendering import rasterization_2dgs, rasterization_2dgs_inria_wrapper
from gsplat.strategy import DefaultStrategy
Expand Down Expand Up @@ -194,7 +195,7 @@ def create_splats_with_optimizers(
sparse_grad: bool = False,
batch_size: int = 1,
feature_dim: Optional[int] = None,
device: str = "cuda",
device: str = torch_acc._device(0).type,
) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]:
if init_type == "sfm":
points = torch.from_numpy(parser.points).float()
Expand Down Expand Up @@ -257,7 +258,7 @@ def __init__(self, cfg: Config) -> None:
set_random_seed(42)

self.cfg = cfg
self.device = "cuda"
self.device = torch_acc._device(0).type

# Where to dump results.
os.makedirs(cfg.result_dir, exist_ok=True)
Expand Down Expand Up @@ -650,7 +651,7 @@ def train(self):
pbar.set_description(desc)

if cfg.tb_every > 0 and step % cfg.tb_every == 0:
mem = torch.cuda.max_memory_allocated() / 1024**3
mem = torch_acc.max_memory_allocated() / 1024**3
self.writer.add_scalar("train/loss", loss.item(), step)
self.writer.add_scalar("train/l1loss", l1loss.item(), step)
self.writer.add_scalar("train/ssimloss", ssimloss.item(), step)
Expand Down Expand Up @@ -712,7 +713,7 @@ def train(self):

# save checkpoint
if step in [i - 1 for i in cfg.save_steps] or step == max_steps - 1:
mem = torch.cuda.max_memory_allocated() / 1024**3
mem = torch_acc.max_memory_allocated() / 1024**3
stats = {
"mem": mem,
"ellipse_time": time.time() - global_tic,
Expand Down Expand Up @@ -765,7 +766,7 @@ def eval(self, step: int):
pixels = data["image"].to(device) / 255.0
height, width = pixels.shape[1:3]

torch.cuda.synchronize()
torch_acc.synchronize()
tic = time.time()
(
colors,
Expand All @@ -787,7 +788,7 @@ def eval(self, step: int):
) # [1, H, W, 3]
colors = torch.clamp(colors, 0.0, 1.0)
colors = colors[..., :3] # Take RGB channels
torch.cuda.synchronize()
torch_acc.synchronize()
ellipse_time += max(time.time() - tic, 1e-10)

# write images
Expand Down
2 changes: 1 addition & 1 deletion examples/simple_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

def main(local_rank: int, world_rank, world_size: int, args):
torch.manual_seed(42)
device = torch.device("cuda", local_rank)
device = torch.device(local_rank)

if args.ckpt is None:
(
Expand Down
60 changes: 52 additions & 8 deletions gsplat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import warnings
import os
import sys
import torch

from .compression import PngCompression
from .cuda._torch_impl import accumulate
from .cuda._torch_impl_2dgs import accumulate_2dgs
from .cuda._wrapper import (
BACKEND: str = ""
torch_acc = torch.cpu
_force_backend = os.getenv("GSPLAT_BACKEND", "").lower()

from .cuda._wrapper import ( # Default to CUDA imports, works even if no CUDA is available
RollingShutterType,
fully_fused_projection,
fully_fused_projection_2dgs,
Expand All @@ -20,6 +23,44 @@
spherical_harmonics,
world_to_cam,
)

if _force_backend == "cuda" or (_force_backend == "" and torch.cuda.is_available()):
BACKEND = "cuda"
torch_acc = torch.cuda
print("gsplat: Using CUDA backend.", file=sys.stderr)
# Functions already imported above

if (
not BACKEND
and _force_backend in ("sycl", "xpu")
or _force_backend == ""
and hasattr(torch, "xpu")
and torch.xpu.is_available()
):
from .sycl._wrapper import ( # Overwrite imports for SYCL backend
RollingShutterType,
fully_fused_projection,
fully_fused_projection_2dgs,
fully_fused_projection_with_ut,
isect_offset_encode,
isect_tiles,
proj,
quat_scale_to_covar_preci,
rasterize_to_indices_in_range,
rasterize_to_indices_in_range_2dgs,
rasterize_to_pixels,
rasterize_to_pixels_2dgs,
rasterize_to_pixels_eval3d,
spherical_harmonics,
world_to_cam,
)

BACKEND = "sycl"
torch_acc = torch.xpu
print("gsplat: Using SYCL XPU backend.", file=sys.stderr)


from .compression import PngCompression
from .exporter import export_splats
from .optimizers import SelectiveAdam
from .rendering import (
Expand All @@ -31,7 +72,10 @@
from .strategy import DefaultStrategy, MCMCStrategy, Strategy
from .version import __version__

all = [

__all__ = [
"BACKEND",
"torch_acc",
"PngCompression",
"DefaultStrategy",
"MCMCStrategy",
Expand All @@ -47,16 +91,16 @@
"quat_scale_to_covar_preci",
"rasterize_to_pixels",
"world_to_cam",
"accumulate",
"rasterize_to_indices_in_range",
"fully_fused_projection_2dgs",
"rasterize_to_pixels_2dgs",
"rasterize_to_indices_in_range_2dgs",
"accumulate_2dgs",
"rasterization_2dgs_inria_wrapper",
"RollingShutterType",
"fully_fused_projection_with_ut",
"rasterize_to_pixels_eval3d",
"export_splats",
"__version__",
"SelectiveAdam",
# Note: accumulate and accumulate_2dgs are not typically part of the public API
]
Loading