Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions batched/dense/impl/KokkosBatched_ApplyPivot_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,136 @@

namespace KokkosBatched {

///
/// Serial Internal Impl
/// ========================

///
/// Forward pivot apply
///

/// row swap
template<>
struct SerialApplyPivot<Side::Left, Direct::Forward> {
template <typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const int piv, const AViewType &A) {
if (AViewType::rank == 1) {
const int as0 = A.stride(0);
SerialApplyPivotVectorForwardInternal::invoke(piv, A.data(), as0);
} else if (AViewType::rank == 2) {
const int n = A.extent(1), as0 = A.stride(0), as1 = A.stride(1);
SerialPivotMatrixForwardInternal::invoke(n, piv, A.data(), as0, as1);
}
return 0;
}

template <typename PivViewType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const PivViewType piv, const AViewType &A) {
if (AViewType::rank == 1) {
const int plen = piv.extent(0), ps0 = piv.stride(0), as0 = A.stride(0);
SerialApplyPivotVectorForwardInternal::invoke(plen, piv.data(), ps0, A.data(), as0);
} else if (AViewType::rank == 2) {
// row permutation
const int plen = piv.extent(0), ps0 = piv.stride(0), n = A.extent(1), as0 = A.stride(0), as1 = A.stride(1);
SerialApplyPivotMatrixForwardInternal::invoke(n, plen, piv.data(), ps0, A.data(), as0, as1);
}
return 0;
}
};

/// column swap
template <>
struct SerialApplyPivot<Side::Right, Direct::Forward> {
template <typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const int piv, const AViewType &A) {
if (AViewType::rank == 1) {
const int as0 = A.stride(0);
SerialApplyPivotVectorForwardInternal::invoke(piv, A.data(), as0);
} else if (AViewType::rank == 2) {
const int m = A.extent(0), as0 = A.stride(0), as1 = A.stride(1);
SerialApplyPivotMatrixForwardInternal::invoke(m, piv, A.data(), as1, as0);
}
return 0;
}

template <typename PivViewType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const PivViewType &piv, const AViewType &A) {
if (AViewType::rank == 1) {
const int plen = piv.extent(0), as0 = A.stride(0);
SerialApplyPivotVectorForwardInternal ::invoke(plen, piv.data(), A.data(), as0);
} else if (AViewType::rank == 2) {
// column permutation
const int plen = piv.extent(0), ps = piv.stride(0), m = A.extent(0), as0 = A.stride(0), as1 = A.stride(1);
SerialApplyPivotMatrixForwardInternal ::invoke(m, plen, piv.data(), ps, A.data(), as1, as0);
}
return 0;
}
};


///
/// Backward pivot apply
///

/// row swap
template <>
struct SerialApplyPivot<Side::Left, Direct::Backward> {
template <typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const int piv, const AViewType &A) {
if (AViewType::rank == 1) {
const int as0 = A.stride(0);
SerialApplyPivotVectorBackwardInternal::invoke(piv, A.data(), as0);
} else if (AViewType::rank == 2) {
const int n = A.extent(1), as0 = A.stride(0), as1 = A.stride(1);
SerialApplyPivotMatrixBackwardInternal::invoke(n, piv, A.data(), as0, as1);
}
return 0;
}

template <typename PivViewType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const PivViewType piv, const AViewType &A) {
if (AViewType::rank == 1) {
const int plen = piv.extent(0), ps0 = piv.stride(0), as0 = A.stride(0);
SerialApplyPivotVectorBackwardInternal::invoke(plen, piv.data(), ps0, A.data(), as0);
} else if (AViewType::rank == 2) {
// row permutation
const int plen = piv.extent(0), ps0 = piv.stride(0), n = A.extent(1), as0 = A.stride(0), as1 = A.stride(1);
SerialApplyPivotMatrixBackwardInternal::invoke(n, plen, piv.data(), ps0, A.data(), as0, as1);
}
return 0;
}
};

/// column swap
template <>
struct SerialApplyPivot<MemberType, Side::Right, Direct::Backward> {
template <typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const int piv, const AViewType &A) {
if (AViewType::rank == 1) {
const int as0 = A.stride(0);
SerialApplyPivotVectorBackwardInternal::invoke(piv, A.data(), as0);
} else if (AViewType::rank == 2) {
const int m = A.extent(0), as0 = A.stride(0), as1 = A.stride(1);
SerialApplyPivotMatrixBackwardInternal::invoke(m, piv, A.data(), as1, as0);
}
return 0;
}

template <typename PivViewType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const PivViewType &piv, const AViewType &A) {
if (AViewType::rank == 1) {
const int plen = piv.extent(0), as0 = A.stride(0);
SerialApplyPivotVectorBackwardInternal ::invoke(plen, piv.data(), A.data(), as0);
} else if (AViewType::rank == 2) {
// column permutation
const int plen = piv.extent(0), ps = piv.stride(0), m = A.extent(0), as0 = A.stride(0), as1 = A.stride(1);
SerialApplyPivotMatrixBackwardInternal ::invoke(m, plen, piv.data(), ps, A.data(), as1, as0);
}
return 0;
}
};


///
/// TeamVector Internal Impl
/// ========================
Expand Down
140 changes: 140 additions & 0 deletions batched/dense/impl/KokkosBatched_ApplyPivot_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,146 @@

namespace KokkosBatched {

///
/// Serial Internal Impl
/// ========================

///
/// Forward
///
struct SerialApplyPivotVectorForwardInternal {
template <typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const int piv,
/* */ ValueType *KOKKOS_RESTRICT A, const int as0) {
if (piv != 0) {
const int idx_p = piv * as0;
const ValueType tmp = A[0];
A[0] = A[idx_p];
A[idx_p] = tmp;
}
return 0;
}

template <typename IntType, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const int plen, const IntType *KOKKOS_RESTRICT p, const int ps0,
/* */ ValueType *KOKKOS_RESTRICT A, const int as0) {
for (int i = 0; i < plen; ++i) {
const int piv = p[i * ps0];
if (piv != 0) {
const int idx_i = i * as0, idx_p = (i + piv) * as0;
const ValueType tmp = A[idx_i];
A[idx_i] = A[idx_p];
A[idx_p] = tmp;
}
return 0;
}
}
};

/// Pivot a row
struct SerialApplyPivotMatrixForwardInternal {
template <typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const int n, const int piv,
/* */ ValueType *KOKKOS_RESTRICT A, const int as0, const int as1) {
if (piv != 0) {
for (int j = 0; j < n; ++j) {
ValueType *KOKKOS_RESTRICT A_at_j = A + j * as1;
const int idx_p = piv * as0;
const ValueType tmp = A_at_j[0];
A_at_j[0] = A_at_j[idx_p];
A_at_j[idx_p] = tmp;
}
}
return 0;
}

template <typename IntType, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const int n, const int plen, const IntType *KOKKOS_RESTRICT p, const int ps0,
/* */ ValueType *KOKKOS_RESTRICT A, const int as0, const int as1) {
for (int j = 0; j < n; ++j) {
ValueType *KOKKOS_RESTRICT A_at_j = A + j * as1;
for (int i = 0; i < plen; ++i) {
const int piv = p[i * ps0];
if (piv != 0) {
const int idx_i = i * as0, idx_p = (i + piv) * as0;
const ValueType tmp = A_at_j[idx_i];
A_at_j[idx_i] = A_at_j[idx_p];
A_at_j[idx_p] = tmp;
}
}
}
return 0;
}
};

///
/// Backward
///
struct SerialApplyPivotVectorBackwardInternal {
template <typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const int piv,
/* */ ValueType *KOKKOS_RESTRICT A, const int as0) {
if (piv != 0) {
const int idx_p = piv * as0;
const ValueType tmp = A[0];
A[0] = A[idx_p];
A[idx_p] = tmp;
}
return 0;
}

template <typename IntType, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const int plen, const IntType *KOKKOS_RESTRICT p, const int ps0,
/* */ ValueType *KOKKOS_RESTRICT A, const int as0) {
for (int i = (plen - 1); i >= 0; --i) {
const int piv = p[i * ps0];
if (piv != 0) {
const int idx_i = i * as0, idx_p = (i + piv) * as0;
const ValueType tmp = A[idx_i];
A[idx_i] = A[idx_p];
A[idx_p] = tmp;
}
}
return 0;
}
};

/// Pivot a row
struct SerialApplyPivotMatrixBackwardInternal {
template <typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const int n, const int piv,
/* */ ValueType *KOKKOS_RESTRICT A, const int as0, const int as1) {
if (piv != 0) {
for (int j = 0; j < n; ++j) {
ValueType *KOKKOS_RESTRICT A_at_j = A + j * as1;
const int idx_p = piv * as0;
const ValueType tmp = A_at_j[0];
A_at_j[0] = A_at_j[idx_p];
A_at_j[idx_p] = tmp;
}
}
return 0;
}

template <typename IntType, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const int n, const int plen, const IntType *KOKKOS_RESTRICT p, const int ps0,
/* */ ValueType *KOKKOS_RESTRICT A, const int as0, const int as1) {
for (int j = 0; j < n; ++j) {
ValueType *KOKKOS_RESTRICT A_at_j = A + j * as1;
for (int i = (plen - 1); i >= 0; --i) {
const int piv = p[i * ps0];
if (piv != 0) {
const int idx_i = i * as0, idx_p = (i + piv) * as0;
const ValueType tmp = A_at_j[idx_i];
A_at_j[idx_i] = A_at_j[idx_p];
A_at_j[idx_p] = tmp;
}
}
}
return 0;
}
};

///
/// TeamVector Internal Impl
/// ========================
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER
#ifndef KOKKOSBATCHED_QR_WITH_COLUMNPIVOTING_SERIAL_IMPL_HPP
#define KOKKOSBATCHED_QR_WITH_COLUMNPIVOTING_SERIAL_IMPL_HPP

/// \author Kyungjoo Kim (kyukim@sandia.gov)

#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_QR_WithColumnPivoting_Serial_Internal.hpp"

namespace KokkosBatched {

///
/// Serial Impl
/// ===============

struct SerialQR_WithColumnPivoting<Algo::QR::Unblocked> {
template <typename AViewType, typename tViewType, typename pViewType, typename wViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const tViewType &t, const pViewType &p,
const wViewType &w,
/* */ int &matrix_rank) {
return SerialQR_WithColumnPivotingInternal::invoke(A.extent(0), A.extent(1), A.data(), A.stride(0), A.stride(1),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am going to guess that some assumptions are made here. For instance you are pulling extents and strides 0 and 1 from A so you expect a rank 2 view. There should be a static assertion on the rank of A then.
You might also want to check that p is storing some kind of integral type?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, will do.

t.data(), t.stride(0), p.data(), p.stride(0), w.data(),
matrix_rank);
}
};

} // namespace KokkosBatched

#endif
Loading