Skip to content

Commit 00849ac

Browse files
committed
Replace lambdas with named functors in transform_tensor_descriptor
Lambda expressions in transform_tensor_descriptor created unique template instantiations for each capture combination. This change replaces lambdas with named functor structs to reduce instantiation count: - Add merge_sequences_functor and unpack_and_merge_sequences helper - Add convert_visible_to_hidden_id and convert_visible_ids_to_hidden_ids - Add generate_arithmetic_sequence_from_scan Build analysis shows instantiation count dropped from 388 to 32 (92% reduction).
1 parent d7e7fbd commit 00849ac

File tree

3 files changed

+72
-33
lines changed

3 files changed

+72
-33
lines changed

include/ck/tensor_description/tensor_descriptor.hpp

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,9 @@ struct TensorDescriptor
3636

3737
__host__ __device__ static constexpr index_t GetNumOfHiddenDimension()
3838
{
39-
constexpr auto all_low_dim_ids = unpack(
40-
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{});
39+
constexpr auto all_low_dim_ids = unpack_and_merge_sequences(LowerDimensionIdss{});
4140

42-
constexpr auto all_up_dim_ids = unpack(
43-
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{});
41+
constexpr auto all_up_dim_ids = unpack_and_merge_sequences(UpperDimensionIdss{});
4442

4543
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
4644

@@ -311,6 +309,45 @@ struct lambda_get_up_dim_num
311309
}
312310
};
313311

312+
// Functor to convert a single visible dimension id to hidden id
313+
// Replaces inner lambda in transform_tensor_descriptor
314+
// Note: transform_sequences passes index_t values, not Number<> types
315+
template <typename OldTensorDescriptor>
316+
struct convert_visible_to_hidden_id
317+
{
318+
__host__ __device__ constexpr auto operator()(index_t low_dim_visible_id) const
319+
{
320+
return OldTensorDescriptor::GetVisibleDimensionIds().At(low_dim_visible_id);
321+
}
322+
};
323+
324+
// Functor to convert a sequence of visible dimension ids to hidden ids
325+
// Replaces outer lambda in transform_tensor_descriptor
326+
template <typename OldTensorDescriptor>
327+
struct convert_visible_ids_to_hidden_ids
328+
{
329+
template <typename LowDimVisibleIds>
330+
__host__ __device__ constexpr auto operator()(LowDimVisibleIds low_dim_visible_ids) const
331+
{
332+
return transform_sequences(convert_visible_to_hidden_id<OldTensorDescriptor>{},
333+
low_dim_visible_ids);
334+
}
335+
};
336+
337+
// Functor to generate arithmetic sequences from scan results
338+
// Replaces lambda in transform_tensor_descriptor that generates up_dim_hidden_idss
339+
template <index_t OldHiddenDimNumber, typename UpDimNumbersScan>
340+
struct generate_arithmetic_sequence_from_scan
341+
{
342+
template <typename I>
343+
__host__ __device__ constexpr auto operator()(I) const
344+
{
345+
constexpr index_t start = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{});
346+
constexpr index_t end = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{} + Number<1>{});
347+
return typename arithmetic_sequence_gen<start, end, 1>::type{};
348+
}
349+
};
350+
314351
template <typename OldTensorDescriptor,
315352
typename NewTransforms,
316353
typename NewLowerDimensionOldVisibleIdss,
@@ -327,11 +364,11 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
327364
NewTransforms::Size() == NewUpperDimensionNewVisibleIdss::Size(),
328365
"wrong! inconsitent number of transform");
329366

330-
constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
331-
NewLowerDimensionOldVisibleIdss{});
367+
constexpr auto all_old_top_ids =
368+
unpack_and_merge_sequences(NewLowerDimensionOldVisibleIdss{});
332369

333-
constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
334-
NewUpperDimensionNewVisibleIdss{});
370+
constexpr auto all_new_top_ids =
371+
unpack_and_merge_sequences(NewUpperDimensionNewVisibleIdss{});
335372

336373
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
337374
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
@@ -341,17 +378,9 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
341378
// lower dimension's hidden idss
342379
// convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of
343380
// sequences)
344-
constexpr auto low_dim_hidden_idss = transform_tuples(
345-
// convert lower dimension visible ids (a sequence) to hidden ids (a sequence)
346-
[](auto low_dim_visible_ids) constexpr {
347-
return transform_sequences(
348-
// convert lower dimension visible id to hidden id
349-
[](auto low_dim_visible_id) constexpr {
350-
return OldTensorDescriptor::GetVisibleDimensionIds()[low_dim_visible_id];
351-
},
352-
low_dim_visible_ids);
353-
},
354-
NewLowerDimensionOldVisibleIdss{});
381+
constexpr auto low_dim_hidden_idss =
382+
transform_tuples(convert_visible_ids_to_hidden_ids<OldTensorDescriptor>{},
383+
NewLowerDimensionOldVisibleIdss{});
355384

356385
constexpr index_t num_new_transform = NewTransforms::Size();
357386

@@ -364,22 +393,17 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
364393
constexpr auto up_dim_numbers_scan = merge_sequences(
365394
Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus<index_t>{}, Number<0>{}));
366395

396+
using UpDimNumbersScanType = remove_cvref_t<decltype(up_dim_numbers_scan)>;
367397
constexpr auto up_dim_hidden_idss = generate_tuple(
368-
[old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr {
369-
return
370-
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
371-
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
372-
1>::type{};
373-
},
398+
generate_arithmetic_sequence_from_scan<old_hidden_dim_number, UpDimNumbersScanType>{},
374399
Number<num_new_transform>{});
375400

376401
// new visible dimension's hidden ids
377402
constexpr auto unordered_new_visible_dim_hidden_ids =
378-
unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
403+
unpack_and_merge_sequences(up_dim_hidden_idss);
379404

380405
constexpr auto new_visible_dim_unordered2ordered =
381-
unpack([](auto... xs) constexpr { return merge_sequences(xs...); },
382-
NewUpperDimensionNewVisibleIdss{});
406+
unpack_and_merge_sequences(NewUpperDimensionNewVisibleIdss{});
383407

384408
constexpr auto new_visible_dim_hidden_ids =
385409
unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered);

include/ck/tensor_operation/gpu/device/matrix_padder.hpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,8 @@ PadTensorDescriptor(const TensorDesc& desc, const TileLengths& tile_lengths, DoP
4343
},
4444
Number<num_dim>{});
4545

46-
// lower dimension Id
47-
const auto lower_dimss =
48-
generate_tuple([&](auto idim) { return Sequence<idim.value>{}; }, Number<num_dim>{});
49-
50-
// upper dimension Id
46+
// lower/upper dimension Ids
47+
const auto lower_dimss = generate_identity_sequences<num_dim>();
5148
const auto upper_dimss = lower_dimss;
5249

5350
return transform_tensor_descriptor(desc, transforms, lower_dimss, upper_dimss);

include/ck/utility/sequence_helper.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,22 @@ __host__ __device__ constexpr auto to_sequence(Tuple<Number<Is>...>)
3434
return Sequence<Is...>{};
3535
}
3636

37+
// Functor for merge_sequences to avoid lambda instantiation overhead
38+
struct merge_sequences_functor
39+
{
40+
template <typename... Seqs>
41+
__host__ __device__ constexpr auto operator()(Seqs... seqs) const
42+
{
43+
return merge_sequences(seqs...);
44+
}
45+
};
46+
47+
// Helper to unpack a tuple of sequences and merge them
48+
// Replaces: unpack([](auto... xs) { return merge_sequences(xs...); }, tuple_of_sequences)
49+
template <typename TupleOfSequences>
50+
__host__ __device__ constexpr auto unpack_and_merge_sequences(TupleOfSequences)
51+
{
52+
return unpack(merge_sequences_functor{}, TupleOfSequences{});
53+
}
54+
3755
} // namespace ck

0 commit comments

Comments
 (0)