Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,15 @@ struct __subgroup_radix_sort
{
assert(__src.size() <= std::numeric_limits<std::uint16_t>::max());
assert(__block_size * __wg_size <= std::numeric_limits<std::uint16_t>::max());
uint16_t __n = __src.size();
std::uint16_t __n = __src.size();
assert(__n <= __block_size * __wg_size);

using _ValT = oneapi::dpl::__internal::__value_t<_RangeIn>;
using _KeyT = oneapi::dpl::__internal::__key_t<_Proj, _RangeIn>;

const auto __counter_buf_sz = __get_counter_buf_size(__wg_size);
_TempBuf<_ValT, _SLM_tag_val> __buf_val(__block_size * __wg_size);
_TempBuf<uint32_t, _SLM_counter> __buf_count(__counter_buf_sz);
_TempBuf<std::uint32_t, _SLM_counter> __buf_count(__counter_buf_sz);

sycl::nd_range __range{sycl::range{__wg_size}, sycl::range{__wg_size}};
return __q.submit([&](sycl::handler& __cgh) {
Expand All @@ -202,9 +202,9 @@ struct __subgroup_radix_sort
_ValT __v[__block_size];
__storage() {}
} __values;
uint16_t __wi = __it.get_local_linear_id();
uint16_t __begin_bit = 0;
constexpr uint16_t __end_bit = sizeof(_KeyT) * ::std::numeric_limits<unsigned char>::digits;
std::uint16_t __wi = __it.get_local_linear_id();
std::uint16_t __begin_bit = 0;
constexpr std::uint16_t __end_bit = sizeof(_KeyT) * std::numeric_limits<unsigned char>::digits;

//copy(move) values construction
__block_load<_ValT>(__wi, __src, __values.__v, __n);
Expand All @@ -213,24 +213,24 @@ struct __subgroup_radix_sort

while (true)
{
uint16_t __indices[__block_size]; //indices for indirect access in the "re-order" phase
std::uint16_t __indices[__block_size]; //indices for indirect access in the "re-order" phase
{
//pointers(by performance reasons) to bucket's counters
uint32_t* __counters[__block_size];
// Cache bin indices to avoid recomputation from values
std::uint16_t __bins[__block_size];

//1. "counting" phase
//counter initialization
auto __pcounter = __dpl_sycl::__get_accessor_ptr(__counter_lacc) + __wi;

_ONEDPL_PRAGMA_UNROLL
for (uint16_t __i = 0; __i < __bin_count; ++__i)
for (std::uint16_t __i = 0; __i < __bin_count; ++__i)
__pcounter[__i * __wg_size] = 0;

_ONEDPL_PRAGMA_UNROLL
for (uint16_t __i = 0; __i < __block_size; ++__i)
for (std::uint16_t __i = 0; __i < __block_size; ++__i)
{
const uint16_t __idx = __wi * __block_size + __i;
const uint16_t __bin =
const std::uint16_t __idx = __wi * __block_size + __i;
const std::uint16_t __bin =
__idx < __n
? __get_bucket</*mask*/ __bin_count - 1>(
__order_preserving_cast<__is_asc>(
Expand All @@ -239,9 +239,10 @@ struct __subgroup_radix_sort
: __bin_count - 1 /*default bin for out of range elements (when idx >= n)*/;

//"counting" and local offset calculation
__counters[__i] = &__pcounter[__bin * __wg_size];
__indices[__i] = *__counters[__i];
*__counters[__i] = __indices[__i] + 1;
__bins[__i] = __bin;
std::uint32_t* __p = &__pcounter[__bin * __wg_size];
__indices[__i] = *__p;
*__p = __indices[__i] + 1;
}
__dpl_sycl::__group_barrier(__it, decltype(__buf_count)::get_fence());

Expand All @@ -250,20 +251,21 @@ struct __subgroup_radix_sort
//TODO: probably can be further optimized

//scan contiguous numbers
uint16_t __bin_sum[__bin_count];
std::uint16_t __bin_sum[__bin_count];
__bin_sum[0] = __counter_lacc[__wi * __bin_count];

_ONEDPL_PRAGMA_UNROLL
for (uint16_t __i = 1; __i < __bin_count; ++__i)
for (std::uint16_t __i = 1; __i < __bin_count; ++__i)
__bin_sum[__i] = __bin_sum[__i - 1] + __counter_lacc[__wi * __bin_count + __i];
__dpl_sycl::__group_barrier(__it, decltype(__buf_count)::get_fence());

//exclusive scan local sum
uint16_t __sum_scan = __dpl_sycl::__exclusive_scan_over_group(
__it.get_group(), __bin_sum[__bin_count - 1], __dpl_sycl::__plus<uint16_t>());
std::uint16_t __sum_scan = __dpl_sycl::__exclusive_scan_over_group(
__it.get_group(), __bin_sum[__bin_count - 1],
__dpl_sycl::__plus<std::uint16_t>());
//add to local sum, generate exclusive scan result
_ONEDPL_PRAGMA_UNROLL
for (uint16_t __i = 0; __i < __bin_count; ++__i)
for (std::uint16_t __i = 0; __i < __bin_count; ++__i)
__counter_lacc[__wi * __bin_count + __i + 1] = __sum_scan + __bin_sum[__i];

if (__wi == 0)
Expand All @@ -272,10 +274,10 @@ struct __subgroup_radix_sort
}

_ONEDPL_PRAGMA_UNROLL
for (uint16_t __i = 0; __i < __block_size; ++__i)
for (std::uint16_t __i = 0; __i < __block_size; ++__i)
{
// a global index is a local offset plus a global base index
__indices[__i] += *__counters[__i];
__indices[__i] += __pcounter[__bins[__i] * __wg_size];
}
}

Expand All @@ -287,22 +289,22 @@ struct __subgroup_radix_sort
{
// the last iteration - writing out the result
_ONEDPL_PRAGMA_UNROLL
for (uint16_t __i = 0; __i < __block_size; ++__i)
for (std::uint16_t __i = 0; __i < __block_size; ++__i)
{
const uint16_t __r = __indices[__i];
const std::uint16_t __r = __indices[__i];
if (__r < __n)
{
//move the values to source range and destroy the values
__src[__r] = ::std::move(__values.__v[__i]);
__src[__r] = std::move(__values.__v[__i]);
__values.__v[__i].~_ValT();
}
}

//destroy values in exchange buffer
_ONEDPL_PRAGMA_UNROLL
for (uint16_t __i = 0; __i < __block_size; ++__i)
for (std::uint16_t __i = 0; __i < __block_size; ++__i)
{
const uint16_t __idx = __wi * __block_size + __i;
const std::uint16_t __idx = __wi * __block_size + __i;
if (__idx < __n)
__exchange_lacc[__idx].~_ValT();
}
Expand All @@ -313,31 +315,31 @@ struct __subgroup_radix_sort
if (__begin_bit == __radix) //the first sort iteration
{
_ONEDPL_PRAGMA_UNROLL
for (uint16_t __i = 0; __i < __block_size; ++__i)
for (std::uint16_t __i = 0; __i < __block_size; ++__i)
{
const uint16_t __r = __indices[__i];
const std::uint16_t __r = __indices[__i];
if (__r < __n)
new (&__exchange_lacc[__r]) _ValT(::std::move(__values.__v[__i]));
new (&__exchange_lacc[__r]) _ValT(std::move(__values.__v[__i]));
}
}
else
{
_ONEDPL_PRAGMA_UNROLL
for (uint16_t __i = 0; __i < __block_size; ++__i)
for (std::uint16_t __i = 0; __i < __block_size; ++__i)
{
const uint16_t __r = __indices[__i];
const std::uint16_t __r = __indices[__i];
if (__r < __n)
__exchange_lacc[__r] = ::std::move(__values.__v[__i]);
__exchange_lacc[__r] = std::move(__values.__v[__i]);
}
}
__dpl_sycl::__group_barrier(__it, decltype(__buf_val)::get_fence());

_ONEDPL_PRAGMA_UNROLL
for (uint16_t __i = 0; __i < __block_size; ++__i)
for (std::uint16_t __i = 0; __i < __block_size; ++__i)
{
const uint16_t __idx = __wi * __block_size + __i;
const std::uint16_t __idx = __wi * __block_size + __i;
if (__idx < __n)
__values.__v[__i] = ::std::move(__exchange_lacc[__idx]);
__values.__v[__i] = std::move(__exchange_lacc[__idx]);
}
__dpl_sycl::__group_barrier(__it, decltype(__buf_val)::get_fence());
}
Expand Down
Loading