Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor/static_tensor.hpp"
#include "ck/utility/is_detected.hpp"
Expand Down Expand Up @@ -61,24 +62,38 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static constexpr auto SrcScalarPerVector = Number<SrcScalarPerVector_ / PackedSize>{};
static constexpr auto DstScalarPerVector = Number<DstScalarPerVector_ / PackedSize>{};

// Scalar per access sequences for SFC
static constexpr auto src_scalar_per_access_ = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
static constexpr auto dst_scalar_per_access_ = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});

// SpaceFillingCurve types for precomputed index lookups
using SrcSFC = SpaceFillingCurve<SliceLengths,
SrcDimAccessOrder,
remove_cv_t<decltype(src_scalar_per_access_)>>;
using DstSFC = SpaceFillingCurve<SliceLengths,
DstDimAccessOrder,
remove_cv_t<decltype(dst_scalar_per_access_)>>;

// =====================================================================
// Private implementation methods (must be declared before public methods
// that call them)
// =====================================================================
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
return Helper::ComputeCoordinateResetStep<SliceLengths,
SrcVectorDim,
SrcScalarPerVector_,
SrcDimAccessOrder>();
return ThreadwiseTransferHelper_SFC::ComputeSFCCoordinateResetStep<
SliceLengths,
SrcDimAccessOrder,
decltype(src_scalar_per_access_)>();
}

__device__ static constexpr auto GetDstCoordinateResetStep()
{
return Helper::ComputeCoordinateResetStep<SliceLengths,
DstVectorDim,
DstScalarPerVector_,
DstDimAccessOrder>();
return ThreadwiseTransferHelper_SFC::ComputeSFCCoordinateResetStep<
SliceLengths,
DstDimAccessOrder,
decltype(dst_scalar_per_access_)>();
}

__device__ static constexpr auto GetSrcThreadScratchDescriptor()
Expand Down Expand Up @@ -112,29 +127,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
});
#else
// OOB Check
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});

constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;

constexpr auto src_dim_access_order = SrcDimAccessOrder{};

constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// OOB Check using SpaceFillingCurve for O(1) index lookup
constexpr index_t num_access = SrcSFC::GetNumOfAccess();

// loop over tensor and copy
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep =
Helper::ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths);

// calculate src data index
constexpr auto src_data_idx = Helper::ComputeDataIndex(ordered_src_access_idx,
ordered_src_access_lengths,
forward_sweep,
src_dim_access_order,
src_scalar_per_access);
// loop over space-filling curve
static_for<0, num_access, 1>{}([&](auto iAccess) {
// Get data index from precomputed lookup table - O(1)
constexpr auto src_data_idx = SrcSFC::GetIndex(iAccess);

constexpr auto src_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
Expand Down Expand Up @@ -292,38 +291,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value,
"wrong! SrcBuffer and SrcData data type are inconsistent");

// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});

constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;

static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0,
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector");

constexpr auto src_dim_access_order = SrcDimAccessOrder{};

constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
constexpr index_t num_access = SrcSFC::GetNumOfAccess();

// make forward and backward steps
const auto src_forward_steps = Helper::ComputeForwardSteps(src_desc, src_scalar_per_access);
const auto src_backward_steps =
Helper::ComputeBackwardSteps(src_desc, src_scalar_per_access);

// loop over tensor and copy
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep =
Helper::ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths);

// calculate src data index
constexpr auto src_data_idx = Helper::ComputeDataIndex(ordered_src_access_idx,
ordered_src_access_lengths,
forward_sweep,
src_dim_access_order,
src_scalar_per_access);
// loop over space-filling curve
static_for<0, num_access, 1>{}([&](auto iAccess) {
// Get data index from precomputed lookup table - O(1)
constexpr auto src_data_idx = SrcSFC::GetIndex(iAccess);

constexpr auto src_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
Expand Down Expand Up @@ -396,25 +372,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
.template SetAsType<dst_vector_t>(
src_data_idx_seq, op_r_v.template AsType<dst_vector_t>()[Helper::I0]);

constexpr auto move_on_dim =
Helper::ComputeMoveOnDim(ordered_src_access_idx, ordered_src_access_lengths);

// move src coord
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
}
}
});
// move src coord using precomputed forward step
if constexpr(iAccess.value != num_access - 1)
{
constexpr auto forward_step = SrcSFC::GetForwardStep(iAccess);
move_tensor_coordinate(
src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
}
});

// move src coordinate back to slice origin (or not)
Expand Down Expand Up @@ -445,35 +409,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong");

// src scalar per access on each dim
// TODO: don't use this
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});

constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;

constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr index_t num_access = DstSFC::GetNumOfAccess();

constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);

// make forward and backward steps
const auto dst_forward_steps = Helper::ComputeForwardSteps(dst_desc, dst_scalar_per_access);
const auto dst_backward_steps =
Helper::ComputeBackwardSteps(dst_desc, dst_scalar_per_access);

// loop over tensor and copy
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep =
Helper::ComputeForwardSweep(ordered_dst_access_idx, ordered_dst_access_lengths);

// calculate dst data index
constexpr auto dst_data_idx = Helper::ComputeDataIndex(ordered_dst_access_idx,
ordered_dst_access_lengths,
forward_sweep,
dst_dim_access_order,
dst_scalar_per_access);
// loop over space-filling curve
static_for<0, num_access, 1>{}([&](auto iAccess) {
// Get data index from precomputed lookup table - O(1)
constexpr auto dst_data_idx = DstSFC::GetIndex(iAccess);

constexpr auto dst_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<dst_data_idx[i]>{}; }, Number<dst_data_idx.Size()>{});
Expand Down Expand Up @@ -501,25 +442,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[Helper::I0]);

constexpr auto move_on_dim =
Helper::ComputeMoveOnDim(ordered_dst_access_idx, ordered_dst_access_lengths);

// move dst coord
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
}
}
});
// move dst coord using precomputed forward step
if constexpr(iAccess.value != num_access - 1)
{
constexpr auto forward_step = DstSFC::GetForwardStep(iAccess);
move_tensor_coordinate(
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
}
});

// move dst coordinate back to slice origin (or not)
Expand Down
Loading