Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 92 additions & 45 deletions cpp/oneapi/dal/algo/dbscan/backend/gpu/kernel_fp_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,113 +66,160 @@
/// of the updated arrays(cores and neighbours) for reading and writing
template <typename Float, bool use_weights>
struct get_core_wide_kernel {
static auto run(sycl::queue& queue,
const pr::ndview<Float, 2>& data,
const pr::ndview<Float, 2>& weights,
pr::ndview<std::int32_t, 1>& cores,
pr::ndview<Float, 1>& neighbours,
Float epsilon,
std::int64_t min_observations,
const bk::event_vector& deps) {
const std::int64_t local_row_count = data.get_dimension(0);
const std::int64_t column_count = data.get_dimension(1);
const std::int64_t row_count64 = data.get_dimension(0);
const std::int64_t col_count64 = data.get_dimension(1);

ONEDAL_ASSERT(local_row_count > 0);
ONEDAL_ASSERT(!use_weights || weights.get_dimension(0) == local_row_count);
ONEDAL_ASSERT(row_count64 > 0);
ONEDAL_ASSERT(!use_weights || weights.get_dimension(0) == row_count64);
ONEDAL_ASSERT(!use_weights || weights.get_dimension(1) == 1);
ONEDAL_ASSERT(cores.get_dimension(0) == local_row_count);
ONEDAL_ASSERT(neighbours.get_dimension(0) == local_row_count);
ONEDAL_ASSERT(cores.get_dimension(0) == row_count64);
ONEDAL_ASSERT(neighbours.get_dimension(0) == row_count64);

const Float* data_ptr = data.get_data();
const Float* weights_ptr = weights.get_data();
std::int32_t* cores_ptr = cores.get_mutable_data();
Float* neighbours_ptr = neighbours.get_mutable_data();

auto event = queue.submit([&](sycl::handler& cgh) {
cgh.depends_on(deps);
const std::int64_t wg_size = get_recommended_wg_size(queue, column_count);
const std::int64_t block_split_size =
get_recommended_check_block_size(queue, column_count, wg_size);
cgh.parallel_for(
bk::make_multiple_nd_range_2d({ wg_size, local_row_count }, { wg_size, 1 }),

const std::int64_t wg_size64 = get_recommended_wg_size(queue, col_count64);
const std::int64_t block_split_size64 =
get_recommended_check_block_size(queue, col_count64, wg_size64);

cgh.parallel_for<get_core_wide_kernel<Float, use_weights>>(
bk::make_multiple_nd_range_2d({ wg_size64, row_count64 }, { wg_size64, 1 }),
[=](sycl::nd_item<2> item) {
auto sg = item.get_sub_group();
const std::uint32_t sg_id = sg.get_group_id()[0];
if (sg_id > 0)
sycl::sub_group sg = item.get_sub_group();

// Match reference: only first sub-group of each work-group participates
if (sg.get_group_id()[0] != 0)
return;

const std::uint32_t wg_id = item.get_global_id(1);
if (wg_id >= local_row_count)
const std::uint32_t row_count = static_cast<std::uint32_t>(row_count64);
const std::uint32_t col_count = static_cast<std::uint32_t>(col_count64);

const std::uint32_t row_i =
static_cast<std::uint32_t>(item.get_global_id(1));
if (row_i >= row_count)
return;

const std::uint32_t local_id = sg.get_local_id();
const std::uint32_t local_size = sg.get_local_range()[0];
const std::uint32_t lane = sg.get_local_id()[0];
const std::uint32_t sg_size = sg.get_local_range()[0];

const Float min_obs_f = static_cast<Float>(min_observations);

// Make block_split at least 1
std::uint32_t block_split = static_cast<std::uint32_t>(block_split_size64);
block_split = block_split ? block_split : 1u;

// Base pointers for row_i
const std::uint32_t base_i = row_i * col_count;
const Float* const xi = data_ptr + base_i;

// Keep neighbor count local; write back once (plus early-exit path)
Float count = neighbours_ptr[row_i];

// Iterate over all candidate points j
for (std::uint32_t j = 0; j < row_count; ++j) {
const Float* const xj = data_ptr + (j * col_count);

Float count = neighbours_ptr[wg_id];
for (std::int64_t j = 0; j < local_row_count; j++) {
Float sum = Float(0);
std::int64_t count_iter = 0;
for (std::int64_t i = local_id; i < column_count; i += local_size) {
count_iter++;
Float val =
data_ptr[wg_id * column_count + i] - data_ptr[j * column_count + i];
sum += val * val;

if (count_iter % block_split_size == 0 &&
local_size * count_iter <= column_count) {
Float distance_check =
sycl::reduce_over_group(sg,
sum,
sycl::ext::oneapi::plus<Float>());
if (distance_check > epsilon) {
break;
// Periodic early pruning state
bool pruned = false;
std::uint32_t ticks = block_split;
std::uint32_t iter = 0;

// Feature loop distributed across subgroup lanes
for (std::uint32_t k = lane; k < col_count; k += sg_size) {
++iter;

const Float v = xi[k] - xj[k];
sum = sycl::fma(v, v, sum);

// Early check every block_split iterations (reference cadence)
if (--ticks == 0) {
ticks = block_split;

// Reference guard: local_size * count_iter <= column_count
if (sg_size * iter <= col_count) {
const Float partial =
sycl::reduce_over_group(sg, sum, sycl::plus<Float>());
if (partial > epsilon) {
pruned = true;
break;
}
}
}
}
Float distance =
sycl::reduce_over_group(sg, sum, sycl::ext::oneapi::plus<Float>());
if (distance <= epsilon) {

// If pruned, do NOT perform the final reduction (saves cost)
if (pruned) {
continue;
}

const Float dist =
sycl::reduce_over_group(sg, sum, sycl::plus<Float>());

if (dist <= epsilon) {
// Only load weights if actually a neighbor
count += use_weights ? weights_ptr[j] : Float(1);
if (local_id == 0) {
neighbours_ptr[wg_id] = count;
}
if (count >= min_observations && !use_weights) {
if (local_id == 0) {
cores_ptr[wg_id] = Float(1);

// Reference early exit only in the unweighted case
if (!use_weights && count >= min_obs_f) {
if (lane == 0) {
neighbours_ptr[row_i] = count;
cores_ptr[row_i] = std::int32_t(1);
}
break;
}
}
}
if (neighbours_ptr[wg_id] >= min_observations) {
cores_ptr[wg_id] = Float(1);

// Single global writeback of neighbours (unless early exit already wrote it)
if (lane == 0) {
neighbours_ptr[row_i] = count;

// Reference final marking uses neighbours_ptr value
if (count >= min_obs_f) {
cores_ptr[row_i] = std::int32_t(1);
}
}
});
});
return event;
}
};

/// A struct that finds the core points without subgroups
/// it is effective only on narrow cases. The column count of narrow cases < 4.
///
/// @tparam Float Floating-point type used to perform computations
/// @tparam use_weights Bool type used to check that weights are enabled
///
/// @param[in] queue The SYCL queue
/// @param[in] data The input data of size `row_count` x `column_count`
/// @param[in] weights The input weights of size `row_count` x `1`
/// @param[in] cores The current cores of size `row_count` x `1`
/// @param[in] neighbours The current neighbours of size `row_count` x `1`
/// it contains the counter of neighbours for each point
/// @param[in] epsilon The input parameter epsilon
/// @param[in] min_observations The input parameter min_observation
/// @param[in] deps Events indicating availability of the `data` for reading or writing
///
/// @return A SYCL event indicating the availability
/// of the updated arrays(cores and neighbours) for reading and writing
template <typename Float, bool use_weights>
struct get_core_narrow_kernel {

Check notice on line 222 in cpp/oneapi/dal/algo/dbscan/backend/gpu/kernel_fp_impl.hpp

View check run for this annotation

codefactor.io / CodeFactor

cpp/oneapi/dal/algo/dbscan/backend/gpu/kernel_fp_impl.hpp#L69-L222

Complex Method
static auto run(sycl::queue& queue,
const pr::ndview<Float, 2>& data,
const pr::ndview<Float, 2>& weights,
Expand Down
Loading