diff --git a/cpp/oneapi/dal/algo/dbscan/backend/gpu/kernel_fp_impl.hpp b/cpp/oneapi/dal/algo/dbscan/backend/gpu/kernel_fp_impl.hpp index 7980f461942..9bf8aef8abb 100644 --- a/cpp/oneapi/dal/algo/dbscan/backend/gpu/kernel_fp_impl.hpp +++ b/cpp/oneapi/dal/algo/dbscan/backend/gpu/kernel_fp_impl.hpp @@ -74,14 +74,14 @@ struct get_core_wide_kernel { Float epsilon, std::int64_t min_observations, const bk::event_vector& deps) { - const std::int64_t local_row_count = data.get_dimension(0); - const std::int64_t column_count = data.get_dimension(1); + const std::int64_t row_count64 = data.get_dimension(0); + const std::int64_t col_count64 = data.get_dimension(1); - ONEDAL_ASSERT(local_row_count > 0); - ONEDAL_ASSERT(!use_weights || weights.get_dimension(0) == local_row_count); + ONEDAL_ASSERT(row_count64 > 0); + ONEDAL_ASSERT(!use_weights || weights.get_dimension(0) == row_count64); ONEDAL_ASSERT(!use_weights || weights.get_dimension(1) == 1); - ONEDAL_ASSERT(cores.get_dimension(0) == local_row_count); - ONEDAL_ASSERT(neighbours.get_dimension(0) == local_row_count); + ONEDAL_ASSERT(cores.get_dimension(0) == row_count64); + ONEDAL_ASSERT(neighbours.get_dimension(0) == row_count64); const Float* data_ptr = data.get_data(); const Float* weights_ptr = weights.get_data(); @@ -90,62 +90,109 @@ struct get_core_wide_kernel { auto event = queue.submit([&](sycl::handler& cgh) { cgh.depends_on(deps); - const std::int64_t wg_size = get_recommended_wg_size(queue, column_count); - const std::int64_t block_split_size = - get_recommended_check_block_size(queue, column_count, wg_size); - cgh.parallel_for( - bk::make_multiple_nd_range_2d({ wg_size, local_row_count }, { wg_size, 1 }), + + const std::int64_t wg_size64 = get_recommended_wg_size(queue, col_count64); + const std::int64_t block_split_size64 = + get_recommended_check_block_size(queue, col_count64, wg_size64); + + cgh.parallel_for>( + bk::make_multiple_nd_range_2d({ wg_size64, row_count64 }, { wg_size64, 1 }), [=](sycl::nd_item<2> item) { - auto sg = item.get_sub_group(); - const std::uint32_t sg_id = sg.get_group_id()[0]; - if (sg_id > 0) + sycl::sub_group sg = item.get_sub_group(); + + // Match reference: only first sub-group of each work-group participates + if (sg.get_group_id()[0] != 0) return; - const std::uint32_t wg_id = item.get_global_id(1); - if (wg_id >= local_row_count) + const std::uint32_t row_count = static_cast(row_count64); + const std::uint32_t col_count = static_cast(col_count64); + + const std::uint32_t row_i = + static_cast(item.get_global_id(1)); + if (row_i >= row_count) return; - const std::uint32_t local_id = sg.get_local_id(); - const std::uint32_t local_size = sg.get_local_range()[0]; + const std::uint32_t lane = sg.get_local_id()[0]; + const std::uint32_t sg_size = sg.get_local_range()[0]; + + const Float min_obs_f = static_cast(min_observations); + + // Make block_split at least 1 + std::uint32_t block_split = static_cast(block_split_size64); + block_split = block_split ? block_split : 1u; + + // Base pointers for row_i + const std::uint32_t base_i = row_i * col_count; + const Float* const xi = data_ptr + base_i; + + // Keep neighbor count local; write back once (plus early-exit path) + Float count = neighbours_ptr[row_i]; + + // Iterate over all candidate points j + for (std::uint32_t j = 0; j < row_count; ++j) { + const Float* const xj = data_ptr + (j * col_count); - Float count = neighbours_ptr[wg_id]; - for (std::int64_t j = 0; j < local_row_count; j++) { Float sum = Float(0); - std::int64_t count_iter = 0; - for (std::int64_t i = local_id; i < column_count; i += local_size) { - count_iter++; - Float val = - data_ptr[wg_id * column_count + i] - data_ptr[j * column_count + i]; - sum += val * val; - if (count_iter % block_split_size == 0 && - local_size * count_iter <= column_count) { - Float distance_check = - sycl::reduce_over_group(sg, - sum, - sycl::ext::oneapi::plus()); - if (distance_check > epsilon) { - break; + // Periodic early pruning state + bool pruned = false; + std::uint32_t ticks = block_split; + std::uint32_t iter = 0; + + // Feature loop distributed across subgroup lanes + for (std::uint32_t k = lane; k < col_count; k += sg_size) { + ++iter; + + const Float v = xi[k] - xj[k]; + sum = sycl::fma(v, v, sum); + + // Early check every block_split iterations (reference cadence) + if (--ticks == 0) { + ticks = block_split; + + // Reference guard: local_size * count_iter <= column_count + if (sg_size * iter <= col_count) { + const Float partial = + sycl::reduce_over_group(sg, sum, sycl::plus()); + if (partial > epsilon) { + pruned = true; + break; + } } } } - Float distance = - sycl::reduce_over_group(sg, sum, sycl::ext::oneapi::plus()); - if (distance <= epsilon) { + + // If pruned, do NOT perform the final reduction (saves cost) + if (pruned) { + continue; + } + + const Float dist = + sycl::reduce_over_group(sg, sum, sycl::plus()); + + if (dist <= epsilon) { + // Only load weights if actually a neighbor count += use_weights ? weights_ptr[j] : Float(1); - if (local_id == 0) { - neighbours_ptr[wg_id] = count; - } - if (count >= min_observations && !use_weights) { - if (local_id == 0) { - cores_ptr[wg_id] = Float(1); + + // Reference early exit only in the unweighted case + if (!use_weights && count >= min_obs_f) { + if (lane == 0) { + neighbours_ptr[row_i] = count; + cores_ptr[row_i] = std::int32_t(1); } break; } } } - if (neighbours_ptr[wg_id] >= min_observations) { - cores_ptr[wg_id] = Float(1); + + // Single global writeback of neighbours (unless early exit already wrote it) + if (lane == 0) { + neighbours_ptr[row_i] = count; + + // Reference final marking uses neighbours_ptr value + if (count >= min_obs_f) { + cores_ptr[row_i] = std::int32_t(1); + } } }); });