Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPI Reduce for ValLocPair #3003

Merged
merged 1 commit into from
Oct 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions Src/Base/AMReX_ParallelDescriptor.H
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <AMReX_REAL.H>
#include <AMReX_Array.H>
#include <AMReX_Vector.H>
#include <AMReX_ValLocPair.H>

#ifndef BL_AMRPROF
#include <AMReX_Box.H>
Expand Down Expand Up @@ -211,6 +212,11 @@ while ( false )
extern AMREX_EXPORT MPI_Comm m_comm;
inline MPI_Comm Communicator () noexcept { return m_comm; }

#ifdef AMREX_USE_MPI
extern Vector<MPI_Datatype*> m_mpi_types;
extern Vector<MPI_Op*> m_mpi_ops;
#endif

//! return the number of MPI ranks local to the current Parallel Context
inline int
NProcs () noexcept
Expand Down Expand Up @@ -1479,6 +1485,73 @@ void DoReduce (T* r, MPI_Op op, int cnt, int cpu)
#endif
}

#ifdef AMREX_USE_MPI
namespace ParallelDescriptor {

template<typename TV, typename TI>
struct Mpi_typemap<ValLocPair<TV,TI>>
{
static MPI_Datatype type ()
{
static MPI_Datatype mpi_type = MPI_DATATYPE_NULL;
if (mpi_type == MPI_DATATYPE_NULL) {
using T = ValLocPair<TV,TI>;
static_assert(std::is_trivially_copyable<T>::value,
"To communicate with MPI, ValLocPair must be trivially copyable.");
static_assert(std::is_standard_layout<T>::value,
"To communicate with MPI, ValLocPair must be standard layout");

T vlp[2];
MPI_Datatype types[] = {
Mpi_typemap<TV>::type(),
Mpi_typemap<TI>::type(),
};
int blocklens[] = { 1, 1 };
MPI_Aint disp[2];
BL_MPI_REQUIRE( MPI_Get_address(&vlp[0].value, &disp[0]) );
BL_MPI_REQUIRE( MPI_Get_address(&vlp[0].index, &disp[1]) );
disp[1] -= disp[0];
disp[0] = 0;
BL_MPI_REQUIRE( MPI_Type_create_struct(2, blocklens, disp, types,
&mpi_type) );
MPI_Aint lb, extent;
BL_MPI_REQUIRE( MPI_Type_get_extent(mpi_type, &lb, &extent) );
if (extent != sizeof(T)) {
MPI_Datatype tmp = mpi_type;
BL_MPI_REQUIRE( MPI_Type_create_resized(tmp, 0, sizeof(vlp[0]), &mpi_type) );
BL_MPI_REQUIRE( MPI_Type_free(&tmp) );
}
BL_MPI_REQUIRE( MPI_Type_commit( &mpi_type ) );

m_mpi_types.push_back(&mpi_type);
}
return mpi_type;
}
};

template <typename T, typename F>
MPI_Op Mpi_op ()
{
static MPI_Op mpi_op = MPI_OP_NULL;
if (mpi_op == MPI_OP_NULL) {
static auto user_fn = [] (void *invec, void *inoutvec, int* len,
MPI_Datatype * /*datatype*/)
{
auto in = static_cast<T const*>(invec);
auto out = static_cast<T*>(inoutvec);
for (int i = 0; i < *len; ++i) {
out[i] = F()(in[i],out[i]);
}
};
BL_MPI_REQUIRE( MPI_Op_create(user_fn, 1, &mpi_op) );
m_mpi_ops.push_back(&mpi_op);
}
return mpi_op;
}

}
#endif

}

#endif /*BL_PARALLELDESCRIPTOR_H*/
15 changes: 15 additions & 0 deletions Src/Base/AMReX_ParallelDescriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ namespace amrex { namespace ParallelDescriptor {

MPI_Comm m_comm = MPI_COMM_NULL; // communicator for all ranks, probably MPI_COMM_WORLD

#ifdef AMREX_USE_MPI
Vector<MPI_Datatype*> m_mpi_types;
Vector<MPI_Op*> m_mpi_ops;
#endif

int m_MinTag = 1000, m_MaxTag = -1;

const int ioProcessor = 0;
Expand Down Expand Up @@ -357,10 +362,20 @@ EndParallel ()
BL_MPI_REQUIRE( MPI_Type_free(&mpi_type_indextype) );
BL_MPI_REQUIRE( MPI_Type_free(&mpi_type_box) );
BL_MPI_REQUIRE( MPI_Type_free(&mpi_type_lull_t) );
for (auto t : m_mpi_types) {
BL_MPI_REQUIRE( MPI_Type_free(t) );
*t = MPI_DATATYPE_NULL;
}
for (auto op : m_mpi_ops) {
BL_MPI_REQUIRE( MPI_Op_free(op) );
*op = MPI_OP_NULL;
}
mpi_type_intvect = MPI_DATATYPE_NULL;
mpi_type_indextype = MPI_DATATYPE_NULL;
mpi_type_box = MPI_DATATYPE_NULL;
mpi_type_lull_t = MPI_DATATYPE_NULL;
m_mpi_types.clear();
m_mpi_ops.clear();
}

if (!call_mpi_finalize) {
Expand Down
55 changes: 55 additions & 0 deletions Src/Base/AMReX_ParallelReduce.H
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <AMReX_Config.H>

#include <AMReX.H>
#include <AMReX_Functional.H>
#include <AMReX_ParallelDescriptor.H>
#include <AMReX_Print.H>
#include <AMReX_Vector.H>
Expand Down Expand Up @@ -120,6 +121,32 @@ namespace ParallelGather {

namespace ParallelAllReduce {

template<typename TV, typename TI>
void Max (ValLocPair<TV,TI>& vi, MPI_Comm comm) {
#ifdef AMREX_USE_MPI
auto tmp = vi;
using T = ValLocPair<TV,TI>;
MPI_Allreduce(&tmp, &vi, 1,
ParallelDescriptor::Mpi_typemap<T>::type(),
ParallelDescriptor::Mpi_op<T,amrex::Greater<T>>(), comm);
#else
amrex::ignore_unused(vi, comm);
#endif
}

template<typename TV, typename TI>
void Min (ValLocPair<TV,TI>& vi, MPI_Comm comm) {
#ifdef AMREX_USE_MPI
auto tmp = vi;
using T = ValLocPair<TV,TI>;
MPI_Allreduce(&tmp, &vi, 1,
ParallelDescriptor::Mpi_typemap<T>::type(),
ParallelDescriptor::Mpi_op<T,amrex::Less<T>>(), comm);
#else
amrex::ignore_unused(vi, comm);
#endif
}

template<typename T>
void Max (T& v, MPI_Comm comm) {
detail::Reduce(detail::ReduceOp::max, v, -1, comm);
Expand Down Expand Up @@ -174,6 +201,34 @@ namespace ParallelAllReduce {

namespace ParallelReduce {

template<typename TV, typename TI>
void Max (ValLocPair<TV,TI>& vi, int root, MPI_Comm comm) {
#ifdef AMREX_USE_MPI
auto tmp = vi;
using T = ValLocPair<TV,TI>;
MPI_Reduce(&tmp, &vi, 1,
ParallelDescriptor::Mpi_typemap<T>::type(),
ParallelDescriptor::Mpi_op<T,amrex::Greater<T>>(),
root, comm);
#else
amrex::ignore_unused(vi, root, comm);
#endif
}

template<typename TV, typename TI>
void Min (ValLocPair<TV,TI>& vi, int root, MPI_Comm comm) {
#ifdef AMREX_USE_MPI
auto tmp = vi;
using T = ValLocPair<TV,TI>;
MPI_Reduce(&tmp, &vi, 1,
ParallelDescriptor::Mpi_typemap<T>::type(),
ParallelDescriptor::Mpi_op<T,amrex::Less<T>>(),
root, comm);
#else
amrex::ignore_unused(vi, root, comm);
#endif
}

template<typename T>
void Max (T& v, int root, MPI_Comm comm) {
detail::Reduce(detail::ReduceOp::max, v, root, comm);
Expand Down
26 changes: 1 addition & 25 deletions Src/Base/AMReX_Reduce.H
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,14 @@
#include <AMReX_Arena.H>
#include <AMReX_OpenMP.H>
#include <AMReX_MFIter.H>
#include <AMReX_ValLocPair.H>

#include <algorithm>
#include <functional>
#include <limits>

namespace amrex {

template <typename TV, typename TI>
struct ValLocPair
{
TV value;
TI index;

static constexpr ValLocPair<TV,TI> max () {
return ValLocPair<TV,TI>{std::numeric_limits<TV>::max(), TI()};
}

static constexpr ValLocPair<TV,TI> lowest () {
return ValLocPair<TV,TI>{std::numeric_limits<TV>::lowest(), TI()};
}

friend constexpr bool operator< (ValLocPair<TV,TI> const& a, ValLocPair<TV,TI> const& b)
{
return a.value < b.value;
}

friend constexpr bool operator> (ValLocPair<TV,TI> const& a, ValLocPair<TV,TI> const& b)
{
return a.value > b.value;
}
};

namespace Reduce { namespace detail {

#ifdef AMREX_USE_GPU
Expand Down
35 changes: 35 additions & 0 deletions Src/Base/AMReX_ValLocPair.H
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#ifndef AMREX_VALLOCPAIR_H_
#define AMREX_VALLOCPAIR_H_

#include <limits>

namespace amrex {

template <typename TV, typename TI>
struct ValLocPair
{
TV value;
TI index;

static constexpr ValLocPair<TV,TI> max () {
return ValLocPair<TV,TI>{std::numeric_limits<TV>::max(), TI()};
}

static constexpr ValLocPair<TV,TI> lowest () {
return ValLocPair<TV,TI>{std::numeric_limits<TV>::lowest(), TI()};
}

friend constexpr bool operator< (ValLocPair<TV,TI> const& a, ValLocPair<TV,TI> const& b)
{
return a.value < b.value;
}

friend constexpr bool operator> (ValLocPair<TV,TI> const& a, ValLocPair<TV,TI> const& b)
{
return a.value > b.value;
}
};

}

#endif
1 change: 1 addition & 0 deletions Src/Base/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ target_sources( amrex
AMReX_Utility.cpp
AMReX_FileSystem.H
AMReX_FileSystem.cpp
AMReX_ValLocPair.H
AMReX_Reduce.H
AMReX_Scan.H
AMReX_Partition.H
Expand Down
1 change: 1 addition & 0 deletions Src/Base/Make.package
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ C$(AMREX_BASE)_sources += AMReX_BlockMutex.cpp
C$(AMREX_BASE)_sources += AMReX_ParmParse.cpp AMReX_parmparse_fi.cpp AMReX_Utility.cpp
C$(AMREX_BASE)_headers += AMReX_ParmParse.H AMReX_Utility.H AMReX_BLassert.H AMReX_ArrayLim.H
C$(AMREX_BASE)_headers += AMReX_Functional.H AMReX_Reduce.H AMReX_Scan.H AMReX_Partition.H
C$(AMREX_BASE)_headers += AMReX_ValLocPair.H

C$(AMREX_BASE)_headers += AMReX_FileSystem.H
C$(AMREX_BASE)_sources += AMReX_FileSystem.cpp
Expand Down