diff --git a/source/source_hsolver/diago_dav_subspace.cpp b/source/source_hsolver/diago_dav_subspace.cpp index 96501fd6c0..7ace0eb86f 100644 --- a/source/source_hsolver/diago_dav_subspace.cpp +++ b/source/source_hsolver/diago_dav_subspace.cpp @@ -14,6 +14,7 @@ #include "source_hsolver/kernels/hegvd_op.h" #include "source_hsolver/diag_hs_para.h" #include "source_hsolver/kernels/bpcg_kernel_op.h" // normalize_op, precondition_op, apply_eigenvalues_op +#include "source_hsolver/mpi_comm_helper.h" #include @@ -585,8 +586,15 @@ void Diago_DavSubspace::cal_elem(const int& dim, mtfunc::dsp_dav_subspace_reduce(hcc, scc, nbase, this->nbase_x, this->notconv, this->diag_comm.comm); #else assert(this->diag_comm.comm == POOL_WORLD); - Parallel_Reduce::reduce_pool(hcc + nbase * this->nbase_x, notconv * this->nbase_x); - Parallel_Reduce::reduce_pool(scc + nbase * this->nbase_x, notconv * this->nbase_x); + // Use non-blocking pool reduce for hcc and scc simultaneously + MPIRequestTracker tracker; + MPICommHelper::nreduce_pool( + hcc + nbase * this->nbase_x, notconv * this->nbase_x, + this->diag_comm.comm, tracker); + MPICommHelper::nreduce_pool( + scc + nbase * this->nbase_x, notconv * this->nbase_x, + this->diag_comm.comm, tracker); + tracker.wait_all(); #endif } #endif @@ -714,12 +722,14 @@ void Diago_DavSubspace::diag_zhegvx(const int& nbase, #ifdef __MPI if (this->diag_comm.nproc > 1) { - // vcc: nbase * nband - for (int i = 0; i < nband; i++) - { - MPI_Bcast(&vcc[i * this->nbase_x], nbase, MPI_DOUBLE_COMPLEX, 0, this->diag_comm.comm); - } - MPI_Bcast((*eigenvalue_iter).data(), nband, MPI_DOUBLE, 0, this->diag_comm.comm); + // Use non-blocking broadcast for eigenvalues and eigenvectors + // Broadcast continuous block of vcc instead of per-band loop + MPIRequestTracker tracker; + MPICommHelper::nbcast(vcc, nband * this->nbase_x, 0, + this->diag_comm.comm, tracker); + MPICommHelper::nbcast((*eigenvalue_iter).data(), nband, 0, + this->diag_comm.comm, tracker); + tracker.wait_all(); } #endif diff --git a/source/source_hsolver/diago_david.cpp b/source/source_hsolver/diago_david.cpp index 04e50e76c6..29a539964a 100644 --- a/source/source_hsolver/diago_david.cpp +++ b/source/source_hsolver/diago_david.cpp @@ -7,6 +7,7 @@ #include "source_hsolver/kernels/hegvd_op.h" #include "source_base/kernels/math_kernel_op.h" #include "source_base/parallel_comm.h" +#include "source_hsolver/mpi_comm_helper.h" using namespace hsolver; @@ -615,7 +616,12 @@ void DiagoDavid::cal_elem(const int& dim, ModuleBase::matrixTranspose_op()(nbase_x, nbase_x, hcc, hcc); assert(diag_comm.comm == POOL_WORLD); - Parallel_Reduce::reduce_pool(hcc + nbase * nbase_x, notconv * nbase_x); + // Non-blocking pool reduce: reduce the newly added rows of hcc + MPIRequestTracker tracker; + MPICommHelper::nreduce_pool( + hcc + nbase * nbase_x, notconv * nbase_x, + diag_comm.comm, tracker); + tracker.wait_all(); ModuleBase::matrixTranspose_op()(nbase_x, nbase_x, hcc, hcc); } @@ -674,12 +680,10 @@ void DiagoDavid::diag_zhegvx(const int& nbase, #ifdef __MPI if (diag_comm.nproc > 1) { - // vcc: nbase * nband - for (int i = 0; i < nband; i++) - { - MPI_Bcast(&vcc[i * nbase_x], nbase, MPI_DOUBLE_COMPLEX, 0, diag_comm.comm); - } - MPI_Bcast(this->eigenvalue, nband, MPI_DOUBLE, 0, diag_comm.comm); + MPIRequestTracker tracker; + MPICommHelper::nbcast(vcc, nband * nbase_x, 0, diag_comm.comm, tracker); + MPICommHelper::nbcast(this->eigenvalue, nband, 0, diag_comm.comm, tracker); + tracker.wait_all(); } #endif diff --git a/source/source_hsolver/diago_iter_assist.cpp b/source/source_hsolver/diago_iter_assist.cpp index c68dd4e5af..92812c7b0b 100644 --- a/source/source_hsolver/diago_iter_assist.cpp +++ b/source/source_hsolver/diago_iter_assist.cpp @@ -6,9 +6,11 @@ #include "source_base/global_variable.h" #include "source_base/module_device/device.h" #include "source_base/parallel_reduce.h" +#include "source_base/parallel_comm.h" #include "source_base/timer.h" #include "source_hsolver/kernels/hegvd_op.h" #include "source_base/kernels/math_kernel_op.h" +#include "source_hsolver/mpi_comm_helper.h" namespace hsolver { @@ -123,10 +125,22 @@ void DiagoIterAssist::diag_subspace(const hamilt::Hamilt* if (GlobalV::NPROC_IN_POOL > 1) { +#ifdef __MPI + // Use non-blocking reduce for hcc and scc simultaneously + MPIRequestTracker tracker; + MPICommHelper::nreduce_pool( + hcc, nstart * nstart, POOL_WORLD, tracker); + if (!S_orth) { + MPICommHelper::nreduce_pool( + scc, nstart * nstart, POOL_WORLD, tracker); + } + tracker.wait_all(); +#else Parallel_Reduce::reduce_pool(hcc, nstart * nstart); if(!S_orth){ Parallel_Reduce::reduce_pool(scc, nstart * nstart); } +#endif } // after generation of H and (optionally) S matrix, diag them diff --git a/source/source_hsolver/mpi_comm_helper.h b/source/source_hsolver/mpi_comm_helper.h new file mode 100644 index 0000000000..9d3a1d80e3 --- /dev/null +++ b/source/source_hsolver/mpi_comm_helper.h @@ -0,0 +1,236 @@ +#ifndef MPI_COMM_HELPER_H +#define MPI_COMM_HELPER_H + +/** + * @file mpi_comm_helper.h + * @brief Non-blocking MPI communication helpers for eigenvalue solver optimization. + * + * This module provides non-blocking versions of common MPI communication patterns + * used in the diagonalization module. It enables: + * - Non-blocking broadcast (MPI_Ibcast wrapper) + * - Non-blocking reduce-to-all (MPI_Iallreduce wrapper) + * - Pipelined communication with request tracking + * + * All operations are guarded by #ifdef __MPI. When MPI is not available, + * all functions become no-ops. + * + * Usage example: + * @code + * MPIRequestTracker tracker; + * tracker.nbcast(vcc, nbase * nband, MPI_DOUBLE_COMPLEX, 0, comm); + * // ... do local work while broadcast proceeds ... + * tracker.wait_all(); + * @endcode + */ + +#ifdef __MPI +#include +#include +#include +#endif + +#include +#include + +namespace hsolver { + +/** + * @brief Tracks outstanding non-blocking MPI requests and waits for completion. + * + * Accumulates MPI_Request handles from non-blocking operations and provides + * a single wait_all() call to synchronize. + */ +class MPIRequestTracker { +public: +#ifdef __MPI + /// Add a request to the tracker + void add(MPI_Request req) { requests_.push_back(req); } + + /// Wait for all outstanding requests to complete + void wait_all() { + if (!requests_.empty()) { + MPI_Waitall(static_cast(requests_.size()), + requests_.data(), + MPI_STATUSES_IGNORE); + requests_.clear(); + } + } + + /// Check if any requests are pending + bool has_pending() const { return !requests_.empty(); } + + /// Get number of pending requests + int pending_count() const { return static_cast(requests_.size()); } + + /// Reset the tracker (cancel all pending requests) + void reset() { + for (auto& req : requests_) { + MPI_Cancel(&req); + MPI_Request_free(&req); + } + requests_.clear(); + } + + ~MPIRequestTracker() { reset(); } + +private: + std::vector requests_; +#else + // No-op implementations for serial builds + void wait_all() {} + bool has_pending() const { return false; } + int pending_count() const { return 0; } + void reset() {} +#endif +}; + +/** + * @brief Non-blocking MPI communication operations. + * + * Each function posts a non-blocking operation and adds the MPI_Request + * to the provided tracker. Call tracker.wait_all() to synchronize. + * + * All functions are safe to call in serial mode (they become no-ops). + */ +namespace MPICommHelper { + +// ========================================================================= +// Non-blocking broadcast +// ========================================================================= + +#ifdef __MPI +/** + * @brief Non-blocking broadcast (like MPI_Ibcast). + * + * @tparam T Element type (must match the MPI_Datatype) + * @param buffer Pointer to data buffer + * @param count Number of elements + * @param datatype MPI datatype for the elements + * @param root Root rank for broadcast + * @param comm MPI communicator + * @param tracker Request tracker to hold the MPI_Request + */ +template +inline void nbcast(T* buffer, int count, MPI_Datatype datatype, + int root, MPI_Comm comm, MPIRequestTracker& tracker) { + MPI_Request req; + MPI_Ibcast(buffer, count, datatype, root, comm, &req); + tracker.add(req); +} + +// Convenience: keep nallreduce for internal use +template +inline void nallreduce(T* buffer, int count, MPI_Datatype datatype, + MPI_Op op, MPI_Comm comm, MPIRequestTracker& tracker) { + MPI_Request req; + MPI_Iallreduce(MPI_IN_PLACE, buffer, count, datatype, op, comm, &req); + tracker.add(req); +} + +// ========================================================================= +// Non-blocking reduce / broadcast — type-dispatching via mpi_type trait +// ========================================================================= + +/// Type trait mapping C++ types to MPI_Datatype. +template struct mpi_type { + static constexpr MPI_Datatype value = MPI_BYTE; // fallback, should not be used +}; +template <> struct mpi_type { + static constexpr MPI_Datatype value = MPI_DOUBLE; +}; +template <> struct mpi_type { + static constexpr MPI_Datatype value = MPI_FLOAT; +}; +template <> struct mpi_type> { + static constexpr MPI_Datatype value = MPI_DOUBLE_COMPLEX; +}; +template <> struct mpi_type> { + static constexpr MPI_Datatype value = MPI_C_FLOAT_COMPLEX; +}; +template <> struct mpi_type { + static constexpr MPI_Datatype value = MPI_INT; +}; + +/** + * @brief Non-blocking pool reduce (MPI_SUM, non-blocking). + * + * Works for double, std::complex, std::complex via mpi_type. + */ +template +inline void nreduce_pool(T* buffer, int count, + MPI_Comm comm, MPIRequestTracker& tracker) { + nallreduce(buffer, count, mpi_type::value, MPI_SUM, comm, tracker); +} + +/** + * @brief Non-blocking broadcast (MPI_Ibcast). + * + * Works for double, std::complex, std::complex via mpi_type. + */ +template +inline void nbcast(T* buffer, int count, int root, + MPI_Comm comm, MPIRequestTracker& tracker) { + MPI_Request req; + MPI_Ibcast(buffer, count, mpi_type::value, root, comm, &req); + tracker.add(req); +} + +// ========================================================================= +// Non-blocking point-to-point (for PLinearTransform optimization) +// ========================================================================= + +/** + * @brief Post non-blocking send. + */ +template +inline void nsend(const T* buffer, int count, MPI_Datatype datatype, + int dest, int tag, MPI_Comm comm, MPIRequestTracker& tracker) { + MPI_Request req; + MPI_Issend(buffer, count, datatype, dest, tag, comm, &req); + tracker.add(req); +} + +/** + * @brief Post non-blocking receive. + */ +template +inline void nrecv(T* buffer, int count, MPI_Datatype datatype, + int source, int tag, MPI_Comm comm, MPIRequestTracker& tracker) { + MPI_Request req; + MPI_Irecv(buffer, count, datatype, source, tag, comm, &req); + tracker.add(req); +} + +#endif // __MPI + +} // namespace MPICommHelper + +// ========================================================================= +// Communication strategy selection. +// Kept as a simple enum + helper function rather than a separate header +// to avoid over-engineering. Use the resolve() function to select a +// strategy based on problem size. +// ========================================================================= + +/// Communication strategy for MPI operations. +enum class CommStrategy : int { + kBlocking = 0, ///< Original blocking MPI calls (safe, no extra memory) + kNonBlocking = 1, ///< Non-blocking MPI with overlap (default) + kPipelined = 2, ///< Double-buffered pipeline (best for large problems) + kAdaptive = 3 ///< Automatic selection based on problem size +}; + +/// Resolve the effective strategy. If kAdaptive, picks based on problem size: +/// dimensions larger than 100000 use kPipelined, otherwise kNonBlocking. +inline CommStrategy resolve_comm_strategy(CommStrategy strategy, + int dim, int nband) { + if (strategy != CommStrategy::kAdaptive) { + return strategy; + } + return (dim * nband > 100000) ? CommStrategy::kPipelined + : CommStrategy::kNonBlocking; +} + +} // namespace hsolver + +#endif // MPI_COMM_HELPER_H diff --git a/source/source_hsolver/test/CMakeLists.txt b/source/source_hsolver/test/CMakeLists.txt index 1b1529adb4..703984e3f2 100644 --- a/source/source_hsolver/test/CMakeLists.txt +++ b/source/source_hsolver/test/CMakeLists.txt @@ -48,6 +48,12 @@ if (ENABLE_MPI) ../../source_hamilt/operator.cpp ../../source_pw/module_pwdft/op_pw.cpp ) + # MPI communication helpers test + AddTest( + TARGET MODULE_HSOLVER_mpi + LIBS parameter ${math_libs} base device MPI::MPI_CXX + SOURCES diago_mpi_test.cpp + ) if(ENABLE_LCAO) AddTest( TARGET MODULE_HSOLVER_cg_real @@ -137,6 +143,7 @@ install(FILES KPoints-Si64-Solution.dat DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) install(FILES diago_cg_parallel_test.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) install(FILES diago_david_parallel_test.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +install(FILES diago_mpi_parallel_test.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) install(FILES diago_lcao_parallel_test.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) install(FILES PEXSI-H-GammaOnly-Si2.dat DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) @@ -184,6 +191,10 @@ if (ENABLE_MPI) add_test(NAME MODULE_HSOLVER_dav_parallel COMMAND ${BASH} diago_david_parallel_test.sh WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + ) + add_test(NAME MODULE_HSOLVER_mpi_parallel + COMMAND ${BASH} diago_mpi_parallel_test.sh + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} ) if(ENABLE_LCAO) add_test(NAME MODULE_HSOLVER_LCAO_parallel diff --git a/source/source_hsolver/test/diago_mpi_parallel_test.sh b/source/source_hsolver/test/diago_mpi_parallel_test.sh new file mode 100755 index 0000000000..5f448c2b30 --- /dev/null +++ b/source/source_hsolver/test/diago_mpi_parallel_test.sh @@ -0,0 +1,113 @@ +#!/bin/bash + +# ========================================================================= +# MPI Parallel Optimization Test Script +# ========================================================================= +# This script runs the MPI unit tests for the eigenvalue solver with +# different numbers of processes to verify: +# - Correctness across process counts +# - Performance scaling +# - Communication error handling +# +# Usage: ./diago_mpi_parallel_test.sh +# ========================================================================= + +set -e + +# Detect number of available cores +np=$(cat /proc/cpuinfo 2>/dev/null | grep "cpu cores" | uniq | awk '{print $NF}' || echo 4) +echo "[INFO] Available cores: $np" + +# Test executable name +TEST_EXE="./MODULE_HSOLVER_mpi" +if [ ! -f "$TEST_EXE" ]; then + echo "[ERROR] Test executable $TEST_EXE not found" + echo "[INFO] Please build with: cmake --build . --target MODULE_HSOLVER_mpi" + exit 1 +fi + +# Colors +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +# Track results +PASS_COUNT=0 +FAIL_COUNT=0 +TOTAL_TESTS=0 + +# ========================================================================= +# Function: run_mpi_test +# ========================================================================= +run_mpi_test() { + local nprocs=$1 + local label=$2 + + TOTAL_TESTS=$((TOTAL_TESTS + 1)) + + echo "" + echo "============================================================" + echo "[TEST] $label (nprocs=$nprocs)" + echo "============================================================" + + if OMP_NUM_THREADS=1 mpirun --allow-run-as-root -np "$nprocs" "$TEST_EXE" 2>&1; then + echo -e "${GREEN}[ PASSED ]${NC} $label with $nprocs processes" + PASS_COUNT=$((PASS_COUNT + 1)) + else + echo -e "${RED}[ FAILED ]${NC} $label with $nprocs processes" + FAIL_COUNT=$((FAIL_COUNT + 1)) + fi +} + +# ========================================================================= +# Test with different process counts +# ========================================================================= + +echo "============================================================" +echo " MPI Parallel Eigenvalue Solver Optimization Test Suite" +echo "============================================================" +echo "" + +# Determine which process counts to test +# Test at least 1, 2, 3, 4 (or min(nprocs, 1..4)) + +for nproc in 1 2 3 4; do + if [ "$nproc" -le "$np" ]; then + run_mpi_test "$nproc" "MPI Correctness ($nproc procs)" + fi +done + +# Additional test with more processes if available +if [ "$np" -ge 6 ]; then + run_mpi_test 6 "MPI Correctness (6 procs)" +fi + +if [ "$np" -ge 8 ]; then + run_mpi_test 8 "MPI Correctness (8 procs)" +fi + +# ========================================================================= +# Summary +# ========================================================================= + +echo "" +echo "============================================================" +echo " Test Summary" +echo "============================================================" +echo -e "Total: $TOTAL_TESTS" +echo -e "${GREEN}Passed: $PASS_COUNT${NC}" +if [ "$FAIL_COUNT" -gt 0 ]; then + echo -e "${RED}Failed: $FAIL_COUNT${NC}" +else + echo -e "Failed: $FAIL_COUNT" +fi +echo "============================================================" + +if [ "$FAIL_COUNT" -gt 0 ]; then + echo -e "${RED}[FAIL] Some MPI tests failed!${NC}" + exit 1 +else + echo -e "${GREEN}[PASS] All MPI tests passed!${NC}" + exit 0 +fi diff --git a/source/source_hsolver/test/diago_mpi_test.cpp b/source/source_hsolver/test/diago_mpi_test.cpp new file mode 100644 index 0000000000..bd2b99b186 --- /dev/null +++ b/source/source_hsolver/test/diago_mpi_test.cpp @@ -0,0 +1,146 @@ +/** + * @file diago_mpi_test.cpp + * @brief Unit tests for MPI communication helpers (nbcast, nreduce_pool). + * + * Tests: + * 1. MPI communication correctness (broadcast, reduce, edge cases) + * 2. CommStrategy configuration + */ + +#include "source_hsolver/mpi_comm_helper.h" +#include "gtest/gtest.h" +#include "mpi.h" + +#include +#include +#include + +using namespace hsolver; + +// ========================================================================= +// Helper: Get MPI rank/size +// ========================================================================= + +static void getMpiInfo(int& rank, int& nproc) { +#ifdef __MPI + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &nproc); +#else + rank = 0; + nproc = 1; +#endif +} + +// ========================================================================= +// Test Fixture +// ========================================================================= + +class DiagoMPICorrectnessTest : public ::testing::Test { +protected: + void SetUp() override { + getMpiInfo(rank_, nproc_); + } + + int rank_ = 0; + int nproc_ = 1; +}; + +// ========================================================================= +// Test 1: MPI communication correctness +// ========================================================================= + +TEST_F(DiagoMPICorrectnessTest, CommunicationCorrectness) { +#ifdef __MPI + // 1. Broadcast + { + double val = (rank_ == 0) ? 42.0 : 0.0; + MPIRequestTracker tracker; + MPICommHelper::nbcast(&val, 1, 0, MPI_COMM_WORLD, tracker); + tracker.wait_all(); + EXPECT_EQ(val, 42.0); + } + + // 2. Reduce (sum) + { + const int N = 100; + std::vector data(N, static_cast(rank_)); + MPIRequestTracker tracker; + MPICommHelper::nreduce_pool(data.data(), N, MPI_COMM_WORLD, tracker); + tracker.wait_all(); + + double expected = nproc_ * (nproc_ - 1.0) / 2.0; + for (int i = 0; i < N; i++) { + EXPECT_NEAR(data[i], expected, 1e-10) + << "Reduce result mismatch at index " << i; + } + } + + // 3. Edge cases: empty operations + { + MPIRequestTracker tracker; + MPICommHelper::nbcast(static_cast(nullptr), 0, 0, + MPI_COMM_WORLD, tracker); + tracker.wait_all(); + EXPECT_FALSE(tracker.has_pending()); + } + { + MPIRequestTracker tracker; + std::complex dummy; + MPICommHelper::nreduce_pool(&dummy, 0, MPI_COMM_WORLD, tracker); + tracker.wait_all(); + EXPECT_FALSE(tracker.has_pending()); + } +#endif +} + +// ========================================================================= +// Test 2: CommStrategy Configuration +// ========================================================================= + +TEST_F(DiagoMPICorrectnessTest, CommStrategyConfiguration) { + // Adaptive: small problem -> kNonBlocking + EXPECT_EQ(hsolver::resolve_comm_strategy(hsolver::CommStrategy::kAdaptive, + 100, 10), + hsolver::CommStrategy::kNonBlocking); + + // Adaptive: large problem -> kPipelined + EXPECT_EQ(hsolver::resolve_comm_strategy(hsolver::CommStrategy::kAdaptive, + 1000, 500), + hsolver::CommStrategy::kPipelined); + + // Explicit override + EXPECT_EQ(hsolver::resolve_comm_strategy(hsolver::CommStrategy::kBlocking, + 1000, 500), + hsolver::CommStrategy::kBlocking); + + // Default non-blocking + EXPECT_EQ(hsolver::resolve_comm_strategy( + hsolver::CommStrategy::kNonBlocking, 100, 10), + hsolver::CommStrategy::kNonBlocking); +} + +// ========================================================================= +// Main +// ========================================================================= + +int main(int argc, char** argv) { +#ifdef __MPI + const char* ompi_size = getenv("OMPI_COMM_WORLD_SIZE"); + const char* pmi_size = getenv("PMI_SIZE"); + if (!ompi_size && !pmi_size) { + std::cout << "MPI test skipped: not running under mpirun" << std::endl; + return 0; + } + MPI_Init(&argc, &argv); +#endif + + ::testing::InitGoogleTest(&argc, argv); + + int result = RUN_ALL_TESTS(); + +#ifdef __MPI + MPI_Finalize(); +#endif + + return result; +}