diff --git a/cpp/oneapi/dal/algo/subgraph_isomorphism/backend/cpu/matching.hpp b/cpp/oneapi/dal/algo/subgraph_isomorphism/backend/cpu/matching.hpp index 61d9064be6e..1ae6a947b3b 100644 --- a/cpp/oneapi/dal/algo/subgraph_isomorphism/backend/cpu/matching.hpp +++ b/cpp/oneapi/dal/algo/subgraph_isomorphism/backend/cpu/matching.hpp @@ -47,6 +47,7 @@ class matching_engine { inner_alloc alloc); virtual ~matching_engine(); + template void run_and_wait(global_stack& gstack, std::int64_t& busy_engine_count, std::int64_t& current_match_count, @@ -59,6 +60,7 @@ class matching_engine { std::int64_t state_exploration_bit(bool check_solution = true); std::int64_t state_exploration_list(bool check_solution = true); + template bool check_if_max_match_count_reached(std::int64_t& cumulative_match_count, std::int64_t delta, std::int64_t target_match_count); @@ -92,9 +94,40 @@ class matching_engine { std::int64_t extract_candidates(bool check_solution); bool check_vertex_candidate(bool check_solution, std::int64_t candidate); + template void set_not_busy(bool& is_busy_engine, std::int64_t& busy_engine_count); }; +template +void increment_shared_value(std::int64_t& value, std::int64_t delta = 1) { + if constexpr (is_parallel) { + dal::detail::atomic_increment(value, delta); + } + else { + value += delta; + } +} + +template +void decrement_shared_value(std::int64_t& value, std::int64_t delta = 1) { + if constexpr (is_parallel) { + dal::detail::atomic_decrement(value, delta); + } + else { + value -= delta; + } +} + +template +std::int64_t load_shared_value(std::int64_t& value) { + if constexpr (is_parallel) { + return dal::detail::atomic_load(value); + } + else { + return value; + } +} + template class engine_bundle { public: @@ -372,28 +405,31 @@ std::int64_t matching_engine::state_exploration() { } template +template bool matching_engine::check_if_max_match_count_reached(std::int64_t& cumulative_match_count, std::int64_t delta, std::int64_t target_match_count) { bool is_reached = false; if (delta > 0) { - dal::detail::atomic_increment(cumulative_match_count, delta); + increment_shared_value(cumulative_match_count, delta); } - if (dal::detail::atomic_load(cumulative_match_count) >= target_match_count) { + if (load_shared_value(cumulative_match_count) >= target_match_count) { is_reached = true; } return is_reached; } template +template void matching_engine::set_not_busy(bool& is_busy_engine, std::int64_t& busy_engine_count) { if (is_busy_engine) { is_busy_engine = false; - dal::detail::atomic_decrement(busy_engine_count); + decrement_shared_value(busy_engine_count); } } template +template void matching_engine::run_and_wait(global_stack& gstack, std::int64_t& busy_engine_count, std::int64_t& cumulative_match_count, @@ -407,8 +443,8 @@ void matching_engine::run_and_wait(global_stack& gstack, ONEDAL_ASSERT(pattern != nullptr); for (;;) { if (target_match_count > 0 && - dal::detail::atomic_load(cumulative_match_count) >= target_match_count) { - set_not_busy(is_busy_engine, busy_engine_count); + load_shared_value(cumulative_match_count) >= target_match_count) { + set_not_busy(is_busy_engine, busy_engine_count); break; } if (hlocal_stack.states_in_stack() > 0) { @@ -417,27 +453,28 @@ void matching_engine::run_and_wait(global_stack& gstack, ONEDAL_ASSERT(hlocal_stack.states_in_stack() > 0); const auto delta = state_exploration(); current_match_count += delta; - if (target_match_count > 0 && check_if_max_match_count_reached(cumulative_match_count, - delta, - target_match_count)) { - set_not_busy(is_busy_engine, busy_engine_count); + if (target_match_count > 0 && + check_if_max_match_count_reached(cumulative_match_count, + delta, + target_match_count)) { + set_not_busy(is_busy_engine, busy_engine_count); break; } } else { gstack.pop(hlocal_stack); if (hlocal_stack.empty()) { - set_not_busy(is_busy_engine, busy_engine_count); + set_not_busy(is_busy_engine, busy_engine_count); if (target_match_count > 0 && - dal::detail::atomic_load(cumulative_match_count) >= target_match_count) { + load_shared_value(cumulative_match_count) >= target_match_count) { break; } - if (dal::detail::atomic_load(busy_engine_count) == 0) + if (load_shared_value(busy_engine_count) == 0) break; } else if (!is_busy_engine) { is_busy_engine = true; - dal::detail::atomic_increment(busy_engine_count); + increment_shared_value(busy_engine_count); } } } @@ -569,13 +606,23 @@ solution engine_bundle::run(std::int64_t max_match_count) { global_stack gstack(pattern->get_vertex_count(), allocator); std::int64_t busy_engine_count(array_size); std::int64_t cumulative_match_count(0); - dal::detail::threader_for(array_size, array_size, [&](const int index) { - engine_array[index].run_and_wait(gstack, - busy_engine_count, - cumulative_match_count, - max_match_count, - false); - }); + + if (array_size == 1) { + engine_array[0].template run_and_wait(gstack, + busy_engine_count, + cumulative_match_count, + max_match_count, + false); + } + else { + dal::detail::threader_for(array_size, array_size, [&](const int index) { + engine_array[index].template run_and_wait(gstack, + busy_engine_count, + cumulative_match_count, + max_match_count, + false); + }); + } auto aggregated_solution = combine_solutions(engine_array, array_size, max_match_count);