diff --git a/blas/tpls/KokkosBlas1_rot_tpl_spec_avail.hpp b/blas/tpls/KokkosBlas1_rot_tpl_spec_avail.hpp index 4f8f8c37c2..49e612214a 100644 --- a/blas/tpls/KokkosBlas1_rot_tpl_spec_avail.hpp +++ b/blas/tpls/KokkosBlas1_rot_tpl_spec_avail.hpp @@ -119,6 +119,7 @@ KOKKOSBLAS1_ROT_TPL_SPEC_AVAIL_CUBLAS(Kokkos::complex, #endif // rocBLAS +/* #ifdef KOKKOSKERNELS_ENABLE_TPL_ROCBLAS #define KOKKOSBLAS1_ROT_TPL_SPEC_AVAIL_ROCBLAS(SCALAR, LAYOUT, EXECSPACE, \ MEMSPACE) \ @@ -143,6 +144,7 @@ KOKKOSBLAS1_ROT_TPL_SPEC_AVAIL_ROCBLAS(Kokkos::complex, Kokkos::LayoutLeft, Kokkos::HIP, Kokkos::HIPSpace) #endif +*/ } // namespace Impl } // namespace KokkosBlas diff --git a/blas/tpls/KokkosBlas1_rotg_tpl_spec_decl.hpp b/blas/tpls/KokkosBlas1_rotg_tpl_spec_decl.hpp index 0c7e814334..ef2f0b5488 100644 --- a/blas/tpls/KokkosBlas1_rotg_tpl_spec_decl.hpp +++ b/blas/tpls/KokkosBlas1_rotg_tpl_spec_decl.hpp @@ -521,9 +521,16 @@ namespace Impl { KokkosBlas::Impl::RocBlasSingleton& singleton = \ KokkosBlas::Impl::RocBlasSingleton::singleton(); \ KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ - rocblasSetStream(singleton.handle, space.hip_stream())); \ + rocblas_set_stream(singleton.handle, space.hip_stream())); \ + rocblas_pointer_mode pointer_mode; \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ + rocblas_get_pointer_mode(singleton.handle, &pointer_mode)); \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocblas_set_pointer_mode( \ + singleton.handle, rocblas_pointer_mode_device)); \ KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocblas_drotg( \ singleton.handle, a.data(), b.data(), c.data(), s.data())); \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ + rocblas_set_pointer_mode(singleton.handle, pointer_mode)); \ Kokkos::Profiling::popRegion(); \ } \ }; @@ -551,9 +558,16 @@ namespace Impl { KokkosBlas::Impl::RocBlasSingleton& singleton = \ KokkosBlas::Impl::RocBlasSingleton::singleton(); \ KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ - rocblasSetStream(singleton.handle, space.hip_stream())); \ + rocblas_set_stream(singleton.handle, space.hip_stream())); \ + rocblas_pointer_mode pointer_mode; \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ + rocblas_get_pointer_mode(singleton.handle, &pointer_mode)); \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocblas_set_pointer_mode( \ + singleton.handle, rocblas_pointer_mode_device)); \ KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocblas_srotg( \ singleton.handle, a.data(), b.data(), c.data(), s.data())); \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ + rocblas_set_pointer_mode(singleton.handle, pointer_mode)); \ Kokkos::Profiling::popRegion(); \ } \ }; @@ -584,12 +598,19 @@ namespace Impl { KokkosBlas::Impl::RocBlasSingleton& singleton = \ KokkosBlas::Impl::RocBlasSingleton::singleton(); \ KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ - rocblasSetStream(singleton.handle, space.hip_stream())); \ + rocblas_set_stream(singleton.handle, space.hip_stream())); \ + rocblas_pointer_mode pointer_mode; \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ + rocblas_get_pointer_mode(singleton.handle, &pointer_mode)); \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocblas_set_pointer_mode( \ + singleton.handle, rocblas_pointer_mode_device)); \ KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocblas_zrotg( \ singleton.handle, \ reinterpret_cast(a.data()), \ reinterpret_cast(b.data()), c.data(), \ reinterpret_cast(s.data()))); \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ + rocblas_set_pointer_mode(singleton.handle, pointer_mode)); \ Kokkos::Profiling::popRegion(); \ } \ }; @@ -619,12 +640,19 @@ namespace Impl { KokkosBlas::Impl::RocBlasSingleton& singleton = \ KokkosBlas::Impl::RocBlasSingleton::singleton(); \ KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ - rocblasSetStream(singleton.handle, space.hip_stream())); \ + rocblas_set_stream(singleton.handle, space.hip_stream())); \ + rocblas_pointer_mode pointer_mode; \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ + rocblas_get_pointer_mode(singleton.handle, &pointer_mode)); \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocblas_set_pointer_mode( \ + singleton.handle, rocblas_pointer_mode_device)); \ KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocblas_crotg( \ singleton.handle, \ reinterpret_cast(a.data()), \ reinterpret_cast(b.data()), c.data(), \ reinterpret_cast(s.data()))); \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ + rocblas_set_pointer_mode(singleton.handle, pointer_mode)); \ Kokkos::Profiling::popRegion(); \ } \ }; diff --git a/blas/tpls/KokkosBlas1_rotm_tpl_spec_decl.hpp b/blas/tpls/KokkosBlas1_rotm_tpl_spec_decl.hpp index a73b95e7e8..7cc983f42e 100644 --- a/blas/tpls/KokkosBlas1_rotm_tpl_spec_decl.hpp +++ b/blas/tpls/KokkosBlas1_rotm_tpl_spec_decl.hpp @@ -256,34 +256,37 @@ namespace Impl { template <> \ struct Rotm< \ EXEC_SPACE, \ - Kokkos::View, \ + Kokkos::View, \ Kokkos::MemoryTraits>, \ - Kokkos::View, \ Kokkos::MemoryTraits>, \ - Kokkos::View, \ - Kokkos::MemoryTraits>, \ true, ETI_SPEC_AVAIL> { \ - using DXView = \ - Kokkos::View, \ + using VectorView = \ + Kokkos::View, \ Kokkos::MemoryTraits>; \ - using YView = Kokkos::View, \ Kokkos::MemoryTraits>; \ - using PView = \ - Kokkos::View, \ - Kokkos::MemoryTraits>; \ \ - static void rotm(EXEC_SPACE const& space, DXView const& d1, \ - DXView const& d2, DXView const& x1, YView const& y1, \ - PView const& param) { \ - Kokkos::Profiling::pushRegion("KokkosBlas::nrm1[TPL_ROCBLAS,double]"); \ + static void rotm(EXEC_SPACE const& space, VectorView const& X, \ + VectorView const& Y, PView const& param) { \ + Kokkos::Profiling::pushRegion("KokkosBlas::rotm[TPL_ROCBLAS,double]"); \ rotm_print_specialization(); \ KokkosBlas::Impl::RocBlasSingleton& s = \ KokkosBlas::Impl::RocBlasSingleton::singleton(); \ KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ - rocblasSetStream(s.handle, space.hip_stream())); \ - KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocblas_drotm(s.handle, &a, &b, &c, &s)); \ + rocblas_set_stream(s.handle, space.hip_stream())); \ + rocblas_pointer_mode pointer_mode; \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ + rocblas_get_pointer_mode(s.handle, &pointer_mode)); \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ + rocblas_set_pointer_mode(s.handle, rocblas_pointer_mode_device)); \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ + rocblas_drotm(s.handle, static_cast(X.extent(0)), X.data(), 1, \ + Y.data(), 1, param.data())); \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ + rocblas_set_pointer_mode(s.handle, pointer_mode)); \ Kokkos::Profiling::popRegion(); \ } \ }; @@ -302,33 +305,37 @@ KOKKOSBLAS1_DROTM_TPL_SPEC_DECL_ROCBLAS(Kokkos::LayoutRight, Kokkos::HIP, template <> \ struct Rotm< \ EXEC_SPACE, \ - Kokkos::View, \ - Kokkos::MemoryTraits>, \ - Kokkos::View, \ + Kokkos::View, \ Kokkos::MemoryTraits>, \ - Kokkos::View, \ + Kokkos::View, \ Kokkos::MemoryTraits>, \ true, ETI_SPEC_AVAIL> { \ - using DXView = \ - Kokkos::View, \ + using VectorView = \ + Kokkos::View, \ Kokkos::MemoryTraits>; \ - using YView = Kokkos::View, \ Kokkos::MemoryTraits>; \ - using PView = \ - Kokkos::View, \ - Kokkos::MemoryTraits>; \ \ - static void rotm(EXEC_SPACE const& space, DXView const& d1, \ - DXView const& d2, DXView const& x1, YView const& y1, \ - PView const& param) { \ - Kokkos::Profiling::pushRegion("KokkosBlas::nrm1[TPL_ROCBLAS,float]"); \ + static void rotm(EXEC_SPACE const& space, VectorView const& X, \ + VectorView const& Y, PView const& param) { \ + Kokkos::Profiling::pushRegion("KokkosBlas::rotm[TPL_ROCBLAS,float]"); \ rotm_print_specialization(); \ KokkosBlas::Impl::RocBlasSingleton& s = \ KokkosBlas::Impl::RocBlasSingleton::singleton(); \ KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ - rocblasSetStream(s.handle, space.hip_stream())); \ - KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocblas_srotm(s.handle, &a, &b, &c, &s)); \ + rocblas_set_stream(s.handle, space.hip_stream())); \ + rocblas_pointer_mode pointer_mode; \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ + rocblas_get_pointer_mode(s.handle, &pointer_mode)); \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ + rocblas_set_pointer_mode(s.handle, rocblas_pointer_mode_device)); \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ + rocblas_srotm(s.handle, static_cast(X.extent(0)), X.data(), 1, \ + Y.data(), 1, param.data())); \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ + rocblas_set_pointer_mode(s.handle, pointer_mode)); \ Kokkos::Profiling::popRegion(); \ } \ }; diff --git a/blas/tpls/KokkosBlas1_rotmg_tpl_spec_decl.hpp b/blas/tpls/KokkosBlas1_rotmg_tpl_spec_decl.hpp index 0305cde3f5..30619f3970 100644 --- a/blas/tpls/KokkosBlas1_rotmg_tpl_spec_decl.hpp +++ b/blas/tpls/KokkosBlas1_rotmg_tpl_spec_decl.hpp @@ -296,13 +296,22 @@ namespace Impl { static void rotmg(EXEC_SPACE const& space, DXView const& d1, \ DXView const& d2, DXView const& x1, YView const& y1, \ PView const& param) { \ - Kokkos::Profiling::pushRegion("KokkosBlas::nrm1[TPL_ROCBLAS,double]"); \ + Kokkos::Profiling::pushRegion("KokkosBlas::rotmg[TPL_ROCBLAS,double]"); \ rotmg_print_specialization(); \ KokkosBlas::Impl::RocBlasSingleton& s = \ KokkosBlas::Impl::RocBlasSingleton::singleton(); \ KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ - rocblasSetStream(s.handle, space.hip_stream())); \ - KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocblas_drotmg(s.handle, &a, &b, &c, &s)); \ + rocblas_set_stream(s.handle, space.hip_stream())); \ + rocblas_pointer_mode pointer_mode; \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ + rocblas_get_pointer_mode(s.handle, &pointer_mode)); \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ + rocblas_set_pointer_mode(s.handle, rocblas_pointer_mode_device)); \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocblas_drotmg(s.handle, d1.data(), \ + d2.data(), x1.data(), \ + y1.data(), param.data())); \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ + rocblas_set_pointer_mode(s.handle, pointer_mode)); \ Kokkos::Profiling::popRegion(); \ } \ }; @@ -341,13 +350,22 @@ KOKKOSBLAS1_DROTMG_TPL_SPEC_DECL_ROCBLAS(Kokkos::LayoutRight, Kokkos::HIP, static void rotmg(EXEC_SPACE const& space, DXView const& d1, \ DXView const& d2, DXView const& x1, YView const& y1, \ PView const& param) { \ - Kokkos::Profiling::pushRegion("KokkosBlas::nrm1[TPL_ROCBLAS,float]"); \ + Kokkos::Profiling::pushRegion("KokkosBlas::rotmg[TPL_ROCBLAS,float]"); \ rotmg_print_specialization(); \ KokkosBlas::Impl::RocBlasSingleton& s = \ KokkosBlas::Impl::RocBlasSingleton::singleton(); \ KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ - rocblasSetStream(s.handle, space.hip_stream())); \ - KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocblas_srotmg(s.handle, &a, &b, &c, &s)); \ + rocblas_set_stream(s.handle, space.hip_stream())); \ + rocblas_pointer_mode pointer_mode; \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ + rocblas_get_pointer_mode(s.handle, &pointer_mode)); \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ + rocblas_set_pointer_mode(s.handle, rocblas_pointer_mode_device)); \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocblas_srotmg(s.handle, d1.data(), \ + d2.data(), x1.data(), \ + y1.data(), param.data())); \ + KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \ + rocblas_set_pointer_mode(s.handle, pointer_mode)); \ Kokkos::Profiling::popRegion(); \ } \ };