diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4018509f..cf6dd5a8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -87,6 +87,7 @@ jobs: for proc in 1 5 9 ; do echo "Running MPI test ($MPI) with $proc processes" PYTHONPATH=. mpiexec -n $proc python3 ./test/test_mpi.py + PYTHONPATH=. mpiexec -n $proc python3 ./test/test_detector_blocks.py done - name: Tests OpenMPI @@ -95,4 +96,5 @@ jobs: for proc in 1 2 ; do echo "Running MPI test ($MPI) with $proc processes" PYTHONPATH=. mpiexec -n $proc python3 ./test/test_mpi.py + PYTHONPATH=. mpiexec -n $proc python3 ./test/test_detector_blocks.py done diff --git a/docs/source/mpi.rst b/docs/source/mpi.rst index ec64e75a..79735ad1 100644 --- a/docs/source/mpi.rst +++ b/docs/source/mpi.rst @@ -133,6 +133,31 @@ variable :data:`.MPI_ENABLED`:: To ensure that your code uses MPI in the proper way, you should always use :data:`.MPI_COMM_WORLD` instead of importing ``mpi4py`` directly. +The simulation framework also provides a global object +:data:`.MPI_COMM_GRID`. It has two attributes: + +- ``COMM_OBS_GRID``: This is an MPI communicator that contains all the + MPI processes with the global rank less than ``n_blocks_time * n_blocks_det``. + It provides a safety net to the operations and MPI communications + that are needed to be performed only on the partition of :data:`.MPI_COMM_WORLD` + that contain non-zero number of pointings and TODs. By default, + ``COMM_OBS_GRID`` points to the global MPI communicator :data:`.MPI_COMM_WORLD`. + It is updated once :class:`.Observation` are defined. For example, + consider the case when a user runs the simulation with 10 MPI + processes but due some specific ``det_blocks_attributes`` argument + in :class:`.Observation` class, the number of detector and time + blocks are determined to be 2 and 4 respectively. Then the + simulation framework will store the pointings and TODs only on + :math:`2\times4=8` MPI processes and the last two ranks of :data:`.MPI_COMM_WORLD` + will be left unused. Once this happens, ``COMM_OBS_GRID`` on first 8 + ranks (rank 0 to 7) will point to the local sub-communicator + containing the processes with global rank 0 to 7. On the unused + ranks, it will simply point to the NULL communicator. +- ``COMM_NULL``: If :data:`.MPI_ENABLED` is ``True``, this object + points to a NULL MPI communicator (``mpi4py.MPI.COMM_NULL``). + Otherwise it is set to ``None``. The user should compare + ``COMM_OBS_GRID`` with ``COMM_NULL`` on every MPI process in order + to avoid running a piece of code on unused MPI processes. Enabling/disabling MPI ---------------------- diff --git a/docs/source/observations.rst b/docs/source/observations.rst index 9afe6798..3c9b9077 100644 --- a/docs/source/observations.rst +++ b/docs/source/observations.rst @@ -97,8 +97,15 @@ With this memory layout, typical operations look like this:: Parallel applications --------------------- -The only work that the :class:`.Observation` class actually does is handling -parallelism. ``obs.tod`` can be distributed over a +The :class:`.Observation` class allows the distribution of ``obs.tod`` over multiple MPI +processes to enable the parallelization of computations. The distribution of ``obs.tod`` +can be achieved in two different ways: + +1. Uniform distribution of detectors along the detector axis +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +With ``n_blocks_det`` and ``n_blocks_time`` arguments of :class:`.Observation` class, +the ``obs.tod`` is evenly distributed over a ``n_blocks_det`` by ``n_blocks_time`` grid of MPI ranks. The blocks can be changed at run-time. @@ -111,7 +118,7 @@ The main advantage is that the example operations in the Serial section are achieved with the same lines of code. The price to pay is that you have to set detector properties with special methods. -:: +.. code-block:: python import litebird_sim as lbs from mpi4py import MPI @@ -158,21 +165,222 @@ TOD) gets distributed. .. image:: ./images/observation_data_distribution.png -When ``n_blocks_det != 1``, keep in mind that ``obs.tod[0]`` or -``obs.wn_levels[0]`` are quantities of the first *local* detector, not global. -This should not be a problem as the only thing that matters is that the two -quantities refer to the same detector. If you need the global detector index, -you can get it with ``obs.det_idx[0]``, which is created -at construction time. - -To get a better understanding of how observations are being used in a -MPI simulation, use the method :meth:`.Simulation.describe_mpi_distribution`. -This method must be called *after* the observations have been allocated using -:meth:`.Simulation.create_observations`; it will return an instance of the -class :class:`.MpiDistributionDescr`, which can be inspected to determine -which detectors and time spans are covered by each observation in all the -MPI processes that are being used. For more information, refer to the Section -:ref:`simulations`. +2. Custom grouping of detectors along the detector axis +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +While uniform distribution of detectors along the detector axis optimizes load +balancing, it is less suitable for simulating the some effects, like crosstalk and +noise correlation between the detectors. This uniform distribution across MPI +processes necessitates the transfer of large TOD arrays across multiple MPI processes, +which complicates the code implementation and may potentially lead to significant +performance overhead. To save us from this situation, the :class:`.Observation` class +accepts an argument ``det_blocks_attributes`` that is a list of string objects +specifying the detector attributes to create the group of detectors. Once the +detector groups are made, the detectors are distributed to the MPI processes in such +a way that all the detectors of a group reside on the same MPI process. + +If a valid ``det_blocks_attributes`` argument is passed to the :class:`.Observation` +class, the arguments ``n_blocks_det`` and ``n_blocks_time`` are ignored. Since the +``det_blocks_attributes`` creates the detector blocks dynamically, the +``n_blocks_time`` is computed during runtime using the size of MPI communicator and +the number of detector blocks (``n_blocks_time = comm.size // n_blocks_det``). + +The detector blocks made in this way can be accessed with +``Observation.detector_blocks``. It is a dictionary object has the tuple of +``det_blocks_attributes`` values as dictionary keys and the list of detectors +corresponding to the key as dictionary values. This dictionary is sorted so that the +group with the largest number of detectors comes first and the one with +the fewest detectors comes last. + +The following example illustrates the distribution of ``obs.tod`` matrix across the +MPI processes when ``det_blocks_attributes`` is specified. + +.. code-block:: python + + import litebird_sim as lbs + + comm = lbs.MPI_COMM_WORLD + + start_time = 456 + duration_s = 100 + sampling_freq_Hz = 1 + + # Creating a list of detectors. + dets = [ + lbs.DetectorInfo( + name="channel1_w9_detA", + wafer="wafer_9", + channel="channel1", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel1_w3_detB", + wafer="wafer_3", + channel="channel1", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel1_w1_detC", + wafer="wafer_1", + channel="channel1", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel1_w1_detD", + wafer="wafer_1", + channel="channel1", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel2_w4_detA", + wafer="wafer_4", + channel="channel2", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel2_w4_detB", + wafer="wafer_4", + channel="channel2", + sampling_rate_hz=sampling_freq_Hz, + ), + ] + + # Initializing a simulation + sim = lbs.Simulation( + start_time=start_time, + duration_s=duration_s, + random_seed=12345, + mpi_comm=comm, + ) + + # Creating the observations with detector blocks + sim.create_observations( + detectors=dets, + split_list_over_processes=False, + num_of_obs_per_detector=3, + det_blocks_attributes=["channel"], # case 1 and 2 + # det_blocks_attributes=["channel", "wafer"] # case 3 + ) + +With the list of detectors defined in the code snippet above, we can see how the +detectors axis and time axis is divided depending on the size of MPI communicator and +``det_blocks_attributes``. + +**Case 1** + +*Size of MPI communicator = 3*, ``det_blocks_attributes=["channel"]`` + +:: + + Detector axis ---> + Two blocks ---> + +------------------+ +------------------+ + | Rank 0 | | Rank 1 | + +------------------+ +------------------+ + | channel1_w9_detA | | channel2_w4_detA | + T + + + + + i O | channel1_w3_detB | | channel2_w4_detB | + m n + + +------------------+ + e e | channel1_w1_detC | + + + + a b | channel1_w1_detD | + x l +------------------+ + i o + s c ........................................... + k + | | +------------------+ + | | | Rank 2 | + ⋎ ⋎ +------------------+ + | (Unused) | + +------------------+ + +**Case 2** + +*Size of MPI communicator = 4*, ``det_blocks_attributes=["channel"]`` + +:: + + Detector axis ---> + Two blocks ---> + +------------------+ +------------------+ + | Rank 0 | | Rank 2 | + +------------------+ +------------------+ + | channel1_w9_detA | | channel2_w4_detA | + + + + + + | channel1_w3_detB | | channel2_w4_detB | + T + + +------------------+ + i T | channel1_w1_detC | + m w + + + e o | channel1_w1_detD | + +------------------+ + a b + x l ........................................... + i o + s c +------------------+ +------------------+ + k | Rank 1 | | Rank 3 | + | | +------------------+ +------------------+ + | | | channel1_w9_detA | | channel2_w4_detA | + ⋎ ⋎ + + + + + | channel1_w3_detB | | channel2_w4_detB | + + + +------------------+ + | channel1_w1_detC | + + + + | channel1_w1_detD | + +------------------+ + +**Case 3** + +*Size of MPI communicator = 10*, ``det_blocks_attributes=["channel", "wafer"]`` + +:: + + Detector axis ---> + Four blocks ---> + +------------------+ +------------------+ +------------------+ +------------------+ + | Rank 0 | | Rank 2 | | Rank 4 | | Rank 6 | + +------------------+ +------------------+ +------------------+ +------------------+ + T | channel1_w1_detC | | channel2_w4_detA | | channel1_w9_detA | | channel1_w3_detB | + i T + + + + +------------------+ +------------------+ + m w | channel1_w1_detD | | channel2_w4_detB | + e o +------------------+ +------------------+ + + ......................................................................................... + + a b +------------------+ +------------------+ +------------------+ +------------------+ + x l | Rank 1 | | Rank 3 | | Rank 5 | | Rank 7 | + i o +------------------+ +------------------+ +------------------+ +------------------+ + s c | channel1_w1_detC | | channel2_w4_detA | | channel1_w9_detA | | channel1_w3_detB | + k + + + + +------------------+ +------------------+ + | | | channel1_w1_detD | | channel2_w4_detB | + | | +------------------+ +------------------+ + ⋎ ⋎ + ......................................................................................... + + +------------------+ +------------------+ + | Rank 8 | | Rank 9 | + +------------------+ +------------------+ + | (Unused) | | (Unused) | + +------------------+ +------------------+ + +.. note:: + When ``n_blocks_det != 1``, keep in mind that ``obs.tod[0]`` or + ``obs.wn_levels[0]`` are quantities of the first *local* detector, not global. + This should not be a problem as the only thing that matters is that the two + quantities refer to the same detector. If you need the global detector index, + you can get it with ``obs.det_idx[0]``, which is created at construction time. + ``obs.det_idx`` stores the detector indices of the detectors available to an + :class:`.Observation` class, with respect to the list of detectors stored in + ``obs.detectors_global`` variable. + +.. note:: + To get a better understanding of how observations are being used in a + MPI simulation, use the method :meth:`.Simulation.describe_mpi_distribution`. + This method must be called *after* the observations have been allocated using + :meth:`.Simulation.create_observations`; it will return an instance of the + class :class:`.MpiDistributionDescr`, which can be inspected to determine + which detectors and time spans are covered by each observation in all the + MPI processes that are being used. For more information, refer to the Section + :ref:`simulations`. Other notable functionalities ----------------------------- diff --git a/litebird_sim/__init__.py b/litebird_sim/__init__.py index b4a8aa6b..44285c14 100644 --- a/litebird_sim/__init__.py +++ b/litebird_sim/__init__.py @@ -72,7 +72,7 @@ ) from .madam import save_simulation_for_madam from .mbs.mbs import Mbs, MbsParameters, MbsSavedMapInfo -from .mpi import MPI_COMM_WORLD, MPI_ENABLED, MPI_CONFIGURATION +from .mpi import MPI_COMM_WORLD, MPI_ENABLED, MPI_CONFIGURATION, MPI_COMM_GRID from .noise import ( add_white_noise, add_one_over_f_noise, @@ -218,6 +218,7 @@ def destripe_with_toast2(*args, **kwargs): "MPI_COMM_WORLD", "MPI_ENABLED", "MPI_CONFIGURATION", + "MPI_COMM_GRID", # observations.py "Observation", "TodDescription", diff --git a/litebird_sim/detectors.py b/litebird_sim/detectors.py index a489c13b..f2d0d981 100644 --- a/litebird_sim/detectors.py +++ b/litebird_sim/detectors.py @@ -75,6 +75,9 @@ class DetectorInfo: - channel (Union[str, None]): The channel. The default is None + - squid (Union[int, None]): The squid number of the detector. + The default value is None. + - sampling_rate_hz (float): The sampling rate of the ADC associated with this detector. The default is 0.0 @@ -136,6 +139,7 @@ class DetectorInfo: pixel: Union[int, None] = None pixtype: Union[str, None] = None channel: Union[str, None] = None + squid: Union[int, None] = None sampling_rate_hz: float = 0.0 fwhm_arcmin: float = 0.0 ellipticity: float = 0.0 @@ -148,8 +152,6 @@ class DetectorInfo: fknee_mhz: float = 0.0 fmin_hz: float = 0.0 alpha: float = 0.0 - bandcenter_ghz: float = 0.0 - bandwidth_ghz: float = 0.0 pol: Union[str, None] = None orient: Union[str, None] = None quat: Any = None @@ -175,6 +177,7 @@ def from_dict(dictionary: Dict[str, Any]): - ``pixel`` - ``pixtype`` - ``channel`` + - ``squid`` - ``bandcenter_ghz`` - ``bandwidth_ghz`` - ``band_freqs_ghz`` diff --git a/litebird_sim/distribute.py b/litebird_sim/distribute.py index 3649a877..ef893d82 100644 --- a/litebird_sim/distribute.py +++ b/litebird_sim/distribute.py @@ -50,7 +50,7 @@ def distribute_evenly(num_of_elements, num_of_groups): # If leftovers == 0, then the number of elements is divided evenly # by num_of_groups, and the solution is trivial. If it's not, then - # each of the "leftoverss" is placed in one of the first groups. + # each of the "leftovers" is placed in one of the first groups. # # Example: let's split 8 elements in 3 groups. In this case, # base_length=2 and leftovers=2 (elements #7 and #8): @@ -68,7 +68,7 @@ def distribute_evenly(num_of_elements, num_of_groups): cur_length = base_length + 1 cur_pos = cur_length * i else: - # No need to accomodate for leftovers, but consider their + # No need to accommodate for leftovers, but consider their # presence in fixing the starting position for this group cur_length = base_length cur_pos = base_length * i + leftovers @@ -84,6 +84,47 @@ def distribute_evenly(num_of_elements, num_of_groups): return result +def distribute_detector_blocks(detector_blocks): + """Similar to the :func:`distribute_evenly()` function, this function + returns a list of named-tuples, with fields `start_idx` (the starting + index of the detector in a group within the global list of detectors) and + num_of_elements` (the number of detectors in the group). Unlike + :func:`distribute_evenly()`, this function simply uses the detector groups + given in the `detector_blocks` attribute. + + Example: + Following the example given in + :meth:`litebird_sim.Observation._make_detector_blocks()`, + `distribute_detector_blocks()` will return + + ``` + [ + Span(start_idx=0, num_of_elements=2), + Span(start_idx=2, num_of_elements=2), + Span(start_idx=4, num_of_elements=1), + ] + ``` + + Args: + detector_blocks (dict): The detector block object. See :meth:`litebird_sim.Observation._make_detector_blocks()`. + + Returns: + A list of 2-elements named-tuples containing (1) the starting index of + the detectors of the block with respect to the flatten list of entire + detector blocks and (2) the number of elements in the detector block. + """ + cur_position = 0 + prev_length = 0 + result = [] + for key in detector_blocks: + cur_length = len(detector_blocks[key]) + cur_position += prev_length + prev_length = cur_length + result.append(Span(start_idx=cur_position, num_of_elements=cur_length)) + + return result + + # The following implementation of the painter's partition problem is # heavily inspired by the code at # https://www.geeksforgeeks.org/painters-partition-problem-set-2/?ref=rp diff --git a/litebird_sim/mapmaking/binner.py b/litebird_sim/mapmaking/binner.py index 682a5fbe..9622a234 100644 --- a/litebird_sim/mapmaking/binner.py +++ b/litebird_sim/mapmaking/binner.py @@ -68,7 +68,7 @@ class BinnerResult: @njit def _solve_binning(nobs_matrix, atd): - # Sove the map-making equation + # Solve the map-making equation # # This method alters the parameter `nobs_matrix`, so that after its completion # each 3×3 matrix in nobs_matrix[idx, :, :] will be the *inverse*. diff --git a/litebird_sim/mapmaking/common.py b/litebird_sim/mapmaking/common.py index 3f60e1c3..ac2f1393 100644 --- a/litebird_sim/mapmaking/common.py +++ b/litebird_sim/mapmaking/common.py @@ -219,7 +219,10 @@ def _compute_pixel_indices( if output_coordinate_system == CoordinateSystem.Galactic: # Free curr_pointings_det if the output map is already in Galactic coordinates - del curr_pointings_det + try: + del curr_pointings_det + except UnboundLocalError: + pass return pixidx_all, polang_all diff --git a/litebird_sim/mapmaking/destriper.py b/litebird_sim/mapmaking/destriper.py index c56fe6df..71e323af 100644 --- a/litebird_sim/mapmaking/destriper.py +++ b/litebird_sim/mapmaking/destriper.py @@ -20,7 +20,7 @@ from numba import njit, prange import healpy as hp -from litebird_sim.mpi import MPI_ENABLED, MPI_COMM_WORLD +from litebird_sim.mpi import MPI_ENABLED, MPI_COMM_WORLD, MPI_COMM_GRID from typing import Callable, Union, List, Optional, Tuple, Any, Dict from litebird_sim.hwp import HWP from litebird_sim.observations import Observation @@ -44,7 +44,7 @@ __DESTRIPER_RESULTS_FILE_NAME = "destriper_results.fits" -__BASELINES_FILE_NAME = f"baselines_mpi{MPI_COMM_WORLD.rank:04d}.fits" +__BASELINES_FILE_NAME = f"baselines_mpi{MPI_COMM_GRID.COMM_OBS_GRID.rank:04d}.fits" def _split_items_into_n_segments(n: int, num_of_segments: int) -> List[int]: @@ -498,8 +498,10 @@ def _build_nobs_matrix( ) # Now we must accumulate the result of every MPI process - if MPI_ENABLED: - MPI_COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE, nobs_matrix, op=mpi4py.MPI.SUM) + if MPI_ENABLED and MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: + MPI_COMM_GRID.COMM_OBS_GRID.Allreduce( + mpi4py.MPI.IN_PLACE, nobs_matrix, op=mpi4py.MPI.SUM + ) # `nobs_matrix_cholesky` will *not* contain the M_i maps shown in # Eq. 9 of KurkiSuonio2009, but its Cholesky decomposition, i.e., @@ -746,8 +748,12 @@ def _compute_binned_map( ) if MPI_ENABLED: - MPI_COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE, output_sky_map, op=mpi4py.MPI.SUM) - MPI_COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE, output_hit_map, op=mpi4py.MPI.SUM) + MPI_COMM_GRID.COMM_OBS_GRID.Allreduce( + mpi4py.MPI.IN_PLACE, output_sky_map, op=mpi4py.MPI.SUM + ) + MPI_COMM_GRID.COMM_OBS_GRID.Allreduce( + mpi4py.MPI.IN_PLACE, output_hit_map, op=mpi4py.MPI.SUM + ) # Step 2: compute the “binned map” (Eq. 21) _sum_map_to_binned_map( @@ -987,7 +993,7 @@ def _mpi_dot(a: List[npt.ArrayLike], b: List[npt.ArrayLike]) -> float: # the dot product local_result = sum([np.dot(x1.flatten(), x2.flatten()) for (x1, x2) in zip(a, b)]) if MPI_ENABLED: - return MPI_COMM_WORLD.allreduce(local_result, op=mpi4py.MPI.SUM) + return MPI_COMM_GRID.COMM_OBS_GRID.allreduce(local_result, op=mpi4py.MPI.SUM) else: return local_result @@ -1004,7 +1010,7 @@ def _get_stopping_factor(residual: List[npt.ArrayLike]) -> float: """ local_result = np.max(np.abs(residual)) if MPI_ENABLED: - return MPI_COMM_WORLD.allreduce(local_result, op=mpi4py.MPI.MAX) + return MPI_COMM_GRID.COMM_OBS_GRID.allreduce(local_result, op=mpi4py.MPI.MAX) else: return local_result @@ -1418,7 +1424,7 @@ def _run_destriper( bytes_in_temporary_buffers += mask.nbytes if MPI_ENABLED: - bytes_in_temporary_buffers = MPI_COMM_WORLD.allreduce( + bytes_in_temporary_buffers = MPI_COMM_GRID.COMM_OBS_GRID.allreduce( bytes_in_temporary_buffers, op=mpi4py.MPI.SUM, ) @@ -1613,91 +1619,103 @@ def my_gui_callback( binned_map = np.empty((3, number_of_pixels)) hit_map = np.empty(number_of_pixels) - if do_destriping: - try: - # This will fail if the parameter is a scalar - len(params.samples_per_baseline) - - baseline_lengths_list = params.samples_per_baseline - assert len(baseline_lengths_list) == len(obs_list), ( - f"The list baseline_lengths_list has {len(baseline_lengths_list)} " - f"elements, but there are {len(obs_list)} observations" - ) - except TypeError: - # Ok, params.samples_per_baseline is a scalar, so we must - # figure out the number of samples in each baseline within - # each observation - baseline_lengths_list = [ - split_items_evenly( - n=getattr(cur_obs, components[0]).shape[1], - sub_n=int(params.samples_per_baseline), + if MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: + # perform the following operations when MPI is not being used + # OR when the MPI_COMM_GRID.COMM_OBS_GRID is not a NULL communicator + if do_destriping: + try: + # This will fail if the parameter is a scalar + len(params.samples_per_baseline) + + baseline_lengths_list = params.samples_per_baseline + assert len(baseline_lengths_list) == len(obs_list), ( + f"The list baseline_lengths_list has {len(baseline_lengths_list)} " + f"elements, but there are {len(obs_list)} observations" ) - for cur_obs in obs_list - ] + except TypeError: + # Ok, params.samples_per_baseline is a scalar, so we must + # figure out the number of samples in each baseline within + # each observation + baseline_lengths_list = [ + split_items_evenly( + n=getattr(cur_obs, components[0]).shape[1], + sub_n=int(params.samples_per_baseline), + ) + for cur_obs in obs_list + ] + + # Each element of this list is a 2D array with shape (N_det, N_baselines), + # where N_det is the number of detectors in the i-th Observation object + recycle_baselines = False + if baselines_list is None: + baselines_list = [ + np.zeros( + (getattr(cur_obs, components[0]).shape[0], len(cur_baseline)) + ) + for (cur_obs, cur_baseline) in zip(obs_list, baseline_lengths_list) + ] + else: + recycle_baselines = True + + destriped_map = np.empty((3, number_of_pixels)) + ( + baselines_list, + baseline_errors_list, + history_of_stopping_factors, + best_stopping_factor, + converged, + bytes_in_temporary_buffers, + ) = _run_destriper( + obs_list=obs_list, + nobs_matrix_cholesky=nobs_matrix_cholesky, + binned_map=binned_map, + destriped_map=destriped_map, + hit_map=hit_map, + baseline_lengths_list=baseline_lengths_list, + baselines_list_start=baselines_list, + recycle_baselines=recycle_baselines, + recycled_convergence=recycled_convergence, + dm_list=detector_mask_list, + tm_list=time_mask_list, + component=components[0], + threshold=params.threshold, + max_steps=params.iter_max, + use_preconditioner=params.use_preconditioner, + callback=callback, + callback_kwargs=callback_kwargs if callback_kwargs else {}, + ) - # Each element of this list is a 2D array with shape (N_det, N_baselines), - # where N_det is the number of detectors in the i-th Observation object - recycle_baselines = False - if baselines_list is None: - baselines_list = [ - np.zeros((getattr(cur_obs, components[0]).shape[0], len(cur_baseline))) - for (cur_obs, cur_baseline) in zip(obs_list, baseline_lengths_list) - ] + if MPI_ENABLED: + bytes_in_temporary_buffers = MPI_COMM_GRID.COMM_OBS_GRID.allreduce( + bytes_in_temporary_buffers, + op=mpi4py.MPI.SUM, + ) else: - recycle_baselines = True - - destriped_map = np.empty((3, number_of_pixels)) - ( - baselines_list, - baseline_errors_list, - history_of_stopping_factors, - best_stopping_factor, - converged, - bytes_in_temporary_buffers, - ) = _run_destriper( - obs_list=obs_list, - nobs_matrix_cholesky=nobs_matrix_cholesky, - binned_map=binned_map, - destriped_map=destriped_map, - hit_map=hit_map, - baseline_lengths_list=baseline_lengths_list, - baselines_list_start=baselines_list, - recycle_baselines=recycle_baselines, - recycled_convergence=recycled_convergence, - dm_list=detector_mask_list, - tm_list=time_mask_list, - component=components[0], - threshold=params.threshold, - max_steps=params.iter_max, - use_preconditioner=params.use_preconditioner, - callback=callback, - callback_kwargs=callback_kwargs if callback_kwargs else {}, - ) - - if MPI_ENABLED: - bytes_in_temporary_buffers = MPI_COMM_WORLD.allreduce( - bytes_in_temporary_buffers, - op=mpi4py.MPI.SUM, + # No need to run the destriping, just compute the binned map with + # one single baseline set to zero + _compute_binned_map( + obs_list=obs_list, + output_sky_map=binned_map, + output_hit_map=hit_map, + nobs_matrix_cholesky=nobs_matrix_cholesky, + component=components[0], + dm_list=detector_mask_list, + tm_list=time_mask_list, + baselines_list=None, + baseline_lengths_list=[ + np.array([getattr(cur_obs, components[0]).shape[1]], dtype=int) + for cur_obs in obs_list + ], ) + bytes_in_temporary_buffers = 0 + destriped_map = None + baseline_lengths_list = None + baselines_list = None + baseline_errors_list = None + history_of_stopping_factors = None + best_stopping_factor = None + converged = True else: - # No need to run the destriping, just compute the binned map with - # one single baseline set to zero - _compute_binned_map( - obs_list=obs_list, - output_sky_map=binned_map, - output_hit_map=hit_map, - nobs_matrix_cholesky=nobs_matrix_cholesky, - component=components[0], - dm_list=detector_mask_list, - tm_list=time_mask_list, - baselines_list=None, - baseline_lengths_list=[ - np.array([getattr(cur_obs, components[0]).shape[1]], dtype=int) - for cur_obs in obs_list - ], - ) - bytes_in_temporary_buffers = 0 - destriped_map = None baseline_lengths_list = None baselines_list = None @@ -1707,14 +1725,18 @@ def my_gui_callback( converged = True # Add the temporary memory that was allocated *before* calling the destriper - bytes_in_temporary_buffers += sum( - [ - cur_obs.destriper_weights.nbytes - + cur_obs.destriper_pixel_idx.nbytes - + cur_obs.destriper_pol_angle_rad.nbytes - for cur_obs in obs_list - ] - ) + try: + bytes_in_temporary_buffers += sum( + [ + cur_obs.destriper_weights.nbytes + + cur_obs.destriper_pixel_idx.nbytes + + cur_obs.destriper_pol_angle_rad.nbytes + for cur_obs in obs_list + ] + ) + except UnboundLocalError: + # The case when `bytes_in_temporary_buffers` is not defined + bytes_in_temporary_buffers = 0 # We're nearly done! Let's clean up some stuff… if not keep_weights: @@ -1992,11 +2014,11 @@ def _save_baselines(results: DestriperResult, output_file: Path) -> None: primary_hdu = fits.PrimaryHDU() primary_hdu.header["MPIRANK"] = ( - MPI_COMM_WORLD.rank, + MPI_COMM_GRID.COMM_OBS_GRID.rank, "The rank of the MPI process that wrote this file", ) primary_hdu.header["MPISIZE"] = ( - MPI_COMM_WORLD.size, + MPI_COMM_GRID.COMM_OBS_GRID.size, "The number of MPI processes used in the computation", ) @@ -2212,11 +2234,11 @@ def load_destriper_results( baselines_file_name = folder / __BASELINES_FILE_NAME with fits.open(baselines_file_name) as inpf: - assert MPI_COMM_WORLD.rank == inpf[0].header["MPIRANK"], ( + assert MPI_COMM_GRID.COMM_OBS_GRID.rank == inpf[0].header["MPIRANK"], ( "You must call load_destriper_results using the " "same MPI layout that was used for save_destriper_results " ) - assert MPI_COMM_WORLD.size == inpf[0].header["MPISIZE"], ( + assert MPI_COMM_GRID.COMM_OBS_GRID.size == inpf[0].header["MPISIZE"], ( "You must call load_destriper_results using the " "same MPI layout that was used for save_destriper_results" ) diff --git a/litebird_sim/mpi.py b/litebird_sim/mpi.py index fe623248..64c31181 100644 --- a/litebird_sim/mpi.py +++ b/litebird_sim/mpi.py @@ -22,10 +22,57 @@ class _SerialMpiCommunicator: size = 1 +class _GridCommClass: + """ + This class encapsulates the `COMM_OBS_GRID` and `COMM_NULL` communicators. It + offers explicitly defined setter functions so that the communicators cannot be + changed accidentally. + + Attributes: + + COMM_OBS_GRID (mpi4py.MPI.Intracomm): A subset of `MPI.COMM_WORLD` that + contain all the processes associated with non-zero observations. + + COMM_NULL (mpi4py.MPI.Comm): A NULL communicator. When MPI is not enabled, it + is set as `None`. If MPI is enabled, it is set as `MPI.COMM_NULL` + + """ + + def __init__(self, comm_obs_grid=_SerialMpiCommunicator(), comm_null=None): + self._MPI_COMM_OBS_GRID = comm_obs_grid + self._MPI_COMM_NULL = comm_null + + @property + def COMM_OBS_GRID(self): + return self._MPI_COMM_OBS_GRID + + @property + def COMM_NULL(self): + return self._MPI_COMM_NULL + + def _set_comm_obs_grid(self, comm_obs_grid): + self._MPI_COMM_OBS_GRID = comm_obs_grid + + def _set_null_comm(self, comm_null): + self._MPI_COMM_NULL = comm_null + + #: Global variable equal either to `mpi4py.MPI.COMM_WORLD` or a object #: that defines the member variables `rank = 0` and `size = 1`. MPI_COMM_WORLD = _SerialMpiCommunicator() + +#: Global object with two attributes: +#: +#: - ``COMM_OBS_GRID``: It is a partition of ``MPI_COMM_WORLD`` that includes all the +#: MPI processes with global rank less than ``n_blocks_time * n_blocks_det``. On MPI +#: processes with higher ranks, it points to NULL MPI communicator +#: ``mpi4py.MPI.COMM_NULL``. +#: +#: - ``COMM_NULL``: If :data:`.MPI_ENABLED` is ``True``, this object points to a NULL +#: MPI communicator (``mpi4py.MPI.COMM_NULL``). Otherwise it is ``None``. +MPI_COMM_GRID = _GridCommClass() + #: `True` if MPI should be used by the application. The value of this #: variable is set according to the following rules: #: @@ -53,6 +100,8 @@ class _SerialMpiCommunicator: from mpi4py import MPI MPI_COMM_WORLD = MPI.COMM_WORLD + MPI_COMM_GRID._set_comm_obs_grid(comm_obs_grid=MPI.COMM_WORLD) + MPI_COMM_GRID._set_null_comm(comm_null=MPI.COMM_NULL) MPI_ENABLED = True MPI_CONFIGURATION = mpi4py.get_config() except ImportError: diff --git a/litebird_sim/observations.py b/litebird_sim/observations.py index 7997de56..4bcac47e 100644 --- a/litebird_sim/observations.py +++ b/litebird_sim/observations.py @@ -2,13 +2,18 @@ from dataclasses import dataclass from typing import Union, List, Any, Optional +import numbers import astropy.time import numpy as np import numpy.typing as npt +from collections import defaultdict + from .coordinates import DEFAULT_TIME_SCALE -from .distribute import distribute_evenly +from .distribute import distribute_evenly, distribute_detector_blocks +from .detectors import DetectorInfo +from .mpi import MPI_COMM_GRID, _SerialMpiCommunicator @dataclass @@ -80,11 +85,20 @@ class Observation: sampling_rate_hz (float): The sampling frequency, in Hertz. + det_blocks_attributes (list of strings): The list of detector + attributes that will be used to divide the detector axis of the + tod matrix and all its attributes. For example, with + ``det_blocks_attributes = ["wafer", "pixel"]``, the detectors will + be divided into blocks such that all detectors in a block will + have the same ``wafer`` and ``pixel`` attribute. + n_blocks_det (int): divide the detector axis of the tod (and all the - arrays of detector attributes) in `n_blocks_det` blocks + arrays of detector attributes) in `n_blocks_det` blocks. It will + be ignored if ``det_blocks_attributes`` is not `None`. n_blocks_time (int): divide the time axis of the tod in - `n_blocks_time` blocks + `n_blocks_time` blocks. It will be ignored + if ``det_blocks_attributes`` is not `None`. comm: either `None` (do not use MPI) or a MPI communicator object, like `mpi4py.MPI.COMM_WORLD`. Its size is required to be at @@ -103,6 +117,7 @@ def __init__( sampling_rate_hz: float, allocate_tod=True, tods=None, + det_blocks_attributes: Union[List[str], None] = None, n_blocks_det=1, n_blocks_time=1, comm=None, @@ -123,27 +138,36 @@ def __init__( delta = 1.0 / sampling_rate_hz self.end_time_global = start_time_global + n_samples_global * delta + self._sampling_rate_hz = sampling_rate_hz + self._det_blocks_attributes = det_blocks_attributes + self.detector_blocks = None + if isinstance(detectors, int): self._n_detectors_global = detectors else: - if comm and comm.size > 1: - self._n_detectors_global = comm.bcast(len(detectors), root) + if self.comm and self.comm.size > 1: + self._n_detectors_global = self.comm.bcast(len(detectors), root) + + if self._det_blocks_attributes is not None: + n_blocks_det, n_blocks_time = self._make_detector_blocks( + detectors, self.comm + ) else: self._n_detectors_global = len(detectors) - self._sampling_rate_hz = sampling_rate_hz - - # Neme of the attributes that store an array with the value of a + # Name of the attributes that store an array with the value of a # property for each of the (local) detectors self._attr_det_names = [] self._check_blocks(n_blocks_det, n_blocks_time) - if comm and comm.size > 1: + if self.comm and self.comm.size > 1: self._n_blocks_det = n_blocks_det self._n_blocks_time = n_blocks_time else: self._n_blocks_det = 1 self._n_blocks_time = 1 + self._set_mpi_subcommunicators() + self.tod_list = tods for cur_tod in self.tod_list: if allocate_tod: @@ -159,8 +183,17 @@ def __init__( setattr(self, cur_tod.name, None) self.setattr_det_global("det_idx", np.arange(self._n_detectors_global), root) + + self.detectors_global = [] + + if self.detector_blocks is not None: + for key in self.detector_blocks: + self.detectors_global += self.detector_blocks[key] + else: + self.detectors_global = detectors + if not isinstance(detectors, int): - self._set_attributes_from_list_of_dict(detectors, root) + self._set_attributes_from_list_of_dict(self.detectors_global, root) ( self.start_time, @@ -176,7 +209,10 @@ def sampling_rate_hz(self): @property def n_detectors(self): - return len(self.det_idx) + if self.det_idx is None: + return 0 + else: + return len(self.det_idx) def _get_local_start_time_start_and_n_samples(self): _, _, start, num = self._get_start_and_num( @@ -203,7 +239,7 @@ def _get_local_start_time_start_and_n_samples(self): return self.start_time_global + start * delta, start, num def _set_attributes_from_list_of_dict(self, list_of_dict, root): - assert len(list_of_dict) == self.n_detectors_global + np.testing.assert_equal(len(list_of_dict), self.n_detectors_global) # Turn list of dict into dict of arrays if not self.comm or self.comm.rank == root: @@ -273,10 +309,86 @@ def n_blocks_time(self): def n_blocks_det(self): return self._n_blocks_det + def _make_detector_blocks(self, detectors, comm): + """This function distributes the detectors in groups such that each + group has the same set of attributes specified by the strings in + `self._det_block_attributes`. Once the groups are made, the number of + detector blocks is set to be the total number of detector groups, + whereas the number of time blocks is computed using the number of + detector blocks and the size of `comm` communicator. + + The detector blocks are stored in `self.detector_blocks`. This + dictionary object has the tuple of `self._det_blocks_attributes` values + as dictionary keys and the list of detectors corresponding to the key + as dictionary values. This dictionary is sorted so that the + group with the largest number of detectors comes first and the one with + the fewest detectors comes last. + + Example: + For + + ``` + detectors = [ + "000_002_123_xx_140_x", + "000_005_321_xx_140_x", + "000_004_456_xx_119_x", + "000_002_654_xx_140_x", + "000_004_789_xx_119_x", + ] + ``` + + and `self._det_blocks_attributes = ["channel", "wafer"]`, + `_make_detector_blocks()` will set + + ``` + self.detector_blocks = { + ("140", "L02"): ["000_002_123_xx_140_x", "000_002_654_xx_140_x"], + ("119", "L04"): ["000_004_456_xx_119_x", "000_004_789_xx_119_x"], + ("140", "L05"): ["000_005_321_xx_140_x"], + } + ``` + + and return `n_blocks_det = 3` + + Args: + detectors (List[dict]): List of detectors + + comm: The MPI communicator + + Returns: + n_blocks_det (int): Number of detector blocks + + n_blocks_time (int): Number of time blocks + + """ + self.detector_blocks = defaultdict(list) + for det in detectors: + key = tuple(det[attribute] for attribute in self._det_blocks_attributes) + self.detector_blocks[key].append(det) + + self.detector_blocks = dict( + sorted( + self.detector_blocks.items(), + key=lambda item: len(item[1]), + reverse=True, + ) + ) + n_blocks_det = len(self.detector_blocks) + n_blocks_time = comm.size // n_blocks_det + + return n_blocks_det, n_blocks_time + def _check_blocks(self, n_blocks_det, n_blocks_time): if self.comm is None: if n_blocks_det != 1 or n_blocks_time != 1: raise ValueError("Only one block allowed without an MPI comm") + elif n_blocks_det == 0 or n_blocks_time == 0: + raise ValueError( + "The number of detector blocks and the number of time blocks " + "must be must be non-zero\n" + f"n_blocks_det = {n_blocks_det}, " + f"n_blocks_time = {n_blocks_time}" + ) elif n_blocks_det > self.n_detectors_global: raise ValueError( "You can not have more detector blocks than detectors " @@ -296,21 +408,33 @@ def _check_blocks(self, n_blocks_det, n_blocks_time): def _get_start_and_num(self, n_blocks_det, n_blocks_time): """For both detectors and time, returns the starting (global) - index and lenght of each block if the number of blocks is changed to the + index and length of each block if the number of blocks is changed to the values passed as arguments """ - det_start, det_n = np.array( - [ - [span.start_idx, span.num_of_elements] - for span in distribute_evenly(self._n_detectors_global, n_blocks_det) - ] - ).T + if self._det_blocks_attributes is None or self.comm.size == 1: + det_start, det_n = np.array( + [ + [span.start_idx, span.num_of_elements] + for span in distribute_evenly( + self._n_detectors_global, n_blocks_det + ) + ] + ).T + else: + det_start, det_n = np.array( + [ + [span.start_idx, span.num_of_elements] + for span in distribute_detector_blocks(self.detector_blocks) + ] + ).T + time_start, time_n = np.array( [ [span.start_idx, span.num_of_elements] for span in distribute_evenly(self._n_samples_global, n_blocks_time) ] ).T + return ( np.array(det_start), np.array(det_n), @@ -327,6 +451,7 @@ def _get_tod_shape(self, n_blocks_det, n_blocks_time): return (self._n_detectors_global, self._n_samples_global) _, det_n, _, time_n = self._get_start_and_num(n_blocks_det, n_blocks_time) + try: return ( det_n[self.comm.rank // n_blocks_time], @@ -550,26 +675,27 @@ def setattr_det_global(self, name, info, root=0): setattr(self, name, info) return - is_in_grid = self.comm.rank < self._n_blocks_det * self._n_blocks_time - comm_grid = self.comm.Split(int(is_in_grid)) - if not is_in_grid: # The process does not own any detector (and TOD) - setattr(self, name, None) + if ( + MPI_COMM_GRID.COMM_OBS_GRID == MPI_COMM_GRID.COMM_NULL + ): # The process does not own any detector (and TOD) + null_det = DetectorInfo() + attribute = getattr(null_det, name, None) + value = 0 if isinstance(attribute, numbers.Number) else None + setattr(self, name, value) return - my_col = comm_grid.rank % self._n_blocks_time - comm_col = comm_grid.Split(my_col) + my_col = MPI_COMM_GRID.COMM_OBS_GRID.rank % self._n_blocks_time root_col = root // self._n_blocks_det if my_col == root_col: - if comm_grid.rank == root: + if MPI_COMM_GRID.COMM_OBS_GRID.rank == root: starts, nums, _, _ = self._get_start_and_num( self._n_blocks_det, self._n_blocks_time ) info = [info[s : s + n] for s, n in zip(starts, nums)] - info = comm_col.scatter(info, root) + info = self.comm_time_block.scatter(info, root) - comm_row = comm_grid.Split(comm_grid.rank // self._n_blocks_time) - info = comm_row.bcast(info, root_col) + info = self.comm_det_block.bcast(info, root_col) assert (not self.tod_list) or len(info) == len( getattr(self, self.tod_list[0].name) ) @@ -662,7 +788,7 @@ def get_pointings( pointing_buffer: Optional[npt.NDArray] = None, hwp_buffer: Optional[npt.NDArray] = None, pointings_dtype=np.float32, - ) -> (npt.NDArray, Optional[npt.NDArray]): + ) -> tuple[npt.NDArray, Optional[npt.NDArray]]: """ Compute the pointings for one or more detectors in this observation @@ -785,3 +911,55 @@ def get_pointings( ) return pointing_buffer, hwp_buffer + + def _set_mpi_subcommunicators(self): + """ + This function splits the global MPI communicator into three kinds of + sub-communicators: + + 1. A sub-communicator containing all the processes with global rank less than + `n_blocks_det * n_blocks_time`. Outside of this global rank, the + sub-communicator is NULL. + + 2. A sub-communicator for each block of detectors, that contains all the + processes corresponding to that detector block. If a process doesn't + contain a detector, the sub-communicator is NULL. + + 3. A sub-communicator for each block of time that contains all the processes + corresponding to that time block. If a process doesn't contain a detector, + the sub-communicator is NULL. + """ + + # Set the detector and time block sub-communicators to + # `_SerialMpiCommunicator()` when MPI is not being used + self.comm_det_block = _SerialMpiCommunicator() + self.comm_time_block = _SerialMpiCommunicator() + + if self.comm and self.comm.size > 1: + if self.comm.rank < self.n_blocks_det * self.n_blocks_time: + matrix_color = 1 + else: + from .mpi import MPI + + matrix_color = MPI.UNDEFINED + + # Case1: For `0 < rank < n_blocks_det * n_blocks_time`, + # `comm_obs_grid` is a sub-communicator that includes processes + # from rank 0 to `n_blocks_det * n_blocks_time - 1`. + # Case 2: For `n_blocks_det * n_blocks_time <= rank < comm.size`, + # `comm_obs_grid = MPI.COMM_NULL` + comm_obs_grid = self.comm.Split(matrix_color, self.comm.rank) + MPI_COMM_GRID._set_comm_obs_grid(comm_obs_grid=comm_obs_grid) + + # If the `MPI_COMM_GRID.COMM_OBS_GRID` is not NULL, we split it in + # communicators corresponding to each detector and time block + # If `MPI_COMM_GRID.COMM_OBS_GRID` is NULL, we set the communicators + # corresponding to detector and time blocks to NULL. + if MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: + det_color = MPI_COMM_GRID.COMM_OBS_GRID.rank // self.n_blocks_time + time_color = MPI_COMM_GRID.COMM_OBS_GRID.rank % self.n_blocks_time + self.comm_det_block = MPI_COMM_GRID.COMM_OBS_GRID.Split(det_color) + self.comm_time_block = MPI_COMM_GRID.COMM_OBS_GRID.Split(time_color) + else: + self.comm_det_block = MPI_COMM_GRID.COMM_NULL + self.comm_time_block = MPI_COMM_GRID.COMM_NULL diff --git a/litebird_sim/simulations.py b/litebird_sim/simulations.py index 7a213a0a..51dde0cf 100644 --- a/litebird_sim/simulations.py +++ b/litebird_sim/simulations.py @@ -44,7 +44,7 @@ DestriperResult, destriper_log_callback, ) -from .mpi import MPI_ENABLED, MPI_COMM_WORLD +from .mpi import MPI_ENABLED, MPI_COMM_WORLD, MPI_COMM_GRID from .noise import add_noise_to_observations from .observations import Observation, TodDescription from .pointings_in_obs import prepare_pointings, precompute_pointings @@ -889,6 +889,7 @@ def create_observations( detectors: List[DetectorInfo], num_of_obs_per_detector: int = 1, split_list_over_processes=True, + det_blocks_attributes: Union[List[str], None] = None, n_blocks_det=1, n_blocks_time=1, root=0, @@ -922,7 +923,12 @@ def create_observations( simulating 10 detectors and you specify ``n_blocks_det=5``, this means that each observation will handle ``10 / 5 = 2`` detectors. The default is that *all* the detectors be kept - together (``n_blocks_det=1``). + together (``n_blocks_det=1``). On the other hand, the parameter + `det_blocks_attributes` specifies the list of detector attributes + to create the groups of detectors. For example, with + ``det_blocks_attributes = ["wafer", "pixel"]``, the detectors will + be divided into groups such that all detectors in a group will + have the same ``wafer`` and ``pixel`` attribute. The parameter `n_blocks_time` specifies the number of time splits of the observations. In the case of a 3-month-long @@ -1013,6 +1019,7 @@ def create_observations( start_time_global=cur_time, sampling_rate_hz=sampfreq_hz, n_samples_global=nsamples, + det_blocks_attributes=det_blocks_attributes, n_blocks_det=n_blocks_det, n_blocks_time=n_blocks_time, comm=(None if split_list_over_processes else self.mpi_comm), @@ -1214,7 +1221,8 @@ def set_scanning_strategy( num_of_obs = len(self.observations) if append_to_report and MPI_ENABLED: - num_of_obs = MPI_COMM_WORLD.allreduce(num_of_obs) + if MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: + num_of_obs = MPI_COMM_GRID.COMM_OBS_GRID.allreduce(num_of_obs) if append_to_report and MPI_COMM_WORLD.rank == 0: template_file_path = get_template_file_path("report_quaternions.md") @@ -1311,8 +1319,11 @@ def prepare_pointings( memory_occupation = pointing_provider.bore2ecliptic_quats.quats.nbytes num_of_obs = len(self.observations) if append_to_report and MPI_ENABLED: - memory_occupation = MPI_COMM_WORLD.allreduce(memory_occupation) - num_of_obs = MPI_COMM_WORLD.allreduce(num_of_obs) + if MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: + memory_occupation = MPI_COMM_GRID.COMM_OBS_GRID.allreduce( + memory_occupation + ) + num_of_obs = MPI_COMM_GRID.COMM_OBS_GRID.allreduce(num_of_obs) if append_to_report and MPI_COMM_WORLD.rank == 0: template_file_path = get_template_file_path("report_pointings.md") diff --git a/test/test_detector_blocks.py b/test/test_detector_blocks.py new file mode 100644 index 00000000..e5f9dce9 --- /dev/null +++ b/test/test_detector_blocks.py @@ -0,0 +1,158 @@ +import numpy as np +import litebird_sim as lbs + +# data for testing detector blocks and MPI sub-communicators +sampling_freq_Hz = 1 +dets = [ + lbs.DetectorInfo( + name="channel1_w9_detA", + wafer="wafer_9", + channel="channel1", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel1_w3_detB", + wafer="wafer_3", + channel="channel1", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel1_w1_detC", + wafer="wafer_1", + channel="channel1", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel1_w1_detD", + wafer="wafer_1", + channel="channel1", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel2_w4_detA", + wafer="wafer_4", + channel="channel2", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel2_w4_detB", + wafer="wafer_4", + channel="channel2", + sampling_rate_hz=sampling_freq_Hz, + ), +] + + +def test_detector_blocks(dets=dets, sampling_freq_Hz=sampling_freq_Hz): + comm = lbs.MPI_COMM_WORLD + + start_time = 456 + duration_s = 100 + nobs_per_det = 3 + + if comm.size > 4: + det_blocks_attribute = ["channel", "wafer"] + else: + det_blocks_attribute = ["channel"] + + sim = lbs.Simulation( + start_time=start_time, + duration_s=duration_s, + random_seed=12345, + mpi_comm=comm, + ) + + sim.create_observations( + detectors=dets, + split_list_over_processes=False, + num_of_obs_per_detector=nobs_per_det, + det_blocks_attributes=det_blocks_attribute, + ) + + tod_len_per_det_per_proc = 0 + for obs in sim.observations: + tod_shape = obs.tod.shape + + n_blocks_det = obs.n_blocks_det + n_blocks_time = obs.n_blocks_time + tod_len_per_det_per_proc += obs.tod.shape[1] + + # No testing required if the proc doesn't owns a detector + if obs.det_idx is not None: + det_names_per_obs = [ + obs.detectors_global[idx]["name"] for idx in obs.det_idx + ] + + # Testing if the mapping between the obs.name and + # obs.det_idx is consistent with obs.detectors_global + np.testing.assert_equal(obs.name, det_names_per_obs) + + # Testing the distribution of the number of detectors per + # detector block + np.testing.assert_equal(obs.name.shape[0], tod_shape[0]) + + # Testing if the distribution of samples along the time axis is consistent + if comm.rank < n_blocks_det * n_blocks_time: + arr = [ + span.num_of_elements + for span in lbs.distribute.distribute_evenly( + duration_s * sampling_freq_Hz, n_blocks_time * nobs_per_det + ) + ] + + start_idx = (comm.rank % n_blocks_time) * nobs_per_det + stop_idx = start_idx + nobs_per_det + np.testing.assert_equal(sum(arr[start_idx:stop_idx]), tod_len_per_det_per_proc) + + +def test_mpi_subcommunicators(dets=dets): + comm = lbs.MPI_COMM_WORLD + + start_time = 456 + duration_s = 100 + nobs_per_det = 3 + + if comm.size > 4: + det_blocks_attribute = ["channel", "wafer"] + else: + det_blocks_attribute = ["channel"] + + sim = lbs.Simulation( + start_time=start_time, + duration_s=duration_s, + random_seed=12345, + mpi_comm=comm, + ) + + sim.create_observations( + detectors=dets, + split_list_over_processes=False, + num_of_obs_per_detector=nobs_per_det, + det_blocks_attributes=det_blocks_attribute, + ) + + if lbs.MPI_COMM_GRID.COMM_OBS_GRID != lbs.MPI_COMM_GRID.COMM_NULL: + # since unused MPI processes stay at the end of global, + # communicator, the rank of the used processes in + # `MPI_COMM_GRID.COMM_OBS_GRID` must be same as their rank in + # global communicator + np.testing.assert_equal(lbs.MPI_COMM_GRID.COMM_OBS_GRID.rank, comm.rank) + + for obs in sim.observations: + # comm_det_block.rank + comm_time_block.rank * n_block_time + # must be equal to the global communicator rank for the + # used processes. It follows from the way split colors + # were defined. + np.testing.assert_equal( + obs.comm_det_block.rank + obs.comm_time_block.rank * obs.n_blocks_time, + comm.rank, + ) + else: + for obs in sim.observations: + # the global rank of the unused MPI processes must be larger than the number of used processes. + assert comm.rank > (obs.n_blocks_det * obs.n_blocks_time - 1) + + # The block communicators on the unused MPI processes must + # be the NULL communicators + np.testing.assert_equal(obs.comm_det_block, lbs.MPI_COMM_GRID.COMM_NULL) + np.testing.assert_equal(obs.comm_time_block, lbs.MPI_COMM_GRID.COMM_NULL)