diff --git a/pypolychord/_pypolychord.cpp b/pypolychord/_pypolychord.cpp index 72b88ab5..ac92ea69 100644 --- a/pypolychord/_pypolychord.cpp +++ b/pypolychord/_pypolychord.cpp @@ -5,12 +5,19 @@ #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION #include +#ifdef USE_MPI +#include +#include +#endif /* Initialize the module */ #ifdef PYTHON3 PyMODINIT_FUNC PyInit__pypolychord(void) { import_array(); +#ifdef USE_MPI + import_mpi4py(); +#endif return PyModule_Create(&_pypolychordmodule); } #else @@ -18,6 +25,9 @@ PyMODINIT_FUNC init_pypolychord(void) { Py_InitModule3("_pypolychord", module_methods, module_docstring); import_array(); +#ifdef USE_MPI + import_mpi4py(); +#endif } #endif @@ -121,12 +131,12 @@ static PyObject *run_pypolychord(PyObject *, PyObject *args) Settings S; PyObject *temp_logl, *temp_prior, *temp_dumper; - PyObject* py_grade_dims, *py_grade_frac, *py_nlives; + PyObject* py_grade_dims, *py_grade_frac, *py_nlives, *py_comm; char* base_dir, *file_root; if (!PyArg_ParseTuple(args, - "OOOiiiiiiiiddidiiiiiiiiiiidissO!O!O!i:run", + "OOOiiiiiiiiddidiiiiiiiiiiidissO!O!O!iO:run", &temp_logl, &temp_prior, &temp_dumper, @@ -163,7 +173,8 @@ static PyObject *run_pypolychord(PyObject *, PyObject *args) &py_grade_dims, &PyDict_Type, &py_nlives, - &S.seed + &S.seed, + &py_comm ) ) return NULL; @@ -216,7 +227,14 @@ static PyObject *run_pypolychord(PyObject *, PyObject *args) python_dumper = temp_dumper; /* Run PolyChord */ - try{ run_polychord(loglikelihood, prior, dumper, S); } + try{ +#ifdef USE_MPI + MPI_Comm *comm = PyMPIComm_Get(py_comm); + run_polychord(loglikelihood, prior, dumper, S, *comm); +#else + run_polychord(loglikelihood, prior, dumper, S); +#endif + } catch (PythonException& e) { Py_DECREF(py_grade_frac);Py_DECREF(py_grade_dims);Py_DECREF(python_loglikelihood);Py_DECREF(python_prior); diff --git a/pypolychord/polychord.py b/pypolychord/polychord.py index 947fe4b7..c3a68b62 100644 --- a/pypolychord/polychord.py +++ b/pypolychord/polychord.py @@ -14,7 +14,7 @@ def default_dumper(live, dead, logweights, logZ, logZerr): def run_polychord(loglikelihood, nDims, nDerived, settings, - prior=default_prior, dumper=default_dumper): + prior=default_prior, dumper=default_dumper, comm=None): """ Runs PolyChord. @@ -143,13 +143,15 @@ def run_polychord(loglikelihood, nDims, nDerived, settings, Final output evidence statistics """ - - try: - from mpi4py import MPI - comm = MPI.COMM_WORLD - rank = comm.Get_rank() - except ImportError: - rank = 0 + rank = 0 + if comm is None: + try: + from mpi4py import MPI + comm = MPI.COMM_WORLD + except ImportError: + pass + if comm is not None: + rank = comm.rank if rank == 0: Path(settings.cluster_dir).mkdir(parents=True, exist_ok=True) @@ -207,7 +209,8 @@ def wrap_prior(cube, theta): settings.grade_frac, settings.grade_dims, settings.nlives, - settings.seed) + settings.seed, + comm) if settings.cube_samples is not None: settings.read_resume = read_resume @@ -512,10 +515,9 @@ def run(loglikelihood, nDims, **kwargs): try: from mpi4py import MPI - comm = MPI.COMM_WORLD - rank = comm.Get_rank() + default_comm = MPI.COMM_WORLD except ImportError: - rank = 0 + default_comm = None paramnames = kwargs.pop('paramnames', None) @@ -552,6 +554,7 @@ def run(loglikelihood, nDims, **kwargs): 'grade_dims': [nDims], 'nlives': {}, 'seed': -1, + 'comm': default_comm, } default_kwargs['grade_frac'] = ([1.0]*len(default_kwargs['grade_dims']) if 'grade_dims' not in kwargs else @@ -563,6 +566,11 @@ def run(loglikelihood, nDims, **kwargs): default_kwargs.update(kwargs) kwargs = default_kwargs + if kwargs['comm'] is not None: + rank = kwargs['comm'].rank + else: + rank = 0 + if rank == 0: (Path(kwargs['base_dir']) / kwargs['cluster_dir']).mkdir( parents=True, exist_ok=True) @@ -631,7 +639,7 @@ def wrap_prior(cube, theta): kwargs['grade_dims'], kwargs['nlives'], kwargs['seed'], - ) + kwargs['comm']) if 'cube_samples' in kwargs: kwargs['read_resume'] = read_resume diff --git a/setup.py b/setup.py index 09558f5a..d075db15 100644 --- a/setup.py +++ b/setup.py @@ -12,10 +12,19 @@ import numpy -def check_compiler(default_CC="gcc"): +try: + import mpi4py +except ImportError: + mpi4py_get_include = None +else: + mpi4py_get_include = mpi4py.get_include() + + +def check_compiler(): """Checks what compiler is being used (clang, intel, or gcc).""" - CC = default_CC if "CC" not in os.environ else os.environ["CC"] + CC = os.getenv('CC', 'mpicc' if mpi4py_get_include else 'gcc') + os.environ['CC'] = CC CC_version = subprocess.check_output([CC, "-v"], stderr=subprocess.STDOUT).decode("utf-8").lower() if "clang" in CC_version: @@ -107,14 +116,22 @@ def run(self): subprocess.run(["make", "veryclean"], check=True, env=os.environ) return super().run() + +include_dirs = ['src/polychord', numpy.get_include()] + if "--no-mpi" in sys.argv: NAME += '_nompi' DOCLINES[1] = DOCLINES[1] + ' (cannot be used with MPI)' +elif mpi4py_get_include: + CPPRUNTIMELIB_FLAG += ["-DUSE_MPI"] + print(mpi4py_get_include) + include_dirs += [mpi4py_get_include] + pypolychord_module = Extension( name='_pypolychord', library_dirs=['lib'], - include_dirs=['src/polychord', numpy.get_include()], + include_dirs=include_dirs, libraries=['chord',], extra_link_args=RPATH_FLAG + CPPRUNTIMELIB_FLAG, extra_compile_args= ["-std=c++11"] + RPATH_FLAG + CPPRUNTIMELIB_FLAG, diff --git a/src/polychord/mpi_utils.F90 b/src/polychord/mpi_utils.F90 index 0c587bbf..e5bd84bc 100644 --- a/src/polychord/mpi_utils.F90 +++ b/src/polychord/mpi_utils.F90 @@ -22,6 +22,9 @@ module mpi_module integer, parameter :: tag_run_epoch_babies=10 integer, parameter :: tag_run_stop=11 + integer, parameter :: tag_tag_gen=12 + integer, parameter :: tag_tag_run=13 + type mpi_bundle integer :: rank integer :: nprocs @@ -476,6 +479,7 @@ function catch_seed(seed_point,cholesky,logL,epoch,mpi_information) result(more_ real(dp),intent(out),dimension(:) :: seed_point !> The seed point to be caught real(dp),intent(out),dimension(:,:) :: cholesky !> Cholesky matrix to be caught real(dp),intent(out) :: logL !> loglikelihood contour to be caught + integer :: tag_run integer, intent(out) :: epoch type(mpi_bundle), intent(in) :: mpi_information @@ -483,26 +487,37 @@ function catch_seed(seed_point,cholesky,logL,epoch,mpi_information) result(more_ integer, dimension(MPI_STATUS_SIZE) :: mpistatus ! status identifier - call MPI_RECV( &! - seed_point, &! - size(seed_point), &! - MPI_DOUBLE_PRECISION, &! + tag_run, &! + 1, &! + MPI_INTEGER, &! mpi_information%root, &! - MPI_ANY_TAG, &! + tag_tag_run, &! mpi_information%communicator,&! mpistatus, &! mpierror &! ) - if(mpistatus(MPI_TAG) == tag_run_stop ) then + + if(tag_run == tag_run_stop) then more_points_needed = .false. return - else if(mpistatus(MPI_TAG) == tag_run_seed) then + else if(tag_run == tag_run_seed) then more_points_needed = .true. else call halt_program('worker error: unrecognised tag') end if + call MPI_RECV( &! + seed_point, &! + size(seed_point), &! + MPI_DOUBLE_PRECISION, &! + mpi_information%root, &! + tag_run, &! + mpi_information%communicator,&! + mpistatus, &! + mpierror &! + ) + call MPI_RECV( &! cholesky, &! size(cholesky,1)*size(cholesky,1),&! @@ -548,26 +563,34 @@ subroutine throw_seed(seed_point,cholesky,logL,mpi_information,worker_id,epoch,k type(mpi_bundle),intent(in) :: mpi_information !> mpi handle integer, intent(in) :: worker_id !> identity of target worker integer, intent(in) :: epoch !> epoch of seed - logical, intent(in) :: keep_going !> Further signal whether to keep going + logical, intent(in) :: keep_going !> Further signal whether to keep going - integer :: tag ! tag variable to - - tag = tag_run_stop ! Default tag is stop tag - if(keep_going) tag = tag_run_seed ! If we want to keep going then change this to the seed tag + integer :: tag_run ! tag variable to + tag_run = tag_run_stop ! Default tag is stop tag + if(keep_going) tag_run = tag_run_seed ! If we want to keep going then change this to the seed tag call MPI_SEND( &! - seed_point, &! - size(seed_point), &! - MPI_DOUBLE_PRECISION, &! + tag_run, &! + 1, &! + MPI_INTEGER, &! worker_id, &! - tag, &! + tag_tag_run, &! mpi_information%communicator,&! mpierror &! ) if(.not. keep_going) return ! Stop here if we're wrapping up + call MPI_SEND( &! + seed_point, &! + size(seed_point), &! + MPI_DOUBLE_PRECISION, &! + worker_id, &! + tag_run, &! + mpi_information%communicator, &! + mpierror &! + ) call MPI_SEND( &! cholesky, &! size(cholesky,1)*size(cholesky,2),&! @@ -620,17 +643,14 @@ subroutine no_more_points(mpi_information,worker_id) type(mpi_bundle), intent(in) :: mpi_information integer, intent(in) :: worker_id !> Worker to request a new point from - - integer :: empty_buffer(0) ! empty buffer to send - call MPI_SEND( & - empty_buffer, &! not sending anything - 0, &! size of nothing - MPI_INTEGER, &! sending no integers - worker_id, &! process id to send to - tag_gen_stop, &! continuation tag - mpi_information%communicator,&! mpi handle - mpierror &! error flag + tag_gen_stop, &! + 1, &! + MPI_INTEGER, &! + worker_id, &! + tag_tag_gen, &! + mpi_information%communicator,&! + mpierror &! ) end subroutine no_more_points @@ -650,6 +670,16 @@ subroutine request_live_point(live_point,mpi_information,worker_id) integer, intent(in) :: worker_id !> Worker to request a new point from real(dp), intent(in), dimension(:) :: live_point !> The live point to be sent + call MPI_SEND( & + tag_gen_request, &! + 1, &! + MPI_INTEGER, &! + worker_id, &! + tag_tag_gen, &! + mpi_information%communicator,&! + mpierror &! + ) + call MPI_SEND( &! live_point, &! live point being sent @@ -675,6 +705,29 @@ function live_point_needed(live_point,mpi_information) integer, dimension(MPI_STATUS_SIZE) :: mpistatus ! status identifier logical :: live_point_needed !> Whether we need more points or not + integer :: tag_gen + + call MPI_RECV( &! + tag_gen, &! + 1, &! + MPI_INTEGER, &! + mpi_information%root, &! + tag_tag_gen, &! + mpi_information%communicator,&! + mpistatus, &! + mpierror &! + ) + + ! If we've recieved a kill signal, then exit this loop + if(tag_gen == tag_gen_stop) then + live_point_needed = .false. + return + else if(tag_gen == tag_gen_request) then + live_point_needed = .true. + else + call halt_program('generate error: unrecognised tag') + end if + call MPI_RECV( &! live_point, &! live point recieved @@ -687,15 +740,6 @@ function live_point_needed(live_point,mpi_information) mpierror &! error flag ) - ! If we've recieved a kill signal, then exit this loop - if(mpistatus(MPI_TAG) == tag_gen_stop ) then - live_point_needed = .false. - else if(mpistatus(MPI_TAG) == tag_gen_request) then - live_point_needed = .true. - else - call halt_program('generate error: unrecognised tag') - end if - end function live_point_needed diff --git a/tests/test_custom_mpi_comm.py b/tests/test_custom_mpi_comm.py new file mode 100644 index 00000000..e3aa70cf --- /dev/null +++ b/tests/test_custom_mpi_comm.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +""" +Test custom MPI communicator functionality in PolyChord. + +This test validates that the new `comm` parameter in pypolychord.run() +and run_polychord() correctly accepts custom MPI communicators instead +of just using MPI.COMM_WORLD. +""" + +import numpy as np +import pytest + +try: + from mpi4py import MPI + HAS_MPI = True +except ImportError: + HAS_MPI = False + +# Only run MPI tests if mpi4py is available +pytest_plugins = [] +if HAS_MPI: + import pypolychord + from pypolychord.priors import UniformPrior + from pypolychord.settings import PolyChordSettings + + +def gaussian_likelihood(theta): + """Simple 2D gaussian likelihood for testing.""" + nDims = len(theta) + r2 = sum(theta**2) + logL = -0.5 * r2 + return logL, [r2] # return likelihood and derived parameter + + +@pytest.mark.skipif(not HAS_MPI, reason="mpi4py not available") +def test_custom_communicator_interface_exists(): + """Test that the comm parameter exists in the interfaces.""" + # Just test that we can call the functions with comm parameter without hanging + import inspect + + # Check run_polychord signature has explicit comm parameter + run_polychord_sig = inspect.signature(pypolychord.run_polychord) + assert 'comm' in run_polychord_sig.parameters + + # run() uses **kwargs, so comm should be accepted as keyword argument + # Test that we can bind comm parameter to run() + run_sig = inspect.signature(pypolychord.run) + try: + bound_args = run_sig.bind(gaussian_likelihood, 2, comm=None) + # If we get here, comm is accepted as a keyword argument + assert True + except TypeError: + pytest.fail("run() does not accept comm as keyword argument") + + # This test just validates the interface exists - actual MPI testing + # would require more complex setup to avoid hangs + + +@pytest.mark.skipif(not HAS_MPI, reason="mpi4py not available") +def test_none_communicator_parameter(): + """Test that passing comm=None parameter is accepted.""" + # Just test that the parameter is accepted without actually running + # (to avoid MPI hangs in CI) + comm = MPI.COMM_WORLD + + # Test that we can pass None without error + try: + # We won't actually run this to avoid hangs, just test signature + import inspect + sig = inspect.signature(pypolychord.run) + bound_args = sig.bind( + gaussian_likelihood, 2, + nlive=5, comm=None + ) + # If we get here, the signature accepts comm=None + assert True + except TypeError: + pytest.fail("comm=None parameter not accepted") + + +if __name__ == "__main__": + # Simple interface test only to avoid MPI hangs + if HAS_MPI: + print("Testing MPI communicator interface...") + test_custom_communicator_interface_exists() + print("Interface test passed!") + test_none_communicator_parameter() + print("None parameter test passed!") + else: + print("mpi4py not available, skipping tests") \ No newline at end of file diff --git a/tests/test_mpi_functional.py b/tests/test_mpi_functional.py new file mode 100644 index 00000000..a74bc3c8 --- /dev/null +++ b/tests/test_mpi_functional.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +""" +Minimal functional test for custom MPI communicator support. +This actually runs PolyChord with different communicators. +""" + +try: + from mpi4py import MPI + import pypolychord + from pypolychord.priors import UniformPrior + HAS_MPI = True +except ImportError: + HAS_MPI = False + exit(0) + +def gaussian_likelihood(theta): + """Simple Gaussian likelihood centered at origin.""" + # 2D Gaussian with unit variance + logL = -0.5 * sum(theta**2) + r2 = sum(theta**2) + return logL, [r2] # return log likelihood and derived parameter + +if __name__ == "__main__": + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + size = comm.Get_size() + + print(f"Rank {rank}/{size}: Testing custom MPI communicator functionality") + + # Test 1: Default behavior (should use COMM_WORLD internally) + try: + output1 = pypolychord.run( + gaussian_likelihood, 2, # 2D Gaussian + nlive=10, num_repeats=5, max_ndead=50, + file_root=f'test_default_{rank}', + write_resume=False, write_dead=False, + write_stats=False, + feedback=0, do_clustering=False + ) + # run() returns None when anesthetic not installed, but that means it completed + if output1 is None: + print(f"Rank {rank}: Default comm test - SUCCESS (completed, anesthetic not available)") + else: + print(f"Rank {rank}: Default comm test - SUCCESS (logZ ≈ {output1.logZ:.2f})") + except Exception as e: + print(f"Rank {rank}: Default comm test - FAILED: {e}") + + # Test 2: Explicit COMM_WORLD + try: + output2 = pypolychord.run( + gaussian_likelihood, 2, # 2D Gaussian + nlive=10, num_repeats=5, max_ndead=50, + file_root=f'test_explicit_{rank}', + write_resume=False, write_dead=False, + write_stats=False, + feedback=0, do_clustering=False, + comm=comm # Explicit COMM_WORLD + ) + if output2 is None: + print(f"Rank {rank}: Explicit COMM_WORLD test - SUCCESS (completed, anesthetic not available)") + else: + print(f"Rank {rank}: Explicit COMM_WORLD test - SUCCESS (logZ ≈ {output2.logZ:.2f})") + except Exception as e: + print(f"Rank {rank}: Explicit COMM_WORLD test - FAILED: {e}") + + # Test 3: Duplicated communicator + try: + dup_comm = comm.Dup() + output3 = pypolychord.run( + gaussian_likelihood, 2, # 2D Gaussian + nlive=10, num_repeats=5, max_ndead=50, + file_root=f'test_dup_{rank}', + write_resume=False, write_dead=False, + write_stats=False, + feedback=0, do_clustering=False, + comm=dup_comm # Custom duplicated communicator + ) + dup_comm.Free() + if output3 is None: + print(f"Rank {rank}: Duplicated communicator test - SUCCESS (completed, anesthetic not available)") + else: + print(f"Rank {rank}: Duplicated communicator test - SUCCESS (logZ ≈ {output3.logZ:.2f})") + except Exception as e: + print(f"Rank {rank}: Duplicated communicator test - FAILED: {e}") + + print(f"Rank {rank}: All tests completed") + + # Synchronize before exit + comm.Barrier() \ No newline at end of file