From b06a95571c8c67660b3b29dfb44db28b912a4ca9 Mon Sep 17 00:00:00 2001 From: Andrea Sabatucci Date: Fri, 5 Jul 2024 11:27:55 +0200 Subject: [PATCH 01/11] Added the Squid attribute to DetectorInfo() class. --- litebird_sim/detectors.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/litebird_sim/detectors.py b/litebird_sim/detectors.py index a489c13b..f2d0d981 100644 --- a/litebird_sim/detectors.py +++ b/litebird_sim/detectors.py @@ -75,6 +75,9 @@ class DetectorInfo: - channel (Union[str, None]): The channel. The default is None + - squid (Union[int, None]): The squid number of the detector. + The default value is None. + - sampling_rate_hz (float): The sampling rate of the ADC associated with this detector. The default is 0.0 @@ -136,6 +139,7 @@ class DetectorInfo: pixel: Union[int, None] = None pixtype: Union[str, None] = None channel: Union[str, None] = None + squid: Union[int, None] = None sampling_rate_hz: float = 0.0 fwhm_arcmin: float = 0.0 ellipticity: float = 0.0 @@ -148,8 +152,6 @@ class DetectorInfo: fknee_mhz: float = 0.0 fmin_hz: float = 0.0 alpha: float = 0.0 - bandcenter_ghz: float = 0.0 - bandwidth_ghz: float = 0.0 pol: Union[str, None] = None orient: Union[str, None] = None quat: Any = None @@ -175,6 +177,7 @@ def from_dict(dictionary: Dict[str, Any]): - ``pixel`` - ``pixtype`` - ``channel`` + - ``squid`` - ``bandcenter_ghz`` - ``bandwidth_ghz`` - ``band_freqs_ghz`` From dc42297675c24ebe1e7319cc363fa6855861e5a6 Mon Sep 17 00:00:00 2001 From: Avinash Anand <36325275+anand-avinash@users.noreply.github.com> Date: Wed, 9 Oct 2024 14:37:42 +0900 Subject: [PATCH 02/11] Preliminary implementation of attribute-based detector distribution scheme --- litebird_sim/distribute.py | 39 ++++++++++ litebird_sim/observations.py | 144 +++++++++++++++++++++++++++++++---- litebird_sim/simulations.py | 9 ++- 3 files changed, 177 insertions(+), 15 deletions(-) diff --git a/litebird_sim/distribute.py b/litebird_sim/distribute.py index 3649a877..edf2ea43 100644 --- a/litebird_sim/distribute.py +++ b/litebird_sim/distribute.py @@ -84,6 +84,45 @@ def distribute_evenly(num_of_elements, num_of_groups): return result +def distribute_detector_blocks(detector_blocks): + """Similar to the function :func:`distribute_evenly()`, this function returns the named-tuples of the starting index of the detectors in a group + with respect to the global list of detectors and the number of detectors + in the group. Unlike the :func:`distribute_evenly()`, this function simply + uses the detector groups given in `detector_blocks` attribute. + + Example: + Following the example given in + :meth:`litebird_sim.Observation._make_detector_blocks()`, + `distribute_detector_blocks()` will return + + ``` + [ + Span(start_idx=0, num_of_elements=2), + Span(start_idx=2, num_of_elements=2), + Span(start_idx=4, num_of_elements=1), + ] + ``` + + Args: + detector_blocks (dict): The detector block object. See :meth:`litebird_sim.Observation._make_detector_blocks()`. + + Returns: + A list of 2-elements named-tuples containing (1) the starting index of + the detectors of the block with respect to the flatten list of entire + detector blocks and (2) the number of elements in the detector block. + """ + cur_position = 0 + prev_length = 0 + result = [] + for key in detector_blocks: + cur_length = len(detector_blocks[key]) + cur_position += prev_length + prev_length = cur_length + result.append(Span(start_idx=cur_position, num_of_elements=cur_length)) + + return result + + # The following implementation of the painter's partition problem is # heavily inspired by the code at # https://www.geeksforgeeks.org/painters-partition-problem-set-2/?ref=rp diff --git a/litebird_sim/observations.py b/litebird_sim/observations.py index 7997de56..2ad9cbf1 100644 --- a/litebird_sim/observations.py +++ b/litebird_sim/observations.py @@ -7,8 +7,10 @@ import numpy as np import numpy.typing as npt +from collections import defaultdict + from .coordinates import DEFAULT_TIME_SCALE -from .distribute import distribute_evenly +from .distribute import distribute_evenly, distribute_detector_blocks @dataclass @@ -80,11 +82,19 @@ class Observation: sampling_rate_hz (float): The sampling frequency, in Hertz. + det_blocks_attributes (list of strings): The list of detector attributes that + will be used to divide detector axis of the tod (and all their attributes). + For example, with ``det_blocks_attributes = ["wafer", "pixel"]``, the + detectors will be divided into the blocks such that all detectors in a + block will have same ``wafer`` and ``pixel`` attribute. + n_blocks_det (int): divide the detector axis of the tod (and all the - arrays of detector attributes) in `n_blocks_det` blocks + arrays of detector attributes) in `n_blocks_det` blocks. Will be ignored + if ``det_blocks_attributes`` is not `None`. n_blocks_time (int): divide the time axis of the tod in - `n_blocks_time` blocks + `n_blocks_time` blocks. Will be ignored if ``det_blocks_attributes`` + is not `None`. comm: either `None` (do not use MPI) or a MPI communicator object, like `mpi4py.MPI.COMM_WORLD`. Its size is required to be at @@ -103,6 +113,7 @@ def __init__( sampling_rate_hz: float, allocate_tod=True, tods=None, + det_blocks_attributes: Union[List[str], None] = None, n_blocks_det=1, n_blocks_time=1, comm=None, @@ -123,6 +134,10 @@ def __init__( delta = 1.0 / sampling_rate_hz self.end_time_global = start_time_global + n_samples_global * delta + self._sampling_rate_hz = sampling_rate_hz + self._det_blocks_attributes = det_blocks_attributes + self.detector_blocks = None + if isinstance(detectors, int): self._n_detectors_global = detectors else: @@ -131,9 +146,12 @@ def __init__( else: self._n_detectors_global = len(detectors) - self._sampling_rate_hz = sampling_rate_hz + if self._det_blocks_attributes is not None and comm.size > 1: + n_blocks_det, n_blocks_time = self._make_detector_blocks( + detectors, comm + ) - # Neme of the attributes that store an array with the value of a + # Name of the attributes that store an array with the value of a # property for each of the (local) detectors self._attr_det_names = [] self._check_blocks(n_blocks_det, n_blocks_time) @@ -159,8 +177,17 @@ def __init__( setattr(self, cur_tod.name, None) self.setattr_det_global("det_idx", np.arange(self._n_detectors_global), root) + + self.detectors_global = [] + + if self.detector_blocks is not None: + for key in self.detector_blocks: + self.detectors_global += self.detector_blocks[key] + else: + self.detectors_global = detectors + if not isinstance(detectors, int): - self._set_attributes_from_list_of_dict(detectors, root) + self._set_attributes_from_list_of_dict(self.detectors_global, root) ( self.start_time, @@ -203,7 +230,7 @@ def _get_local_start_time_start_and_n_samples(self): return self.start_time_global + start * delta, start, num def _set_attributes_from_list_of_dict(self, list_of_dict, root): - assert len(list_of_dict) == self.n_detectors_global + np.testing.assert_equal(len(list_of_dict), self.n_detectors_global) # Turn list of dict into dict of arrays if not self.comm or self.comm.rank == root: @@ -273,10 +300,86 @@ def n_blocks_time(self): def n_blocks_det(self): return self._n_blocks_det + def _make_detector_blocks(self, detectors, comm): + """This function distributes the detectors in groups such that each + group has same set of attributes specified by the strings in + `self._det_block_attributes`. Once the groups are made, number of + detector blocks is set to be the total number of detector groups + whereas the number of time blocks is computed using the number of + detector blocks and the size of `comm` communicator. + + The blocks of detectors are stored in `self.detector_blocks`. It is a + dictionary object with the tuple of `self._det_blocks_attributes` values + as dictionary keys and the list of detectors corresponding to the key + as the dictionary value. This dictionary is sorted so that that the + group with largest number of detectors comes first and the one with + the least number of detectors, comes last. + + Example: + For + + ``` + detectors = [ + "000_002_123_xx_140_x", + "000_005_321_xx_140_x", + "000_004_456_xx_119_x", + "000_002_654_xx_140_x", + "000_004_789_xx_119_x", + ] + ``` + + and `self._det_blocks_attributes = ["channel", "wafer"]`, + `_make_detector_blocks()` will set + + ``` + self.detector_blocks = { + ("140", "L02"): ["000_002_123_xx_140_x", "000_002_654_xx_140_x"], + ("119", "L04"): ["000_004_456_xx_119_x", "000_004_789_xx_119_x"], + ("140", "L05"): ["000_005_321_xx_140_x"], + } + ``` + + and return `n_blocks_det = 3` + + Args: + detectors (List[dict]): List of detectors + + comm: The MPI communicator + + Returns: + n_blocks_det (int): Number of detector blocks + + n_blocks_time (int): Number of time blocks + + """ + self.detector_blocks = defaultdict(list) + for det in detectors: + key = tuple(det[attribute] for attribute in self._det_blocks_attributes) + self.detector_blocks[key].append(det) + + self.detector_blocks = dict( + sorted( + self.detector_blocks.items(), + key=lambda item: len(item[1]), + reverse=True, + ) + ) + n_blocks_det = len(self.detector_blocks) + n_blocks_time = comm.size // n_blocks_det + + return n_blocks_det, n_blocks_time + def _check_blocks(self, n_blocks_det, n_blocks_time): if self.comm is None: if n_blocks_det != 1 or n_blocks_time != 1: raise ValueError("Only one block allowed without an MPI comm") + elif n_blocks_det == 0 or n_blocks_time == 0: + raise ValueError( + "The number of detector blocks and the number of time blocks " + "must be must be non-zero\n" + f"n_blocks_det = {n_blocks_det}, " + f"n_blocks_time = {n_blocks_time}" + ) elif n_blocks_det > self.n_detectors_global: raise ValueError( "You can not have more detector blocks than detectors " @@ -296,21 +399,33 @@ def _check_blocks(self, n_blocks_det, n_blocks_time): def _get_start_and_num(self, n_blocks_det, n_blocks_time): """For both detectors and time, returns the starting (global) - index and lenght of each block if the number of blocks is changed to the + index and length of each block if the number of blocks is changed to the values passed as arguments """ - det_start, det_n = np.array( - [ - [span.start_idx, span.num_of_elements] - for span in distribute_evenly(self._n_detectors_global, n_blocks_det) - ] - ).T + if self._det_blocks_attributes is None or self.comm.size == 1: + det_start, det_n = np.array( + [ + [span.start_idx, span.num_of_elements] + for span in distribute_evenly( + self._n_detectors_global, n_blocks_det + ) + ] + ).T + else: + det_start, det_n = np.array( + [ + [span.start_idx, span.num_of_elements] + for span in distribute_detector_blocks(self.detector_blocks) + ] + ).T + time_start, time_n = np.array( [ [span.start_idx, span.num_of_elements] for span in distribute_evenly(self._n_samples_global, n_blocks_time) ] ).T + return ( np.array(det_start), np.array(det_n), @@ -327,6 +442,7 @@ def _get_tod_shape(self, n_blocks_det, n_blocks_time): return (self._n_detectors_global, self._n_samples_global) _, det_n, _, time_n = self._get_start_and_num(n_blocks_det, n_blocks_time) + try: return ( det_n[self.comm.rank // n_blocks_time], diff --git a/litebird_sim/simulations.py b/litebird_sim/simulations.py index 7a213a0a..b98a6eec 100644 --- a/litebird_sim/simulations.py +++ b/litebird_sim/simulations.py @@ -889,6 +889,7 @@ def create_observations( detectors: List[DetectorInfo], num_of_obs_per_detector: int = 1, split_list_over_processes=True, + det_blocks_attributes: Union[List[str], None] = None, n_blocks_det=1, n_blocks_time=1, root=0, @@ -922,7 +923,12 @@ def create_observations( simulating 10 detectors and you specify ``n_blocks_det=5``, this means that each observation will handle ``10 / 5 = 2`` detectors. The default is that *all* the detectors be kept - together (``n_blocks_det=1``). + together (``n_blocks_det=1``). On the other hand, the parameter + `det_blocks_attributes` specifies the list of detector attributes + to be used to create the groups of detectors. For example, with + ``det_blocks_attributes = ["wafer", "pixel"]``, the detectors will + be divided into the groups such that all detectors in a group will + have same ``wafer`` and ``pixel`` attribute. The parameter `n_blocks_time` specifies the number of time splits of the observations. In the case of a 3-month-long @@ -1013,6 +1019,7 @@ def create_observations( start_time_global=cur_time, sampling_rate_hz=sampfreq_hz, n_samples_global=nsamples, + det_blocks_attributes=det_blocks_attributes, n_blocks_det=n_blocks_det, n_blocks_time=n_blocks_time, comm=(None if split_list_over_processes else self.mpi_comm), From e87dd3fe3d68903d07d8367b68cc5d9a0f062efe Mon Sep 17 00:00:00 2001 From: Avinash Anand <36325275+anand-avinash@users.noreply.github.com> Date: Tue, 15 Oct 2024 13:26:13 +0900 Subject: [PATCH 03/11] resolved the review request --- litebird_sim/distribute.py | 14 ++++++++------ litebird_sim/observations.py | 37 ++++++++++++++++++------------------ 2 files changed, 27 insertions(+), 24 deletions(-) diff --git a/litebird_sim/distribute.py b/litebird_sim/distribute.py index edf2ea43..ef893d82 100644 --- a/litebird_sim/distribute.py +++ b/litebird_sim/distribute.py @@ -50,7 +50,7 @@ def distribute_evenly(num_of_elements, num_of_groups): # If leftovers == 0, then the number of elements is divided evenly # by num_of_groups, and the solution is trivial. If it's not, then - # each of the "leftoverss" is placed in one of the first groups. + # each of the "leftovers" is placed in one of the first groups. # # Example: let's split 8 elements in 3 groups. In this case, # base_length=2 and leftovers=2 (elements #7 and #8): @@ -68,7 +68,7 @@ def distribute_evenly(num_of_elements, num_of_groups): cur_length = base_length + 1 cur_pos = cur_length * i else: - # No need to accomodate for leftovers, but consider their + # No need to accommodate for leftovers, but consider their # presence in fixing the starting position for this group cur_length = base_length cur_pos = base_length * i + leftovers @@ -85,10 +85,12 @@ def distribute_evenly(num_of_elements, num_of_groups): def distribute_detector_blocks(detector_blocks): - """Similar to the function :func:`distribute_evenly()`, this function returns the named-tuples of the starting index of the detectors in a group - with respect to the global list of detectors and the number of detectors - in the group. Unlike the :func:`distribute_evenly()`, this function simply - uses the detector groups given in `detector_blocks` attribute. + """Similar to the :func:`distribute_evenly()` function, this function + returns a list of named-tuples, with fields `start_idx` (the starting + index of the detector in a group within the global list of detectors) and + num_of_elements` (the number of detectors in the group). Unlike + :func:`distribute_evenly()`, this function simply uses the detector groups + given in the `detector_blocks` attribute. Example: Following the example given in diff --git a/litebird_sim/observations.py b/litebird_sim/observations.py index 2ad9cbf1..144823a1 100644 --- a/litebird_sim/observations.py +++ b/litebird_sim/observations.py @@ -82,19 +82,20 @@ class Observation: sampling_rate_hz (float): The sampling frequency, in Hertz. - det_blocks_attributes (list of strings): The list of detector attributes that - will be used to divide detector axis of the tod (and all their attributes). - For example, with ``det_blocks_attributes = ["wafer", "pixel"]``, the - detectors will be divided into the blocks such that all detectors in a - block will have same ``wafer`` and ``pixel`` attribute. + det_blocks_attributes (list of strings): The list of detector + attributes that will be used to divide the detector axis of the + tod matrix and all its attributes. For example, with + ``det_blocks_attributes = ["wafer", "pixel"]``, the detectors will + be divided into blocks such that all detectors in a block will + have the same ``wafer`` and ``pixel`` attribute. n_blocks_det (int): divide the detector axis of the tod (and all the - arrays of detector attributes) in `n_blocks_det` blocks. Will be ignored - if ``det_blocks_attributes`` is not `None`. + arrays of detector attributes) in `n_blocks_det` blocks. It will + be ignored if ``det_blocks_attributes`` is not `None`. n_blocks_time (int): divide the time axis of the tod in - `n_blocks_time` blocks. Will be ignored if ``det_blocks_attributes`` - is not `None`. + `n_blocks_time` blocks. It will be ignored + if ``det_blocks_attributes`` is not `None`. comm: either `None` (do not use MPI) or a MPI communicator object, like `mpi4py.MPI.COMM_WORLD`. Its size is required to be at @@ -302,18 +303,18 @@ def n_blocks_det(self): def _make_detector_blocks(self, detectors, comm): """This function distributes the detectors in groups such that each - group has same set of attributes specified by the strings in - `self._det_block_attributes`. Once the groups are made, number of - detector blocks is set to be the total number of detector groups + group has the same set of attributes specified by the strings in + `self._det_block_attributes`. Once the groups are made, the number of + detector blocks is set to be the total number of detector groups, whereas the number of time blocks is computed using the number of detector blocks and the size of `comm` communicator. - The blocks of detectors are stored in `self.detector_blocks`. It is a - dictionary object with the tuple of `self._det_blocks_attributes` values + The detector blocks are stored in `self.detector_blocks`. This + dictionary object has the tuple of `self._det_blocks_attributes` values as dictionary keys and the list of detectors corresponding to the key - as the dictionary value. This dictionary is sorted so that that the - group with largest number of detectors comes first and the one with - the least number of detectors, comes last. + as dictionary values. This dictionary is sorted so that the + group with the largest number of detectors comes first and the one with + the fewest detectors comes last. Example: For @@ -778,7 +779,7 @@ def get_pointings( pointing_buffer: Optional[npt.NDArray] = None, hwp_buffer: Optional[npt.NDArray] = None, pointings_dtype=np.float32, - ) -> (npt.NDArray, Optional[npt.NDArray]): + ) -> tuple[npt.NDArray, Optional[npt.NDArray]]: """ Compute the pointings for one or more detectors in this observation From 35e5d4d6106e323707dacfdcd2e772359e1c4c49 Mon Sep 17 00:00:00 2001 From: Avinash Anand <36325275+anand-avinash@users.noreply.github.com> Date: Tue, 15 Oct 2024 13:31:17 +0900 Subject: [PATCH 04/11] resolved the review request --- litebird_sim/simulations.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/litebird_sim/simulations.py b/litebird_sim/simulations.py index b98a6eec..e11267b4 100644 --- a/litebird_sim/simulations.py +++ b/litebird_sim/simulations.py @@ -925,10 +925,10 @@ def create_observations( detectors. The default is that *all* the detectors be kept together (``n_blocks_det=1``). On the other hand, the parameter `det_blocks_attributes` specifies the list of detector attributes - to be used to create the groups of detectors. For example, with + to create the groups of detectors. For example, with ``det_blocks_attributes = ["wafer", "pixel"]``, the detectors will - be divided into the groups such that all detectors in a group will - have same ``wafer`` and ``pixel`` attribute. + be divided into groups such that all detectors in a group will + have the same ``wafer`` and ``pixel`` attribute. The parameter `n_blocks_time` specifies the number of time splits of the observations. In the case of a 3-month-long From 2eeab2c4b3f49e119b3f1f2872387b4ac27cf0b8 Mon Sep 17 00:00:00 2001 From: Avinash Anand <36325275+anand-avinash@users.noreply.github.com> Date: Wed, 30 Oct 2024 15:02:45 +0900 Subject: [PATCH 05/11] added the documentation for detector blocks; minor fix for obs.det_idx --- docs/source/observations.rst | 244 ++++++++++++++++++++++++++++++++--- litebird_sim/observations.py | 5 +- 2 files changed, 230 insertions(+), 19 deletions(-) diff --git a/docs/source/observations.rst b/docs/source/observations.rst index 9afe6798..a7c5bf74 100644 --- a/docs/source/observations.rst +++ b/docs/source/observations.rst @@ -97,8 +97,15 @@ With this memory layout, typical operations look like this:: Parallel applications --------------------- -The only work that the :class:`.Observation` class actually does is handling -parallelism. ``obs.tod`` can be distributed over a +The :class:`.Observation` class allows the distribution of ``obs.tod`` over multiple MPI +processes to enable the parallelization of computations. The distribution of ``obs.tod`` +can be achieved in two different ways: + +1. Uniform distribution of detectors along the detector axis +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +With ``n_blocks_det`` and ``n_blocks_time`` arguments of :class:`.Observation` class, +the ``obs.tod`` is evenly distributed over a ``n_blocks_det`` by ``n_blocks_time`` grid of MPI ranks. The blocks can be changed at run-time. @@ -111,7 +118,7 @@ The main advantage is that the example operations in the Serial section are achieved with the same lines of code. The price to pay is that you have to set detector properties with special methods. -:: +.. code-block:: python import litebird_sim as lbs from mpi4py import MPI @@ -158,21 +165,222 @@ TOD) gets distributed. .. image:: ./images/observation_data_distribution.png -When ``n_blocks_det != 1``, keep in mind that ``obs.tod[0]`` or -``obs.wn_levels[0]`` are quantities of the first *local* detector, not global. -This should not be a problem as the only thing that matters is that the two -quantities refer to the same detector. If you need the global detector index, -you can get it with ``obs.det_idx[0]``, which is created -at construction time. - -To get a better understanding of how observations are being used in a -MPI simulation, use the method :meth:`.Simulation.describe_mpi_distribution`. -This method must be called *after* the observations have been allocated using -:meth:`.Simulation.create_observations`; it will return an instance of the -class :class:`.MpiDistributionDescr`, which can be inspected to determine -which detectors and time spans are covered by each observation in all the -MPI processes that are being used. For more information, refer to the Section -:ref:`simulations`. +2. Custom grouping of detectors along the detector axis +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +While uniform distribution of detectors along the detector axis optimizes load +balancing, it is less suitable for simulating the some effects, like crosstalk and +noise correlation between the detectors. This uniform distribution across MPI +processes necessitates the transfer of large TOD arrays across multiple MPI processes, +which complicates the code implementation and may potentially lead to significant +performance overhead. To save us from this situation, the :class:`Observation` class +accepts an argument ``det_blocks_attributes`` that is a list of string objects +specifying the detector attributes to create the group of detectors. Once the +detector groups are made, the detectors are distributed to the MPI processes in such +a way that all the detectors of a group reside on the same MPI process. + +If a valid ``det_blocks_attributes`` argument is passed to the :class:`Observation` +class, the arguments ``n_blocks_det`` and ``n_blocks_time`` are ignored. Since the +``det_blocks_attributes`` creates the detector blocks dynamically, the +``n_blocks_time`` is computed during runtime using the size of MPI communicator and +the number of detector blocks (``n_blocks_time = comm.size // n_blocks_det``). + +The detector blocks made in this way can be accessed with +``Observation.detector_blocks``. It is a dictionary object has the tuple of +``self._det_blocks_attributes`` values as dictionary keys and the list of detectors +corresponding to the key as dictionary values. This dictionary is sorted so that the +group with the largest number of detectors comes first and the one with +the fewest detectors comes last. + +The following example illustrates the distribution of ``obs.tod`` matrix across the +MPI processes when ``det_blocks_attributes`` is specified. + +.. code-block:: python + + import litebird_sim as lbs + + comm = lbs.MPI_COMM_WORLD + + start_time = 456 + duration_s = 100 + sampling_freq_Hz = 1 + + # Creating a list of detectors. + dets = [ + lbs.DetectorInfo( + name="channel1_w9_detA", + wafer="wafer_9", + channel="channel1", + sampling_rate_hz=1, + ), + lbs.DetectorInfo( + name="channel1_w3_detB", + wafer="wafer_3", + channel="channel1", + sampling_rate_hz=1, + ), + lbs.DetectorInfo( + name="channel1_w1_detC", + wafer="wafer_1", + channel="channel1", + sampling_rate_hz=1, + ), + lbs.DetectorInfo( + name="channel1_w1_detD", + wafer="wafer_1", + channel="channel1", + sampling_rate_hz=1, + ), + lbs.DetectorInfo( + name="channel2_w4_detA", + wafer="wafer_4", + channel="channel2", + sampling_rate_hz=1, + ), + lbs.DetectorInfo( + name="channel2_w4_detB", + wafer="wafer_4", + channel="channel2", + sampling_rate_hz=1, + ), + ] + + # Initializing a simulation + sim = lbs.Simulation( + start_time=start_time, + duration_s=duration_s, + random_seed=12345, + mpi_comm=comm, + ) + + # Creating the observations with detector blocks + sim.create_observations( + detectors=dets, + split_list_over_processes=False, + num_of_obs_per_detector=3, + det_blocks_attributes=["channel"], # case 1 and 2 + # det_blocks_attributes=["channel", "wafer"] # case 3 + ) + +With the list of detectors defined in the code snippet above, we can see how the +detectors axis and time axis is divided depending on the size of MPI communicator and +``det_blocks_attributes``. + +**Case 1** + +*Size of MPI communicator = 3*, ``det_blocks_attributes=["channel"]`` + +:: + + Detector axis ---> + Two blocks ---> + +------------------+ +------------------+ + | Rank 0 | | Rank 1 | + +------------------+ +------------------+ + | channel1_w9_detA | | channel2_w4_detA | + T + + + + + i O | channel1_w3_detB | | channel2_w4_detB | + m n + + +------------------+ + e e | channel1_w1_detC | + + + + a b | channel1_w1_detD | + x l +------------------+ + i o + s c ........................................... + k + | | +------------------+ + | | | Rank 2 | + ⋎ ⋎ +------------------+ + | (Unused) | + +------------------+ + +**Case 2** + +*Size of MPI communicator = 4*, ``det_blocks_attributes=["channel"]`` + +:: + + Detector axis ---> + Two blocks ---> + +------------------+ +------------------+ + | Rank 0 | | Rank 2 | + +------------------+ +------------------+ + | channel1_w9_detA | | channel2_w4_detA | + + + + + + | channel1_w3_detB | | channel2_w4_detB | + T + + +------------------+ + i T | channel1_w1_detC | + m w + + + e o | channel1_w1_detD | + +------------------+ + a b + x l ........................................... + i o + s c +------------------+ +------------------+ + k | Rank 1 | | Rank 3 | + | | +------------------+ +------------------+ + | | | channel1_w9_detA | | channel2_w4_detA | + ⋎ ⋎ + + + + + | channel1_w3_detB | | channel2_w4_detB | + + + +------------------+ + | channel1_w1_detC | + + + + | channel1_w1_detD | + +------------------+ + +**Case 3** + +*Size of MPI communicator = 10*, ``det_blocks_attributes=["channel", "wafer"]`` + +:: + + Detector axis ---> + Four blocks ---> + +------------------+ +------------------+ +------------------+ +------------------+ + | Rank 0 | | Rank 2 | | Rank 4 | | Rank 6 | + +------------------+ +------------------+ +------------------+ +------------------+ + T | channel1_w1_detC | | channel2_w4_detA | | channel1_w9_detA | | channel1_w3_detB | + i T + + + + +------------------+ +------------------+ + m w | channel1_w1_detD | | channel2_w4_detB | + e o +------------------+ +------------------+ + + ......................................................................................... + + a b +------------------+ +------------------+ +------------------+ +------------------+ + x l | Rank 1 | | Rank 3 | | Rank 5 | | Rank 7 | + i o +------------------+ +------------------+ +------------------+ +------------------+ + s c | channel1_w1_detC | | channel2_w4_detA | | channel1_w9_detA | | channel1_w3_detB | + k + + + + +------------------+ +------------------+ + | | | channel1_w1_detD | | channel2_w4_detB | + | | +------------------+ +------------------+ + ⋎ ⋎ + ......................................................................................... + + +------------------+ +------------------+ + | Rank 8 | | Rank 9 | + +------------------+ +------------------+ + | (Unused) | | (Unused) | + +------------------+ +------------------+ + +.. note:: + When ``n_blocks_det != 1``, keep in mind that ``obs.tod[0]`` or + ``obs.wn_levels[0]`` are quantities of the first *local* detector, not global. + This should not be a problem as the only thing that matters is that the two + quantities refer to the same detector. If you need the global detector index, + you can get it with ``obs.det_idx[0]``, which is created at construction time. + ``obs.det_idx`` stores the detector indices of the detectors available to an + :class:`Observation` class, with respect to the list of detectors stored in + ``obs.detectors_global`` variable. + +.. note:: + To get a better understanding of how observations are being used in a + MPI simulation, use the method :meth:`.Simulation.describe_mpi_distribution`. + This method must be called *after* the observations have been allocated using + :meth:`.Simulation.create_observations`; it will return an instance of the + class :class:`.MpiDistributionDescr`, which can be inspected to determine + which detectors and time spans are covered by each observation in all the + MPI processes that are being used. For more information, refer to the Section + :ref:`simulations`. Other notable functionalities ----------------------------- diff --git a/litebird_sim/observations.py b/litebird_sim/observations.py index 144823a1..f62df323 100644 --- a/litebird_sim/observations.py +++ b/litebird_sim/observations.py @@ -204,7 +204,10 @@ def sampling_rate_hz(self): @property def n_detectors(self): - return len(self.det_idx) + if self.det_idx is None: + return 0 + else: + return len(self.det_idx) def _get_local_start_time_start_and_n_samples(self): _, _, start, num = self._get_start_and_num( From a941fd0f7fbf72f441c17dd4e2edfe732fb9909b Mon Sep 17 00:00:00 2001 From: Avinash Anand <36325275+anand-avinash@users.noreply.github.com> Date: Wed, 30 Oct 2024 21:08:24 +0900 Subject: [PATCH 06/11] added test for detector block distribution; fixed the assignment of some attributes in observations.py --- docs/source/observations.rst | 14 ++--- litebird_sim/observations.py | 21 +++++--- test/test_detector_blocks.py | 100 +++++++++++++++++++++++++++++++++++ 3 files changed, 120 insertions(+), 15 deletions(-) create mode 100644 test/test_detector_blocks.py diff --git a/docs/source/observations.rst b/docs/source/observations.rst index a7c5bf74..48b4e4ef 100644 --- a/docs/source/observations.rst +++ b/docs/source/observations.rst @@ -187,7 +187,7 @@ the number of detector blocks (``n_blocks_time = comm.size // n_blocks_det``). The detector blocks made in this way can be accessed with ``Observation.detector_blocks``. It is a dictionary object has the tuple of -``self._det_blocks_attributes`` values as dictionary keys and the list of detectors +``det_blocks_attributes`` values as dictionary keys and the list of detectors corresponding to the key as dictionary values. This dictionary is sorted so that the group with the largest number of detectors comes first and the one with the fewest detectors comes last. @@ -211,37 +211,37 @@ MPI processes when ``det_blocks_attributes`` is specified. name="channel1_w9_detA", wafer="wafer_9", channel="channel1", - sampling_rate_hz=1, + sampling_rate_hz=sampling_freq_Hz, ), lbs.DetectorInfo( name="channel1_w3_detB", wafer="wafer_3", channel="channel1", - sampling_rate_hz=1, + sampling_rate_hz=sampling_freq_Hz, ), lbs.DetectorInfo( name="channel1_w1_detC", wafer="wafer_1", channel="channel1", - sampling_rate_hz=1, + sampling_rate_hz=sampling_freq_Hz, ), lbs.DetectorInfo( name="channel1_w1_detD", wafer="wafer_1", channel="channel1", - sampling_rate_hz=1, + sampling_rate_hz=sampling_freq_Hz, ), lbs.DetectorInfo( name="channel2_w4_detA", wafer="wafer_4", channel="channel2", - sampling_rate_hz=1, + sampling_rate_hz=sampling_freq_Hz, ), lbs.DetectorInfo( name="channel2_w4_detB", wafer="wafer_4", channel="channel2", - sampling_rate_hz=1, + sampling_rate_hz=sampling_freq_Hz, ), ] diff --git a/litebird_sim/observations.py b/litebird_sim/observations.py index f62df323..7485a940 100644 --- a/litebird_sim/observations.py +++ b/litebird_sim/observations.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from typing import Union, List, Any, Optional +import numbers import astropy.time import numpy as np @@ -11,6 +12,7 @@ from .coordinates import DEFAULT_TIME_SCALE from .distribute import distribute_evenly, distribute_detector_blocks +from .detectors import DetectorInfo @dataclass @@ -142,16 +144,16 @@ def __init__( if isinstance(detectors, int): self._n_detectors_global = detectors else: - if comm and comm.size > 1: - self._n_detectors_global = comm.bcast(len(detectors), root) + if self.comm and self.comm.size > 1: + self._n_detectors_global = self.comm.bcast(len(detectors), root) + + if self._det_blocks_attributes is not None: + n_blocks_det, n_blocks_time = self._make_detector_blocks( + detectors, self.comm + ) else: self._n_detectors_global = len(detectors) - if self._det_blocks_attributes is not None and comm.size > 1: - n_blocks_det, n_blocks_time = self._make_detector_blocks( - detectors, comm - ) - # Name of the attributes that store an array with the value of a # property for each of the (local) detectors self._attr_det_names = [] @@ -673,7 +675,10 @@ def setattr_det_global(self, name, info, root=0): is_in_grid = self.comm.rank < self._n_blocks_det * self._n_blocks_time comm_grid = self.comm.Split(int(is_in_grid)) if not is_in_grid: # The process does not own any detector (and TOD) - setattr(self, name, None) + null_det = DetectorInfo() + attribute = getattr(null_det, name, None) + value = 0 if isinstance(attribute, numbers.Number) else None + setattr(self, name, value) return my_col = comm_grid.rank % self._n_blocks_time diff --git a/test/test_detector_blocks.py b/test/test_detector_blocks.py new file mode 100644 index 00000000..1c1e9346 --- /dev/null +++ b/test/test_detector_blocks.py @@ -0,0 +1,100 @@ +import numpy as np +import litebird_sim as lbs + + +def test_detector_blocks(): + comm = lbs.MPI_COMM_WORLD + + start_time = 456 + duration_s = 100 + sampling_freq_Hz = 1 + nobs_per_det = 3 + + # Creating a list of detectors. + dets = [ + lbs.DetectorInfo( + name="channel1_w9_detA", + wafer="wafer_9", + channel="channel1", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel1_w3_detB", + wafer="wafer_3", + channel="channel1", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel1_w1_detC", + wafer="wafer_1", + channel="channel1", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel1_w1_detD", + wafer="wafer_1", + channel="channel1", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel2_w4_detA", + wafer="wafer_4", + channel="channel2", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel2_w4_detB", + wafer="wafer_4", + channel="channel2", + sampling_rate_hz=sampling_freq_Hz, + ), + ] + + sim = lbs.Simulation( + start_time=start_time, + duration_s=duration_s, + random_seed=12345, + mpi_comm=comm, + ) + + sim.create_observations( + detectors=dets, + split_list_over_processes=False, + num_of_obs_per_detector=nobs_per_det, + det_blocks_attributes=["channel", "wafer"], + ) + + tod_len_per_det_per_proc = 0 + for obs in sim.observations: + tod_shape = obs.tod.shape + + n_blocks_det = obs.n_blocks_det + n_blocks_time = obs.n_blocks_time + tod_len_per_det_per_proc += obs.tod.shape[1] + + # No testing required if the proc doesn't owns a detector + if obs.det_idx is not None: + det_names_per_obs = [ + obs.detectors_global[idx]["name"] for idx in obs.det_idx + ] + + # Testing if the mapping between the obs.name and + # obs.det_idx is consistent with obs.detectors_global + np.testing.assert_equal(obs.name, det_names_per_obs) + + # Testing the distribution of the number of detectors per + # detector block + np.testing.assert_equal(obs.name.shape[0], tod_shape[0]) + + # Testing if the distribution of samples along the time axis is consistent + if comm.rank < n_blocks_det * n_blocks_time: + arr = [ + span.num_of_elements + for span in lbs.distribute.distribute_evenly( + duration_s * sampling_freq_Hz, n_blocks_time * nobs_per_det + ) + ] + + start_idx = (comm.rank % n_blocks_time) * nobs_per_det + stop_idx = start_idx + nobs_per_det + np.testing.assert_equal(sum(arr[start_idx:stop_idx]), tod_len_per_det_per_proc) From 719ac8e974e0c9756f32774ede47c095f8d069d3 Mon Sep 17 00:00:00 2001 From: Avinash Anand <36325275+anand-avinash@users.noreply.github.com> Date: Mon, 25 Nov 2024 16:19:42 +0900 Subject: [PATCH 07/11] added a subcommunicator for MPI processes that contain non-zero observations --- docs/source/observations.rst | 6 +- litebird_sim/__init__.py | 3 +- litebird_sim/mapmaking/binner.py | 2 +- litebird_sim/mapmaking/common.py | 5 +- litebird_sim/mapmaking/destriper.py | 224 +++++++++++++++------------- litebird_sim/mpi.py | 39 +++++ litebird_sim/observations.py | 12 ++ litebird_sim/simulations.py | 10 +- 8 files changed, 190 insertions(+), 111 deletions(-) diff --git a/docs/source/observations.rst b/docs/source/observations.rst index 48b4e4ef..3c9b9077 100644 --- a/docs/source/observations.rst +++ b/docs/source/observations.rst @@ -173,13 +173,13 @@ balancing, it is less suitable for simulating the some effects, like crosstalk a noise correlation between the detectors. This uniform distribution across MPI processes necessitates the transfer of large TOD arrays across multiple MPI processes, which complicates the code implementation and may potentially lead to significant -performance overhead. To save us from this situation, the :class:`Observation` class +performance overhead. To save us from this situation, the :class:`.Observation` class accepts an argument ``det_blocks_attributes`` that is a list of string objects specifying the detector attributes to create the group of detectors. Once the detector groups are made, the detectors are distributed to the MPI processes in such a way that all the detectors of a group reside on the same MPI process. -If a valid ``det_blocks_attributes`` argument is passed to the :class:`Observation` +If a valid ``det_blocks_attributes`` argument is passed to the :class:`.Observation` class, the arguments ``n_blocks_det`` and ``n_blocks_time`` are ignored. Since the ``det_blocks_attributes`` creates the detector blocks dynamically, the ``n_blocks_time`` is computed during runtime using the size of MPI communicator and @@ -369,7 +369,7 @@ detectors axis and time axis is divided depending on the size of MPI communicato quantities refer to the same detector. If you need the global detector index, you can get it with ``obs.det_idx[0]``, which is created at construction time. ``obs.det_idx`` stores the detector indices of the detectors available to an - :class:`Observation` class, with respect to the list of detectors stored in + :class:`.Observation` class, with respect to the list of detectors stored in ``obs.detectors_global`` variable. .. note:: diff --git a/litebird_sim/__init__.py b/litebird_sim/__init__.py index b4a8aa6b..ab720c05 100644 --- a/litebird_sim/__init__.py +++ b/litebird_sim/__init__.py @@ -72,7 +72,7 @@ ) from .madam import save_simulation_for_madam from .mbs.mbs import Mbs, MbsParameters, MbsSavedMapInfo -from .mpi import MPI_COMM_WORLD, MPI_ENABLED, MPI_CONFIGURATION +from .mpi import MPI_COMM_WORLD, MPI_ENABLED, MPI_CONFIGURATION, comm_grid from .noise import ( add_white_noise, add_one_over_f_noise, @@ -218,6 +218,7 @@ def destripe_with_toast2(*args, **kwargs): "MPI_COMM_WORLD", "MPI_ENABLED", "MPI_CONFIGURATION", + "comm_grid", # observations.py "Observation", "TodDescription", diff --git a/litebird_sim/mapmaking/binner.py b/litebird_sim/mapmaking/binner.py index 682a5fbe..9622a234 100644 --- a/litebird_sim/mapmaking/binner.py +++ b/litebird_sim/mapmaking/binner.py @@ -68,7 +68,7 @@ class BinnerResult: @njit def _solve_binning(nobs_matrix, atd): - # Sove the map-making equation + # Solve the map-making equation # # This method alters the parameter `nobs_matrix`, so that after its completion # each 3×3 matrix in nobs_matrix[idx, :, :] will be the *inverse*. diff --git a/litebird_sim/mapmaking/common.py b/litebird_sim/mapmaking/common.py index 3f60e1c3..ac2f1393 100644 --- a/litebird_sim/mapmaking/common.py +++ b/litebird_sim/mapmaking/common.py @@ -219,7 +219,10 @@ def _compute_pixel_indices( if output_coordinate_system == CoordinateSystem.Galactic: # Free curr_pointings_det if the output map is already in Galactic coordinates - del curr_pointings_det + try: + del curr_pointings_det + except UnboundLocalError: + pass return pixidx_all, polang_all diff --git a/litebird_sim/mapmaking/destriper.py b/litebird_sim/mapmaking/destriper.py index c56fe6df..16e023da 100644 --- a/litebird_sim/mapmaking/destriper.py +++ b/litebird_sim/mapmaking/destriper.py @@ -20,7 +20,7 @@ from numba import njit, prange import healpy as hp -from litebird_sim.mpi import MPI_ENABLED, MPI_COMM_WORLD +from litebird_sim.mpi import MPI_ENABLED, MPI_COMM_WORLD, comm_grid from typing import Callable, Union, List, Optional, Tuple, Any, Dict from litebird_sim.hwp import HWP from litebird_sim.observations import Observation @@ -44,7 +44,7 @@ __DESTRIPER_RESULTS_FILE_NAME = "destriper_results.fits" -__BASELINES_FILE_NAME = f"baselines_mpi{MPI_COMM_WORLD.rank:04d}.fits" +__BASELINES_FILE_NAME = f"baselines_mpi{comm_grid.COMM_OBS_GRID.rank:04d}.fits" def _split_items_into_n_segments(n: int, num_of_segments: int) -> List[int]: @@ -498,8 +498,10 @@ def _build_nobs_matrix( ) # Now we must accumulate the result of every MPI process - if MPI_ENABLED: - MPI_COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE, nobs_matrix, op=mpi4py.MPI.SUM) + if MPI_ENABLED and comm_grid.COMM_OBS_GRID != comm_grid.COMM_NULL: + comm_grid.COMM_OBS_GRID.Allreduce( + mpi4py.MPI.IN_PLACE, nobs_matrix, op=mpi4py.MPI.SUM + ) # `nobs_matrix_cholesky` will *not* contain the M_i maps shown in # Eq. 9 of KurkiSuonio2009, but its Cholesky decomposition, i.e., @@ -746,8 +748,12 @@ def _compute_binned_map( ) if MPI_ENABLED: - MPI_COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE, output_sky_map, op=mpi4py.MPI.SUM) - MPI_COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE, output_hit_map, op=mpi4py.MPI.SUM) + comm_grid.COMM_OBS_GRID.Allreduce( + mpi4py.MPI.IN_PLACE, output_sky_map, op=mpi4py.MPI.SUM + ) + comm_grid.COMM_OBS_GRID.Allreduce( + mpi4py.MPI.IN_PLACE, output_hit_map, op=mpi4py.MPI.SUM + ) # Step 2: compute the “binned map” (Eq. 21) _sum_map_to_binned_map( @@ -987,7 +993,7 @@ def _mpi_dot(a: List[npt.ArrayLike], b: List[npt.ArrayLike]) -> float: # the dot product local_result = sum([np.dot(x1.flatten(), x2.flatten()) for (x1, x2) in zip(a, b)]) if MPI_ENABLED: - return MPI_COMM_WORLD.allreduce(local_result, op=mpi4py.MPI.SUM) + return comm_grid.COMM_OBS_GRID.allreduce(local_result, op=mpi4py.MPI.SUM) else: return local_result @@ -1004,7 +1010,7 @@ def _get_stopping_factor(residual: List[npt.ArrayLike]) -> float: """ local_result = np.max(np.abs(residual)) if MPI_ENABLED: - return MPI_COMM_WORLD.allreduce(local_result, op=mpi4py.MPI.MAX) + return comm_grid.COMM_OBS_GRID.allreduce(local_result, op=mpi4py.MPI.MAX) else: return local_result @@ -1418,7 +1424,7 @@ def _run_destriper( bytes_in_temporary_buffers += mask.nbytes if MPI_ENABLED: - bytes_in_temporary_buffers = MPI_COMM_WORLD.allreduce( + bytes_in_temporary_buffers = comm_grid.COMM_OBS_GRID.allreduce( bytes_in_temporary_buffers, op=mpi4py.MPI.SUM, ) @@ -1613,91 +1619,103 @@ def my_gui_callback( binned_map = np.empty((3, number_of_pixels)) hit_map = np.empty(number_of_pixels) - if do_destriping: - try: - # This will fail if the parameter is a scalar - len(params.samples_per_baseline) - - baseline_lengths_list = params.samples_per_baseline - assert len(baseline_lengths_list) == len(obs_list), ( - f"The list baseline_lengths_list has {len(baseline_lengths_list)} " - f"elements, but there are {len(obs_list)} observations" - ) - except TypeError: - # Ok, params.samples_per_baseline is a scalar, so we must - # figure out the number of samples in each baseline within - # each observation - baseline_lengths_list = [ - split_items_evenly( - n=getattr(cur_obs, components[0]).shape[1], - sub_n=int(params.samples_per_baseline), + if comm_grid.COMM_OBS_GRID != comm_grid.COMM_NULL: + # perform the following operations when MPI is not being used + # OR when the comm_grid.COMM_OBS_GRID is not a NULL communicator + if do_destriping: + try: + # This will fail if the parameter is a scalar + len(params.samples_per_baseline) + + baseline_lengths_list = params.samples_per_baseline + assert len(baseline_lengths_list) == len(obs_list), ( + f"The list baseline_lengths_list has {len(baseline_lengths_list)} " + f"elements, but there are {len(obs_list)} observations" ) - for cur_obs in obs_list - ] + except TypeError: + # Ok, params.samples_per_baseline is a scalar, so we must + # figure out the number of samples in each baseline within + # each observation + baseline_lengths_list = [ + split_items_evenly( + n=getattr(cur_obs, components[0]).shape[1], + sub_n=int(params.samples_per_baseline), + ) + for cur_obs in obs_list + ] + + # Each element of this list is a 2D array with shape (N_det, N_baselines), + # where N_det is the number of detectors in the i-th Observation object + recycle_baselines = False + if baselines_list is None: + baselines_list = [ + np.zeros( + (getattr(cur_obs, components[0]).shape[0], len(cur_baseline)) + ) + for (cur_obs, cur_baseline) in zip(obs_list, baseline_lengths_list) + ] + else: + recycle_baselines = True + + destriped_map = np.empty((3, number_of_pixels)) + ( + baselines_list, + baseline_errors_list, + history_of_stopping_factors, + best_stopping_factor, + converged, + bytes_in_temporary_buffers, + ) = _run_destriper( + obs_list=obs_list, + nobs_matrix_cholesky=nobs_matrix_cholesky, + binned_map=binned_map, + destriped_map=destriped_map, + hit_map=hit_map, + baseline_lengths_list=baseline_lengths_list, + baselines_list_start=baselines_list, + recycle_baselines=recycle_baselines, + recycled_convergence=recycled_convergence, + dm_list=detector_mask_list, + tm_list=time_mask_list, + component=components[0], + threshold=params.threshold, + max_steps=params.iter_max, + use_preconditioner=params.use_preconditioner, + callback=callback, + callback_kwargs=callback_kwargs if callback_kwargs else {}, + ) - # Each element of this list is a 2D array with shape (N_det, N_baselines), - # where N_det is the number of detectors in the i-th Observation object - recycle_baselines = False - if baselines_list is None: - baselines_list = [ - np.zeros((getattr(cur_obs, components[0]).shape[0], len(cur_baseline))) - for (cur_obs, cur_baseline) in zip(obs_list, baseline_lengths_list) - ] + if MPI_ENABLED: + bytes_in_temporary_buffers = comm_grid.COMM_OBS_GRID.allreduce( + bytes_in_temporary_buffers, + op=mpi4py.MPI.SUM, + ) else: - recycle_baselines = True - - destriped_map = np.empty((3, number_of_pixels)) - ( - baselines_list, - baseline_errors_list, - history_of_stopping_factors, - best_stopping_factor, - converged, - bytes_in_temporary_buffers, - ) = _run_destriper( - obs_list=obs_list, - nobs_matrix_cholesky=nobs_matrix_cholesky, - binned_map=binned_map, - destriped_map=destriped_map, - hit_map=hit_map, - baseline_lengths_list=baseline_lengths_list, - baselines_list_start=baselines_list, - recycle_baselines=recycle_baselines, - recycled_convergence=recycled_convergence, - dm_list=detector_mask_list, - tm_list=time_mask_list, - component=components[0], - threshold=params.threshold, - max_steps=params.iter_max, - use_preconditioner=params.use_preconditioner, - callback=callback, - callback_kwargs=callback_kwargs if callback_kwargs else {}, - ) - - if MPI_ENABLED: - bytes_in_temporary_buffers = MPI_COMM_WORLD.allreduce( - bytes_in_temporary_buffers, - op=mpi4py.MPI.SUM, + # No need to run the destriping, just compute the binned map with + # one single baseline set to zero + _compute_binned_map( + obs_list=obs_list, + output_sky_map=binned_map, + output_hit_map=hit_map, + nobs_matrix_cholesky=nobs_matrix_cholesky, + component=components[0], + dm_list=detector_mask_list, + tm_list=time_mask_list, + baselines_list=None, + baseline_lengths_list=[ + np.array([getattr(cur_obs, components[0]).shape[1]], dtype=int) + for cur_obs in obs_list + ], ) + bytes_in_temporary_buffers = 0 + destriped_map = None + baseline_lengths_list = None + baselines_list = None + baseline_errors_list = None + history_of_stopping_factors = None + best_stopping_factor = None + converged = True else: - # No need to run the destriping, just compute the binned map with - # one single baseline set to zero - _compute_binned_map( - obs_list=obs_list, - output_sky_map=binned_map, - output_hit_map=hit_map, - nobs_matrix_cholesky=nobs_matrix_cholesky, - component=components[0], - dm_list=detector_mask_list, - tm_list=time_mask_list, - baselines_list=None, - baseline_lengths_list=[ - np.array([getattr(cur_obs, components[0]).shape[1]], dtype=int) - for cur_obs in obs_list - ], - ) - bytes_in_temporary_buffers = 0 - destriped_map = None baseline_lengths_list = None baselines_list = None @@ -1707,14 +1725,18 @@ def my_gui_callback( converged = True # Add the temporary memory that was allocated *before* calling the destriper - bytes_in_temporary_buffers += sum( - [ - cur_obs.destriper_weights.nbytes - + cur_obs.destriper_pixel_idx.nbytes - + cur_obs.destriper_pol_angle_rad.nbytes - for cur_obs in obs_list - ] - ) + try: + bytes_in_temporary_buffers += sum( + [ + cur_obs.destriper_weights.nbytes + + cur_obs.destriper_pixel_idx.nbytes + + cur_obs.destriper_pol_angle_rad.nbytes + for cur_obs in obs_list + ] + ) + except UnboundLocalError: + # The case when `bytes_in_temporary_buffers` is not defined + bytes_in_temporary_buffers = 0 # We're nearly done! Let's clean up some stuff… if not keep_weights: @@ -1992,11 +2014,11 @@ def _save_baselines(results: DestriperResult, output_file: Path) -> None: primary_hdu = fits.PrimaryHDU() primary_hdu.header["MPIRANK"] = ( - MPI_COMM_WORLD.rank, + comm_grid.COMM_OBS_GRID.rank, "The rank of the MPI process that wrote this file", ) primary_hdu.header["MPISIZE"] = ( - MPI_COMM_WORLD.size, + comm_grid.COMM_OBS_GRID.size, "The number of MPI processes used in the computation", ) @@ -2212,11 +2234,11 @@ def load_destriper_results( baselines_file_name = folder / __BASELINES_FILE_NAME with fits.open(baselines_file_name) as inpf: - assert MPI_COMM_WORLD.rank == inpf[0].header["MPIRANK"], ( + assert comm_grid.COMM_OBS_GRID.rank == inpf[0].header["MPIRANK"], ( "You must call load_destriper_results using the " "same MPI layout that was used for save_destriper_results " ) - assert MPI_COMM_WORLD.size == inpf[0].header["MPISIZE"], ( + assert comm_grid.COMM_OBS_GRID.size == inpf[0].header["MPISIZE"], ( "You must call load_destriper_results using the " "same MPI layout that was used for save_destriper_results" ) diff --git a/litebird_sim/mpi.py b/litebird_sim/mpi.py index fe623248..d6c31c4d 100644 --- a/litebird_sim/mpi.py +++ b/litebird_sim/mpi.py @@ -22,10 +22,47 @@ class _SerialMpiCommunicator: size = 1 +class _GridCommClass: + """ + This class encapsulates the `COMM_OBS_GRID` and `COMM_NULL` communicators. It + offers explicitly defined setter functions so that the communicators cannot be + changed accidentally. + + Attributes: + + COMM_OBS_GRID (mpi4py.MPI.Intracomm): A subset of `MPI.COMM_WORLD` that + contain all the processes associated with non-zero observations. + + COMM_NULL (mpi4py.MPI.Comm): A NULL communicator. When MPI is not enabled, it + is set as `None`. If MPI is enabled, it is set as `MPI.COMM_NULL` + + """ + + def __init__(self, comm_obs_grid=_SerialMpiCommunicator(), comm_null=None): + self._MPI_COMM_OBS_GRID = comm_obs_grid + self._MPI_COMM_NULL = comm_null + + @property + def COMM_OBS_GRID(self): + return self._MPI_COMM_OBS_GRID + + @property + def COMM_NULL(self): + return self._MPI_COMM_NULL + + def _set_comm_obs_grid(self, comm_obs_grid): + self._MPI_COMM_OBS_GRID = comm_obs_grid + + def _set_null_comm(self, comm_null): + self._MPI_COMM_NULL = comm_null + + #: Global variable equal either to `mpi4py.MPI.COMM_WORLD` or a object #: that defines the member variables `rank = 0` and `size = 1`. MPI_COMM_WORLD = _SerialMpiCommunicator() +comm_grid = _GridCommClass() + #: `True` if MPI should be used by the application. The value of this #: variable is set according to the following rules: #: @@ -53,6 +90,8 @@ class _SerialMpiCommunicator: from mpi4py import MPI MPI_COMM_WORLD = MPI.COMM_WORLD + comm_grid._set_comm_obs_grid(comm_obs_grid=MPI.COMM_WORLD) + comm_grid._set_null_comm(comm_null=MPI.COMM_NULL) MPI_ENABLED = True MPI_CONFIGURATION = mpi4py.get_config() except ImportError: diff --git a/litebird_sim/observations.py b/litebird_sim/observations.py index 7485a940..38582ef1 100644 --- a/litebird_sim/observations.py +++ b/litebird_sim/observations.py @@ -13,6 +13,7 @@ from .coordinates import DEFAULT_TIME_SCALE from .distribute import distribute_evenly, distribute_detector_blocks from .detectors import DetectorInfo +from .mpi import comm_grid @dataclass @@ -165,6 +166,17 @@ def __init__( self._n_blocks_det = 1 self._n_blocks_time = 1 + if comm and comm.size > 1: + if comm.rank < self.n_blocks_det * n_blocks_time: + matrix_color = 1 + else: + from .mpi import MPI + + matrix_color = MPI.UNDEFINED + + comm_obs_grid = comm.Split(matrix_color, comm.rank) + comm_grid._set_comm_obs_grid(comm_obs_grid=comm_obs_grid) + self.tod_list = tods for cur_tod in self.tod_list: if allocate_tod: diff --git a/litebird_sim/simulations.py b/litebird_sim/simulations.py index e11267b4..44148fe0 100644 --- a/litebird_sim/simulations.py +++ b/litebird_sim/simulations.py @@ -44,7 +44,7 @@ DestriperResult, destriper_log_callback, ) -from .mpi import MPI_ENABLED, MPI_COMM_WORLD +from .mpi import MPI_ENABLED, MPI_COMM_WORLD, comm_grid from .noise import add_noise_to_observations from .observations import Observation, TodDescription from .pointings_in_obs import prepare_pointings, precompute_pointings @@ -1221,7 +1221,8 @@ def set_scanning_strategy( num_of_obs = len(self.observations) if append_to_report and MPI_ENABLED: - num_of_obs = MPI_COMM_WORLD.allreduce(num_of_obs) + if comm_grid.COMM_OBS_GRID != comm_grid.COMM_NULL: + num_of_obs = comm_grid.COMM_OBS_GRID.allreduce(num_of_obs) if append_to_report and MPI_COMM_WORLD.rank == 0: template_file_path = get_template_file_path("report_quaternions.md") @@ -1318,8 +1319,9 @@ def prepare_pointings( memory_occupation = pointing_provider.bore2ecliptic_quats.quats.nbytes num_of_obs = len(self.observations) if append_to_report and MPI_ENABLED: - memory_occupation = MPI_COMM_WORLD.allreduce(memory_occupation) - num_of_obs = MPI_COMM_WORLD.allreduce(num_of_obs) + if comm_grid.COMM_OBS_GRID != comm_grid.COMM_NULL: + memory_occupation = comm_grid.COMM_OBS_GRID.allreduce(memory_occupation) + num_of_obs = comm_grid.COMM_OBS_GRID.allreduce(num_of_obs) if append_to_report and MPI_COMM_WORLD.rank == 0: template_file_path = get_template_file_path("report_pointings.md") From bee1e5f2659cf7a4dc3f9cc2e0ac9954178fc371 Mon Sep 17 00:00:00 2001 From: Avinash Anand <36325275+anand-avinash@users.noreply.github.com> Date: Wed, 27 Nov 2024 17:37:48 +0900 Subject: [PATCH 08/11] renamed comm_grid to MPI_COMM_GRID --- litebird_sim/__init__.py | 4 ++-- litebird_sim/mapmaking/destriper.py | 32 ++++++++++++++--------------- litebird_sim/mpi.py | 6 +++--- litebird_sim/observations.py | 4 ++-- litebird_sim/simulations.py | 14 +++++++------ 5 files changed, 31 insertions(+), 29 deletions(-) diff --git a/litebird_sim/__init__.py b/litebird_sim/__init__.py index ab720c05..44285c14 100644 --- a/litebird_sim/__init__.py +++ b/litebird_sim/__init__.py @@ -72,7 +72,7 @@ ) from .madam import save_simulation_for_madam from .mbs.mbs import Mbs, MbsParameters, MbsSavedMapInfo -from .mpi import MPI_COMM_WORLD, MPI_ENABLED, MPI_CONFIGURATION, comm_grid +from .mpi import MPI_COMM_WORLD, MPI_ENABLED, MPI_CONFIGURATION, MPI_COMM_GRID from .noise import ( add_white_noise, add_one_over_f_noise, @@ -218,7 +218,7 @@ def destripe_with_toast2(*args, **kwargs): "MPI_COMM_WORLD", "MPI_ENABLED", "MPI_CONFIGURATION", - "comm_grid", + "MPI_COMM_GRID", # observations.py "Observation", "TodDescription", diff --git a/litebird_sim/mapmaking/destriper.py b/litebird_sim/mapmaking/destriper.py index 16e023da..71e323af 100644 --- a/litebird_sim/mapmaking/destriper.py +++ b/litebird_sim/mapmaking/destriper.py @@ -20,7 +20,7 @@ from numba import njit, prange import healpy as hp -from litebird_sim.mpi import MPI_ENABLED, MPI_COMM_WORLD, comm_grid +from litebird_sim.mpi import MPI_ENABLED, MPI_COMM_WORLD, MPI_COMM_GRID from typing import Callable, Union, List, Optional, Tuple, Any, Dict from litebird_sim.hwp import HWP from litebird_sim.observations import Observation @@ -44,7 +44,7 @@ __DESTRIPER_RESULTS_FILE_NAME = "destriper_results.fits" -__BASELINES_FILE_NAME = f"baselines_mpi{comm_grid.COMM_OBS_GRID.rank:04d}.fits" +__BASELINES_FILE_NAME = f"baselines_mpi{MPI_COMM_GRID.COMM_OBS_GRID.rank:04d}.fits" def _split_items_into_n_segments(n: int, num_of_segments: int) -> List[int]: @@ -498,8 +498,8 @@ def _build_nobs_matrix( ) # Now we must accumulate the result of every MPI process - if MPI_ENABLED and comm_grid.COMM_OBS_GRID != comm_grid.COMM_NULL: - comm_grid.COMM_OBS_GRID.Allreduce( + if MPI_ENABLED and MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: + MPI_COMM_GRID.COMM_OBS_GRID.Allreduce( mpi4py.MPI.IN_PLACE, nobs_matrix, op=mpi4py.MPI.SUM ) @@ -748,10 +748,10 @@ def _compute_binned_map( ) if MPI_ENABLED: - comm_grid.COMM_OBS_GRID.Allreduce( + MPI_COMM_GRID.COMM_OBS_GRID.Allreduce( mpi4py.MPI.IN_PLACE, output_sky_map, op=mpi4py.MPI.SUM ) - comm_grid.COMM_OBS_GRID.Allreduce( + MPI_COMM_GRID.COMM_OBS_GRID.Allreduce( mpi4py.MPI.IN_PLACE, output_hit_map, op=mpi4py.MPI.SUM ) @@ -993,7 +993,7 @@ def _mpi_dot(a: List[npt.ArrayLike], b: List[npt.ArrayLike]) -> float: # the dot product local_result = sum([np.dot(x1.flatten(), x2.flatten()) for (x1, x2) in zip(a, b)]) if MPI_ENABLED: - return comm_grid.COMM_OBS_GRID.allreduce(local_result, op=mpi4py.MPI.SUM) + return MPI_COMM_GRID.COMM_OBS_GRID.allreduce(local_result, op=mpi4py.MPI.SUM) else: return local_result @@ -1010,7 +1010,7 @@ def _get_stopping_factor(residual: List[npt.ArrayLike]) -> float: """ local_result = np.max(np.abs(residual)) if MPI_ENABLED: - return comm_grid.COMM_OBS_GRID.allreduce(local_result, op=mpi4py.MPI.MAX) + return MPI_COMM_GRID.COMM_OBS_GRID.allreduce(local_result, op=mpi4py.MPI.MAX) else: return local_result @@ -1424,7 +1424,7 @@ def _run_destriper( bytes_in_temporary_buffers += mask.nbytes if MPI_ENABLED: - bytes_in_temporary_buffers = comm_grid.COMM_OBS_GRID.allreduce( + bytes_in_temporary_buffers = MPI_COMM_GRID.COMM_OBS_GRID.allreduce( bytes_in_temporary_buffers, op=mpi4py.MPI.SUM, ) @@ -1619,9 +1619,9 @@ def my_gui_callback( binned_map = np.empty((3, number_of_pixels)) hit_map = np.empty(number_of_pixels) - if comm_grid.COMM_OBS_GRID != comm_grid.COMM_NULL: + if MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: # perform the following operations when MPI is not being used - # OR when the comm_grid.COMM_OBS_GRID is not a NULL communicator + # OR when the MPI_COMM_GRID.COMM_OBS_GRID is not a NULL communicator if do_destriping: try: # This will fail if the parameter is a scalar @@ -1686,7 +1686,7 @@ def my_gui_callback( ) if MPI_ENABLED: - bytes_in_temporary_buffers = comm_grid.COMM_OBS_GRID.allreduce( + bytes_in_temporary_buffers = MPI_COMM_GRID.COMM_OBS_GRID.allreduce( bytes_in_temporary_buffers, op=mpi4py.MPI.SUM, ) @@ -2014,11 +2014,11 @@ def _save_baselines(results: DestriperResult, output_file: Path) -> None: primary_hdu = fits.PrimaryHDU() primary_hdu.header["MPIRANK"] = ( - comm_grid.COMM_OBS_GRID.rank, + MPI_COMM_GRID.COMM_OBS_GRID.rank, "The rank of the MPI process that wrote this file", ) primary_hdu.header["MPISIZE"] = ( - comm_grid.COMM_OBS_GRID.size, + MPI_COMM_GRID.COMM_OBS_GRID.size, "The number of MPI processes used in the computation", ) @@ -2234,11 +2234,11 @@ def load_destriper_results( baselines_file_name = folder / __BASELINES_FILE_NAME with fits.open(baselines_file_name) as inpf: - assert comm_grid.COMM_OBS_GRID.rank == inpf[0].header["MPIRANK"], ( + assert MPI_COMM_GRID.COMM_OBS_GRID.rank == inpf[0].header["MPIRANK"], ( "You must call load_destriper_results using the " "same MPI layout that was used for save_destriper_results " ) - assert comm_grid.COMM_OBS_GRID.size == inpf[0].header["MPISIZE"], ( + assert MPI_COMM_GRID.COMM_OBS_GRID.size == inpf[0].header["MPISIZE"], ( "You must call load_destriper_results using the " "same MPI layout that was used for save_destriper_results" ) diff --git a/litebird_sim/mpi.py b/litebird_sim/mpi.py index d6c31c4d..80bb2dbc 100644 --- a/litebird_sim/mpi.py +++ b/litebird_sim/mpi.py @@ -61,7 +61,7 @@ def _set_null_comm(self, comm_null): #: that defines the member variables `rank = 0` and `size = 1`. MPI_COMM_WORLD = _SerialMpiCommunicator() -comm_grid = _GridCommClass() +MPI_COMM_GRID = _GridCommClass() #: `True` if MPI should be used by the application. The value of this #: variable is set according to the following rules: @@ -90,8 +90,8 @@ def _set_null_comm(self, comm_null): from mpi4py import MPI MPI_COMM_WORLD = MPI.COMM_WORLD - comm_grid._set_comm_obs_grid(comm_obs_grid=MPI.COMM_WORLD) - comm_grid._set_null_comm(comm_null=MPI.COMM_NULL) + MPI_COMM_GRID._set_comm_obs_grid(comm_obs_grid=MPI.COMM_WORLD) + MPI_COMM_GRID._set_null_comm(comm_null=MPI.COMM_NULL) MPI_ENABLED = True MPI_CONFIGURATION = mpi4py.get_config() except ImportError: diff --git a/litebird_sim/observations.py b/litebird_sim/observations.py index 38582ef1..fa5ab634 100644 --- a/litebird_sim/observations.py +++ b/litebird_sim/observations.py @@ -13,7 +13,7 @@ from .coordinates import DEFAULT_TIME_SCALE from .distribute import distribute_evenly, distribute_detector_blocks from .detectors import DetectorInfo -from .mpi import comm_grid +from .mpi import MPI_COMM_GRID @dataclass @@ -175,7 +175,7 @@ def __init__( matrix_color = MPI.UNDEFINED comm_obs_grid = comm.Split(matrix_color, comm.rank) - comm_grid._set_comm_obs_grid(comm_obs_grid=comm_obs_grid) + MPI_COMM_GRID._set_comm_obs_grid(comm_obs_grid=comm_obs_grid) self.tod_list = tods for cur_tod in self.tod_list: diff --git a/litebird_sim/simulations.py b/litebird_sim/simulations.py index 44148fe0..51dde0cf 100644 --- a/litebird_sim/simulations.py +++ b/litebird_sim/simulations.py @@ -44,7 +44,7 @@ DestriperResult, destriper_log_callback, ) -from .mpi import MPI_ENABLED, MPI_COMM_WORLD, comm_grid +from .mpi import MPI_ENABLED, MPI_COMM_WORLD, MPI_COMM_GRID from .noise import add_noise_to_observations from .observations import Observation, TodDescription from .pointings_in_obs import prepare_pointings, precompute_pointings @@ -1221,8 +1221,8 @@ def set_scanning_strategy( num_of_obs = len(self.observations) if append_to_report and MPI_ENABLED: - if comm_grid.COMM_OBS_GRID != comm_grid.COMM_NULL: - num_of_obs = comm_grid.COMM_OBS_GRID.allreduce(num_of_obs) + if MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: + num_of_obs = MPI_COMM_GRID.COMM_OBS_GRID.allreduce(num_of_obs) if append_to_report and MPI_COMM_WORLD.rank == 0: template_file_path = get_template_file_path("report_quaternions.md") @@ -1319,9 +1319,11 @@ def prepare_pointings( memory_occupation = pointing_provider.bore2ecliptic_quats.quats.nbytes num_of_obs = len(self.observations) if append_to_report and MPI_ENABLED: - if comm_grid.COMM_OBS_GRID != comm_grid.COMM_NULL: - memory_occupation = comm_grid.COMM_OBS_GRID.allreduce(memory_occupation) - num_of_obs = comm_grid.COMM_OBS_GRID.allreduce(num_of_obs) + if MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: + memory_occupation = MPI_COMM_GRID.COMM_OBS_GRID.allreduce( + memory_occupation + ) + num_of_obs = MPI_COMM_GRID.COMM_OBS_GRID.allreduce(num_of_obs) if append_to_report and MPI_COMM_WORLD.rank == 0: template_file_path = get_template_file_path("report_pointings.md") From 241a401e11df3b43125480c69e8e8fd6e78b05ff Mon Sep 17 00:00:00 2001 From: Avinash Anand <36325275+anand-avinash@users.noreply.github.com> Date: Wed, 27 Nov 2024 23:38:42 +0900 Subject: [PATCH 09/11] added documentation for MPI_COMM_GRID --- docs/source/mpi.rst | 25 +++++++++++++++++++++++++ litebird_sim/mpi.py | 10 ++++++++++ 2 files changed, 35 insertions(+) diff --git a/docs/source/mpi.rst b/docs/source/mpi.rst index ec64e75a..79735ad1 100644 --- a/docs/source/mpi.rst +++ b/docs/source/mpi.rst @@ -133,6 +133,31 @@ variable :data:`.MPI_ENABLED`:: To ensure that your code uses MPI in the proper way, you should always use :data:`.MPI_COMM_WORLD` instead of importing ``mpi4py`` directly. +The simulation framework also provides a global object +:data:`.MPI_COMM_GRID`. It has two attributes: + +- ``COMM_OBS_GRID``: This is an MPI communicator that contains all the + MPI processes with the global rank less than ``n_blocks_time * n_blocks_det``. + It provides a safety net to the operations and MPI communications + that are needed to be performed only on the partition of :data:`.MPI_COMM_WORLD` + that contain non-zero number of pointings and TODs. By default, + ``COMM_OBS_GRID`` points to the global MPI communicator :data:`.MPI_COMM_WORLD`. + It is updated once :class:`.Observation` are defined. For example, + consider the case when a user runs the simulation with 10 MPI + processes but due some specific ``det_blocks_attributes`` argument + in :class:`.Observation` class, the number of detector and time + blocks are determined to be 2 and 4 respectively. Then the + simulation framework will store the pointings and TODs only on + :math:`2\times4=8` MPI processes and the last two ranks of :data:`.MPI_COMM_WORLD` + will be left unused. Once this happens, ``COMM_OBS_GRID`` on first 8 + ranks (rank 0 to 7) will point to the local sub-communicator + containing the processes with global rank 0 to 7. On the unused + ranks, it will simply point to the NULL communicator. +- ``COMM_NULL``: If :data:`.MPI_ENABLED` is ``True``, this object + points to a NULL MPI communicator (``mpi4py.MPI.COMM_NULL``). + Otherwise it is set to ``None``. The user should compare + ``COMM_OBS_GRID`` with ``COMM_NULL`` on every MPI process in order + to avoid running a piece of code on unused MPI processes. Enabling/disabling MPI ---------------------- diff --git a/litebird_sim/mpi.py b/litebird_sim/mpi.py index 80bb2dbc..64c31181 100644 --- a/litebird_sim/mpi.py +++ b/litebird_sim/mpi.py @@ -61,6 +61,16 @@ def _set_null_comm(self, comm_null): #: that defines the member variables `rank = 0` and `size = 1`. MPI_COMM_WORLD = _SerialMpiCommunicator() + +#: Global object with two attributes: +#: +#: - ``COMM_OBS_GRID``: It is a partition of ``MPI_COMM_WORLD`` that includes all the +#: MPI processes with global rank less than ``n_blocks_time * n_blocks_det``. On MPI +#: processes with higher ranks, it points to NULL MPI communicator +#: ``mpi4py.MPI.COMM_NULL``. +#: +#: - ``COMM_NULL``: If :data:`.MPI_ENABLED` is ``True``, this object points to a NULL +#: MPI communicator (``mpi4py.MPI.COMM_NULL``). Otherwise it is ``None``. MPI_COMM_GRID = _GridCommClass() #: `True` if MPI should be used by the application. The value of this From bcfee9b936bada66e9b49aa1d77f842fda17ef38 Mon Sep 17 00:00:00 2001 From: Avinash Anand <36325275+anand-avinash@users.noreply.github.com> Date: Thu, 28 Nov 2024 11:41:49 +0900 Subject: [PATCH 10/11] added sub-communicators for detector and time blocks --- litebird_sim/observations.py | 67 +++++++++++++++++++++++++++++------- 1 file changed, 55 insertions(+), 12 deletions(-) diff --git a/litebird_sim/observations.py b/litebird_sim/observations.py index fa5ab634..deab36fe 100644 --- a/litebird_sim/observations.py +++ b/litebird_sim/observations.py @@ -13,7 +13,7 @@ from .coordinates import DEFAULT_TIME_SCALE from .distribute import distribute_evenly, distribute_detector_blocks from .detectors import DetectorInfo -from .mpi import MPI_COMM_GRID +from .mpi import MPI_COMM_GRID, _SerialMpiCommunicator @dataclass @@ -159,23 +159,14 @@ def __init__( # property for each of the (local) detectors self._attr_det_names = [] self._check_blocks(n_blocks_det, n_blocks_time) - if comm and comm.size > 1: + if self.comm and self.comm.size > 1: self._n_blocks_det = n_blocks_det self._n_blocks_time = n_blocks_time else: self._n_blocks_det = 1 self._n_blocks_time = 1 - if comm and comm.size > 1: - if comm.rank < self.n_blocks_det * n_blocks_time: - matrix_color = 1 - else: - from .mpi import MPI - - matrix_color = MPI.UNDEFINED - - comm_obs_grid = comm.Split(matrix_color, comm.rank) - MPI_COMM_GRID._set_comm_obs_grid(comm_obs_grid=comm_obs_grid) + self._set_mpi_subcommunicators() self.tod_list = tods for cur_tod in self.tod_list: @@ -922,3 +913,55 @@ def get_pointings( ) return pointing_buffer, hwp_buffer + + def _set_mpi_subcommunicators(self): + """ + This function splits the global MPI communicator into three kinds of + sub-communicators: + + 1. A sub-communicator containing all the processes with global rank less than + `n_blocks_det * n_blocks_time`. Outside of this global rank, the + sub-communicator is NULL. + + 2. A sub-communicator for each block of detectors, that contains all the + processes corresponding to that detector block. If a process doesn't + contain a detector, the sub-communicator is NULL. + + 3. A sub-communicator for each block of time that contains all the processes + corresponding to that time block. If a process doesn't contain a detector, + the sub-communicator is NULL. + """ + + # Set the detector and time block sub-communicators to + # `_SerialMpiCommunicator()` when MPI is not being used + self.comm_det_block = _SerialMpiCommunicator() + self.comm_time_block = _SerialMpiCommunicator() + + if self.comm and self.comm.size > 1: + if self.comm.rank < self.n_blocks_det * self.n_blocks_time: + matrix_color = 1 + else: + from .mpi import MPI + + matrix_color = MPI.UNDEFINED + + # Case1: For `0 < rank < n_blocks_det * n_blocks_time`, + # `comm_obs_grid` is a sub-communicator that includes processes + # from rank 0 to `n_blocks_det * n_blocks_time - 1`. + # Case 2: For `n_blocks_det * n_blocks_time <= rank < comm.size`, + # `comm_obs_grid = MPI.COMM_NULL` + comm_obs_grid = self.comm.Split(matrix_color, self.comm.rank) + MPI_COMM_GRID._set_comm_obs_grid(comm_obs_grid=comm_obs_grid) + + # If the `MPI_COMM_GRID.COMM_OBS_GRID` is not NULL, we split it in + # communicators corresponding to each detector and time block + # If `MPI_COMM_GRID.COMM_OBS_GRID` is NULL, we set the communicators + # corresponding to detector and time blocks to NULL. + if MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: + det_color = MPI_COMM_GRID.COMM_OBS_GRID.rank // self.n_blocks_time + time_color = MPI_COMM_GRID.COMM_OBS_GRID.rank % self.n_blocks_time + self.comm_det_block = MPI_COMM_GRID.COMM_OBS_GRID.Split(det_color) + self.comm_time_block = MPI_COMM_GRID.COMM_OBS_GRID.Split(time_color) + else: + self.comm_det_block = MPI_COMM_GRID.COMM_NULL + self.comm_time_block = MPI_COMM_GRID.COMM_NULL From f1db957168ec5366bb5232e6f4868bfdca6e6d1a Mon Sep 17 00:00:00 2001 From: Avinash Anand <36325275+anand-avinash@users.noreply.github.com> Date: Thu, 28 Nov 2024 13:36:41 +0900 Subject: [PATCH 11/11] updated setattr_det_global() and added tests for the sub-communicators --- .github/workflows/tests.yml | 2 + litebird_sim/observations.py | 16 ++-- test/test_detector_blocks.py | 142 ++++++++++++++++++++++++----------- 3 files changed, 109 insertions(+), 51 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4018509f..cf6dd5a8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -87,6 +87,7 @@ jobs: for proc in 1 5 9 ; do echo "Running MPI test ($MPI) with $proc processes" PYTHONPATH=. mpiexec -n $proc python3 ./test/test_mpi.py + PYTHONPATH=. mpiexec -n $proc python3 ./test/test_detector_blocks.py done - name: Tests OpenMPI @@ -95,4 +96,5 @@ jobs: for proc in 1 2 ; do echo "Running MPI test ($MPI) with $proc processes" PYTHONPATH=. mpiexec -n $proc python3 ./test/test_mpi.py + PYTHONPATH=. mpiexec -n $proc python3 ./test/test_detector_blocks.py done diff --git a/litebird_sim/observations.py b/litebird_sim/observations.py index deab36fe..4bcac47e 100644 --- a/litebird_sim/observations.py +++ b/litebird_sim/observations.py @@ -675,29 +675,27 @@ def setattr_det_global(self, name, info, root=0): setattr(self, name, info) return - is_in_grid = self.comm.rank < self._n_blocks_det * self._n_blocks_time - comm_grid = self.comm.Split(int(is_in_grid)) - if not is_in_grid: # The process does not own any detector (and TOD) + if ( + MPI_COMM_GRID.COMM_OBS_GRID == MPI_COMM_GRID.COMM_NULL + ): # The process does not own any detector (and TOD) null_det = DetectorInfo() attribute = getattr(null_det, name, None) value = 0 if isinstance(attribute, numbers.Number) else None setattr(self, name, value) return - my_col = comm_grid.rank % self._n_blocks_time - comm_col = comm_grid.Split(my_col) + my_col = MPI_COMM_GRID.COMM_OBS_GRID.rank % self._n_blocks_time root_col = root // self._n_blocks_det if my_col == root_col: - if comm_grid.rank == root: + if MPI_COMM_GRID.COMM_OBS_GRID.rank == root: starts, nums, _, _ = self._get_start_and_num( self._n_blocks_det, self._n_blocks_time ) info = [info[s : s + n] for s, n in zip(starts, nums)] - info = comm_col.scatter(info, root) + info = self.comm_time_block.scatter(info, root) - comm_row = comm_grid.Split(comm_grid.rank // self._n_blocks_time) - info = comm_row.bcast(info, root_col) + info = self.comm_det_block.bcast(info, root_col) assert (not self.tod_list) or len(info) == len( getattr(self, self.tod_list[0].name) ) diff --git a/test/test_detector_blocks.py b/test/test_detector_blocks.py index 1c1e9346..e5f9dce9 100644 --- a/test/test_detector_blocks.py +++ b/test/test_detector_blocks.py @@ -1,54 +1,59 @@ import numpy as np import litebird_sim as lbs +# data for testing detector blocks and MPI sub-communicators +sampling_freq_Hz = 1 +dets = [ + lbs.DetectorInfo( + name="channel1_w9_detA", + wafer="wafer_9", + channel="channel1", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel1_w3_detB", + wafer="wafer_3", + channel="channel1", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel1_w1_detC", + wafer="wafer_1", + channel="channel1", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel1_w1_detD", + wafer="wafer_1", + channel="channel1", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel2_w4_detA", + wafer="wafer_4", + channel="channel2", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel2_w4_detB", + wafer="wafer_4", + channel="channel2", + sampling_rate_hz=sampling_freq_Hz, + ), +] -def test_detector_blocks(): + +def test_detector_blocks(dets=dets, sampling_freq_Hz=sampling_freq_Hz): comm = lbs.MPI_COMM_WORLD start_time = 456 duration_s = 100 - sampling_freq_Hz = 1 nobs_per_det = 3 - # Creating a list of detectors. - dets = [ - lbs.DetectorInfo( - name="channel1_w9_detA", - wafer="wafer_9", - channel="channel1", - sampling_rate_hz=sampling_freq_Hz, - ), - lbs.DetectorInfo( - name="channel1_w3_detB", - wafer="wafer_3", - channel="channel1", - sampling_rate_hz=sampling_freq_Hz, - ), - lbs.DetectorInfo( - name="channel1_w1_detC", - wafer="wafer_1", - channel="channel1", - sampling_rate_hz=sampling_freq_Hz, - ), - lbs.DetectorInfo( - name="channel1_w1_detD", - wafer="wafer_1", - channel="channel1", - sampling_rate_hz=sampling_freq_Hz, - ), - lbs.DetectorInfo( - name="channel2_w4_detA", - wafer="wafer_4", - channel="channel2", - sampling_rate_hz=sampling_freq_Hz, - ), - lbs.DetectorInfo( - name="channel2_w4_detB", - wafer="wafer_4", - channel="channel2", - sampling_rate_hz=sampling_freq_Hz, - ), - ] + if comm.size > 4: + det_blocks_attribute = ["channel", "wafer"] + else: + det_blocks_attribute = ["channel"] sim = lbs.Simulation( start_time=start_time, @@ -61,7 +66,7 @@ def test_detector_blocks(): detectors=dets, split_list_over_processes=False, num_of_obs_per_detector=nobs_per_det, - det_blocks_attributes=["channel", "wafer"], + det_blocks_attributes=det_blocks_attribute, ) tod_len_per_det_per_proc = 0 @@ -98,3 +103,56 @@ def test_detector_blocks(): start_idx = (comm.rank % n_blocks_time) * nobs_per_det stop_idx = start_idx + nobs_per_det np.testing.assert_equal(sum(arr[start_idx:stop_idx]), tod_len_per_det_per_proc) + + +def test_mpi_subcommunicators(dets=dets): + comm = lbs.MPI_COMM_WORLD + + start_time = 456 + duration_s = 100 + nobs_per_det = 3 + + if comm.size > 4: + det_blocks_attribute = ["channel", "wafer"] + else: + det_blocks_attribute = ["channel"] + + sim = lbs.Simulation( + start_time=start_time, + duration_s=duration_s, + random_seed=12345, + mpi_comm=comm, + ) + + sim.create_observations( + detectors=dets, + split_list_over_processes=False, + num_of_obs_per_detector=nobs_per_det, + det_blocks_attributes=det_blocks_attribute, + ) + + if lbs.MPI_COMM_GRID.COMM_OBS_GRID != lbs.MPI_COMM_GRID.COMM_NULL: + # since unused MPI processes stay at the end of global, + # communicator, the rank of the used processes in + # `MPI_COMM_GRID.COMM_OBS_GRID` must be same as their rank in + # global communicator + np.testing.assert_equal(lbs.MPI_COMM_GRID.COMM_OBS_GRID.rank, comm.rank) + + for obs in sim.observations: + # comm_det_block.rank + comm_time_block.rank * n_block_time + # must be equal to the global communicator rank for the + # used processes. It follows from the way split colors + # were defined. + np.testing.assert_equal( + obs.comm_det_block.rank + obs.comm_time_block.rank * obs.n_blocks_time, + comm.rank, + ) + else: + for obs in sim.observations: + # the global rank of the unused MPI processes must be larger than the number of used processes. + assert comm.rank > (obs.n_blocks_det * obs.n_blocks_time - 1) + + # The block communicators on the unused MPI processes must + # be the NULL communicators + np.testing.assert_equal(obs.comm_det_block, lbs.MPI_COMM_GRID.COMM_NULL) + np.testing.assert_equal(obs.comm_time_block, lbs.MPI_COMM_GRID.COMM_NULL)