From 43fa5052dfebaae8f84ae921754529c35883dcb2 Mon Sep 17 00:00:00 2001 From: Jonah Miller Date: Tue, 26 Aug 2025 10:46:05 -0400 Subject: [PATCH 1/2] First attempt to add serial QR decomposition with Householder pivoting --- .../impl/KokkosBatched_ApplyPivot_Impl.hpp | 130 ++++++++++++++ .../KokkosBatched_ApplyPivot_Internal.hpp | 140 +++++++++++++++ ...ched_QR_WithColumnPivoting_Serial_Impl.hpp | 43 +++++ ..._QR_WithColumnPivoting_Serial_Internal.hpp | 159 ++++++++++++++++++ .../src/KokkosBatched_ApplyPivot_Decl.hpp | 12 ++ ...kkosBatched_QR_WithColumnPivoting_Decl.hpp | 13 ++ 6 files changed, 497 insertions(+) create mode 100644 batched/dense/impl/KokkosBatched_QR_WithColumnPivoting_Serial_Impl.hpp create mode 100644 batched/dense/impl/KokkosBatched_QR_WithColumnPivoting_Serial_Internal.hpp diff --git a/batched/dense/impl/KokkosBatched_ApplyPivot_Impl.hpp b/batched/dense/impl/KokkosBatched_ApplyPivot_Impl.hpp index adbce23d9a..58d83db335 100644 --- a/batched/dense/impl/KokkosBatched_ApplyPivot_Impl.hpp +++ b/batched/dense/impl/KokkosBatched_ApplyPivot_Impl.hpp @@ -23,6 +23,136 @@ namespace KokkosBatched { +/// +/// Serial Internal Impl +/// ======================== + +/// +/// Forward pivot apply +/// + +/// row swap +template<> +struct SerialApplyPivot { + template + KOKKOS_INLINE_FUNCTION static int invoke(const int piv, const AViewType &A) { + if (AViewType::rank == 1) { + const int as0 = A.stride(0); + SerialApplyPivotVectorForwardInternal::invoke(piv, A.data(), as0); + } else if (AViewType::rank == 2) { + const int n = A.extent(1), as0 = A.stride(0), as1 = A.stride(1); + SerialPivotMatrixForwardInternal::invoke(n, piv, A.data(), as0, as1); + } + return 0; + } + + template + KOKKOS_INLINE_FUNCTION static int invoke(const PivViewType piv, const AViewType &A) { + if (AViewType::rank == 1) { + const int plen = piv.extent(0), ps0 = piv.stride(0), as0 = A.stride(0); + SerialApplyPivotVectorForwardInternal::invoke(plen, piv.data(), ps0, A.data(), as0); + } else if (AViewType::rank == 2) { + // row permutation + const int plen = piv.extent(0), ps0 = piv.stride(0), n = A.extent(1), as0 = A.stride(0), as1 = A.stride(1); + SerialApplyPivotMatrixForwardInternal::invoke(n, plen, piv.data(), ps0, A.data(), as0, as1); + } + return 0; + } +}; + +/// column swap +template <> +struct SerialApplyPivot { + template + KOKKOS_INLINE_FUNCTION static int invoke(const int piv, const AViewType &A) { + if (AViewType::rank == 1) { + const int as0 = A.stride(0); + SerialApplyPivotVectorForwardInternal::invoke(piv, A.data(), as0); + } else if (AViewType::rank == 2) { + const int m = A.extent(0), as0 = A.stride(0), as1 = A.stride(1); + SerialApplyPivotMatrixForwardInternal::invoke(m, piv, A.data(), as1, as0); + } + return 0; + } + + template + KOKKOS_INLINE_FUNCTION static int invoke(const PivViewType &piv, const AViewType &A) { + if (AViewType::rank == 1) { + const int plen = piv.extent(0), as0 = A.stride(0); + SerialApplyPivotVectorForwardInternal ::invoke(plen, piv.data(), A.data(), as0); + } else if (AViewType::rank == 2) { + // column permutation + const int plen = piv.extent(0), ps = piv.stride(0), m = A.extent(0), as0 = A.stride(0), as1 = A.stride(1); + SerialApplyPivotMatrixForwardInternal ::invoke(m, plen, piv.data(), ps, A.data(), as1, as0); + } + return 0; + } +}; + + +/// +/// Backward pivot apply +/// + +/// row swap +template <> +struct SerialApplyPivot { + template + KOKKOS_INLINE_FUNCTION static int invoke(const int piv, const AViewType &A) { + if (AViewType::rank == 1) { + const int as0 = A.stride(0); + SerialApplyPivotVectorBackwardInternal::invoke(piv, A.data(), as0); + } else if (AViewType::rank == 2) { + const int n = A.extent(1), as0 = A.stride(0), as1 = A.stride(1); + SerialApplyPivotMatrixBackwardInternal::invoke(n, piv, A.data(), as0, as1); + } + return 0; + } + + template + KOKKOS_INLINE_FUNCTION static int invoke(const PivViewType piv, const AViewType &A) { + if (AViewType::rank == 1) { + const int plen = piv.extent(0), ps0 = piv.stride(0), as0 = A.stride(0); + SerialApplyPivotVectorBackwardInternal::invoke(plen, piv.data(), ps0, A.data(), as0); + } else if (AViewType::rank == 2) { + // row permutation + const int plen = piv.extent(0), ps0 = piv.stride(0), n = A.extent(1), as0 = A.stride(0), as1 = A.stride(1); + SerialApplyPivotMatrixBackwardInternal::invoke(n, plen, piv.data(), ps0, A.data(), as0, as1); + } + return 0; + } +}; + +/// column swap +template <> +struct SerialApplyPivot { + template + KOKKOS_INLINE_FUNCTION static int invoke(const int piv, const AViewType &A) { + if (AViewType::rank == 1) { + const int as0 = A.stride(0); + SerialApplyPivotVectorBackwardInternal::invoke(piv, A.data(), as0); + } else if (AViewType::rank == 2) { + const int m = A.extent(0), as0 = A.stride(0), as1 = A.stride(1); + SerialApplyPivotMatrixBackwardInternal::invoke(m, piv, A.data(), as1, as0); + } + return 0; + } + + template + KOKKOS_INLINE_FUNCTION static int invoke(const PivViewType &piv, const AViewType &A) { + if (AViewType::rank == 1) { + const int plen = piv.extent(0), as0 = A.stride(0); + SerialApplyPivotVectorBackwardInternal ::invoke(plen, piv.data(), A.data(), as0); + } else if (AViewType::rank == 2) { + // column permutation + const int plen = piv.extent(0), ps = piv.stride(0), m = A.extent(0), as0 = A.stride(0), as1 = A.stride(1); + SerialApplyPivotMatrixBackwardInternal ::invoke(m, plen, piv.data(), ps, A.data(), as1, as0); + } + return 0; + } +}; + + /// /// TeamVector Internal Impl /// ======================== diff --git a/batched/dense/impl/KokkosBatched_ApplyPivot_Internal.hpp b/batched/dense/impl/KokkosBatched_ApplyPivot_Internal.hpp index e013252255..8a285859c0 100644 --- a/batched/dense/impl/KokkosBatched_ApplyPivot_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_ApplyPivot_Internal.hpp @@ -22,6 +22,146 @@ namespace KokkosBatched { +/// +/// Serial Internal Impl +/// ======================== + +/// +/// Forward +/// +struct SerialApplyPivotVectorForwardInternal { + template + KOKKOS_INLINE_FUNCTION static int invoke(const int piv, + /* */ ValueType *KOKKOS_RESTRICT A, const int as0) { + if (piv != 0) { + const int idx_p = piv * as0; + const ValueType tmp = A[0]; + A[0] = A[idx_p]; + A[idx_p] = tmp; + } + return 0; + } + + template + KOKKOS_INLINE_FUNCTION static int invoke(const int plen, const IntType *KOKKOS_RESTRICT p, const int ps0, + /* */ ValueType *KOKKOS_RESTRICT A, const int as0) { + for (int i = 0; i < plen; ++i) { + const int piv = p[i * ps0]; + if (piv != 0) { + const int idx_i = i * as0, idx_p = (i + piv) * as0; + const ValueType tmp = A[idx_i]; + A[idx_i] = A[idx_p]; + A[idx_p] = tmp; + } + return 0; + } + } +}; + +/// Pivot a row +struct SerialApplyPivotMatrixForwardInternal { + template + KOKKOS_INLINE_FUNCTION static int invoke(const int n, const int piv, + /* */ ValueType *KOKKOS_RESTRICT A, const int as0, const int as1) { + if (piv != 0) { + for (int j = 0; j < n; ++j) { + ValueType *KOKKOS_RESTRICT A_at_j = A + j * as1; + const int idx_p = piv * as0; + const ValueType tmp = A_at_j[0]; + A_at_j[0] = A_at_j[idx_p]; + A_at_j[idx_p] = tmp; + } + } + return 0; + } + + template + KOKKOS_INLINE_FUNCTION static int invoke(const int n, const int plen, const IntType *KOKKOS_RESTRICT p, const int ps0, + /* */ ValueType *KOKKOS_RESTRICT A, const int as0, const int as1) { + for (int j = 0; j < n; ++j) { + ValueType *KOKKOS_RESTRICT A_at_j = A + j * as1; + for (int i = 0; i < plen; ++i) { + const int piv = p[i * ps0]; + if (piv != 0) { + const int idx_i = i * as0, idx_p = (i + piv) * as0; + const ValueType tmp = A_at_j[idx_i]; + A_at_j[idx_i] = A_at_j[idx_p]; + A_at_j[idx_p] = tmp; + } + } + } + return 0; + } +}; + +/// +/// Backward +/// +struct SerialApplyPivotVectorBackwardInternal { + template + KOKKOS_INLINE_FUNCTION static int invoke(const int piv, + /* */ ValueType *KOKKOS_RESTRICT A, const int as0) { + if (piv != 0) { + const int idx_p = piv * as0; + const ValueType tmp = A[0]; + A[0] = A[idx_p]; + A[idx_p] = tmp; + } + return 0; + } + + template + KOKKOS_INLINE_FUNCTION static int invoke(const int plen, const IntType *KOKKOS_RESTRICT p, const int ps0, + /* */ ValueType *KOKKOS_RESTRICT A, const int as0) { + for (int i = (plen - 1); i >= 0; --i) { + const int piv = p[i * ps0]; + if (piv != 0) { + const int idx_i = i * as0, idx_p = (i + piv) * as0; + const ValueType tmp = A[idx_i]; + A[idx_i] = A[idx_p]; + A[idx_p] = tmp; + } + } + return 0; + } +}; + +/// Pivot a row +struct SerialApplyPivotMatrixBackwardInternal { + template + KOKKOS_INLINE_FUNCTION static int invoke(const int n, const int piv, + /* */ ValueType *KOKKOS_RESTRICT A, const int as0, const int as1) { + if (piv != 0) { + for (int j = 0; j < n; ++j) { + ValueType *KOKKOS_RESTRICT A_at_j = A + j * as1; + const int idx_p = piv * as0; + const ValueType tmp = A_at_j[0]; + A_at_j[0] = A_at_j[idx_p]; + A_at_j[idx_p] = tmp; + } + } + return 0; + } + + template + KOKKOS_INLINE_FUNCTION static int invoke(const int n, const int plen, const IntType *KOKKOS_RESTRICT p, const int ps0, + /* */ ValueType *KOKKOS_RESTRICT A, const int as0, const int as1) { + for (int j = 0; j < n; ++j) { + ValueType *KOKKOS_RESTRICT A_at_j = A + j * as1; + for (int i = (plen - 1); i >= 0; --i) { + const int piv = p[i * ps0]; + if (piv != 0) { + const int idx_i = i * as0, idx_p = (i + piv) * as0; + const ValueType tmp = A_at_j[idx_i]; + A_at_j[idx_i] = A_at_j[idx_p]; + A_at_j[idx_p] = tmp; + } + } + } + return 0; + } +}; + /// /// TeamVector Internal Impl /// ======================== diff --git a/batched/dense/impl/KokkosBatched_QR_WithColumnPivoting_Serial_Impl.hpp b/batched/dense/impl/KokkosBatched_QR_WithColumnPivoting_Serial_Impl.hpp new file mode 100644 index 0000000000..34008c61f9 --- /dev/null +++ b/batched/dense/impl/KokkosBatched_QR_WithColumnPivoting_Serial_Impl.hpp @@ -0,0 +1,43 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER +#ifndef KOKKOSBATCHED_QR_WITH_COLUMNPIVOTING_SERIAL_IMPL_HPP +#define KOKKOSBATCHED_QR_WITH_COLUMNPIVOTING_SERIAL_IMPL_HPP + +/// \author Kyungjoo Kim (kyukim@sandia.gov) + +#include "KokkosBatched_Util.hpp" +#include "KokkosBatched_QR_WithColumnPivoting_Serial_Internal.hpp" + +namespace KokkosBatched { + +/// +/// Serial Impl +/// =============== + +struct SerialQR_WithColumnPivoting { + template + KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const tViewType &t, const pViewType &p, + const wViewType &w, + /* */ int &matrix_rank) { + return SerialQR_WithColumnPivotingInternal::invoke(A.extent(0), A.extent(1), A.data(), A.stride(0), A.stride(1), + t.data(), t.stride(0), p.data(), p.stride(0), w.data(), + matrix_rank); + } +}; + +} // namespace KokkosBatched + +#endif diff --git a/batched/dense/impl/KokkosBatched_QR_WithColumnPivoting_Serial_Internal.hpp b/batched/dense/impl/KokkosBatched_QR_WithColumnPivoting_Serial_Internal.hpp new file mode 100644 index 0000000000..86918b9550 --- /dev/null +++ b/batched/dense/impl/KokkosBatched_QR_WithColumnPivoting_Serial_Internal.hpp @@ -0,0 +1,159 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER +#ifndef KOKKOSBATCHED_QR_WITH_COLUMNPIVOTING_SERIAL_INTERNAL_HPP +#define KOKKOSBATCHED_QR_WITH_COLUMNPIVOTING_SERIAL_INTERNAL_HPP + +/// \author Kyungjoo Kim (kyukim@sandia.gov) + +#include "KokkosBatched_Util.hpp" + +#include "KokkosBatched_FindAmax_Internal.hpp" +#include "KokkosBatched_Dot.hpp" +#include "KokkosBatched_ApplyPivot_Internal.hpp" + +#include "KokkosBatched_Householder_Serial_Internal.hpp" +#include "KokkosBatched_ApplyHouseholder_Serial_Internal.hpp" + +namespace KokkosBatched { + +/// +/// Serial Internal +/// =================== +/// +/// this impl follows the flame interface of householder transformation +/// +struct SerialUpdateColumnNormsInternal { + template + KOKKOS_INLINE_FUNCTION static int invoke(const int n, const ValueType *KOKKOS_RESTRICT a, const int as0, + /* */ ValueType *KOKKOS_RESTRICT norm, const int ns0) { + using ats = Kokkos::ArithTraits; + for (int j = 0; j < n; ++j) { + const int idx_a = j * as0, idx_n = j * ns0; + norm[idx_n] -= ats::conj(a[idx_a]) * a[idx_a]; + } + return 0; + } +}; + +struct SerialQR_WithColumnPivotingInternal { + template + KOKKOS_INLINE_FUNCTION static int invoke(const int m, // m = NumRows(A) + const int n, // n = NumCols(A) + /* */ ValueType *A, const int as0, const int as1, + /* */ ValueType *t, const int ts0, + /* */ IntType *p, const int ps0, + /* */ ValueType *w, + /* */ int &matrix_rank) { + using value_type = ValueType; + using int_type = IntType; + using ats = Kokkos::ArithTraits; + + /// Given a matrix A, it computes QR decomposition of the matrix + /// - t is to store tau and w is for workspace + + // partitions used for loop iteration + Partition2x2 A_part2x2(as0, as1); + Partition3x3 A_part3x3(as0, as1); + + // column vector of tau (size of min_mn) + Partition2x1 t_part2x1(ts0); + Partition3x1 t_part3x1(ts0); + + // row vector for norm and p (size of n) + Partition1x2 p_part1x2(ps0); + Partition1x3 p_part1x3(ps0); + + Partition1x2 norm_part1x2(1); + Partition1x3 norm_part1x3(1); + + // loop size + const int min_mn = m < n ? m : n; + + // workspace (norm and householder application, 2*max(m,n) is needed) + value_type *norm = w; + w += n; + + // initial partition of A where ATL has a zero dimension + A_part2x2.partWithATL(A, m, n, 0, 0); + t_part2x1.partWithAT(t, min_mn, 0); + + p_part1x2.partWithAL(p, n, 0); + norm_part1x2.partWithAL(norm, n, 0); + + // compute initial column norms (replaced by dot product) + SerialDotInternal::invoke(m, n, A, as0, as1, A, as0, as1, norm, 1); + + const bool finish_when_rank_found = (matrix_rank == -1); + + matrix_rank = min_mn; + value_type max_diag(0); + for (int m_atl = 0; m_atl < min_mn; ++m_atl) { + const int n_AR = n - m_atl; + + // part 2x2 into 3x3 + A_part3x3.partWithABR(A_part2x2, 1, 1); + const int m_A22 = m - m_atl - 1; + const int n_A22 = n - m_atl - 1; + + t_part3x1.partWithAB(t_part2x1, 1); + value_type *tau = t_part3x1.A1; + + p_part1x3.partWithAR(p_part1x2, 1); + int_type *pividx = p_part1x3.A1; + + norm_part1x3.partWithAR(norm_part1x2, 1); + + /// ----------------------------------------------------- + // find max location + SerialFindAmaxInternal::invoke(n_AR, norm_part1x2.AR, 1, pividx); + + // apply pivot + SerialApplyPivotVectorForwardInternal::invoke(*pividx, norm_part1x2.AR, 1); + SerialApplyPivotMatrixForwardInternal::invoke(m, *pividx, A_part2x2.ATR, as1, as0); + + // perform householder transformation + SerialApplyLeftHouseholderInternal::invoke(m_A22, A_part3x3.A11, A_part3x3.A21, as0, tau); + + // left apply householder to A22 + SerialApplyLeftHouseholderInternal::invoke(m_A22, n_A22, tau, A_part3x3.A21, as0, A_part3x3.A12, as1, + A_part3x3.A22, as0, as1, w); + + // break condition + if (matrix_rank == min_mn) { + if (m_atl == 0) max_diag = ats::abs(A[0]); + const value_type val_diag = ats::abs(A_part3x3.A11[0]), threshold(10 * max_diag * ats::epsilon()); + if (val_diag < threshold) { + matrix_rank = m_atl; + if (finish_when_rank_found) break; + } + } + + // norm update + SerialUpdateColumnNormsInternal::invoke(n_A22, A_part3x3.A12, as1, norm_part1x3.A2, 1); + /// ----------------------------------------------------- + A_part2x2.mergeToATL(A_part3x3); + t_part2x1.mergeToAT(t_part3x1); + p_part1x2.mergeToAL(p_part1x3); + norm_part1x2.mergeToAL(norm_part1x3); + } + + return 0; + } +}; + +} // end namespace KokkosBatched + +#endif diff --git a/batched/dense/src/KokkosBatched_ApplyPivot_Decl.hpp b/batched/dense/src/KokkosBatched_ApplyPivot_Decl.hpp index f4795f863d..eae968b883 100644 --- a/batched/dense/src/KokkosBatched_ApplyPivot_Decl.hpp +++ b/batched/dense/src/KokkosBatched_ApplyPivot_Decl.hpp @@ -22,6 +22,18 @@ namespace KokkosBatched { +/// +/// Serial +/// ========== +template +struct SerialApplyPivot { + template + KOKKOS_INLINE_FUNCTION static int invoke(const int piv, const AViewType &A); + + template + KOKKOS_INLINE_FUNCTION static int invoke(const PivViewType piv, const AViewType &A); +}; + /// /// TeamVector /// ========== diff --git a/batched/dense/src/KokkosBatched_QR_WithColumnPivoting_Decl.hpp b/batched/dense/src/KokkosBatched_QR_WithColumnPivoting_Decl.hpp index 3c663f237e..32de12c326 100644 --- a/batched/dense/src/KokkosBatched_QR_WithColumnPivoting_Decl.hpp +++ b/batched/dense/src/KokkosBatched_QR_WithColumnPivoting_Decl.hpp @@ -22,6 +22,18 @@ namespace KokkosBatched { +/// +/// Serial QR +/// + +template +struct SerialQR_WithColumnPivoting { + template + KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const tViewType &t, const pViewType &p, + const wViewType &w, + /* */ int &matrix_rank); +}; + /// /// TeamVector QR /// @@ -36,6 +48,7 @@ struct TeamVectorQR_WithColumnPivoting { } // namespace KokkosBatched +#include "KokkosBatched_QR_WithColumnPivoting_TeamVector_Serial.hpp" #include "KokkosBatched_QR_WithColumnPivoting_TeamVector_Impl.hpp" #endif From f335b01df062a2069bf8ccf4c3d55166505ad5fc Mon Sep 17 00:00:00 2001 From: Jonah Miller Date: Tue, 26 Aug 2025 11:36:20 -0400 Subject: [PATCH 2/2] oops include the correctly named file --- batched/dense/src/KokkosBatched_QR_WithColumnPivoting_Decl.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/batched/dense/src/KokkosBatched_QR_WithColumnPivoting_Decl.hpp b/batched/dense/src/KokkosBatched_QR_WithColumnPivoting_Decl.hpp index 32de12c326..be63a4d732 100644 --- a/batched/dense/src/KokkosBatched_QR_WithColumnPivoting_Decl.hpp +++ b/batched/dense/src/KokkosBatched_QR_WithColumnPivoting_Decl.hpp @@ -48,7 +48,7 @@ struct TeamVectorQR_WithColumnPivoting { } // namespace KokkosBatched -#include "KokkosBatched_QR_WithColumnPivoting_TeamVector_Serial.hpp" +#include "KokkosBatched_QR_WithColumnPivoting_Serial_Impl.hpp" #include "KokkosBatched_QR_WithColumnPivoting_TeamVector_Impl.hpp" #endif