diff --git a/src/native/cuda/ops/gemm/blas.h b/src/native/cuda/ops/gemm/blas.h index 6bac0d6a6..eae97bb69 100644 --- a/src/native/cuda/ops/gemm/blas.h +++ b/src/native/cuda/ops/gemm/blas.h @@ -1,6 +1,7 @@ #ifndef INFINI_OPS_CUDA_GEMM_BLAS_H_ #define INFINI_OPS_CUDA_GEMM_BLAS_H_ +#include #include #include "base/gemm.h" @@ -32,7 +33,8 @@ class BlasGemm : public Gemm { void operator()(const Tensor a, const Tensor b, std::optional alpha, std::optional beta, std::optional trans_a, std::optional trans_b, Tensor c) const override { - Backend::BlasSetStream(GetHandle(), + auto& handle{GetHandle(c.device())}; + Backend::BlasSetStream(handle, static_cast(stream_)); const auto& alpha_value{alpha.value_or(alpha_)}; @@ -46,7 +48,7 @@ class BlasGemm : public Gemm { const void* beta_ptr{GetBetaPtr(beta_value, c.dtype())}; Backend::BlasGemmStridedBatchedEx( - GetHandle(), op_a, op_b, swap_a_and_b_ ? n_ : m_, + handle, op_a, op_b, swap_a_and_b_ ? n_ : m_, swap_a_and_b_ ? m_ : n_, k_, alpha_ptr, swap_a_and_b_ ? b.data() : a.data(), BlasUtils::GetDataType(swap_a_and_b_ ? b.dtype() @@ -92,15 +94,17 @@ class BlasGemm : public Gemm { : Backend::BLAS_OP_N; } - // TODO: This static singleton is not thread-safe under concurrent access - // from multiple host threads. Add proper synchronization in the future. - static typename Backend::BlasHandle& GetHandle() { - static typename Backend::BlasHandle handle = []() { - typename Backend::BlasHandle h; - Backend::BlasCreate(&h); - return h; - }(); - return handle; + static typename Backend::BlasHandle& GetHandle(const Device& device) { + static thread_local std::unordered_map + handles; + + auto it{handles.find(device.index())}; + if (it == handles.end()) { + typename Backend::BlasHandle handle; + Backend::BlasCreate(&handle); + it = handles.emplace(device.index(), handle).first; + } + return it->second; } bool a_is_col_major_{false}; diff --git a/src/operator.h b/src/operator.h index 257d62de7..1b13d45df 100644 --- a/src/operator.h +++ b/src/operator.h @@ -1,6 +1,7 @@ #ifndef INFINI_OPS_OPERATOR_H_ #define INFINI_OPS_OPERATOR_H_ +#include #include #include #include @@ -138,13 +139,15 @@ class Operator : public OperatorBase { // Generation counter for lazy cache invalidation. Bumped by // `clear_cache()`; the next `call()` detects the mismatch and // destroys all cached operator instances. - static inline std::size_t cache_generation_{0}; + static inline std::atomic_size_t cache_generation_{0}; public: // Invalidate the operator cache. Cached operators are destroyed on the // next `call()` invocation. Intended for test isolation — production // code should never call this. - static void clear_cache() { ++cache_generation_; } + static void clear_cache() { + cache_generation_.fetch_add(1, std::memory_order_acq_rel); + } template static auto Make(const Config& config, const Tensor tensor, Args&&... args) { std::unique_ptr op_ptr; @@ -193,13 +196,16 @@ class Operator : public OperatorBase { template static auto Call(const Handle& handle, const Config& config, const Args&... args) { - static std::unordered_map> + static thread_local std::unordered_map> cache; - static std::size_t generation{0}; + static thread_local std::size_t generation{0}; - if (generation != cache_generation_) { + const auto current_generation{ + cache_generation_.load(std::memory_order_acquire)}; + if (generation != current_generation) { cache.clear(); - generation = cache_generation_; + generation = current_generation; } auto key = detail::CacheKey::Build(config.implementation_index(), args...); diff --git a/tests/test_cpp_api.py b/tests/test_cpp_api.py index 86c0c1600..109ecb14f 100644 --- a/tests/test_cpp_api.py +++ b/tests/test_cpp_api.py @@ -7,18 +7,27 @@ def test_cpp_operator_call_instantiation_smoke(tmp_path): + _compile_and_run(tmp_path, "add_smoke", _ADD_SMOKE_SOURCE) + + +def test_cpp_operator_call_thread_local_cache_regression(tmp_path): + _compile_and_run(tmp_path, "thread_local_cache_probe", _THREAD_LOCAL_CACHE_SOURCE) + + +def _compile_and_run(tmp_path, stem, source_text): install_prefix = _install_prefix() - include_dir = install_prefix / "include" + include_dir = _include_dir(install_prefix) library_dir = _library_dir(install_prefix) - source = tmp_path / "add_smoke.cc" - binary = tmp_path / "add_smoke" - source.write_text(_ADD_SMOKE_SOURCE) + source = tmp_path / f"{stem}.cc" + binary = tmp_path / stem + source.write_text(source_text) _run( [ _compiler("CXX", "c++"), "-std=c++17", "-Werror", + "-pthread", f"-I{include_dir}", str(source), f"-L{library_dir}", @@ -37,16 +46,44 @@ def _install_prefix(): if prefix: return Path(prefix) - pytest.skip("`INFINIOPS_INSTALL_PREFIX` is not set.") + try: + import infini + except ImportError: + pytest.skip("INFINIOPS_INSTALL_PREFIX is not set and infini is not installed.") + + return Path(infini.__file__).resolve().parent + + +def _candidate_prefixes(prefix): + yield prefix + + try: + import infini + except ImportError: + return + + package_prefix = Path(infini.__file__).resolve().parent + if package_prefix != prefix: + yield package_prefix + + +def _include_dir(prefix): + for candidate in _candidate_prefixes(prefix): + include_dir = candidate / "include" + if (include_dir / "infini" / "ops.h").exists(): + return include_dir + + pytest.skip(f"infini/ops.h was not found under {prefix}.") def _library_dir(prefix): - for name in ("lib", "lib64"): - library_dir = prefix / name - if (library_dir / "libinfiniops.so").exists(): - return library_dir + for candidate in _candidate_prefixes(prefix): + for name in ("lib", "lib64", "."): + library_dir = candidate / name + if (library_dir / "libinfiniops.so").exists(): + return library_dir - pytest.skip(f"`libinfiniops.so` was not found under `{prefix}`.") + pytest.skip(f"libinfiniops.so was not found under {prefix}.") def _compiler(env_name, default): @@ -72,7 +109,9 @@ def _run(command): r""" #include + #include #include + #include int main() { float input_data[3] = {1.0f, 2.0f, 3.0f}; @@ -109,7 +148,124 @@ def _run(command): return 1; } - return 0; + std::atomic failures{0}; + auto threaded_call = [&]() { + float threaded_output[3] = {0.0f, 0.0f, 0.0f}; + infini::ops::Tensor threaded_out(threaded_output, shape, data_type, + device); + for (int i = 0; i < 100; ++i) { + infini::ops::Add::Call(handle, config, input, other, threaded_out); + if (std::fabs(threaded_output[0] - 5.0f) > 1e-6f || + std::fabs(threaded_output[1] - 7.0f) > 1e-6f || + std::fabs(threaded_output[2] - 9.0f) > 1e-6f) { + failures.fetch_add(1, std::memory_order_relaxed); + } + threaded_output[0] = 0.0f; + threaded_output[1] = 0.0f; + threaded_output[2] = 0.0f; + } + }; + + std::thread t0(threaded_call); + std::thread t1(threaded_call); + std::thread t2(threaded_call); + std::thread t3(threaded_call); + t0.join(); + t1.join(); + t2.join(); + t3.join(); + + return failures.load(std::memory_order_relaxed) == 0 ? 0 : 1; + } + """ +).lstrip() + + +_THREAD_LOCAL_CACHE_SOURCE = textwrap.dedent( + r""" + #include + + #include + #include + + namespace infini::ops { + + class ThreadLocalCacheProbe : public Operator { + public: + ThreadLocalCacheProbe(const Tensor input, Tensor out) {} + + virtual void operator()(const Tensor input, Tensor out) const = 0; + }; + + template <> + struct ActiveDevicesImpl { + using type = List; + }; + + template <> + class Operator + : public ThreadLocalCacheProbe { + public: + Operator(const Tensor input, Tensor out) + : ThreadLocalCacheProbe{input, out}, + owner_thread_id_{std::this_thread::get_id()} { + constructions.fetch_add(1, std::memory_order_relaxed); + } + + void operator()(const Tensor input, Tensor out) const override { + if (owner_thread_id_ != std::this_thread::get_id()) { + cross_thread_calls.fetch_add(1, std::memory_order_relaxed); + } + } + + static std::atomic constructions; + + static std::atomic cross_thread_calls; + + private: + std::thread::id owner_thread_id_; + }; + + std::atomic Operator::constructions{0}; + + std::atomic Operator::cross_thread_calls{0}; + + } // namespace infini::ops + + int main() { + float input_data[1] = {1.0f}; + float output_data[1] = {0.0f}; + + const infini::ops::Tensor::Shape shape{1}; + const infini::ops::Device device{infini::ops::Device::Type::kCpu}; + const infini::ops::DataType data_type{infini::ops::DataType::kFloat32}; + infini::ops::Tensor input(input_data, shape, data_type, device); + infini::ops::Tensor output(output_data, shape, data_type, device); + infini::ops::Handle handle; + infini::ops::Config config; + + using Probe = infini::ops::ThreadLocalCacheProbe; + using ProbeImpl = infini::ops::Operator< + Probe, infini::ops::Device::Type::kCpu>; + Probe::clear_cache(); + ProbeImpl::constructions.store(0, std::memory_order_relaxed); + ProbeImpl::cross_thread_calls.store(0, std::memory_order_relaxed); + + Probe::Call(handle, config, input, output); + + std::thread cache_probe_thread([&]() { + Probe::Call(handle, config, input, output); + }); + cache_probe_thread.join(); + + if (ProbeImpl::cross_thread_calls.load(std::memory_order_relaxed) != 0) { + return 1; + } + + return ProbeImpl::constructions.load(std::memory_order_relaxed) == 2 ? 0 + : 1; } """ ).lstrip()