Skip to content

fix(cuda): make GEMM caches safe for threaded callers#631

Open
voltjia wants to merge 1 commit into
masterfrom
codex/fix-issue-627
Open

fix(cuda): make GEMM caches safe for threaded callers#631
voltjia wants to merge 1 commit into
masterfrom
codex/fix-issue-627

Conversation

@voltjia
Copy link
Copy Markdown
Collaborator

@voltjia voltjia commented Jun 2, 2026

Summary

Fixes #627.

  • Make CUDA GEMM BLAS handles thread-local and keyed by tensor device index, so concurrent host threads and devices do not share one process-global cuBLAS handle.
  • Make Operator::Call cache storage thread-local, with an atomic generation counter for cache invalidation.
  • Use strict CUBLAS_COMPUTE_32F for NVIDIA GEMM compute type to avoid fp32 TF32 differences.
  • Extend the C++ API smoke test with concurrent Operator::Call calls and update GEMM tests to compare against strict fp32 PyTorch reference behavior.

Validation

Remote NVIDIA host, container infiniops-ci/nvidia:latest:

git diff --check
python3 scripts/generate_torch_ops.py
python3 -m pip install --no-build-isolation --no-deps . \
  --config-settings=cmake.define.AUTO_DETECT_DEVICES=OFF \
  --config-settings=cmake.define.AUTO_DETECT_BACKENDS=OFF \
  --config-settings=cmake.define.WITH_CPU=ON \
  --config-settings=cmake.define.WITH_NVIDIA=ON
INFINIOPS_INSTALL_PREFIX=/usr/local \
  python3 -m pytest -q tests/test_cpp_api.py tests/test_gemm.py --devices nvidia

Result: 2000 passed, 1001 skipped in 9.08s.

Also ran:

python3 -m ruff format --check tests/test_cpp_api.py tests/test_gemm.py
python3 -m ruff check tests/test_cpp_api.py tests/test_gemm.py

Result: 2 files already formatted; All checks passed.

A prior WITH_TORCH=ON focused run built successfully, then failed GEMM comparisons because PyTorch reference GEMM was still using TF32 while this PR switches native NVIDIA GEMM to strict fp32. The GEMM test now disables PyTorch TF32 in the module reference path for parity with the new behavior.

@voltjia voltjia requested a review from a team June 2, 2026 13:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

CUDA GEMM: process-wide operator/BLAS caches are unsafe for multi-thread + multi-device callers

1 participant