Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distribution of detector blocks across MPI processes #334

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
6 changes: 3 additions & 3 deletions docs/source/observations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,13 @@ balancing, it is less suitable for simulating the some effects, like crosstalk a
noise correlation between the detectors. This uniform distribution across MPI
processes necessitates the transfer of large TOD arrays across multiple MPI processes,
which complicates the code implementation and may potentially lead to significant
performance overhead. To save us from this situation, the :class:`Observation` class
performance overhead. To save us from this situation, the :class:`.Observation` class
accepts an argument ``det_blocks_attributes`` that is a list of string objects
specifying the detector attributes to create the group of detectors. Once the
detector groups are made, the detectors are distributed to the MPI processes in such
a way that all the detectors of a group reside on the same MPI process.

If a valid ``det_blocks_attributes`` argument is passed to the :class:`Observation`
If a valid ``det_blocks_attributes`` argument is passed to the :class:`.Observation`
class, the arguments ``n_blocks_det`` and ``n_blocks_time`` are ignored. Since the
``det_blocks_attributes`` creates the detector blocks dynamically, the
``n_blocks_time`` is computed during runtime using the size of MPI communicator and
Expand Down Expand Up @@ -369,7 +369,7 @@ detectors axis and time axis is divided depending on the size of MPI communicato
quantities refer to the same detector. If you need the global detector index,
you can get it with ``obs.det_idx[0]``, which is created at construction time.
``obs.det_idx`` stores the detector indices of the detectors available to an
:class:`Observation` class, with respect to the list of detectors stored in
:class:`.Observation` class, with respect to the list of detectors stored in
``obs.detectors_global`` variable.

.. note::
Expand Down
3 changes: 2 additions & 1 deletion litebird_sim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
)
from .madam import save_simulation_for_madam
from .mbs.mbs import Mbs, MbsParameters, MbsSavedMapInfo
from .mpi import MPI_COMM_WORLD, MPI_ENABLED, MPI_CONFIGURATION
from .mpi import MPI_COMM_WORLD, MPI_ENABLED, MPI_CONFIGURATION, comm_grid
from .noise import (
add_white_noise,
add_one_over_f_noise,
Expand Down Expand Up @@ -218,6 +218,7 @@ def destripe_with_toast2(*args, **kwargs):
"MPI_COMM_WORLD",
"MPI_ENABLED",
"MPI_CONFIGURATION",
"comm_grid",
# observations.py
"Observation",
"TodDescription",
Expand Down
2 changes: 1 addition & 1 deletion litebird_sim/mapmaking/binner.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class BinnerResult:

@njit
def _solve_binning(nobs_matrix, atd):
# Sove the map-making equation
# Solve the map-making equation
#
# This method alters the parameter `nobs_matrix`, so that after its completion
# each 3×3 matrix in nobs_matrix[idx, :, :] will be the *inverse*.
Expand Down
5 changes: 4 additions & 1 deletion litebird_sim/mapmaking/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,10 @@ def _compute_pixel_indices(

if output_coordinate_system == CoordinateSystem.Galactic:
# Free curr_pointings_det if the output map is already in Galactic coordinates
del curr_pointings_det
try:
del curr_pointings_det
except UnboundLocalError:
pass

return pixidx_all, polang_all

Expand Down
224 changes: 123 additions & 101 deletions litebird_sim/mapmaking/destriper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from numba import njit, prange
import healpy as hp

from litebird_sim.mpi import MPI_ENABLED, MPI_COMM_WORLD
from litebird_sim.mpi import MPI_ENABLED, MPI_COMM_WORLD, comm_grid
from typing import Callable, Union, List, Optional, Tuple, Any, Dict
from litebird_sim.hwp import HWP
from litebird_sim.observations import Observation
Expand All @@ -44,7 +44,7 @@


__DESTRIPER_RESULTS_FILE_NAME = "destriper_results.fits"
__BASELINES_FILE_NAME = f"baselines_mpi{MPI_COMM_WORLD.rank:04d}.fits"
__BASELINES_FILE_NAME = f"baselines_mpi{comm_grid.COMM_OBS_GRID.rank:04d}.fits"


def _split_items_into_n_segments(n: int, num_of_segments: int) -> List[int]:
Expand Down Expand Up @@ -498,8 +498,10 @@ def _build_nobs_matrix(
)

# Now we must accumulate the result of every MPI process
if MPI_ENABLED:
MPI_COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE, nobs_matrix, op=mpi4py.MPI.SUM)
if MPI_ENABLED and comm_grid.COMM_OBS_GRID != comm_grid.COMM_NULL:
comm_grid.COMM_OBS_GRID.Allreduce(
mpi4py.MPI.IN_PLACE, nobs_matrix, op=mpi4py.MPI.SUM
)

# `nobs_matrix_cholesky` will *not* contain the M_i maps shown in
# Eq. 9 of KurkiSuonio2009, but its Cholesky decomposition, i.e.,
Expand Down Expand Up @@ -746,8 +748,12 @@ def _compute_binned_map(
)

if MPI_ENABLED:
MPI_COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE, output_sky_map, op=mpi4py.MPI.SUM)
MPI_COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE, output_hit_map, op=mpi4py.MPI.SUM)
comm_grid.COMM_OBS_GRID.Allreduce(
mpi4py.MPI.IN_PLACE, output_sky_map, op=mpi4py.MPI.SUM
)
comm_grid.COMM_OBS_GRID.Allreduce(
mpi4py.MPI.IN_PLACE, output_hit_map, op=mpi4py.MPI.SUM
)

# Step 2: compute the “binned map” (Eq. 21)
_sum_map_to_binned_map(
Expand Down Expand Up @@ -987,7 +993,7 @@ def _mpi_dot(a: List[npt.ArrayLike], b: List[npt.ArrayLike]) -> float:
# the dot product
local_result = sum([np.dot(x1.flatten(), x2.flatten()) for (x1, x2) in zip(a, b)])
if MPI_ENABLED:
return MPI_COMM_WORLD.allreduce(local_result, op=mpi4py.MPI.SUM)
return comm_grid.COMM_OBS_GRID.allreduce(local_result, op=mpi4py.MPI.SUM)
else:
return local_result

Expand All @@ -1004,7 +1010,7 @@ def _get_stopping_factor(residual: List[npt.ArrayLike]) -> float:
"""
local_result = np.max(np.abs(residual))
if MPI_ENABLED:
return MPI_COMM_WORLD.allreduce(local_result, op=mpi4py.MPI.MAX)
return comm_grid.COMM_OBS_GRID.allreduce(local_result, op=mpi4py.MPI.MAX)
else:
return local_result

Expand Down Expand Up @@ -1418,7 +1424,7 @@ def _run_destriper(
bytes_in_temporary_buffers += mask.nbytes

if MPI_ENABLED:
bytes_in_temporary_buffers = MPI_COMM_WORLD.allreduce(
bytes_in_temporary_buffers = comm_grid.COMM_OBS_GRID.allreduce(
bytes_in_temporary_buffers,
op=mpi4py.MPI.SUM,
)
Expand Down Expand Up @@ -1613,91 +1619,103 @@ def my_gui_callback(
binned_map = np.empty((3, number_of_pixels))
hit_map = np.empty(number_of_pixels)

if do_destriping:
try:
# This will fail if the parameter is a scalar
len(params.samples_per_baseline)

baseline_lengths_list = params.samples_per_baseline
assert len(baseline_lengths_list) == len(obs_list), (
f"The list baseline_lengths_list has {len(baseline_lengths_list)} "
f"elements, but there are {len(obs_list)} observations"
)
except TypeError:
# Ok, params.samples_per_baseline is a scalar, so we must
# figure out the number of samples in each baseline within
# each observation
baseline_lengths_list = [
split_items_evenly(
n=getattr(cur_obs, components[0]).shape[1],
sub_n=int(params.samples_per_baseline),
if comm_grid.COMM_OBS_GRID != comm_grid.COMM_NULL:
# perform the following operations when MPI is not being used
# OR when the comm_grid.COMM_OBS_GRID is not a NULL communicator
if do_destriping:
try:
# This will fail if the parameter is a scalar
len(params.samples_per_baseline)

baseline_lengths_list = params.samples_per_baseline
assert len(baseline_lengths_list) == len(obs_list), (
f"The list baseline_lengths_list has {len(baseline_lengths_list)} "
f"elements, but there are {len(obs_list)} observations"
)
for cur_obs in obs_list
]
except TypeError:
# Ok, params.samples_per_baseline is a scalar, so we must
# figure out the number of samples in each baseline within
# each observation
baseline_lengths_list = [
split_items_evenly(
n=getattr(cur_obs, components[0]).shape[1],
sub_n=int(params.samples_per_baseline),
)
for cur_obs in obs_list
]

# Each element of this list is a 2D array with shape (N_det, N_baselines),
# where N_det is the number of detectors in the i-th Observation object
recycle_baselines = False
if baselines_list is None:
baselines_list = [
np.zeros(
(getattr(cur_obs, components[0]).shape[0], len(cur_baseline))
)
for (cur_obs, cur_baseline) in zip(obs_list, baseline_lengths_list)
]
else:
recycle_baselines = True

destriped_map = np.empty((3, number_of_pixels))
(
baselines_list,
baseline_errors_list,
history_of_stopping_factors,
best_stopping_factor,
converged,
bytes_in_temporary_buffers,
) = _run_destriper(
obs_list=obs_list,
nobs_matrix_cholesky=nobs_matrix_cholesky,
binned_map=binned_map,
destriped_map=destriped_map,
hit_map=hit_map,
baseline_lengths_list=baseline_lengths_list,
baselines_list_start=baselines_list,
recycle_baselines=recycle_baselines,
recycled_convergence=recycled_convergence,
dm_list=detector_mask_list,
tm_list=time_mask_list,
component=components[0],
threshold=params.threshold,
max_steps=params.iter_max,
use_preconditioner=params.use_preconditioner,
callback=callback,
callback_kwargs=callback_kwargs if callback_kwargs else {},
)

# Each element of this list is a 2D array with shape (N_det, N_baselines),
# where N_det is the number of detectors in the i-th Observation object
recycle_baselines = False
if baselines_list is None:
baselines_list = [
np.zeros((getattr(cur_obs, components[0]).shape[0], len(cur_baseline)))
for (cur_obs, cur_baseline) in zip(obs_list, baseline_lengths_list)
]
if MPI_ENABLED:
bytes_in_temporary_buffers = comm_grid.COMM_OBS_GRID.allreduce(
bytes_in_temporary_buffers,
op=mpi4py.MPI.SUM,
)
else:
recycle_baselines = True

destriped_map = np.empty((3, number_of_pixels))
(
baselines_list,
baseline_errors_list,
history_of_stopping_factors,
best_stopping_factor,
converged,
bytes_in_temporary_buffers,
) = _run_destriper(
obs_list=obs_list,
nobs_matrix_cholesky=nobs_matrix_cholesky,
binned_map=binned_map,
destriped_map=destriped_map,
hit_map=hit_map,
baseline_lengths_list=baseline_lengths_list,
baselines_list_start=baselines_list,
recycle_baselines=recycle_baselines,
recycled_convergence=recycled_convergence,
dm_list=detector_mask_list,
tm_list=time_mask_list,
component=components[0],
threshold=params.threshold,
max_steps=params.iter_max,
use_preconditioner=params.use_preconditioner,
callback=callback,
callback_kwargs=callback_kwargs if callback_kwargs else {},
)

if MPI_ENABLED:
bytes_in_temporary_buffers = MPI_COMM_WORLD.allreduce(
bytes_in_temporary_buffers,
op=mpi4py.MPI.SUM,
# No need to run the destriping, just compute the binned map with
# one single baseline set to zero
_compute_binned_map(
obs_list=obs_list,
output_sky_map=binned_map,
output_hit_map=hit_map,
nobs_matrix_cholesky=nobs_matrix_cholesky,
component=components[0],
dm_list=detector_mask_list,
tm_list=time_mask_list,
baselines_list=None,
baseline_lengths_list=[
np.array([getattr(cur_obs, components[0]).shape[1]], dtype=int)
for cur_obs in obs_list
],
)
bytes_in_temporary_buffers = 0
destriped_map = None
baseline_lengths_list = None
baselines_list = None
baseline_errors_list = None
history_of_stopping_factors = None
best_stopping_factor = None
converged = True
else:
# No need to run the destriping, just compute the binned map with
# one single baseline set to zero
_compute_binned_map(
obs_list=obs_list,
output_sky_map=binned_map,
output_hit_map=hit_map,
nobs_matrix_cholesky=nobs_matrix_cholesky,
component=components[0],
dm_list=detector_mask_list,
tm_list=time_mask_list,
baselines_list=None,
baseline_lengths_list=[
np.array([getattr(cur_obs, components[0]).shape[1]], dtype=int)
for cur_obs in obs_list
],
)
bytes_in_temporary_buffers = 0

destriped_map = None
baseline_lengths_list = None
baselines_list = None
Expand All @@ -1707,14 +1725,18 @@ def my_gui_callback(
converged = True

# Add the temporary memory that was allocated *before* calling the destriper
bytes_in_temporary_buffers += sum(
[
cur_obs.destriper_weights.nbytes
+ cur_obs.destriper_pixel_idx.nbytes
+ cur_obs.destriper_pol_angle_rad.nbytes
for cur_obs in obs_list
]
)
try:
bytes_in_temporary_buffers += sum(
[
cur_obs.destriper_weights.nbytes
+ cur_obs.destriper_pixel_idx.nbytes
+ cur_obs.destriper_pol_angle_rad.nbytes
for cur_obs in obs_list
]
)
except UnboundLocalError:
# The case when `bytes_in_temporary_buffers` is not defined
bytes_in_temporary_buffers = 0

# We're nearly done! Let's clean up some stuff…
if not keep_weights:
Expand Down Expand Up @@ -1992,11 +2014,11 @@ def _save_baselines(results: DestriperResult, output_file: Path) -> None:

primary_hdu = fits.PrimaryHDU()
primary_hdu.header["MPIRANK"] = (
MPI_COMM_WORLD.rank,
comm_grid.COMM_OBS_GRID.rank,
"The rank of the MPI process that wrote this file",
)
primary_hdu.header["MPISIZE"] = (
MPI_COMM_WORLD.size,
comm_grid.COMM_OBS_GRID.size,
"The number of MPI processes used in the computation",
)

Expand Down Expand Up @@ -2212,11 +2234,11 @@ def load_destriper_results(
baselines_file_name = folder / __BASELINES_FILE_NAME

with fits.open(baselines_file_name) as inpf:
assert MPI_COMM_WORLD.rank == inpf[0].header["MPIRANK"], (
assert comm_grid.COMM_OBS_GRID.rank == inpf[0].header["MPIRANK"], (
"You must call load_destriper_results using the "
"same MPI layout that was used for save_destriper_results "
)
assert MPI_COMM_WORLD.size == inpf[0].header["MPISIZE"], (
assert comm_grid.COMM_OBS_GRID.size == inpf[0].header["MPISIZE"], (
"You must call load_destriper_results using the "
"same MPI layout that was used for save_destriper_results"
)
Expand Down
Loading