Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved Tape Statistics #2235

Merged
merged 11 commits into from
Mar 6, 2024
82 changes: 80 additions & 2 deletions Common/include/basic_types/ad_structure.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@

#include "../code_config.hpp"
#include "../parallelization/omp_structure.hpp"
#ifdef HAVE_MPI
#include <mpi.h>
#endif

/*!
* \namespace AD
Expand Down Expand Up @@ -58,8 +61,13 @@ inline bool TapeActive() { return false; }

/*!
* \brief Prints out tape statistics.
*
* Tape statistics are aggregated across OpenMP threads and MPI processes, if applicable.
* With MPI, the given communicator is used to reduce data across MPI processes, and the printing behaviour can be set
* per rank (usually, only the master rank prints).
*/
inline void PrintStatistics() {}
template <typename Comm>
inline void PrintStatistics(Comm communicator, bool printingRank) {}

/*!
* \brief Registers the variable as an input. I.e. as a leaf of the computational graph.
Expand Down Expand Up @@ -348,7 +356,77 @@ FORCEINLINE void StopRecording() { AD::getTape().setPassive(); }

FORCEINLINE bool TapeActive() { return AD::getTape().isActive(); }

FORCEINLINE void PrintStatistics() { AD::getTape().printStatistics(); }
template <typename Comm>
FORCEINLINE void PrintStatistics(Comm communicator, bool printingRank) {
if (printingRank) {
std::cout << "-------------------------------------------------------\n";
std::cout << " Serial parts of the tape\n";
#ifdef HAVE_MPI
std::cout << " (aggregated across MPI processes)\n";
#endif
std::cout << "-------------------------------------------------------\n";
}

codi::TapeValues serialTapeValues = AD::getTape().getTapeValues();
serialTapeValues.combineDataMPI(communicator);

if (printingRank) {
serialTapeValues.formatDefault(std::cout);
}

double totalMemoryUsed = serialTapeValues.getUsedMemorySize();
double totalMemoryAllocated = serialTapeValues.getAllocatedMemorySize();

#ifdef HAVE_OPDI

if (printingRank) {
std::cout << "-------------------------------------------------------\n";
std::cout << " OpenMP parallel parts of the tape\n";
std::cout << " (aggregated across OpenMP threads)\n";
#ifdef HAVE_MPI
std::cout << " (aggregated across MPI processes)\n";
#endif
std::cout << "-------------------------------------------------------\n";
}

codi::TapeValues* aggregatedOpenMPTapeValues = nullptr;

// clang-format off

SU2_OMP_PARALLEL {
if (omp_get_thread_num() == 0) { // master thread
codi::TapeValues masterTapeValues = AD::getTape().getTapeValues();
aggregatedOpenMPTapeValues = &masterTapeValues;

SU2_OMP_BARRIER // master completes initialization
SU2_OMP_BARRIER // other threads complete adding their data

aggregatedOpenMPTapeValues->combineDataMPI(communicator);
totalMemoryUsed += aggregatedOpenMPTapeValues->getUsedMemorySize();
totalMemoryAllocated += aggregatedOpenMPTapeValues->getAllocatedMemorySize();
if (printingRank) {
aggregatedOpenMPTapeValues->formatDefault(std::cout);
}
aggregatedOpenMPTapeValues = nullptr;
} else { // other threads
SU2_OMP_BARRIER // master completes initialization
SU2_OMP_CRITICAL {
aggregatedOpenMPTapeValues->combineData(AD::getTape().getTapeValues());
} END_SU2_OMP_CRITICAL
SU2_OMP_BARRIER // other threads complete adding their data
}
} END_SU2_OMP_PARALLEL

// clang-format on
#endif

if (printingRank) {
std::cout << "-------------------------------------------------------\n";
std::cout << " Total memory used : " << totalMemoryUsed / 1024.0 / 1024.0 << " MB\n";
std::cout << " Total memory allocated : " << totalMemoryAllocated / 1024.0 / 1024.0 << " MB\n";
std::cout << "-------------------------------------------------------\n";
}
}

FORCEINLINE void ClearAdjoints() { AD::getTape().clearAdjoints(); }

Expand Down
14 changes: 1 addition & 13 deletions SU2_CFD/src/drivers/CDiscAdjMultizoneDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -670,19 +670,7 @@ void CDiscAdjMultizoneDriver::SetRecording(RECORDING kind_recording, Kind_Tape t
}

if (kind_recording != RECORDING::CLEAR_INDICES && driver_config->GetWrt_AD_Statistics()) {
if (rank == MASTER_NODE) AD::PrintStatistics();
#ifdef CODI_REVERSE_TYPE
if (size > SINGLE_NODE) {
su2double myMem = AD::getTape().getTapeValues().getUsedMemorySize(), totMem = 0.0;
SU2_MPI::Allreduce(&myMem, &totMem, 1, MPI_DOUBLE, MPI_SUM, SU2_MPI::GetComm());
if (rank == MASTER_NODE) {
cout << "MPI\n";
cout << "-------------------------------------\n";
cout << " Total memory used : " << totMem << " MB\n";
cout << "-------------------------------------\n" << endl;
}
}
#endif
AD::PrintStatistics(SU2_MPI::GetComm(), rank == MASTER_NODE);
}

AD::StopRecording();
Expand Down
14 changes: 1 addition & 13 deletions SU2_CFD/src/drivers/CDiscAdjSinglezoneDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,19 +305,7 @@ void CDiscAdjSinglezoneDriver::SetRecording(RECORDING kind_recording){
SetObjFunction();

if (kind_recording != RECORDING::CLEAR_INDICES && config_container[ZONE_0]->GetWrt_AD_Statistics()) {
if (rank == MASTER_NODE) AD::PrintStatistics();
#ifdef CODI_REVERSE_TYPE
if (size > SINGLE_NODE) {
su2double myMem = AD::getTape().getTapeValues().getUsedMemorySize(), totMem = 0.0;
SU2_MPI::Allreduce(&myMem, &totMem, 1, MPI_DOUBLE, MPI_SUM, SU2_MPI::GetComm());
if (rank == MASTER_NODE) {
cout << "MPI\n";
cout << "-------------------------------------\n";
cout << " Total memory used : " << totMem << " MB\n";
cout << "-------------------------------------\n" << endl;
}
}
#endif
AD::PrintStatistics(SU2_MPI::GetComm(), rank == MASTER_NODE);
}

AD::StopRecording();
Expand Down
2 changes: 1 addition & 1 deletion meson_scripts/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def init_submodules(

# This information of the modules is used if projects was not cloned using git
# The sha tag must be maintained manually to point to the correct commit
sha_version_codi = "bb7689fb9479818d4ab55c4f3898c88d92890315"
sha_version_codi = "c6b039e5c9edb7675f90ffc725f9dd8e66571264"
github_repo_codi = "https://github.com/scicompkl/CoDiPack"
sha_version_medi = "ab3a7688f6d518f8d940eb61a341d89f51922ba4"
github_repo_medi = "https://github.com/SciCompKL/MeDiPack"
Expand Down