diff --git a/litebird_sim/__init__.py b/litebird_sim/__init__.py index f29c5e0f8..42e6c018e 100644 --- a/litebird_sim/__init__.py +++ b/litebird_sim/__init__.py @@ -82,7 +82,7 @@ ) from .madam import save_simulation_for_madam from .mbs.mbs import Mbs, MbsParameters, MbsSavedMapInfo -from .mpi import MPI_COMM_WORLD, MPI_ENABLED, MPI_CONFIGURATION, MPI_COMM_GRID +from .mpi import MPI_COMM_WORLD, MPI_ENABLED, MPI_CONFIGURATION from .mueller_convolver import MuellerConvolver from .noise import ( add_white_noise, @@ -239,7 +239,6 @@ "MPI_COMM_WORLD", "MPI_ENABLED", "MPI_CONFIGURATION", - "MPI_COMM_GRID", # mueller_convolver.py "MuellerConvolver", # observations.py diff --git a/litebird_sim/io.py b/litebird_sim/io.py index d56bb4d29..ca8900478 100644 --- a/litebird_sim/io.py +++ b/litebird_sim/io.py @@ -470,7 +470,7 @@ def write_list_of_observations( observations = [observations] except IndexError: # Empty list - # We do not want to return here, as we still need to participate to + # We do not want to return here, as we still need to participate in # the call to _compute_global_start_index below observations = [] # type: List[Observation] diff --git a/litebird_sim/mapmaking/common.py b/litebird_sim/mapmaking/common.py index 6fe8486fb..ad6652039 100644 --- a/litebird_sim/mapmaking/common.py +++ b/litebird_sim/mapmaking/common.py @@ -8,7 +8,6 @@ from numba import njit from litebird_sim.coordinates import CoordinateSystem -from litebird_sim.mpi import MPI_COMM_GRID from litebird_sim.observations import Observation from litebird_sim.pointings_in_obs import _get_pointings_array, _get_pol_angle @@ -109,15 +108,14 @@ def get_map_making_weights( except AttributeError: weights = np.ones(observations.n_detectors) - if check and MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: - if check: - # Check that there are no weird weights - assert np.all(np.isfinite(weights)), ( - f"Not all the detectors' weights are finite numbers: {weights}" - ) - assert np.all(weights > 0.0), ( - f"Not all the detectors' weights are positive: {weights}" - ) + if check: + # Check that there are no weird weights + assert np.all(np.isfinite(weights)), ( + f"Not all the detectors' weights are finite numbers: {weights}" + ) + assert np.all(weights > 0.0), ( + f"Not all the detectors' weights are positive: {weights}" + ) return weights diff --git a/litebird_sim/mapmaking/destriper.py b/litebird_sim/mapmaking/destriper.py index 80c60d8af..6e08ca145 100644 --- a/litebird_sim/mapmaking/destriper.py +++ b/litebird_sim/mapmaking/destriper.py @@ -14,7 +14,7 @@ from litebird_sim.coordinates import CoordinateSystem, coord_sys_to_healpix_string from litebird_sim.hwp import HWP -from litebird_sim.mpi import MPI_ENABLED, MPI_COMM_WORLD, MPI_COMM_GRID +from litebird_sim.mpi import MPI_ENABLED, MPI_COMM_WORLD from litebird_sim.observations import Observation from litebird_sim.pointings_in_obs import ( _get_hwp_angle, @@ -44,7 +44,7 @@ __DESTRIPER_RESULTS_FILE_NAME = "destriper_results.fits" -__BASELINES_FILE_NAME = f"baselines_mpi{MPI_COMM_GRID.COMM_OBS_GRID.rank:04d}.fits" +__BASELINES_FILE_NAME = f"baselines_mpi{MPI_COMM_WORLD.rank:04d}.fits" def _split_items_into_n_segments(n: int, num_of_segments: int) -> List[int]: @@ -495,10 +495,8 @@ def _build_nobs_matrix( ) # Now we must accumulate the result of every MPI process - if MPI_ENABLED and MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: - MPI_COMM_GRID.COMM_OBS_GRID.Allreduce( - mpi4py.MPI.IN_PLACE, nobs_matrix, op=mpi4py.MPI.SUM - ) + if MPI_ENABLED: + MPI_COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE, nobs_matrix, op=mpi4py.MPI.SUM) # `nobs_matrix_cholesky` will *not* contain the M_i maps shown in # Eq. 9 of KurkiSuonio2009, but its Cholesky decomposition, i.e., @@ -745,12 +743,8 @@ def _compute_binned_map( ) if MPI_ENABLED: - MPI_COMM_GRID.COMM_OBS_GRID.Allreduce( - mpi4py.MPI.IN_PLACE, output_sky_map, op=mpi4py.MPI.SUM - ) - MPI_COMM_GRID.COMM_OBS_GRID.Allreduce( - mpi4py.MPI.IN_PLACE, output_hit_map, op=mpi4py.MPI.SUM - ) + MPI_COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE, output_sky_map, op=mpi4py.MPI.SUM) + MPI_COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE, output_hit_map, op=mpi4py.MPI.SUM) # Step 2: compute the “binned map” (Eq. 21) _sum_map_to_binned_map( @@ -990,7 +984,7 @@ def _mpi_dot(a: List[npt.ArrayLike], b: List[npt.ArrayLike]) -> float: # the dot product local_result = sum([np.dot(x1.flatten(), x2.flatten()) for (x1, x2) in zip(a, b)]) if MPI_ENABLED: - return MPI_COMM_GRID.COMM_OBS_GRID.allreduce(local_result, op=mpi4py.MPI.SUM) + return MPI_COMM_WORLD.allreduce(local_result, op=mpi4py.MPI.SUM) else: return local_result @@ -1007,7 +1001,7 @@ def _get_stopping_factor(residual: List[npt.ArrayLike]) -> float: """ local_result = np.max(np.abs(residual)) if MPI_ENABLED: - return MPI_COMM_GRID.COMM_OBS_GRID.allreduce(local_result, op=mpi4py.MPI.MAX) + return MPI_COMM_WORLD.allreduce(local_result, op=mpi4py.MPI.MAX) else: return local_result @@ -1421,7 +1415,7 @@ def _run_destriper( bytes_in_temporary_buffers += mask.nbytes if MPI_ENABLED: - bytes_in_temporary_buffers = MPI_COMM_GRID.COMM_OBS_GRID.allreduce( + bytes_in_temporary_buffers = MPI_COMM_WORLD.allreduce( bytes_in_temporary_buffers, op=mpi4py.MPI.SUM, ) @@ -1623,103 +1617,91 @@ def my_gui_callback( binned_map = np.empty((3, number_of_pixels)) hit_map = np.empty(number_of_pixels) - if MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: - # perform the following operations when MPI is not being used - # OR when the MPI_COMM_GRID.COMM_OBS_GRID is not a NULL communicator - if do_destriping: - try: - # This will fail if the parameter is a scalar - len(params.samples_per_baseline) - - baseline_lengths_list = params.samples_per_baseline - assert len(baseline_lengths_list) == len(obs_list), ( - f"The list baseline_lengths_list has {len(baseline_lengths_list)} " - f"elements, but there are {len(obs_list)} observations" - ) - except TypeError: - # Ok, params.samples_per_baseline is a scalar, so we must - # figure out the number of samples in each baseline within - # each observation - baseline_lengths_list = [ - split_items_evenly( - n=getattr(cur_obs, components[0]).shape[1], - sub_n=int(params.samples_per_baseline), - ) - for cur_obs in obs_list - ] - - # Each element of this list is a 2D array with shape (N_det, N_baselines), - # where N_det is the number of detectors in the i-th Observation object - recycle_baselines = False - if baselines_list is None: - baselines_list = [ - np.zeros( - (getattr(cur_obs, components[0]).shape[0], len(cur_baseline)) - ) - for (cur_obs, cur_baseline) in zip(obs_list, baseline_lengths_list) - ] - else: - recycle_baselines = True - - destriped_map = np.empty((3, number_of_pixels)) - ( - baselines_list, - baseline_errors_list, - history_of_stopping_factors, - best_stopping_factor, - converged, - bytes_in_temporary_buffers, - ) = _run_destriper( - obs_list=obs_list, - nobs_matrix_cholesky=nobs_matrix_cholesky, - binned_map=binned_map, - destriped_map=destriped_map, - hit_map=hit_map, - baseline_lengths_list=baseline_lengths_list, - baselines_list_start=baselines_list, - recycle_baselines=recycle_baselines, - recycled_convergence=recycled_convergence, - dm_list=detector_mask_list, - tm_list=time_mask_list, - component=components[0], - threshold=params.threshold, - max_steps=params.iter_max, - use_preconditioner=params.use_preconditioner, - callback=callback, - callback_kwargs=callback_kwargs if callback_kwargs else {}, - ) + # perform the following operations when MPI is not being used + if do_destriping: + try: + # This will fail if the parameter is a scalar + len(params.samples_per_baseline) - if MPI_ENABLED: - bytes_in_temporary_buffers = MPI_COMM_GRID.COMM_OBS_GRID.allreduce( - bytes_in_temporary_buffers, - op=mpi4py.MPI.SUM, + baseline_lengths_list = params.samples_per_baseline + assert len(baseline_lengths_list) == len(obs_list), ( + f"The list baseline_lengths_list has {len(baseline_lengths_list)} " + f"elements, but there are {len(obs_list)} observations" + ) + except TypeError: + # Ok, params.samples_per_baseline is a scalar, so we must + # figure out the number of samples in each baseline within + # each observation + baseline_lengths_list = [ + split_items_evenly( + n=getattr(cur_obs, components[0]).shape[1], + sub_n=int(params.samples_per_baseline), ) + for cur_obs in obs_list + ] + + # Each element of this list is a 2D array with shape (N_det, N_baselines), + # where N_det is the number of detectors in the i-th Observation object + recycle_baselines = False + if baselines_list is None: + baselines_list = [ + np.zeros((getattr(cur_obs, components[0]).shape[0], len(cur_baseline))) + for (cur_obs, cur_baseline) in zip(obs_list, baseline_lengths_list) + ] else: - # No need to run the destriping, just compute the binned map with - # one single baseline set to zero - _compute_binned_map( - obs_list=obs_list, - output_sky_map=binned_map, - output_hit_map=hit_map, - nobs_matrix_cholesky=nobs_matrix_cholesky, - component=components[0], - dm_list=detector_mask_list, - tm_list=time_mask_list, - baselines_list=None, - baseline_lengths_list=[ - np.array([getattr(cur_obs, components[0]).shape[1]], dtype=int) - for cur_obs in obs_list - ], + recycle_baselines = True + + destriped_map = np.empty((3, number_of_pixels)) + ( + baselines_list, + baseline_errors_list, + history_of_stopping_factors, + best_stopping_factor, + converged, + bytes_in_temporary_buffers, + ) = _run_destriper( + obs_list=obs_list, + nobs_matrix_cholesky=nobs_matrix_cholesky, + binned_map=binned_map, + destriped_map=destriped_map, + hit_map=hit_map, + baseline_lengths_list=baseline_lengths_list, + baselines_list_start=baselines_list, + recycle_baselines=recycle_baselines, + recycled_convergence=recycled_convergence, + dm_list=detector_mask_list, + tm_list=time_mask_list, + component=components[0], + threshold=params.threshold, + max_steps=params.iter_max, + use_preconditioner=params.use_preconditioner, + callback=callback, + callback_kwargs=callback_kwargs if callback_kwargs else {}, + ) + + if MPI_ENABLED: + bytes_in_temporary_buffers = MPI_COMM_WORLD.allreduce( + bytes_in_temporary_buffers, + op=mpi4py.MPI.SUM, ) - bytes_in_temporary_buffers = 0 - destriped_map = None - baseline_lengths_list = None - baselines_list = None - baseline_errors_list = None - history_of_stopping_factors = None - best_stopping_factor = None - converged = True else: + # No need to run the destriping, just compute the binned map with + # one single baseline set to zero + _compute_binned_map( + obs_list=obs_list, + output_sky_map=binned_map, + output_hit_map=hit_map, + nobs_matrix_cholesky=nobs_matrix_cholesky, + component=components[0], + dm_list=detector_mask_list, + tm_list=time_mask_list, + baselines_list=None, + baseline_lengths_list=[ + np.array([getattr(cur_obs, components[0]).shape[1]], dtype=int) + for cur_obs in obs_list + ], + ) + bytes_in_temporary_buffers = 0 destriped_map = None baseline_lengths_list = None baselines_list = None @@ -2018,11 +2000,11 @@ def _save_baselines(results: DestriperResult, output_file: Path) -> None: primary_hdu = fits.PrimaryHDU() primary_hdu.header["MPIRANK"] = ( - MPI_COMM_GRID.COMM_OBS_GRID.rank, + MPI_COMM_WORLD.rank, "The rank of the MPI process that wrote this file", ) primary_hdu.header["MPISIZE"] = ( - MPI_COMM_GRID.COMM_OBS_GRID.size, + MPI_COMM_WORLD.size, "The number of MPI processes used in the computation", ) @@ -2238,11 +2220,11 @@ def load_destriper_results( baselines_file_name = folder / __BASELINES_FILE_NAME with fits.open(baselines_file_name) as inpf: - assert MPI_COMM_GRID.COMM_OBS_GRID.rank == inpf[0].header["MPIRANK"], ( + assert MPI_COMM_WORLD.rank == inpf[0].header["MPIRANK"], ( "You must call load_destriper_results using the " "same MPI layout that was used for save_destriper_results " ) - assert MPI_COMM_GRID.COMM_OBS_GRID.size == inpf[0].header["MPISIZE"], ( + assert MPI_COMM_WORLD.size == inpf[0].header["MPISIZE"], ( "You must call load_destriper_results using the " "same MPI layout that was used for save_destriper_results" ) diff --git a/litebird_sim/mpi.py b/litebird_sim/mpi.py index 64c311813..70ae9469e 100644 --- a/litebird_sim/mpi.py +++ b/litebird_sim/mpi.py @@ -22,57 +22,11 @@ class _SerialMpiCommunicator: size = 1 -class _GridCommClass: - """ - This class encapsulates the `COMM_OBS_GRID` and `COMM_NULL` communicators. It - offers explicitly defined setter functions so that the communicators cannot be - changed accidentally. - - Attributes: - - COMM_OBS_GRID (mpi4py.MPI.Intracomm): A subset of `MPI.COMM_WORLD` that - contain all the processes associated with non-zero observations. - - COMM_NULL (mpi4py.MPI.Comm): A NULL communicator. When MPI is not enabled, it - is set as `None`. If MPI is enabled, it is set as `MPI.COMM_NULL` - - """ - - def __init__(self, comm_obs_grid=_SerialMpiCommunicator(), comm_null=None): - self._MPI_COMM_OBS_GRID = comm_obs_grid - self._MPI_COMM_NULL = comm_null - - @property - def COMM_OBS_GRID(self): - return self._MPI_COMM_OBS_GRID - - @property - def COMM_NULL(self): - return self._MPI_COMM_NULL - - def _set_comm_obs_grid(self, comm_obs_grid): - self._MPI_COMM_OBS_GRID = comm_obs_grid - - def _set_null_comm(self, comm_null): - self._MPI_COMM_NULL = comm_null - - #: Global variable equal either to `mpi4py.MPI.COMM_WORLD` or a object #: that defines the member variables `rank = 0` and `size = 1`. MPI_COMM_WORLD = _SerialMpiCommunicator() -#: Global object with two attributes: -#: -#: - ``COMM_OBS_GRID``: It is a partition of ``MPI_COMM_WORLD`` that includes all the -#: MPI processes with global rank less than ``n_blocks_time * n_blocks_det``. On MPI -#: processes with higher ranks, it points to NULL MPI communicator -#: ``mpi4py.MPI.COMM_NULL``. -#: -#: - ``COMM_NULL``: If :data:`.MPI_ENABLED` is ``True``, this object points to a NULL -#: MPI communicator (``mpi4py.MPI.COMM_NULL``). Otherwise it is ``None``. -MPI_COMM_GRID = _GridCommClass() - #: `True` if MPI should be used by the application. The value of this #: variable is set according to the following rules: #: @@ -100,8 +54,6 @@ def _set_null_comm(self, comm_null): from mpi4py import MPI MPI_COMM_WORLD = MPI.COMM_WORLD - MPI_COMM_GRID._set_comm_obs_grid(comm_obs_grid=MPI.COMM_WORLD) - MPI_COMM_GRID._set_null_comm(comm_null=MPI.COMM_NULL) MPI_ENABLED = True MPI_CONFIGURATION = mpi4py.get_config() except ImportError: diff --git a/litebird_sim/observations.py b/litebird_sim/observations.py index a62244207..ecb6f6cab 100644 --- a/litebird_sim/observations.py +++ b/litebird_sim/observations.py @@ -1,19 +1,18 @@ # -*- encoding: utf-8 -*- -import numbers from collections import defaultdict from dataclasses import dataclass -from typing import Union, List, Any, Optional +from typing import Union, List, Any, Optional, Dict import astropy.time import numpy as np import numpy.typing as npt from .coordinates import DEFAULT_TIME_SCALE -from .detectors import DetectorInfo, InstrumentInfo +from .detectors import InstrumentInfo from .distribute import distribute_evenly, distribute_detector_blocks from .hwp import HWP -from .mpi import MPI_COMM_GRID, _SerialMpiCommunicator +from .mpi import _SerialMpiCommunicator from .pointings import PointingProvider from .scanning import RotQuaternion @@ -243,40 +242,52 @@ def _get_local_start_time_start_and_n_samples(self): return self.start_time_global + start * delta, start, num - def _set_attributes_from_list_of_dict(self, list_of_dict, root): + def _set_attributes_from_list_of_dict( + self, + list_of_dict: List[Dict[str, str]], + root: int, + ) -> None: + """ + Take a list of dictionaries describing each detector and propagate them + """ np.testing.assert_equal(len(list_of_dict), self.n_detectors_global) # Turn list of dict into dict of arrays if not self.comm or self.comm.rank == root: # Build a list of all the keys in the dictionaries contained within - # `list_of_dict` (which is a *list* of dictionaries) + # `list_of_dict` (which is a *list* of dictionaries). `keys` is a list of + # strings like `name`, `net_ukrts`, `fknee_mhz`, etc. keys = list(set().union(*list_of_dict) - set(dir(self))) # This will be the dictionary associating each key with the - # *array* of value for that dictionary - dict_of_array = {k: [] for k in keys} + # *array* of values for that dictionary + dict_of_array = {cur_key: [] for cur_key in keys} # This array associates either np.nan or None to each type; # the former indicates that the value is a NumPy array, while # None is used for everything else nan_or_none = {} - for k in keys: - for d in list_of_dict: - if k in d: + for cur_key in keys: + for cur_det_dict in list_of_dict: + if cur_key in cur_det_dict: try: - nan_or_none[k] = np.nan * d[k] + nan_or_none[cur_key] = np.nan * cur_det_dict[cur_key] except TypeError: - nan_or_none[k] = None + nan_or_none[cur_key] = None break # Finally, build `dict_of_array` - for d in list_of_dict: - for k in keys: - dict_of_array[k].append(d.get(k, nan_or_none[k])) + for cur_det_dict in list_of_dict: + for cur_key in keys: + dict_of_array[cur_key].append( + cur_det_dict.get(cur_key, nan_or_none[cur_key]) + ) - # Why should this code iterate over `keys`?!? - for k in keys: - dict_of_array = {k: np.array(dict_of_array[k]) for k in keys} + # So far, dict_of_array entries are plain lists. This converts them into + # NumPy arrays + dict_of_array = { + cur_key: np.array(dict_of_array[cur_key]) for cur_key in keys + } else: keys = None dict_of_array = {} @@ -285,8 +296,8 @@ def _set_attributes_from_list_of_dict(self, list_of_dict, root): if self.comm and self.comm.size > 1: keys = self.comm.bcast(keys) - for k in keys: - self.setattr_det_global(k, dict_of_array.get(k), root) + for cur_key in keys: + self.setattr_det_global(cur_key, dict_of_array.get(cur_key), root) @property def n_samples_global(self): @@ -404,11 +415,11 @@ def _check_blocks(self, n_blocks_det, n_blocks_time): "You can not have more time blocks than time samples " f"({n_blocks_time} > {self.n_blocks_time})" ) - elif self.comm.size < n_blocks_det * n_blocks_time: + elif self.comm.size != n_blocks_det * n_blocks_time: raise ValueError( - "Too many blocks: n_blocks_det x n_blocks_time = " - f"{n_blocks_det * n_blocks_time} but the number " - f"processes is {self.comm.size}" + "Number of blocks requested: n_blocks_det x n_blocks_time = " + f"{n_blocks_det * n_blocks_time} is not equal to the number " + f"of available MPI processes: comm.size = {self.comm.size}" ) def _get_start_and_num(self, n_blocks_det, n_blocks_time): @@ -683,19 +694,10 @@ def setattr_det_global(self, name, info, root=0): setattr(self, name, info) return - if ( - MPI_COMM_GRID.COMM_OBS_GRID == MPI_COMM_GRID.COMM_NULL - ): # The process does not own any detector (and TOD) - null_det = DetectorInfo() - attribute = getattr(null_det, name, None) - value = np.array([0]) if isinstance(attribute, numbers.Number) else [None] - setattr(self, name, value) - return - - my_col = MPI_COMM_GRID.COMM_OBS_GRID.rank % self._n_blocks_time + my_col = self.comm.rank % self._n_blocks_time root_col = root // self._n_blocks_det if my_col == root_col: - if MPI_COMM_GRID.COMM_OBS_GRID.rank == root: + if self.comm.rank == root: starts, nums, _, _ = self._get_start_and_num( self._n_blocks_det, self._n_blocks_time ) @@ -814,7 +816,7 @@ def prepare_pointings( # If the hwp object is passed and is not initialised in the observations, it gets applied to all detectors if hwp is None: - assert all(m is None for m in self.mueller_hwp), ( + assert self.no_mueller_hwp(), ( "Some detectors have been initialized with a mueller_hwp," "but no HWP object has been passed to prepare_pointings." ) @@ -1019,20 +1021,20 @@ def precompute_pointings( self.pointing_matrix = pointing_matrix self.hwp_angle = hwp_angle + def no_mueller_hwp(self) -> bool: + "Return True if no detectors have defined a Mueller matrix for the HWP" + return (self.mueller_hwp is None) or all(m is None for m in self.mueller_hwp) + def _set_mpi_subcommunicators(self): """ - This function splits the global MPI communicator into three kinds of + This function splits the global MPI communicator into two kinds of sub-communicators: - 1. A sub-communicator containing all the processes with global rank less than - `n_blocks_det * n_blocks_time`. Outside of this global rank, the - sub-communicator is NULL. - - 2. A sub-communicator for each block of detectors, that contains all the + 1. A sub-communicator for each block of detectors, that contains all the processes corresponding to that detector block. This sub-communicator is an attribute of the :class:`.Observation` class. - 3. A sub-communicator for each block of time that contains all the processes + 2. A sub-communicator for each block of time that contains all the processes corresponding to that time block. This sub-communicator is an attribute of the :class:`.Observation` class. """ @@ -1043,30 +1045,7 @@ def _set_mpi_subcommunicators(self): self.comm_time_block = _SerialMpiCommunicator() if self.comm and self.comm.size > 1: - if self.comm.rank < self.n_blocks_det * self.n_blocks_time: - matrix_color = 1 - else: - from .mpi import MPI - - matrix_color = MPI.UNDEFINED - - # Case1: For `0 < rank < n_blocks_det * n_blocks_time`, - # `comm_obs_grid` is a sub-communicator that includes processes - # from rank 0 to `n_blocks_det * n_blocks_time - 1`. - # Case 2: For `n_blocks_det * n_blocks_time <= rank < comm.size`, - # `comm_obs_grid = MPI.COMM_NULL` - comm_obs_grid = self.comm.Split(matrix_color, self.comm.rank) - MPI_COMM_GRID._set_comm_obs_grid(comm_obs_grid=comm_obs_grid) - - # If the `MPI_COMM_GRID.COMM_OBS_GRID` is not NULL, we split it in - # communicators corresponding to each detector and time block - # If `MPI_COMM_GRID.COMM_OBS_GRID` is NULL, we set the communicators - # corresponding to detector and time blocks to NULL. - if MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: - det_color = MPI_COMM_GRID.COMM_OBS_GRID.rank // self.n_blocks_time - time_color = MPI_COMM_GRID.COMM_OBS_GRID.rank % self.n_blocks_time - self.comm_det_block = MPI_COMM_GRID.COMM_OBS_GRID.Split(det_color) - self.comm_time_block = MPI_COMM_GRID.COMM_OBS_GRID.Split(time_color) - else: - self.comm_det_block = MPI_COMM_GRID.COMM_NULL - self.comm_time_block = MPI_COMM_GRID.COMM_NULL + det_color = self.comm.rank // self.n_blocks_time + time_color = self.comm.rank % self.n_blocks_time + self.comm_det_block = self.comm.Split(det_color) + self.comm_time_block = self.comm.Split(time_color) diff --git a/litebird_sim/pointings_in_obs.py b/litebird_sim/pointings_in_obs.py index f5c4f346a..ce3213c1c 100644 --- a/litebird_sim/pointings_in_obs.py +++ b/litebird_sim/pointings_in_obs.py @@ -119,7 +119,7 @@ def _get_hwp_angle( hwp_angle = obs.get_hwp_angle(pointings_dtype=pointing_dtype) else: if hasattr(obs, "mueller_hwp"): - assert all(m is None for m in obs.mueller_hwp), ( + assert obs.no_mueller_hwp(), ( "Detectors have been initialized with a mueller_hwp," "but no HWP is either passed or initilized in the pointing" ) diff --git a/litebird_sim/simulations.py b/litebird_sim/simulations.py index de5e8f650..7f0a2e027 100644 --- a/litebird_sim/simulations.py +++ b/litebird_sim/simulations.py @@ -51,7 +51,7 @@ destriper_log_callback, ) from .mbs import Mbs, MbsParameters -from .mpi import MPI_ENABLED, MPI_COMM_WORLD, MPI_COMM_GRID +from .mpi import MPI_ENABLED, MPI_COMM_WORLD from .noise import add_noise_to_observations from .non_linearity import NonLinParams, apply_quadratic_nonlin_to_observations from .observations import Observation, TodDescription @@ -1340,8 +1340,7 @@ def set_scanning_strategy( num_of_obs = len(self.observations) if append_to_report and MPI_ENABLED: - if MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: - num_of_obs = MPI_COMM_GRID.COMM_OBS_GRID.allreduce(num_of_obs) + num_of_obs = self.mpi_comm.allreduce(num_of_obs) if append_to_report and MPI_COMM_WORLD.rank == 0: template_file_path = get_template_file_path("report_quaternions.md") @@ -1419,11 +1418,8 @@ def prepare_pointings( memory_occupation = pointing_provider.bore2ecliptic_quats.quats.nbytes num_of_obs = len(self.observations) if append_to_report and MPI_ENABLED: - if MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: - memory_occupation = MPI_COMM_GRID.COMM_OBS_GRID.allreduce( - memory_occupation - ) - num_of_obs = MPI_COMM_GRID.COMM_OBS_GRID.allreduce(num_of_obs) + memory_occupation = self.mpi_comm.allreduce(memory_occupation) + num_of_obs = self.mpi_comm.allreduce(num_of_obs) if append_to_report and MPI_COMM_WORLD.rank == 0: template_file_path = get_template_file_path("report_pointings.md") diff --git a/test/test_detector_blocks.py b/test/test_detector_blocks.py index 50658c5e8..f24f7d8d0 100644 --- a/test/test_detector_blocks.py +++ b/test/test_detector_blocks.py @@ -132,31 +132,15 @@ def test_mpi_subcommunicators(dets=dets): det_blocks_attributes=det_blocks_attribute, ) - if lbs.MPI_COMM_GRID.COMM_OBS_GRID != lbs.MPI_COMM_GRID.COMM_NULL: - # since unused MPI processes stay at the end of global, - # communicator, the rank of the used processes in - # `MPI_COMM_GRID.COMM_OBS_GRID` must be same as their rank in - # global communicator - np.testing.assert_equal(lbs.MPI_COMM_GRID.COMM_OBS_GRID.rank, comm.rank) - - for obs in sim.observations: - # comm_det_block.rank + comm_time_block.rank * n_block_time - # must be equal to the global communicator rank for the - # used processes. It follows from the way split colors - # were defined. - np.testing.assert_equal( - obs.comm_det_block.rank + obs.comm_time_block.rank * obs.n_blocks_time, - comm.rank, - ) - else: - for obs in sim.observations: - # the global rank of the unused MPI processes must be larger than the number of used processes. - assert comm.rank > (obs.n_blocks_det * obs.n_blocks_time - 1) - - # The block communicators on the unused MPI processes must - # be the NULL communicators - np.testing.assert_equal(obs.comm_det_block, lbs.MPI_COMM_GRID.COMM_NULL) - np.testing.assert_equal(obs.comm_time_block, lbs.MPI_COMM_GRID.COMM_NULL) + for obs in sim.observations: + # comm_det_block.rank + comm_time_block.rank * n_block_time + # must be equal to the global communicator rank for the + # used processes. It follows from the way split colors + # were defined. + np.testing.assert_equal( + obs.comm_det_block.rank + obs.comm_time_block.rank * obs.n_blocks_time, + comm.rank, + ) if __name__ == "__main__":