Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions batched/KokkosBatched_Util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include "Kokkos_Complex.hpp"

#include "Kokkos_DynRankView.hpp"
#include "KokkosKernels_config.h"
#include "KokkosKernels_Macros.hpp"
#include "KokkosKernels_SimpleUtils.hpp"
Expand Down Expand Up @@ -677,9 +678,14 @@ KOKKOS_INLINE_FUNCTION void fma_bounds_check(ViewType v, SizeType m, SizeType n,
namespace Impl {
template <typename ViewType>
KOKKOS_INLINE_FUNCTION int get_extent_int(const ViewType &v, const int r) {
static_assert(Kokkos::is_view_v<ViewType>, "KokkosBatched: ViewType is not a Kokkos::View.");
constexpr std::size_t V_rank = ViewType::rank();
static_assert(V_rank <= 2, "KokkosBatched: ViewType must have rank 0, 1 or 2.");
// Only allow View and DynRankView objects
if constexpr (Kokkos::is_view_v<ViewType>) {
static_assert(ViewType::rank() == 2, "KokkosBatched: ViewType must be rank 0, 1, or 2");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
static_assert(ViewType::rank() == 2, "KokkosBatched: ViewType must be rank 0, 1, or 2");
static_assert(ViewType::rank() <= 2, "KokkosBatched: ViewType must be rank 0, 1, or 2");

} else {
static_assert(Kokkos::is_dyn_rank_view_v<ViewType>, "KokkosBatched: ViewType must be a DynRankView or a View");
}

const std::size_t V_rank = v.rank();
Comment thread
lucbv marked this conversation as resolved.

if (r == 0) {
int V_extent_0 = V_rank < 1 ? 1 : v.extent_int(0);
Expand All @@ -694,9 +700,14 @@ KOKKOS_INLINE_FUNCTION int get_extent_int(const ViewType &v, const int r) {

template <typename ViewType>
KOKKOS_INLINE_FUNCTION std::size_t get_stride(const ViewType &v, const int r) {
static_assert(Kokkos::is_view_v<ViewType>, "KokkosBatched: ViewType is not a Kokkos::View.");
constexpr std::size_t V_rank = ViewType::rank();
static_assert(V_rank <= 2, "KokkosBatched: ViewType must have rank 0, 1 or 2.");
// Only allow View and DynRankView objects
if constexpr (Kokkos::is_view_v<ViewType>) {
static_assert(ViewType::rank() == 2, "KokkosBatched: ViewType must be rank 0, 1, or 2");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
static_assert(ViewType::rank() == 2, "KokkosBatched: ViewType must be rank 0, 1, or 2");
static_assert(ViewType::rank() <= 2, "KokkosBatched: ViewType must be rank 0, 1, or 2");

} else {
static_assert(Kokkos::is_dyn_rank_view_v<ViewType>, "KokkosBatched: ViewType must be a DynRankView or a View");
}

const std::size_t V_rank = v.rank();
Comment thread
lucbv marked this conversation as resolved.

if (r == 0) {
std::size_t V_stride_0 = V_rank < 1 ? 1 : v.stride(0);
Expand All @@ -708,6 +719,7 @@ KOKKOS_INLINE_FUNCTION std::size_t get_stride(const ViewType &v, const int r) {
return 1;
}
}

} // namespace Impl

} // namespace KokkosBatched
Expand Down
Loading