diff --git a/blas/impl/KokkosBlas_Concepts.hpp b/blas/impl/KokkosBlas_Concepts.hpp new file mode 100644 index 0000000000..a9045c1460 --- /dev/null +++ b/blas/impl/KokkosBlas_Concepts.hpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project + +#ifndef KOKKOSBLAS_CONCEPTS_HPP +#define KOKKOSBLAS_CONCEPTS_HPP + +#include +#include "KokkosBlas_util.hpp" + +namespace KokkosBlas { + +template +concept TransposeOperation = is_trans_v; + +template +concept BlasLevel2 = is_level2_v; + +template +concept BlasLevel3 = is_level3_v; + +} // namespace KokkosBlas + +#endif // KOKKOSBLAS_CONCEPTS_HPP diff --git a/blas/impl/KokkosBlas_util.hpp b/blas/impl/KokkosBlas_util.hpp index 704aaf5c4e..58bcbd629f 100644 --- a/blas/impl/KokkosBlas_util.hpp +++ b/blas/impl/KokkosBlas_util.hpp @@ -189,6 +189,24 @@ struct Algo { using Pbtrs = Level2; }; +template +struct is_level3 : std::false_type {}; + +template <> +struct is_level3 : std::true_type {}; + +template <> +struct is_level3 : std::true_type {}; + +template <> +struct is_level3 : std::true_type {}; + +template <> +struct is_level3 : std::true_type {}; + +template +static constexpr bool is_level3_v = is_level3::value; + template struct is_level2 : std::false_type {}; diff --git a/blas/unit_test/Test_Blas.hpp b/blas/unit_test/Test_Blas.hpp index 51ad4e6306..7717227c00 100644 --- a/blas/unit_test/Test_Blas.hpp +++ b/blas/unit_test/Test_Blas.hpp @@ -3,6 +3,9 @@ #ifndef TEST_BLAS_HPP #define TEST_BLAS_HPP +// Blas concepts +#include "Test_Blas_Concepts.hpp" + // Blas 1 #include "Test_Blas1_abs.hpp" #include "Test_Blas1_asum.hpp" diff --git a/blas/unit_test/Test_Blas_Concepts.hpp b/blas/unit_test/Test_Blas_Concepts.hpp new file mode 100644 index 0000000000..fe024b77c9 --- /dev/null +++ b/blas/unit_test/Test_Blas_Concepts.hpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project +#ifndef TEST_BLAS_CONCEPTS_HPP +#define TEST_BLAS_CONCEPTS_HPP + +#include +#include "KokkosBlas_Concepts.hpp" +#include + +namespace Test { + +template +struct DummyFunctor {}; + +void test_blas_concepts() { + // Check that the concepts compile for valid types + static_assert(KokkosBlas::TransposeOperation); + static_assert(KokkosBlas::TransposeOperation); + static_assert(KokkosBlas::TransposeOperation); + + // Check for level 2 concepts + static_assert(KokkosBlas::BlasLevel2); + static_assert(KokkosBlas::BlasLevel2); + static_assert(KokkosBlas::BlasLevel2); + static_assert(KokkosBlas::BlasLevel2); + + static_assert(!KokkosBlas::BlasLevel2); + static_assert(!KokkosBlas::BlasLevel2); + static_assert(!KokkosBlas::BlasLevel2); + static_assert(!KokkosBlas::BlasLevel2); + + // Check for level 3 concepts + static_assert(KokkosBlas::BlasLevel3); + static_assert(KokkosBlas::BlasLevel3); + static_assert(KokkosBlas::BlasLevel3); + static_assert(KokkosBlas::BlasLevel3); + + static_assert(!KokkosBlas::BlasLevel3); + static_assert(!KokkosBlas::BlasLevel3); + static_assert(!KokkosBlas::BlasLevel3); + static_assert(!KokkosBlas::BlasLevel3); +} + +void test_concepts_in_functor() { + [[maybe_unused]] DummyFunctor dummy_no_trans; + [[maybe_unused]] DummyFunctor dummy_trans; + [[maybe_unused]] DummyFunctor dummy_conj_trans; +} +} // namespace Test + +TEST_F(TestCategory, blas_concepts) { ::Test::test_blas_concepts(); } +TEST_F(TestCategory, concepts_in_functor) { ::Test::test_concepts_in_functor(); } + +#endif // TEST_BLAS_CONCEPTS_HPP