Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: remove comm_2D from Parallel_2D (Useful Information: BLACS encapsulates MPI and maintains its own internal data structure for MPI communicators) #4658

Merged
merged 10 commits into from
Jul 12, 2024
77 changes: 77 additions & 0 deletions source/module_base/blacs_connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#ifndef BLACS_CONNECTOR_H
#define BLACS_CONNECTOR_H

#include <complex>

extern "C"
{
void Cblacs_pinfo(int *myid, int *nprocs);
Expand All @@ -41,13 +43,88 @@ extern "C"
int Cblacs_pnum(int icontxt, int prow, int pcol);
void Cblacs_pcoord(int icontxt, int pnum, int *prow, int *pcol);
void Cblacs_exit(int icontxt);

// broadcast (send/recv)
void Cigebs2d(int ConTxt, char *scope, char *top, int m, int n, int *A, int lda);
void Cigebr2d(int ConTxt, char *scope, char *top, int m, int n, int *A, int lda, int rsrc, int csrc);

void Csgebs2d(int ConTxt, char *scope, char *top, int m, int n, float *A, int lda);
void Csgebr2d(int ConTxt, char *scope, char *top, int m, int n, float *A, int lda, int rsrc, int csrc);

void Cdgebs2d(int ConTxt, char *scope, char *top, int m, int n, double *A, int lda);
void Cdgebr2d(int ConTxt, char *scope, char *top, int m, int n, double *A, int lda, int rsrc, int csrc);

void Ccgebs2d(int ConTxt, char *scope, char *top, int m, int n, std::complex<float> *A, int lda);
void Ccgebr2d(int ConTxt, char *scope, char *top, int m, int n, std::complex<float> *A, int lda, int rsrc, int csrc);

void Czgebs2d(int ConTxt, char *scope, char *top, int m, int n, std::complex<double> *A, int lda);
void Czgebr2d(int ConTxt, char *scope, char *top, int m, int n, std::complex<double> *A, int lda, int rsrc, int csrc);
}

// unified interface for broadcast
template <typename T>
void Cxgebs2d(int ConTxt, char *scope, char *top, int m, int n, T *A, int lda)
{
static_assert(
std::is_same<T, int>::value ||
std::is_same<T, float>::value ||
std::is_same<T, double>::value ||
std::is_same<T,std::complex<float>>::value ||
std::is_same<T,std::complex<double>>::value,
"Type not supported");

if (std::is_same<T, int>::value) {
Cigebs2d(ConTxt, scope, top, m, n, reinterpret_cast<int*>(A), lda);
}
if (std::is_same<T, float>::value) {
Csgebs2d(ConTxt, scope, top, m, n, reinterpret_cast<float*>(A), lda);
}
if (std::is_same<T, double>::value) {
Cdgebs2d(ConTxt, scope, top, m, n, reinterpret_cast<double*>(A), lda);
}
if (std::is_same<T, std::complex<float>>::value) {
Ccgebs2d(ConTxt, scope, top, m, n, reinterpret_cast<std::complex<float>*>(A), lda);
}
if (std::is_same<T, std::complex<double>>::value) {
Czgebs2d(ConTxt, scope, top, m, n, reinterpret_cast<std::complex<double>*>(A), lda);
}
}

template <typename T>
void Cxgebr2d(int ConTxt, char *scope, char *top, int m, int n, T *A, int lda, int rsrc, int csrc)
{
static_assert(
std::is_same<T, int>::value ||
std::is_same<T, float>::value ||
std::is_same<T, double>::value ||
std::is_same<T,std::complex<float>>::value ||
std::is_same<T,std::complex<double>>::value,
"Type not supported");

if (std::is_same<T, int>::value) {
Cigebr2d(ConTxt, scope, top, m, n, reinterpret_cast<int*>(A), lda, rsrc, csrc);
}
if (std::is_same<T, float>::value) {
Csgebr2d(ConTxt, scope, top, m, n, reinterpret_cast<float*>(A), lda, rsrc, csrc);
}
if (std::is_same<T, double>::value) {
Cdgebr2d(ConTxt, scope, top, m, n, reinterpret_cast<double*>(A), lda, rsrc, csrc);
}
if (std::is_same<T, std::complex<float>>::value) {
Ccgebr2d(ConTxt, scope, top, m, n, reinterpret_cast<std::complex<float>*>(A), lda, rsrc, csrc);
}
if (std::is_same<T, std::complex<double>>::value) {
Czgebr2d(ConTxt, scope, top, m, n, reinterpret_cast<std::complex<double>*>(A), lda, rsrc, csrc);
}
}


#ifdef __MPI
#include <mpi.h>
extern "C"
{
int Csys2blacs_handle(MPI_Comm SysCtxt);
MPI_Comm Cblacs2sys_handle(int BlacsCtxt);
}
#endif // __MPI

Expand Down
39 changes: 20 additions & 19 deletions source/module_basis/module_ao/parallel_2d.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "parallel_2d.h"

#include "module_base/blacs_connector.h"
#include "module_base/scalapack_connector.h"

#include <cassert>
Expand Down Expand Up @@ -30,6 +29,22 @@ int Parallel_2D::get_global_col_size() const
}

#ifdef __MPI
MPI_Comm Parallel_2D::comm() const
{
// it is an error to call blacs_get with an invalid BLACS context
if (blacs_ctxt < 0)
{
return MPI_COMM_NULL;
}

int sys_ctxt = 0;
Cblacs_get(blacs_ctxt, 10, &sys_ctxt);
// blacs_get with "what" = 10 takes a BLACS context and returns the index
// of the associated system context (MPI communicator) that can be used by
// blacs2sys_handle to get the MPI communicator.
return Cblacs2sys_handle(sys_ctxt);
}

void Parallel_2D::_init_proc_grid(const MPI_Comm comm, const bool mode)
{
// determine the number of rows and columns of the process grid
Expand All @@ -47,23 +62,11 @@ void Parallel_2D::_init_proc_grid(const MPI_Comm comm, const bool mode)
std::swap(dim0, dim1);
}

// create a 2D Cartesian MPI communicator (row-major by default)
int period[2] = {1, 1};
int dim[2] = {dim0, dim1};
const int reorder = 0;
MPI_Cart_create(comm, 2, dim, period, reorder, &comm_2D);
MPI_Cart_get(comm_2D, 2, dim, period, coord);

// initialize the BLACS grid accordingly
blacs_ctxt = Csys2blacs_handle(comm_2D);
blacs_ctxt = Csys2blacs_handle(comm);
char order = 'R'; // row-major
Cblacs_gridinit(&blacs_ctxt, &order, dim0, dim1);

// TODO Currently MPI and BLACS are made to have the same Cartesian grid.
// In theory, however, BLACS would split any given communicator to create
// new ones for its own purpose when initializing the process grid, so it
// might be unnecessary to create an MPI communicator with Cartesian topology.
// ***This needs to be verified***
Cblacs_gridinfo(blacs_ctxt, &dim0, &dim1, &coord[0], &coord[1]);
}

void Parallel_2D::_set_dist_info(const int mg, const int ng, const int nb)
Expand Down Expand Up @@ -105,9 +108,8 @@ int Parallel_2D::init(const int mg, const int ng, const int nb, const MPI_Comm c
return nrow == 0 || ncol == 0;
}

int Parallel_2D::set(const int mg, const int ng, const int nb, const MPI_Comm comm_2D, const int blacs_ctxt)
int Parallel_2D::set(const int mg, const int ng, const int nb, const int blacs_ctxt)
{
this->comm_2D = comm_2D;
this->blacs_ctxt = blacs_ctxt;
Cblacs_gridinfo(blacs_ctxt, &dim0, &dim1, &coord[0], &coord[1]);
_set_dist_info(mg, ng, nb);
Expand All @@ -124,15 +126,14 @@ void Parallel_2D::set_serial(const int mg, const int ng)
coord[0] = coord[1] = 0;
nrow = mg;
ncol = ng;
nloc = nrow * ncol;
nloc = static_cast<int64_t>(nrow) * ncol;
local2global_row_.resize(nrow);
local2global_col_.resize(ncol);
std::iota(local2global_row_.begin(), local2global_row_.end(), 0);
std::iota(local2global_col_.begin(), local2global_col_.end(), 0);
global2local_row_ = local2global_row_;
global2local_col_ = local2global_col_;
#ifdef __MPI
comm_2D = MPI_COMM_NULL;
blacs_ctxt = -1;
#endif
}
15 changes: 7 additions & 8 deletions source/module_basis/module_ao/parallel_2d.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
#include <cstdint>
#include <vector>

#ifdef __MPI
#include <mpi.h>
#endif
#include "module_base/blacs_connector.h"

/// @brief This class packs the basic information of
/// 2D-block-cyclic parallel distribution of an arbitrary matrix.
Expand Down Expand Up @@ -89,13 +87,12 @@ class Parallel_2D

/**
* @brief Set up the info of a block-cyclic distribution using given
* MPI communicator and BLACS context.
* BLACS context.
*
*/
int set(const int mg,
const int ng,
const int nb, // square block is assumed
const MPI_Comm comm_2D,
const int blacs_ctxt);

/// BLACS context
Expand All @@ -104,8 +101,7 @@ class Parallel_2D
/// ScaLAPACK descriptor
int desc[9] = {};

/// 2D Cartesian MPI communicator
MPI_Comm comm_2D = MPI_COMM_NULL;
MPI_Comm comm() const;
#endif

void set_serial(const int mg, const int ng);
Expand All @@ -118,6 +114,9 @@ class Parallel_2D
int nrow = 0;
int ncol = 0;
int64_t nloc = 0;
// NOTE: ScaLAPACK descriptors use int type for the number of rows and columns of
// both the global and local matrices, so nrow & ncol have to be int type. Their
// product, however, can exceed the range of int type.

/// block size
int nb = 1;
Expand All @@ -126,7 +125,7 @@ class Parallel_2D
int dim0 = 0;
int dim1 = 0;

/// process coordinate in the MPI Cartesian grid
/// process coordinate in the BLACS grid
int coord[2] = {-1, -1};

protected:
Expand Down
1 change: 0 additions & 1 deletion source/module_basis/module_ao/parallel_orbitals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ void Parallel_Orbitals::set_desc_wfc_Eij(const int& nbasis, const int& nbands, c
{
ModuleBase::TITLE("Parallel_2D", "set_desc_wfc_Eij");
#ifdef __DEBUG
assert(this->comm_2D != MPI_COMM_NULL);
assert(nbasis > 0 && nbands > 0 && lld > 0);
assert(this->nb > 0 && this->dim0 > 0 && this->dim1 > 0);
#endif
Expand Down
21 changes: 13 additions & 8 deletions source/module_basis/module_ao/test/parallel_2d_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,14 @@ TEST_F(test_para2d, Divide2D)

// 1. dim0 and dim1
EXPECT_EQ(p2d.dim0 * p2d.dim1, dsize);
if (mode)
if (mode) {
EXPECT_LE(p2d.dim1, p2d.dim0);
else
} else {
EXPECT_LE(p2d.dim0, p2d.dim1);
}

// 2. MPI 2d communicator
EXPECT_NE(p2d.comm_2D, MPI_COMM_NULL);
//EXPECT_NE(p2d.comm_2D, MPI_COMM_NULL);

// 3. local2global and local sizes
int lr = p2d.get_row_size();
Expand Down Expand Up @@ -96,18 +97,22 @@ TEST_F(test_para2d, Divide2D)
auto sum_array = [&p2d](const int& gr, const int& gc) -> std::pair<int, int> {
int sum_row = 0;
int sum_col = 0;
for (int i = 0; i < gr; ++i)
for (int i = 0; i < gr; ++i) {
sum_row += p2d.global2local_row(i);
for (int i = 0; i < gc; ++i)
}
for (int i = 0; i < gc; ++i) {
sum_col += p2d.global2local_col(i);
}
return {sum_row, sum_col};
};
std::pair<int, int> sumrc = sum_array(gr, gc);
EXPECT_EQ(std::get<0>(sumrc), lr * (lr - 1) / 2 - (gr - lr));
EXPECT_EQ(std::get<1>(sumrc), lc * (lc - 1) / 2 - (gc - lc));
for (int i = 0; i < lr; ++i)
for (int j = 0; j < lc; ++j)
for (int i = 0; i < lr; ++i) {
for (int j = 0; j < lc; ++j) {
EXPECT_TRUE(p2d.in_this_processor(p2d.local2global_row(i), p2d.local2global_col(j)));
}
}

EXPECT_EQ(p2d.get_global_row_size(), gr);
EXPECT_EQ(p2d.get_global_col_size(), gc);
Expand All @@ -124,7 +129,7 @@ TEST_F(test_para2d, DescReuseCtxt)
p1.init(sizes[0].first, sizes[0].second, nb, MPI_COMM_WORLD);

Parallel_2D p2; // use 2 different sizes, but they can share the same ctxt
p2.set(sizes[1].first, sizes[1].second, nb, p1.comm_2D, p1.blacs_ctxt);
p2.set(sizes[1].first, sizes[1].second, nb, p1.blacs_ctxt);

EXPECT_EQ(p1.desc[1], p2.desc[1]);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ TEST_F(TestParaO, Divide2D)
else EXPECT_LE(po.dim0, po.dim1);

//2. comm_2D
EXPECT_NE(po.comm_2D, MPI_COMM_NULL);
//EXPECT_NE(po.comm_2D, MPI_COMM_NULL);

//3. local2global and local sizes
int lr = po.get_row_size();
Expand Down
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ void ESolver_KS_LCAO<TK, TR>::init_basis_lcao(Input& inp, UnitCell& ucell)
try_nb += ParaV.set_nloc_wfc_Eij(GlobalV::NBANDS, GlobalV::ofs_running, GlobalV::ofs_warning);
if (try_nb != 0)
{
ParaV.set(GlobalV::NLOCAL, GlobalV::NLOCAL, 1, ParaV.comm_2D, ParaV.blacs_ctxt);
ParaV.set(GlobalV::NLOCAL, GlobalV::NLOCAL, 1, ParaV.blacs_ctxt);
try_nb = ParaV.set_nloc_wfc_Eij(GlobalV::NBANDS, GlobalV::ofs_running, GlobalV::ofs_warning);
}

Expand Down
11 changes: 3 additions & 8 deletions source/module_hamilt_lcao/module_tddft/bandenergy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,15 @@ void compute_ekb(const Parallel_Orbitals* pv,
}

int info;
int myid;
int naroc[2];
MPI_Comm_rank(pv->comm_2D, &myid);

double* Eii = new double[nband];
ModuleBase::GlobalFunc::ZEROS(Eii, nband);
for (int iprow = 0; iprow < pv->dim0; ++iprow)
{
for (int ipcol = 0; ipcol < pv->dim1; ++ipcol)
{
const int coord[2] = {iprow, ipcol};
int src_rank;
info = MPI_Cart_rank(pv->comm_2D, coord, &src_rank);
if (myid == src_rank)
if (iprow == pv->coord[0] && ipcol == pv->coord[1])
{
naroc[0] = pv->nrow;
naroc[1] = pv->ncol;
Expand All @@ -139,7 +134,7 @@ void compute_ekb(const Parallel_Orbitals* pv,
}
} // loop ipcol
} // loop iprow
info = MPI_Allreduce(Eii, ekb, nband, MPI_DOUBLE, MPI_SUM, pv->comm_2D);
info = MPI_Allreduce(Eii, ekb, nband, MPI_DOUBLE, MPI_SUM, pv->comm());

delete[] tmp1;
delete[] Eij;
Expand All @@ -148,4 +143,4 @@ void compute_ekb(const Parallel_Orbitals* pv,

#endif

} // namespace module_tddft
} // namespace module_tddft
Loading
Loading