Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions batched/dense/impl/KokkosBatched_Rot_Impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project

#ifndef KOKKOSBATCHED_ROT_IMPL_HPP_
#define KOKKOSBATCHED_ROT_IMPL_HPP_

#include <KokkosBlas_util.hpp>
#include <KokkosBatched_Util.hpp>
#include "KokkosBatched_Rot_Internal.hpp"

namespace KokkosBatched {
namespace Impl {
template <bool Conj, typename CType, typename SType, typename XViewType, typename YViewType>
KOKKOS_INLINE_FUNCTION static int checkRotInput([[maybe_unused]] const XViewType &x,
[[maybe_unused]] const YViewType &y) {
static_assert(Kokkos::is_view_v<XViewType>, "KokkosBatched::rot: XViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<YViewType>, "KokkosBatched::rot: YViewType is not a Kokkos::View.");
static_assert(XViewType::rank() == 1, "KokkosBatched::rot: XViewType must have rank 1.");
static_assert(YViewType::rank() == 1, "KokkosBatched::rot: YViewType must have rank 1.");
static_assert(std::is_same_v<typename XViewType::value_type, typename XViewType::non_const_value_type>,
"KokkosBatched::rot: XViewType must have non-const value type.");
static_assert(std::is_same_v<typename YViewType::value_type, typename YViewType::non_const_value_type>,
"KokkosBatched::rot: YViewType must have non-const value type.");
static_assert(!KokkosKernels::ArithTraits<CType>::is_complex, "KokkosBatched::rot: CType must be real.");
using x_value_type = typename XViewType::non_const_value_type;
using y_value_type = typename YViewType::non_const_value_type;
static_assert(
(KokkosKernels::ArithTraits<x_value_type>::is_complex && KokkosKernels::ArithTraits<y_value_type>::is_complex) ||
(!KokkosKernels::ArithTraits<x_value_type>::is_complex &&
!KokkosKernels::ArithTraits<y_value_type>::is_complex),
"KokkosBatched::rot: XViewType and YViewType must be either both complex or both real.");

if constexpr (KokkosKernels::ArithTraits<x_value_type>::is_complex) {
if constexpr (Conj) {
// {c,z}rot, S must be complex
static_assert(KokkosKernels::ArithTraits<SType>::is_complex,
"KokkosBatched::rot: SType must be complex for complex input with Conj = true.");
} else {
// {cs,zd}rot, S must be real
static_assert(!KokkosKernels::ArithTraits<SType>::is_complex,
"KokkosBatched::rot: SType must be real for complex input with Conj = false.");
}
} else {
// {s,d} rot, S must be real
static_assert(!KokkosKernels::ArithTraits<SType>::is_complex,
"KokkosBatched::rot: SType must be real for real input.");
}

#ifndef NDEBUG
const int n = x.extent_int(0);

if (y.extent_int(0) != n) {
Kokkos::printf(
"KokkosBatched::rot: x and y must have the same length: x length "
"= "
"%d, y length = %d\n",
n, y.extent_int(0));
return 1;
}
#endif
return 0;
}
} // namespace Impl

///
/// Serial Impl
/// ===========

// {s,d,cs,zd}rot interface for Conj = false
// x(i) := c*x(i) + s*y(i)
// y(i) := c*y(i) - s*x(i)
//
// {c,z}rot interface for Conj = true
// x(i) := c*x(i) + s*y(i)
// y(i) := c*y(i) - conj(s)*x(i)
template <bool Conj>
template <typename XViewType, typename YViewType, typename CType, typename SType>
KOKKOS_INLINE_FUNCTION int SerialRot<Conj>::invoke(const XViewType &x, const YViewType &y, const CType c,
const SType s) {
// Quick return if possible
const int n = x.extent_int(0);
if (n == 0) return 0;

auto info = Impl::checkRotInput<Conj, CType, SType>(x, y);
if (info) return info;

using op = std::conditional_t<Conj, KokkosBlas::Impl::OpConj, KokkosBlas::Impl::OpID>;
return Impl::SerialRotInternal::invoke(op(), n, x.data(), x.stride(0), y.data(), y.stride(0), c, s);
}

///
/// Team Impl
/// ===========

// {s,d,cs,zd}rot interface for Conj = false
// x(i) := c*x(i) + s*y(i)
// y(i) := c*y(i) - s*x(i)
//
// {c,z}rot interface for Conj = true
// x(i) := c*x(i) + s*y(i)
// y(i) := c*y(i) - conj(s)*x(i)
template <typename MemberType, bool Conj>
template <typename XViewType, typename YViewType, typename CType, typename SType>
KOKKOS_INLINE_FUNCTION int TeamRot<MemberType, Conj>::invoke(const MemberType &member, const XViewType &x,
const YViewType &y, const CType c, const SType s) {
// Quick return if possible
const int n = x.extent_int(0);
if (n == 0) return 0;

auto info = Impl::checkRotInput<Conj, CType, SType>(x, y);
if (info) return info;

using op = std::conditional_t<Conj, KokkosBlas::Impl::OpConj, KokkosBlas::Impl::OpID>;

return Impl::TeamRotInternal::invoke(member, op(), n, x.data(), x.stride(0), y.data(), y.stride(0), c, s);
}

///
/// TeamVector Impl
/// ===============

// {s,d,cs,zd}rot interface for Conj = false
// x(i) := c*x(i) + s*y(i)
// y(i) := c*y(i) - s*x(i)
//
// {c,z}rot interface for Conj = true
// x(i) := c*x(i) + s*y(i)
// y(i) := c*y(i) - conj(s)*x(i)
template <typename MemberType, bool Conj>
template <typename XViewType, typename YViewType, typename CType, typename SType>
KOKKOS_INLINE_FUNCTION int TeamVectorRot<MemberType, Conj>::invoke(const MemberType &member, const XViewType &x,
const YViewType &y, const CType c, const SType s) {
// Quick return if possible
const int n = x.extent_int(0);
if (n == 0) return 0;

auto info = Impl::checkRotInput<Conj, CType, SType>(x, y);
if (info) return info;

using op = std::conditional_t<Conj, KokkosBlas::Impl::OpConj, KokkosBlas::Impl::OpID>;

return Impl::TeamVectorRotInternal::invoke(member, op(), n, x.data(), x.stride(0), y.data(), y.stride(0), c, s);
}

} // namespace KokkosBatched

#endif // KOKKOSBATCHED_ROT_IMPL_HPP_
86 changes: 86 additions & 0 deletions batched/dense/impl/KokkosBatched_Rot_Internal.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project

#ifndef KOKKOSBATCHED_ROT_INTERNAL_HPP_
#define KOKKOSBATCHED_ROT_INTERNAL_HPP_

#include <KokkosBatched_Util.hpp>

namespace KokkosBatched {
namespace Impl {

///
/// Serial Internal Impl
/// ====================

struct SerialRotInternal {
template <typename Op, typename ValueType, typename CType, typename SType>
KOKKOS_INLINE_FUNCTION static int invoke(Op op, const int n, ValueType *KOKKOS_RESTRICT x, const int xs0,
ValueType *KOKKOS_RESTRICT y, const int ys0, const CType c, const SType s);
};

template <typename Op, typename ValueType, typename CType, typename SType>
KOKKOS_INLINE_FUNCTION int SerialRotInternal::invoke(Op op, const int n, ValueType *KOKKOS_RESTRICT x, const int xs0,
ValueType *KOKKOS_RESTRICT y, const int ys0, const CType c,
const SType s) {
for (int i = 0; i < n; i++) {
auto temp = c * x[i * xs0] + s * y[i * ys0];
y[i * ys0] = c * y[i * ys0] - op(s) * x[i * xs0];
x[i * xs0] = temp;
}

return 0;
}

///
/// Team Internal Impl
/// ==================

struct TeamRotInternal {
template <typename MemberType, typename Op, typename ValueType, typename CType, typename SType>
KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member, Op op, const int n, ValueType *KOKKOS_RESTRICT x,
const int xs0, ValueType *KOKKOS_RESTRICT y, const int ys0, const CType c,
const SType s);
};

template <typename MemberType, typename Op, typename ValueType, typename CType, typename SType>
KOKKOS_INLINE_FUNCTION int TeamRotInternal::invoke(const MemberType &member, Op op, const int n,
ValueType *KOKKOS_RESTRICT x, const int xs0,
ValueType *KOKKOS_RESTRICT y, const int ys0, const CType c,
const SType s) {
Kokkos::parallel_for(Kokkos::TeamThreadRange(member, n), [&](const int &i) {
auto temp = c * x[i * xs0] + s * y[i * ys0];
y[i * ys0] = c * y[i * ys0] - op(s) * x[i * xs0];
x[i * xs0] = temp;
});
return 0;
}

///
/// TeamVector Internal Impl
/// ========================

struct TeamVectorRotInternal {
template <typename MemberType, typename Op, typename ValueType, typename CType, typename SType>
KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member, Op op, const int n, ValueType *KOKKOS_RESTRICT x,
const int xs0, ValueType *KOKKOS_RESTRICT y, const int ys0, const CType c,
const SType s);
};

template <typename MemberType, typename Op, typename ValueType, typename CType, typename SType>
KOKKOS_INLINE_FUNCTION int TeamVectorRotInternal::invoke(const MemberType &member, Op op, const int n,
ValueType *KOKKOS_RESTRICT x, const int xs0,
ValueType *KOKKOS_RESTRICT y, const int ys0, const CType c,
const SType s) {
Kokkos::parallel_for(Kokkos::TeamVectorRange(member, n), [&](const int &i) {
auto temp = c * x[i * xs0] + s * y[i * ys0];
y[i * ys0] = c * y[i * ys0] - op(s) * x[i * xs0];
x[i * xs0] = temp;
});
return 0;
}

} // namespace Impl
} // namespace KokkosBatched

#endif // KOKKOSBATCHED_ROT_INTERNAL_HPP_
103 changes: 103 additions & 0 deletions batched/dense/src/KokkosBatched_Rot.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project
#ifndef KOKKOSBATCHED_ROT_HPP_
#define KOKKOSBATCHED_ROT_HPP_

#include <KokkosBatched_Util.hpp>

/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr)

namespace KokkosBatched {

/// \brief Serial Batched Rot:
/// Applies a plane rotation to vectors x and y:
/// x(i) := c*x(i) + s*y(i)
/// y(i) := c*y(i) - s*x(i) (Conj = false, {s,d,cs,zd}rot)
/// y(i) := c*y(i) - conj(s)*x(i) (Conj = true, {c,z}rot)
///
/// \tparam Conj: Type indicating whether s is used directly (false)
/// or its conjugate is used (true) in the update of y
///
/// \tparam XViewType: Input/output type for the vector x, needs to be a 1D view
/// \tparam YViewType: Input/output type for the vector y, needs to be a 1D view
/// \tparam CType: Input type for the cosine c (typically real)
/// \tparam SType: Input type for the sine s (real or complex)
///
/// \param[in,out] x: x is a length n vector, a rank 1 view
/// \param[in,out] y: y is a length n vector, a rank 1 view
/// \param[in] c: cosine of the rotation (real scalar)
/// \param[in] s: sine of the rotation (real or complex scalar)
///
/// No nested parallel_for is used inside of the function.
///
template <bool Conj = false>
struct SerialRot {
template <typename XViewType, typename YViewType, typename CType, typename SType>
KOKKOS_INLINE_FUNCTION static int invoke(const XViewType &x, const YViewType &y, const CType c, const SType s);
};

/// \brief Team Batched Rot:
/// Applies a plane rotation to vectors x and y:
/// x(i) := c*x(i) + s*y(i)
/// y(i) := c*y(i) - s*x(i) (Conj = false, {s,d,cs,zd}rot)
/// y(i) := c*y(i) - conj(s)*x(i) (Conj = true, {c,z}rot)
///
/// \tparam MemberType: TeamPolicy member type
/// \tparam Conj: Type indicating whether s is used directly (false)
/// or its conjugate is used (true) in the update of y
///
/// \tparam XViewType: Input/output type for the vector x, needs to be a 1D view
/// \tparam YViewType: Input/output type for the vector y, needs to be a 1D view
/// \tparam CType: Input type for the cosine c (typically real)
/// \tparam SType: Input type for the sine s (real or complex)
///
/// \param[in] member: TeamPolicy member
/// \param[in,out] x: x is a length n vector, a rank 1 view
/// \param[in,out] y: y is a length n vector, a rank 1 view
/// \param[in] c: cosine of the rotation (real scalar)
/// \param[in] s: sine of the rotation (real or complex scalar)
///
/// A nested parallel_for with TeamThreadRange is used.
///
template <typename MemberType, bool Conj = false>
struct TeamRot {
template <typename XViewType, typename YViewType, typename CType, typename SType>
KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member, const XViewType &x, const YViewType &y,
const CType c, const SType s);
};

/// \brief TeamVector Batched Rot:
/// Applies a plane rotation to vectors x and y:
/// x(i) := c*x(i) + s*y(i)
/// y(i) := c*y(i) - s*x(i) (Conj = false, {s,d,cs,zd}rot)
/// y(i) := c*y(i) - conj(s)*x(i) (Conj = true, {c,z}rot)
///
/// \tparam MemberType: TeamPolicy member type
/// \tparam Conj: Type indicating whether s is used directly (false)
/// or its conjugate is used (true) in the update of y
///
/// \tparam XViewType: Input/output type for the vector x, needs to be a 1D view
/// \tparam YViewType: Input/output type for the vector y, needs to be a 1D view
/// \tparam CType: Input type for the cosine c (typically real)
/// \tparam SType: Input type for the sine s (real or complex)
///
/// \param[in] member: TeamPolicy member
/// \param[in,out] x: x is a length n vector, a rank 1 view
/// \param[in,out] y: y is a length n vector, a rank 1 view
/// \param[in] c: cosine of the rotation (real scalar)
/// \param[in] s: sine of the rotation (real or complex scalar)
///
/// A nested parallel_for with TeamVectorRange is used.
///
template <typename MemberType, bool Conj = false>
struct TeamVectorRot {
template <typename XViewType, typename YViewType, typename CType, typename SType>
KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member, const XViewType &x, const YViewType &y,
const CType c, const SType s);
};

} // namespace KokkosBatched

#include "KokkosBatched_Rot_Impl.hpp"

#endif // KOKKOSBATCHED_ROT_HPP_
1 change: 1 addition & 0 deletions batched/dense/unit_test/Test_Batched_Dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
// Serial kernels
#include "Test_Batched_SerialAxpy.hpp"
#include "Test_Batched_Copy.hpp"
#include "Test_Batched_Rot.hpp"
#include "Test_Batched_SerialEigendecomposition.hpp"
#include "Test_Batched_SerialEigendecomposition_Real.hpp"
#include "Test_Batched_SerialGesv.hpp"
Expand Down
Loading
Loading