diff --git a/source/adios2/helper/adiosComm.h b/source/adios2/helper/adiosComm.h index 750f54dcc4..3ebf6878b8 100644 --- a/source/adios2/helper/adiosComm.h +++ b/source/adios2/helper/adiosComm.h @@ -218,6 +218,23 @@ class Comm const size_t *recvcounts, const size_t *displs, int root, const std::string &hint = std::string()) const; + template + void Gatherv64(const TSend *sendbuf, size_t sendcount, TRecv *recvbuf, + const size_t *recvcounts, const size_t *displs, int root, + const std::string &hint = std::string()) const; + + template + void Gatherv64OneSidedPush(const TSend *sendbuf, size_t sendcount, + TRecv *recvbuf, const size_t *recvcounts, + const size_t *displs, int root, + const std::string &hint = std::string()) const; + + template + void Gatherv64OneSidedPull(const TSend *sendbuf, size_t sendcount, + TRecv *recvbuf, const size_t *recvcounts, + const size_t *displs, int root, + const std::string &hint = std::string()) const; + template void Reduce(const T *sendbuf, T *recvbuf, size_t count, Op op, int root, const std::string &hint = std::string()) const; @@ -400,6 +417,26 @@ class CommImpl Datatype recvtype, int root, const std::string &hint) const = 0; + virtual void Gatherv64(const void *sendbuf, size_t sendcount, + Datatype sendtype, void *recvbuf, + const size_t *recvcounts, const size_t *displs, + Datatype recvtype, int root, + const std::string &hint) const = 0; + + virtual void Gatherv64OneSidedPush(const void *sendbuf, size_t sendcount, + Datatype sendtype, void *recvbuf, + const size_t *recvcounts, + const size_t *displs, Datatype recvtype, + int root, + const std::string &hint) const = 0; + + virtual void Gatherv64OneSidedPull(const void *sendbuf, size_t sendcount, + Datatype sendtype, void *recvbuf, + const size_t *recvcounts, + const size_t *displs, Datatype recvtype, + int root, + const std::string &hint) const = 0; + virtual void Reduce(const void *sendbuf, void *recvbuf, size_t count, Datatype datatype, Comm::Op op, int root, const std::string &hint) const = 0; diff --git a/source/adios2/helper/adiosComm.inl b/source/adios2/helper/adiosComm.inl index 0ff88d892c..d57d6b6338 100644 --- a/source/adios2/helper/adiosComm.inl +++ b/source/adios2/helper/adiosComm.inl @@ -217,6 +217,38 @@ void Comm::Gatherv(const TSend *sendbuf, size_t sendcount, TRecv *recvbuf, CommImpl::GetDatatype(), root, hint); } +template +void Comm::Gatherv64(const TSend *sendbuf, size_t sendcount, TRecv *recvbuf, + const size_t *recvcounts, const size_t *displs, int root, + const std::string &hint) const +{ + return m_Impl->Gatherv64(sendbuf, sendcount, CommImpl::GetDatatype(), + recvbuf, recvcounts, displs, + CommImpl::GetDatatype(), root, hint); +} + +template +void Comm::Gatherv64OneSidedPush(const TSend *sendbuf, size_t sendcount, + TRecv *recvbuf, const size_t *recvcounts, + const size_t *displs, int root, + const std::string &hint) const +{ + return m_Impl->Gatherv64OneSidedPush( + sendbuf, sendcount, CommImpl::GetDatatype(), recvbuf, recvcounts, + displs, CommImpl::GetDatatype(), root, hint); +} + +template +void Comm::Gatherv64OneSidedPull(const TSend *sendbuf, size_t sendcount, + TRecv *recvbuf, const size_t *recvcounts, + const size_t *displs, int root, + const std::string &hint) const +{ + return m_Impl->Gatherv64OneSidedPush( + sendbuf, sendcount, CommImpl::GetDatatype(), recvbuf, recvcounts, + displs, CommImpl::GetDatatype(), root, hint); +} + template void Comm::Reduce(const T *sendbuf, T *recvbuf, size_t count, Op op, int root, const std::string &hint) const diff --git a/source/adios2/helper/adiosCommDummy.cpp b/source/adios2/helper/adiosCommDummy.cpp index d9c431a096..de7ba1419f 100644 --- a/source/adios2/helper/adiosCommDummy.cpp +++ b/source/adios2/helper/adiosCommDummy.cpp @@ -80,6 +80,23 @@ class CommImplDummy : public CommImpl Datatype recvtype, int root, const std::string &hint) const override; + void Gatherv64(const void *sendbuf, size_t sendcount, Datatype sendtype, + void *recvbuf, const size_t *recvcounts, + const size_t *displs, Datatype recvtype, int root, + const std::string &hint) const override; + + void Gatherv64OneSidedPush(const void *sendbuf, size_t sendcount, + Datatype sendtype, void *recvbuf, + const size_t *recvcounts, const size_t *displs, + Datatype recvtype, int root, + const std::string &hint) const override; + + void Gatherv64OneSidedPull(const void *sendbuf, size_t sendcount, + Datatype sendtype, void *recvbuf, + const size_t *recvcounts, const size_t *displs, + Datatype recvtype, int root, + const std::string &hint) const override; + void Reduce(const void *sendbuf, void *recvbuf, size_t count, Datatype datatype, Comm::Op op, int root, const std::string &hint) const override; @@ -211,6 +228,53 @@ void CommImplDummy::Gatherv(const void *sendbuf, size_t sendcount, recvtype, root, hint); } +void CommImplDummy::Gatherv64(const void *sendbuf, size_t sendcount, + Datatype sendtype, void *recvbuf, + const size_t *recvcounts, const size_t *displs, + Datatype recvtype, int root, + const std::string &hint) const +{ + const size_t recvcount = recvcounts[0]; + if (recvcount != sendcount) + { + return CommDummyError("send and recv counts differ"); + } + CommImplDummy::Gather(sendbuf, sendcount, sendtype, recvbuf, recvcount, + recvtype, root, hint); +} + +void CommImplDummy::Gatherv64OneSidedPush(const void *sendbuf, size_t sendcount, + Datatype sendtype, void *recvbuf, + const size_t *recvcounts, + const size_t *displs, + Datatype recvtype, int root, + const std::string &hint) const +{ + const size_t recvcount = recvcounts[0]; + if (recvcount != sendcount) + { + return CommDummyError("send and recv counts differ"); + } + CommImplDummy::Gather(sendbuf, sendcount, sendtype, recvbuf, recvcount, + recvtype, root, hint); +} + +void CommImplDummy::Gatherv64OneSidedPull(const void *sendbuf, size_t sendcount, + Datatype sendtype, void *recvbuf, + const size_t *recvcounts, + const size_t *displs, + Datatype recvtype, int root, + const std::string &hint) const +{ + const size_t recvcount = recvcounts[0]; + if (recvcount != sendcount) + { + return CommDummyError("send and recv counts differ"); + } + CommImplDummy::Gather(sendbuf, sendcount, sendtype, recvbuf, recvcount, + recvtype, root, hint); +} + void CommImplDummy::Reduce(const void *sendbuf, void *recvbuf, size_t count, Datatype datatype, Comm::Op, int, const std::string &) const diff --git a/source/adios2/helper/adiosCommMPI.cpp b/source/adios2/helper/adiosCommMPI.cpp index 07f33d9abf..6eee7e268a 100644 --- a/source/adios2/helper/adiosCommMPI.cpp +++ b/source/adios2/helper/adiosCommMPI.cpp @@ -153,6 +153,23 @@ class CommImplMPI : public CommImpl Datatype recvtype, int root, const std::string &hint) const override; + void Gatherv64(const void *sendbuf, size_t sendcount, Datatype sendtype, + void *recvbuf, const size_t *recvcounts, + const size_t *displs, Datatype recvtype, int root, + const std::string &hint) const override; + + void Gatherv64OneSidedPull(const void *sendbuf, size_t sendcount, + Datatype sendtype, void *recvbuf, + const size_t *recvcounts, const size_t *displs, + Datatype recvtype, int root, + const std::string &hint) const override; + + void Gatherv64OneSidedPush(const void *sendbuf, size_t sendcount, + Datatype sendtype, void *recvbuf, + const size_t *recvcounts, const size_t *displs, + Datatype recvtype, int root, + const std::string &hint) const override; + void Reduce(const void *sendbuf, void *recvbuf, size_t count, Datatype datatype, Comm::Op op, int root, const std::string &hint) const override; @@ -344,6 +361,195 @@ void CommImplMPI::Gatherv(const void *sendbuf, size_t sendcount, hint); } +void CommImplMPI::Gatherv64(const void *sendbuf, size_t sendcount, + Datatype sendtype, void *recvbuf, + const size_t *recvcounts, const size_t *displs, + Datatype recvtype, int root, + const std::string &hint) const +{ + + const int chunksize = std::numeric_limits::max(); + + int mpiSize; + int mpiRank; + MPI_Comm_size(m_MPIComm, &mpiSize); + MPI_Comm_rank(m_MPIComm, &mpiRank); + + int recvTypeSize; + int sendTypeSize; + + MPI_Type_size(ToMPI(recvtype), &recvTypeSize); + MPI_Type_size(ToMPI(sendtype), &sendTypeSize); + + std::vector requests; + if (mpiRank == root) + { + for (int i = 0; i < mpiSize; ++i) + { + size_t recvcount = recvcounts[i]; + while (recvcount > 0) + { + requests.emplace_back(); + if (recvcount > chunksize) + { + MPI_Irecv(reinterpret_cast(recvbuf) + + (displs[i] + recvcounts[i] - recvcount) * + recvTypeSize, + chunksize, ToMPI(recvtype), i, 0, m_MPIComm, + &requests.back()); + recvcount -= chunksize; + } + else + { + MPI_Irecv(reinterpret_cast(recvbuf) + + (displs[i] + recvcounts[i] - recvcount) * + recvTypeSize, + static_cast(recvcount), ToMPI(recvtype), i, + 0, m_MPIComm, &requests.back()); + recvcount = 0; + } + } + } + } + + size_t sendcountvar = sendcount; + + while (sendcountvar > 0) + { + requests.emplace_back(); + if (sendcountvar > chunksize) + { + MPI_Isend(reinterpret_cast(sendbuf) + + (sendcount - sendcountvar) * sendTypeSize, + chunksize, ToMPI(sendtype), root, 0, m_MPIComm, + &requests.back()); + sendcountvar -= chunksize; + } + else + { + MPI_Isend(reinterpret_cast(sendbuf) + + (sendcount - sendcountvar) * sendTypeSize, + static_cast(sendcountvar), ToMPI(sendtype), root, 0, + m_MPIComm, &requests.back()); + sendcountvar = 0; + } + } + + MPI_Waitall(static_cast(requests.size()), requests.data(), + MPI_STATUSES_IGNORE); +} + +void CommImplMPI::Gatherv64OneSidedPush(const void *sendbuf, size_t sendcount, + Datatype sendtype, void *recvbuf, + const size_t *recvcounts, + const size_t *displs, Datatype recvtype, + int root, const std::string &hint) const +{ + const int chunksize = std::numeric_limits::max(); + + int mpiSize; + int mpiRank; + MPI_Comm_size(m_MPIComm, &mpiSize); + MPI_Comm_rank(m_MPIComm, &mpiRank); + + int recvTypeSize; + int sendTypeSize; + + MPI_Type_size(ToMPI(recvtype), &recvTypeSize); + MPI_Type_size(ToMPI(sendtype), &sendTypeSize); + + size_t recvsize = displs[mpiSize - 1] + recvcounts[mpiSize - 1]; + + MPI_Win win; + MPI_Win_create(recvbuf, recvsize * recvTypeSize, recvTypeSize, + MPI_INFO_NULL, m_MPIComm, &win); + + size_t sendcountvar = sendcount; + + while (sendcountvar > 0) + { + if (sendcountvar > chunksize) + { + MPI_Put(reinterpret_cast(sendbuf) + + (sendcount - sendcountvar) * sendTypeSize, + chunksize, ToMPI(sendtype), root, + displs[mpiRank] + sendcount - sendcountvar, chunksize, + ToMPI(sendtype), win); + sendcountvar -= chunksize; + } + else + { + MPI_Put(reinterpret_cast(sendbuf) + + (sendcount - sendcountvar) * sendTypeSize, + static_cast(sendcountvar), ToMPI(sendtype), root, + static_cast(displs[mpiRank]) + sendcount - + sendcountvar, + static_cast(sendcountvar), ToMPI(sendtype), win); + sendcountvar = 0; + } + } + + MPI_Win_free(&win); +} + +void CommImplMPI::Gatherv64OneSidedPull(const void *sendbuf, size_t sendcount, + Datatype sendtype, void *recvbuf, + const size_t *recvcounts, + const size_t *displs, Datatype recvtype, + int root, const std::string &hint) const +{ + + const int chunksize = std::numeric_limits::max(); + + int mpiSize; + int mpiRank; + MPI_Comm_size(m_MPIComm, &mpiSize); + MPI_Comm_rank(m_MPIComm, &mpiRank); + + int recvTypeSize; + int sendTypeSize; + + MPI_Type_size(ToMPI(recvtype), &recvTypeSize); + MPI_Type_size(ToMPI(sendtype), &sendTypeSize); + + MPI_Win win; + MPI_Win_create(const_cast(sendbuf), sendcount * sendTypeSize, + sendTypeSize, MPI_INFO_NULL, m_MPIComm, &win); + + if (mpiRank == root) + { + for (int i = 0; i < mpiSize; ++i) + { + size_t recvcount = recvcounts[i]; + while (recvcount > 0) + { + if (recvcount > chunksize) + { + MPI_Get(reinterpret_cast(recvbuf) + + (displs[i] + recvcounts[i] - recvcount) * + recvTypeSize, + chunksize, ToMPI(recvtype), i, + recvcounts[i] - recvcount, chunksize, + ToMPI(recvtype), win); + recvcount -= chunksize; + } + else + { + MPI_Get(reinterpret_cast(recvbuf) + + (displs[i] + recvcounts[i] - recvcount) * + recvTypeSize, + static_cast(recvcount), ToMPI(recvtype), i, + recvcounts[i] - recvcount, + static_cast(recvcount), ToMPI(recvtype), win); + recvcount = 0; + } + } + } + } + + MPI_Win_free(&win); +} + void CommImplMPI::Reduce(const void *sendbuf, void *recvbuf, size_t count, Datatype datatype, Comm::Op op, int root, const std::string &hint) const