From 8f82d59de3dacd3d7748ce13611482d4096c777e Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Thu, 23 Apr 2026 14:07:50 -0400 Subject: [PATCH] [CK] Use SpaceFillingCurve in ThreadwiseTensorSliceTransfer_v3r1 Replace O(N^2) ComputeForwardSweep/ComputeDataIndex/ComputeMoveOnDim calls with O(1) SpaceFillingCurve::GetIndex() lookups in RunRead(), RunWrite(), and TransferDataFromSrcThreadScratchToDstThreadScratch(). This reuses the existing SpaceFillingCurve infrastructure (already used by v6r1, v7r2, etc.) to precompute the serpentine traversal indices at compile time, reducing template instantiation depth significantly. Salvages the approach from PR #5047, rebased onto current develop which has the ThreadwiseTransferHelper refactoring. Generated-by: Claude Code (claude-opus-4-6) --- .../threadwise_tensor_slice_transfer_v3r1.hpp | 177 ++++++------------ 1 file changed, 53 insertions(+), 124 deletions(-) diff --git a/projects/composablekernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/projects/composablekernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index 8b0b35935f0..e90fb9aefaa 100644 --- a/projects/composablekernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/projects/composablekernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -6,6 +6,7 @@ #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor/static_tensor.hpp" #include "ck/utility/is_detected.hpp" @@ -61,24 +62,38 @@ struct ThreadwiseTensorSliceTransfer_v3r1 static constexpr auto SrcScalarPerVector = Number{}; static constexpr auto DstScalarPerVector = Number{}; + // Scalar per access sequences for SFC + static constexpr auto src_scalar_per_access_ = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + static constexpr auto dst_scalar_per_access_ = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + // SpaceFillingCurve types for precomputed index lookups + using SrcSFC = SpaceFillingCurve>; + using DstSFC = SpaceFillingCurve>; + // ===================================================================== // Private implementation methods (must be declared before public methods // that call them) // ===================================================================== __device__ static constexpr auto GetSrcCoordinateResetStep() { - return Helper::ComputeCoordinateResetStep(); + return ThreadwiseTransferHelper_SFC::ComputeSFCCoordinateResetStep< + SliceLengths, + SrcDimAccessOrder, + decltype(src_scalar_per_access_)>(); } __device__ static constexpr auto GetDstCoordinateResetStep() { - return Helper::ComputeCoordinateResetStep(); + return ThreadwiseTransferHelper_SFC::ComputeSFCCoordinateResetStep< + SliceLengths, + DstDimAccessOrder, + decltype(dst_scalar_per_access_)>(); } __device__ static constexpr auto GetSrcThreadScratchDescriptor() @@ -112,29 +127,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1 dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; }); #else - // OOB Check - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_dim_access_order = SrcDimAccessOrder{}; - - constexpr auto ordered_src_access_lengths = - container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + // OOB Check using SpaceFillingCurve for O(1) index lookup + constexpr index_t num_access = SrcSFC::GetNumOfAccess(); - // loop over tensor and copy - static_ford{}([&](auto ordered_src_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = - Helper::ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths); - - // calculate src data index - constexpr auto src_data_idx = Helper::ComputeDataIndex(ordered_src_access_idx, - ordered_src_access_lengths, - forward_sweep, - src_dim_access_order, - src_scalar_per_access); + // loop over space-filling curve + static_for<0, num_access, 1>{}([&](auto iAccess) { + // Get data index from precomputed lookup table - O(1) + constexpr auto src_data_idx = SrcSFC::GetIndex(iAccess); constexpr auto src_data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); @@ -292,38 +291,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1 is_same, remove_cvref_t>::value, "wrong! SrcBuffer and SrcData data type are inconsistent"); - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0, "SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"); - constexpr auto src_dim_access_order = SrcDimAccessOrder{}; - - constexpr auto ordered_src_access_lengths = - container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + constexpr index_t num_access = SrcSFC::GetNumOfAccess(); - // make forward and backward steps - const auto src_forward_steps = Helper::ComputeForwardSteps(src_desc, src_scalar_per_access); - const auto src_backward_steps = - Helper::ComputeBackwardSteps(src_desc, src_scalar_per_access); - - // loop over tensor and copy - static_ford{}([&](auto ordered_src_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = - Helper::ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths); - - // calculate src data index - constexpr auto src_data_idx = Helper::ComputeDataIndex(ordered_src_access_idx, - ordered_src_access_lengths, - forward_sweep, - src_dim_access_order, - src_scalar_per_access); + // loop over space-filling curve + static_for<0, num_access, 1>{}([&](auto iAccess) { + // Get data index from precomputed lookup table - O(1) + constexpr auto src_data_idx = SrcSFC::GetIndex(iAccess); constexpr auto src_data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); @@ -396,25 +372,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1 .template SetAsType( src_data_idx_seq, op_r_v.template AsType()[Helper::I0]); - constexpr auto move_on_dim = - Helper::ComputeMoveOnDim(ordered_src_access_idx, ordered_src_access_lengths); - - // move src coord - static_for<0, nDim, 1>{}([&](auto i) { - if constexpr(move_on_dim[i]) - { - if constexpr(forward_sweep[i]) - { - move_tensor_coordinate( - src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); - } - else - { - move_tensor_coordinate( - src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); - } - } - }); + // move src coord using precomputed forward step + if constexpr(iAccess.value != num_access - 1) + { + constexpr auto forward_step = SrcSFC::GetForwardStep(iAccess); + move_tensor_coordinate( + src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step)); + } }); // move src coordinate back to slice origin (or not) @@ -445,35 +409,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 is_same, remove_cvref_t>::value, "wrong! SrcBuffer or DstBuffer data type is wrong"); - // src scalar per access on each dim - // TODO: don't use this - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + constexpr index_t num_access = DstSFC::GetNumOfAccess(); - constexpr auto ordered_dst_access_lengths = - container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); - - // make forward and backward steps - const auto dst_forward_steps = Helper::ComputeForwardSteps(dst_desc, dst_scalar_per_access); - const auto dst_backward_steps = - Helper::ComputeBackwardSteps(dst_desc, dst_scalar_per_access); - - // loop over tensor and copy - static_ford{}([&](auto ordered_dst_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = - Helper::ComputeForwardSweep(ordered_dst_access_idx, ordered_dst_access_lengths); - - // calculate dst data index - constexpr auto dst_data_idx = Helper::ComputeDataIndex(ordered_dst_access_idx, - ordered_dst_access_lengths, - forward_sweep, - dst_dim_access_order, - dst_scalar_per_access); + // loop over space-filling curve + static_for<0, num_access, 1>{}([&](auto iAccess) { + // Get data index from precomputed lookup table - O(1) + constexpr auto dst_data_idx = DstSFC::GetIndex(iAccess); constexpr auto dst_data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); @@ -501,25 +442,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1 is_dst_valid, dst_vector_container.template AsType()[Helper::I0]); - constexpr auto move_on_dim = - Helper::ComputeMoveOnDim(ordered_dst_access_idx, ordered_dst_access_lengths); - - // move dst coord - static_for<0, nDim, 1>{}([&](auto i) { - if constexpr(move_on_dim[i]) - { - if constexpr(forward_sweep[i]) - { - move_tensor_coordinate( - dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); - } - else - { - move_tensor_coordinate( - dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); - } - } - }); + // move dst coord using precomputed forward step + if constexpr(iAccess.value != num_access - 1) + { + constexpr auto forward_step = DstSFC::GetForwardStep(iAccess); + move_tensor_coordinate( + dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); + } }); // move dst coordinate back to slice origin (or not)