Skip to content

Commit d162e26

Browse files
committed
Replace O(N) recursive sequence_map_inverse with O(1) pack expansion
Use constexpr loop in find_source_index to locate permutation inverse indices, then expand via pack expansion for O(1) template instantiation depth instead of O(N) recursive template instantiation.
1 parent 44f481a commit d162e26

File tree

1 file changed

+34
-20
lines changed

1 file changed

+34
-20
lines changed

include/ck/utility/sequence.hpp

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -525,31 +525,45 @@ struct is_valid_sequence_map : is_same<typename arithmetic_sequence_gen<0, SeqMa
525525
{
526526
};
527527

528-
template <typename SeqMap>
529-
struct sequence_map_inverse
528+
// Invert a permutation sequence using O(1) template instantiations
529+
// For X2Y = {a, b, c, ...}, computes Y2X where Y2X[X2Y[i]] = i
530+
//
531+
// Uses a single constexpr member function instead of N template instantiations.
532+
// The function is called N times via pack expansion, but only one function exists.
533+
namespace detail {
534+
template <typename T, index_t N>
535+
struct ConstexprArray
530536
{
531-
template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
532-
struct sequence_map_inverse_impl
533-
{
534-
static constexpr auto new_y2x =
535-
WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
537+
T data[N];
536538

537-
using type =
538-
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::
539-
type;
540-
};
539+
constexpr const T& operator[](index_t i) const { return data[i]; }
540+
};
541+
} // namespace detail
542+
543+
template <index_t... Is>
544+
struct sequence_map_inverse<Sequence<Is...>>
545+
{
546+
private:
547+
static constexpr detail::ConstexprArray<index_t, sizeof...(Is)> values = {{Is...}};
541548

542-
template <typename X2Y, typename WorkingY2X, index_t XBegin>
543-
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
549+
static constexpr index_t find_inverse(index_t target)
544550
{
545-
using type = WorkingY2X;
546-
};
551+
for(index_t i = 0; i < static_cast<index_t>(sizeof...(Is)); ++i)
552+
{
553+
if(values[i] == target)
554+
return i;
555+
}
556+
return -1; // should not reach for valid permutation
557+
}
547558

548-
using type =
549-
typename sequence_map_inverse_impl<SeqMap,
550-
typename uniform_sequence_gen<SeqMap::Size(), 0>::type,
551-
0,
552-
SeqMap::Size()>::type;
559+
template <index_t... Positions>
560+
static constexpr auto compute(Sequence<Positions...>)
561+
{
562+
return Sequence<find_inverse(Positions)...>{};
563+
}
564+
565+
public:
566+
using type = decltype(compute(make_index_sequence<sizeof...(Is)>{}));
553567
};
554568

555569
template <index_t... Xs, index_t... Ys>

0 commit comments

Comments
 (0)