diff --git a/src/native/cuda/nvidia/blas_utils.h b/src/native/cuda/nvidia/blas_utils.h index 45644ae38..c4b135070 100644 --- a/src/native/cuda/nvidia/blas_utils.h +++ b/src/native/cuda/nvidia/blas_utils.h @@ -18,11 +18,7 @@ struct BlasUtils { return CUDA_R_32F; } - static auto GetComputeType(DataType dtype) { - if (dtype == DataType::kFloat16 || dtype == DataType::kBFloat16) - return CUBLAS_COMPUTE_32F; - return CUBLAS_COMPUTE_32F_FAST_TF32; - } + static auto GetComputeType(DataType) { return CUBLAS_COMPUTE_32F; } }; } // namespace infini::ops diff --git a/src/native/cuda/ops/gemm/blas.h b/src/native/cuda/ops/gemm/blas.h index 6bac0d6a6..eae97bb69 100644 --- a/src/native/cuda/ops/gemm/blas.h +++ b/src/native/cuda/ops/gemm/blas.h @@ -1,6 +1,7 @@ #ifndef INFINI_OPS_CUDA_GEMM_BLAS_H_ #define INFINI_OPS_CUDA_GEMM_BLAS_H_ +#include #include #include "base/gemm.h" @@ -32,7 +33,8 @@ class BlasGemm : public Gemm { void operator()(const Tensor a, const Tensor b, std::optional alpha, std::optional beta, std::optional trans_a, std::optional trans_b, Tensor c) const override { - Backend::BlasSetStream(GetHandle(), + auto& handle{GetHandle(c.device())}; + Backend::BlasSetStream(handle, static_cast(stream_)); const auto& alpha_value{alpha.value_or(alpha_)}; @@ -46,7 +48,7 @@ class BlasGemm : public Gemm { const void* beta_ptr{GetBetaPtr(beta_value, c.dtype())}; Backend::BlasGemmStridedBatchedEx( - GetHandle(), op_a, op_b, swap_a_and_b_ ? n_ : m_, + handle, op_a, op_b, swap_a_and_b_ ? n_ : m_, swap_a_and_b_ ? m_ : n_, k_, alpha_ptr, swap_a_and_b_ ? b.data() : a.data(), BlasUtils::GetDataType(swap_a_and_b_ ? b.dtype() @@ -92,15 +94,17 @@ class BlasGemm : public Gemm { : Backend::BLAS_OP_N; } - // TODO: This static singleton is not thread-safe under concurrent access - // from multiple host threads. Add proper synchronization in the future. - static typename Backend::BlasHandle& GetHandle() { - static typename Backend::BlasHandle handle = []() { - typename Backend::BlasHandle h; - Backend::BlasCreate(&h); - return h; - }(); - return handle; + static typename Backend::BlasHandle& GetHandle(const Device& device) { + static thread_local std::unordered_map + handles; + + auto it{handles.find(device.index())}; + if (it == handles.end()) { + typename Backend::BlasHandle handle; + Backend::BlasCreate(&handle); + it = handles.emplace(device.index(), handle).first; + } + return it->second; } bool a_is_col_major_{false}; diff --git a/src/operator.h b/src/operator.h index 257d62de7..1b13d45df 100644 --- a/src/operator.h +++ b/src/operator.h @@ -1,6 +1,7 @@ #ifndef INFINI_OPS_OPERATOR_H_ #define INFINI_OPS_OPERATOR_H_ +#include #include #include #include @@ -138,13 +139,15 @@ class Operator : public OperatorBase { // Generation counter for lazy cache invalidation. Bumped by // `clear_cache()`; the next `call()` detects the mismatch and // destroys all cached operator instances. - static inline std::size_t cache_generation_{0}; + static inline std::atomic_size_t cache_generation_{0}; public: // Invalidate the operator cache. Cached operators are destroyed on the // next `call()` invocation. Intended for test isolation — production // code should never call this. - static void clear_cache() { ++cache_generation_; } + static void clear_cache() { + cache_generation_.fetch_add(1, std::memory_order_acq_rel); + } template static auto Make(const Config& config, const Tensor tensor, Args&&... args) { std::unique_ptr op_ptr; @@ -193,13 +196,16 @@ class Operator : public OperatorBase { template static auto Call(const Handle& handle, const Config& config, const Args&... args) { - static std::unordered_map> + static thread_local std::unordered_map> cache; - static std::size_t generation{0}; + static thread_local std::size_t generation{0}; - if (generation != cache_generation_) { + const auto current_generation{ + cache_generation_.load(std::memory_order_acquire)}; + if (generation != current_generation) { cache.clear(); - generation = cache_generation_; + generation = current_generation; } auto key = detail::CacheKey::Build(config.implementation_index(), args...); diff --git a/tests/test_cpp_api.py b/tests/test_cpp_api.py index 86c0c1600..9ff71c3e9 100644 --- a/tests/test_cpp_api.py +++ b/tests/test_cpp_api.py @@ -19,6 +19,7 @@ def test_cpp_operator_call_instantiation_smoke(tmp_path): _compiler("CXX", "c++"), "-std=c++17", "-Werror", + "-pthread", f"-I{include_dir}", str(source), f"-L{library_dir}", @@ -72,7 +73,9 @@ def _run(command): r""" #include + #include #include + #include int main() { float input_data[3] = {1.0f, 2.0f, 3.0f}; @@ -109,7 +112,34 @@ def _run(command): return 1; } - return 0; + std::atomic failures{0}; + auto threaded_call = [&]() { + float threaded_output[3] = {0.0f, 0.0f, 0.0f}; + infini::ops::Tensor threaded_out(threaded_output, shape, data_type, + device); + for (int i = 0; i < 100; ++i) { + infini::ops::Add::Call(handle, config, input, other, threaded_out); + if (std::fabs(threaded_output[0] - 5.0f) > 1e-6f || + std::fabs(threaded_output[1] - 7.0f) > 1e-6f || + std::fabs(threaded_output[2] - 9.0f) > 1e-6f) { + failures.fetch_add(1, std::memory_order_relaxed); + } + threaded_output[0] = 0.0f; + threaded_output[1] = 0.0f; + threaded_output[2] = 0.0f; + } + }; + + std::thread t0(threaded_call); + std::thread t1(threaded_call); + std::thread t2(threaded_call); + std::thread t3(threaded_call); + t0.join(); + t1.join(); + t2.join(); + t3.join(); + + return failures.load(std::memory_order_relaxed) == 0 ? 0 : 1; } """ ).lstrip() diff --git a/tests/test_gemm.py b/tests/test_gemm.py index 224390d15..d8224039a 100644 --- a/tests/test_gemm.py +++ b/tests/test_gemm.py @@ -5,6 +5,19 @@ from tests.utils import Payload, get_stream, randn_strided +@pytest.fixture(autouse=True) +def _strict_cuda_fp32_reference(): + old_matmul_allow_tf32 = torch.backends.cuda.matmul.allow_tf32 + old_cudnn_allow_tf32 = torch.backends.cudnn.allow_tf32 + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + try: + yield + finally: + torch.backends.cuda.matmul.allow_tf32 = old_matmul_allow_tf32 + torch.backends.cudnn.allow_tf32 = old_cudnn_allow_tf32 + + @pytest.mark.auto_act_and_assert @pytest.mark.parametrize( "a_shape, b_shape, c_shape, a_strides, b_strides, c_strides",