diff --git a/sparse/impl/KokkosSparse_CrsMatrix_traversal_impl.hpp b/sparse/impl/KokkosSparse_CrsMatrix_traversal_impl.hpp new file mode 100644 index 0000000000..805e93d481 --- /dev/null +++ b/sparse/impl/KokkosSparse_CrsMatrix_traversal_impl.hpp @@ -0,0 +1,158 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER + +namespace KokkosSparse { +namespace Impl { + +template +struct crsmatrix_traversal_functor { + using size_type = typename matrix_type::non_const_size_type; + using ordinal_type = typename matrix_type::non_const_ordinal_type; + using value_type = typename matrix_type::non_const_value_type; + + using team_policy_type = Kokkos::TeamPolicy; + using team_member_type = typename team_policy_type::member_type; + + matrix_type A; + functor_type func; + ordinal_type rows_per_team; + + crsmatrix_traversal_functor(const matrix_type& A_, const functor_type& func_, + const ordinal_type rows_per_team_) + : A(A_), func(func_), rows_per_team(rows_per_team_) {} + + // RangePolicy overload + KOKKOS_INLINE_FUNCTION void operator()(const ordinal_type rowIdx) const { + for (size_type entryIdx = A.graph.row_map(rowIdx); + entryIdx < A.graph.row_map(rowIdx + 1); ++entryIdx) { + const ordinal_type colIdx = A.graph.entries(entryIdx); + const value_type value = A.values(entryIdx); + + func(rowIdx, entryIdx, colIdx, value); + } + } + + // TeamPolicy overload + KOKKOS_INLINE_FUNCTION void operator()(const team_member_type& dev) const { + const ordinal_type teamWork = dev.league_rank() * rows_per_team; + Kokkos::parallel_for( + Kokkos::TeamThreadRange(dev, rows_per_team), [&](ordinal_type loop) { + // iRow represents a row of the matrix, so its correct type is + // ordinal_type. + const ordinal_type rowIdx = teamWork + loop; + if (rowIdx >= A.numRows()) { + return; + } + + const ordinal_type row_length = + A.graph.row_map(rowIdx + 1) - A.graph.row_map(rowIdx); + Kokkos::parallel_for( + Kokkos::ThreadVectorRange(dev, row_length), + [&](ordinal_type rowEntryIdx) { + const size_type entryIdx = A.graph.row_map(rowIdx) + + static_cast(rowEntryIdx); + const ordinal_type colIdx = A.graph.entries(entryIdx); + const value_type value = A.values(entryIdx); + + func(rowIdx, entryIdx, colIdx, value); + }); + }); + } +}; + +template +int64_t crsmatrix_traversal_launch_parameters(int64_t numRows, int64_t nnz, + int64_t rows_per_thread, + int& team_size, + int& vector_length) { + int64_t rows_per_team; + int64_t nnz_per_row = nnz / numRows; + + if (nnz_per_row < 1) nnz_per_row = 1; + + int max_vector_length = + Kokkos::TeamPolicy::vector_length_max(); + + if (vector_length < 1) { + vector_length = 1; + while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) + vector_length *= 2; + } + + // Determine rows per thread + if (rows_per_thread < 1) { + if (KokkosKernels::Impl::kk_is_gpu_exec_space()) + rows_per_thread = 1; + else { + if (nnz_per_row < 20 && nnz > 5000000) { + rows_per_thread = 256; + } else + rows_per_thread = 64; + } + } + + if (team_size < 1) { + if (KokkosKernels::Impl::kk_is_gpu_exec_space()) { + team_size = 256 / vector_length; + } else { + team_size = 1; + } + } + + rows_per_team = rows_per_thread * team_size; + + return rows_per_team; +} + +template +void crsmatrix_traversal_on_host(const execution_space& space, + const crsmatrix_type& A, + const functor_type& func) { + // Wrap user functor with crsmatrix_traversal_functor + crsmatrix_traversal_functor + traversal_func(A, func, -1); + + // Launch traversal kernel + Kokkos::parallel_for( + "KokkosSparse::crsmatrix_traversal", + Kokkos::RangePolicy(space, 0, A.numRows()), + traversal_func); +} + +template +void crsmatrix_traversal_on_gpu(const execution_space& space, + const crsmatrix_type& A, + const functor_type& func) { + // Wrap user functor with crsmatrix_traversal_functor + int64_t rows_per_thread = 0; + int team_size = 0, vector_length = 0; + const int64_t rows_per_team = + crsmatrix_traversal_launch_parameters( + A.numRows(), A.nnz(), rows_per_thread, team_size, vector_length); + const int nteams = + (static_cast(A.numRows()) + rows_per_team - 1) / rows_per_team; + crsmatrix_traversal_functor + traversal_func(A, func, rows_per_team); + + // Launch traversal kernel + Kokkos::parallel_for("KokkosSparse::crsmatrix_traversal", + Kokkos::TeamPolicy( + space, nteams, team_size, vector_length), + traversal_func); +} + +} // namespace Impl +} // namespace KokkosSparse diff --git a/sparse/src/KokkosSparse_CrsMatrix_traversal.hpp b/sparse/src/KokkosSparse_CrsMatrix_traversal.hpp new file mode 100644 index 0000000000..55ccd2446a --- /dev/null +++ b/sparse/src/KokkosSparse_CrsMatrix_traversal.hpp @@ -0,0 +1,92 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER + +/// \file KokkosSparse_CrsMatrix_traversal.hpp +/// \brief Traversal method to access all entries in a CrsMatrix +/// +/// This file provides a public interface to traversal +/// methods that are used as a common and efficient way +/// to access entries in a matrix on host and/or device. + +#ifndef KOKKOSSPARSE_CRSMATRIX_TRAVERSAL_HPP +#define KOKKOSSPARSE_CRSMATRIX_TRAVERSAL_HPP + +#include "Kokkos_Core.hpp" + +#include "KokkosSparse_CrsMatrix.hpp" +#include "KokkosKernels_ExecSpaceUtils.hpp" + +#include "KokkosSparse_CrsMatrix_traversal_impl.hpp" + +namespace KokkosSparse { +namespace Experimental { + +/// \brief Public interface to sparse matrix traversal algorithm. +/// +/// Loop over the entries of the input matrix and apply the functor +/// to them. The functor itself may contain its own data to save results +/// after the traversal completes. +/// +/// \tparam execution_space +/// \tparam crsmatrix_type +/// \tparam functor_type +/// +/// \param space [in] execution space instance that specifies where the kernel +/// will be executed. +/// \param matrix [in] the matrix to be traversed. +/// \param functor [in] a functor that is being called on each local entries +/// of the crsmatrix and that implement a user defined capabilities. +/// +template +void crsmatrix_traversal(const execution_space& space, + const crsmatrix_type& matrix, functor_type& functor) { + // Check if a quick return can be performed + if (!matrix.nnz()) { + return; + } + + // Choose between device and host implementation + if constexpr (KokkosKernels::Impl::kk_is_gpu_exec_space()) { + KokkosSparse::Impl::crsmatrix_traversal_on_gpu(space, matrix, functor); + } else { + KokkosSparse::Impl::crsmatrix_traversal_on_host(space, matrix, functor); + } +} + +/// \brief Public interface to sparse matrix traversal algorithm. +/// +/// Loop over the entries of the input matrix and apply the functor +/// to them. The functor itself may contain its own data to save results +/// after the traversal completes. +/// +/// \tparam crsmatrix_type +/// \tparam functor_type +/// +/// \param matrix [in] the matrix to be traversed. +/// \param functor [in] a functor that is being called on each local entries +/// of the crsmatrix and that implement a user defined capabilities. +/// +template +void crsmatrix_traversal(const crsmatrix_type& matrix, functor_type& functor) { + using execution_space = typename crsmatrix_type::execution_space; + execution_space space{}; + crsmatrix_traversal(space, matrix, functor); +} + +} // namespace Experimental +} // namespace KokkosSparse + +#endif // KOKKOSSPARSE_CRSMATRIX_TRAVERSAL_HPP diff --git a/sparse/unit_test/Test_Sparse.hpp b/sparse/unit_test/Test_Sparse.hpp index 8ae06b598a..1d88b95f09 100644 --- a/sparse/unit_test/Test_Sparse.hpp +++ b/sparse/unit_test/Test_Sparse.hpp @@ -44,6 +44,7 @@ #include "Test_Sparse_crs2ccs.hpp" #include "Test_Sparse_removeCrsMatrixZeros.hpp" #include "Test_Sparse_extractCrsDiagonalBlocks.hpp" +#include "Test_Sparse_crsmatrix_traversal.hpp" // TPL specific tests, these require // particular pairs of backend and TPL diff --git a/sparse/unit_test/Test_Sparse_crsmatrix_traversal.hpp b/sparse/unit_test/Test_Sparse_crsmatrix_traversal.hpp new file mode 100644 index 0000000000..f1f680080a --- /dev/null +++ b/sparse/unit_test/Test_Sparse_crsmatrix_traversal.hpp @@ -0,0 +1,124 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER + +/// \file Test_Sparse_SortCrs.hpp +/// \brief Tests for sort_crs_matrix and sort_crs_graph in +/// KokkosSparse_CrsMatrix_traversal.hpp + +#ifndef TEST_SPARSE_CRSMATRIX_TRAVERSAL_HPP +#define TEST_SPARSE_CRSMATRIX_TRAVERSAL_HPP + +#include + +#include "KokkosKernels_Test_Structured_Matrix.hpp" +#include "KokkosSparse_CrsMatrix_traversal.hpp" + +namespace TestCrsMatrixTraversal { + +template +struct diag_extraction { + using diag_view = typename CrsMatrix::values_type::non_const_type; + using size_type = typename CrsMatrix::non_const_size_type; + using ordinal_type = typename CrsMatrix::non_const_ordinal_type; + using value_type = typename CrsMatrix::non_const_value_type; + + diag_view diag; + + diag_extraction(CrsMatrix A) { + diag = diag_view("diag values", A.numRows()); + }; + + KOKKOS_INLINE_FUNCTION void operator()(const ordinal_type rowIdx, + const size_type /*entryIdx*/, + const ordinal_type colIdx, + const value_type value) const { + if (rowIdx == colIdx) { + diag(rowIdx) = value; + } + } +}; + +} // namespace TestCrsMatrixTraversal + +void testCrsMatrixTraversal(int testCase) { + using namespace TestCrsMatrixTraversal; + using Device = + Kokkos::Device; + using Matrix = KokkosSparse::CrsMatrix; + using Vector = Kokkos::View; + + constexpr int nx = 4, ny = 4; + constexpr bool leftBC = true, rightBC = false, topBC = false, botBC = false; + + Kokkos::View mat_structure("Matrix Structure", + 2); + mat_structure(0, 0) = nx; + mat_structure(0, 1) = (leftBC ? 1 : 0); + mat_structure(0, 2) = (rightBC ? 1 : 0); + + mat_structure(1, 0) = ny; + mat_structure(1, 1) = (topBC ? 1 : 0); + mat_structure(1, 2) = (botBC ? 1 : 0); + + Matrix A = Test::generate_structured_matrix2D("FD", mat_structure); + + Vector diag_ref("diag ref", A.numRows()); + auto diag_ref_h = Kokkos::create_mirror_view(diag_ref); + diag_ref_h(0) = 1; + diag_ref_h(1) = 3; + diag_ref_h(2) = 3; + diag_ref_h(3) = 2; + diag_ref_h(4) = 1; + diag_ref_h(5) = 4; + diag_ref_h(6) = 4; + diag_ref_h(7) = 3; + diag_ref_h(8) = 1; + diag_ref_h(9) = 4; + diag_ref_h(10) = 4; + diag_ref_h(11) = 3; + diag_ref_h(12) = 1; + diag_ref_h(13) = 3; + diag_ref_h(14) = 3; + diag_ref_h(15) = 2; + + // Run the diagonal extraction functor + // using traversal function. + diag_extraction func(A); + KokkosSparse::Experimental::crsmatrix_traversal(A, func); + Kokkos::fence(); + + // Extract the diagonal view from functor + auto diag_h = Kokkos::create_mirror_view(func.diag); + Kokkos::deep_copy(diag_h, func.diag); + + // Check for correctness + bool matches = true; + for (int rowIdx = 0; rowIdx < A.numRows(); ++rowIdx) { + if (diag_ref_h(rowIdx) != diag_h(rowIdx)) matches = false; + } + + EXPECT_TRUE(matches) + << "Test case " << testCase + << ": matrix with zeros filtered out does not match reference."; +} + +TEST_F(TestCategory, sparse_crsmatrix_traversal) { + for (int testCase = 0; testCase < 1; testCase++) + testCrsMatrixTraversal(testCase); +} + +#endif // TEST_SPARSE_CRSMATRIX_TRAVERSAL_HPP