Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions src/native/cuda/nvidia/blas_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@ struct BlasUtils<Device::Type::kNvidia> {
return CUDA_R_32F;
}

static auto GetComputeType(DataType dtype) {
if (dtype == DataType::kFloat16 || dtype == DataType::kBFloat16)
return CUBLAS_COMPUTE_32F;
return CUBLAS_COMPUTE_32F_FAST_TF32;
}
static auto GetComputeType(DataType) { return CUBLAS_COMPUTE_32F; }
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个改动是必要的嘛?如果不是请不要修改。

};

} // namespace infini::ops
Expand Down
26 changes: 15 additions & 11 deletions src/native/cuda/ops/gemm/blas.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef INFINI_OPS_CUDA_GEMM_BLAS_H_
#define INFINI_OPS_CUDA_GEMM_BLAS_H_

#include <unordered_map>
#include <utility>

#include "base/gemm.h"
Expand Down Expand Up @@ -32,7 +33,8 @@ class BlasGemm : public Gemm {
void operator()(const Tensor a, const Tensor b, std::optional<float> alpha,
std::optional<float> beta, std::optional<int> trans_a,
std::optional<int> trans_b, Tensor c) const override {
Backend::BlasSetStream(GetHandle(),
auto& handle{GetHandle(c.device())};
Backend::BlasSetStream(handle,
static_cast<typename Backend::Stream>(stream_));

const auto& alpha_value{alpha.value_or(alpha_)};
Expand All @@ -46,7 +48,7 @@ class BlasGemm : public Gemm {
const void* beta_ptr{GetBetaPtr(beta_value, c.dtype())};

Backend::BlasGemmStridedBatchedEx(
GetHandle(), op_a, op_b, swap_a_and_b_ ? n_ : m_,
handle, op_a, op_b, swap_a_and_b_ ? n_ : m_,
swap_a_and_b_ ? m_ : n_, k_, alpha_ptr,
swap_a_and_b_ ? b.data() : a.data(),
BlasUtils<Backend::kDeviceType>::GetDataType(swap_a_and_b_ ? b.dtype()
Expand Down Expand Up @@ -92,15 +94,17 @@ class BlasGemm : public Gemm {
: Backend::BLAS_OP_N;
}

// TODO: This static singleton is not thread-safe under concurrent access
// from multiple host threads. Add proper synchronization in the future.
static typename Backend::BlasHandle& GetHandle() {
static typename Backend::BlasHandle handle = []() {
typename Backend::BlasHandle h;
Backend::BlasCreate(&h);
return h;
}();
return handle;
static typename Backend::BlasHandle& GetHandle(const Device& device) {
static thread_local std::unordered_map<int, typename Backend::BlasHandle>
handles;

auto it{handles.find(device.index())};
if (it == handles.end()) {
typename Backend::BlasHandle handle;
Backend::BlasCreate(&handle);
it = handles.emplace(device.index(), handle).first;
}
return it->second;
}

bool a_is_col_major_{false};
Expand Down
18 changes: 12 additions & 6 deletions src/operator.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef INFINI_OPS_OPERATOR_H_
#define INFINI_OPS_OPERATOR_H_

#include <atomic>
#include <cassert>
#include <memory>
#include <optional>
Expand Down Expand Up @@ -138,13 +139,15 @@ class Operator : public OperatorBase {
// Generation counter for lazy cache invalidation. Bumped by
// `clear_cache()`; the next `call()` detects the mismatch and
// destroys all cached operator instances.
static inline std::size_t cache_generation_{0};
static inline std::atomic_size_t cache_generation_{0};

public:
// Invalidate the operator cache. Cached operators are destroyed on the
// next `call()` invocation. Intended for test isolation — production
// code should never call this.
static void clear_cache() { ++cache_generation_; }
static void clear_cache() {
cache_generation_.fetch_add(1, std::memory_order_acq_rel);
}
template <typename... Args>
static auto Make(const Config& config, const Tensor tensor, Args&&... args) {
std::unique_ptr<Operator> op_ptr;
Expand Down Expand Up @@ -193,13 +196,16 @@ class Operator : public OperatorBase {
template <typename... Args>
static auto Call(const Handle& handle, const Config& config,
const Args&... args) {
static std::unordered_map<detail::CacheKey, std::unique_ptr<Operator>>
static thread_local std::unordered_map<detail::CacheKey,
std::unique_ptr<Operator>>
cache;
static std::size_t generation{0};
static thread_local std::size_t generation{0};

if (generation != cache_generation_) {
const auto current_generation{
cache_generation_.load(std::memory_order_acquire)};
if (generation != current_generation) {
cache.clear();
generation = cache_generation_;
generation = current_generation;
}

auto key = detail::CacheKey::Build(config.implementation_index(), args...);
Expand Down
32 changes: 31 additions & 1 deletion tests/test_cpp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_cpp_operator_call_instantiation_smoke(tmp_path):
_compiler("CXX", "c++"),
"-std=c++17",
"-Werror",
"-pthread",
f"-I{include_dir}",
str(source),
f"-L{library_dir}",
Expand Down Expand Up @@ -72,7 +73,9 @@ def _run(command):
r"""
#include <infini/ops.h>

#include <atomic>
#include <cmath>
#include <thread>

int main() {
float input_data[3] = {1.0f, 2.0f, 3.0f};
Expand Down Expand Up @@ -109,7 +112,34 @@ def _run(command):
return 1;
}

return 0;
std::atomic<int> failures{0};
auto threaded_call = [&]() {
float threaded_output[3] = {0.0f, 0.0f, 0.0f};
infini::ops::Tensor threaded_out(threaded_output, shape, data_type,
device);
for (int i = 0; i < 100; ++i) {
infini::ops::Add::Call(handle, config, input, other, threaded_out);
if (std::fabs(threaded_output[0] - 5.0f) > 1e-6f ||
std::fabs(threaded_output[1] - 7.0f) > 1e-6f ||
std::fabs(threaded_output[2] - 9.0f) > 1e-6f) {
failures.fetch_add(1, std::memory_order_relaxed);
}
threaded_output[0] = 0.0f;
threaded_output[1] = 0.0f;
threaded_output[2] = 0.0f;
}
};

std::thread t0(threaded_call);
std::thread t1(threaded_call);
std::thread t2(threaded_call);
std::thread t3(threaded_call);
t0.join();
t1.join();
t2.join();
t3.join();

return failures.load(std::memory_order_relaxed) == 0 ? 0 : 1;
}
"""
).lstrip()
13 changes: 13 additions & 0 deletions tests/test_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@
from tests.utils import Payload, get_stream, randn_strided


@pytest.fixture(autouse=True)
def _strict_cuda_fp32_reference():
old_matmul_allow_tf32 = torch.backends.cuda.matmul.allow_tf32
old_cudnn_allow_tf32 = torch.backends.cudnn.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
try:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try 前面请空一行。

yield
finally:
torch.backends.cuda.matmul.allow_tf32 = old_matmul_allow_tf32
torch.backends.cudnn.allow_tf32 = old_cudnn_allow_tf32


@pytest.mark.auto_act_and_assert
@pytest.mark.parametrize(
"a_shape, b_shape, c_shape, a_strides, b_strides, c_strides",
Expand Down
Loading