diff --git a/projects/composablekernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/projects/composablekernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index 8b0b35935f0..e90fb9aefaa 100644 --- a/projects/composablekernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/projects/composablekernel/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -6,6 +6,7 @@ #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor/static_tensor.hpp" #include "ck/utility/is_detected.hpp" @@ -61,24 +62,38 @@ struct ThreadwiseTensorSliceTransfer_v3r1 static constexpr auto SrcScalarPerVector = Number{}; static constexpr auto DstScalarPerVector = Number{}; + // Scalar per access sequences for SFC + static constexpr auto src_scalar_per_access_ = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + static constexpr auto dst_scalar_per_access_ = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + // SpaceFillingCurve types for precomputed index lookups + using SrcSFC = SpaceFillingCurve>; + using DstSFC = SpaceFillingCurve>; + // ===================================================================== // Private implementation methods (must be declared before public methods // that call them) // ===================================================================== __device__ static constexpr auto GetSrcCoordinateResetStep() { - return Helper::ComputeCoordinateResetStep(); + return ThreadwiseTransferHelper_SFC::ComputeSFCCoordinateResetStep< + SliceLengths, + SrcDimAccessOrder, + decltype(src_scalar_per_access_)>(); } __device__ static constexpr auto GetDstCoordinateResetStep() { - return Helper::ComputeCoordinateResetStep(); + return ThreadwiseTransferHelper_SFC::ComputeSFCCoordinateResetStep< + SliceLengths, + DstDimAccessOrder, + decltype(dst_scalar_per_access_)>(); } __device__ static constexpr auto GetSrcThreadScratchDescriptor() @@ -112,29 +127,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1 dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; }); #else - // OOB Check - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_dim_access_order = SrcDimAccessOrder{}; - - constexpr auto ordered_src_access_lengths = - container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + // OOB Check using SpaceFillingCurve for O(1) index lookup + constexpr index_t num_access = SrcSFC::GetNumOfAccess(); - // loop over tensor and copy - static_ford{}([&](auto ordered_src_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = - Helper::ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths); - - // calculate src data index - constexpr auto src_data_idx = Helper::ComputeDataIndex(ordered_src_access_idx, - ordered_src_access_lengths, - forward_sweep, - src_dim_access_order, - src_scalar_per_access); + // loop over space-filling curve + static_for<0, num_access, 1>{}([&](auto iAccess) { + // Get data index from precomputed lookup table - O(1) + constexpr auto src_data_idx = SrcSFC::GetIndex(iAccess); constexpr auto src_data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); @@ -292,38 +291,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1 is_same, remove_cvref_t>::value, "wrong! SrcBuffer and SrcData data type are inconsistent"); - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0, "SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"); - constexpr auto src_dim_access_order = SrcDimAccessOrder{}; - - constexpr auto ordered_src_access_lengths = - container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + constexpr index_t num_access = SrcSFC::GetNumOfAccess(); - // make forward and backward steps - const auto src_forward_steps = Helper::ComputeForwardSteps(src_desc, src_scalar_per_access); - const auto src_backward_steps = - Helper::ComputeBackwardSteps(src_desc, src_scalar_per_access); - - // loop over tensor and copy - static_ford{}([&](auto ordered_src_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = - Helper::ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths); - - // calculate src data index - constexpr auto src_data_idx = Helper::ComputeDataIndex(ordered_src_access_idx, - ordered_src_access_lengths, - forward_sweep, - src_dim_access_order, - src_scalar_per_access); + // loop over space-filling curve + static_for<0, num_access, 1>{}([&](auto iAccess) { + // Get data index from precomputed lookup table - O(1) + constexpr auto src_data_idx = SrcSFC::GetIndex(iAccess); constexpr auto src_data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); @@ -396,25 +372,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1 .template SetAsType( src_data_idx_seq, op_r_v.template AsType()[Helper::I0]); - constexpr auto move_on_dim = - Helper::ComputeMoveOnDim(ordered_src_access_idx, ordered_src_access_lengths); - - // move src coord - static_for<0, nDim, 1>{}([&](auto i) { - if constexpr(move_on_dim[i]) - { - if constexpr(forward_sweep[i]) - { - move_tensor_coordinate( - src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); - } - else - { - move_tensor_coordinate( - src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); - } - } - }); + // move src coord using precomputed forward step + if constexpr(iAccess.value != num_access - 1) + { + constexpr auto forward_step = SrcSFC::GetForwardStep(iAccess); + move_tensor_coordinate( + src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step)); + } }); // move src coordinate back to slice origin (or not) @@ -445,35 +409,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 is_same, remove_cvref_t>::value, "wrong! SrcBuffer or DstBuffer data type is wrong"); - // src scalar per access on each dim - // TODO: don't use this - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + constexpr index_t num_access = DstSFC::GetNumOfAccess(); - constexpr auto ordered_dst_access_lengths = - container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); - - // make forward and backward steps - const auto dst_forward_steps = Helper::ComputeForwardSteps(dst_desc, dst_scalar_per_access); - const auto dst_backward_steps = - Helper::ComputeBackwardSteps(dst_desc, dst_scalar_per_access); - - // loop over tensor and copy - static_ford{}([&](auto ordered_dst_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = - Helper::ComputeForwardSweep(ordered_dst_access_idx, ordered_dst_access_lengths); - - // calculate dst data index - constexpr auto dst_data_idx = Helper::ComputeDataIndex(ordered_dst_access_idx, - ordered_dst_access_lengths, - forward_sweep, - dst_dim_access_order, - dst_scalar_per_access); + // loop over space-filling curve + static_for<0, num_access, 1>{}([&](auto iAccess) { + // Get data index from precomputed lookup table - O(1) + constexpr auto dst_data_idx = DstSFC::GetIndex(iAccess); constexpr auto dst_data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); @@ -501,25 +442,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1 is_dst_valid, dst_vector_container.template AsType()[Helper::I0]); - constexpr auto move_on_dim = - Helper::ComputeMoveOnDim(ordered_dst_access_idx, ordered_dst_access_lengths); - - // move dst coord - static_for<0, nDim, 1>{}([&](auto i) { - if constexpr(move_on_dim[i]) - { - if constexpr(forward_sweep[i]) - { - move_tensor_coordinate( - dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); - } - else - { - move_tensor_coordinate( - dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); - } - } - }); + // move dst coord using precomputed forward step + if constexpr(iAccess.value != num_access - 1) + { + constexpr auto forward_step = DstSFC::GetForwardStep(iAccess); + move_tensor_coordinate( + dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); + } }); // move dst coordinate back to slice origin (or not)