diff --git a/include/oneapi/dpl/pstl/parallel_backend_utils.h b/include/oneapi/dpl/pstl/parallel_backend_utils.h index af85a4fc5e9..82a44069206 100644 --- a/include/oneapi/dpl/pstl/parallel_backend_utils.h +++ b/include/oneapi/dpl/pstl/parallel_backend_utils.h @@ -280,25 +280,36 @@ __set_difference_construct(_ForwardIterator1 __first1, _ForwardIterator1 __last1 { using _Tp = typename ::std::iterator_traits<_OutputIterator>::value_type; - for (; __first1 != __last1;) + // __proj1_val < __proj2_val + auto __op_val1_lt_val2 = [](_ForwardIterator1& __it1, _ForwardIterator2&, _OutputIterator& __out_it) { + ::new (::std::addressof(*__out_it)) _Tp(*__it1); + ++__it1; + ++__out_it; + }; + + // __proj2_val < __proj1_val + auto __op_val2_lt_val1 = [](_ForwardIterator1&, _ForwardIterator2& __it2, _OutputIterator&) { ++__it2; }; + + // __proj1_val == __proj2_val + auto __op_val1_eq_val2 = [](_ForwardIterator1& __it1, _ForwardIterator2& __it2, _OutputIterator&) { + ++__it1; + ++__it2; + }; + + // 1. Main set_difference operation + while (__first1 != __last1 && __first2 != __last2) { - if (__first2 == __last2) - return __cc_range(__first1, __last1, __result); - - if (std::invoke(__comp, std::invoke(__proj1, *__first1), std::invoke(__proj2, *__first2))) - { - ::new (::std::addressof(*__result)) _Tp(*__first1); - ++__result; - ++__first1; - } - else - { - if (!std::invoke(__comp, std::invoke(__proj2, *__first2), std::invoke(__proj1, *__first1))) - ++__first1; - ++__first2; - } + const bool __val1_lt_val2 = + std::invoke(__comp, std::invoke(__proj1, *__first1), std::invoke(__proj2, *__first2)); + const bool __val2_lt_val1 = + !__val1_lt_val2 && std::invoke(__comp, std::invoke(__proj2, *__first2), std::invoke(__proj1, *__first1)); + + __val1_lt_val2 ? __op_val1_lt_val2(__first1, __first2, __result) + : (__val2_lt_val1 ? __op_val2_lt_val1(__first1, __first2, __result) + : __op_val1_eq_val2(__first1, __first2, __result)); } - return __result; + + return __cc_range(__first1, __last1, __result); } template ::value_type; - for (; __first1 != __last1;) + while (__first1 != __last1 && __first2 != __last2) { - if (__first2 == __last2) - return __cc_range(__first1, __last1, __result); - if (std::invoke(__comp, std::invoke(__proj1, *__first1), std::invoke(__proj2, *__first2))) { ::new (::std::addressof(*__result)) _Tp(*__first1); @@ -333,7 +341,11 @@ __set_symmetric_difference_construct(_ForwardIterator1 __first1, _ForwardIterato ++__first2; } } - return __cc_range(__first2, __last2, __result); + + if (__first1 == __last1) + return __cc_range(__first2, __last2, __result); + + return __cc_range(__first1, __last1, __result); } template