diff --git a/batched/KokkosBatched_Util.hpp b/batched/KokkosBatched_Util.hpp index c8dae37678..8955b186ed 100644 --- a/batched/KokkosBatched_Util.hpp +++ b/batched/KokkosBatched_Util.hpp @@ -23,6 +23,7 @@ #include "Kokkos_Complex.hpp" +#include "Kokkos_DynRankView.hpp" #include "KokkosKernels_config.h" #include "KokkosKernels_Macros.hpp" #include "KokkosKernels_SimpleUtils.hpp" @@ -677,9 +678,14 @@ KOKKOS_INLINE_FUNCTION void fma_bounds_check(ViewType v, SizeType m, SizeType n, namespace Impl { template KOKKOS_INLINE_FUNCTION int get_extent_int(const ViewType &v, const int r) { - static_assert(Kokkos::is_view_v, "KokkosBatched: ViewType is not a Kokkos::View."); - constexpr std::size_t V_rank = ViewType::rank(); - static_assert(V_rank <= 2, "KokkosBatched: ViewType must have rank 0, 1 or 2."); + // Only allow View and DynRankView objects + if constexpr (Kokkos::is_view_v) { + static_assert(ViewType::rank() == 2, "KokkosBatched: ViewType must be rank 0, 1, or 2"); + } else { + static_assert(Kokkos::is_dyn_rank_view_v, "KokkosBatched: ViewType must be a DynRankView or a View"); + } + + const std::size_t V_rank = v.rank(); if (r == 0) { int V_extent_0 = V_rank < 1 ? 1 : v.extent_int(0); @@ -694,9 +700,14 @@ KOKKOS_INLINE_FUNCTION int get_extent_int(const ViewType &v, const int r) { template KOKKOS_INLINE_FUNCTION std::size_t get_stride(const ViewType &v, const int r) { - static_assert(Kokkos::is_view_v, "KokkosBatched: ViewType is not a Kokkos::View."); - constexpr std::size_t V_rank = ViewType::rank(); - static_assert(V_rank <= 2, "KokkosBatched: ViewType must have rank 0, 1 or 2."); + // Only allow View and DynRankView objects + if constexpr (Kokkos::is_view_v) { + static_assert(ViewType::rank() == 2, "KokkosBatched: ViewType must be rank 0, 1, or 2"); + } else { + static_assert(Kokkos::is_dyn_rank_view_v, "KokkosBatched: ViewType must be a DynRankView or a View"); + } + + const std::size_t V_rank = v.rank(); if (r == 0) { std::size_t V_stride_0 = V_rank < 1 ? 1 : v.stride(0); @@ -708,6 +719,7 @@ KOKKOS_INLINE_FUNCTION std::size_t get_stride(const ViewType &v, const int r) { return 1; } } + } // namespace Impl } // namespace KokkosBatched