From 1a0bcf606c8f235a8b664fba62e9f0f5fe072acd Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Tue, 5 May 2026 02:39:34 +0900 Subject: [PATCH 1/4] fix: conflicts Signed-off-by: Yuuichi Asahi --- ...okkosBatched_InnerGemmFixA_Serial_Impl.hpp | 800 +++++++++-------- .../KokkosBatched_InnerTrsm_Serial_Impl.hpp | 849 +++++++++--------- .../impl/KokkosBatched_LU_Serial_Internal.hpp | 4 +- .../impl/KokkosBatched_LU_Team_Internal.hpp | 4 +- .../impl/KokkosBatched_Trsm_Serial_Impl.hpp | 120 ++- .../KokkosBatched_Trsm_Serial_Internal.hpp | 235 +++-- .../impl/KokkosBatched_Trsm_Team_Internal.hpp | 8 +- .../impl/KokkosBatched_Trsv_Serial_Impl.hpp | 36 +- .../KokkosBatched_Trsv_Serial_Internal.hpp | 100 +-- .../impl/KokkosBatched_Trsv_Team_Internal.hpp | 8 +- .../src/KokkosBatched_InnerGemmFixA_Decl.hpp | 12 +- .../src/KokkosBatched_InnerTrsm_Decl.hpp | 49 +- .../unit_test/Test_Batched_SerialTrsv.hpp | 30 +- blas/impl/KokkosBlas3_trsm_impl.hpp | 64 +- 14 files changed, 1199 insertions(+), 1120 deletions(-) diff --git a/batched/dense/impl/KokkosBatched_InnerGemmFixA_Serial_Impl.hpp b/batched/dense/impl/KokkosBatched_InnerGemmFixA_Serial_Impl.hpp index e3be5cfc04..b37b941d5e 100644 --- a/batched/dense/impl/KokkosBatched_InnerGemmFixA_Serial_Impl.hpp +++ b/batched/dense/impl/KokkosBatched_InnerGemmFixA_Serial_Impl.hpp @@ -15,34 +15,38 @@ namespace KokkosBatched { /// ================== template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<5, 5>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<5, 5>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int n, /**/ ValueType *KOKKOS_RESTRICT C) { if (n <= 0) return 0; - const ValueType a_00 = A[0 * _as0 + 0 * _as1], a_01 = A[0 * _as0 + 1 * _as1], a_02 = A[0 * _as0 + 2 * _as1], - a_03 = A[0 * _as0 + 3 * _as1], a_04 = A[0 * _as0 + 4 * _as1], a_10 = A[1 * _as0 + 0 * _as1], - a_11 = A[1 * _as0 + 1 * _as1], a_12 = A[1 * _as0 + 2 * _as1], a_13 = A[1 * _as0 + 3 * _as1], - a_14 = A[1 * _as0 + 4 * _as1], a_20 = A[2 * _as0 + 0 * _as1], a_21 = A[2 * _as0 + 1 * _as1], - a_22 = A[2 * _as0 + 2 * _as1], a_23 = A[2 * _as0 + 3 * _as1], a_24 = A[2 * _as0 + 4 * _as1], - a_30 = A[3 * _as0 + 0 * _as1], a_31 = A[3 * _as0 + 1 * _as1], a_32 = A[3 * _as0 + 2 * _as1], - a_33 = A[3 * _as0 + 3 * _as1], a_34 = A[3 * _as0 + 4 * _as1], a_40 = A[4 * _as0 + 0 * _as1], - a_41 = A[4 * _as0 + 1 * _as1], a_42 = A[4 * _as0 + 2 * _as1], a_43 = A[4 * _as0 + 3 * _as1], - a_44 = A[4 * _as0 + 4 * _as1]; + const ValueType a_00 = opA(A[0 * m_as0 + 0 * m_as1]), a_01 = opA(A[0 * m_as0 + 1 * m_as1]), + a_02 = opA(A[0 * m_as0 + 2 * m_as1]), a_03 = opA(A[0 * m_as0 + 3 * m_as1]), + a_04 = opA(A[0 * m_as0 + 4 * m_as1]), a_10 = opA(A[1 * m_as0 + 0 * m_as1]), + a_11 = opA(A[1 * m_as0 + 1 * m_as1]), a_12 = opA(A[1 * m_as0 + 2 * m_as1]), + a_13 = opA(A[1 * m_as0 + 3 * m_as1]), a_14 = opA(A[1 * m_as0 + 4 * m_as1]), + a_20 = opA(A[2 * m_as0 + 0 * m_as1]), a_21 = opA(A[2 * m_as0 + 1 * m_as1]), + a_22 = opA(A[2 * m_as0 + 2 * m_as1]), a_23 = opA(A[2 * m_as0 + 3 * m_as1]), + a_24 = opA(A[2 * m_as0 + 4 * m_as1]), a_30 = opA(A[3 * m_as0 + 0 * m_as1]), + a_31 = opA(A[3 * m_as0 + 1 * m_as1]), a_32 = opA(A[3 * m_as0 + 2 * m_as1]), + a_33 = opA(A[3 * m_as0 + 3 * m_as1]), a_34 = opA(A[3 * m_as0 + 4 * m_as1]), + a_40 = opA(A[4 * m_as0 + 0 * m_as1]), a_41 = opA(A[4 * m_as0 + 1 * m_as1]), + a_42 = opA(A[4 * m_as0 + 2 * m_as1]), a_43 = opA(A[4 * m_as0 + 3 * m_as1]), + a_44 = opA(A[4 * m_as0 + 4 * m_as1]); ValueType b_0p, c_0p, b_1p, c_1p, b_2p, c_2p, b_3p, c_3p, b_4p, c_4p; - const int ib0 = 0 * _bs0, ib1 = 1 * _bs0, ib2 = 2 * _bs0, ib3 = 3 * _bs0, ib4 = 4 * _bs0, ic0 = 0 * _cs0, - ic1 = 1 * _cs0, ic2 = 2 * _cs0, ic3 = 3 * _cs0, ic4 = 4 * _cs0; + const int ib0 = 0 * m_bs0, ib1 = 1 * m_bs0, ib2 = 2 * m_bs0, ib3 = 3 * m_bs0, ib4 = 4 * m_bs0, ic0 = 0 * m_cs0, + ic1 = 1 * m_cs0, ic2 = 2 * m_cs0, ic3 = 3 * m_cs0, ic4 = 4 * m_cs0; for (int p = 0; p < n; ++p) { - b_0p = B[ib0 + p * _bs1]; - b_1p = B[ib1 + p * _bs1]; - b_2p = B[ib2 + p * _bs1]; - b_3p = B[ib3 + p * _bs1]; - b_4p = B[ib4 + p * _bs1]; + b_0p = opB(B[ib0 + p * m_bs1]); + b_1p = opB(B[ib1 + p * m_bs1]); + b_2p = opB(B[ib2 + p * m_bs1]); + b_3p = opB(B[ib3 + p * m_bs1]); + b_4p = opB(B[ib4 + p * m_bs1]); c_0p = a_00 * b_0p; c_0p += a_01 * b_1p; @@ -70,43 +74,46 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<5, 5>::serial_invoke(const ScalarType a c_4p += a_43 * b_3p; c_4p += a_44 * b_4p; - C[ic0 + p * _cs1] += alpha * c_0p; - C[ic1 + p * _cs1] += alpha * c_1p; - C[ic2 + p * _cs1] += alpha * c_2p; - C[ic3 + p * _cs1] += alpha * c_3p; - C[ic4 + p * _cs1] += alpha * c_4p; + C[ic0 + p * m_cs1] += alpha * c_0p; + C[ic1 + p * m_cs1] += alpha * c_1p; + C[ic2 + p * m_cs1] += alpha * c_2p; + C[ic3 + p * m_cs1] += alpha * c_3p; + C[ic4 + p * m_cs1] += alpha * c_4p; } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<5, 4>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<5, 4>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int n, /**/ ValueType *KOKKOS_RESTRICT C) { if (n <= 0) return 0; - const ValueType a_00 = A[0 * _as0 + 0 * _as1], a_01 = A[0 * _as0 + 1 * _as1], a_02 = A[0 * _as0 + 2 * _as1], - a_03 = A[0 * _as0 + 3 * _as1], a_10 = A[1 * _as0 + 0 * _as1], a_11 = A[1 * _as0 + 1 * _as1], - a_12 = A[1 * _as0 + 2 * _as1], a_13 = A[1 * _as0 + 3 * _as1], a_20 = A[2 * _as0 + 0 * _as1], - a_21 = A[2 * _as0 + 1 * _as1], a_22 = A[2 * _as0 + 2 * _as1], a_23 = A[2 * _as0 + 3 * _as1], - a_30 = A[3 * _as0 + 0 * _as1], a_31 = A[3 * _as0 + 1 * _as1], a_32 = A[3 * _as0 + 2 * _as1], - a_33 = A[3 * _as0 + 3 * _as1], a_40 = A[4 * _as0 + 0 * _as1], a_41 = A[4 * _as0 + 1 * _as1], - a_42 = A[4 * _as0 + 2 * _as1], a_43 = A[4 * _as0 + 3 * _as1]; + const ValueType a_00 = opA(A[0 * m_as0 + 0 * m_as1]), a_01 = opA(A[0 * m_as0 + 1 * m_as1]), + a_02 = opA(A[0 * m_as0 + 2 * m_as1]), a_03 = opA(A[0 * m_as0 + 3 * m_as1]), + a_10 = opA(A[1 * m_as0 + 0 * m_as1]), a_11 = opA(A[1 * m_as0 + 1 * m_as1]), + a_12 = opA(A[1 * m_as0 + 2 * m_as1]), a_13 = opA(A[1 * m_as0 + 3 * m_as1]), + a_20 = opA(A[2 * m_as0 + 0 * m_as1]), a_21 = opA(A[2 * m_as0 + 1 * m_as1]), + a_22 = opA(A[2 * m_as0 + 2 * m_as1]), a_23 = opA(A[2 * m_as0 + 3 * m_as1]), + a_30 = opA(A[3 * m_as0 + 0 * m_as1]), a_31 = opA(A[3 * m_as0 + 1 * m_as1]), + a_32 = opA(A[3 * m_as0 + 2 * m_as1]), a_33 = opA(A[3 * m_as0 + 3 * m_as1]), + a_40 = opA(A[4 * m_as0 + 0 * m_as1]), a_41 = opA(A[4 * m_as0 + 1 * m_as1]), + a_42 = opA(A[4 * m_as0 + 2 * m_as1]), a_43 = opA(A[4 * m_as0 + 3 * m_as1]); ValueType b_0p, c_0p, b_1p, c_1p, b_2p, c_2p, b_3p, c_3p, /**/ c_4p; - const int ib0 = 0 * _bs0, ib1 = 1 * _bs0, ib2 = 2 * _bs0, ib3 = 3 * _bs0, ic0 = 0 * _cs0, ic1 = 1 * _cs0, - ic2 = 2 * _cs0, ic3 = 3 * _cs0, ic4 = 4 * _cs0; + const int ib0 = 0 * m_bs0, ib1 = 1 * m_bs0, ib2 = 2 * m_bs0, ib3 = 3 * m_bs0, ic0 = 0 * m_cs0, ic1 = 1 * m_cs0, + ic2 = 2 * m_cs0, ic3 = 3 * m_cs0, ic4 = 4 * m_cs0; for (int p = 0; p < n; ++p) { - b_0p = B[ib0 + p * _bs1]; - b_1p = B[ib1 + p * _bs1]; - b_2p = B[ib2 + p * _bs1]; - b_3p = B[ib3 + p * _bs1]; + b_0p = opB(B[ib0 + p * m_bs1]); + b_1p = opB(B[ib1 + p * m_bs1]); + b_2p = opB(B[ib2 + p * m_bs1]); + b_3p = opB(B[ib3 + p * m_bs1]); c_0p = a_00 * b_0p; c_0p += a_01 * b_1p; @@ -129,41 +136,44 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<5, 4>::serial_invoke(const ScalarType a c_4p += a_42 * b_2p; c_4p += a_43 * b_3p; - C[ic0 + p * _cs1] += alpha * c_0p; - C[ic1 + p * _cs1] += alpha * c_1p; - C[ic2 + p * _cs1] += alpha * c_2p; - C[ic3 + p * _cs1] += alpha * c_3p; - C[ic4 + p * _cs1] += alpha * c_4p; + C[ic0 + p * m_cs1] += alpha * c_0p; + C[ic1 + p * m_cs1] += alpha * c_1p; + C[ic2 + p * m_cs1] += alpha * c_2p; + C[ic3 + p * m_cs1] += alpha * c_3p; + C[ic4 + p * m_cs1] += alpha * c_4p; } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<5, 3>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<5, 3>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int n, /**/ ValueType *KOKKOS_RESTRICT C) { if (n <= 0) return 0; - const ValueType a_00 = A[0 * _as0 + 0 * _as1], a_01 = A[0 * _as0 + 1 * _as1], a_02 = A[0 * _as0 + 2 * _as1], - a_10 = A[1 * _as0 + 0 * _as1], a_11 = A[1 * _as0 + 1 * _as1], a_12 = A[1 * _as0 + 2 * _as1], - a_20 = A[2 * _as0 + 0 * _as1], a_21 = A[2 * _as0 + 1 * _as1], a_22 = A[2 * _as0 + 2 * _as1], - a_30 = A[3 * _as0 + 0 * _as1], a_31 = A[3 * _as0 + 1 * _as1], a_32 = A[3 * _as0 + 2 * _as1], - a_40 = A[4 * _as0 + 0 * _as1], a_41 = A[4 * _as0 + 1 * _as1], a_42 = A[4 * _as0 + 2 * _as1]; + const ValueType a_00 = opA(A[0 * m_as0 + 0 * m_as1]), a_01 = opA(A[0 * m_as0 + 1 * m_as1]), + a_02 = opA(A[0 * m_as0 + 2 * m_as1]), a_10 = opA(A[1 * m_as0 + 0 * m_as1]), + a_11 = opA(A[1 * m_as0 + 1 * m_as1]), a_12 = opA(A[1 * m_as0 + 2 * m_as1]), + a_20 = opA(A[2 * m_as0 + 0 * m_as1]), a_21 = opA(A[2 * m_as0 + 1 * m_as1]), + a_22 = opA(A[2 * m_as0 + 2 * m_as1]), a_30 = opA(A[3 * m_as0 + 0 * m_as1]), + a_31 = opA(A[3 * m_as0 + 1 * m_as1]), a_32 = opA(A[3 * m_as0 + 2 * m_as1]), + a_40 = opA(A[4 * m_as0 + 0 * m_as1]), a_41 = opA(A[4 * m_as0 + 1 * m_as1]), + a_42 = opA(A[4 * m_as0 + 2 * m_as1]); ValueType b_0p, c_0p, b_1p, c_1p, b_2p, c_2p, /**/ c_3p, /**/ c_4p; - const int ib0 = 0 * _bs0, ib1 = 1 * _bs0, ib2 = 2 * _bs0, ic0 = 0 * _cs0, ic1 = 1 * _cs0, ic2 = 2 * _cs0, - ic3 = 3 * _cs0, ic4 = 4 * _cs0; + const int ib0 = 0 * m_bs0, ib1 = 1 * m_bs0, ib2 = 2 * m_bs0, ic0 = 0 * m_cs0, ic1 = 1 * m_cs0, ic2 = 2 * m_cs0, + ic3 = 3 * m_cs0, ic4 = 4 * m_cs0; for (int p = 0; p < n; ++p) { - b_0p = B[ib0 + p * _bs1]; - b_1p = B[ib1 + p * _bs1]; - b_2p = B[ib2 + p * _bs1]; + b_0p = opB(B[ib0 + p * m_bs1]); + b_1p = opB(B[ib1 + p * m_bs1]); + b_2p = opB(B[ib2 + p * m_bs1]); c_0p = a_00 * b_0p; c_0p += a_01 * b_1p; @@ -181,39 +191,40 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<5, 3>::serial_invoke(const ScalarType a c_4p += a_41 * b_1p; c_4p += a_42 * b_2p; - C[ic0 + p * _cs1] += alpha * c_0p; - C[ic1 + p * _cs1] += alpha * c_1p; - C[ic2 + p * _cs1] += alpha * c_2p; - C[ic3 + p * _cs1] += alpha * c_3p; - C[ic4 + p * _cs1] += alpha * c_4p; + C[ic0 + p * m_cs1] += alpha * c_0p; + C[ic1 + p * m_cs1] += alpha * c_1p; + C[ic2 + p * m_cs1] += alpha * c_2p; + C[ic3 + p * m_cs1] += alpha * c_3p; + C[ic4 + p * m_cs1] += alpha * c_4p; } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<5, 2>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<5, 2>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int n, /**/ ValueType *KOKKOS_RESTRICT C) { if (n <= 0) return 0; - const ValueType a_00 = A[0 * _as0 + 0 * _as1], a_01 = A[0 * _as0 + 1 * _as1], a_10 = A[1 * _as0 + 0 * _as1], - a_11 = A[1 * _as0 + 1 * _as1], a_20 = A[2 * _as0 + 0 * _as1], a_21 = A[2 * _as0 + 1 * _as1], - a_30 = A[3 * _as0 + 0 * _as1], a_31 = A[3 * _as0 + 1 * _as1], a_40 = A[4 * _as0 + 0 * _as1], - a_41 = A[4 * _as0 + 1 * _as1]; + const ValueType a_00 = opA(A[0 * m_as0 + 0 * m_as1]), a_01 = opA(A[0 * m_as0 + 1 * m_as1]), + a_10 = opA(A[1 * m_as0 + 0 * m_as1]), a_11 = opA(A[1 * m_as0 + 1 * m_as1]), + a_20 = opA(A[2 * m_as0 + 0 * m_as1]), a_21 = opA(A[2 * m_as0 + 1 * m_as1]), + a_30 = opA(A[3 * m_as0 + 0 * m_as1]), a_31 = opA(A[3 * m_as0 + 1 * m_as1]), + a_40 = opA(A[4 * m_as0 + 0 * m_as1]), a_41 = opA(A[4 * m_as0 + 1 * m_as1]); ValueType b_0p, c_0p, b_1p, c_1p, /**/ c_2p, /**/ c_3p, /**/ c_4p; - const int ib0 = 0 * _bs0, ib1 = 1 * _bs0, ic0 = 0 * _cs0, ic1 = 1 * _cs0, ic2 = 2 * _cs0, ic3 = 3 * _cs0, - ic4 = 4 * _cs0; + const int ib0 = 0 * m_bs0, ib1 = 1 * m_bs0, ic0 = 0 * m_cs0, ic1 = 1 * m_cs0, ic2 = 2 * m_cs0, ic3 = 3 * m_cs0, + ic4 = 4 * m_cs0; for (int p = 0; p < n; ++p) { - b_0p = B[ib0 + p * _bs1]; - b_1p = B[ib1 + p * _bs1]; + b_0p = opB(B[ib0 + p * m_bs1]); + b_1p = opB(B[ib1 + p * m_bs1]); c_0p = a_00 * b_0p; c_0p += a_01 * b_1p; @@ -226,25 +237,26 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<5, 2>::serial_invoke(const ScalarType a c_4p = a_40 * b_0p; c_4p += a_41 * b_1p; - C[ic0 + p * _cs1] += alpha * c_0p; - C[ic1 + p * _cs1] += alpha * c_1p; - C[ic2 + p * _cs1] += alpha * c_2p; - C[ic3 + p * _cs1] += alpha * c_3p; - C[ic4 + p * _cs1] += alpha * c_4p; + C[ic0 + p * m_cs1] += alpha * c_0p; + C[ic1 + p * m_cs1] += alpha * c_1p; + C[ic2 + p * m_cs1] += alpha * c_2p; + C[ic3 + p * m_cs1] += alpha * c_3p; + C[ic4 + p * m_cs1] += alpha * c_4p; } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<5, 1>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<5, 1>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int n, /**/ ValueType *KOKKOS_RESTRICT C) { if (n <= 0) return 0; - const ValueType a_00 = A[0 * _as0 + 0 * _as1], a_10 = A[1 * _as0 + 0 * _as1], a_20 = A[2 * _as0 + 0 * _as1], - a_30 = A[3 * _as0 + 0 * _as1], a_40 = A[4 * _as0 + 0 * _as1]; + const ValueType a_00 = opA(A[0 * m_as0 + 0 * m_as1]), a_10 = opA(A[1 * m_as0 + 0 * m_as1]), + a_20 = opA(A[2 * m_as0 + 0 * m_as1]), a_30 = opA(A[3 * m_as0 + 0 * m_as1]), + a_40 = opA(A[4 * m_as0 + 0 * m_as1]); ValueType b_0p, c_0p, /**/ c_1p, @@ -252,10 +264,10 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<5, 1>::serial_invoke(const ScalarType a /**/ c_3p, /**/ c_4p; - const int ib0 = 0 * _bs0, ic0 = 0 * _cs0, ic1 = 1 * _cs0, ic2 = 2 * _cs0, ic3 = 3 * _cs0, ic4 = 4 * _cs0; + const int ib0 = 0 * m_bs0, ic0 = 0 * m_cs0, ic1 = 1 * m_cs0, ic2 = 2 * m_cs0, ic3 = 3 * m_cs0, ic4 = 4 * m_cs0; for (int p = 0; p < n; ++p) { - b_0p = B[ib0 + p * _bs1]; + b_0p = opB(B[ib0 + p * m_bs1]); c_0p = a_00 * b_0p; c_1p = a_10 * b_0p; @@ -263,43 +275,46 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<5, 1>::serial_invoke(const ScalarType a c_3p = a_30 * b_0p; c_4p = a_40 * b_0p; - C[ic0 + p * _cs1] += alpha * c_0p; - C[ic1 + p * _cs1] += alpha * c_1p; - C[ic2 + p * _cs1] += alpha * c_2p; - C[ic3 + p * _cs1] += alpha * c_3p; - C[ic4 + p * _cs1] += alpha * c_4p; + C[ic0 + p * m_cs1] += alpha * c_0p; + C[ic1 + p * m_cs1] += alpha * c_1p; + C[ic2 + p * m_cs1] += alpha * c_2p; + C[ic3 + p * m_cs1] += alpha * c_3p; + C[ic4 + p * m_cs1] += alpha * c_4p; } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<4, 5>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<4, 5>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int n, /**/ ValueType *KOKKOS_RESTRICT C) { if (n <= 0) return 0; - const ValueType a_00 = A[0 * _as0 + 0 * _as1], a_01 = A[0 * _as0 + 1 * _as1], a_02 = A[0 * _as0 + 2 * _as1], - a_03 = A[0 * _as0 + 3 * _as1], a_04 = A[0 * _as0 + 4 * _as1], a_10 = A[1 * _as0 + 0 * _as1], - a_11 = A[1 * _as0 + 1 * _as1], a_12 = A[1 * _as0 + 2 * _as1], a_13 = A[1 * _as0 + 3 * _as1], - a_14 = A[1 * _as0 + 4 * _as1], a_20 = A[2 * _as0 + 0 * _as1], a_21 = A[2 * _as0 + 1 * _as1], - a_22 = A[2 * _as0 + 2 * _as1], a_23 = A[2 * _as0 + 3 * _as1], a_24 = A[2 * _as0 + 4 * _as1], - a_30 = A[3 * _as0 + 0 * _as1], a_31 = A[3 * _as0 + 1 * _as1], a_32 = A[3 * _as0 + 2 * _as1], - a_33 = A[3 * _as0 + 3 * _as1], a_34 = A[3 * _as0 + 4 * _as1]; + const ValueType a_00 = opA(A[0 * m_as0 + 0 * m_as1]), a_01 = opA(A[0 * m_as0 + 1 * m_as1]), + a_02 = opA(A[0 * m_as0 + 2 * m_as1]), a_03 = opA(A[0 * m_as0 + 3 * m_as1]), + a_04 = opA(A[0 * m_as0 + 4 * m_as1]), a_10 = opA(A[1 * m_as0 + 0 * m_as1]), + a_11 = opA(A[1 * m_as0 + 1 * m_as1]), a_12 = opA(A[1 * m_as0 + 2 * m_as1]), + a_13 = opA(A[1 * m_as0 + 3 * m_as1]), a_14 = opA(A[1 * m_as0 + 4 * m_as1]), + a_20 = opA(A[2 * m_as0 + 0 * m_as1]), a_21 = opA(A[2 * m_as0 + 1 * m_as1]), + a_22 = opA(A[2 * m_as0 + 2 * m_as1]), a_23 = opA(A[2 * m_as0 + 3 * m_as1]), + a_24 = opA(A[2 * m_as0 + 4 * m_as1]), a_30 = opA(A[3 * m_as0 + 0 * m_as1]), + a_31 = opA(A[3 * m_as0 + 1 * m_as1]), a_32 = opA(A[3 * m_as0 + 2 * m_as1]), + a_33 = opA(A[3 * m_as0 + 3 * m_as1]), a_34 = opA(A[3 * m_as0 + 4 * m_as1]); ValueType b_0p, c_0p, b_1p, c_1p, b_2p, c_2p, b_3p, c_3p, b_4p; - const int ib0 = 0 * _bs0, ib1 = 1 * _bs0, ib2 = 2 * _bs0, ib3 = 3 * _bs0, ib4 = 4 * _bs0, ic0 = 0 * _cs0, - ic1 = 1 * _cs0, ic2 = 2 * _cs0, ic3 = 3 * _cs0; + const int ib0 = 0 * m_bs0, ib1 = 1 * m_bs0, ib2 = 2 * m_bs0, ib3 = 3 * m_bs0, ib4 = 4 * m_bs0, ic0 = 0 * m_cs0, + ic1 = 1 * m_cs0, ic2 = 2 * m_cs0, ic3 = 3 * m_cs0; for (int p = 0; p < n; ++p) { - b_0p = B[ib0 + p * _bs1]; - b_1p = B[ib1 + p * _bs1]; - b_2p = B[ib2 + p * _bs1]; - b_3p = B[ib3 + p * _bs1]; - b_4p = B[ib4 + p * _bs1]; + b_0p = opB(B[ib0 + p * m_bs1]); + b_1p = opB(B[ib1 + p * m_bs1]); + b_2p = opB(B[ib2 + p * m_bs1]); + b_3p = opB(B[ib3 + p * m_bs1]); + b_4p = opB(B[ib4 + p * m_bs1]); c_0p = a_00 * b_0p; c_0p += a_01 * b_1p; @@ -322,40 +337,43 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<4, 5>::serial_invoke(const ScalarType a c_3p += a_33 * b_3p; c_3p += a_34 * b_4p; - C[ic0 + p * _cs1] += alpha * c_0p; - C[ic1 + p * _cs1] += alpha * c_1p; - C[ic2 + p * _cs1] += alpha * c_2p; - C[ic3 + p * _cs1] += alpha * c_3p; + C[ic0 + p * m_cs1] += alpha * c_0p; + C[ic1 + p * m_cs1] += alpha * c_1p; + C[ic2 + p * m_cs1] += alpha * c_2p; + C[ic3 + p * m_cs1] += alpha * c_3p; } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<3, 5>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<3, 5>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int n, /**/ ValueType *KOKKOS_RESTRICT C) { if (n <= 0) return 0; - const ValueType a_00 = A[0 * _as0 + 0 * _as1], a_01 = A[0 * _as0 + 1 * _as1], a_02 = A[0 * _as0 + 2 * _as1], - a_03 = A[0 * _as0 + 3 * _as1], a_04 = A[0 * _as0 + 4 * _as1], a_10 = A[1 * _as0 + 0 * _as1], - a_11 = A[1 * _as0 + 1 * _as1], a_12 = A[1 * _as0 + 2 * _as1], a_13 = A[1 * _as0 + 3 * _as1], - a_14 = A[1 * _as0 + 4 * _as1], a_20 = A[2 * _as0 + 0 * _as1], a_21 = A[2 * _as0 + 1 * _as1], - a_22 = A[2 * _as0 + 2 * _as1], a_23 = A[2 * _as0 + 3 * _as1], a_24 = A[2 * _as0 + 4 * _as1]; + const ValueType a_00 = opA(A[0 * m_as0 + 0 * m_as1]), a_01 = opA(A[0 * m_as0 + 1 * m_as1]), + a_02 = opA(A[0 * m_as0 + 2 * m_as1]), a_03 = opA(A[0 * m_as0 + 3 * m_as1]), + a_04 = opA(A[0 * m_as0 + 4 * m_as1]), a_10 = opA(A[1 * m_as0 + 0 * m_as1]), + a_11 = opA(A[1 * m_as0 + 1 * m_as1]), a_12 = opA(A[1 * m_as0 + 2 * m_as1]), + a_13 = opA(A[1 * m_as0 + 3 * m_as1]), a_14 = opA(A[1 * m_as0 + 4 * m_as1]), + a_20 = opA(A[2 * m_as0 + 0 * m_as1]), a_21 = opA(A[2 * m_as0 + 1 * m_as1]), + a_22 = opA(A[2 * m_as0 + 2 * m_as1]), a_23 = opA(A[2 * m_as0 + 3 * m_as1]), + a_24 = opA(A[2 * m_as0 + 4 * m_as1]); ValueType b_0p, c_0p, b_1p, c_1p, b_2p, c_2p, b_3p, b_4p; - const int ib0 = 0 * _bs0, ib1 = 1 * _bs0, ib2 = 2 * _bs0, ib3 = 3 * _bs0, ib4 = 4 * _bs0, ic0 = 0 * _cs0, - ic1 = 1 * _cs0, ic2 = 2 * _cs0; + const int ib0 = 0 * m_bs0, ib1 = 1 * m_bs0, ib2 = 2 * m_bs0, ib3 = 3 * m_bs0, ib4 = 4 * m_bs0, ic0 = 0 * m_cs0, + ic1 = 1 * m_cs0, ic2 = 2 * m_cs0; for (int p = 0; p < n; ++p) { - b_0p = B[ib0 + p * _bs1]; - b_1p = B[ib1 + p * _bs1]; - b_2p = B[ib2 + p * _bs1]; - b_3p = B[ib3 + p * _bs1]; - b_4p = B[ib4 + p * _bs1]; + b_0p = opB(B[ib0 + p * m_bs1]); + b_1p = opB(B[ib1 + p * m_bs1]); + b_2p = opB(B[ib2 + p * m_bs1]); + b_3p = opB(B[ib3 + p * m_bs1]); + b_4p = opB(B[ib4 + p * m_bs1]); c_0p = a_00 * b_0p; c_0p += a_01 * b_1p; @@ -373,37 +391,38 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<3, 5>::serial_invoke(const ScalarType a c_2p += a_23 * b_3p; c_2p += a_24 * b_4p; - C[ic0 + p * _cs1] += alpha * c_0p; - C[ic1 + p * _cs1] += alpha * c_1p; - C[ic2 + p * _cs1] += alpha * c_2p; + C[ic0 + p * m_cs1] += alpha * c_0p; + C[ic1 + p * m_cs1] += alpha * c_1p; + C[ic2 + p * m_cs1] += alpha * c_2p; } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<2, 5>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<2, 5>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int n, /**/ ValueType *KOKKOS_RESTRICT C) { if (n <= 0) return 0; - const ValueType a_00 = A[0 * _as0 + 0 * _as1], a_01 = A[0 * _as0 + 1 * _as1], a_02 = A[0 * _as0 + 2 * _as1], - a_03 = A[0 * _as0 + 3 * _as1], a_04 = A[0 * _as0 + 4 * _as1], a_10 = A[1 * _as0 + 0 * _as1], - a_11 = A[1 * _as0 + 1 * _as1], a_12 = A[1 * _as0 + 2 * _as1], a_13 = A[1 * _as0 + 3 * _as1], - a_14 = A[1 * _as0 + 4 * _as1]; + const ValueType a_00 = opA(A[0 * m_as0 + 0 * m_as1]), a_01 = opA(A[0 * m_as0 + 1 * m_as1]), + a_02 = opA(A[0 * m_as0 + 2 * m_as1]), a_03 = opA(A[0 * m_as0 + 3 * m_as1]), + a_04 = opA(A[0 * m_as0 + 4 * m_as1]), a_10 = opA(A[1 * m_as0 + 0 * m_as1]), + a_11 = opA(A[1 * m_as0 + 1 * m_as1]), a_12 = opA(A[1 * m_as0 + 2 * m_as1]), + a_13 = opA(A[1 * m_as0 + 3 * m_as1]), a_14 = opA(A[1 * m_as0 + 4 * m_as1]); ValueType b_0p, c_0p, b_1p, c_1p, b_2p, b_3p, b_4p; - const int ib0 = 0 * _bs0, ib1 = 1 * _bs0, ib2 = 2 * _bs0, ib3 = 3 * _bs0, ib4 = 4 * _bs0, ic0 = 0 * _cs0, - ic1 = 1 * _cs0; + const int ib0 = 0 * m_bs0, ib1 = 1 * m_bs0, ib2 = 2 * m_bs0, ib3 = 3 * m_bs0, ib4 = 4 * m_bs0, ic0 = 0 * m_cs0, + ic1 = 1 * m_cs0; for (int p = 0; p < n; ++p) { - b_0p = B[ib0 + p * _bs1]; - b_1p = B[ib1 + p * _bs1]; - b_2p = B[ib2 + p * _bs1]; - b_3p = B[ib3 + p * _bs1]; - b_4p = B[ib4 + p * _bs1]; + b_0p = opB(B[ib0 + p * m_bs1]); + b_1p = opB(B[ib1 + p * m_bs1]); + b_2p = opB(B[ib2 + p * m_bs1]); + b_3p = opB(B[ib3 + p * m_bs1]); + b_4p = opB(B[ib4 + p * m_bs1]); c_0p = a_00 * b_0p; c_0p += a_01 * b_1p; @@ -416,33 +435,34 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<2, 5>::serial_invoke(const ScalarType a c_1p += a_13 * b_3p; c_1p += a_14 * b_4p; - C[ic0 + p * _cs1] += alpha * c_0p; - C[ic1 + p * _cs1] += alpha * c_1p; + C[ic0 + p * m_cs1] += alpha * c_0p; + C[ic1 + p * m_cs1] += alpha * c_1p; } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<1, 5>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<1, 5>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int n, /**/ ValueType *KOKKOS_RESTRICT C) { if (n <= 0) return 0; - const ValueType a_00 = A[0 * _as0 + 0 * _as1], a_01 = A[0 * _as0 + 1 * _as1], a_02 = A[0 * _as0 + 2 * _as1], - a_03 = A[0 * _as0 + 3 * _as1], a_04 = A[0 * _as0 + 4 * _as1]; + const ValueType a_00 = opA(A[0 * m_as0 + 0 * m_as1]), a_01 = opA(A[0 * m_as0 + 1 * m_as1]), + a_02 = opA(A[0 * m_as0 + 2 * m_as1]), a_03 = opA(A[0 * m_as0 + 3 * m_as1]), + a_04 = opA(A[0 * m_as0 + 4 * m_as1]); ValueType b_0p, c_0p, b_1p, b_2p, b_3p, b_4p; - const int ib0 = 0 * _bs0, ib1 = 1 * _bs0, ib2 = 2 * _bs0, ib3 = 3 * _bs0, ib4 = 4 * _bs0, ic0 = 0 * _cs0; + const int ib0 = 0 * m_bs0, ib1 = 1 * m_bs0, ib2 = 2 * m_bs0, ib3 = 3 * m_bs0, ib4 = 4 * m_bs0, ic0 = 0 * m_cs0; for (int p = 0; p < n; ++p) { - b_0p = B[ib0 + p * _bs1]; - b_1p = B[ib1 + p * _bs1]; - b_2p = B[ib2 + p * _bs1]; - b_3p = B[ib3 + p * _bs1]; - b_4p = B[ib4 + p * _bs1]; + b_0p = opB(B[ib0 + p * m_bs1]); + b_1p = opB(B[ib1 + p * m_bs1]); + b_2p = opB(B[ib2 + p * m_bs1]); + b_3p = opB(B[ib3 + p * m_bs1]); + b_4p = opB(B[ib4 + p * m_bs1]); c_0p = a_00 * b_0p; c_0p += a_01 * b_1p; @@ -450,15 +470,15 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<1, 5>::serial_invoke(const ScalarType a c_0p += a_03 * b_3p; c_0p += a_04 * b_4p; - C[ic0 + p * _cs1] += alpha * c_0p; + C[ic0 + p * m_cs1] += alpha * c_0p; } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<5, 5>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<5, 5>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int m, const int n, const int k, @@ -467,52 +487,52 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<5, 5>::serial_invoke(const ScalarType a switch (m * 10 + k) { case 54: { - InnerGemmFixA<5, 4> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, n, C); + InnerGemmFixA<5, 4> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); + inner.serial_invoke(opA, opB, alpha, A, B, n, C); break; } case 53: { - InnerGemmFixA<5, 3> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, n, C); + InnerGemmFixA<5, 3> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); + inner.serial_invoke(opA, opB, alpha, A, B, n, C); break; } case 52: { - InnerGemmFixA<5, 2> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, n, C); + InnerGemmFixA<5, 2> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); + inner.serial_invoke(opA, opB, alpha, A, B, n, C); break; } case 51: { - InnerGemmFixA<5, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, n, C); + InnerGemmFixA<5, 1> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); + inner.serial_invoke(opA, opB, alpha, A, B, n, C); break; } case 45: { - InnerGemmFixA<4, 5> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, n, C); + InnerGemmFixA<4, 5> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); + inner.serial_invoke(opA, opB, alpha, A, B, n, C); break; } case 35: { - InnerGemmFixA<3, 5> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, n, C); + InnerGemmFixA<3, 5> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); + inner.serial_invoke(opA, opB, alpha, A, B, n, C); break; } case 25: { - InnerGemmFixA<2, 5> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, n, C); + InnerGemmFixA<2, 5> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); + inner.serial_invoke(opA, opB, alpha, A, B, n, C); break; } case 15: { - InnerGemmFixA<1, 5> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, n, C); + InnerGemmFixA<1, 5> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); + inner.serial_invoke(opA, opB, alpha, A, B, n, C); break; } default: { if (m < 5 && n < 5) { - InnerGemmFixA<2, 2> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); + InnerGemmFixA<2, 2> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); for (int i = 0; i < m; i += 2) for (int p = 0; p < k; p += 2) - inner.serial_invoke(alpha, A + i * _as0 + p * _as1, B + p * _bs0, (i + 2 > m ? 1 : 2), n, - (p + 2 > k ? 1 : 2), C + i * _cs0); + inner.serial_invoke(opA, opB, alpha, A + i * m_as0 + p * m_as1, B + p * m_bs0, (i + 2 > m ? 1 : 2), n, + (p + 2 > k ? 1 : 2), C + i * m_cs0); } else { Kokkos::abort("InnerGemmFixA<5,5>::serial_invoke, assert failure (m<5 && n<5)"); } @@ -528,30 +548,32 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<5, 5>::serial_invoke(const ScalarType a /// ================== template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<4, 4>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<4, 4>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int n, /**/ ValueType *KOKKOS_RESTRICT C) { if (n <= 0) return 0; - const ValueType a_00 = A[0 * _as0 + 0 * _as1], a_01 = A[0 * _as0 + 1 * _as1], a_02 = A[0 * _as0 + 2 * _as1], - a_03 = A[0 * _as0 + 3 * _as1], a_10 = A[1 * _as0 + 0 * _as1], a_11 = A[1 * _as0 + 1 * _as1], - a_12 = A[1 * _as0 + 2 * _as1], a_13 = A[1 * _as0 + 3 * _as1], a_20 = A[2 * _as0 + 0 * _as1], - a_21 = A[2 * _as0 + 1 * _as1], a_22 = A[2 * _as0 + 2 * _as1], a_23 = A[2 * _as0 + 3 * _as1], - a_30 = A[3 * _as0 + 0 * _as1], a_31 = A[3 * _as0 + 1 * _as1], a_32 = A[3 * _as0 + 2 * _as1], - a_33 = A[3 * _as0 + 3 * _as1]; + const ValueType a_00 = opA(A[0 * m_as0 + 0 * m_as1]), a_01 = opA(A[0 * m_as0 + 1 * m_as1]), + a_02 = opA(A[0 * m_as0 + 2 * m_as1]), a_03 = opA(A[0 * m_as0 + 3 * m_as1]), + a_10 = opA(A[1 * m_as0 + 0 * m_as1]), a_11 = opA(A[1 * m_as0 + 1 * m_as1]), + a_12 = opA(A[1 * m_as0 + 2 * m_as1]), a_13 = opA(A[1 * m_as0 + 3 * m_as1]), + a_20 = opA(A[2 * m_as0 + 0 * m_as1]), a_21 = opA(A[2 * m_as0 + 1 * m_as1]), + a_22 = opA(A[2 * m_as0 + 2 * m_as1]), a_23 = opA(A[2 * m_as0 + 3 * m_as1]), + a_30 = opA(A[3 * m_as0 + 0 * m_as1]), a_31 = opA(A[3 * m_as0 + 1 * m_as1]), + a_32 = opA(A[3 * m_as0 + 2 * m_as1]), a_33 = opA(A[3 * m_as0 + 3 * m_as1]); ValueType b_0p, c_0p, b_1p, c_1p, b_2p, c_2p, b_3p, c_3p; - const int ib0 = 0 * _bs0, ib1 = 1 * _bs0, ib2 = 2 * _bs0, ib3 = 3 * _bs0, ic0 = 0 * _cs0, ic1 = 1 * _cs0, - ic2 = 2 * _cs0, ic3 = 3 * _cs0; + const int ib0 = 0 * m_bs0, ib1 = 1 * m_bs0, ib2 = 2 * m_bs0, ib3 = 3 * m_bs0, ic0 = 0 * m_cs0, ic1 = 1 * m_cs0, + ic2 = 2 * m_cs0, ic3 = 3 * m_cs0; for (int p = 0; p < n; ++p) { - b_0p = B[ib0 + p * _bs1]; - b_1p = B[ib1 + p * _bs1]; - b_2p = B[ib2 + p * _bs1]; - b_3p = B[ib3 + p * _bs1]; + b_0p = opB(B[ib0 + p * m_bs1]); + b_1p = opB(B[ib1 + p * m_bs1]); + b_2p = opB(B[ib2 + p * m_bs1]); + b_3p = opB(B[ib3 + p * m_bs1]); c_0p = a_00 * b_0p; c_0p += a_01 * b_1p; @@ -570,38 +592,40 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<4, 4>::serial_invoke(const ScalarType a c_3p += a_32 * b_2p; c_3p += a_33 * b_3p; - C[ic0 + p * _cs1] += alpha * c_0p; - C[ic1 + p * _cs1] += alpha * c_1p; - C[ic2 + p * _cs1] += alpha * c_2p; - C[ic3 + p * _cs1] += alpha * c_3p; + C[ic0 + p * m_cs1] += alpha * c_0p; + C[ic1 + p * m_cs1] += alpha * c_1p; + C[ic2 + p * m_cs1] += alpha * c_2p; + C[ic3 + p * m_cs1] += alpha * c_3p; } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<4, 3>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<4, 3>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int n, /**/ ValueType *KOKKOS_RESTRICT C) { if (n <= 0) return 0; - const ValueType a_00 = A[0 * _as0 + 0 * _as1], a_01 = A[0 * _as0 + 1 * _as1], a_02 = A[0 * _as0 + 2 * _as1], - a_10 = A[1 * _as0 + 0 * _as1], a_11 = A[1 * _as0 + 1 * _as1], a_12 = A[1 * _as0 + 2 * _as1], - a_20 = A[2 * _as0 + 0 * _as1], a_21 = A[2 * _as0 + 1 * _as1], a_22 = A[2 * _as0 + 2 * _as1], - a_30 = A[3 * _as0 + 0 * _as1], a_31 = A[3 * _as0 + 1 * _as1], a_32 = A[3 * _as0 + 2 * _as1]; + const ValueType a_00 = opA(A[0 * m_as0 + 0 * m_as1]), a_01 = opA(A[0 * m_as0 + 1 * m_as1]), + a_02 = opA(A[0 * m_as0 + 2 * m_as1]), a_10 = opA(A[1 * m_as0 + 0 * m_as1]), + a_11 = opA(A[1 * m_as0 + 1 * m_as1]), a_12 = opA(A[1 * m_as0 + 2 * m_as1]), + a_20 = opA(A[2 * m_as0 + 0 * m_as1]), a_21 = opA(A[2 * m_as0 + 1 * m_as1]), + a_22 = opA(A[2 * m_as0 + 2 * m_as1]), a_30 = opA(A[3 * m_as0 + 0 * m_as1]), + a_31 = opA(A[3 * m_as0 + 1 * m_as1]), a_32 = opA(A[3 * m_as0 + 2 * m_as1]); ValueType b_0p, c_0p, b_1p, c_1p, b_2p, c_2p, /**/ c_3p; - const int ib0 = 0 * _bs0, ib1 = 1 * _bs0, ib2 = 2 * _bs0, ic0 = 0 * _cs0, ic1 = 1 * _cs0, ic2 = 2 * _cs0, - ic3 = 3 * _cs0; + const int ib0 = 0 * m_bs0, ib1 = 1 * m_bs0, ib2 = 2 * m_bs0, ic0 = 0 * m_cs0, ic1 = 1 * m_cs0, ic2 = 2 * m_cs0, + ic3 = 3 * m_cs0; for (int p = 0; p < n; ++p) { - b_0p = B[ib0 + p * _bs1]; - b_1p = B[ib1 + p * _bs1]; - b_2p = B[ib2 + p * _bs1]; + b_0p = opB(B[ib0 + p * m_bs1]); + b_1p = opB(B[ib1 + p * m_bs1]); + b_2p = opB(B[ib2 + p * m_bs1]); c_0p = a_00 * b_0p; c_0p += a_01 * b_1p; @@ -616,36 +640,37 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<4, 3>::serial_invoke(const ScalarType a c_3p += a_31 * b_1p; c_3p += a_32 * b_2p; - C[ic0 + p * _cs1] += alpha * c_0p; - C[ic1 + p * _cs1] += alpha * c_1p; - C[ic2 + p * _cs1] += alpha * c_2p; - C[ic3 + p * _cs1] += alpha * c_3p; + C[ic0 + p * m_cs1] += alpha * c_0p; + C[ic1 + p * m_cs1] += alpha * c_1p; + C[ic2 + p * m_cs1] += alpha * c_2p; + C[ic3 + p * m_cs1] += alpha * c_3p; } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<4, 2>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<4, 2>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int n, /**/ ValueType *KOKKOS_RESTRICT C) { if (n <= 0) return 0; - const ValueType a_00 = A[0 * _as0 + 0 * _as1], a_01 = A[0 * _as0 + 1 * _as1], a_10 = A[1 * _as0 + 0 * _as1], - a_11 = A[1 * _as0 + 1 * _as1], a_20 = A[2 * _as0 + 0 * _as1], a_21 = A[2 * _as0 + 1 * _as1], - a_30 = A[3 * _as0 + 0 * _as1], a_31 = A[3 * _as0 + 1 * _as1]; + const ValueType a_00 = opA(A[0 * m_as0 + 0 * m_as1]), a_01 = opA(A[0 * m_as0 + 1 * m_as1]), + a_10 = opA(A[1 * m_as0 + 0 * m_as1]), a_11 = opA(A[1 * m_as0 + 1 * m_as1]), + a_20 = opA(A[2 * m_as0 + 0 * m_as1]), a_21 = opA(A[2 * m_as0 + 1 * m_as1]), + a_30 = opA(A[3 * m_as0 + 0 * m_as1]), a_31 = opA(A[3 * m_as0 + 1 * m_as1]); ValueType b_0p, c_0p, b_1p, c_1p, /**/ c_2p, /**/ c_3p; - const int ib0 = 0 * _bs0, ib1 = 1 * _bs0, ic0 = 0 * _cs0, ic1 = 1 * _cs0, ic2 = 2 * _cs0, ic3 = 3 * _cs0; + const int ib0 = 0 * m_bs0, ib1 = 1 * m_bs0, ic0 = 0 * m_cs0, ic1 = 1 * m_cs0, ic2 = 2 * m_cs0, ic3 = 3 * m_cs0; for (int p = 0; p < n; ++p) { - b_0p = B[ib0 + p * _bs1]; - b_1p = B[ib1 + p * _bs1]; + b_0p = opB(B[ib0 + p * m_bs1]); + b_1p = opB(B[ib1 + p * m_bs1]); c_0p = a_00 * b_0p; c_0p += a_01 * b_1p; @@ -656,73 +681,75 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<4, 2>::serial_invoke(const ScalarType a c_3p = a_30 * b_0p; c_3p += a_31 * b_1p; - C[ic0 + p * _cs1] += alpha * c_0p; - C[ic1 + p * _cs1] += alpha * c_1p; - C[ic2 + p * _cs1] += alpha * c_2p; - C[ic3 + p * _cs1] += alpha * c_3p; + C[ic0 + p * m_cs1] += alpha * c_0p; + C[ic1 + p * m_cs1] += alpha * c_1p; + C[ic2 + p * m_cs1] += alpha * c_2p; + C[ic3 + p * m_cs1] += alpha * c_3p; } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<4, 1>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<4, 1>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int n, /**/ ValueType *KOKKOS_RESTRICT C) { if (n <= 0) return 0; - const ValueType a_00 = A[0 * _as0 + 0 * _as1], a_10 = A[1 * _as0 + 0 * _as1], a_20 = A[2 * _as0 + 0 * _as1], - a_30 = A[3 * _as0 + 0 * _as1]; + const ValueType a_00 = opA(A[0 * m_as0 + 0 * m_as1]), a_10 = opA(A[1 * m_as0 + 0 * m_as1]), + a_20 = opA(A[2 * m_as0 + 0 * m_as1]), a_30 = opA(A[3 * m_as0 + 0 * m_as1]); ValueType b_0p, c_0p, /**/ c_1p, /**/ c_2p, /**/ c_3p; - const int ib0 = 0 * _bs0, ic0 = 0 * _cs0, ic1 = 1 * _cs0, ic2 = 2 * _cs0, ic3 = 3 * _cs0; + const int ib0 = 0 * m_bs0, ic0 = 0 * m_cs0, ic1 = 1 * m_cs0, ic2 = 2 * m_cs0, ic3 = 3 * m_cs0; for (int p = 0; p < n; ++p) { - b_0p = B[ib0 + p * _bs1]; + b_0p = opB(B[ib0 + p * m_bs1]); c_0p = a_00 * b_0p; c_1p = a_10 * b_0p; c_2p = a_20 * b_0p; c_3p = a_30 * b_0p; - C[ic0 + p * _cs1] += alpha * c_0p; - C[ic1 + p * _cs1] += alpha * c_1p; - C[ic2 + p * _cs1] += alpha * c_2p; - C[ic3 + p * _cs1] += alpha * c_3p; + C[ic0 + p * m_cs1] += alpha * c_0p; + C[ic1 + p * m_cs1] += alpha * c_1p; + C[ic2 + p * m_cs1] += alpha * c_2p; + C[ic3 + p * m_cs1] += alpha * c_3p; } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<3, 4>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<3, 4>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int n, /**/ ValueType *KOKKOS_RESTRICT C) { if (n <= 0) return 0; - const ValueType a_00 = A[0 * _as0 + 0 * _as1], a_01 = A[0 * _as0 + 1 * _as1], a_02 = A[0 * _as0 + 2 * _as1], - a_03 = A[0 * _as0 + 3 * _as1], a_10 = A[1 * _as0 + 0 * _as1], a_11 = A[1 * _as0 + 1 * _as1], - a_12 = A[1 * _as0 + 2 * _as1], a_13 = A[1 * _as0 + 3 * _as1], a_20 = A[2 * _as0 + 0 * _as1], - a_21 = A[2 * _as0 + 1 * _as1], a_22 = A[2 * _as0 + 2 * _as1], a_23 = A[2 * _as0 + 3 * _as1]; + const ValueType a_00 = opA(A[0 * m_as0 + 0 * m_as1]), a_01 = opA(A[0 * m_as0 + 1 * m_as1]), + a_02 = opA(A[0 * m_as0 + 2 * m_as1]), a_03 = opA(A[0 * m_as0 + 3 * m_as1]), + a_10 = opA(A[1 * m_as0 + 0 * m_as1]), a_11 = opA(A[1 * m_as0 + 1 * m_as1]), + a_12 = opA(A[1 * m_as0 + 2 * m_as1]), a_13 = opA(A[1 * m_as0 + 3 * m_as1]), + a_20 = opA(A[2 * m_as0 + 0 * m_as1]), a_21 = opA(A[2 * m_as0 + 1 * m_as1]), + a_22 = opA(A[2 * m_as0 + 2 * m_as1]), a_23 = opA(A[2 * m_as0 + 3 * m_as1]); ValueType b_0p, c_0p, b_1p, c_1p, b_2p, c_2p, b_3p; - const int ib0 = 0 * _bs0, ib1 = 1 * _bs0, ib2 = 2 * _bs0, ib3 = 3 * _bs0, ic0 = 0 * _cs0, ic1 = 1 * _cs0, - ic2 = 2 * _cs0; + const int ib0 = 0 * m_bs0, ib1 = 1 * m_bs0, ib2 = 2 * m_bs0, ib3 = 3 * m_bs0, ic0 = 0 * m_cs0, ic1 = 1 * m_cs0, + ic2 = 2 * m_cs0; for (int p = 0; p < n; ++p) { - b_0p = B[ib0 + p * _bs1]; - b_1p = B[ib1 + p * _bs1]; - b_2p = B[ib2 + p * _bs1]; - b_3p = B[ib3 + p * _bs1]; + b_0p = opB(B[ib0 + p * m_bs1]); + b_1p = opB(B[ib1 + p * m_bs1]); + b_2p = opB(B[ib2 + p * m_bs1]); + b_3p = opB(B[ib3 + p * m_bs1]); c_0p = a_00 * b_0p; c_0p += a_01 * b_1p; @@ -737,35 +764,36 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<3, 4>::serial_invoke(const ScalarType a c_2p += a_22 * b_2p; c_2p += a_23 * b_3p; - C[ic0 + p * _cs1] += alpha * c_0p; - C[ic1 + p * _cs1] += alpha * c_1p; - C[ic2 + p * _cs1] += alpha * c_2p; + C[ic0 + p * m_cs1] += alpha * c_0p; + C[ic1 + p * m_cs1] += alpha * c_1p; + C[ic2 + p * m_cs1] += alpha * c_2p; } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<2, 4>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<2, 4>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int n, /**/ ValueType *KOKKOS_RESTRICT C) { if (n <= 0) return 0; - const ValueType a_00 = A[0 * _as0 + 0 * _as1], a_01 = A[0 * _as0 + 1 * _as1], a_02 = A[0 * _as0 + 2 * _as1], - a_03 = A[0 * _as0 + 3 * _as1], a_10 = A[1 * _as0 + 0 * _as1], a_11 = A[1 * _as0 + 1 * _as1], - a_12 = A[1 * _as0 + 2 * _as1], a_13 = A[1 * _as0 + 3 * _as1]; + const ValueType a_00 = opA(A[0 * m_as0 + 0 * m_as1]), a_01 = opA(A[0 * m_as0 + 1 * m_as1]), + a_02 = opA(A[0 * m_as0 + 2 * m_as1]), a_03 = opA(A[0 * m_as0 + 3 * m_as1]), + a_10 = opA(A[1 * m_as0 + 0 * m_as1]), a_11 = opA(A[1 * m_as0 + 1 * m_as1]), + a_12 = opA(A[1 * m_as0 + 2 * m_as1]), a_13 = opA(A[1 * m_as0 + 3 * m_as1]); ValueType b_0p, c_0p, b_1p, c_1p, b_2p, b_3p; - const int ib0 = 0 * _bs0, ib1 = 1 * _bs0, ib2 = 2 * _bs0, ib3 = 3 * _bs0, ic0 = 0 * _cs0, ic1 = 1 * _cs0; + const int ib0 = 0 * m_bs0, ib1 = 1 * m_bs0, ib2 = 2 * m_bs0, ib3 = 3 * m_bs0, ic0 = 0 * m_cs0, ic1 = 1 * m_cs0; for (int p = 0; p < n; ++p) { - b_0p = B[ib0 + p * _bs1]; - b_1p = B[ib1 + p * _bs1]; - b_2p = B[ib2 + p * _bs1]; - b_3p = B[ib3 + p * _bs1]; + b_0p = opB(B[ib0 + p * m_bs1]); + b_1p = opB(B[ib1 + p * m_bs1]); + b_2p = opB(B[ib2 + p * m_bs1]); + b_3p = opB(B[ib3 + p * m_bs1]); c_0p = a_00 * b_0p; c_0p += a_01 * b_1p; @@ -776,48 +804,48 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<2, 4>::serial_invoke(const ScalarType a c_1p += a_12 * b_2p; c_1p += a_13 * b_3p; - C[ic0 + p * _cs1] += alpha * c_0p; - C[ic1 + p * _cs1] += alpha * c_1p; + C[ic0 + p * m_cs1] += alpha * c_0p; + C[ic1 + p * m_cs1] += alpha * c_1p; } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<1, 4>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<1, 4>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int n, /**/ ValueType *KOKKOS_RESTRICT C) { if (n <= 0) return 0; - const ValueType a_00 = A[0 * _as0 + 0 * _as1], a_01 = A[0 * _as0 + 1 * _as1], a_02 = A[0 * _as0 + 2 * _as1], - a_03 = A[0 * _as0 + 3 * _as1]; + const ValueType a_00 = opA(A[0 * m_as0 + 0 * m_as1]), a_01 = opA(A[0 * m_as0 + 1 * m_as1]), + a_02 = opA(A[0 * m_as0 + 2 * m_as1]), a_03 = opA(A[0 * m_as0 + 3 * m_as1]); ValueType b_0p, c_0p, b_1p, b_2p, b_3p; - const int ib0 = 0 * _bs0, ib1 = 1 * _bs0, ib2 = 2 * _bs0, ib3 = 3 * _bs0, ic0 = 0 * _cs0; + const int ib0 = 0 * m_bs0, ib1 = 1 * m_bs0, ib2 = 2 * m_bs0, ib3 = 3 * m_bs0, ic0 = 0 * m_cs0; for (int p = 0; p < n; ++p) { - b_0p = B[ib0 + p * _bs1]; - b_1p = B[ib1 + p * _bs1]; - b_2p = B[ib2 + p * _bs1]; - b_3p = B[ib3 + p * _bs1]; + b_0p = opB(B[ib0 + p * m_bs1]); + b_1p = opB(B[ib1 + p * m_bs1]); + b_2p = opB(B[ib2 + p * m_bs1]); + b_3p = opB(B[ib3 + p * m_bs1]); c_0p = a_00 * b_0p; c_0p += a_01 * b_1p; c_0p += a_02 * b_2p; c_0p += a_03 * b_3p; - C[ic0 + p * _cs1] += alpha * c_0p; + C[ic0 + p * m_cs1] += alpha * c_0p; } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<4, 4>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<4, 4>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int m, const int n, const int k, @@ -826,47 +854,47 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<4, 4>::serial_invoke(const ScalarType a switch (m * 10 + k) { case 44: { - InnerGemmFixA<4, 4> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, n, C); + InnerGemmFixA<4, 4> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); + inner.serial_invoke(opA, opB, alpha, A, B, n, C); break; } case 43: { - InnerGemmFixA<4, 3> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, n, C); + InnerGemmFixA<4, 3> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); + inner.serial_invoke(opA, opB, alpha, A, B, n, C); break; } case 42: { - InnerGemmFixA<4, 2> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, n, C); + InnerGemmFixA<4, 2> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); + inner.serial_invoke(opA, opB, alpha, A, B, n, C); break; } case 41: { - InnerGemmFixA<4, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, n, C); + InnerGemmFixA<4, 1> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); + inner.serial_invoke(opA, opB, alpha, A, B, n, C); break; } case 34: { - InnerGemmFixA<3, 4> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, n, C); + InnerGemmFixA<3, 4> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); + inner.serial_invoke(opA, opB, alpha, A, B, n, C); break; } case 24: { - InnerGemmFixA<2, 4> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, n, C); + InnerGemmFixA<2, 4> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); + inner.serial_invoke(opA, opB, alpha, A, B, n, C); break; } case 14: { - InnerGemmFixA<1, 4> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, n, C); + InnerGemmFixA<1, 4> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); + inner.serial_invoke(opA, opB, alpha, A, B, n, C); break; } default: { if (m < 4 && n < 4) { - InnerGemmFixA<2, 2> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); + InnerGemmFixA<2, 2> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); for (int i = 0; i < m; i += 2) for (int p = 0; p < k; p += 2) - inner.serial_invoke(alpha, A + i * _as0 + p * _as1, B + p * _bs0, (i + 2 > m ? 1 : 2), n, - (p + 2 > k ? 1 : 2), C + i * _cs0); + inner.serial_invoke(opA, opB, alpha, A + i * m_as0 + p * m_as1, B + p * m_bs0, (i + 2 > m ? 1 : 2), n, + (p + 2 > k ? 1 : 2), C + i * m_cs0); } else { Kokkos::abort("InnerGemmFixA<4,4>::serial_invoke, assert failure (m<4 && n<4)"); } @@ -882,25 +910,27 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<4, 4>::serial_invoke(const ScalarType a /// ================== template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<3, 3>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<3, 3>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int n, /**/ ValueType *KOKKOS_RESTRICT C) { if (n <= 0) return 0; - const ValueType a_00 = A[0 * _as0 + 0 * _as1], a_01 = A[0 * _as0 + 1 * _as1], a_02 = A[0 * _as0 + 2 * _as1], - a_10 = A[1 * _as0 + 0 * _as1], a_11 = A[1 * _as0 + 1 * _as1], a_12 = A[1 * _as0 + 2 * _as1], - a_20 = A[2 * _as0 + 0 * _as1], a_21 = A[2 * _as0 + 1 * _as1], a_22 = A[2 * _as0 + 2 * _as1]; + const ValueType a_00 = opA(A[0 * m_as0 + 0 * m_as1]), a_01 = opA(A[0 * m_as0 + 1 * m_as1]), + a_02 = opA(A[0 * m_as0 + 2 * m_as1]), a_10 = opA(A[1 * m_as0 + 0 * m_as1]), + a_11 = opA(A[1 * m_as0 + 1 * m_as1]), a_12 = opA(A[1 * m_as0 + 2 * m_as1]), + a_20 = opA(A[2 * m_as0 + 0 * m_as1]), a_21 = opA(A[2 * m_as0 + 1 * m_as1]), + a_22 = opA(A[2 * m_as0 + 2 * m_as1]); ValueType b_0p, c_0p, b_1p, c_1p, b_2p, c_2p; - const int ib0 = 0 * _bs0, ib1 = 1 * _bs0, ib2 = 2 * _bs0, ic0 = 0 * _cs0, ic1 = 1 * _cs0, ic2 = 2 * _cs0; + const int ib0 = 0 * m_bs0, ib1 = 1 * m_bs0, ib2 = 2 * m_bs0, ic0 = 0 * m_cs0, ic1 = 1 * m_cs0, ic2 = 2 * m_cs0; for (int p = 0; p < n; ++p) { - b_0p = B[ib0 + p * _bs1]; - b_1p = B[ib1 + p * _bs1]; - b_2p = B[ib2 + p * _bs1]; + b_0p = opB(B[ib0 + p * m_bs1]); + b_1p = opB(B[ib1 + p * m_bs1]); + b_2p = opB(B[ib2 + p * m_bs1]); c_0p = a_00 * b_0p; c_0p += a_01 * b_1p; @@ -912,33 +942,34 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<3, 3>::serial_invoke(const ScalarType a c_2p += a_21 * b_1p; c_2p += a_22 * b_2p; - C[ic0 + p * _cs1] += alpha * c_0p; - C[ic1 + p * _cs1] += alpha * c_1p; - C[ic2 + p * _cs1] += alpha * c_2p; + C[ic0 + p * m_cs1] += alpha * c_0p; + C[ic1 + p * m_cs1] += alpha * c_1p; + C[ic2 + p * m_cs1] += alpha * c_2p; } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<3, 2>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<3, 2>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int n, /**/ ValueType *KOKKOS_RESTRICT C) { if (n <= 0) return 0; - const ValueType a_00 = A[0 * _as0 + 0 * _as1], a_01 = A[0 * _as0 + 1 * _as1], a_10 = A[1 * _as0 + 0 * _as1], - a_11 = A[1 * _as0 + 1 * _as1], a_20 = A[2 * _as0 + 0 * _as1], a_21 = A[2 * _as0 + 1 * _as1]; + const ValueType a_00 = opA(A[0 * m_as0 + 0 * m_as1]), a_01 = opA(A[0 * m_as0 + 1 * m_as1]), + a_10 = opA(A[1 * m_as0 + 0 * m_as1]), a_11 = opA(A[1 * m_as0 + 1 * m_as1]), + a_20 = opA(A[2 * m_as0 + 0 * m_as1]), a_21 = opA(A[2 * m_as0 + 1 * m_as1]); ValueType b_0p, c_0p, b_1p, c_1p, /**/ c_2p; - const int ib0 = 0 * _bs0, ib1 = 1 * _bs0, ic0 = 0 * _cs0, ic1 = 1 * _cs0, ic2 = 2 * _cs0; + const int ib0 = 0 * m_bs0, ib1 = 1 * m_bs0, ic0 = 0 * m_cs0, ic1 = 1 * m_cs0, ic2 = 2 * m_cs0; for (int p = 0; p < n; ++p) { - b_0p = B[ib0 + p * _bs1]; - b_1p = B[ib1 + p * _bs1]; + b_0p = opB(B[ib0 + p * m_bs1]); + b_1p = opB(B[ib1 + p * m_bs1]); c_0p = a_00 * b_0p; c_0p += a_01 * b_1p; @@ -947,64 +978,66 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<3, 2>::serial_invoke(const ScalarType a c_2p = a_20 * b_0p; c_2p += a_21 * b_1p; - C[ic0 + p * _cs1] += alpha * c_0p; - C[ic1 + p * _cs1] += alpha * c_1p; - C[ic2 + p * _cs1] += alpha * c_2p; + C[ic0 + p * m_cs1] += alpha * c_0p; + C[ic1 + p * m_cs1] += alpha * c_1p; + C[ic2 + p * m_cs1] += alpha * c_2p; } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<3, 1>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<3, 1>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int n, /**/ ValueType *KOKKOS_RESTRICT C) { if (n <= 0) return 0; - const ValueType a_00 = A[0 * _as0 + 0 * _as1], a_10 = A[1 * _as0 + 0 * _as1], a_20 = A[2 * _as0 + 0 * _as1]; + const ValueType a_00 = opA(A[0 * m_as0 + 0 * m_as1]), a_10 = opA(A[1 * m_as0 + 0 * m_as1]), + a_20 = opA(A[2 * m_as0 + 0 * m_as1]); ValueType b_0p, c_0p, /**/ c_1p, /**/ c_2p; - const int ib0 = 0 * _bs0, ic0 = 0 * _cs0, ic1 = 1 * _cs0, ic2 = 2 * _cs0; + const int ib0 = 0 * m_bs0, ic0 = 0 * m_cs0, ic1 = 1 * m_cs0, ic2 = 2 * m_cs0; for (int p = 0; p < n; ++p) { - b_0p = B[ib0 + p * _bs1]; + b_0p = opB(B[ib0 + p * m_bs1]); c_0p = a_00 * b_0p; c_1p = a_10 * b_0p; c_2p = a_20 * b_0p; - C[ic0 + p * _cs1] += alpha * c_0p; - C[ic1 + p * _cs1] += alpha * c_1p; - C[ic2 + p * _cs1] += alpha * c_2p; + C[ic0 + p * m_cs1] += alpha * c_0p; + C[ic1 + p * m_cs1] += alpha * c_1p; + C[ic2 + p * m_cs1] += alpha * c_2p; } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<2, 3>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<2, 3>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int n, /**/ ValueType *KOKKOS_RESTRICT C) { if (n <= 0) return 0; - const ValueType a_00 = A[0 * _as0 + 0 * _as1], a_01 = A[0 * _as0 + 1 * _as1], a_02 = A[0 * _as0 + 2 * _as1], - a_10 = A[1 * _as0 + 0 * _as1], a_11 = A[1 * _as0 + 1 * _as1], a_12 = A[1 * _as0 + 2 * _as1]; + const ValueType a_00 = opA(A[0 * m_as0 + 0 * m_as1]), a_01 = opA(A[0 * m_as0 + 1 * m_as1]), + a_02 = opA(A[0 * m_as0 + 2 * m_as1]), a_10 = opA(A[1 * m_as0 + 0 * m_as1]), + a_11 = opA(A[1 * m_as0 + 1 * m_as1]), a_12 = opA(A[1 * m_as0 + 2 * m_as1]); ValueType b_0p, c_0p, b_1p, c_1p, b_2p; - const int ib0 = 0 * _bs0, ib1 = 1 * _bs0, ib2 = 2 * _bs0, ic0 = 0 * _cs0, ic1 = 1 * _cs0; + const int ib0 = 0 * m_bs0, ib1 = 1 * m_bs0, ib2 = 2 * m_bs0, ic0 = 0 * m_cs0, ic1 = 1 * m_cs0; for (int p = 0; p < n; ++p) { - b_0p = B[ib0 + p * _bs1]; - b_1p = B[ib1 + p * _bs1]; - b_2p = B[ib2 + p * _bs1]; + b_0p = opB(B[ib0 + p * m_bs1]); + b_1p = opB(B[ib1 + p * m_bs1]); + b_2p = opB(B[ib2 + p * m_bs1]); c_0p = a_00 * b_0p; c_0p += a_01 * b_1p; @@ -1013,44 +1046,45 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<2, 3>::serial_invoke(const ScalarType a c_1p += a_11 * b_1p; c_1p += a_12 * b_2p; - C[ic0 + p * _cs1] += alpha * c_0p; - C[ic1 + p * _cs1] += alpha * c_1p; + C[ic0 + p * m_cs1] += alpha * c_0p; + C[ic1 + p * m_cs1] += alpha * c_1p; } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<1, 3>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<1, 3>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int n, /**/ ValueType *KOKKOS_RESTRICT C) { if (n <= 0) return 0; - const ValueType a_00 = A[0 * _as0 + 0 * _as1], a_01 = A[0 * _as0 + 1 * _as1], a_02 = A[0 * _as0 + 2 * _as1]; + const ValueType a_00 = opA(A[0 * m_as0 + 0 * m_as1]), a_01 = opA(A[0 * m_as0 + 1 * m_as1]), + a_02 = opA(A[0 * m_as0 + 2 * m_as1]); ValueType b_0p, c_0p, b_1p, b_2p; - const int ib0 = 0 * _bs0, ib1 = 1 * _bs0, ib2 = 2 * _bs0, ic0 = 0 * _cs0; + const int ib0 = 0 * m_bs0, ib1 = 1 * m_bs0, ib2 = 2 * m_bs0, ic0 = 0 * m_cs0; for (int p = 0; p < n; ++p) { - b_0p = B[ib0 + p * _bs1]; - b_1p = B[ib1 + p * _bs1]; - b_2p = B[ib2 + p * _bs1]; + b_0p = opB(B[ib0 + p * m_bs1]); + b_1p = opB(B[ib1 + p * m_bs1]); + b_2p = opB(B[ib2 + p * m_bs1]); c_0p = a_00 * b_0p; c_0p += a_01 * b_1p; c_0p += a_02 * b_2p; - C[ic0 + p * _cs1] += alpha * c_0p; + C[ic0 + p * m_cs1] += alpha * c_0p; } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<3, 3>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<3, 3>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int m, const int n, const int k, @@ -1059,37 +1093,37 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<3, 3>::serial_invoke(const ScalarType a switch (m * 10 + k) { case 33: { - InnerGemmFixA<3, 3> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, n, C); + InnerGemmFixA<3, 3> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); + inner.serial_invoke(opA, opB, alpha, A, B, n, C); break; } case 32: { - InnerGemmFixA<3, 2> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, n, C); + InnerGemmFixA<3, 2> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); + inner.serial_invoke(opA, opB, alpha, A, B, n, C); break; } case 31: { - InnerGemmFixA<3, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, n, C); + InnerGemmFixA<3, 1> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); + inner.serial_invoke(opA, opB, alpha, A, B, n, C); break; } case 23: { - InnerGemmFixA<2, 3> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, n, C); + InnerGemmFixA<2, 3> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); + inner.serial_invoke(opA, opB, alpha, A, B, n, C); break; } case 13: { - InnerGemmFixA<1, 3> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, n, C); + InnerGemmFixA<1, 3> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); + inner.serial_invoke(opA, opB, alpha, A, B, n, C); break; } default: { if (m < 3 && n < 3) { - InnerGemmFixA<2, 2> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); + InnerGemmFixA<2, 2> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); for (int i = 0; i < m; i += 2) for (int p = 0; p < k; p += 2) - inner.serial_invoke(alpha, A + i * _as0 + p * _as1, B + p * _bs0, (i + 2 > m ? 1 : 2), n, - (p + 2 > k ? 1 : 2), C + i * _cs0); + inner.serial_invoke(opA, opB, alpha, A + i * m_as0 + p * m_as1, B + p * m_bs0, (i + 2 > m ? 1 : 2), n, + (p + 2 > k ? 1 : 2), C + i * m_cs0); } else { Kokkos::abort("InnerGemmFixA<3,3>::serial_invoke, assert failure (m<3 && n<3)"); } @@ -1105,94 +1139,94 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<3, 3>::serial_invoke(const ScalarType a /// ================== template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<2, 2>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<2, 2>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int n, /**/ ValueType *KOKKOS_RESTRICT C) { if (n <= 0) return 0; - const ValueType a_00 = A[0 * _as0 + 0 * _as1], a_01 = A[0 * _as0 + 1 * _as1], a_10 = A[1 * _as0 + 0 * _as1], - a_11 = A[1 * _as0 + 1 * _as1]; + const ValueType a_00 = opA(A[0 * m_as0 + 0 * m_as1]), a_01 = opA(A[0 * m_as0 + 1 * m_as1]), + a_10 = opA(A[1 * m_as0 + 0 * m_as1]), a_11 = opA(A[1 * m_as0 + 1 * m_as1]); ValueType b_0p, c_0p, b_1p, c_1p; - const int ib0 = 0 * _bs0, ib1 = 1 * _bs0, ic0 = 0 * _cs0, ic1 = 1 * _cs0; + const int ib0 = 0 * m_bs0, ib1 = 1 * m_bs0, ic0 = 0 * m_cs0, ic1 = 1 * m_cs0; for (int p = 0; p < n; ++p) { - b_0p = B[ib0 + p * _bs1]; - b_1p = B[ib1 + p * _bs1]; + b_0p = opB(B[ib0 + p * m_bs1]); + b_1p = opB(B[ib1 + p * m_bs1]); c_0p = a_00 * b_0p; c_0p += a_01 * b_1p; c_1p = a_10 * b_0p; c_1p += a_11 * b_1p; - C[ic0 + p * _cs1] += alpha * c_0p; - C[ic1 + p * _cs1] += alpha * c_1p; + C[ic0 + p * m_cs1] += alpha * c_0p; + C[ic1 + p * m_cs1] += alpha * c_1p; } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<2, 1>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<2, 1>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int n, /**/ ValueType *KOKKOS_RESTRICT C) { if (n <= 0) return 0; - const ValueType a_00 = A[0 * _as0 + 0 * _as1], a_10 = A[1 * _as0 + 0 * _as1]; + const ValueType a_00 = opA(A[0 * m_as0 + 0 * m_as1]), a_10 = opA(A[1 * m_as0 + 0 * m_as1]); ValueType b_0p, c_0p, /**/ c_1p; - const int ib0 = 0 * _bs0, ic0 = 0 * _cs0, ic1 = 1 * _cs0; + const int ib0 = 0 * m_bs0, ic0 = 0 * m_cs0, ic1 = 1 * m_cs0; for (int p = 0; p < n; ++p) { - b_0p = B[ib0 + p * _bs1]; + b_0p = opB(B[ib0 + p * m_bs1]); c_0p = a_00 * b_0p; c_1p = a_10 * b_0p; - C[ic0 + p * _cs1] += alpha * c_0p; - C[ic1 + p * _cs1] += alpha * c_1p; + C[ic0 + p * m_cs1] += alpha * c_0p; + C[ic1 + p * m_cs1] += alpha * c_1p; } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<1, 2>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<1, 2>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int n, /**/ ValueType *KOKKOS_RESTRICT C) { if (n <= 0) return 0; - const ValueType a_00 = A[0 * _as0 + 0 * _as1], a_01 = A[0 * _as0 + 1 * _as1]; + const ValueType a_00 = opA(A[0 * m_as0 + 0 * m_as1]), a_01 = opA(A[0 * m_as0 + 1 * m_as1]); ValueType b_0p, c_0p, b_1p; - const int ib0 = 0 * _bs0, ib1 = 1 * _bs0, ic0 = 0 * _cs0; + const int ib0 = 0 * m_bs0, ib1 = 1 * m_bs0, ic0 = 0 * m_cs0; for (int p = 0; p < n; ++p) { - b_0p = B[ib0 + p * _bs1]; - b_1p = B[ib1 + p * _bs1]; + b_0p = opB(B[ib0 + p * m_bs1]); + b_1p = opB(B[ib1 + p * m_bs1]); c_0p = a_00 * b_0p; c_0p += a_01 * b_1p; - C[ic0 + p * _cs1] += alpha * c_0p; + C[ic0 + p * m_cs1] += alpha * c_0p; } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<2, 2>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<2, 2>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int m, const int n, const int k, @@ -1201,23 +1235,23 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<2, 2>::serial_invoke(const ScalarType a switch (m * 10 + k) { case 22: { - InnerGemmFixA<2, 2> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, n, C); + InnerGemmFixA<2, 2> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); + inner.serial_invoke(opA, opB, alpha, A, B, n, C); break; } case 21: { - InnerGemmFixA<2, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, n, C); + InnerGemmFixA<2, 1> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); + inner.serial_invoke(opA, opB, alpha, A, B, n, C); break; } case 12: { - InnerGemmFixA<1, 2> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, n, C); + InnerGemmFixA<1, 2> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); + inner.serial_invoke(opA, opB, alpha, A, B, n, C); break; } case 11: { - InnerGemmFixA<1, 1> inner(_as0, _as1, _bs0, _bs1, _cs0, _cs1); - inner.serial_invoke(alpha, A, B, n, C); + InnerGemmFixA<1, 1> inner(m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1); + inner.serial_invoke(opA, opB, alpha, A, B, n, C); break; } default: { @@ -1234,25 +1268,25 @@ KOKKOS_INLINE_FUNCTION int InnerGemmFixA<2, 2>::serial_invoke(const ScalarType a /// ================== template <> -template -KOKKOS_INLINE_FUNCTION int InnerGemmFixA<1, 1>::serial_invoke(const ScalarType alpha, +template +KOKKOS_INLINE_FUNCTION int InnerGemmFixA<1, 1>::serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int n, /**/ ValueType *KOKKOS_RESTRICT C) { if (n <= 0) return 0; - const ValueType a_00 = A[0 * _as0 + 0 * _as1]; + const ValueType a_00 = opA(A[0 * m_as0 + 0 * m_as1]); ValueType b_0p, c_0p; - const int ib0 = 0 * _bs0, ic0 = 0 * _cs0; + const int ib0 = 0 * m_bs0, ic0 = 0 * m_cs0; for (int p = 0; p < n; ++p) { - b_0p = B[ib0 + p * _bs1]; + b_0p = opB(B[ib0 + p * m_bs1]); c_0p = a_00 * b_0p; - C[ic0 + p * _cs1] += alpha * c_0p; + C[ic0 + p * m_cs1] += alpha * c_0p; } return 0; diff --git a/batched/dense/impl/KokkosBatched_InnerTrsm_Serial_Impl.hpp b/batched/dense/impl/KokkosBatched_InnerTrsm_Serial_Impl.hpp index f77f936804..c52c158ec3 100644 --- a/batched/dense/impl/KokkosBatched_InnerTrsm_Serial_Impl.hpp +++ b/batched/dense/impl/KokkosBatched_InnerTrsm_Serial_Impl.hpp @@ -4,6 +4,7 @@ #define KOKKOSBATCHED_INNER_TRSM_SERIAL_IMPL_HPP /// \author Kyungjoo Kim (kyukim@sandia.gov) +/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr) #include "KokkosBatched_Util.hpp" #include "KokkosBatched_InnerTrsm_Decl.hpp" @@ -16,15 +17,17 @@ namespace KokkosBatched { /// L(m x m) X(m x n) = B (m x n) template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<5>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, const int n, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<5>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, + const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (n <= 0) return 0; - const ValueType a_10 = A[1 * _as0 + 0 * _as1], a_20 = A[2 * _as0 + 0 * _as1], a_21 = A[2 * _as0 + 1 * _as1], - a_30 = A[3 * _as0 + 0 * _as1], a_31 = A[3 * _as0 + 1 * _as1], a_32 = A[3 * _as0 + 2 * _as1], - a_40 = A[4 * _as0 + 0 * _as1], a_41 = A[4 * _as0 + 1 * _as1], a_42 = A[4 * _as0 + 2 * _as1], - a_43 = A[4 * _as0 + 3 * _as1]; + const ValueType a_10 = op(A[1 * m_as0 + 0 * m_as1]), a_20 = op(A[2 * m_as0 + 0 * m_as1]), + a_21 = op(A[2 * m_as0 + 1 * m_as1]), a_30 = op(A[3 * m_as0 + 0 * m_as1]), + a_31 = op(A[3 * m_as0 + 1 * m_as1]), a_32 = op(A[3 * m_as0 + 2 * m_as1]), + a_40 = op(A[4 * m_as0 + 0 * m_as1]), a_41 = op(A[4 * m_as0 + 1 * m_as1]), + a_42 = op(A[4 * m_as0 + 2 * m_as1]), a_43 = op(A[4 * m_as0 + 3 * m_as1]); auto trsv = [&](const int p) { ValueType b_p[5]; @@ -35,11 +38,11 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<5>::serial_invoke(const Va auto &b_4p = b_p[4]; // load - b_0p = B[0 * _bs0 + p * _bs1]; - b_1p = B[1 * _bs0 + p * _bs1]; - b_2p = B[2 * _bs0 + p * _bs1]; - b_3p = B[3 * _bs0 + p * _bs1]; - b_4p = B[4 * _bs0 + p * _bs1]; + b_0p = B[0 * m_bs0 + p * m_bs1]; + b_1p = B[1 * m_bs0 + p * m_bs1]; + b_2p = B[2 * m_bs0 + p * m_bs1]; + b_3p = B[3 * m_bs0 + p * m_bs1]; + b_4p = B[4 * m_bs0 + p * m_bs1]; // 0 iteration b_1p -= a_10 * b_0p; @@ -60,10 +63,10 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<5>::serial_invoke(const Va b_4p -= a_43 * b_3p; // store - B[1 * _bs0 + p * _bs1] = b_1p; - B[2 * _bs0 + p * _bs1] = b_2p; - B[3 * _bs0 + p * _bs1] = b_3p; - B[4 * _bs0 + p * _bs1] = b_4p; + B[1 * m_bs0 + p * m_bs1] = b_1p; + B[2 * m_bs0 + p * m_bs1] = b_2p; + B[3 * m_bs0 + p * m_bs1] = b_3p; + B[4 * m_bs0 + p * m_bs1] = b_4p; }; #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) @@ -76,13 +79,15 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<5>::serial_invoke(const Va } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<4>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, const int n, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<4>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, + const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (n <= 0) return 0; - const ValueType a_10 = A[1 * _as0 + 0 * _as1], a_20 = A[2 * _as0 + 0 * _as1], a_21 = A[2 * _as0 + 1 * _as1], - a_30 = A[3 * _as0 + 0 * _as1], a_31 = A[3 * _as0 + 1 * _as1], a_32 = A[3 * _as0 + 2 * _as1]; + const ValueType a_10 = op(A[1 * m_as0 + 0 * m_as1]), a_20 = op(A[2 * m_as0 + 0 * m_as1]), + a_21 = op(A[2 * m_as0 + 1 * m_as1]), a_30 = op(A[3 * m_as0 + 0 * m_as1]), + a_31 = op(A[3 * m_as0 + 1 * m_as1]), a_32 = op(A[3 * m_as0 + 2 * m_as1]); auto trsv = [&](const int p) { ValueType b_p[4]; @@ -92,10 +97,10 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<4>::serial_invoke(const Va auto &b_3p = b_p[3]; // load - b_0p = B[0 * _bs0 + p * _bs1]; - b_1p = B[1 * _bs0 + p * _bs1]; - b_2p = B[2 * _bs0 + p * _bs1]; - b_3p = B[3 * _bs0 + p * _bs1]; + b_0p = B[0 * m_bs0 + p * m_bs1]; + b_1p = B[1 * m_bs0 + p * m_bs1]; + b_2p = B[2 * m_bs0 + p * m_bs1]; + b_3p = B[3 * m_bs0 + p * m_bs1]; // 0 iteration b_1p -= a_10 * b_0p; @@ -110,9 +115,9 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<4>::serial_invoke(const Va b_3p -= a_32 * b_2p; // store - B[1 * _bs0 + p * _bs1] = b_1p; - B[2 * _bs0 + p * _bs1] = b_2p; - B[3 * _bs0 + p * _bs1] = b_3p; + B[1 * m_bs0 + p * m_bs1] = b_1p; + B[2 * m_bs0 + p * m_bs1] = b_2p; + B[3 * m_bs0 + p * m_bs1] = b_3p; }; #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) @@ -125,12 +130,14 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<4>::serial_invoke(const Va } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<3>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, const int n, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<3>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, + const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (n <= 0) return 0; - const ValueType a_10 = A[1 * _as0 + 0 * _as1], a_20 = A[2 * _as0 + 0 * _as1], a_21 = A[2 * _as0 + 1 * _as1]; + const ValueType a_10 = op(A[1 * m_as0 + 0 * m_as1]), a_20 = op(A[2 * m_as0 + 0 * m_as1]), + a_21 = op(A[2 * m_as0 + 1 * m_as1]); auto trsv = [&](const int p) { ValueType b_p[3]; @@ -139,9 +146,9 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<3>::serial_invoke(const Va auto &b_2p = b_p[2]; // load - b_0p = B[0 * _bs0 + p * _bs1]; - b_1p = B[1 * _bs0 + p * _bs1]; - b_2p = B[2 * _bs0 + p * _bs1]; + b_0p = B[0 * m_bs0 + p * m_bs1]; + b_1p = B[1 * m_bs0 + p * m_bs1]; + b_2p = B[2 * m_bs0 + p * m_bs1]; // 0 iteration b_1p -= a_10 * b_0p; @@ -151,8 +158,8 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<3>::serial_invoke(const Va b_2p -= a_21 * b_1p; // store - B[1 * _bs0 + p * _bs1] = b_1p; - B[2 * _bs0 + p * _bs1] = b_2p; + B[1 * m_bs0 + p * m_bs1] = b_1p; + B[2 * m_bs0 + p * m_bs1] = b_2p; }; #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) @@ -165,12 +172,13 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<3>::serial_invoke(const Va } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<2>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, const int n, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<2>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, + const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (n <= 0) return 0; - const ValueType a_10 = A[1 * _as0 + 0 * _as1]; + const ValueType a_10 = op(A[1 * m_as0 + 0 * m_as1]); auto trsv = [&](const int p) { ValueType b_p[2]; @@ -178,14 +186,14 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<2>::serial_invoke(const Va auto &b_1p = b_p[1]; // load - b_0p = B[0 * _bs0 + p * _bs1]; - b_1p = B[1 * _bs0 + p * _bs1]; + b_0p = B[0 * m_bs0 + p * m_bs1]; + b_1p = B[1 * m_bs0 + p * m_bs1]; // 0 iteration b_1p -= a_10 * b_0p; // store - B[1 * _bs0 + p * _bs1] = b_1p; + B[1 * m_bs0 + p * m_bs1] = b_1p; }; #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) @@ -199,8 +207,8 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<2>::serial_invoke(const Va } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<1>::serial_invoke(const ValueType *KOKKOS_RESTRICT /* A */, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<1>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT /* A */, const int /* n */, /**/ ValueType *KOKKOS_RESTRICT /* B */) { return 0; @@ -212,130 +220,130 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<1>::serial_invoke(const Va /// L(m x m) X(m x n) = B (m x n) template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<5>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, const int m, - const int n, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<5>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, + const int m, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (m > 5) Kokkos::abort("InnerTrsmLeftLowerUnitDiag<5>::serial_invoke, assert failure (m<=5)"); if (m <= 0 || n <= 0) return 0; switch (m) { case 5: { - InnerTrsmLeftLowerUnitDiag<5> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerUnitDiag<5> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 4: { - InnerTrsmLeftLowerUnitDiag<4> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerUnitDiag<4> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 3: { - InnerTrsmLeftLowerUnitDiag<3> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerUnitDiag<3> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 2: { - InnerTrsmLeftLowerUnitDiag<2> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerUnitDiag<2> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 1: { - InnerTrsmLeftLowerUnitDiag<1> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerUnitDiag<1> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<4>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, const int m, - const int n, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<4>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, + const int m, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (m > 4) Kokkos::abort("InnerTrsmLeftLowerUnitDiag<4>::serial_invoke, assert failure (m<=4)"); if (m <= 0 || n <= 0) return 0; switch (m) { case 4: { - InnerTrsmLeftLowerUnitDiag<4> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerUnitDiag<4> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 3: { - InnerTrsmLeftLowerUnitDiag<3> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerUnitDiag<3> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 2: { - InnerTrsmLeftLowerUnitDiag<2> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerUnitDiag<2> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 1: { - InnerTrsmLeftLowerUnitDiag<1> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerUnitDiag<1> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<3>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, const int m, - const int n, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<3>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, + const int m, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (m > 3) Kokkos::abort("InnerTrsmLeftLowerUnitDiag<3>::serial_invoke, assert failure (m<=3)"); if (m <= 0 || n <= 0) return 0; switch (m) { case 3: { - InnerTrsmLeftLowerUnitDiag<3> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerUnitDiag<3> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 2: { - InnerTrsmLeftLowerUnitDiag<2> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerUnitDiag<2> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 1: { - InnerTrsmLeftLowerUnitDiag<1> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerUnitDiag<1> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<2>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, const int m, - const int n, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<2>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, + const int m, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (m > 2) Kokkos::abort("InnerTrsmLeftLowerUnitDiag<2>::serial_invoke, assert failure (m<=2)"); if (m <= 0 || n <= 0) return 0; switch (m) { case 2: { - InnerTrsmLeftLowerUnitDiag<2> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerUnitDiag<2> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 1: { - InnerTrsmLeftLowerUnitDiag<1> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerUnitDiag<1> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<1>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, const int m, - const int n, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<1>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, + const int m, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (m > 1) Kokkos::abort("InnerTrsmLeftLowerUnitDiag<1>::serial_invoke, assert failure (m<=1)"); if (m <= 0 || n <= 0) return 0; switch (m) { case 1: { - InnerTrsmLeftLowerUnitDiag<1> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerUnitDiag<1> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } } @@ -348,29 +356,30 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<1>::serial_invoke(const Va /// L(m x m) X(m x n) = B (m x n) template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<5>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<5>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (n <= 0) return 0; - const ValueType a_10 = A[1 * _as0 + 0 * _as1], a_20 = A[2 * _as0 + 0 * _as1], a_21 = A[2 * _as0 + 1 * _as1], - a_30 = A[3 * _as0 + 0 * _as1], a_31 = A[3 * _as0 + 1 * _as1], a_32 = A[3 * _as0 + 2 * _as1], - a_40 = A[4 * _as0 + 0 * _as1], a_41 = A[4 * _as0 + 1 * _as1], a_42 = A[4 * _as0 + 2 * _as1], - a_43 = A[4 * _as0 + 3 * _as1]; + const ValueType a_10 = op(A[1 * m_as0 + 0 * m_as1]), a_20 = op(A[2 * m_as0 + 0 * m_as1]), + a_21 = op(A[2 * m_as0 + 1 * m_as1]), a_30 = op(A[3 * m_as0 + 0 * m_as1]), + a_31 = op(A[3 * m_as0 + 1 * m_as1]), a_32 = op(A[3 * m_as0 + 2 * m_as1]), + a_40 = op(A[4 * m_as0 + 0 * m_as1]), a_41 = op(A[4 * m_as0 + 1 * m_as1]), + a_42 = op(A[4 * m_as0 + 2 * m_as1]), a_43 = op(A[4 * m_as0 + 3 * m_as1]); // const ValueType - // a_00 = A[0*_as0+0*_as1], - // a_11 = A[1*_as0+1*_as1], - // a_22 = A[2*_as0+2*_as1], - // a_33 = A[3*_as0+3*_as1], - // a_44 = A[4*_as0+4*_as1]; - - const ValueType inv_a_00 = static_cast(1.0) / A[0 * _as0 + 0 * _as1], - inv_a_11 = static_cast(1.0) / A[1 * _as0 + 1 * _as1], - inv_a_22 = static_cast(1.0) / A[2 * _as0 + 2 * _as1], - inv_a_33 = static_cast(1.0) / A[3 * _as0 + 3 * _as1], - inv_a_44 = static_cast(1.0) / A[4 * _as0 + 4 * _as1]; + // a_00 = A[0*m_as0+0*m_as1], + // a_11 = A[1*m_as0+1*m_as1], + // a_22 = A[2*m_as0+2*m_as1], + // a_33 = A[3*m_as0+3*m_as1], + // a_44 = A[4*m_as0+4*m_as1]; + + const ValueType inv_a_00 = static_cast(1.0) / op(A[0 * m_as0 + 0 * m_as1]), + inv_a_11 = static_cast(1.0) / op(A[1 * m_as0 + 1 * m_as1]), + inv_a_22 = static_cast(1.0) / op(A[2 * m_as0 + 2 * m_as1]), + inv_a_33 = static_cast(1.0) / op(A[3 * m_as0 + 3 * m_as1]), + inv_a_44 = static_cast(1.0) / op(A[4 * m_as0 + 4 * m_as1]); auto trsv = [&](const int p) { ValueType b_p[5]; @@ -381,11 +390,11 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<5>::serial_invoke(const auto &b_4p = b_p[4]; // load - b_0p = B[0 * _bs0 + p * _bs1]; - b_1p = B[1 * _bs0 + p * _bs1]; - b_2p = B[2 * _bs0 + p * _bs1]; - b_3p = B[3 * _bs0 + p * _bs1]; - b_4p = B[4 * _bs0 + p * _bs1]; + b_0p = B[0 * m_bs0 + p * m_bs1]; + b_1p = B[1 * m_bs0 + p * m_bs1]; + b_2p = B[2 * m_bs0 + p * m_bs1]; + b_3p = B[3 * m_bs0 + p * m_bs1]; + b_4p = B[4 * m_bs0 + p * m_bs1]; // 0 iteration b_0p *= inv_a_00; /* b_0p /= a_00;*/ @@ -413,11 +422,11 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<5>::serial_invoke(const b_4p *= inv_a_44; /* b_4p /= a_44; */ // store - B[0 * _bs0 + p * _bs1] = b_0p; - B[1 * _bs0 + p * _bs1] = b_1p; - B[2 * _bs0 + p * _bs1] = b_2p; - B[3 * _bs0 + p * _bs1] = b_3p; - B[4 * _bs0 + p * _bs1] = b_4p; + B[0 * m_bs0 + p * m_bs1] = b_0p; + B[1 * m_bs0 + p * m_bs1] = b_1p; + B[2 * m_bs0 + p * m_bs1] = b_2p; + B[3 * m_bs0 + p * m_bs1] = b_3p; + B[4 * m_bs0 + p * m_bs1] = b_4p; }; #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) @@ -431,25 +440,26 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<5>::serial_invoke(const } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<4>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<4>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (n <= 0) return 0; - const ValueType a_10 = A[1 * _as0 + 0 * _as1], a_20 = A[2 * _as0 + 0 * _as1], a_21 = A[2 * _as0 + 1 * _as1], - a_30 = A[3 * _as0 + 0 * _as1], a_31 = A[3 * _as0 + 1 * _as1], a_32 = A[3 * _as0 + 2 * _as1]; + const ValueType a_10 = op(A[1 * m_as0 + 0 * m_as1]), a_20 = op(A[2 * m_as0 + 0 * m_as1]), + a_21 = op(A[2 * m_as0 + 1 * m_as1]), a_30 = op(A[3 * m_as0 + 0 * m_as1]), + a_31 = op(A[3 * m_as0 + 1 * m_as1]), a_32 = op(A[3 * m_as0 + 2 * m_as1]); // const ValueType - // a_00 = A[0*_as0+0*_as1], - // a_11 = A[1*_as0+1*_as1], - // a_22 = A[2*_as0+2*_as1], - // a_33 = A[3*_as0+3*_as1]; + // a_00 = A[0*m_as0+0*m_as1], + // a_11 = A[1*m_as0+1*m_as1], + // a_22 = A[2*m_as0+2*m_as1], + // a_33 = A[3*m_as0+3*m_as1]; - const ValueType inv_a_00 = static_cast(1.0) / A[0 * _as0 + 0 * _as1], - inv_a_11 = static_cast(1.0) / A[1 * _as0 + 1 * _as1], - inv_a_22 = static_cast(1.0) / A[2 * _as0 + 2 * _as1], - inv_a_33 = static_cast(1.0) / A[3 * _as0 + 3 * _as1]; + const ValueType inv_a_00 = static_cast(1.0) / op(A[0 * m_as0 + 0 * m_as1]), + inv_a_11 = static_cast(1.0) / op(A[1 * m_as0 + 1 * m_as1]), + inv_a_22 = static_cast(1.0) / op(A[2 * m_as0 + 2 * m_as1]), + inv_a_33 = static_cast(1.0) / op(A[3 * m_as0 + 3 * m_as1]); auto trsv = [&](const int p) { ValueType b_p[4]; @@ -459,10 +469,10 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<4>::serial_invoke(const auto &b_3p = b_p[3]; // load - b_0p = B[0 * _bs0 + p * _bs1]; - b_1p = B[1 * _bs0 + p * _bs1]; - b_2p = B[2 * _bs0 + p * _bs1]; - b_3p = B[3 * _bs0 + p * _bs1]; + b_0p = B[0 * m_bs0 + p * m_bs1]; + b_1p = B[1 * m_bs0 + p * m_bs1]; + b_2p = B[2 * m_bs0 + p * m_bs1]; + b_3p = B[3 * m_bs0 + p * m_bs1]; // 0 iteration b_0p *= inv_a_00; /* b_0p /= a_00;*/ @@ -483,10 +493,10 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<4>::serial_invoke(const b_3p *= inv_a_33; /* b_3p /= a_33; */ // store - B[0 * _bs0 + p * _bs1] = b_0p; - B[1 * _bs0 + p * _bs1] = b_1p; - B[2 * _bs0 + p * _bs1] = b_2p; - B[3 * _bs0 + p * _bs1] = b_3p; + B[0 * m_bs0 + p * m_bs1] = b_0p; + B[1 * m_bs0 + p * m_bs1] = b_1p; + B[2 * m_bs0 + p * m_bs1] = b_2p; + B[3 * m_bs0 + p * m_bs1] = b_3p; }; #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) @@ -500,22 +510,23 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<4>::serial_invoke(const } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<3>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<3>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (n <= 0) return 0; - const ValueType a_10 = A[1 * _as0 + 0 * _as1], a_20 = A[2 * _as0 + 0 * _as1], a_21 = A[2 * _as0 + 1 * _as1]; + const ValueType a_10 = op(A[1 * m_as0 + 0 * m_as1]), a_20 = op(A[2 * m_as0 + 0 * m_as1]), + a_21 = op(A[2 * m_as0 + 1 * m_as1]); // const ValueType - // a_00 = A[0*_as0+0*_as1], - // a_11 = A[1*_as0+1*_as1], - // a_22 = A[2*_as0+2*_as1]; + // a_00 = A[0*m_as0+0*m_as1], + // a_11 = A[1*m_as0+1*m_as1], + // a_22 = A[2*m_as0+2*m_as1]; - const ValueType inv_a_00 = static_cast(1.0) / A[0 * _as0 + 0 * _as1], - inv_a_11 = static_cast(1.0) / A[1 * _as0 + 1 * _as1], - inv_a_22 = static_cast(1.0) / A[2 * _as0 + 2 * _as1]; + const ValueType inv_a_00 = static_cast(1.0) / op(A[0 * m_as0 + 0 * m_as1]), + inv_a_11 = static_cast(1.0) / op(A[1 * m_as0 + 1 * m_as1]), + inv_a_22 = static_cast(1.0) / op(A[2 * m_as0 + 2 * m_as1]); auto trsv = [&](const int p) { ValueType b_p[3]; @@ -524,9 +535,9 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<3>::serial_invoke(const auto &b_2p = b_p[2]; // load - b_0p = B[0 * _bs0 + p * _bs1]; - b_1p = B[1 * _bs0 + p * _bs1]; - b_2p = B[2 * _bs0 + p * _bs1]; + b_0p = B[0 * m_bs0 + p * m_bs1]; + b_1p = B[1 * m_bs0 + p * m_bs1]; + b_2p = B[2 * m_bs0 + p * m_bs1]; // 0 iteration b_0p *= inv_a_00; /* b_0p /= a_00;*/ @@ -541,9 +552,9 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<3>::serial_invoke(const b_2p *= inv_a_22; /* b_2p /= a_22; */ // store - B[0 * _bs0 + p * _bs1] = b_0p; - B[1 * _bs0 + p * _bs1] = b_1p; - B[2 * _bs0 + p * _bs1] = b_2p; + B[0 * m_bs0 + p * m_bs1] = b_0p; + B[1 * m_bs0 + p * m_bs1] = b_1p; + B[2 * m_bs0 + p * m_bs1] = b_2p; }; #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) @@ -557,20 +568,20 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<3>::serial_invoke(const } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<2>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<2>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (n <= 0) return 0; - const ValueType a_10 = A[1 * _as0 + 0 * _as1]; + const ValueType a_10 = op(A[1 * m_as0 + 0 * m_as1]); // const ValueType - // a_00 = A[0*_as0+0*_as1], - // a_11 = A[1*_as0+1*_as1]; + // a_00 = A[0*m_as0+0*m_as1], + // a_11 = A[1*m_as0+1*m_as1]; - const ValueType inv_a_00 = static_cast(1.0) / A[0 * _as0 + 0 * _as1], - inv_a_11 = static_cast(1.0) / A[1 * _as0 + 1 * _as1]; + const ValueType inv_a_00 = static_cast(1.0) / op(A[0 * m_as0 + 0 * m_as1]), + inv_a_11 = static_cast(1.0) / op(A[1 * m_as0 + 1 * m_as1]); auto trsv = [&](const int p) { ValueType b_p[2]; @@ -578,8 +589,8 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<2>::serial_invoke(const auto &b_1p = b_p[1]; // load - b_0p = B[0 * _bs0 + p * _bs1]; - b_1p = B[1 * _bs0 + p * _bs1]; + b_0p = B[0 * m_bs0 + p * m_bs1]; + b_1p = B[1 * m_bs0 + p * m_bs1]; // 0 iteration b_0p *= inv_a_00; /* b_0p /= a_00;*/ @@ -589,8 +600,8 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<2>::serial_invoke(const b_1p *= inv_a_11; /* b_1p /= a_11; */ // store - B[0 * _bs0 + p * _bs1] = b_0p; - B[1 * _bs0 + p * _bs1] = b_1p; + B[0 * m_bs0 + p * m_bs1] = b_0p; + B[1 * m_bs0 + p * m_bs1] = b_1p; }; #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) @@ -604,19 +615,19 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<2>::serial_invoke(const } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<1>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<1>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (n <= 0) return 0; // const ValueType - // a_00 = A[0*_as0+0*_as1]; + // a_00 = A[0*m_as0+0*m_as1]; - const ValueType inv_a_00 = static_cast(1.0) / A[0 * _as0 + 0 * _as1]; + const ValueType inv_a_00 = static_cast(1.0) / op(A[0 * m_as0 + 0 * m_as1]); auto trsv = [&](const int p) { - B[0 * _bs0 + p * _bs1] *= inv_a_00; /* b_0p /= a_00;*/ + B[0 * m_bs0 + p * m_bs1] *= inv_a_00; /* b_0p /= a_00;*/ }; #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) @@ -635,8 +646,8 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<1>::serial_invoke(const /// L(m x m) X(m x n) = B (m x n) template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<5>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<5>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int m, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (m > 5) @@ -646,36 +657,36 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<5>::serial_invoke(const if (m <= 0 || n <= 0) return 0; switch (m) { case 5: { - InnerTrsmLeftLowerNonUnitDiag<5> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerNonUnitDiag<5> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 4: { - InnerTrsmLeftLowerNonUnitDiag<4> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerNonUnitDiag<4> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 3: { - InnerTrsmLeftLowerNonUnitDiag<3> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerNonUnitDiag<3> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 2: { - InnerTrsmLeftLowerNonUnitDiag<2> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerNonUnitDiag<2> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 1: { - InnerTrsmLeftLowerNonUnitDiag<1> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerNonUnitDiag<1> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<4>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<4>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int m, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (m > 4) @@ -685,31 +696,31 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<4>::serial_invoke(const if (m <= 0 || n <= 0) return 0; switch (m) { case 4: { - InnerTrsmLeftLowerNonUnitDiag<4> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerNonUnitDiag<4> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 3: { - InnerTrsmLeftLowerNonUnitDiag<3> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerNonUnitDiag<3> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 2: { - InnerTrsmLeftLowerNonUnitDiag<2> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerNonUnitDiag<2> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 1: { - InnerTrsmLeftLowerNonUnitDiag<1> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerNonUnitDiag<1> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<3>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<3>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int m, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (m > 3) @@ -719,26 +730,26 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<3>::serial_invoke(const if (m <= 0 || n <= 0) return 0; switch (m) { case 3: { - InnerTrsmLeftLowerNonUnitDiag<3> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerNonUnitDiag<3> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 2: { - InnerTrsmLeftLowerNonUnitDiag<2> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerNonUnitDiag<2> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 1: { - InnerTrsmLeftLowerNonUnitDiag<1> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerNonUnitDiag<1> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<2>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<2>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int m, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (m > 2) @@ -748,21 +759,21 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<2>::serial_invoke(const if (m <= 0 || n <= 0) return 0; switch (m) { case 2: { - InnerTrsmLeftLowerNonUnitDiag<2> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerNonUnitDiag<2> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 1: { - InnerTrsmLeftLowerNonUnitDiag<1> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerNonUnitDiag<1> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<1>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<1>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int m, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (m > 1) @@ -772,8 +783,8 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<1>::serial_invoke(const if (m <= 0 || n <= 0) return 0; switch (m) { case 1: { - InnerTrsmLeftLowerNonUnitDiag<1> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftLowerNonUnitDiag<1> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } } @@ -786,16 +797,18 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerNonUnitDiag<1>::serial_invoke(const /// L(m x m) X(m x n) = B (m x n) template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<5>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, const int n, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<5>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, + const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (n <= 0) return 0; - const ValueType a_01 = A[0 * _as0 + 1 * _as1], a_02 = A[0 * _as0 + 2 * _as1], a_03 = A[0 * _as0 + 3 * _as1], - a_04 = A[0 * _as0 + 4 * _as1], - /**/ a_12 = A[1 * _as0 + 2 * _as1], a_13 = A[1 * _as0 + 3 * _as1], a_14 = A[1 * _as0 + 4 * _as1], - /**/ a_23 = A[2 * _as0 + 3 * _as1], a_24 = A[2 * _as0 + 4 * _as1], - /**/ a_34 = A[3 * _as0 + 4 * _as1]; + const ValueType a_01 = op(A[0 * m_as0 + 1 * m_as1]), a_02 = op(A[0 * m_as0 + 2 * m_as1]), + a_03 = op(A[0 * m_as0 + 3 * m_as1]), a_04 = op(A[0 * m_as0 + 4 * m_as1]), + /**/ a_12 = op(A[1 * m_as0 + 2 * m_as1]), a_13 = op(A[1 * m_as0 + 3 * m_as1]), + a_14 = op(A[1 * m_as0 + 4 * m_as1]), + /**/ a_23 = op(A[2 * m_as0 + 3 * m_as1]), a_24 = op(A[2 * m_as0 + 4 * m_as1]), + /**/ a_34 = op(A[3 * m_as0 + 4 * m_as1]); auto trsv = [&](const int p) { ValueType b_p[5]; @@ -806,11 +819,11 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<5>::serial_invoke(const Va auto &b_4p = b_p[4]; // load - b_0p = B[0 * _bs0 + p * _bs1]; - b_1p = B[1 * _bs0 + p * _bs1]; - b_2p = B[2 * _bs0 + p * _bs1]; - b_3p = B[3 * _bs0 + p * _bs1]; - b_4p = B[4 * _bs0 + p * _bs1]; + b_0p = B[0 * m_bs0 + p * m_bs1]; + b_1p = B[1 * m_bs0 + p * m_bs1]; + b_2p = B[2 * m_bs0 + p * m_bs1]; + b_3p = B[3 * m_bs0 + p * m_bs1]; + b_4p = B[4 * m_bs0 + p * m_bs1]; // 0 iteration b_0p -= a_04 * b_4p; @@ -831,10 +844,10 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<5>::serial_invoke(const Va b_0p -= a_01 * b_1p; // store - B[0 * _bs0 + p * _bs1] = b_0p; - B[1 * _bs0 + p * _bs1] = b_1p; - B[2 * _bs0 + p * _bs1] = b_2p; - B[3 * _bs0 + p * _bs1] = b_3p; + B[0 * m_bs0 + p * m_bs1] = b_0p; + B[1 * m_bs0 + p * m_bs1] = b_1p; + B[2 * m_bs0 + p * m_bs1] = b_2p; + B[3 * m_bs0 + p * m_bs1] = b_3p; }; #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) @@ -848,14 +861,16 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<5>::serial_invoke(const Va } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<4>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, const int n, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<4>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, + const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (n <= 0) return 0; - const ValueType a_01 = A[0 * _as0 + 1 * _as1], a_02 = A[0 * _as0 + 2 * _as1], a_03 = A[0 * _as0 + 3 * _as1], - /**/ a_12 = A[1 * _as0 + 2 * _as1], a_13 = A[1 * _as0 + 3 * _as1], - /**/ a_23 = A[2 * _as0 + 3 * _as1]; + const ValueType a_01 = op(A[0 * m_as0 + 1 * m_as1]), a_02 = op(A[0 * m_as0 + 2 * m_as1]), + a_03 = op(A[0 * m_as0 + 3 * m_as1]), + /**/ a_12 = op(A[1 * m_as0 + 2 * m_as1]), a_13 = op(A[1 * m_as0 + 3 * m_as1]), + /**/ a_23 = op(A[2 * m_as0 + 3 * m_as1]); auto trsv = [&](const int p) { ValueType b_p[4]; @@ -865,10 +880,10 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<4>::serial_invoke(const Va auto &b_3p = b_p[3]; // load - b_0p = B[0 * _bs0 + p * _bs1]; - b_1p = B[1 * _bs0 + p * _bs1]; - b_2p = B[2 * _bs0 + p * _bs1]; - b_3p = B[3 * _bs0 + p * _bs1]; + b_0p = B[0 * m_bs0 + p * m_bs1]; + b_1p = B[1 * m_bs0 + p * m_bs1]; + b_2p = B[2 * m_bs0 + p * m_bs1]; + b_3p = B[3 * m_bs0 + p * m_bs1]; // 0 iteration b_0p -= a_03 * b_3p; @@ -883,9 +898,9 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<4>::serial_invoke(const Va b_0p -= a_01 * b_1p; // store - B[0 * _bs0 + p * _bs1] = b_0p; - B[1 * _bs0 + p * _bs1] = b_1p; - B[2 * _bs0 + p * _bs1] = b_2p; + B[0 * m_bs0 + p * m_bs1] = b_0p; + B[1 * m_bs0 + p * m_bs1] = b_1p; + B[2 * m_bs0 + p * m_bs1] = b_2p; }; #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) @@ -899,13 +914,14 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<4>::serial_invoke(const Va } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<3>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, const int n, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<3>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, + const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (n <= 0) return 0; - const ValueType a_01 = A[0 * _as0 + 1 * _as1], a_02 = A[0 * _as0 + 2 * _as1], - /**/ a_12 = A[1 * _as0 + 2 * _as1]; + const ValueType a_01 = op(A[0 * m_as0 + 1 * m_as1]), a_02 = op(A[0 * m_as0 + 2 * m_as1]), + /**/ a_12 = op(A[1 * m_as0 + 2 * m_as1]); auto trsv = [&](const int p) { ValueType b_p[3]; @@ -914,9 +930,9 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<3>::serial_invoke(const Va auto &b_2p = b_p[2]; // load - b_0p = B[0 * _bs0 + p * _bs1]; - b_1p = B[1 * _bs0 + p * _bs1]; - b_2p = B[2 * _bs0 + p * _bs1]; + b_0p = B[0 * m_bs0 + p * m_bs1]; + b_1p = B[1 * m_bs0 + p * m_bs1]; + b_2p = B[2 * m_bs0 + p * m_bs1]; // 0 iteration b_0p -= a_02 * b_2p; @@ -926,8 +942,8 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<3>::serial_invoke(const Va b_0p -= a_01 * b_1p; // store - B[0 * _bs0 + p * _bs1] = b_0p; - B[1 * _bs0 + p * _bs1] = b_1p; + B[0 * m_bs0 + p * m_bs1] = b_0p; + B[1 * m_bs0 + p * m_bs1] = b_1p; }; #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) @@ -941,12 +957,13 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<3>::serial_invoke(const Va } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<2>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, const int n, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<2>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, + const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (n <= 0) return 0; - const ValueType a_01 = A[0 * _as0 + 1 * _as1]; + const ValueType a_01 = op(A[0 * m_as0 + 1 * m_as1]); auto trsv = [&](const int p) { ValueType b_p[2]; @@ -954,14 +971,14 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<2>::serial_invoke(const Va auto &b_1p = b_p[1]; // load - b_0p = B[0 * _bs0 + p * _bs1]; - b_1p = B[1 * _bs0 + p * _bs1]; + b_0p = B[0 * m_bs0 + p * m_bs1]; + b_1p = B[1 * m_bs0 + p * m_bs1]; // 0 iteration b_0p -= a_01 * b_1p; // store - B[0 * _bs0 + p * _bs1] = b_0p; + B[0 * m_bs0 + p * m_bs1] = b_0p; }; #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) @@ -975,8 +992,8 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<2>::serial_invoke(const Va } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<1>::serial_invoke(const ValueType *KOKKOS_RESTRICT /* A */, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<1>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT /* A */, const int /* n */, /**/ ValueType *KOKKOS_RESTRICT /* B */) { return 0; @@ -988,92 +1005,92 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<1>::serial_invoke(const Va /// L(m x m) X(m x n) = B (m x n) template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<5>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, const int m, - const int n, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<5>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, + const int m, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (m > 5) Kokkos::abort("InnerTrsmLeftUpperUnitDiag<5>::serial_invoke, assert failure (m<=5)"); if (m <= 0 || n <= 0) return 0; switch (m) { case 5: { - InnerTrsmLeftUpperUnitDiag<5> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperUnitDiag<5> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 4: { - InnerTrsmLeftUpperUnitDiag<4> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperUnitDiag<4> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 3: { - InnerTrsmLeftUpperUnitDiag<3> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperUnitDiag<3> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 2: { - InnerTrsmLeftUpperUnitDiag<2> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperUnitDiag<2> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 1: { - InnerTrsmLeftUpperUnitDiag<1> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperUnitDiag<1> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<4>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, const int m, - const int n, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<4>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, + const int m, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (m > 4) Kokkos::abort("InnerTrsmLeftUpperUnitDiag<4>::serial_invoke, assert failure (m<=4)"); if (m <= 0 || n <= 0) return 0; switch (m) { case 4: { - InnerTrsmLeftUpperUnitDiag<4> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperUnitDiag<4> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 3: { - InnerTrsmLeftUpperUnitDiag<3> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperUnitDiag<3> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 2: { - InnerTrsmLeftUpperUnitDiag<2> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperUnitDiag<2> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 1: { - InnerTrsmLeftUpperUnitDiag<1> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperUnitDiag<1> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<3>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, const int m, - const int n, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<3>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, + const int m, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (m > 3) Kokkos::abort("InnerTrsmLeftUpperUnitDiag<3>::serial_invoke, assert failure (m<=3)"); if (m <= 0 || n <= 0) return 0; switch (m) { case 3: { - InnerTrsmLeftUpperUnitDiag<3> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperUnitDiag<3> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 2: { - InnerTrsmLeftUpperUnitDiag<2> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperUnitDiag<2> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 1: { - InnerTrsmLeftUpperUnitDiag<1> inner(_as0, _as1, _bs0, _bs1); + InnerTrsmLeftUpperUnitDiag<1> inner(m_as0, m_as1, m_bs0, m_bs1); inner.serial_invoke(A, n, B); break; } @@ -1081,37 +1098,37 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<3>::serial_invoke(const Va return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<2>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, const int m, - const int n, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<2>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, + const int m, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (m > 2) Kokkos::abort("InnerTrsmLeftUpperUnitDiag<2>::serial_invoke, assert failure (m<=2)"); if (m <= 0 || n <= 0) return 0; switch (m) { case 2: { - InnerTrsmLeftUpperUnitDiag<2> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperUnitDiag<2> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 1: { - InnerTrsmLeftUpperUnitDiag<1> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperUnitDiag<1> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<1>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, const int m, - const int n, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<1>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, + const int m, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (m > 1) Kokkos::abort("InnerTrsmLeftUpperUnitDiag<1>::serial_invoke, assert failure (m<=1)"); if (m <= 0 || n <= 0) return 0; switch (m) { case 1: { - InnerTrsmLeftUpperUnitDiag<1> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperUnitDiag<1> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } } @@ -1124,30 +1141,31 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<1>::serial_invoke(const Va /// L(m x m) X(m x n) = B (m x n) template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<5>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<5>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (n <= 0) return 0; - const ValueType a_01 = A[0 * _as0 + 1 * _as1], a_02 = A[0 * _as0 + 2 * _as1], a_03 = A[0 * _as0 + 3 * _as1], - a_04 = A[0 * _as0 + 4 * _as1], - /**/ a_12 = A[1 * _as0 + 2 * _as1], a_13 = A[1 * _as0 + 3 * _as1], a_14 = A[1 * _as0 + 4 * _as1], - /**/ a_23 = A[2 * _as0 + 3 * _as1], a_24 = A[2 * _as0 + 4 * _as1], - /**/ a_34 = A[3 * _as0 + 4 * _as1]; + const ValueType a_01 = op(A[0 * m_as0 + 1 * m_as1]), a_02 = op(A[0 * m_as0 + 2 * m_as1]), + a_03 = op(A[0 * m_as0 + 3 * m_as1]), a_04 = op(A[0 * m_as0 + 4 * m_as1]), + /**/ a_12 = op(A[1 * m_as0 + 2 * m_as1]), a_13 = op(A[1 * m_as0 + 3 * m_as1]), + a_14 = op(A[1 * m_as0 + 4 * m_as1]), + /**/ a_23 = op(A[2 * m_as0 + 3 * m_as1]), a_24 = op(A[2 * m_as0 + 4 * m_as1]), + /**/ a_34 = op(A[3 * m_as0 + 4 * m_as1]); // const ValueType - // a_00 = A[0*_as0+0*_as1], - // a_11 = A[1*_as0+1*_as1], - // a_22 = A[2*_as0+2*_as1], - // a_33 = A[3*_as0+3*_as1], - // a_44 = A[4*_as0+4*_as1]; - - const ValueType inv_a_00 = static_cast(1.0) / A[0 * _as0 + 0 * _as1], - inv_a_11 = static_cast(1.0) / A[1 * _as0 + 1 * _as1], - inv_a_22 = static_cast(1.0) / A[2 * _as0 + 2 * _as1], - inv_a_33 = static_cast(1.0) / A[3 * _as0 + 3 * _as1], - inv_a_44 = static_cast(1.0) / A[4 * _as0 + 4 * _as1]; + // a_00 = A[0*m_as0+0*m_as1], + // a_11 = A[1*m_as0+1*m_as1], + // a_22 = A[2*m_as0+2*m_as1], + // a_33 = A[3*m_as0+3*m_as1], + // a_44 = A[4*m_as0+4*m_as1]; + + const ValueType inv_a_00 = static_cast(1.0) / op(A[0 * m_as0 + 0 * m_as1]), + inv_a_11 = static_cast(1.0) / op(A[1 * m_as0 + 1 * m_as1]), + inv_a_22 = static_cast(1.0) / op(A[2 * m_as0 + 2 * m_as1]), + inv_a_33 = static_cast(1.0) / op(A[3 * m_as0 + 3 * m_as1]), + inv_a_44 = static_cast(1.0) / op(A[4 * m_as0 + 4 * m_as1]); auto trsv = [&](const int p) { ValueType b_p[5]; @@ -1158,11 +1176,11 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<5>::serial_invoke(const auto &b_4p = b_p[4]; // load - b_0p = B[0 * _bs0 + p * _bs1]; - b_1p = B[1 * _bs0 + p * _bs1]; - b_2p = B[2 * _bs0 + p * _bs1]; - b_3p = B[3 * _bs0 + p * _bs1]; - b_4p = B[4 * _bs0 + p * _bs1]; + b_0p = B[0 * m_bs0 + p * m_bs1]; + b_1p = B[1 * m_bs0 + p * m_bs1]; + b_2p = B[2 * m_bs0 + p * m_bs1]; + b_3p = B[3 * m_bs0 + p * m_bs1]; + b_4p = B[4 * m_bs0 + p * m_bs1]; // 0 iteration b_4p *= inv_a_44; /* b_4p /= a_44;*/ @@ -1190,11 +1208,11 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<5>::serial_invoke(const b_0p *= inv_a_00; /* b_0p /= a_00; */ // store - B[0 * _bs0 + p * _bs1] = b_0p; - B[1 * _bs0 + p * _bs1] = b_1p; - B[2 * _bs0 + p * _bs1] = b_2p; - B[3 * _bs0 + p * _bs1] = b_3p; - B[4 * _bs0 + p * _bs1] = b_4p; + B[0 * m_bs0 + p * m_bs1] = b_0p; + B[1 * m_bs0 + p * m_bs1] = b_1p; + B[2 * m_bs0 + p * m_bs1] = b_2p; + B[3 * m_bs0 + p * m_bs1] = b_3p; + B[4 * m_bs0 + p * m_bs1] = b_4p; }; #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) @@ -1208,26 +1226,27 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<5>::serial_invoke(const } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<4>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<4>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (n <= 0) return 0; - const ValueType a_01 = A[0 * _as0 + 1 * _as1], a_02 = A[0 * _as0 + 2 * _as1], a_03 = A[0 * _as0 + 3 * _as1], - /**/ a_12 = A[1 * _as0 + 2 * _as1], a_13 = A[1 * _as0 + 3 * _as1], - /**/ a_23 = A[2 * _as0 + 3 * _as1]; + const ValueType a_01 = op(A[0 * m_as0 + 1 * m_as1]), a_02 = op(A[0 * m_as0 + 2 * m_as1]), + a_03 = op(A[0 * m_as0 + 3 * m_as1]), + /**/ a_12 = op(A[1 * m_as0 + 2 * m_as1]), a_13 = op(A[1 * m_as0 + 3 * m_as1]), + /**/ a_23 = op(A[2 * m_as0 + 3 * m_as1]); // const ValueType - // a_00 = A[0*_as0+0*_as1], - // a_11 = A[1*_as0+1*_as1], - // a_22 = A[2*_as0+2*_as1], - // a_33 = A[3*_as0+3*_as1]; + // a_00 = A[0*m_as0+0*m_as1], + // a_11 = A[1*m_as0+1*m_as1], + // a_22 = A[2*m_as0+2*m_as1], + // a_33 = A[3*m_as0+3*m_as1]; - const ValueType inv_a_00 = static_cast(1.0) / A[0 * _as0 + 0 * _as1], - inv_a_11 = static_cast(1.0) / A[1 * _as0 + 1 * _as1], - inv_a_22 = static_cast(1.0) / A[2 * _as0 + 2 * _as1], - inv_a_33 = static_cast(1.0) / A[3 * _as0 + 3 * _as1]; + const ValueType inv_a_00 = static_cast(1.0) / op(A[0 * m_as0 + 0 * m_as1]), + inv_a_11 = static_cast(1.0) / op(A[1 * m_as0 + 1 * m_as1]), + inv_a_22 = static_cast(1.0) / op(A[2 * m_as0 + 2 * m_as1]), + inv_a_33 = static_cast(1.0) / op(A[3 * m_as0 + 3 * m_as1]); auto trsv = [&](const int p) { ValueType b_p[4]; @@ -1237,10 +1256,10 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<4>::serial_invoke(const auto &b_3p = b_p[3]; // load - b_0p = B[0 * _bs0 + p * _bs1]; - b_1p = B[1 * _bs0 + p * _bs1]; - b_2p = B[2 * _bs0 + p * _bs1]; - b_3p = B[3 * _bs0 + p * _bs1]; + b_0p = B[0 * m_bs0 + p * m_bs1]; + b_1p = B[1 * m_bs0 + p * m_bs1]; + b_2p = B[2 * m_bs0 + p * m_bs1]; + b_3p = B[3 * m_bs0 + p * m_bs1]; // 0 iteration b_3p *= inv_a_33; /* b_3p /= a_33;*/ @@ -1261,10 +1280,10 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<4>::serial_invoke(const b_0p *= inv_a_00; /* b_0p /= a_00; */ // store - B[0 * _bs0 + p * _bs1] = b_0p; - B[1 * _bs0 + p * _bs1] = b_1p; - B[2 * _bs0 + p * _bs1] = b_2p; - B[3 * _bs0 + p * _bs1] = b_3p; + B[0 * m_bs0 + p * m_bs1] = b_0p; + B[1 * m_bs0 + p * m_bs1] = b_1p; + B[2 * m_bs0 + p * m_bs1] = b_2p; + B[3 * m_bs0 + p * m_bs1] = b_3p; }; #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) @@ -1278,23 +1297,23 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<4>::serial_invoke(const } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<3>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<3>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (n <= 0) return 0; - const ValueType a_01 = A[0 * _as0 + 1 * _as1], a_02 = A[0 * _as0 + 2 * _as1], - /**/ a_12 = A[1 * _as0 + 2 * _as1]; + const ValueType a_01 = op(A[0 * m_as0 + 1 * m_as1]), a_02 = op(A[0 * m_as0 + 2 * m_as1]), + /**/ a_12 = op(A[1 * m_as0 + 2 * m_as1]); // const ValueType - // a_00 = A[0*_as0+0*_as1], - // a_11 = A[1*_as0+1*_as1], - // a_22 = A[2*_as0+2*_as1]; + // a_00 = A[0*m_as0+0*m_as1], + // a_11 = A[1*m_as0+1*m_as1], + // a_22 = A[2*m_as0+2*m_as1]; - const ValueType inv_a_00 = static_cast(1.0) / A[0 * _as0 + 0 * _as1], - inv_a_11 = static_cast(1.0) / A[1 * _as0 + 1 * _as1], - inv_a_22 = static_cast(1.0) / A[2 * _as0 + 2 * _as1]; + const ValueType inv_a_00 = static_cast(1.0) / op(A[0 * m_as0 + 0 * m_as1]), + inv_a_11 = static_cast(1.0) / op(A[1 * m_as0 + 1 * m_as1]), + inv_a_22 = static_cast(1.0) / op(A[2 * m_as0 + 2 * m_as1]); auto trsv = [&](const int p) { ValueType b_p[3]; @@ -1303,9 +1322,9 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<3>::serial_invoke(const auto &b_2p = b_p[2]; // load - b_0p = B[0 * _bs0 + p * _bs1]; - b_1p = B[1 * _bs0 + p * _bs1]; - b_2p = B[2 * _bs0 + p * _bs1]; + b_0p = B[0 * m_bs0 + p * m_bs1]; + b_1p = B[1 * m_bs0 + p * m_bs1]; + b_2p = B[2 * m_bs0 + p * m_bs1]; // 0 iteration b_2p *= inv_a_22; /* b_2p /= a_22; */ @@ -1320,9 +1339,9 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<3>::serial_invoke(const b_0p *= inv_a_00; /* b_0p /= a_00; */ // store - B[0 * _bs0 + p * _bs1] = b_0p; - B[1 * _bs0 + p * _bs1] = b_1p; - B[2 * _bs0 + p * _bs1] = b_2p; + B[0 * m_bs0 + p * m_bs1] = b_0p; + B[1 * m_bs0 + p * m_bs1] = b_1p; + B[2 * m_bs0 + p * m_bs1] = b_2p; }; #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) @@ -1336,20 +1355,20 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<3>::serial_invoke(const } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<2>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<2>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (n <= 0) return 0; - const ValueType a_01 = A[0 * _as0 + 1 * _as1]; + const ValueType a_01 = op(A[0 * m_as0 + 1 * m_as1]); // const ValueType - // a_00 = A[0*_as0+0*_as1], - // a_11 = A[1*_as0+1*_as1]; + // a_00 = A[0*m_as0+0*m_as1], + // a_11 = A[1*m_as0+1*m_as1]; - const ValueType inv_a_00 = static_cast(1.0) / A[0 * _as0 + 0 * _as1], - inv_a_11 = static_cast(1.0) / A[1 * _as0 + 1 * _as1]; + const ValueType inv_a_00 = static_cast(1.0) / op(A[0 * m_as0 + 0 * m_as1]), + inv_a_11 = static_cast(1.0) / op(A[1 * m_as0 + 1 * m_as1]); auto trsv = [&](const int p) { ValueType b_p[2]; @@ -1357,8 +1376,8 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<2>::serial_invoke(const auto &b_1p = b_p[1]; // load - b_0p = B[0 * _bs0 + p * _bs1]; - b_1p = B[1 * _bs0 + p * _bs1]; + b_0p = B[0 * m_bs0 + p * m_bs1]; + b_1p = B[1 * m_bs0 + p * m_bs1]; // 2 iteration b_1p *= inv_a_11; /* b_1p /= a_11; */ @@ -1368,8 +1387,8 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<2>::serial_invoke(const b_0p *= inv_a_00; /* b_0p /= a_00; */ // store - B[0 * _bs0 + p * _bs1] = b_0p; - B[1 * _bs0 + p * _bs1] = b_1p; + B[0 * m_bs0 + p * m_bs1] = b_0p; + B[1 * m_bs0 + p * m_bs1] = b_1p; }; #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) @@ -1383,20 +1402,20 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<2>::serial_invoke(const } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<1>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<1>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (n <= 0) return 0; // const ValueType - // a_00 = A[0*_as0+0*_as1]; + // a_00 = A[0*m_as0+0*m_as1]; - const ValueType inv_a_00 = static_cast(1.0) / A[0 * _as0 + 0 * _as1]; + const ValueType inv_a_00 = static_cast(1.0) / op(A[0 * m_as0 + 0 * m_as1]); auto trsv = [&](const int p) { // 0 iteration - B[0 * _bs0 + p * _bs1] *= inv_a_00; /* b_0p /= a_00; */ + B[0 * m_bs0 + p * m_bs1] *= inv_a_00; /* b_0p /= a_00; */ }; #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) @@ -1415,8 +1434,8 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<1>::serial_invoke(const /// L(m x m) X(m x n) = B (m x n) template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<5>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<5>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int m, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (m > 5) @@ -1426,36 +1445,36 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<5>::serial_invoke(const if (m <= 0 || n <= 0) return 0; switch (m) { case 5: { - InnerTrsmLeftUpperNonUnitDiag<5> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperNonUnitDiag<5> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 4: { - InnerTrsmLeftUpperNonUnitDiag<4> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperNonUnitDiag<4> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 3: { - InnerTrsmLeftUpperNonUnitDiag<3> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperNonUnitDiag<3> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 2: { - InnerTrsmLeftUpperNonUnitDiag<2> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperNonUnitDiag<2> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 1: { - InnerTrsmLeftUpperNonUnitDiag<1> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperNonUnitDiag<1> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<4>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<4>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int m, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (m > 4) @@ -1465,31 +1484,31 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<4>::serial_invoke(const if (m <= 0 || n <= 0) return 0; switch (m) { case 4: { - InnerTrsmLeftUpperNonUnitDiag<4> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperNonUnitDiag<4> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 3: { - InnerTrsmLeftUpperNonUnitDiag<3> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperNonUnitDiag<3> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 2: { - InnerTrsmLeftUpperNonUnitDiag<2> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperNonUnitDiag<2> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 1: { - InnerTrsmLeftUpperNonUnitDiag<1> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperNonUnitDiag<1> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<3>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<3>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int m, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (m > 3) @@ -1499,26 +1518,26 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<3>::serial_invoke(const if (m <= 0 || n <= 0) return 0; switch (m) { case 3: { - InnerTrsmLeftUpperNonUnitDiag<3> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperNonUnitDiag<3> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 2: { - InnerTrsmLeftUpperNonUnitDiag<2> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperNonUnitDiag<2> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 1: { - InnerTrsmLeftUpperNonUnitDiag<1> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperNonUnitDiag<1> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<2>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<2>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int m, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (m > 2) @@ -1528,21 +1547,21 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<2>::serial_invoke(const if (m <= 0 || n <= 0) return 0; switch (m) { case 2: { - InnerTrsmLeftUpperNonUnitDiag<2> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperNonUnitDiag<2> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } case 1: { - InnerTrsmLeftUpperNonUnitDiag<1> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperNonUnitDiag<1> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<1>::serial_invoke(const ValueType *KOKKOS_RESTRICT A, +template +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<1>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int m, const int n, /**/ ValueType *KOKKOS_RESTRICT B) { if (m > 1) @@ -1552,8 +1571,8 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperNonUnitDiag<1>::serial_invoke(const if (m <= 0 || n <= 0) return 0; switch (m) { case 1: { - InnerTrsmLeftUpperNonUnitDiag<1> inner(_as0, _as1, _bs0, _bs1); - inner.serial_invoke(A, n, B); + InnerTrsmLeftUpperNonUnitDiag<1> inner(m_as0, m_as1, m_bs0, m_bs1); + inner.serial_invoke(op, A, n, B); break; } } diff --git a/batched/dense/impl/KokkosBatched_LU_Serial_Internal.hpp b/batched/dense/impl/KokkosBatched_LU_Serial_Internal.hpp index 9effd454e1..c28b9671a5 100644 --- a/batched/dense/impl/KokkosBatched_LU_Serial_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_LU_Serial_Internal.hpp @@ -96,8 +96,8 @@ KOKKOS_INLINE_FUNCTION int SerialLU_Internal::invoke( const int m_abr = ib - p - mb, n_abr = jb - p - mb; // trsm update - trsm_llu.serial_invoke(Ap, pb, n_abr, Ap + mb * as1); - trsm_run.serial_invoke(Ap, pb, m_abr, Ap + mb * as0); + trsm_llu.serial_invoke(KokkosBlas::Impl::OpID(), Ap, pb, n_abr, Ap + mb * as1); + trsm_run.serial_invoke(KokkosBlas::Impl::OpID(), Ap, pb, m_abr, Ap + mb * as0); // gemm update Impl::SerialGemmInternal::invoke( diff --git a/batched/dense/impl/KokkosBatched_LU_Team_Internal.hpp b/batched/dense/impl/KokkosBatched_LU_Team_Internal.hpp index 0673d8034e..e4f5914c9d 100644 --- a/batched/dense/impl/KokkosBatched_LU_Team_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_LU_Team_Internal.hpp @@ -115,10 +115,10 @@ KOKKOS_INLINE_FUNCTION int TeamLU_Internal::invoke( Kokkos::parallel_for(Kokkos::TeamThreadRange(member, 0, mq_abr + nq_abr), [&](const int &ij) { if (ij < nq_abr) { const int j = (ij)*nb, qb = (j + nb) > n_abr ? np_abr : nb; - trsm_llu.serial_invoke(Ap, pb, qb, Ap + (j + mb) * as1); + trsm_llu.serial_invoke(KokkosBlas::Impl::OpID(), Ap, pb, qb, Ap + (j + mb) * as1); } else { const int i = (ij - nq_abr) * nb, qb = (i + nb) > m_abr ? mp_abr : nb; - trsm_run.serial_invoke(Ap, pb, qb, Ap + (i + mb) * as0); + trsm_run.serial_invoke(KokkosBlas::Impl::OpID(), Ap, pb, qb, Ap + (i + mb) * as0); } }); member.team_barrier(); diff --git a/batched/dense/impl/KokkosBatched_Trsm_Serial_Impl.hpp b/batched/dense/impl/KokkosBatched_Trsm_Serial_Impl.hpp index 0b91a77532..af618313ff 100644 --- a/batched/dense/impl/KokkosBatched_Trsm_Serial_Impl.hpp +++ b/batched/dense/impl/KokkosBatched_Trsm_Serial_Impl.hpp @@ -104,8 +104,8 @@ struct SerialTrsm::invoke( - ArgDiag::use_unit_diag, false, B.extent(0), B_extent_1, alpha, A.data(), A.stride(0), A.stride(1), B.data(), - B.stride(0), B_stride_1); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(0), B_extent_1, alpha, A.data(), A.stride(0), + A.stride(1), B.data(), B.stride(0), B_stride_1); } }; @@ -128,8 +128,8 @@ struct SerialTrsm::invoke( - ArgDiag::use_unit_diag, false, B.extent(0), B_extent_1, alpha, A.data(), A.stride(0), A.stride(1), B.data(), - B.stride(0), B_stride_1); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(0), B_extent_1, alpha, A.data(), A.stride(0), + A.stride(1), B.data(), B.stride(0), B_stride_1); } }; @@ -201,8 +201,8 @@ struct SerialTrsm::invoke( - ArgDiag::use_unit_diag, false, B.extent(0), B_extent_1, alpha, A.data(), A.stride(0), A.stride(1), B.data(), - B.stride(0), B_stride_1); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(0), B_extent_1, alpha, A.data(), A.stride(0), + A.stride(1), B.data(), B.stride(0), B_stride_1); } }; @@ -224,8 +224,8 @@ struct SerialTrsm::invoke( - ArgDiag::use_unit_diag, false, B.extent(0), B_extent_1, alpha, A.data(), A.stride(0), A.stride(1), B.data(), - B.stride(0), B_stride_1); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(0), B_extent_1, alpha, A.data(), A.stride(0), + A.stride(1), B.data(), B.stride(0), B_stride_1); } }; @@ -297,8 +297,8 @@ struct SerialTrsm::invoke( - ArgDiag::use_unit_diag, false, B.extent(0), B_extent_1, alpha, A.data(), A.stride(1), A.stride(0), B.data(), - B.stride(0), B_stride_1); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(0), B_extent_1, alpha, A.data(), A.stride(1), + A.stride(0), B.data(), B.stride(0), B_stride_1); } }; @@ -320,8 +320,8 @@ struct SerialTrsm::invoke( - ArgDiag::use_unit_diag, false, B.extent(0), B_extent_1, alpha, A.data(), A.stride(1), A.stride(0), B.data(), - B.stride(0), B_stride_1); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(0), B_extent_1, alpha, A.data(), A.stride(1), + A.stride(0), B.data(), B.stride(0), B_stride_1); } }; @@ -392,8 +392,8 @@ struct SerialTrsm::invoke( - ArgDiag::use_unit_diag, false, B.extent(0), B_extent_1, alpha, A.data(), A.stride(1), A.stride(0), B.data(), - B.stride(0), B_stride_1); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(0), B_extent_1, alpha, A.data(), A.stride(1), + A.stride(0), B.data(), B.stride(0), B_stride_1); } }; @@ -415,8 +415,8 @@ struct SerialTrsm::invoke( - ArgDiag::use_unit_diag, false, B.extent(0), B_extent_1, alpha, A.data(), A.stride(1), A.stride(0), B.data(), - B.stride(0), B_stride_1); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(0), B_extent_1, alpha, A.data(), A.stride(1), + A.stride(0), B.data(), B.stride(0), B_stride_1); } }; @@ -488,12 +488,33 @@ struct SerialTrsm::invoke( - ArgDiag::use_unit_diag, true, B.extent(0), B_extent_1, alpha, A.data(), A.stride(1), A.stride(0), B.data(), - B.stride(0), B_stride_1); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpConj(), B.extent(0), B_extent_1, alpha, A.data(), A.stride(1), + A.stride(0), B.data(), B.stride(0), B_stride_1); } }; -// [TO DO] ConjTranspose is not supported yet +template +struct SerialTrsm { + template + KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const BViewType &B) { + static_assert(AViewType::rank() == 2); + constexpr size_t B_rank = BViewType::rank(); + static_assert(B_rank == 1 || B_rank == 2); + + // Quick return if possible + if (B.size() == 0) return 0; + + size_t B_extent_1 = B_rank == 1 ? 1 : B.extent(1); + size_t B_stride_1 = B_rank == 1 ? 1 : B.stride(1); + + auto info = KokkosBatched::Impl::checkTrsmInput(A, B); + if (info) return info; + + return KokkosBatched::Impl::SerialTrsmInternalLeftUpper::invoke( + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpConj(), B.extent(0), B_extent_1, alpha, A.data(), A.stride(1), + A.stride(0), B.data(), B.stride(0), B_stride_1); + } +}; /// /// L/U/C @@ -562,8 +583,8 @@ struct SerialTrsm::invoke( - ArgDiag::use_unit_diag, true, B.extent(0), B_extent_1, alpha, A.data(), A.stride(1), A.stride(0), B.data(), - B.stride(0), B_stride_1); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpConj(), B.extent(0), B_extent_1, alpha, A.data(), A.stride(1), + A.stride(0), B.data(), B.stride(0), B_stride_1); } }; @@ -622,8 +643,8 @@ struct SerialTrsm::invoke( - ArgDiag::use_unit_diag, false, B.extent(1), B.extent(0), alpha, A.data(), A.stride(1), A.stride(0), B.data(), - B.stride(1), B.stride(0)); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(1), + A.stride(0), B.data(), B.stride(1), B.stride(0)); } }; @@ -639,8 +660,8 @@ struct SerialTrsm::invoke( - ArgDiag::use_unit_diag, false, B.extent(1), B.extent(0), alpha, A.data(), A.stride(1), A.stride(0), B.data(), - B.stride(1), B.stride(0)); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(1), + A.stride(0), B.data(), B.stride(1), B.stride(0)); } }; @@ -696,8 +717,8 @@ struct SerialTrsm::invoke( - ArgDiag::use_unit_diag, false, B.extent(1), B.extent(0), alpha, A.data(), A.stride(1), A.stride(0), B.data(), - B.stride(1), B.stride(0)); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(1), + A.stride(0), B.data(), B.stride(1), B.stride(0)); } }; @@ -713,8 +734,8 @@ struct SerialTrsm::invoke( - ArgDiag::use_unit_diag, false, B.extent(1), B.extent(0), alpha, A.data(), A.stride(1), A.stride(0), B.data(), - B.stride(1), B.stride(0)); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(1), + A.stride(0), B.data(), B.stride(1), B.stride(0)); } }; @@ -771,8 +792,8 @@ struct SerialTrsm::invoke( - ArgDiag::use_unit_diag, false, B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), A.stride(1), B.data(), - B.stride(1), B.stride(0)); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), + A.stride(1), B.data(), B.stride(1), B.stride(0)); } }; @@ -788,8 +809,8 @@ struct SerialTrsm::invoke( - ArgDiag::use_unit_diag, false, B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), A.stride(1), B.data(), - B.stride(1), B.stride(0)); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), + A.stride(1), B.data(), B.stride(1), B.stride(0)); } }; @@ -845,8 +866,8 @@ struct SerialTrsm::invoke( - ArgDiag::use_unit_diag, false, B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), A.stride(1), B.data(), - B.stride(1), B.stride(0)); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), + A.stride(1), B.data(), B.stride(1), B.stride(0)); } }; @@ -862,8 +883,8 @@ struct SerialTrsm::invoke( - ArgDiag::use_unit_diag, false, B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), A.stride(1), B.data(), - B.stride(1), B.stride(0)); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), + A.stride(1), B.data(), B.stride(1), B.stride(0)); } }; @@ -920,8 +941,8 @@ struct SerialTrsm::invoke( - ArgDiag::use_unit_diag, true, B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), A.stride(1), B.data(), - B.stride(1), B.stride(0)); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpConj(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), + A.stride(1), B.data(), B.stride(1), B.stride(0)); } }; @@ -979,12 +1000,27 @@ struct SerialTrsm::invoke( - ArgDiag::use_unit_diag, true, B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), A.stride(1), B.data(), - B.stride(1), B.stride(0)); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpConj(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), + A.stride(1), B.data(), B.stride(1), B.stride(0)); } }; -// [TO DO] ConjTranspose is not supported yet +template +struct SerialTrsm { + template + KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const BViewType &B) { + static_assert(AViewType::rank() == 2 && BViewType::rank() == 2); + // Quick return if possible + if (B.extent(0) == 0 || B.extent(1) == 0) return 0; + + auto info = KokkosBatched::Impl::checkTrsmInput(A, B); + if (info) return info; + + return KokkosBatched::Impl::SerialTrsmInternalLeftUpper::invoke( + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpConj(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), + A.stride(1), B.data(), B.stride(1), B.stride(0)); + } +}; } // namespace KokkosBatched diff --git a/batched/dense/impl/KokkosBatched_Trsm_Serial_Internal.hpp b/batched/dense/impl/KokkosBatched_Trsm_Serial_Internal.hpp index 982129d8d1..70cf0f1637 100644 --- a/batched/dense/impl/KokkosBatched_Trsm_Serial_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_Trsm_Serial_Internal.hpp @@ -22,17 +22,17 @@ namespace Impl { template struct SerialTrsmInternalLeftLower { - template - KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const bool do_conj, const int m, const int n, + template + KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, Op op, const int m, const int n, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, /**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1); }; template <> -template +template KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftLower::invoke( - const bool use_unit_diag, const bool do_conj, const int m, const int n, const ScalarType alpha, + const bool use_unit_diag, Op op, const int m, const int n, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, /**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1) { const ScalarType one(1.0), zero(0.0); @@ -50,8 +50,7 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftLower::i ValueType *KOKKOS_RESTRICT b1t = B + p * bs0, *KOKKOS_RESTRICT B2 = iend ? B + (p + 1) * bs0 : nullptr; if (!use_unit_diag) { - const ValueType alpha11 = - (do_conj ? KokkosKernels::ArithTraits::conj(A[p * as0 + p * as1]) : A[p * as0 + p * as1]); + const ValueType alpha11 = op(A[p * as0 + p * as1]); #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) #pragma unroll @@ -64,9 +63,7 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftLower::i #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) #pragma unroll #endif - for (int j = 0; j < jend; ++j) - B2[i * bs0 + j * bs1] -= (do_conj ? KokkosKernels::ArithTraits::conj(a21[i * as0]) * b1t[j * bs1] - : a21[i * as0] * b1t[j * bs1]); + for (int j = 0; j < jend; ++j) B2[i * bs0 + j * bs1] -= op(a21[i * as0]) * b1t[j * bs1]; } } @@ -74,9 +71,9 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftLower::i } template <> -template +template KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftLower::invoke( - const bool use_unit_diag, const bool /*do_conj*/, const int m, const int n, const ScalarType alpha, + const bool use_unit_diag, Op op, const int m, const int n, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, /**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1) { const ScalarType one(1.0), zero(0.0), minus_one(-1.0); @@ -102,142 +99,136 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftLower::inv /**/ ValueType *KOKKOS_RESTRICT Bp = BB + p * bs0; if (use_unit_diag) - trsm_u.serial_invoke(Ap, pb, jb, Bp); + trsm_u.serial_invoke(op, Ap, pb, jb, Bp); else - trsm_n.serial_invoke(Ap, pb, jb, Bp); + trsm_n.serial_invoke(op, Ap, pb, jb, Bp); // gemm update for (int i = p + mb; i < ib; i += mb) { const int mm = (i + mb) > ib ? (ib - i) : mb; - gemm.serial_invoke(minus_one, AA + i * as0 + p * as1, BB + p * bs0, mm, jb, pb, BB + i * bs0); + gemm.serial_invoke(op, KokkosBlas::Impl::OpID(), minus_one, AA + i * as0 + p * as1, BB + p * bs0, mm, jb, + pb, BB + i * bs0); } - } + }; + KOKKOS_IF_ON_HOST((host_or_device(Algo::Trsm::Blocked::Impl::Host{});)) + KOKKOS_IF_ON_DEVICE((host_or_device(Algo::Trsm::Blocked::Impl::Device{});)) }; - KOKKOS_IF_ON_HOST((host_or_device(Algo::Trsm::Blocked::Impl::Host{});)) - KOKKOS_IF_ON_DEVICE((host_or_device(Algo::Trsm::Blocked::Impl::Device{});)) - }; - - const bool is_small = true; //(m*n <= 64*64); - if (is_small) { - trsm(m, n, A, B); - } else { - // // some cache blocking may need (not priority yet); - // trsm(m, n, A, B); - } - } - return 0; -} - -template -struct SerialTrsmInternalLeftUpper { - template - KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const bool do_conj, const int m, const int n, - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, - const int as1, - /**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1); -}; - -template <> -template -KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper::invoke( - const bool use_unit_diag, const bool do_conj, const int m, const int n, const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, - /**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1) { - const ScalarType one(1.0), zero(0.0); - - if (alpha == zero) - KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, B, bs0, bs1); - else { - if (alpha != one) KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1); - ValueType *KOKKOS_RESTRICT B0 = B; - for (int p = (m - 1); p >= 0; --p) { - const int iend = p, jend = n; - - const ValueType *KOKKOS_RESTRICT a01 = A + p * as1; - ValueType *KOKKOS_RESTRICT b1t = B + p * bs0; + const bool is_small = true; //(m*n <= 64*64); + if (is_small) { + trsm(m, n, A, B); + } else { + // // some cache blocking may need (not priority yet); + // trsm(m, n, A, B); + } + } return 0; + } - if (!use_unit_diag) { - const ValueType alpha11 = - (do_conj ? KokkosKernels::ArithTraits::conj(A[p * as0 + p * as1]) : A[p * as0 + p * as1]); + template + struct SerialTrsmInternalLeftUpper { + template + KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, Op op, const int m, const int n, + const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, + const int as1, + /**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1); + }; + + template <> + template + KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper::invoke( + const bool use_unit_diag, Op op, const int m, const int n, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, + /**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1) { + const ScalarType one(1.0), zero(0.0); + + if (alpha == zero) + KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, B, bs0, bs1); + else { + if (alpha != one) KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1); + + ValueType *KOKKOS_RESTRICT B0 = B; + for (int p = (m - 1); p >= 0; --p) { + const int iend = p, jend = n; + + const ValueType *KOKKOS_RESTRICT a01 = A + p * as1; + ValueType *KOKKOS_RESTRICT b1t = B + p * bs0; + + if (!use_unit_diag) { + const ValueType alpha11 = op(A[p * as0 + p * as1]); #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) #pragma unroll #endif - for (int j = 0; j < n; ++j) b1t[j * bs1] = b1t[j * bs1] / alpha11; - } + for (int j = 0; j < n; ++j) b1t[j * bs1] = b1t[j * bs1] / alpha11; + } - if (p > 0) { // Note: A workaround to produce correct results for - // complex with Intel-18.2.199 - for (int i = 0; i < iend; ++i) + if (p > 0) { // Note: A workaround to produce correct results for + // complex with Intel-18.2.199 + for (int i = 0; i < iend; ++i) #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) #pragma unroll #endif - for (int j = 0; j < jend; ++j) - B0[i * bs0 + j * bs1] -= (do_conj ? KokkosKernels::ArithTraits::conj(a01[i * as0]) * b1t[j * bs1] - : a01[i * as0] * b1t[j * bs1]); + for (int j = 0; j < jend; ++j) B0[i * bs0 + j * bs1] -= op(a01[i * as0]) * b1t[j * bs1]; + } } } + return 0; } - return 0; -} -template <> -template -KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper::invoke( - const bool use_unit_diag, const bool /*do_conj*/, const int m, const int n, const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, - /**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1) { - const ScalarType one(1.0), zero(0.0), minus_one(-1.0); - - if (alpha == zero) - KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, B, bs0, bs1); - else { - if (alpha != one) KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1); - - auto trsm = [&](const int ib, const int jb, const ValueType *KOKKOS_RESTRICT AA, - /**/ ValueType *KOKKOS_RESTRICT BB) { - auto host_or_device = [&](auto tag) { - constexpr int mb = Algo::Trsm::Blocked::Impl::mb(); - InnerTrsmLeftUpperUnitDiag trsm_u(as0, as1, bs0, bs1); - InnerTrsmLeftUpperNonUnitDiag trsm_n(as0, as1, bs0, bs1); - InnerGemmFixA gemm(as0, as1, bs0, bs1, bs0, bs1); - for (int pp = 0; pp < ib; pp += mb) { - const int ptmp = ib - pp - mb, p = ptmp < 0 ? 0 : ptmp, pb = mb + (ptmp < 0) * ptmp; - - // trsm update - const ValueType *KOKKOS_RESTRICT Ap = AA + p * as0 + p * as1; - /**/ ValueType *KOKKOS_RESTRICT Bp = BB + p * bs0; - - if (use_unit_diag) - trsm_u.serial_invoke(Ap, pb, jb, Bp); - else - trsm_n.serial_invoke(Ap, pb, jb, Bp); - - // gemm update - for (int i = 0; i < p; i += mb) { - gemm.serial_invoke(minus_one, AA + i * as0 + p * as1, Bp, (i + mb) > p ? (p - i) : mb, jb, pb, - BB + i * bs0); - } + template <> + template + KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper::invoke( + const bool use_unit_diag, Op op, const int m, const int n, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, + /**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1) { + const ScalarType one(1.0), zero(0.0), minus_one(-1.0); + + if (alpha == zero) + KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, B, bs0, bs1); + else { + if (alpha != one) KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1); + + auto trsm = [&](const int ib, const int jb, const ValueType *KOKKOS_RESTRICT AA, + /**/ ValueType *KOKKOS_RESTRICT BB) { + auto host_or_device = [&](auto tag) { + constexpr int mb = Algo::Trsm::Blocked::Impl::mb(); + InnerTrsmLeftUpperUnitDiag trsm_u(as0, as1, bs0, bs1); + InnerTrsmLeftUpperNonUnitDiag trsm_n(as0, as1, bs0, bs1); + InnerGemmFixA gemm(as0, as1, bs0, bs1, bs0, bs1); + for (int pp = 0; pp < ib; pp += mb) { + const int ptmp = ib - pp - mb, p = ptmp < 0 ? 0 : ptmp, pb = mb + (ptmp < 0) * ptmp; + + // trsm update + const ValueType *KOKKOS_RESTRICT Ap = AA + p * as0 + p * as1; + /**/ ValueType *KOKKOS_RESTRICT Bp = BB + p * bs0; + + if (use_unit_diag) + trsm_u.serial_invoke(op, Ap, pb, jb, Bp); + else + trsm_n.serial_invoke(op, Ap, pb, jb, Bp); + + // gemm update + for (int i = 0; i < p; i += mb) { + gemm.serial_invoke(op, KokkosBlas::Impl::OpID(), minus_one, AA + i * as0 + p * as1, Bp, + (i + mb) > p ? (p - i) : mb, jb, pb, BB + i * bs0); + } + }; + KOKKOS_IF_ON_HOST((host_or_device(Algo::Trsm::Blocked::Impl::Host{});)) + KOKKOS_IF_ON_DEVICE((host_or_device(Algo::Trsm::Blocked::Impl::Device{});)) + }; + + const bool is_small = (m * n <= 64 * 64); + if (is_small) { + trsm(m, n, A, B); + } else { + // // some cache blocking may need (not priority yet); + // trsm(m, n, A, B); } - }; - KOKKOS_IF_ON_HOST((host_or_device(Algo::Trsm::Blocked::Impl::Host{});)) - KOKKOS_IF_ON_DEVICE((host_or_device(Algo::Trsm::Blocked::Impl::Device{});)) - }; - - const bool is_small = (m * n <= 64 * 64); - if (is_small) { - trsm(m, n, A, B); - } else { - // // some cache blocking may need (not priority yet); - // trsm(m, n, A, B); + } return 0; } - } - return 0; -} -} // namespace Impl + } // namespace Impl } // namespace KokkosBatched #endif diff --git a/batched/dense/impl/KokkosBatched_Trsm_Team_Internal.hpp b/batched/dense/impl/KokkosBatched_Trsm_Team_Internal.hpp index a760389919..72765bfd66 100644 --- a/batched/dense/impl/KokkosBatched_Trsm_Team_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_Trsm_Team_Internal.hpp @@ -113,9 +113,9 @@ KOKKOS_INLINE_FUNCTION int TeamTrsmInternalLeftLower::invok // Made this non-const in order to WORKAROUND issue #349 int j = jj * nb, qb = (j + nb) > jb ? np : nb; if (use_unit_diag) - trsm_u.serial_invoke(Ap, pb, qb, Bp + j * bs1); + trsm_u.serial_invoke(KokkosBlas::Impl::OpID(), Ap, pb, qb, Bp + j * bs1); else - trsm_n.serial_invoke(Ap, pb, qb, Bp + j * bs1); + trsm_n.serial_invoke(KokkosBlas::Impl::OpID(), Ap, pb, qb, Bp + j * bs1); }); member.team_barrier(); @@ -233,9 +233,9 @@ KOKKOS_INLINE_FUNCTION int TeamTrsmInternalLeftUpper::invok Kokkos::parallel_for(Kokkos::TeamThreadRange(member, 0, (jb / nb) + (np > 0)), [&](const int &jj) { const int j = jj * nb, qb = (j + nb) > jb ? np : nb; if (use_unit_diag) - trsm_u.serial_invoke(Ap, pb, qb, Bp + j * bs1); + trsm_u.serial_invoke(KokkosBlas::Impl::OpID(), Ap, pb, qb, Bp + j * bs1); else - trsm_n.serial_invoke(Ap, pb, qb, Bp + j * bs1); + trsm_n.serial_invoke(KokkosBlas::Impl::OpID(), Ap, pb, qb, Bp + j * bs1); }); member.team_barrier(); diff --git a/batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp b/batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp index 78cbdbabd4..7b0c41a75e 100644 --- a/batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp +++ b/batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp @@ -106,7 +106,8 @@ struct SerialTrsv::invoke( - ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), b.data(), b.stride(0)); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), + b.data(), b.stride(0)); } }; @@ -120,7 +121,8 @@ struct SerialTrsv auto info = KokkosBatched::Impl::checkTrsvInput(A, b); if (info) return info; return KokkosBatched::Impl::SerialTrsvInternalLower::invoke( - ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), b.data(), b.stride(0)); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), + b.data(), b.stride(0)); } }; @@ -177,7 +179,8 @@ struct SerialTrsv auto info = KokkosBatched::Impl::checkTrsvInput(A, b); if (info) return info; return KokkosBatched::Impl::SerialTrsvInternalUpper::invoke( - ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0)); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), + b.data(), b.stride(0)); } }; @@ -191,7 +194,8 @@ struct SerialTrsv { auto info = KokkosBatched::Impl::checkTrsvInput(A, b); if (info) return info; return KokkosBatched::Impl::SerialTrsvInternalUpper::invoke( - ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0)); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), + b.data(), b.stride(0)); } }; @@ -248,7 +252,8 @@ struct SerialTrsv::invoke( - ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0)); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpConj(), A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), + b.data(), b.stride(0)); } }; @@ -262,7 +267,8 @@ struct SerialTrsv::invoke( - ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0)); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpConj(), A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), + b.data(), b.stride(0)); } }; @@ -319,7 +325,8 @@ struct SerialTrsv::invoke( - ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), b.data(), b.stride(0)); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), + b.data(), b.stride(0)); } }; @@ -333,7 +340,8 @@ struct SerialTrsv auto info = KokkosBatched::Impl::checkTrsvInput(A, b); if (info) return info; return KokkosBatched::Impl::SerialTrsvInternalUpper::invoke( - ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), b.data(), b.stride(0)); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), + b.data(), b.stride(0)); } }; @@ -390,7 +398,8 @@ struct SerialTrsv auto info = KokkosBatched::Impl::checkTrsvInput(A, b); if (info) return info; return KokkosBatched::Impl::SerialTrsvInternalLower::invoke( - ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0)); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), + b.data(), b.stride(0)); } }; @@ -404,7 +413,8 @@ struct SerialTrsv { auto info = KokkosBatched::Impl::checkTrsvInput(A, b); if (info) return info; return KokkosBatched::Impl::SerialTrsvInternalLower::invoke( - ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0)); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpID(), A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), + b.data(), b.stride(0)); } }; @@ -461,7 +471,8 @@ struct SerialTrsv::invoke( - ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0)); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpConj(), A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), + b.data(), b.stride(0)); } }; @@ -475,7 +486,8 @@ struct SerialTrsv::invoke( - ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0)); + ArgDiag::use_unit_diag, KokkosBlas::Impl::OpConj(), A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), + b.data(), b.stride(0)); } }; diff --git a/batched/dense/impl/KokkosBatched_Trsv_Serial_Internal.hpp b/batched/dense/impl/KokkosBatched_Trsv_Serial_Internal.hpp index e66fcfe24f..1402aa761b 100644 --- a/batched/dense/impl/KokkosBatched_Trsv_Serial_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_Trsv_Serial_Internal.hpp @@ -26,19 +26,20 @@ namespace Impl { template struct SerialTrsvInternalLower { - template - KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const bool do_conj, const int m, - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, - const int as1, + template + KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, Op op, const int m, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, /**/ ValueType *KOKKOS_RESTRICT b, const int bs0); }; template <> -template -KOKKOS_INLINE_FUNCTION int SerialTrsvInternalLower::invoke( - const bool use_unit_diag, const bool do_conj, const int m, const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, - /**/ ValueType *KOKKOS_RESTRICT b, const int bs0) { +template +KOKKOS_INLINE_FUNCTION int SerialTrsvInternalLower::invoke(const bool use_unit_diag, Op op, + const int m, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, + const int as0, const int as1, + /**/ ValueType *KOKKOS_RESTRICT b, + const int bs0) { const ScalarType one(1.0), zero(0.0); if (alpha == zero) @@ -57,24 +58,22 @@ KOKKOS_INLINE_FUNCTION int SerialTrsvInternalLower::invok // with KOKKOS_RESTRICT a compiler assumes that the pointer is not // accessed by others op(/=) uses this pointer and changes the associated // values, which brings a compiler problem - if (!use_unit_diag) - *beta1 = (do_conj ? *beta1 / KokkosKernels::ArithTraits::conj(A[p * as0 + p * as1]) - : *beta1 / A[p * as0 + p * as1]); + if (!use_unit_diag) *beta1 = *beta1 / op(A[p * as0 + p * as1]); - for (int i = 0; i < iend; ++i) - b2[i * bs0] -= - (do_conj ? KokkosKernels::ArithTraits::conj(a21[i * as0]) * (*beta1) : a21[i * as0] * (*beta1)); + for (int i = 0; i < iend; ++i) b2[i * bs0] -= op(a21[i * as0]) * (*beta1); } } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int SerialTrsvInternalLower::invoke( - const bool use_unit_diag, const bool /*do_conj*/, const int m, const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, - /**/ ValueType *KOKKOS_RESTRICT b, const int bs0) { +template +KOKKOS_INLINE_FUNCTION int SerialTrsvInternalLower::invoke(const bool use_unit_diag, Op op, + const int m, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, + const int as0, const int as1, + /**/ ValueType *KOKKOS_RESTRICT b, + const int bs0) { const ScalarType one(1.0), zero(0.0), minus_one(-1.0); if (alpha == zero) @@ -96,13 +95,13 @@ KOKKOS_INLINE_FUNCTION int SerialTrsvInternalLower::invoke( /**/ ValueType *KOKKOS_RESTRICT bp = b + p * bs0; if (use_unit_diag) - trsm_u.serial_invoke(Ap, pb, 1, bp); + trsm_u.serial_invoke(op, Ap, pb, 1, bp); else - trsm_n.serial_invoke(Ap, pb, 1, bp); + trsm_n.serial_invoke(op, Ap, pb, 1, bp); // gemv update - KokkosBlas::Impl::SerialGemvInternal::invoke(m - p - pb, pb, minus_one, Ap + pb * as0, as0, - as1, bp, bs0, one, bp + pb * bs0, bs0); + KokkosBlas::Impl::SerialGemvInternal::invoke(op, m - p - pb, pb, minus_one, Ap + pb * as0, + as0, as1, bp, bs0, one, bp + pb * bs0, bs0); } }; KOKKOS_IF_ON_HOST((host_or_device(Algo::Trsv::Blocked::Impl::Host{});)) @@ -117,19 +116,20 @@ KOKKOS_INLINE_FUNCTION int SerialTrsvInternalLower::invoke( template struct SerialTrsvInternalUpper { - template - KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const bool do_conj, const int m, - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, - const int as1, + template + KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, Op op, const int m, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, /**/ ValueType *KOKKOS_RESTRICT b, const int bs0); }; template <> -template -KOKKOS_INLINE_FUNCTION int SerialTrsvInternalUpper::invoke( - const bool use_unit_diag, const bool do_conj, const int m, const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, - /**/ ValueType *KOKKOS_RESTRICT b, const int bs0) { +template +KOKKOS_INLINE_FUNCTION int SerialTrsvInternalUpper::invoke(const bool use_unit_diag, Op op, + const int m, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, + const int as0, const int as1, + /**/ ValueType *KOKKOS_RESTRICT b, + const int bs0) { const ScalarType one(1.0), zero(0.0); if (alpha == zero) @@ -148,24 +148,22 @@ KOKKOS_INLINE_FUNCTION int SerialTrsvInternalUpper::invok // with KOKKOS_RESTRICT a compiler assumes that the pointer is not // accessed by others op(/=) uses this pointer and changes the associated // values, which brings a compiler problem - if (!use_unit_diag) - *beta1 = (do_conj ? *beta1 / KokkosKernels::ArithTraits::conj(A[p * as0 + p * as1]) - : *beta1 / A[p * as0 + p * as1]); + if (!use_unit_diag) *beta1 = *beta1 / op(A[p * as0 + p * as1]); - for (int i = 0; i < iend; ++i) - b0[i * bs0] -= - (do_conj ? KokkosKernels::ArithTraits::conj(a01[i * as0]) * (*beta1) : a01[i * as0] * (*beta1)); + for (int i = 0; i < iend; ++i) b0[i * bs0] -= op(a01[i * as0]) * (*beta1); } } return 0; } template <> -template -KOKKOS_INLINE_FUNCTION int SerialTrsvInternalUpper::invoke( - const bool use_unit_diag, const bool /*do_conj*/, const int m, const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, - /**/ ValueType *KOKKOS_RESTRICT b, const int bs0) { +template +KOKKOS_INLINE_FUNCTION int SerialTrsvInternalUpper::invoke(const bool use_unit_diag, Op op, + const int m, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, + const int as0, const int as1, + /**/ ValueType *KOKKOS_RESTRICT b, + const int bs0) { const ScalarType one(1.0), zero(0.0), minus_one(-1.0); // note that parallel range is different ( m*n vs m-1*n); @@ -188,13 +186,13 @@ KOKKOS_INLINE_FUNCTION int SerialTrsvInternalUpper::invoke( /**/ ValueType *KOKKOS_RESTRICT bp = b + p * bs0; if (use_unit_diag) - trsm_u.serial_invoke(Ap, pb, 1, bp); + trsm_u.serial_invoke(op, Ap, pb, 1, bp); else - trsm_n.serial_invoke(Ap, pb, 1, bp); + trsm_n.serial_invoke(op, Ap, pb, 1, bp); // gemv update - KokkosBlas::Impl::SerialGemvInternal::invoke(p, pb, minus_one, Ap - p * as0, as0, as1, bp, - bs0, one, b, bs0); + KokkosBlas::Impl::SerialGemvInternal::invoke(op, p, pb, minus_one, Ap - p * as0, as0, as1, + bp, bs0, one, b, bs0); } }; KOKKOS_IF_ON_HOST((host_or_device(Algo::Trsm::Blocked::Impl::Host{});)) @@ -211,7 +209,8 @@ struct [[deprecated("Use KokkosBatched::SerialTrsv instead")]] SerialTrsvInterna KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const int m, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, /**/ ValueType *KOKKOS_RESTRICT b, const int bs0) { - return Impl::SerialTrsvInternalLower::invoke(use_unit_diag, false, m, alpha, A, as0, as1, b, bs0); + return Impl::SerialTrsvInternalLower::invoke(use_unit_diag, KokkosBlas::Impl::OpID(), m, alpha, A, as0, + as1, b, bs0); } }; @@ -221,7 +220,8 @@ struct [[deprecated("Use KokkosBatched::SerialTrsv instead")]] SerialTrsvInterna KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const int m, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, /**/ ValueType *KOKKOS_RESTRICT b, const int bs0) { - return Impl::SerialTrsvInternalUpper::invoke(use_unit_diag, false, m, alpha, A, as0, as1, b, bs0); + return Impl::SerialTrsvInternalUpper::invoke(use_unit_diag, KokkosBlas::Impl::OpID(), m, alpha, A, as0, + as1, b, bs0); } }; diff --git a/batched/dense/impl/KokkosBatched_Trsv_Team_Internal.hpp b/batched/dense/impl/KokkosBatched_Trsv_Team_Internal.hpp index 0190ae3ff3..e30a20c3e8 100644 --- a/batched/dense/impl/KokkosBatched_Trsv_Team_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_Trsv_Team_Internal.hpp @@ -102,9 +102,9 @@ KOKKOS_INLINE_FUNCTION int TeamTrsvInternalLower::invoke( member.team_barrier(); if (member.team_rank() == 0) { if (use_unit_diag) - trsm_u.serial_invoke(Ap, pb, 1, bp); + trsm_u.serial_invoke(KokkosBlas::Impl::OpID(), Ap, pb, 1, bp); else - trsm_n.serial_invoke(Ap, pb, 1, bp); + trsm_n.serial_invoke(KokkosBlas::Impl::OpID(), Ap, pb, 1, bp); } // gemv update @@ -204,9 +204,9 @@ KOKKOS_INLINE_FUNCTION int TeamTrsvInternalUpper::invoke( member.team_barrier(); if (member.team_rank() == 0) { if (use_unit_diag) - trsm_u.serial_invoke(Ap, pb, 1, bp); + trsm_u.serial_invoke(KokkosBlas::Impl::OpID(), Ap, pb, 1, bp); else - trsm_n.serial_invoke(Ap, pb, 1, bp); + trsm_n.serial_invoke(KokkosBlas::Impl::OpID(), Ap, pb, 1, bp); } // gemv update diff --git a/batched/dense/src/KokkosBatched_InnerGemmFixA_Decl.hpp b/batched/dense/src/KokkosBatched_InnerGemmFixA_Decl.hpp index 8fdfd69cb9..da543650e5 100644 --- a/batched/dense/src/KokkosBatched_InnerGemmFixA_Decl.hpp +++ b/batched/dense/src/KokkosBatched_InnerGemmFixA_Decl.hpp @@ -9,21 +9,21 @@ namespace KokkosBatched { template struct InnerGemmFixA { - const int _as0, _as1, _bs0, _bs1, _cs0, _cs1; + const int m_as0, m_as1, m_bs0, m_bs1, m_cs0, m_cs1; KOKKOS_INLINE_FUNCTION InnerGemmFixA(const int as0, const int as1, const int bs0, const int bs1, const int cs0, const int cs1) - : _as0(as0), _as1(as1), _bs0(bs0), _bs1(bs1), _cs0(cs0), _cs1(cs1) {} + : m_as0(as0), m_as1(as1), m_bs0(bs0), m_bs1(bs1), m_cs0(cs0), m_cs1(cs1) {} // serial rank update - template - KOKKOS_INLINE_FUNCTION int serial_invoke(const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, + template + KOKKOS_INLINE_FUNCTION int serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int n, /**/ ValueType *KOKKOS_RESTRICT C); // serial rank update for remainder - template - KOKKOS_INLINE_FUNCTION int serial_invoke(const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, + template + KOKKOS_INLINE_FUNCTION int serial_invoke(OpA opA, OpB opB, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const ValueType *KOKKOS_RESTRICT B, const int m, const int n, const int k, /**/ ValueType *KOKKOS_RESTRICT C); }; diff --git a/batched/dense/src/KokkosBatched_InnerTrsm_Decl.hpp b/batched/dense/src/KokkosBatched_InnerTrsm_Decl.hpp index 9001eb2fc8..355e63313c 100644 --- a/batched/dense/src/KokkosBatched_InnerTrsm_Decl.hpp +++ b/batched/dense/src/KokkosBatched_InnerTrsm_Decl.hpp @@ -4,6 +4,7 @@ #define KOKKOSBATCHED_INNER_TRSM_DECL_HPP /// \author Kyungjoo Kim (kyukim@sandia.gov) +/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr) namespace KokkosBatched { @@ -11,20 +12,20 @@ namespace KokkosBatched { // Solve L(m x m) X(m x n) = B(m x n) template struct InnerTrsmLeftLowerUnitDiag { - const int _as0, _as1, _bs0, _bs1; + const int m_as0, m_as1, m_bs0, m_bs1; KOKKOS_INLINE_FUNCTION InnerTrsmLeftLowerUnitDiag(const int as0, const int as1, const int bs0, const int bs1) - : _as0(as0), _as1(as1), _bs0(bs0), _bs1(bs1) {} + : m_as0(as0), m_as1(as1), m_bs0(bs0), m_bs1(bs1) {} // trisolve - template - KOKKOS_INLINE_FUNCTION int serial_invoke(const ValueType *KOKKOS_RESTRICT A, const int n, + template + KOKKOS_INLINE_FUNCTION int serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int n, /**/ ValueType *KOKKOS_RESTRICT B); // for remainder - template - KOKKOS_INLINE_FUNCTION int serial_invoke(const ValueType *KOKKOS_RESTRICT A, const int m, const int n, + template + KOKKOS_INLINE_FUNCTION int serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int m, const int n, /**/ ValueType *KOKKOS_RESTRICT B); }; @@ -32,20 +33,20 @@ struct InnerTrsmLeftLowerUnitDiag { // Solve L(m x m) X(m x n) = B(m x n) template struct InnerTrsmLeftLowerNonUnitDiag { - const int _as0, _as1, _bs0, _bs1; + const int m_as0, m_as1, m_bs0, m_bs1; KOKKOS_INLINE_FUNCTION InnerTrsmLeftLowerNonUnitDiag(const int as0, const int as1, const int bs0, const int bs1) - : _as0(as0), _as1(as1), _bs0(bs0), _bs1(bs1) {} + : m_as0(as0), m_as1(as1), m_bs0(bs0), m_bs1(bs1) {} // trisolve - template - KOKKOS_INLINE_FUNCTION int serial_invoke(const ValueType *KOKKOS_RESTRICT A, const int n, + template + KOKKOS_INLINE_FUNCTION int serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int n, /**/ ValueType *KOKKOS_RESTRICT B); // for remainder - template - KOKKOS_INLINE_FUNCTION int serial_invoke(const ValueType *KOKKOS_RESTRICT A, const int m, const int n, + template + KOKKOS_INLINE_FUNCTION int serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int m, const int n, /**/ ValueType *KOKKOS_RESTRICT B); }; @@ -53,20 +54,20 @@ struct InnerTrsmLeftLowerNonUnitDiag { // Solve U(m x m) X(m x n) = B(m x n) template struct InnerTrsmLeftUpperUnitDiag { - const int _as0, _as1, _bs0, _bs1; + const int m_as0, m_as1, m_bs0, m_bs1; KOKKOS_INLINE_FUNCTION InnerTrsmLeftUpperUnitDiag(const int as0, const int as1, const int bs0, const int bs1) - : _as0(as0), _as1(as1), _bs0(bs0), _bs1(bs1) {} + : m_as0(as0), m_as1(as1), m_bs0(bs0), m_bs1(bs1) {} // trisolve - template - KOKKOS_INLINE_FUNCTION int serial_invoke(const ValueType *KOKKOS_RESTRICT A, const int n, + template + KOKKOS_INLINE_FUNCTION int serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int n, /**/ ValueType *KOKKOS_RESTRICT B); // for remainder - template - KOKKOS_INLINE_FUNCTION int serial_invoke(const ValueType *KOKKOS_RESTRICT A, const int m, const int n, + template + KOKKOS_INLINE_FUNCTION int serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int m, const int n, /**/ ValueType *KOKKOS_RESTRICT B); }; @@ -74,20 +75,20 @@ struct InnerTrsmLeftUpperUnitDiag { // Solve U(m x m) X(m x n) = B(m x n) template struct InnerTrsmLeftUpperNonUnitDiag { - const int _as0, _as1, _bs0, _bs1; + const int m_as0, m_as1, m_bs0, m_bs1; KOKKOS_INLINE_FUNCTION InnerTrsmLeftUpperNonUnitDiag(const int as0, const int as1, const int bs0, const int bs1) - : _as0(as0), _as1(as1), _bs0(bs0), _bs1(bs1) {} + : m_as0(as0), m_as1(as1), m_bs0(bs0), m_bs1(bs1) {} // trisolve - template - KOKKOS_INLINE_FUNCTION int serial_invoke(const ValueType *KOKKOS_RESTRICT A, const int n, + template + KOKKOS_INLINE_FUNCTION int serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int n, /**/ ValueType *KOKKOS_RESTRICT B); // for remainder - template - KOKKOS_INLINE_FUNCTION int serial_invoke(const ValueType *KOKKOS_RESTRICT A, const int m, const int n, + template + KOKKOS_INLINE_FUNCTION int serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT A, const int m, const int n, /**/ ValueType *KOKKOS_RESTRICT B); }; diff --git a/batched/dense/unit_test/Test_Batched_SerialTrsv.hpp b/batched/dense/unit_test/Test_Batched_SerialTrsv.hpp index 27226e7c34..1dfb2ebe48 100644 --- a/batched/dense/unit_test/Test_Batched_SerialTrsv.hpp +++ b/batched/dense/unit_test/Test_Batched_SerialTrsv.hpp @@ -401,19 +401,12 @@ int test_batched_trsv() { AlgoTagType>(1); Test::Trsv::impl_test_batched_trsv_analytical(2); - - // FIXME: ConjTranspose with blocking is not implemented yet - if constexpr (!std::is_same_v) { - Test::Trsv::impl_test_batched_trsv_blocking(0, 10); - } + Test::Trsv::impl_test_batched_trsv_blocking(0, 10); for (int i = 0; i < 10; ++i) { - // FIXME: ConjTranspose with blocking is not implemented yet - if constexpr (!std::is_same_v) { - Test::Trsv::impl_test_batched_trsv_blocking(1, i); - } + Test::Trsv::impl_test_batched_trsv_blocking(1, i); Test::Trsv::impl_test_batched_trsv(1, i); Test::Trsv::impl_test_batched_trsv(2, @@ -428,19 +421,12 @@ int test_batched_trsv() { AlgoTagType>(1); Test::Trsv::impl_test_batched_trsv_analytical(2); - - // FIXME: ConjTranspose with blocking is not implemented yet - if constexpr (!std::is_same_v) { - Test::Trsv::impl_test_batched_trsv_blocking(0, 10); - } + Test::Trsv::impl_test_batched_trsv_blocking(0, 10); for (int i = 0; i < 10; ++i) { - // FIXME: ConjTranspose with blocking is not implemented yet - if constexpr (!std::is_same_v) { - Test::Trsv::impl_test_batched_trsv_blocking(1, i); - } + Test::Trsv::impl_test_batched_trsv_blocking(1, i); Test::Trsv::impl_test_batched_trsv(1, i); Test::Trsv::impl_test_batched_trsv(2, diff --git a/blas/impl/KokkosBlas3_trsm_impl.hpp b/blas/impl/KokkosBlas3_trsm_impl.hpp index 327c1ac35d..13ae3b0e23 100644 --- a/blas/impl/KokkosBlas3_trsm_impl.hpp +++ b/blas/impl/KokkosBlas3_trsm_impl.hpp @@ -99,25 +99,25 @@ void SerialTrsm_Invoke(const char side[], const char uplo[], const char trans[], if (((side[0] == 'L') || (side[0] == 'l')) && ((uplo[0] == 'L') || (uplo[0] == 'l')) && ((trans[0] == 'N') || (trans[0] == 'n')) && ((diag[0] == 'U') || (diag[0] == 'u'))) KokkosBatched::Impl::SerialTrsmInternalLeftLower::invoke( - Diag::Unit::use_unit_diag, false, B.extent(0), B.extent(1), alpha, A.data(), A.stride(0), A.stride(1), B.data(), - B.stride(0), B.stride(1)); + Diag::Unit::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(0), B.extent(1), alpha, A.data(), A.stride(0), + A.stride(1), B.data(), B.stride(0), B.stride(1)); if (((side[0] == 'L') || (side[0] == 'l')) && ((uplo[0] == 'L') || (uplo[0] == 'l')) && ((trans[0] == 'N') || (trans[0] == 'n')) && ((diag[0] == 'N') || (diag[0] == 'n'))) KokkosBatched::Impl::SerialTrsmInternalLeftLower::invoke( - Diag::NonUnit::use_unit_diag, false, B.extent(0), B.extent(1), alpha, A.data(), A.stride(0), A.stride(1), - B.data(), B.stride(0), B.stride(1)); + Diag::NonUnit::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(0), B.extent(1), alpha, A.data(), A.stride(0), + A.stride(1), B.data(), B.stride(0), B.stride(1)); // Side::Left, Uplo::Lower, Trans::Transpose if (((side[0] == 'L') || (side[0] == 'l')) && ((uplo[0] == 'L') || (uplo[0] == 'l')) && ((trans[0] == 'T') || (trans[0] == 't')) && ((diag[0] == 'U') || (diag[0] == 'u'))) KokkosBatched::Impl::SerialTrsmInternalLeftUpper::invoke( - Diag::Unit::use_unit_diag, false, B.extent(0), B.extent(1), alpha, A.data(), A.stride(1), A.stride(0), B.data(), - B.stride(0), B.stride(1)); + Diag::Unit::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(0), B.extent(1), alpha, A.data(), A.stride(1), + A.stride(0), B.data(), B.stride(0), B.stride(1)); if (((side[0] == 'L') || (side[0] == 'l')) && ((uplo[0] == 'L') || (uplo[0] == 'l')) && ((trans[0] == 'T') || (trans[0] == 't')) && ((diag[0] == 'N') || (diag[0] == 'n'))) KokkosBatched::Impl::SerialTrsmInternalLeftUpper::invoke( - Diag::NonUnit::use_unit_diag, false, B.extent(0), B.extent(1), alpha, A.data(), A.stride(1), A.stride(0), - B.data(), B.stride(0), B.stride(1)); + Diag::NonUnit::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(0), B.extent(1), alpha, A.data(), A.stride(1), + A.stride(0), B.data(), B.stride(0), B.stride(1)); // Side::Left, Uplo::Lower, Trans::ConjTranspose if (((side[0] == 'L') || (side[0] == 'l')) && ((uplo[0] == 'L') || (uplo[0] == 'l')) && @@ -133,25 +133,25 @@ void SerialTrsm_Invoke(const char side[], const char uplo[], const char trans[], if (((side[0] == 'L') || (side[0] == 'l')) && ((uplo[0] == 'U') || (uplo[0] == 'u')) && ((trans[0] == 'N') || (trans[0] == 'n')) && ((diag[0] == 'U') || (diag[0] == 'u'))) KokkosBatched::Impl::SerialTrsmInternalLeftUpper::invoke( - Diag::Unit::use_unit_diag, false, B.extent(0), B.extent(1), alpha, A.data(), A.stride(0), A.stride(1), B.data(), - B.stride(0), B.stride(1)); + Diag::Unit::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(0), B.extent(1), alpha, A.data(), A.stride(0), + A.stride(1), B.data(), B.stride(0), B.stride(1)); if (((side[0] == 'L') || (side[0] == 'l')) && ((uplo[0] == 'U') || (uplo[0] == 'u')) && ((trans[0] == 'N') || (trans[0] == 'n')) && ((diag[0] == 'N') || (diag[0] == 'n'))) KokkosBatched::Impl::SerialTrsmInternalLeftUpper::invoke( - Diag::NonUnit::use_unit_diag, false, B.extent(0), B.extent(1), alpha, A.data(), A.stride(0), A.stride(1), - B.data(), B.stride(0), B.stride(1)); + Diag::NonUnit::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(0), B.extent(1), alpha, A.data(), A.stride(0), + A.stride(1), B.data(), B.stride(0), B.stride(1)); // Side::Left, Uplo::Upper, Trans::Transpose if (((side[0] == 'L') || (side[0] == 'l')) && ((uplo[0] == 'U') || (uplo[0] == 'u')) && ((trans[0] == 'T') || (trans[0] == 't')) && ((diag[0] == 'U') || (diag[0] == 'u'))) KokkosBatched::Impl::SerialTrsmInternalLeftLower::invoke( - Diag::Unit::use_unit_diag, false, B.extent(0), B.extent(1), alpha, A.data(), A.stride(1), A.stride(0), B.data(), - B.stride(0), B.stride(1)); + Diag::Unit::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(0), B.extent(1), alpha, A.data(), A.stride(1), + A.stride(0), B.data(), B.stride(0), B.stride(1)); if (((side[0] == 'L') || (side[0] == 'l')) && ((uplo[0] == 'U') || (uplo[0] == 'u')) && ((trans[0] == 'T') || (trans[0] == 't')) && ((diag[0] == 'N') || (diag[0] == 'n'))) KokkosBatched::Impl::SerialTrsmInternalLeftLower::invoke( - Diag::NonUnit::use_unit_diag, false, B.extent(0), B.extent(1), alpha, A.data(), A.stride(1), A.stride(0), - B.data(), B.stride(0), B.stride(1)); + Diag::NonUnit::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(0), B.extent(1), alpha, A.data(), A.stride(1), + A.stride(0), B.data(), B.stride(0), B.stride(1)); // Side::Left, Uplo::Upper, Trans::ConjTranspose if (((side[0] == 'L') || (side[0] == 'l')) && ((uplo[0] == 'U') || (uplo[0] == 'u')) && @@ -167,25 +167,25 @@ void SerialTrsm_Invoke(const char side[], const char uplo[], const char trans[], if (((side[0] == 'R') || (side[0] == 'r')) && ((uplo[0] == 'L') || (uplo[0] == 'l')) && ((trans[0] == 'N') || (trans[0] == 'n')) && ((diag[0] == 'U') || (diag[0] == 'u'))) KokkosBatched::Impl::SerialTrsmInternalLeftUpper::invoke( - Diag::Unit::use_unit_diag, false, B.extent(1), B.extent(0), alpha, A.data(), A.stride(1), A.stride(0), B.data(), - B.stride(1), B.stride(0)); + Diag::Unit::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(1), + A.stride(0), B.data(), B.stride(1), B.stride(0)); if (((side[0] == 'R') || (side[0] == 'r')) && ((uplo[0] == 'L') || (uplo[0] == 'l')) && ((trans[0] == 'N') || (trans[0] == 'n')) && ((diag[0] == 'N') || (diag[0] == 'n'))) KokkosBatched::Impl::SerialTrsmInternalLeftUpper::invoke( - Diag::NonUnit::use_unit_diag, false, B.extent(1), B.extent(0), alpha, A.data(), A.stride(1), A.stride(0), - B.data(), B.stride(1), B.stride(0)); + Diag::NonUnit::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(1), + A.stride(0), B.data(), B.stride(1), B.stride(0)); // Side::Right, Uplo::Lower, Trans::Transpose if (((side[0] == 'R') || (side[0] == 'r')) && ((uplo[0] == 'L') || (uplo[0] == 'l')) && ((trans[0] == 'T') || (trans[0] == 't')) && ((diag[0] == 'U') || (diag[0] == 'u'))) KokkosBatched::Impl::SerialTrsmInternalLeftLower::invoke( - Diag::Unit::use_unit_diag, false, B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), A.stride(1), B.data(), - B.stride(1), B.stride(0)); + Diag::Unit::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), + A.stride(1), B.data(), B.stride(1), B.stride(0)); if (((side[0] == 'R') || (side[0] == 'r')) && ((uplo[0] == 'L') || (uplo[0] == 'l')) && ((trans[0] == 'T') || (trans[0] == 't')) && ((diag[0] == 'N') || (diag[0] == 'n'))) KokkosBatched::Impl::SerialTrsmInternalLeftLower::invoke( - Diag::NonUnit::use_unit_diag, false, B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), A.stride(1), - B.data(), B.stride(1), B.stride(0)); + Diag::NonUnit::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), + A.stride(1), B.data(), B.stride(1), B.stride(0)); // Side::Right, Uplo::Lower, Trans::ConjTranspose if (((side[0] == 'R') || (side[0] == 'r')) && ((uplo[0] == 'L') || (uplo[0] == 'l')) && @@ -201,25 +201,25 @@ void SerialTrsm_Invoke(const char side[], const char uplo[], const char trans[], if (((side[0] == 'R') || (side[0] == 'r')) && ((uplo[0] == 'U') || (uplo[0] == 'u')) && ((trans[0] == 'N') || (trans[0] == 'n')) && ((diag[0] == 'U') || (diag[0] == 'u'))) KokkosBatched::Impl::SerialTrsmInternalLeftLower::invoke( - Diag::Unit::use_unit_diag, false, B.extent(1), B.extent(0), alpha, A.data(), A.stride(1), A.stride(0), B.data(), - B.stride(1), B.stride(0)); + Diag::Unit::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(1), + A.stride(0), B.data(), B.stride(1), B.stride(0)); if (((side[0] == 'R') || (side[0] == 'r')) && ((uplo[0] == 'U') || (uplo[0] == 'u')) && ((trans[0] == 'N') || (trans[0] == 'n')) && ((diag[0] == 'N') || (diag[0] == 'n'))) KokkosBatched::Impl::SerialTrsmInternalLeftLower::invoke( - Diag::NonUnit::use_unit_diag, false, B.extent(1), B.extent(0), alpha, A.data(), A.stride(1), A.stride(0), - B.data(), B.stride(1), B.stride(0)); + Diag::NonUnit::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(1), + A.stride(0), B.data(), B.stride(1), B.stride(0)); // Side::Right, Uplo::Upper, Trans::Transpose if (((side[0] == 'R') || (side[0] == 'r')) && ((uplo[0] == 'U') || (uplo[0] == 'u')) && ((trans[0] == 'T') || (trans[0] == 't')) && ((diag[0] == 'U') || (diag[0] == 'u'))) KokkosBatched::Impl::SerialTrsmInternalLeftUpper::invoke( - Diag::Unit::use_unit_diag, false, B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), A.stride(1), B.data(), - B.stride(1), B.stride(0)); + Diag::Unit::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), + A.stride(1), B.data(), B.stride(1), B.stride(0)); if (((side[0] == 'R') || (side[0] == 'r')) && ((uplo[0] == 'U') || (uplo[0] == 'u')) && ((trans[0] == 'T') || (trans[0] == 't')) && ((diag[0] == 'N') || (diag[0] == 'n'))) KokkosBatched::Impl::SerialTrsmInternalLeftUpper::invoke( - Diag::NonUnit::use_unit_diag, false, B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), A.stride(1), - B.data(), B.stride(1), B.stride(0)); + Diag::NonUnit::use_unit_diag, KokkosBlas::Impl::OpID(), B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), + A.stride(1), B.data(), B.stride(1), B.stride(0)); // Side::Right, Uplo::Upper, Trans::ConjTranspose if (((side[0] == 'R') || (side[0] == 'r')) && ((uplo[0] == 'U') || (uplo[0] == 'u')) && From c90b9c26980a3e0dcad9c7890a4724ada887350d Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Tue, 5 May 2026 13:49:57 +0900 Subject: [PATCH 2/4] fix nesting Signed-off-by: Yuuichi Asahi --- .../KokkosBatched_Trsm_Serial_Internal.hpp | 197 +++++++++--------- 1 file changed, 100 insertions(+), 97 deletions(-) diff --git a/batched/dense/impl/KokkosBatched_Trsm_Serial_Internal.hpp b/batched/dense/impl/KokkosBatched_Trsm_Serial_Internal.hpp index 70cf0f1637..d5fed66b2f 100644 --- a/batched/dense/impl/KokkosBatched_Trsm_Serial_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_Trsm_Serial_Internal.hpp @@ -109,126 +109,129 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftLower::inv gemm.serial_invoke(op, KokkosBlas::Impl::OpID(), minus_one, AA + i * as0 + p * as1, BB + p * bs0, mm, jb, pb, BB + i * bs0); } - }; - KOKKOS_IF_ON_HOST((host_or_device(Algo::Trsm::Blocked::Impl::Host{});)) - KOKKOS_IF_ON_DEVICE((host_or_device(Algo::Trsm::Blocked::Impl::Device{});)) + } }; - - const bool is_small = true; //(m*n <= 64*64); - if (is_small) { - trsm(m, n, A, B); - } else { - // // some cache blocking may need (not priority yet); - // trsm(m, n, A, B); - } - } return 0; + KOKKOS_IF_ON_HOST((host_or_device(Algo::Trsm::Blocked::Impl::Host{});)) + KOKKOS_IF_ON_DEVICE((host_or_device(Algo::Trsm::Blocked::Impl::Device{});)) + }; + + const bool is_small = true; //(m*n <= 64*64); + if (is_small) { + trsm(m, n, A, B); + } else { + // // some cache blocking may need (not priority yet); + // trsm(m, n, A, B); + } } + return 0; +} - template - struct SerialTrsmInternalLeftUpper { - template - KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, Op op, const int m, const int n, - const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, - const int as1, - /**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1); - }; - - template <> +template +struct SerialTrsmInternalLeftUpper { template - KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper::invoke( - const bool use_unit_diag, Op op, const int m, const int n, const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, - /**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1) { - const ScalarType one(1.0), zero(0.0); + KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, Op op, const int m, const int n, + const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, + const int as1, + /**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1); +}; - if (alpha == zero) - KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, B, bs0, bs1); - else { - if (alpha != one) KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1); +template <> +template +KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper::invoke( + const bool use_unit_diag, Op op, const int m, const int n, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, + /**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1) { + const ScalarType one(1.0), zero(0.0); + + if (alpha == zero) + KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, B, bs0, bs1); + else { + if (alpha != one) KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1); - ValueType *KOKKOS_RESTRICT B0 = B; - for (int p = (m - 1); p >= 0; --p) { - const int iend = p, jend = n; + ValueType *KOKKOS_RESTRICT B0 = B; + for (int p = (m - 1); p >= 0; --p) { + const int iend = p, jend = n; - const ValueType *KOKKOS_RESTRICT a01 = A + p * as1; - ValueType *KOKKOS_RESTRICT b1t = B + p * bs0; + const ValueType *KOKKOS_RESTRICT a01 = A + p * as1; + ValueType *KOKKOS_RESTRICT b1t = B + p * bs0; - if (!use_unit_diag) { - const ValueType alpha11 = op(A[p * as0 + p * as1]); + if (!use_unit_diag) { + const ValueType alpha11 = op(A[p * as0 + p * as1]); #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) #pragma unroll #endif - for (int j = 0; j < n; ++j) b1t[j * bs1] = b1t[j * bs1] / alpha11; - } + for (int j = 0; j < n; ++j) b1t[j * bs1] = b1t[j * bs1] / alpha11; + } - if (p > 0) { // Note: A workaround to produce correct results for - // complex with Intel-18.2.199 - for (int i = 0; i < iend; ++i) + if (p > 0) { // Note: A workaround to produce correct results for + // complex with Intel-18.2.199 + for (int i = 0; i < iend; ++i) #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) #pragma unroll #endif - for (int j = 0; j < jend; ++j) B0[i * bs0 + j * bs1] -= op(a01[i * as0]) * b1t[j * bs1]; - } + for (int j = 0; j < jend; ++j) B0[i * bs0 + j * bs1] -= op(a01[i * as0]) * b1t[j * bs1]; } } - return 0; } + return 0; +} - template <> - template - KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper::invoke( - const bool use_unit_diag, Op op, const int m, const int n, const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, - /**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1) { - const ScalarType one(1.0), zero(0.0), minus_one(-1.0); - - if (alpha == zero) - KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, B, bs0, bs1); - else { - if (alpha != one) KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1); - - auto trsm = [&](const int ib, const int jb, const ValueType *KOKKOS_RESTRICT AA, - /**/ ValueType *KOKKOS_RESTRICT BB) { - auto host_or_device = [&](auto tag) { - constexpr int mb = Algo::Trsm::Blocked::Impl::mb(); - InnerTrsmLeftUpperUnitDiag trsm_u(as0, as1, bs0, bs1); - InnerTrsmLeftUpperNonUnitDiag trsm_n(as0, as1, bs0, bs1); - InnerGemmFixA gemm(as0, as1, bs0, bs1, bs0, bs1); - for (int pp = 0; pp < ib; pp += mb) { - const int ptmp = ib - pp - mb, p = ptmp < 0 ? 0 : ptmp, pb = mb + (ptmp < 0) * ptmp; - - // trsm update - const ValueType *KOKKOS_RESTRICT Ap = AA + p * as0 + p * as1; - /**/ ValueType *KOKKOS_RESTRICT Bp = BB + p * bs0; - - if (use_unit_diag) - trsm_u.serial_invoke(op, Ap, pb, jb, Bp); - else - trsm_n.serial_invoke(op, Ap, pb, jb, Bp); - - // gemm update - for (int i = 0; i < p; i += mb) { - gemm.serial_invoke(op, KokkosBlas::Impl::OpID(), minus_one, AA + i * as0 + p * as1, Bp, - (i + mb) > p ? (p - i) : mb, jb, pb, BB + i * bs0); - } - }; - KOKKOS_IF_ON_HOST((host_or_device(Algo::Trsm::Blocked::Impl::Host{});)) - KOKKOS_IF_ON_DEVICE((host_or_device(Algo::Trsm::Blocked::Impl::Device{});)) - }; - - const bool is_small = (m * n <= 64 * 64); - if (is_small) { - trsm(m, n, A, B); - } else { - // // some cache blocking may need (not priority yet); - // trsm(m, n, A, B); +template <> +template +KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper::invoke( + const bool use_unit_diag, Op op, const int m, const int n, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, + /**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1) { + const ScalarType one(1.0), zero(0.0), minus_one(-1.0); + + if (alpha == zero) + KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, B, bs0, bs1); + else { + if (alpha != one) KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1); + + auto trsm = [&](const int ib, const int jb, const ValueType *KOKKOS_RESTRICT AA, + /**/ ValueType *KOKKOS_RESTRICT BB) { + auto host_or_device = [&](auto tag) { + constexpr int mb = Algo::Trsm::Blocked::Impl::mb(); + InnerTrsmLeftUpperUnitDiag trsm_u(as0, as1, bs0, bs1); + InnerTrsmLeftUpperNonUnitDiag trsm_n(as0, as1, bs0, bs1); + InnerGemmFixA gemm(as0, as1, bs0, bs1, bs0, bs1); + for (int pp = 0; pp < ib; pp += mb) { + const int ptmp = ib - pp - mb, p = ptmp < 0 ? 0 : ptmp, pb = mb + (ptmp < 0) * ptmp; + + // trsm update + const ValueType *KOKKOS_RESTRICT Ap = AA + p * as0 + p * as1; + /**/ ValueType *KOKKOS_RESTRICT Bp = BB + p * bs0; + + if (use_unit_diag) + trsm_u.serial_invoke(op, Ap, pb, jb, Bp); + else + trsm_n.serial_invoke(op, Ap, pb, jb, Bp); + + // gemm update + for (int i = 0; i < p; i += mb) { + gemm.serial_invoke(op, KokkosBlas::Impl::OpID(), minus_one, AA + i * as0 + p * as1, Bp, + (i + mb) > p ? (p - i) : mb, jb, pb, BB + i * bs0); + } } - } return 0; + }; + KOKKOS_IF_ON_HOST((host_or_device(Algo::Trsm::Blocked::Impl::Host{});)) + KOKKOS_IF_ON_DEVICE((host_or_device(Algo::Trsm::Blocked::Impl::Device{});)) + }; + + const bool is_small = (m * n <= 64 * 64); + if (is_small) { + trsm(m, n, A, B); + } else { + // // some cache blocking may need (not priority yet); + // trsm(m, n, A, B); } - - } // namespace Impl + } + return 0; +} +} // namespace Impl } // namespace KokkosBatched #endif From 49de5c9a894dc8ff8944ea3f3425ef9103c36b30 Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Tue, 5 May 2026 13:57:40 +0900 Subject: [PATCH 3/4] Add static rank check when input is a View Signed-off-by: Yuuichi Asahi --- .../impl/KokkosBatched_Trsv_Serial_Impl.hpp | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp b/batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp index 7b0c41a75e..3787320f0b 100644 --- a/batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp +++ b/batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp @@ -15,10 +15,24 @@ namespace Impl { template KOKKOS_INLINE_FUNCTION static int checkTrsvInput([[maybe_unused]] const AViewType &A, [[maybe_unused]] const bViewType &b) { - static_assert(Kokkos::is_view_v || Kokkos::is_dyn_rank_view_v, - "KokkosBatched::trsv: AViewType must be either a Kokkos::View or a Kokkos::DynRankView."); - static_assert(Kokkos::is_view_v || Kokkos::is_dyn_rank_view_v, - "KokkosBatched::trsv: bViewType must be either a Kokkos::View or a Kokkos::DynRankView."); + if constexpr (Kokkos::is_view_v) { + static_assert(AViewType::rank() == 2, "KokkosBatched::trsv: A must be a rank 2 View."); + } else if constexpr (Kokkos::is_dyn_rank_view_v) { + KOKKOS_EXPECTS((A.rank() == 2)); + } else { + static_assert(Kokkos::is_view_v || Kokkos::is_dyn_rank_view_v, + "KokkosBatched::trsv: AViewType must be either a Kokkos::View or a Kokkos::DynRankView."); + } + + if constexpr (Kokkos::is_view_v) { + static_assert(bViewType::rank() == 1, "KokkosBatched::trsv: b must be a rank 1 View."); + } else if constexpr (Kokkos::is_dyn_rank_view_v) { + KOKKOS_EXPECTS((b.rank() == 1)); + } else { + static_assert(Kokkos::is_view_v || Kokkos::is_dyn_rank_view_v, + "KokkosBatched::trsv: bViewType must be either a Kokkos::View or a Kokkos::DynRankView."); + } + #ifndef NDEBUG if (A.rank() != 2) { Kokkos::printf( From 8c1fdde59967382432d8dfcea695f80965d8900c Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Tue, 5 May 2026 14:09:20 +0900 Subject: [PATCH 4/4] op is unused in UnitDiag n=1 case Signed-off-by: Yuuichi Asahi --- batched/dense/impl/KokkosBatched_InnerTrsm_Serial_Impl.hpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/batched/dense/impl/KokkosBatched_InnerTrsm_Serial_Impl.hpp b/batched/dense/impl/KokkosBatched_InnerTrsm_Serial_Impl.hpp index c52c158ec3..a7bd9c51be 100644 --- a/batched/dense/impl/KokkosBatched_InnerTrsm_Serial_Impl.hpp +++ b/batched/dense/impl/KokkosBatched_InnerTrsm_Serial_Impl.hpp @@ -208,7 +208,8 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<2>::serial_invoke(Op op, c template <> template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<1>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT /* A */, +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftLowerUnitDiag<1>::serial_invoke(Op /* op */, + const ValueType *KOKKOS_RESTRICT /* A */, const int /* n */, /**/ ValueType *KOKKOS_RESTRICT /* B */) { return 0; @@ -993,7 +994,8 @@ KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<2>::serial_invoke(Op op, c template <> template -KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<1>::serial_invoke(Op op, const ValueType *KOKKOS_RESTRICT /* A */, +KOKKOS_INLINE_FUNCTION int InnerTrsmLeftUpperUnitDiag<1>::serial_invoke(Op /*op*/, + const ValueType *KOKKOS_RESTRICT /* A */, const int /* n */, /**/ ValueType *KOKKOS_RESTRICT /* B */) { return 0;