diff --git a/include/oneapi/dpl/experimental/kt/internal/cooperative_lookback.h b/include/oneapi/dpl/experimental/kt/internal/cooperative_lookback.h index e8798e67626..8c92a548a3b 100644 --- a/include/oneapi/dpl/experimental/kt/internal/cooperative_lookback.h +++ b/include/oneapi/dpl/experimental/kt/internal/cooperative_lookback.h @@ -342,7 +342,7 @@ struct __cooperative_lookback if (__is_full_ballot_bits) { oneapi::dpl::__par_backend_hetero::__sub_group_scan_partial< - __sub_group_size, /*__use_subgroup_ops=*/true, /*__is_inclusive*/ true, + /*__use_subgroup_ops=*/true, /*__is_inclusive*/ true, /*__init_present*/ decltype(__is_initialized)::value>( __subgroup, __tile_value, __binary_op, __running, __lowest_item_with_full + 1, static_cast(nullptr)); @@ -351,7 +351,7 @@ struct __cooperative_lookback else { oneapi::dpl::__par_backend_hetero::__sub_group_scan< - __sub_group_size, /*__use_subgroup_ops=*/true, /*__is_inclusive*/ true, + /*__use_subgroup_ops=*/true, /*__is_inclusive*/ true, /*__init_present*/ decltype(__is_initialized)::value>( __subgroup, __tile_value, __binary_op, __running, static_cast(nullptr)); return false; diff --git a/include/oneapi/dpl/experimental/kt/internal/sub_group/sub_group_scan.h b/include/oneapi/dpl/experimental/kt/internal/sub_group/sub_group_scan.h index 3f97ab08830..0d39c6d4435 100644 --- a/include/oneapi/dpl/experimental/kt/internal/sub_group/sub_group_scan.h +++ b/include/oneapi/dpl/experimental/kt/internal/sub_group/sub_group_scan.h @@ -91,15 +91,13 @@ __sub_group_scan(const _SubGroup& __sub_group, _InputTypeWrapped __input[__iters _ScanValueType* __no_slm = nullptr; if (__is_full) { - oneapi::dpl::__par_backend_hetero::__sub_group_scan<__sub_group_size, /*__use_subgroup_ops=*/true, - /*__is_inclusive*/ true, + oneapi::dpl::__par_backend_hetero::__sub_group_scan( __sub_group, __extract_scan_input(__input[0]), __binary_op, __carry, __no_slm); _ONEDPL_PRAGMA_UNROLL for (std::uint16_t __i = 1; __i < __iters_per_item; ++__i) { - oneapi::dpl::__par_backend_hetero::__sub_group_scan<__sub_group_size, /*__use_subgroup_ops=*/true, - /*__is_inclusive*/ true, + oneapi::dpl::__par_backend_hetero::__sub_group_scan( __sub_group, __extract_scan_input(__input[__i]), __binary_op, __carry, __no_slm); } @@ -111,7 +109,7 @@ __sub_group_scan(const _SubGroup& __sub_group, _InputTypeWrapped __input[__iters std::uint16_t __i = 0; if (__limited_iters_per_item == 1) { - oneapi::dpl::__par_backend_hetero::__sub_group_scan_partial<__sub_group_size, /*__use_subgroup_ops=*/true, + oneapi::dpl::__par_backend_hetero::__sub_group_scan_partial( __sub_group, __extract_scan_input(__input[__i]), __binary_op, __carry, @@ -119,18 +117,17 @@ __sub_group_scan(const _SubGroup& __sub_group, _InputTypeWrapped __input[__iters } else if (__limited_iters_per_item > 1) { - oneapi::dpl::__par_backend_hetero::__sub_group_scan<__sub_group_size, /*__use_subgroup_ops=*/true, - /*__is_inclusive*/ true, + oneapi::dpl::__par_backend_hetero::__sub_group_scan( __sub_group, __extract_scan_input(__input[__i++]), __binary_op, __carry, __no_slm); for (; __i < __limited_iters_per_item - 1; ++__i) { - oneapi::dpl::__par_backend_hetero::__sub_group_scan<__sub_group_size, /*__use_subgroup_ops=*/true, + oneapi::dpl::__par_backend_hetero::__sub_group_scan( __sub_group, __extract_scan_input(__input[__i]), __binary_op, __carry, __no_slm); } - oneapi::dpl::__par_backend_hetero::__sub_group_scan_partial<__sub_group_size, /*__use_subgroup_ops=*/true, + oneapi::dpl::__par_backend_hetero::__sub_group_scan_partial( __sub_group, __extract_scan_input(__input[__i]), __binary_op, __carry, diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h index b1b987cf52b..b3d728f9bf1 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h @@ -681,8 +681,6 @@ __parallel_transform_scan(oneapi::dpl::__internal::__device_backend_tag, _Execut using _Type = typename _InitType::__value_type; - bool __use_reduce_then_scan = oneapi::dpl::__par_backend_hetero::__is_gpu_with_reduce_then_scan_sg_sz(__q_local); - // The single work-group implementation requires a fundamental type which must be trivially copyable. if constexpr (std::is_trivially_copyable_v<_Type>) { @@ -696,7 +694,7 @@ __parallel_transform_scan(oneapi::dpl::__internal::__device_backend_tag, _Execut // GPU: reduce-then-scan is efficient for moderate-sized inputs, so the single-group cutoff is low. // CPU: kernel launch overhead dominates, so prefer the single-group path for larger inputs. - std::size_t __single_group_upper_limit = __use_reduce_then_scan ? 2048 : 16384; + std::size_t __single_group_upper_limit = __q_local.get_device().is_gpu() ? 2048 : 16384; if (__group_scan_fits_in_slm<_Type>(__q_local, __n, __n_uniform, __single_group_upper_limit)) { auto __event = __parallel_transform_scan_single_group<_CustomName>( @@ -711,8 +709,7 @@ __parallel_transform_scan(oneapi::dpl::__internal::__device_backend_tag, _Execut } } } - - if (__use_reduce_then_scan) + //reduce_then_scan implementation { using _GenInput = oneapi::dpl::__par_backend_hetero::__gen_transform_input<_UnaryOperation, typename _InitType::__value_type>; @@ -727,32 +724,6 @@ __parallel_transform_scan(oneapi::dpl::__internal::__device_backend_tag, _Execut __binary_op, __gen_transform, _ScanInputTransform{}, _WriteOp{}, __init, _Inclusive{}, /*_IsUniquePattern=*/std::false_type{}); } - else // use multi pass scan implementation - { - using _Assigner = unseq_backend::__scan_assigner; - using _NoAssign = unseq_backend::__scan_ignore; - using _UnaryFunctor = unseq_backend::walk_n<_UnaryOperation>; - using _Unchanged = unseq_backend::__unchanged; - - _Assigner __assign_op; - _NoAssign __ignore_op; - _Unchanged __read_op; - - auto&& [__event, __payload] = __parallel_transform_scan_base<_CustomName>( - __q_local, std::forward<_Range1>(__in_rng), std::forward<_Range2>(__out_rng), __init, - // local scan - unseq_backend::__scan<_Inclusive, _BinaryOperation, _UnaryFunctor, _Assigner, _Assigner, _Unchanged, - _InitType>{__binary_op, _UnaryFunctor{__unary_op}, __assign_op, __assign_op, - __read_op}, - // scan between groups - unseq_backend::__scan>{ - __binary_op, __read_op, __ignore_op, __assign_op, __read_op}, - // global scan - unseq_backend::__global_scan_functor<_Inclusive, _BinaryOperation, _InitType>{__binary_op, __init}, - /*apex*/ __ignore_op); - return __future(std::move(__event), __result_and_scratch_storage<_Type>(__move_state_from(__payload))); - } } template (__rng), std::forward<_Range2>(__result), __n, __n_out, oneapi::dpl::__internal::__unique_at_index<_BinaryPredicate, true>{__pred}, _Assign{}, __max_wg_size); } - else if (__n_out >= __n && oneapi::dpl::__par_backend_hetero::__is_gpu_with_reduce_then_scan_sg_sz(__q_local)) + else if (__n_out >= __n) // TODO: figure out how to support limited output ranges in the reduce-then-scan pattern { using _GenMask = oneapi::dpl::__par_backend_hetero::__gen_unique_mask<_BinaryPredicate>; @@ -905,28 +876,16 @@ __parallel_partition_copy(oneapi::dpl::__internal::__device_backend_tag, _Execut { using _CustomName = oneapi::dpl::__internal::__policy_kernel_name<_ExecutionPolicy>; using _Size1 = oneapi::dpl::__internal::__difference_t<_Range1>; + using _GenMask = oneapi::dpl::__par_backend_hetero::__gen_mask<_UnaryPredicate>; + using _WriteOp = oneapi::dpl::__par_backend_hetero::__write_to_id_if_else; sycl::queue __q_local = __exec.queue(); _Size1 __n = oneapi::dpl::__ranges::__size(__rng); - if (oneapi::dpl::__par_backend_hetero::__is_gpu_with_reduce_then_scan_sg_sz(__q_local)) - { - using _GenMask = oneapi::dpl::__par_backend_hetero::__gen_mask<_UnaryPredicate>; - using _WriteOp = - oneapi::dpl::__par_backend_hetero::__write_to_id_if_else; - - return __parallel_reduce_then_scan_copy<_CustomName>( - __q_local, std::forward<_Range1>(__rng), std::forward<_Range2>(__result), __n, _GenMask{__pred}, _WriteOp{}, - /*_IsUniquePattern=*/std::false_type{}); - } - else - { - auto&& [__event, __payload] = __parallel_scan_copy<_CustomName>( - __q_local, std::forward<_Range1>(__rng), std::forward<_Range2>(__result), __n, - oneapi::dpl::__internal::__pred_at_index{__pred}, unseq_backend::__partition_by_mask{}); - return __future(std::move(__event), __result_and_scratch_storage<_Size1>(__move_state_from(__payload))); - } + return __parallel_reduce_then_scan_copy<_CustomName>( + __q_local, std::forward<_Range1>(__rng), std::forward<_Range2>(__result), __n, _GenMask{__pred}, _WriteOp{}, + /*_IsUniquePattern=*/std::false_type{}); } template (__in_rng), std::forward<_OutRng>(__out_rng), __n, __n_out, oneapi::dpl::__internal::__pred_at_index{__pred}, __assign, __max_wg_size); } - else if (__n_out >= __n && oneapi::dpl::__par_backend_hetero::__is_gpu_with_reduce_then_scan_sg_sz(__q_local)) + else if (__n_out >= __n) // TODO: figure out how to support limited output ranges in the reduce-then-scan pattern { using _GenMask = oneapi::dpl::__par_backend_hetero::__gen_mask<_Pred>; @@ -1373,27 +1332,20 @@ std::size_t __set_op_impl(_SetTag __set_tag, sycl::queue& __q, _Range1&& __rng1, _Range2&& __rng2, _Range3&& __result, _Compare __comp, _Proj1 __proj1, _Proj2 __proj2) { - //can we use reduce then scan? - if (oneapi::dpl::__par_backend_hetero::__is_gpu_with_reduce_then_scan_sg_sz(__q)) + if (__check_use_write_a_alg{}(__set_tag, __rng1, __rng2)) + { + // use reduce then scan with set_a write + return __set_write_a_only_op>( + __set_tag, /*use_reduce_then_scan=*/std::true_type{}, __q, std::forward<_Range1>(__rng1), + std::forward<_Range2>(__rng2), std::forward<_Range3>(__result), __comp, __proj1, __proj2); + } + else { - if (__check_use_write_a_alg{}(__set_tag, __rng1, __rng2)) - { - // use reduce then scan with set_a write - return __set_write_a_only_op>( - __set_tag, /*use_reduce_then_scan=*/std::true_type{}, __q, std::forward<_Range1>(__rng1), - std::forward<_Range2>(__rng2), std::forward<_Range3>(__result), __comp, __proj1, __proj2); - } return __parallel_set_write_a_b_op>( __set_tag, __q, std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2), std::forward<_Range3>(__result), __comp, __proj1, __proj2) .get(); } - else - { - return __set_write_a_only_op>( - __set_tag, /*use_reduce_then_scan=*/std::false_type{}, __q, std::forward<_Range1>(__rng1), - std::forward<_Range2>(__rng2), std::forward<_Range3>(__result), __comp, __proj1, __proj2); - } } template ; // Prior to icpx 2025.0, the reduce-then-scan path performs poorly and should be avoided. #if !defined(__INTEL_LLVM_COMPILER) || __INTEL_LLVM_COMPILER >= 20250000 - if (oneapi::dpl::__par_backend_hetero::__is_gpu_with_reduce_then_scan_sg_sz(__q_local)) - { - auto __res = oneapi::dpl::__par_backend_hetero::__parallel_reduce_by_segment_reduce_then_scan<_CustomName>( - __q_local, std::forward<_Range1>(__keys), std::forward<_Range2>(__values), - std::forward<_Range3>(__out_keys), std::forward<_Range4>(__out_values), __binary_pred, __binary_op); - // Because our init type ends up being tuple, return the first component which is the write index. Add 1 to return the - // past-the-end iterator pair of segmented reduction. - return std::get<0>(__res.get()) + 1; - } + auto __res = oneapi::dpl::__par_backend_hetero::__parallel_reduce_by_segment_reduce_then_scan<_CustomName>( + __q_local, std::forward<_Range1>(__keys), std::forward<_Range2>(__values), std::forward<_Range3>(__out_keys), + std::forward<_Range4>(__out_values), __binary_pred, __binary_op); + // Because our init type ends up being tuple, return the first component which is the write index. Add 1 to return the + // past-the-end iterator pair of segmented reduction. + return std::get<0>(__res.get()) + 1; #endif return __parallel_reduce_by_segment_fallback( oneapi::dpl::__internal::__device_backend_tag{}, std::forward<_ExecutionPolicy>(__exec), @@ -2492,23 +2441,10 @@ __parallel_scan_by_segment(oneapi::dpl::__internal::__device_backend_tag, _Execu assert(oneapi::dpl::__ranges::__size(__keys) > 0); sycl::queue __q_local = __exec.queue(); - if (oneapi::dpl::__par_backend_hetero::__is_gpu_with_reduce_then_scan_sg_sz(__q_local)) - { - __parallel_scan_by_segment_reduce_then_scan<_CustomName, __is_inclusive>( - __q_local, std::forward<_Range1>(__keys), std::forward<_Range2>(__values), - std::forward<_Range3>(__out_values), __binary_pred, __binary_op, __init) - .wait(); - return; - } - // Implicit synchronization in this call. We need to wrap the policy as the implementation may still call - // reduce-then-scan and needs to avoid duplicate kernel names. - __parallel_scan_by_segment_fallback<_CustomName, __is_inclusive>( - oneapi::dpl::__internal::__device_backend_tag{}, - oneapi::dpl::__par_backend_hetero::make_wrapped_policy<__scan_by_seg_fallback>( - std::forward<_ExecutionPolicy>(__exec)), - std::forward<_Range1>(__keys), std::forward<_Range2>(__values), std::forward<_Range3>(__out_values), - __binary_pred, __binary_op, __init, - oneapi::dpl::unseq_backend::__has_known_identity<_BinaryOperator, _ValueType>{}); + __parallel_scan_by_segment_reduce_then_scan<_CustomName, __is_inclusive>( + __q_local, std::forward<_Range1>(__keys), std::forward<_Range2>(__values), std::forward<_Range3>(__out_values), + __binary_pred, __binary_op, __init) + .wait(); } } // namespace __par_backend_hetero diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_reduce_then_scan.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_reduce_then_scan.h index 52a75857d01..10fdb39035d 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_reduce_then_scan.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_reduce_then_scan.h @@ -1249,15 +1249,15 @@ __group_broadcast(const __dpl_sycl::__sub_group& __sub_group, _ValueType __value } } -template +template void __exclusive_sub_group_masked_scan(const __dpl_sycl::__sub_group& __sub_group, _MaskOp __mask_fn, _InitBroadcastId __init_broadcast_id, _ValueType& __value, _BinaryOp __binary_op, _LazyValueType& __init_and_carry, _ValueType* __comm_slm) { std::uint8_t __sub_group_local_id = __sub_group.get_local_linear_id(); - _ONEDPL_PRAGMA_UNROLL + const std::uint8_t __sub_group_size = __sub_group.get_max_local_range()[0]; for (std::uint8_t __shift = 1; __shift <= __sub_group_size / 2; __shift <<= 1) { _ValueType __partial_carry_in = @@ -1294,15 +1294,15 @@ __exclusive_sub_group_masked_scan(const __dpl_sycl::__sub_group& __sub_group, _M //return by reference __value and __init_and_carry } -template +template void __inclusive_sub_group_masked_scan(const __dpl_sycl::__sub_group& __sub_group, _MaskOp __mask_fn, _InitBroadcastId __init_broadcast_id, _ValueType& __value, _BinaryOp __binary_op, _LazyValueType& __init_and_carry, _ValueType* __comm_slm) { std::uint8_t __sub_group_local_id = __sub_group.get_local_linear_id(); - _ONEDPL_PRAGMA_UNROLL + const std::uint8_t __sub_group_size = __sub_group.get_max_local_range()[0]; for (std::uint8_t __shift = 1; __shift <= __sub_group_size / 2; __shift <<= 1) { _ValueType __partial_carry_in = @@ -1326,8 +1326,8 @@ __inclusive_sub_group_masked_scan(const __dpl_sycl::__sub_group& __sub_group, _M //return by reference __value and __init_and_carry } -template +template void __sub_group_masked_scan(const __dpl_sycl::__sub_group& __sub_group, _MaskOp __mask_fn, _InitBroadcastId __init_broadcast_id, _ValueType& __value, _BinaryOp __binary_op, @@ -1335,30 +1335,30 @@ __sub_group_masked_scan(const __dpl_sycl::__sub_group& __sub_group, _MaskOp __ma { if constexpr (__is_inclusive) { - __inclusive_sub_group_masked_scan<__sub_group_size, __use_subgroup_ops, __init_present>( + __inclusive_sub_group_masked_scan<__use_subgroup_ops, __init_present>( __sub_group, __mask_fn, __init_broadcast_id, __value, __binary_op, __init_and_carry, __comm_slm); } else { - __exclusive_sub_group_masked_scan<__sub_group_size, __use_subgroup_ops, __init_present>( + __exclusive_sub_group_masked_scan<__use_subgroup_ops, __init_present>( __sub_group, __mask_fn, __init_broadcast_id, __value, __binary_op, __init_and_carry, __comm_slm); } } -template +template void __sub_group_scan(const __dpl_sycl::__sub_group& __sub_group, _ValueType& __value, _BinaryOp __binary_op, _LazyValueType& __init_and_carry, _ValueType* __comm_slm) { auto __mask_fn = [](auto __sub_group_local_id, auto __offset) { return __sub_group_local_id >= __offset; }; - constexpr std::uint8_t __init_broadcast_id = __sub_group_size - 1; - __sub_group_masked_scan<__sub_group_size, __use_subgroup_ops, __is_inclusive, __init_present>( + std::uint8_t __init_broadcast_id = __sub_group.get_max_local_range()[0] - 1; + __sub_group_masked_scan<__use_subgroup_ops, __is_inclusive, __init_present>( __sub_group, __mask_fn, __init_broadcast_id, __value, __binary_op, __init_and_carry, __comm_slm); } -template +template void __sub_group_scan_partial(const __dpl_sycl::__sub_group& __sub_group, _ValueType& __value, _BinaryOp __binary_op, _LazyValueType& __init_and_carry, _SizeType __elements_to_process, _ValueType* __comm_slm) @@ -1367,14 +1367,13 @@ __sub_group_scan_partial(const __dpl_sycl::__sub_group& __sub_group, _ValueType& return __sub_group_local_id >= __offset && __sub_group_local_id < __elements_to_process; }; std::uint8_t __init_broadcast_id = __elements_to_process - 1; - __sub_group_masked_scan<__sub_group_size, __use_subgroup_ops, __is_inclusive, __init_present>( + __sub_group_masked_scan<__use_subgroup_ops, __is_inclusive, __init_present>( __sub_group, __mask_fn, __init_broadcast_id, __value, __binary_op, __init_and_carry, __comm_slm); } -template +template void __scan_through_elements_helper(const __dpl_sycl::__sub_group& __sub_group, _GenInput __gen_input, _ScanInputTransform __scan_input_transform, _BinaryOp __binary_op, _WriteOp __write_op, @@ -1384,6 +1383,8 @@ __scan_through_elements_helper(const __dpl_sycl::__sub_group& __sub_group, _GenI std::uint32_t __active_subgroups, _ScanValueType* __comm_slm) { using _GenInputType = std::invoke_result_t<_GenInput, _InRng, std::size_t, typename _GenInput::TempData&>; + + const std::uint8_t __sub_group_size = __sub_group.get_max_local_range()[0]; bool __is_full_block = (__iters_per_item == __max_inputs_per_item); bool __is_full_thread = __subgroup_start_id + __iters_per_item * __sub_group_size <= __n; using _TempData = typename _GenInput::TempData; @@ -1392,7 +1393,7 @@ __scan_through_elements_helper(const __dpl_sycl::__sub_group& __sub_group, _GenI { _GenInputType __v = __gen_input(__in_rng, __start_id, __temp_data); - __sub_group_scan<__sub_group_size, __use_subgroup_ops, __is_inclusive, __init_present>( + __sub_group_scan<__use_subgroup_ops, __is_inclusive, __init_present>( __sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry, __comm_slm); if constexpr (__capture_output) { @@ -1406,7 +1407,7 @@ __scan_through_elements_helper(const __dpl_sycl::__sub_group& __sub_group, _GenI for (std::uint32_t __j = 1; __j < __max_inputs_per_item; __j++) { __v = __gen_input(__in_rng, __start_id + __j * __sub_group_size, __temp_data); - __sub_group_scan<__sub_group_size, __use_subgroup_ops, __is_inclusive, /*__init_present=*/true>( + __sub_group_scan<__use_subgroup_ops, __is_inclusive, /*__init_present=*/true>( __sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry, __comm_slm); if constexpr (__capture_output) { @@ -1421,7 +1422,7 @@ __scan_through_elements_helper(const __dpl_sycl::__sub_group& __sub_group, _GenI for (std::uint32_t __j = 1; __j < __iters_per_item; __j++) { __v = __gen_input(__in_rng, __start_id + __j * __sub_group_size, __temp_data); - __sub_group_scan<__sub_group_size, __use_subgroup_ops, __is_inclusive, /*__init_present=*/true>( + __sub_group_scan<__use_subgroup_ops, __is_inclusive, /*__init_present=*/true>( __sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry, __comm_slm); if constexpr (__capture_output) { @@ -1442,7 +1443,7 @@ __scan_through_elements_helper(const __dpl_sycl::__sub_group& __sub_group, _GenI { std::size_t __local_id = (__start_id < __n) ? __start_id : __n - 1; _GenInputType __v = __gen_input(__in_rng, __local_id, __temp_data); - __sub_group_scan_partial<__sub_group_size, __use_subgroup_ops, __is_inclusive, __init_present>( + __sub_group_scan_partial<__use_subgroup_ops, __is_inclusive, __init_present>( __sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry, __n - __subgroup_start_id, __comm_slm); if constexpr (__capture_output) @@ -1454,7 +1455,7 @@ __scan_through_elements_helper(const __dpl_sycl::__sub_group& __sub_group, _GenI else { _GenInputType __v = __gen_input(__in_rng, __start_id, __temp_data); - __sub_group_scan<__sub_group_size, __use_subgroup_ops, __is_inclusive, __init_present>( + __sub_group_scan<__use_subgroup_ops, __is_inclusive, __init_present>( __sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry, __comm_slm); if constexpr (__capture_output) { @@ -1465,7 +1466,7 @@ __scan_through_elements_helper(const __dpl_sycl::__sub_group& __sub_group, _GenI { std::size_t __local_id = __start_id + __j * __sub_group_size; __v = __gen_input(__in_rng, __local_id, __temp_data); - __sub_group_scan<__sub_group_size, __use_subgroup_ops, __is_inclusive, /*__init_present=*/true>( + __sub_group_scan<__use_subgroup_ops, __is_inclusive, /*__init_present=*/true>( __sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry, __comm_slm); if constexpr (__capture_output) { @@ -1476,7 +1477,7 @@ __scan_through_elements_helper(const __dpl_sycl::__sub_group& __sub_group, _GenI std::size_t __offset = __start_id + (__iters - 1) * __sub_group_size; std::size_t __local_id = (__offset < __n) ? __offset : __n - 1; __v = __gen_input(__in_rng, __local_id, __temp_data); - __sub_group_scan_partial<__sub_group_size, __use_subgroup_ops, __is_inclusive, /*__init_present=*/true>( + __sub_group_scan_partial<__use_subgroup_ops, __is_inclusive, /*__init_present=*/true>( __sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry, __n - (__subgroup_start_id + (__iters - 1) * __sub_group_size), __comm_slm); if constexpr (__capture_output) @@ -1489,41 +1490,6 @@ __scan_through_elements_helper(const __dpl_sycl::__sub_group& __sub_group, _GenI } } -constexpr inline std::uint8_t -__get_reduce_then_scan_default_sg_sz() -{ - return 32; -} - -constexpr inline std::uint8_t -__get_reduce_then_scan_workaround_sg_sz() -{ - return 16; -} - -// The default sub-group size for reduce-then-scan is 32, but we conditionally enable sub-group sizes of 16 on Intel -// devices to workaround a hardware bug. From the host side, return 32 to assert that this sub-group size is supported -// by an arbitrary device. -constexpr inline std::uint8_t -__get_reduce_then_scan_reqd_sg_sz_host() -{ - return __get_reduce_then_scan_default_sg_sz(); -} - -// To workaround a hardware bug on certain Intel iGPUs with older driver versions and -O0 device compilation, use a -// sub-group size of 16. Note this function may only be called on the device as _ONEDPL_DETECT_SPIRV_COMPILATION is only -// valid here. -constexpr inline std::uint8_t -__get_reduce_then_scan_actual_sg_sz_device() -{ - return -#if _ONEDPL_DETECT_COMPILER_OPTIMIZATIONS_ENABLED || !_ONEDPL_DETECT_SPIRV_COMPILATION - __get_reduce_then_scan_default_sg_sz(); -#else - __get_reduce_then_scan_workaround_sg_sz(); -#endif -} - struct __reduce_then_scan_sub_group_params { __reduce_then_scan_sub_group_params(std::uint32_t __work_group_size, std::uint8_t __sub_group_size, @@ -1567,7 +1533,6 @@ struct __parallel_reduce_then_scan_reduce_submitter<__max_inputs_per_item, __is_ __internal::__optional_kernel_name<_KernelName...>> { using _InitValueType = typename _InitType::__value_type; - static constexpr std::uint8_t __sub_group_size = __get_reduce_then_scan_actual_sg_sz_device(); template void @@ -1576,6 +1541,7 @@ struct __parallel_reduce_then_scan_reduce_submitter<__max_inputs_per_item, __is_ const std::size_t __inputs_remaining, const std::size_t __block_num) const { __dpl_sycl::__sub_group __sub_group = __ndi.get_sub_group(); + const std::uint8_t __sub_group_size = __sub_group.get_max_local_range()[0]; __reduce_then_scan_sub_group_params __sub_group_params( __work_group_size, __sub_group_size, __max_num_work_groups, __max_block_size, __inputs_remaining); @@ -1608,7 +1574,7 @@ struct __parallel_reduce_then_scan_reduce_submitter<__max_inputs_per_item, __is_ { // adjust for lane-id // compute sub-group local prefix on T0..63, K samples/T, send to accumulator kernel - __scan_through_elements_helper<__sub_group_size, __use_subgroup_ops, __is_inclusive, + __scan_through_elements_helper<__use_subgroup_ops, __is_inclusive, /*__init_present=*/false, /*__capture_output=*/false, __max_inputs_per_item>( __sub_group, __gen_reduce_input, oneapi::dpl::identity{}, __reduce_op, nullptr, __sub_group_carry, @@ -1633,9 +1599,8 @@ struct __parallel_reduce_then_scan_reduce_submitter<__max_inputs_per_item, __is_ // fill with unused dummy values to avoid overrunning input std::uint32_t __load_id = std::min(std::uint32_t{__sub_group_local_id}, __active_subgroups - 1); _InitValueType __v = __sub_group_partials[__load_id]; - __sub_group_scan_partial<__sub_group_size, __use_subgroup_ops, /*__is_inclusive=*/true, - /*__init_present=*/false>(__sub_group, __v, __reduce_op, __sub_group_carry, - __active_subgroups, __comm_slm_ptr); + __sub_group_scan_partial<__use_subgroup_ops, /*__is_inclusive=*/true, /*__init_present=*/false>( + __sub_group, __v, __reduce_op, __sub_group_carry, __active_subgroups, __comm_slm_ptr); if (__sub_group_local_id < __active_subgroups) __tmp_acc[__start_id + __sub_group_local_id] = __v; } @@ -1644,18 +1609,16 @@ struct __parallel_reduce_then_scan_reduce_submitter<__max_inputs_per_item, __is_ std::uint32_t __reduction_scan_id = __sub_group_local_id; // need to pull out first iteration tp avoid identity _InitValueType __v = __sub_group_partials[__reduction_scan_id]; - __sub_group_scan<__sub_group_size, __use_subgroup_ops, /*__is_inclusive=*/true, - /*__init_present=*/false>(__sub_group, __v, __reduce_op, __sub_group_carry, - __comm_slm_ptr); + __sub_group_scan<__use_subgroup_ops, /*__is_inclusive=*/true, /*__init_present=*/false>( + __sub_group, __v, __reduce_op, __sub_group_carry, __comm_slm_ptr); __tmp_acc[__start_id + __reduction_scan_id] = __v; __reduction_scan_id += __sub_group_size; for (std::uint32_t __i = 1; __i < __iters - 1; __i++) { __v = __sub_group_partials[__reduction_scan_id]; - __sub_group_scan<__sub_group_size, __use_subgroup_ops, /*__is_inclusive=*/true, - /*__init_present=*/true>(__sub_group, __v, __reduce_op, __sub_group_carry, - __comm_slm_ptr); + __sub_group_scan<__use_subgroup_ops, /*__is_inclusive=*/true, /*__init_present=*/true>( + __sub_group, __v, __reduce_op, __sub_group_carry, __comm_slm_ptr); __tmp_acc[__start_id + __reduction_scan_id] = __v; __reduction_scan_id += __sub_group_size; } @@ -1666,8 +1629,7 @@ struct __parallel_reduce_then_scan_reduce_submitter<__max_inputs_per_item, __is_ std::uint32_t __load_id = std::min(__reduction_scan_id, __sub_group_params.__num_sub_groups_local - 1); __v = __sub_group_partials[__load_id]; - __sub_group_scan_partial<__sub_group_size, __use_subgroup_ops, /*__is_inclusive=*/true, - /*__init_present=*/true>( + __sub_group_scan_partial<__use_subgroup_ops, /*__is_inclusive=*/true, /*__init_present=*/true>( __sub_group, __v, __reduce_op, __sub_group_carry, __active_subgroups - ((__iters - 1) * __sub_group_size), __comm_slm_ptr); if (__reduction_scan_id < __sub_group_params.__num_sub_groups_local) @@ -1695,7 +1657,7 @@ struct __parallel_reduce_then_scan_reduce_submitter<__max_inputs_per_item, __is_ auto __temp_acc = __scratch_container.template __get_scratch_acc( __cgh, __dpl_sycl::__no_init{}); __cgh.parallel_for<_KernelName...>( - __nd_range, [=, *this](sycl::nd_item<1> __ndi) [[sycl::reqd_sub_group_size(__sub_group_size)]] { + __nd_range, [=, *this](sycl::nd_item<1> __ndi) [[_ONEDPL_SYCL_REQD_SUB_GROUP_SIZE_IF_SUPPORTED(32)]] { _InitValueType* __tmp_acc = _TmpStorageAcc::__get_usm_or_buffer_accessor_ptr(__temp_acc); if (!__use_slm_for_comm) @@ -1734,7 +1696,6 @@ struct __parallel_reduce_then_scan_scan_submitter<__max_inputs_per_item, __is_in __internal::__optional_kernel_name<_KernelName...>> { using _InitValueType = typename _InitType::__value_type; - static constexpr std::uint8_t __sub_group_size = __get_reduce_then_scan_actual_sg_sz_device(); template _InitValueType @@ -1760,6 +1721,7 @@ struct __parallel_reduce_then_scan_scan_submitter<__max_inputs_per_item, __is_in const std::size_t __inputs_remaining, const std::size_t __block_num) const { __dpl_sycl::__sub_group __sub_group = __ndi.get_sub_group(); + const std::uint8_t __sub_group_size = __sub_group.get_max_local_range()[0]; _InitValueType* __comm_slm_ptr = __use_subgroup_ops ? nullptr : &__comm_slm[0]; __reduce_then_scan_sub_group_params __sub_group_params( @@ -1844,7 +1806,7 @@ struct __parallel_reduce_then_scan_scan_submitter<__max_inputs_per_item, __is_in std::size_t __reduction_id = (__proposed_id < __subgroups_before_my_group) ? __proposed_id : __subgroups_before_my_group - 1; _InitValueType __value = __tmp_acc[__reduction_id]; - __sub_group_scan_partial<__sub_group_size, __use_subgroup_ops, /*__is_inclusive=*/true, + __sub_group_scan_partial<__use_subgroup_ops, /*__is_inclusive=*/true, /*__init_present=*/false>(__sub_group, __value, __reduce_op, __carry_last, __remaining_elements, __comm_slm_ptr); } @@ -1857,17 +1819,15 @@ struct __parallel_reduce_then_scan_scan_submitter<__max_inputs_per_item, __is_in std::uint32_t __reduction_id_increment = __sub_group_params.__num_sub_groups_local * __sub_group_size; _InitValueType __value = __tmp_acc[__reduction_id]; - __sub_group_scan<__sub_group_size, __use_subgroup_ops, /*__is_inclusive=*/true, - /*__init_present=*/false>(__sub_group, __value, __reduce_op, __carry_last, - __comm_slm_ptr); + __sub_group_scan<__use_subgroup_ops, /*__is_inclusive=*/true, /*__init_present=*/false>( + __sub_group, __value, __reduce_op, __carry_last, __comm_slm_ptr); __reduction_id += __reduction_id_increment; // then some number of full iterations for (std::uint32_t __i = 1; __i < __pre_carry_iters - 1; __i++) { __value = __tmp_acc[__reduction_id]; - __sub_group_scan<__sub_group_size, __use_subgroup_ops, /*__is_inclusive=*/true, - /*__init_present=*/true>(__sub_group, __value, __reduce_op, __carry_last, - __comm_slm_ptr); + __sub_group_scan<__use_subgroup_ops, /*__is_inclusive=*/true, /*__init_present=*/true>( + __sub_group, __value, __reduce_op, __carry_last, __comm_slm_ptr); __reduction_id += __reduction_id_increment; } @@ -1879,7 +1839,7 @@ struct __parallel_reduce_then_scan_scan_submitter<__max_inputs_per_item, __is_in std::size_t __final_reduction_id = std::min(std::size_t{__reduction_id}, __subgroups_before_my_group - 1); __value = __tmp_acc[__final_reduction_id]; - __sub_group_scan_partial<__sub_group_size, __use_subgroup_ops, /*__is_inclusive=*/true, + __sub_group_scan_partial<__use_subgroup_ops, /*__is_inclusive=*/true, /*__init_present=*/true>(__sub_group, __value, __reduce_op, __carry_last, __remaining_elements, __comm_slm_ptr); } @@ -1978,7 +1938,7 @@ struct __parallel_reduce_then_scan_scan_submitter<__max_inputs_per_item, __is_in if (__sub_group_carry_initialized) { - __scan_through_elements_helper<__sub_group_size, __use_subgroup_ops, __is_inclusive, + __scan_through_elements_helper<__use_subgroup_ops, __is_inclusive, /*__init_present=*/true, /*__capture_output=*/true, __max_inputs_per_item>( __sub_group, __gen_scan_input, __scan_input_transform, __reduce_op, __write_op, __sub_group_carry, @@ -1987,7 +1947,7 @@ struct __parallel_reduce_then_scan_scan_submitter<__max_inputs_per_item, __is_in } else // first group first block, no subgroup carry { - __scan_through_elements_helper<__sub_group_size, __use_subgroup_ops, __is_inclusive, + __scan_through_elements_helper<__use_subgroup_ops, __is_inclusive, /*__init_present=*/false, /*__capture_output=*/true, __max_inputs_per_item>( __sub_group, __gen_scan_input, __scan_input_transform, __reduce_op, __write_op, __sub_group_carry, @@ -2050,7 +2010,7 @@ struct __parallel_reduce_then_scan_scan_submitter<__max_inputs_per_item, __is_in __scratch_container.template __get_result_acc(__cgh, __dpl_sycl::__no_init{}); __cgh.parallel_for<_KernelName...>( - __nd_range, [=, *this](sycl::nd_item<1> __ndi) [[sycl::reqd_sub_group_size(__sub_group_size)]] { + __nd_range, [=, *this](sycl::nd_item<1> __ndi) [[_ONEDPL_SYCL_REQD_SUB_GROUP_SIZE_IF_SUPPORTED(32)]] { _InitValueType* __tmp_acc = _TmpStorageAcc::__get_usm_or_buffer_accessor_ptr(__temp_acc); _InitValueType* __res_ptr = _TmpStorageAcc::__get_usm_or_buffer_accessor_ptr(__res_acc, __max_num_sub_groups_global + 2); @@ -2083,16 +2043,6 @@ struct __parallel_reduce_then_scan_scan_submitter<__max_inputs_per_item, __is_in _InitType __init; }; -// Enable reduce-then-scan if the device uses the required sub-group size and is ran on a device -// with fast coordinated subgroup operations. We do not want to run this scan on CPU targets, as they are not -// performant with this algorithm. -inline bool -__is_gpu_with_reduce_then_scan_sg_sz(const sycl::queue& __q) -{ - return (__q.get_device().is_gpu() && - oneapi::dpl::__internal::__supports_sub_group_size(__q, __get_reduce_then_scan_reqd_sg_sz_host())); -} - // General scan-like algorithm helpers // _GenReduceInput - a function which accepts the input range and index to generate the data needed by the main output // used in the reduction operation (to calculate the global carries) @@ -2119,23 +2069,32 @@ __parallel_transform_reduce_then_scan(sycl::queue& __q, const std::size_t __n, _ __reduce_then_scan_scan_kernel<_CustomName>>; using _ValueType = typename _InitType::__value_type; - constexpr std::uint8_t __min_sub_group_size = __get_reduce_then_scan_workaround_sg_sz(); - constexpr std::uint8_t __max_sub_group_size = __get_reduce_then_scan_default_sg_sz(); + // Query the device's supported sub-group sizes to allocate storage conservatively and round + // the work-group size appropriately. The actual sub-group size used by each kernel is determined + // at runtime via sub_group::get_max_local_range(). + const auto __supported_sg_sizes = __q.get_device().template get_info(); + const std::uint8_t __min_sub_group_size = + *std::min_element(__supported_sg_sizes.begin(), __supported_sg_sizes.end()); + const std::uint8_t __max_sub_group_size = + *std::max_element(__supported_sg_sizes.begin(), __supported_sg_sizes.end()); // Empirically determined maximum. May be less for non-full blocks. constexpr std::uint16_t __max_inputs_per_item = std::max(std::uint16_t{1}, std::uint16_t{512 / __bytes_per_work_item_iter}); constexpr bool __inclusive = _Inclusive::value; constexpr bool __is_unique_pattern_v = _IsUniquePattern::value; - const std::uint32_t __max_work_group_size = oneapi::dpl::__internal::__max_work_group_size(__q, 8192); - // Round down to nearest multiple of the subgroup size + // empirical derived caps for workgroup size based upon target + const std::uint32_t __wg_size_cap = __q.get_device().is_gpu() ? 1024 : 256; + const std::uint32_t __max_work_group_size = oneapi::dpl::__internal::__max_work_group_size(__q, __wg_size_cap); + // Round down to nearest multiple of the max subgroup size to ensure compatibility with all sub-group sizes const std::uint32_t __work_group_size = (__max_work_group_size / __max_sub_group_size) * __max_sub_group_size; - // TODO: Investigate potentially basing this on some scale of the number of compute units. 128 work-groups has been - // found to be reasonable number for most devices. - constexpr std::uint32_t __num_work_groups = 128; - // We may use a sub-group size of 16 or 32 depending on the compiler optimization level. Allocate sufficient - // temporary storage to handle both cases. + // use work groups equal to the number of compute units. + const std::uint32_t __num_work_groups = + oneapi::dpl::__internal::__dpl_bit_ceil( + __q.get_device().template get_info()) / + (__q.get_device().is_gpu() ? 4 : 1); + // Allocate sufficient temporary storage for the worst case (smallest sub-group size = most sub-groups). const std::uint32_t __max_num_sub_groups_local = __work_group_size / __min_sub_group_size; const std::uint32_t __max_num_sub_groups_global = __max_num_sub_groups_local * __num_work_groups; const std::uint32_t __max_inputs_per_work_group = __work_group_size * __max_inputs_per_item; @@ -2164,7 +2123,7 @@ __parallel_transform_reduce_then_scan(sycl::queue& __q, const std::size_t __n, _ // Use SLM-based sub-group communication for non-trivially-copyable types or CPU targets // (where native sub-group operations are slow). - const bool __use_slm_for_comm = !std::is_trivially_copyable_v<_ValueType>; + const bool __use_slm_for_comm = !std::is_trivially_copyable_v<_ValueType> || !__q.get_device().is_gpu(); // Reduce and scan step implementations using _ReduceSubmitter = diff --git a/include/oneapi/dpl/pstl/hetero/numeric_impl_hetero.h b/include/oneapi/dpl/pstl/hetero/numeric_impl_hetero.h index 555b5004b6b..0f0aee7774b 100644 --- a/include/oneapi/dpl/pstl/hetero/numeric_impl_hetero.h +++ b/include/oneapi/dpl/pstl/hetero/numeric_impl_hetero.h @@ -135,51 +135,14 @@ __pattern_transform_scan_base(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy& auto __keep1 = oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::read>(); auto __buf1 = __keep1(__first, __last); - // This is a temporary workaround for an in-place exclusive scan while the SYCL backend scan pattern is not fixed. - const bool __is_scan_inplace_exclusive = __n > 1 && !_Inclusive{} && __iterators_possibly_equal(__first, __result); - if (!__is_scan_inplace_exclusive) - { - auto __keep2 = oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::write, - /*_IsNoInitRequested=*/true>(); - auto __buf2 = __keep2(__result, __result + __n); - - oneapi::dpl::__par_backend_hetero::__parallel_transform_scan( - _BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec), __buf1.all_view(), __buf2.all_view(), __n, - __unary_op, __init, __binary_op, _Inclusive{}) - .__checked_deferrable_wait(); - } - else - { - assert(__n > 1); - assert(!_Inclusive{}); - assert(__iterators_possibly_equal(__first, __result)); - - using _Type = typename _InitType::__value_type; - - auto __policy = - __par_backend_hetero::make_wrapped_policy(::std::forward<_ExecutionPolicy>(__exec)); - - // Create temporary buffer - oneapi::dpl::__par_backend_hetero::__buffer<_Type> __tmp_buf(__n); - auto __first_tmp = __tmp_buf.get(); - auto __last_tmp = __first_tmp + __n; - auto __keep2 = oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::write, - /*_IsNoInitRequested=*/true>(); - auto __buf2 = __keep2(__first_tmp, __last_tmp); - - // Run main algorithm and save data into temporary buffer - oneapi::dpl::__par_backend_hetero::__parallel_transform_scan(_BackendTag{}, __policy, __buf1.all_view(), - __buf2.all_view(), __n, __unary_op, __init, - __binary_op, _Inclusive{}) - .wait(); - - // Move data from temporary buffer into results - oneapi::dpl::__internal::__pattern_walk2_brick( - __tag, ::std::move(__policy), __first_tmp, __last_tmp, __result, - oneapi::dpl::__internal::__brick_move<__hetero_tag<_BackendTag>>{}); - - //TODO: optimize copy back depending on Iterator, i.e. set_final_data for host iterator/pointer - } + auto __keep2 = oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::write, + /*_IsNoInitRequested=*/true>(); + auto __buf2 = __keep2(__result, __result + __n); + + oneapi::dpl::__par_backend_hetero::__parallel_transform_scan( + _BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec), __buf1.all_view(), __buf2.all_view(), __n, + __unary_op, __init, __binary_op, _Inclusive{}) + .__checked_deferrable_wait(); return __result + __n; }