diff --git a/batched/dense/impl/KokkosBatched_Rot_Impl.hpp b/batched/dense/impl/KokkosBatched_Rot_Impl.hpp new file mode 100644 index 0000000000..14030f54a9 --- /dev/null +++ b/batched/dense/impl/KokkosBatched_Rot_Impl.hpp @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project + +#ifndef KOKKOSBATCHED_ROT_IMPL_HPP_ +#define KOKKOSBATCHED_ROT_IMPL_HPP_ + +#include +#include +#include "KokkosBatched_Rot_Internal.hpp" + +namespace KokkosBatched { +namespace Impl { +template +KOKKOS_INLINE_FUNCTION static int checkRotInput([[maybe_unused]] const XViewType &x, + [[maybe_unused]] const YViewType &y) { + static_assert(Kokkos::is_view_v, "KokkosBatched::rot: XViewType is not a Kokkos::View."); + static_assert(Kokkos::is_view_v, "KokkosBatched::rot: YViewType is not a Kokkos::View."); + static_assert(XViewType::rank() == 1, "KokkosBatched::rot: XViewType must have rank 1."); + static_assert(YViewType::rank() == 1, "KokkosBatched::rot: YViewType must have rank 1."); + static_assert(std::is_same_v, + "KokkosBatched::rot: XViewType must have non-const value type."); + static_assert(std::is_same_v, + "KokkosBatched::rot: YViewType must have non-const value type."); + static_assert(!KokkosKernels::ArithTraits::is_complex, "KokkosBatched::rot: CType must be real."); + using x_value_type = typename XViewType::non_const_value_type; + using y_value_type = typename YViewType::non_const_value_type; + static_assert( + (KokkosKernels::ArithTraits::is_complex && KokkosKernels::ArithTraits::is_complex) || + (!KokkosKernels::ArithTraits::is_complex && + !KokkosKernels::ArithTraits::is_complex), + "KokkosBatched::rot: XViewType and YViewType must be either both complex or both real."); + + if constexpr (KokkosKernels::ArithTraits::is_complex) { + if constexpr (Conj) { + // {c,z}rot, S must be complex + static_assert(KokkosKernels::ArithTraits::is_complex, + "KokkosBatched::rot: SType must be complex for complex input with Conj = true."); + } else { + // {cs,zd}rot, S must be real + static_assert(!KokkosKernels::ArithTraits::is_complex, + "KokkosBatched::rot: SType must be real for complex input with Conj = false."); + } + } else { + // {s,d} rot, S must be real + static_assert(!KokkosKernels::ArithTraits::is_complex, + "KokkosBatched::rot: SType must be real for real input."); + } + +#ifndef NDEBUG + const int n = x.extent_int(0); + + if (y.extent_int(0) != n) { + Kokkos::printf( + "KokkosBatched::rot: x and y must have the same length: x length " + "= " + "%d, y length = %d\n", + n, y.extent_int(0)); + return 1; + } +#endif + return 0; +} +} // namespace Impl + +/// +/// Serial Impl +/// =========== + +// {s,d,cs,zd}rot interface for Conj = false +// x(i) := c*x(i) + s*y(i) +// y(i) := c*y(i) - s*x(i) +// +// {c,z}rot interface for Conj = true +// x(i) := c*x(i) + s*y(i) +// y(i) := c*y(i) - conj(s)*x(i) +template +template +KOKKOS_INLINE_FUNCTION int SerialRot::invoke(const XViewType &x, const YViewType &y, const CType c, + const SType s) { + // Quick return if possible + const int n = x.extent_int(0); + if (n == 0) return 0; + + auto info = Impl::checkRotInput(x, y); + if (info) return info; + + using op = std::conditional_t; + return Impl::SerialRotInternal::invoke(op(), n, x.data(), x.stride(0), y.data(), y.stride(0), c, s); +} + +/// +/// Team Impl +/// =========== + +// {s,d,cs,zd}rot interface for Conj = false +// x(i) := c*x(i) + s*y(i) +// y(i) := c*y(i) - s*x(i) +// +// {c,z}rot interface for Conj = true +// x(i) := c*x(i) + s*y(i) +// y(i) := c*y(i) - conj(s)*x(i) +template +template +KOKKOS_INLINE_FUNCTION int TeamRot::invoke(const MemberType &member, const XViewType &x, + const YViewType &y, const CType c, const SType s) { + // Quick return if possible + const int n = x.extent_int(0); + if (n == 0) return 0; + + auto info = Impl::checkRotInput(x, y); + if (info) return info; + + using op = std::conditional_t; + + return Impl::TeamRotInternal::invoke(member, op(), n, x.data(), x.stride(0), y.data(), y.stride(0), c, s); +} + +/// +/// TeamVector Impl +/// =============== + +// {s,d,cs,zd}rot interface for Conj = false +// x(i) := c*x(i) + s*y(i) +// y(i) := c*y(i) - s*x(i) +// +// {c,z}rot interface for Conj = true +// x(i) := c*x(i) + s*y(i) +// y(i) := c*y(i) - conj(s)*x(i) +template +template +KOKKOS_INLINE_FUNCTION int TeamVectorRot::invoke(const MemberType &member, const XViewType &x, + const YViewType &y, const CType c, const SType s) { + // Quick return if possible + const int n = x.extent_int(0); + if (n == 0) return 0; + + auto info = Impl::checkRotInput(x, y); + if (info) return info; + + using op = std::conditional_t; + + return Impl::TeamVectorRotInternal::invoke(member, op(), n, x.data(), x.stride(0), y.data(), y.stride(0), c, s); +} + +} // namespace KokkosBatched + +#endif // KOKKOSBATCHED_ROT_IMPL_HPP_ diff --git a/batched/dense/impl/KokkosBatched_Rot_Internal.hpp b/batched/dense/impl/KokkosBatched_Rot_Internal.hpp new file mode 100644 index 0000000000..4d2fa2b130 --- /dev/null +++ b/batched/dense/impl/KokkosBatched_Rot_Internal.hpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project + +#ifndef KOKKOSBATCHED_ROT_INTERNAL_HPP_ +#define KOKKOSBATCHED_ROT_INTERNAL_HPP_ + +#include + +namespace KokkosBatched { +namespace Impl { + +/// +/// Serial Internal Impl +/// ==================== + +struct SerialRotInternal { + template + KOKKOS_INLINE_FUNCTION static int invoke(Op op, const int n, ValueType *KOKKOS_RESTRICT x, const int xs0, + ValueType *KOKKOS_RESTRICT y, const int ys0, const CType c, const SType s); +}; + +template +KOKKOS_INLINE_FUNCTION int SerialRotInternal::invoke(Op op, const int n, ValueType *KOKKOS_RESTRICT x, const int xs0, + ValueType *KOKKOS_RESTRICT y, const int ys0, const CType c, + const SType s) { + for (int i = 0; i < n; i++) { + auto temp = c * x[i * xs0] + s * y[i * ys0]; + y[i * ys0] = c * y[i * ys0] - op(s) * x[i * xs0]; + x[i * xs0] = temp; + } + + return 0; +} + +/// +/// Team Internal Impl +/// ================== + +struct TeamRotInternal { + template + KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member, Op op, const int n, ValueType *KOKKOS_RESTRICT x, + const int xs0, ValueType *KOKKOS_RESTRICT y, const int ys0, const CType c, + const SType s); +}; + +template +KOKKOS_INLINE_FUNCTION int TeamRotInternal::invoke(const MemberType &member, Op op, const int n, + ValueType *KOKKOS_RESTRICT x, const int xs0, + ValueType *KOKKOS_RESTRICT y, const int ys0, const CType c, + const SType s) { + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, n), [&](const int &i) { + auto temp = c * x[i * xs0] + s * y[i * ys0]; + y[i * ys0] = c * y[i * ys0] - op(s) * x[i * xs0]; + x[i * xs0] = temp; + }); + return 0; +} + +/// +/// TeamVector Internal Impl +/// ======================== + +struct TeamVectorRotInternal { + template + KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member, Op op, const int n, ValueType *KOKKOS_RESTRICT x, + const int xs0, ValueType *KOKKOS_RESTRICT y, const int ys0, const CType c, + const SType s); +}; + +template +KOKKOS_INLINE_FUNCTION int TeamVectorRotInternal::invoke(const MemberType &member, Op op, const int n, + ValueType *KOKKOS_RESTRICT x, const int xs0, + ValueType *KOKKOS_RESTRICT y, const int ys0, const CType c, + const SType s) { + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, n), [&](const int &i) { + auto temp = c * x[i * xs0] + s * y[i * ys0]; + y[i * ys0] = c * y[i * ys0] - op(s) * x[i * xs0]; + x[i * xs0] = temp; + }); + return 0; +} + +} // namespace Impl +} // namespace KokkosBatched + +#endif // KOKKOSBATCHED_ROT_INTERNAL_HPP_ diff --git a/batched/dense/src/KokkosBatched_Rot.hpp b/batched/dense/src/KokkosBatched_Rot.hpp new file mode 100644 index 0000000000..dedb39db06 --- /dev/null +++ b/batched/dense/src/KokkosBatched_Rot.hpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project +#ifndef KOKKOSBATCHED_ROT_HPP_ +#define KOKKOSBATCHED_ROT_HPP_ + +#include + +/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr) + +namespace KokkosBatched { + +/// \brief Serial Batched Rot: +/// Applies a plane rotation to vectors x and y: +/// x(i) := c*x(i) + s*y(i) +/// y(i) := c*y(i) - s*x(i) (Conj = false, {s,d,cs,zd}rot) +/// y(i) := c*y(i) - conj(s)*x(i) (Conj = true, {c,z}rot) +/// +/// \tparam Conj: Type indicating whether s is used directly (false) +/// or its conjugate is used (true) in the update of y +/// +/// \tparam XViewType: Input/output type for the vector x, needs to be a 1D view +/// \tparam YViewType: Input/output type for the vector y, needs to be a 1D view +/// \tparam CType: Input type for the cosine c (typically real) +/// \tparam SType: Input type for the sine s (real or complex) +/// +/// \param[in,out] x: x is a length n vector, a rank 1 view +/// \param[in,out] y: y is a length n vector, a rank 1 view +/// \param[in] c: cosine of the rotation (real scalar) +/// \param[in] s: sine of the rotation (real or complex scalar) +/// +/// No nested parallel_for is used inside of the function. +/// +template +struct SerialRot { + template + KOKKOS_INLINE_FUNCTION static int invoke(const XViewType &x, const YViewType &y, const CType c, const SType s); +}; + +/// \brief Team Batched Rot: +/// Applies a plane rotation to vectors x and y: +/// x(i) := c*x(i) + s*y(i) +/// y(i) := c*y(i) - s*x(i) (Conj = false, {s,d,cs,zd}rot) +/// y(i) := c*y(i) - conj(s)*x(i) (Conj = true, {c,z}rot) +/// +/// \tparam MemberType: TeamPolicy member type +/// \tparam Conj: Type indicating whether s is used directly (false) +/// or its conjugate is used (true) in the update of y +/// +/// \tparam XViewType: Input/output type for the vector x, needs to be a 1D view +/// \tparam YViewType: Input/output type for the vector y, needs to be a 1D view +/// \tparam CType: Input type for the cosine c (typically real) +/// \tparam SType: Input type for the sine s (real or complex) +/// +/// \param[in] member: TeamPolicy member +/// \param[in,out] x: x is a length n vector, a rank 1 view +/// \param[in,out] y: y is a length n vector, a rank 1 view +/// \param[in] c: cosine of the rotation (real scalar) +/// \param[in] s: sine of the rotation (real or complex scalar) +/// +/// A nested parallel_for with TeamThreadRange is used. +/// +template +struct TeamRot { + template + KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member, const XViewType &x, const YViewType &y, + const CType c, const SType s); +}; + +/// \brief TeamVector Batched Rot: +/// Applies a plane rotation to vectors x and y: +/// x(i) := c*x(i) + s*y(i) +/// y(i) := c*y(i) - s*x(i) (Conj = false, {s,d,cs,zd}rot) +/// y(i) := c*y(i) - conj(s)*x(i) (Conj = true, {c,z}rot) +/// +/// \tparam MemberType: TeamPolicy member type +/// \tparam Conj: Type indicating whether s is used directly (false) +/// or its conjugate is used (true) in the update of y +/// +/// \tparam XViewType: Input/output type for the vector x, needs to be a 1D view +/// \tparam YViewType: Input/output type for the vector y, needs to be a 1D view +/// \tparam CType: Input type for the cosine c (typically real) +/// \tparam SType: Input type for the sine s (real or complex) +/// +/// \param[in] member: TeamPolicy member +/// \param[in,out] x: x is a length n vector, a rank 1 view +/// \param[in,out] y: y is a length n vector, a rank 1 view +/// \param[in] c: cosine of the rotation (real scalar) +/// \param[in] s: sine of the rotation (real or complex scalar) +/// +/// A nested parallel_for with TeamVectorRange is used. +/// +template +struct TeamVectorRot { + template + KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member, const XViewType &x, const YViewType &y, + const CType c, const SType s); +}; + +} // namespace KokkosBatched + +#include "KokkosBatched_Rot_Impl.hpp" + +#endif // KOKKOSBATCHED_ROT_HPP_ diff --git a/batched/dense/unit_test/Test_Batched_Dense.hpp b/batched/dense/unit_test/Test_Batched_Dense.hpp index decb7fdd61..383dd84993 100644 --- a/batched/dense/unit_test/Test_Batched_Dense.hpp +++ b/batched/dense/unit_test/Test_Batched_Dense.hpp @@ -6,6 +6,7 @@ // Serial kernels #include "Test_Batched_SerialAxpy.hpp" #include "Test_Batched_Copy.hpp" +#include "Test_Batched_Rot.hpp" #include "Test_Batched_SerialEigendecomposition.hpp" #include "Test_Batched_SerialEigendecomposition_Real.hpp" #include "Test_Batched_SerialGesv.hpp" diff --git a/batched/dense/unit_test/Test_Batched_Rot.hpp b/batched/dense/unit_test/Test_Batched_Rot.hpp new file mode 100644 index 0000000000..c74eb6afc5 --- /dev/null +++ b/batched/dense/unit_test/Test_Batched_Rot.hpp @@ -0,0 +1,458 @@ +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project +/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr) +#include +#include +#include +#include +#include +#include "Test_Batched_DenseUtils.hpp" + +namespace Test { +namespace Rot { + +template +struct Functor_BatchedSerialRot { + using execution_space = typename DeviceType::execution_space; + XViewType m_x; + YViewType m_y; + CType m_c; + SType m_s; + + Functor_BatchedSerialRot(const XViewType &x, const YViewType &y, const CType c, const SType s) + : m_x(x), m_y(y), m_c(c), m_s(s) {} + + KOKKOS_INLINE_FUNCTION + void operator()(const int k, int &info) const { + auto sub_x = Kokkos::subview(m_x, k, Kokkos::ALL()); + auto sub_y = Kokkos::subview(m_y, k, Kokkos::ALL()); + + info += KokkosBatched::SerialRot::invoke(sub_x, sub_y, m_c, m_s); + } + + inline int run() { + using value_type = typename XViewType::non_const_value_type; + std::string name_region("KokkosBatched::Test::SerialRot"); + const std::string name_value_type = Test::value_type_name(); + std::string name = name_region + name_value_type; + int info_sum = 0; + Kokkos::Profiling::pushRegion(name.c_str()); + Kokkos::RangePolicy policy(0, m_x.extent(0)); + Kokkos::parallel_reduce(name.c_str(), policy, *this, info_sum); + Kokkos::Profiling::popRegion(); + return info_sum; + } +}; + +template +struct Functor_BatchedTeamRot { + using execution_space = typename DeviceType::execution_space; + XViewType m_x; + YViewType m_y; + CType m_c; + SType m_s; + + Functor_BatchedTeamRot(const XViewType &x, const YViewType &y, const CType c, const SType s) + : m_x(x), m_y(y), m_c(c), m_s(s) {} + + template + KOKKOS_INLINE_FUNCTION void operator()(const MemberType &member) const { + const int k = member.league_rank(); + auto sub_x = Kokkos::subview(m_x, k, Kokkos::ALL()); + auto sub_y = Kokkos::subview(m_y, k, Kokkos::ALL()); + if constexpr (std::is_same_v) { + KokkosBatched::TeamRot::invoke(member, sub_x, sub_y, m_c, m_s); + } else if constexpr (std::is_same_v) { + KokkosBatched::TeamVectorRot::invoke(member, sub_x, sub_y, m_c, m_s); + } + } + + inline void run() { + using value_type = typename XViewType::non_const_value_type; + std::string name_region = std::is_same_v ? "KokkosBatched::Test::TeamRot" + : "KokkosBatched::Test::TeamVectorRot"; + const std::string name_value_type = Test::value_type_name(); + std::string name = name_region + name_value_type; + Kokkos::Profiling::pushRegion(name.c_str()); + const int league_size = m_x.extent_int(0); + Kokkos::TeamPolicy policy(league_size, Kokkos::AUTO); + Kokkos::parallel_for(name.c_str(), policy, *this); + Kokkos::Profiling::popRegion(); + } +}; + +/// \brief Implementation details of batched rot analytical test +/// to confirm x := c*x + s*y and y := c*y - s*x are computed correctly +/// c = 0.6, s = 0.8 +/// x: [1, 2, 3, 4] +/// y: [5, 6, 7, 8] +/// x_ref: [4.6, 6.0, 7.4, 8.8] +/// y_ref: [2.2, 2.0, 1.8, 1.6] +/// +/// \tparam DeviceType Kokkos device type +/// \tparam ScalarType Kokkos scalar type +/// \tparam LayoutType Kokkos layout type for the views +/// \tparam Conj Boolean indicating whether the conjugate of s is used +/// \tparam ArgMode: one of Mode::Serial, Mode::Team, Mode::TeamVector +/// +/// \param[in] Nb Batch size of vectors +template +void impl_test_batched_rot_analytical(const std::size_t Nb) { + using ats = typename KokkosKernels::ArithTraits; + using RealType = typename ats::mag_type; + using View2DType = Kokkos::View; + using StridedView2DType = Kokkos::View; + + const std::size_t N = 4; + View2DType x("x", Nb, N), y("y", Nb, N); + View2DType x_ref("x_ref", Nb, N), y_ref("y_ref", Nb, N); + + const std::size_t incx = 2; + // Testing incx argument with strided views + Kokkos::LayoutStride layout{Nb, incx, N, Nb * incx}; + StridedView2DType x_s("x_s", layout), y_s("y_s", layout); + + auto h_x = Kokkos::create_mirror_view(x); + auto h_y = Kokkos::create_mirror_view(y); + auto h_x_ref = Kokkos::create_mirror_view(x_ref); + auto h_y_ref = Kokkos::create_mirror_view(y_ref); + + for (std::size_t ib = 0; ib < Nb; ib++) { + h_x(ib, 0) = ScalarType(1); + h_x(ib, 1) = ScalarType(2); + h_x(ib, 2) = ScalarType(3); + h_x(ib, 3) = ScalarType(4); + + h_y(ib, 0) = ScalarType(5); + h_y(ib, 1) = ScalarType(6); + h_y(ib, 2) = ScalarType(7); + h_y(ib, 3) = ScalarType(8); + + // x_ref(i) = c*x(i) + s*y(i) with c = 0.6, s = 0.8 + // y_ref(i) = c*y(i) - s*x(i) + // Note: for real s, both Conj = true and Conj = false give the same result since conj(s) = s. + h_x_ref(ib, 0) = ScalarType(4.6); + h_x_ref(ib, 1) = ScalarType(6.0); + h_x_ref(ib, 2) = ScalarType(7.4); + h_x_ref(ib, 3) = ScalarType(8.8); + + h_y_ref(ib, 0) = ScalarType(2.2); + h_y_ref(ib, 1) = ScalarType(2.0); + h_y_ref(ib, 2) = ScalarType(1.8); + h_y_ref(ib, 3) = ScalarType(1.6); + } + + Kokkos::deep_copy(x, h_x); + Kokkos::deep_copy(y, h_y); + + // Deep copy to strided views + Kokkos::deep_copy(x_s, x); + Kokkos::deep_copy(y_s, y); + + // S is complex only for {c,z}rot and real for {s,d,cs,zd}rot. + using MabyBeComplexType = std::conditional_t; + + const RealType c = 0.6; + const MabyBeComplexType s = MabyBeComplexType(0.8); + + if constexpr (std::is_same_v) { + auto info = + Functor_BatchedSerialRot(x, y, c, s) + .run(); + EXPECT_EQ(info, 0); + } else { + Functor_BatchedTeamRot(x, y, c, s) + .run(); + } + + // With strided views + if constexpr (std::is_same_v) { + auto info = + Functor_BatchedSerialRot( + x_s, y_s, c, s) + .run(); + EXPECT_EQ(info, 0); + } else { + Functor_BatchedTeamRot(x_s, y_s, c, s) + .run(); + } + + RealType eps = 1.0e1 * ats::epsilon(); + Kokkos::deep_copy(h_x, x); + Kokkos::deep_copy(h_y, y); + + // Check if x := c*x + s*y and y := c*y - op(s)*x + for (std::size_t ib = 0; ib < Nb; ib++) { + for (std::size_t i = 0; i < N; i++) { + EXPECT_NEAR_KK(h_x(ib, i), h_x_ref(ib, i), eps); + EXPECT_NEAR_KK(h_y(ib, i), h_y_ref(ib, i), eps); + } + } + + // Testing for strided views x_s and y_s, reusing x and y + Kokkos::deep_copy(x, x_s); + Kokkos::deep_copy(y, y_s); + Kokkos::deep_copy(h_x, x); + Kokkos::deep_copy(h_y, y); + for (std::size_t ib = 0; ib < Nb; ib++) { + for (std::size_t i = 0; i < N; i++) { + EXPECT_NEAR_KK(h_x(ib, i), h_x_ref(ib, i), eps); + EXPECT_NEAR_KK(h_y(ib, i), h_y_ref(ib, i), eps); + } + } +} + +/// \brief Implementation details of batched rot test +/// +/// \tparam DeviceType Kokkos device type +/// \tparam ScalarType Kokkos scalar type +/// \tparam LayoutType Kokkos layout type for the views +/// \tparam Conj Boolean indicating whether the conjugate of s is used +/// \tparam ArgMode: one of Mode::Serial, Mode::Team, Mode::TeamVector +/// +/// \param[in] Nb Batch size of vectors +/// \param[in] N Length of vectors x and y +template +void impl_test_batched_rot(const std::size_t Nb, const std::size_t N) { + using ats = typename KokkosKernels::ArithTraits; + using RealType = typename ats::mag_type; + using View2DType = Kokkos::View; + using StridedView2DType = Kokkos::View; + + View2DType x("x", Nb, N), y("y", Nb, N); + View2DType x_ref("x_ref", Nb, N), y_ref("y_ref", Nb, N); + + const std::size_t incx = 2; + // Testing incx argument with strided views + Kokkos::LayoutStride layout{Nb, incx, N, Nb * incx}; + StridedView2DType x_s("x_s", layout), y_s("y_s", layout); + + // Create random x and y + using execution_space = typename DeviceType::execution_space; + Kokkos::Random_XorShift64_Pool rand_pool(13718); + ScalarType randStart, randEnd; + + KokkosKernels::Impl::getRandomBounds(1.0, randStart, randEnd); + Kokkos::fill_random(x, rand_pool, randStart, randEnd); + Kokkos::fill_random(y, rand_pool, randStart, randEnd); + + // Save copies for reference + Kokkos::deep_copy(x_ref, x); + Kokkos::deep_copy(y_ref, y); + + // Deep copy to strided views + Kokkos::deep_copy(x_s, x); + Kokkos::deep_copy(y_s, y); + + // S is complex only for {c,z}rot and real for {s,d,cs,zd}rot. + using MabyBeComplexType = std::conditional_t; + const RealType c_val = 0.6; + MabyBeComplexType s_val; + if constexpr (KokkosKernels::ArithTraits::is_complex) { + s_val = MabyBeComplexType(0.6, 0.8); + } else { + s_val = MabyBeComplexType(0.8); + } + + // Run rot on (x, y) + if constexpr (std::is_same_v) { + auto info = Functor_BatchedSerialRot( + x, y, c_val, s_val) + .run(); + EXPECT_EQ(info, 0); + } else { + Functor_BatchedTeamRot(x, y, c_val, + s_val) + .run(); + } + + // With strided views + if constexpr (std::is_same_v) { + auto info = + Functor_BatchedSerialRot( + x_s, y_s, c_val, s_val) + .run(); + EXPECT_EQ(info, 0); + } else { + Functor_BatchedTeamRot(x_s, y_s, c_val, s_val) + .run(); + } + + // Make a reference at host + auto h_x_ref = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace{}, x_ref); + auto h_y_ref = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace{}, y_ref); + + // Note: ConjTranspose corresponds to zrot where conj(s) is used + using Op = std::conditional_t; + Op op; + for (std::size_t ib = 0; ib < Nb; ib++) { + for (std::size_t i = 0; i < N; i++) { + auto s_applied = op(s_val); + auto temp = c_val * h_x_ref(ib, i) + s_val * h_y_ref(ib, i); + h_y_ref(ib, i) = c_val * h_y_ref(ib, i) - s_applied * h_x_ref(ib, i); + h_x_ref(ib, i) = temp; + } + } + + RealType eps = 1.0e1 * ats::epsilon(); + auto h_x = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace{}, x); + auto h_y = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace{}, y); + + // Check if x := c*x + s*y and y := c*y - op(s)*x + for (std::size_t ib = 0; ib < Nb; ib++) { + for (std::size_t i = 0; i < N; i++) { + EXPECT_NEAR_KK(h_x(ib, i), h_x_ref(ib, i), eps); + EXPECT_NEAR_KK(h_y(ib, i), h_y_ref(ib, i), eps); + } + } + + // Testing for strided views x_s and y_s, reusing x and y + Kokkos::deep_copy(x, x_s); + Kokkos::deep_copy(y, y_s); + Kokkos::deep_copy(h_x, x); + Kokkos::deep_copy(h_y, y); + for (std::size_t ib = 0; ib < Nb; ib++) { + for (std::size_t i = 0; i < N; i++) { + EXPECT_NEAR_KK(h_x(ib, i), h_x_ref(ib, i), eps); + EXPECT_NEAR_KK(h_y(ib, i), h_y_ref(ib, i), eps); + } + } +} + +} // namespace Rot +} // namespace Test + +template +int test_batched_rot() { +#if defined(KOKKOSKERNELS_INST_LAYOUTLEFT) + { + using LayoutType = Kokkos::LayoutLeft; + Test::Rot::impl_test_batched_rot_analytical(1); + Test::Rot::impl_test_batched_rot_analytical(2); + for (int i = 0; i < 10; i++) { + Test::Rot::impl_test_batched_rot(1, i); + Test::Rot::impl_test_batched_rot(2, i); + } + } +#endif +#if defined(KOKKOSKERNELS_INST_LAYOUTRIGHT) + { + using LayoutType = Kokkos::LayoutRight; + Test::Rot::impl_test_batched_rot_analytical(1); + Test::Rot::impl_test_batched_rot_analytical(2); + for (int i = 0; i < 10; i++) { + Test::Rot::impl_test_batched_rot(1, i); + Test::Rot::impl_test_batched_rot(2, i); + } + } +#endif + + return 0; +} + +#if defined(KOKKOSKERNELS_INST_FLOAT) +// Serial +TEST_F(TestCategory, test_batched_serial_rot_i_float) { + test_batched_rot(); +} +TEST_F(TestCategory, test_batched_serial_rot_c_float) { + test_batched_rot(); +} + +// Team +TEST_F(TestCategory, test_batched_team_rot_i_float) { + test_batched_rot(); +} +TEST_F(TestCategory, test_batched_team_rot_c_float) { + test_batched_rot(); +} + +// TeamVector +TEST_F(TestCategory, test_batched_teamvector_rot_i_float) { + test_batched_rot(); +} +TEST_F(TestCategory, test_batched_teamvector_rot_c_float) { + test_batched_rot(); +} +#endif + +#if defined(KOKKOSKERNELS_INST_DOUBLE) +// Serial +TEST_F(TestCategory, test_batched_serial_rot_i_double) { + test_batched_rot(); +} +TEST_F(TestCategory, test_batched_serial_rot_c_double) { + test_batched_rot(); +} + +// Team +TEST_F(TestCategory, test_batched_team_rot_i_double) { + test_batched_rot(); +} +TEST_F(TestCategory, test_batched_team_rot_c_double) { + test_batched_rot(); +} + +// TeamVector +TEST_F(TestCategory, test_batched_teamvector_rot_i_double) { + test_batched_rot(); +} +TEST_F(TestCategory, test_batched_teamvector_rot_c_double) { + test_batched_rot(); +} +#endif + +#if defined(KOKKOSKERNELS_INST_COMPLEX_FLOAT) +// Serial +TEST_F(TestCategory, test_batched_serial_rot_i_fcomplex) { + test_batched_rot, false, KokkosBatched::Mode::Serial>(); +} +TEST_F(TestCategory, test_batched_serial_rot_c_fcomplex) { + test_batched_rot, true, KokkosBatched::Mode::Serial>(); +} + +// Team +TEST_F(TestCategory, test_batched_team_rot_i_fcomplex) { + test_batched_rot, false, KokkosBatched::Mode::Team>(); +} +TEST_F(TestCategory, test_batched_team_rot_c_fcomplex) { + test_batched_rot, true, KokkosBatched::Mode::Team>(); +} + +// TeamVector +TEST_F(TestCategory, test_batched_teamvector_rot_i_fcomplex) { + test_batched_rot, false, KokkosBatched::Mode::TeamVector>(); +} +TEST_F(TestCategory, test_batched_teamvector_rot_c_fcomplex) { + test_batched_rot, true, KokkosBatched::Mode::TeamVector>(); +} +#endif + +#if defined(KOKKOSKERNELS_INST_COMPLEX_DOUBLE) +// Serial +TEST_F(TestCategory, test_batched_serial_rot_i_dcomplex) { + test_batched_rot, false, KokkosBatched::Mode::Serial>(); +} +TEST_F(TestCategory, test_batched_serial_rot_c_dcomplex) { + test_batched_rot, true, KokkosBatched::Mode::Serial>(); +} + +// Team +TEST_F(TestCategory, test_batched_team_rot_i_dcomplex) { + test_batched_rot, false, KokkosBatched::Mode::Team>(); +} +TEST_F(TestCategory, test_batched_team_rot_c_dcomplex) { + test_batched_rot, true, KokkosBatched::Mode::Team>(); +} + +// TeamVector +TEST_F(TestCategory, test_batched_teamvector_rot_i_dcomplex) { + test_batched_rot, false, KokkosBatched::Mode::TeamVector>(); +} +TEST_F(TestCategory, test_batched_teamvector_rot_c_dcomplex) { + test_batched_rot, true, KokkosBatched::Mode::TeamVector>(); +} +#endif