@@ -36,11 +36,9 @@ struct TensorDescriptor
3636
3737 __host__ __device__ static constexpr index_t GetNumOfHiddenDimension ()
3838 {
39- constexpr auto all_low_dim_ids = unpack (
40- [](auto &&... xs) constexpr { return merge_sequences (xs...); }, LowerDimensionIdss{});
39+ constexpr auto all_low_dim_ids = unpack_and_merge_sequences (LowerDimensionIdss{});
4140
42- constexpr auto all_up_dim_ids = unpack (
43- [](auto &&... xs) constexpr { return merge_sequences (xs...); }, UpperDimensionIdss{});
41+ constexpr auto all_up_dim_ids = unpack_and_merge_sequences (UpperDimensionIdss{});
4442
4543 constexpr auto all_dim_ids = merge_sequences (all_low_dim_ids, all_up_dim_ids);
4644
@@ -311,6 +309,45 @@ struct lambda_get_up_dim_num
311309 }
312310};
313311
312+ // Functor to convert a single visible dimension id to hidden id
313+ // Replaces inner lambda in transform_tensor_descriptor
314+ // Note: transform_sequences passes index_t values, not Number<> types
315+ template <typename OldTensorDescriptor>
316+ struct convert_visible_to_hidden_id
317+ {
318+ __host__ __device__ constexpr auto operator ()(index_t low_dim_visible_id) const
319+ {
320+ return OldTensorDescriptor::GetVisibleDimensionIds ().At (low_dim_visible_id);
321+ }
322+ };
323+
324+ // Functor to convert a sequence of visible dimension ids to hidden ids
325+ // Replaces outer lambda in transform_tensor_descriptor
326+ template <typename OldTensorDescriptor>
327+ struct convert_visible_ids_to_hidden_ids
328+ {
329+ template <typename LowDimVisibleIds>
330+ __host__ __device__ constexpr auto operator ()(LowDimVisibleIds low_dim_visible_ids) const
331+ {
332+ return transform_sequences (convert_visible_to_hidden_id<OldTensorDescriptor>{},
333+ low_dim_visible_ids);
334+ }
335+ };
336+
337+ // Functor to generate arithmetic sequences from scan results
338+ // Replaces lambda in transform_tensor_descriptor that generates up_dim_hidden_idss
339+ template <index_t OldHiddenDimNumber, typename UpDimNumbersScan>
340+ struct generate_arithmetic_sequence_from_scan
341+ {
342+ template <typename I>
343+ __host__ __device__ constexpr auto operator ()(I) const
344+ {
345+ constexpr index_t start = OldHiddenDimNumber + UpDimNumbersScan{}.At (I{});
346+ constexpr index_t end = OldHiddenDimNumber + UpDimNumbersScan{}.At (I{} + Number<1 >{});
347+ return typename arithmetic_sequence_gen<start, end, 1 >::type{};
348+ }
349+ };
350+
314351template <typename OldTensorDescriptor,
315352 typename NewTransforms,
316353 typename NewLowerDimensionOldVisibleIdss,
@@ -327,11 +364,11 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
327364 NewTransforms::Size () == NewUpperDimensionNewVisibleIdss::Size (),
328365 " wrong! inconsitent number of transform" );
329366
330- constexpr auto all_old_top_ids = unpack ([]( auto ... xs) { return merge_sequences (xs...); },
331- NewLowerDimensionOldVisibleIdss{});
367+ constexpr auto all_old_top_ids =
368+ unpack_and_merge_sequences ( NewLowerDimensionOldVisibleIdss{});
332369
333- constexpr auto all_new_top_ids = unpack ([]( auto ... xs) { return merge_sequences (xs...); },
334- NewUpperDimensionNewVisibleIdss{});
370+ constexpr auto all_new_top_ids =
371+ unpack_and_merge_sequences ( NewUpperDimensionNewVisibleIdss{});
335372
336373 static_assert (is_valid_sequence_map<decltype (all_old_top_ids)>::value &&
337374 is_valid_sequence_map<decltype (all_new_top_ids)>::value,
@@ -341,17 +378,9 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
341378 // lower dimension's hidden idss
342379 // convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of
343380 // sequences)
344- constexpr auto low_dim_hidden_idss = transform_tuples (
345- // convert lower dimension visible ids (a sequence) to hidden ids (a sequence)
346- [](auto low_dim_visible_ids) constexpr {
347- return transform_sequences (
348- // convert lower dimension visible id to hidden id
349- [](auto low_dim_visible_id) constexpr {
350- return OldTensorDescriptor::GetVisibleDimensionIds ()[low_dim_visible_id];
351- },
352- low_dim_visible_ids);
353- },
354- NewLowerDimensionOldVisibleIdss{});
381+ constexpr auto low_dim_hidden_idss =
382+ transform_tuples (convert_visible_ids_to_hidden_ids<OldTensorDescriptor>{},
383+ NewLowerDimensionOldVisibleIdss{});
355384
356385 constexpr index_t num_new_transform = NewTransforms::Size ();
357386
@@ -364,22 +393,17 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
364393 constexpr auto up_dim_numbers_scan = merge_sequences (
365394 Sequence<0 >{}, inclusive_scan_sequence (up_dim_numbers, math::plus<index_t >{}, Number<0 >{}));
366395
396+ using UpDimNumbersScanType = remove_cvref_t <decltype (up_dim_numbers_scan)>;
367397 constexpr auto up_dim_hidden_idss = generate_tuple (
368- [old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr {
369- return
370- typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
371- old_hidden_dim_number + up_dim_numbers_scan[i + 1 ],
372- 1 >::type{};
373- },
398+ generate_arithmetic_sequence_from_scan<old_hidden_dim_number, UpDimNumbersScanType>{},
374399 Number<num_new_transform>{});
375400
376401 // new visible dimension's hidden ids
377402 constexpr auto unordered_new_visible_dim_hidden_ids =
378- unpack ([]( auto ... xs) constexpr { return merge_sequences (xs...); }, up_dim_hidden_idss);
403+ unpack_and_merge_sequences ( up_dim_hidden_idss);
379404
380405 constexpr auto new_visible_dim_unordered2ordered =
381- unpack ([](auto ... xs) constexpr { return merge_sequences (xs...); },
382- NewUpperDimensionNewVisibleIdss{});
406+ unpack_and_merge_sequences (NewUpperDimensionNewVisibleIdss{});
383407
384408 constexpr auto new_visible_dim_hidden_ids =
385409 unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New (new_visible_dim_unordered2ordered);
0 commit comments