diff --git a/batched/KokkosBatched_Util.hpp b/batched/KokkosBatched_Util.hpp index c8dae37678..ceee7346da 100644 --- a/batched/KokkosBatched_Util.hpp +++ b/batched/KokkosBatched_Util.hpp @@ -21,7 +21,9 @@ #include #include +#include #include "Kokkos_Complex.hpp" +#include "Kokkos_DynRankView.hpp" #include "KokkosKernels_config.h" #include "KokkosKernels_Macros.hpp" @@ -677,9 +679,17 @@ KOKKOS_INLINE_FUNCTION void fma_bounds_check(ViewType v, SizeType m, SizeType n, namespace Impl { template KOKKOS_INLINE_FUNCTION int get_extent_int(const ViewType &v, const int r) { - static_assert(Kokkos::is_view_v, "KokkosBatched: ViewType is not a Kokkos::View."); - constexpr std::size_t V_rank = ViewType::rank(); - static_assert(V_rank <= 2, "KokkosBatched: ViewType must have rank 0, 1 or 2."); + // Check for view and dynrankview + if constexpr (Kokkos::is_view_v) { + static_assert(ViewType::rank() <= 2, "KokkosBatched: ViewType must have rank 0, 1 or 2."); + } else if constexpr (Kokkos::is_dyn_rank_view_v) { + KOKKOS_EXPECTS((v.rank() <= 2)); + } else { + static_assert(Kokkos::is_view_v || Kokkos::is_dyn_rank_view_v, + "KokkosBatched: ViewType must be a Kokkos::View or a Kokkos::DynRankView"); + } + + const std::size_t V_rank = v.rank(); if (r == 0) { int V_extent_0 = V_rank < 1 ? 1 : v.extent_int(0); @@ -694,9 +704,17 @@ KOKKOS_INLINE_FUNCTION int get_extent_int(const ViewType &v, const int r) { template KOKKOS_INLINE_FUNCTION std::size_t get_stride(const ViewType &v, const int r) { - static_assert(Kokkos::is_view_v, "KokkosBatched: ViewType is not a Kokkos::View."); - constexpr std::size_t V_rank = ViewType::rank(); - static_assert(V_rank <= 2, "KokkosBatched: ViewType must have rank 0, 1 or 2."); + // Check for view and dynrankview + if constexpr (Kokkos::is_view_v) { + static_assert(ViewType::rank() <= 2, "KokkosBatched: ViewType must have rank 0, 1 or 2."); + } else if constexpr (Kokkos::is_dyn_rank_view_v) { + KOKKOS_EXPECTS((v.rank() <= 2)); + } else { + static_assert(Kokkos::is_view_v || Kokkos::is_dyn_rank_view_v, + "KokkosBatched: ViewType must be a Kokkos::View or a Kokkos::DynRankView"); + } + + const std::size_t V_rank = v.rank(); if (r == 0) { std::size_t V_stride_0 = V_rank < 1 ? 1 : v.stride(0);