Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions src/native/cuda/ops/gemm/blas.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef INFINI_OPS_CUDA_GEMM_BLAS_H_
#define INFINI_OPS_CUDA_GEMM_BLAS_H_

#include <unordered_map>
#include <utility>

#include "base/gemm.h"
Expand Down Expand Up @@ -32,7 +33,8 @@ class BlasGemm : public Gemm {
void operator()(const Tensor a, const Tensor b, std::optional<float> alpha,
std::optional<float> beta, std::optional<int> trans_a,
std::optional<int> trans_b, Tensor c) const override {
Backend::BlasSetStream(GetHandle(),
auto& handle{GetHandle(c.device())};
Backend::BlasSetStream(handle,
static_cast<typename Backend::Stream>(stream_));

const auto& alpha_value{alpha.value_or(alpha_)};
Expand All @@ -46,7 +48,7 @@ class BlasGemm : public Gemm {
const void* beta_ptr{GetBetaPtr(beta_value, c.dtype())};

Backend::BlasGemmStridedBatchedEx(
GetHandle(), op_a, op_b, swap_a_and_b_ ? n_ : m_,
handle, op_a, op_b, swap_a_and_b_ ? n_ : m_,
swap_a_and_b_ ? m_ : n_, k_, alpha_ptr,
swap_a_and_b_ ? b.data() : a.data(),
BlasUtils<Backend::kDeviceType>::GetDataType(swap_a_and_b_ ? b.dtype()
Expand Down Expand Up @@ -92,15 +94,17 @@ class BlasGemm : public Gemm {
: Backend::BLAS_OP_N;
}

// TODO: This static singleton is not thread-safe under concurrent access
// from multiple host threads. Add proper synchronization in the future.
static typename Backend::BlasHandle& GetHandle() {
static typename Backend::BlasHandle handle = []() {
typename Backend::BlasHandle h;
Backend::BlasCreate(&h);
return h;
}();
return handle;
static typename Backend::BlasHandle& GetHandle(const Device& device) {
static thread_local std::unordered_map<int, typename Backend::BlasHandle>
handles;

auto it{handles.find(device.index())};
if (it == handles.end()) {
typename Backend::BlasHandle handle;
Backend::BlasCreate(&handle);
it = handles.emplace(device.index(), handle).first;
}
return it->second;
}

bool a_is_col_major_{false};
Expand Down
18 changes: 12 additions & 6 deletions src/operator.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef INFINI_OPS_OPERATOR_H_
#define INFINI_OPS_OPERATOR_H_

#include <atomic>
#include <cassert>
#include <memory>
#include <optional>
Expand Down Expand Up @@ -138,13 +139,15 @@ class Operator : public OperatorBase {
// Generation counter for lazy cache invalidation. Bumped by
// `clear_cache()`; the next `call()` detects the mismatch and
// destroys all cached operator instances.
static inline std::size_t cache_generation_{0};
static inline std::atomic_size_t cache_generation_{0};

public:
// Invalidate the operator cache. Cached operators are destroyed on the
// next `call()` invocation. Intended for test isolation — production
// code should never call this.
static void clear_cache() { ++cache_generation_; }
static void clear_cache() {
cache_generation_.fetch_add(1, std::memory_order_acq_rel);
}
template <typename... Args>
static auto Make(const Config& config, const Tensor tensor, Args&&... args) {
std::unique_ptr<Operator> op_ptr;
Expand Down Expand Up @@ -193,13 +196,16 @@ class Operator : public OperatorBase {
template <typename... Args>
static auto Call(const Handle& handle, const Config& config,
const Args&... args) {
static std::unordered_map<detail::CacheKey, std::unique_ptr<Operator>>
static thread_local std::unordered_map<detail::CacheKey,
std::unique_ptr<Operator>>
cache;
static std::size_t generation{0};
static thread_local std::size_t generation{0};

if (generation != cache_generation_) {
const auto current_generation{
cache_generation_.load(std::memory_order_acquire)};
if (generation != current_generation) {
cache.clear();
generation = cache_generation_;
generation = current_generation;
}

auto key = detail::CacheKey::Build(config.implementation_index(), args...);
Expand Down
178 changes: 167 additions & 11 deletions tests/test_cpp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,27 @@


def test_cpp_operator_call_instantiation_smoke(tmp_path):
_compile_and_run(tmp_path, "add_smoke", _ADD_SMOKE_SOURCE)


def test_cpp_operator_call_thread_local_cache_regression(tmp_path):
_compile_and_run(tmp_path, "thread_local_cache_probe", _THREAD_LOCAL_CACHE_SOURCE)


def _compile_and_run(tmp_path, stem, source_text):
install_prefix = _install_prefix()
include_dir = install_prefix / "include"
include_dir = _include_dir(install_prefix)
library_dir = _library_dir(install_prefix)
source = tmp_path / "add_smoke.cc"
binary = tmp_path / "add_smoke"
source.write_text(_ADD_SMOKE_SOURCE)
source = tmp_path / f"{stem}.cc"
binary = tmp_path / stem
source.write_text(source_text)

_run(
[
_compiler("CXX", "c++"),
"-std=c++17",
"-Werror",
"-pthread",
f"-I{include_dir}",
str(source),
f"-L{library_dir}",
Expand All @@ -37,16 +46,44 @@ def _install_prefix():
if prefix:
return Path(prefix)

pytest.skip("`INFINIOPS_INSTALL_PREFIX` is not set.")
try:
import infini
except ImportError:
pytest.skip("INFINIOPS_INSTALL_PREFIX is not set and infini is not installed.")

return Path(infini.__file__).resolve().parent


def _candidate_prefixes(prefix):
yield prefix

try:
import infini
except ImportError:
return

package_prefix = Path(infini.__file__).resolve().parent
if package_prefix != prefix:
yield package_prefix


def _include_dir(prefix):
for candidate in _candidate_prefixes(prefix):
include_dir = candidate / "include"
if (include_dir / "infini" / "ops.h").exists():
return include_dir

pytest.skip(f"infini/ops.h was not found under {prefix}.")


def _library_dir(prefix):
for name in ("lib", "lib64"):
library_dir = prefix / name
if (library_dir / "libinfiniops.so").exists():
return library_dir
for candidate in _candidate_prefixes(prefix):
for name in ("lib", "lib64", "."):
library_dir = candidate / name
if (library_dir / "libinfiniops.so").exists():
return library_dir

pytest.skip(f"`libinfiniops.so` was not found under `{prefix}`.")
pytest.skip(f"libinfiniops.so was not found under {prefix}.")


def _compiler(env_name, default):
Expand All @@ -72,7 +109,9 @@ def _run(command):
r"""
#include <infini/ops.h>

#include <atomic>
#include <cmath>
#include <thread>

int main() {
float input_data[3] = {1.0f, 2.0f, 3.0f};
Expand Down Expand Up @@ -109,7 +148,124 @@ def _run(command):
return 1;
}

return 0;
std::atomic<int> failures{0};
auto threaded_call = [&]() {
float threaded_output[3] = {0.0f, 0.0f, 0.0f};
infini::ops::Tensor threaded_out(threaded_output, shape, data_type,
device);
for (int i = 0; i < 100; ++i) {
infini::ops::Add::Call(handle, config, input, other, threaded_out);
if (std::fabs(threaded_output[0] - 5.0f) > 1e-6f ||
std::fabs(threaded_output[1] - 7.0f) > 1e-6f ||
std::fabs(threaded_output[2] - 9.0f) > 1e-6f) {
failures.fetch_add(1, std::memory_order_relaxed);
}
threaded_output[0] = 0.0f;
threaded_output[1] = 0.0f;
threaded_output[2] = 0.0f;
}
};

std::thread t0(threaded_call);
std::thread t1(threaded_call);
std::thread t2(threaded_call);
std::thread t3(threaded_call);
t0.join();
t1.join();
t2.join();
t3.join();

return failures.load(std::memory_order_relaxed) == 0 ? 0 : 1;
}
"""
).lstrip()


_THREAD_LOCAL_CACHE_SOURCE = textwrap.dedent(
r"""
#include <infini/ops.h>

#include <atomic>
#include <thread>

namespace infini::ops {

class ThreadLocalCacheProbe : public Operator<ThreadLocalCacheProbe> {
public:
ThreadLocalCacheProbe(const Tensor input, Tensor out) {}

virtual void operator()(const Tensor input, Tensor out) const = 0;
};

template <>
struct ActiveDevicesImpl<ThreadLocalCacheProbe> {
using type = List<Device::Type::kCpu>;
};

template <>
class Operator<ThreadLocalCacheProbe, Device::Type::kCpu>
: public ThreadLocalCacheProbe {
public:
Operator(const Tensor input, Tensor out)
: ThreadLocalCacheProbe{input, out},
owner_thread_id_{std::this_thread::get_id()} {
constructions.fetch_add(1, std::memory_order_relaxed);
}

void operator()(const Tensor input, Tensor out) const override {
if (owner_thread_id_ != std::this_thread::get_id()) {
cross_thread_calls.fetch_add(1, std::memory_order_relaxed);
}
}

static std::atomic<int> constructions;

static std::atomic<int> cross_thread_calls;

private:
std::thread::id owner_thread_id_;
};

std::atomic<int> Operator<ThreadLocalCacheProbe,
Device::Type::kCpu>::constructions{0};

std::atomic<int> Operator<ThreadLocalCacheProbe,
Device::Type::kCpu>::cross_thread_calls{0};

} // namespace infini::ops

int main() {
float input_data[1] = {1.0f};
float output_data[1] = {0.0f};

const infini::ops::Tensor::Shape shape{1};
const infini::ops::Device device{infini::ops::Device::Type::kCpu};
const infini::ops::DataType data_type{infini::ops::DataType::kFloat32};
infini::ops::Tensor input(input_data, shape, data_type, device);
infini::ops::Tensor output(output_data, shape, data_type, device);
infini::ops::Handle handle;
infini::ops::Config config;

using Probe = infini::ops::ThreadLocalCacheProbe;
using ProbeImpl = infini::ops::Operator<
Probe, infini::ops::Device::Type::kCpu>;
Probe::clear_cache();
ProbeImpl::constructions.store(0, std::memory_order_relaxed);
ProbeImpl::cross_thread_calls.store(0, std::memory_order_relaxed);

Probe::Call(handle, config, input, output);

std::thread cache_probe_thread([&]() {
Probe::Call(handle, config, input, output);
});
cache_probe_thread.join();

if (ProbeImpl::cross_thread_calls.load(std::memory_order_relaxed) != 0) {
return 1;
}

return ProbeImpl::constructions.load(std::memory_order_relaxed) == 2 ? 0
: 1;
}
"""
).lstrip()
Loading