Skip to content

Commit

Permalink
Improved error handling for maybe uninmplemented ectrans GPU features
Browse files Browse the repository at this point in the history
  • Loading branch information
wdeconinck committed Dec 19, 2024
1 parent bfc54b2 commit cd69b7e
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 23 deletions.
1 change: 1 addition & 0 deletions src/tests/trans/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# granted to it by virtue of its status as an intergovernmental organisation nor
# does it submit to any jurisdiction.

list( APPEND ATLAS_TEST_ENVIRONMENT ATLAS_TEST_IGNORE_ECTRANS_NOT_IMPLEMENTED=1 )
if( HAVE_FCTEST )

if( atlas_HAVE_ECTRANS )
Expand Down
70 changes: 51 additions & 19 deletions src/tests/trans/test_trans.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
*/

#include <algorithm>
#include <cstdlib>

#include "eckit/exception/Exceptions.h"
#include "eckit/filesystem/PathName.h"
Expand Down Expand Up @@ -48,6 +49,44 @@
#endif
#endif

bool ignore_ectrans_not_implemented(std::exception& e) {
static bool ATLAS_TEST_IGNORE_ECTRANS_NOT_IMPLEMENTED = []() -> bool {
const char* env = ::getenv("ATLAS_TEST_IGNORE_ECTRANS_NOT_IMPLEMENTED");
if (env) {
return std::atoi(env);
}
return false;
}();
if (ATLAS_TEST_IGNORE_ECTRANS_NOT_IMPLEMENTED == false) {
return false;
}
std::string errstr(e.what());
for(auto& c : errstr){ c = std::tolower(c); }
std::vector<std::string> search{"trans", "not", "implemented"};
return std::all_of(search.begin() ,search.end(),[&](const std::string& s)->bool {
return errstr.find(s) != std::string::npos;
});
}

#define ECTRANS_MAYBE_NOT_IMPLEMENTED(expr) \
do { \
try { \
expr; \
} \
catch (std::exception & e) { \
if( ignore_ectrans_not_implemented(e) ) { \
atlas::Log::error() << "ERROR IGNORED: Not implemented with ectrans code path\n" \
<< "Skipping remainder of test" \
<< std::endl; \
return; \
} \
throw eckit::testing::TestException("Unexpected exception caught: "+std::string(e.what()), Here()); \
} \
catch (...) { \
throw eckit::testing::TestException("Unexpected and unknown exception caught", Here()); \
} \
} while (false)

using namespace eckit;
using atlas::grid::detail::partitioner::EqualRegionsPartitioner;
using atlas::grid::detail::partitioner::TransPartitioner;
Expand Down Expand Up @@ -166,7 +205,6 @@ CASE("test_trans_options") {
Log::info() << "trans_opts = " << opts << std::endl;
}

#ifdef TRANS_HAVE_IO
CASE("test_write_read_cache") {
Log::info() << "test_write_read_cache" << std::endl;
using namespace trans;
Expand All @@ -193,7 +231,6 @@ CASE("test_write_read_cache") {
Trans trans_cache_O24(legendre_cache_O24, Grid("O24"), 23, option::flt(false));
}
}
#endif

CASE("test_distspec") {
trans::TransIFS trans(Grid("F80"), 159);
Expand Down Expand Up @@ -516,13 +553,11 @@ CASE("test_trans_using_functionspace_StructuredColumns") {
EXPECT_THROWS_AS(trans.dirtrans(gpfields, spfields), eckit::Exception);
}

#if 0
// NOT SUPPORTED IN ECTRANS WITH GPU
CASE("test_trans_MIR_lonlat") {
Log::info() << "test_trans_MIR_lonlat" << std::endl;

Grid grid("L48");
trans::Trans trans(grid, 47);
trans::Trans trans;
ECTRANS_MAYBE_NOT_IMPLEMENTED((trans = trans::Trans(grid, 47)));

// global fields
std::vector<double> spf(trans.spectralCoefficients(), 0.);
Expand All @@ -534,10 +569,7 @@ CASE("test_trans_MIR_lonlat") {
EXPECT_NO_THROW(trans.dirtrans(1, gpf.data(), spf.data(), option::global()));
}
}
#endif

#if 0
// NOT SUPPORTED IN ECTRANS WITH GPU
CASE("test_trans_VorDivToUV") {
int nfld = 1; // TODO: test for nfld>1
std::vector<int> truncation_array{1}; // truncation_array{159,160,1279};
Expand Down Expand Up @@ -566,7 +598,9 @@ CASE("test_trans_VorDivToUV") {
std::vector<double> field_U(nfld * nspec2);
std::vector<double> field_V(nfld * nspec2);

vordiv_to_UV.execute(nspec2, nfld, field_vor.data(), field_div.data(), field_U.data(), field_V.data());
ECTRANS_MAYBE_NOT_IMPLEMENTED(
vordiv_to_UV.execute(nspec2, nfld, field_vor.data(), field_div.data(), field_U.data(), field_V.data());
);

// TODO: do some meaningful checks
Log::info() << "Trans library" << std::endl;
Expand All @@ -585,8 +619,9 @@ CASE("test_trans_VorDivToUV") {
std::vector<double> field_U(nfld * nspec2);
std::vector<double> field_V(nfld * nspec2);

vordiv_to_UV.execute(nspec2, nfld, field_vor.data(), field_div.data(), field_U.data(), field_V.data());

EXPECT_NO_THROW(
vordiv_to_UV.execute(nspec2, nfld, field_vor.data(), field_div.data(), field_U.data(), field_V.data());
);
// TODO: do some meaningful checks
Log::info() << "Local transform" << std::endl;
Log::info() << "U: " << std::endl;
Expand All @@ -597,9 +632,7 @@ CASE("test_trans_VorDivToUV") {
}
}
}
#endif

#ifdef TRANS_HAVE_IO
CASE("ATLAS-256: Legendre coefficient expected unique identifiers") {
if (mpi::comm().size() == 1) {
util::Config options;
Expand Down Expand Up @@ -685,21 +718,20 @@ CASE("ATLAS-256: Legendre coefficient expected unique identifiers") {
for (auto& domain : domains) {
for (int T : spectral_T) {
for (auto name : grids) {
Log::info() << "Case name:" << name << ", T:" << T << ", domain:" << domain << ", UID:'" << *uid
<< "'" << std::endl;
//Log::info() << "Case name:" << name << ", T:" << T << ", domain:" << domain << ", UID:'" << *uid
// << "'" << std::endl;

Grid grid(name, domain);
auto test = trans::LegendreCacheCreator(grid, T, options).uid();
ATLAS_DEBUG_VAR(test);
EXPECT(test == *uid);
//ATLAS_DEBUG_VAR(test);
EXPECT_EQ(test, *uid);

uid++;
}
}
}
}
}
#endif

//-----------------------------------------------------------------------------

Expand Down
48 changes: 44 additions & 4 deletions src/tests/trans/test_transgeneral.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,45 @@
#endif
#endif

bool ignore_ectrans_not_implemented(std::exception& e) {
static bool ATLAS_TEST_IGNORE_ECTRANS_NOT_IMPLEMENTED = []() -> bool {
const char* env = ::getenv("ATLAS_TEST_IGNORE_ECTRANS_NOT_IMPLEMENTED");
if (env) {
std::cout << "ATLAS_TEST_IGNORE_ECTRANS_NOT_IMPLEMENTED=" << env << std::endl;
return std::atoi(env);
}
return false;
}();
if (ATLAS_TEST_IGNORE_ECTRANS_NOT_IMPLEMENTED == false) {
return false;
}
std::string errstr(e.what());
for(auto& c : errstr){ c = std::tolower(c); }
std::vector<std::string> search{"trans", "not", "implemented"};
return std::all_of(search.begin() ,search.end(),[&](const std::string& s)->bool {
return errstr.find(s) != std::string::npos;
});
}

#define ECTRANS_MAYBE_NOT_IMPLEMENTED(expr) \
do { \
try { \
expr; \
} \
catch (std::exception & e) { \
if( ignore_ectrans_not_implemented(e) ) { \
atlas::Log::error() << "ERROR IGNORED: Not implemented with ectrans code path\n" \
<< "Skipping remainder of test" \
<< std::endl; \
return; \
} \
throw eckit::testing::TestException("Unexpected exception caught: "+std::string(e.what()), Here()); \
} \
catch (...) { \
throw eckit::testing::TestException("Unexpected and unknown exception caught", Here()); \
} \
} while (false)

using namespace eckit;

using atlas::array::Array;
Expand Down Expand Up @@ -1581,7 +1620,7 @@ CASE("test_trans_levels") {
}
#endif

#if 0
#if 1
// ECTRANS GPU VERSION DOES NOT YET SUPPORT THIS
#if ATLAS_HAVE_TRANS
#if ATLAS_HAVE_ECTRANS || defined(TRANS_HAVE_INVTRANS_ADJ)
Expand Down Expand Up @@ -1656,7 +1695,7 @@ CASE("test_2level_adjoint_test_with_powerspectrum_convolution") {

// transform fields to spectral and view
if (test_name[test_type].compare("inverse") == 0) {
transIFS.invtrans_adj(gpf, spf);
ECTRANS_MAYBE_NOT_IMPLEMENTED(transIFS.invtrans_adj(gpf, spf));
} else if (test_name[test_type].compare("direct") == 0) {
transIFS.dirtrans(gpf, spf);
}
Expand Down Expand Up @@ -1696,8 +1735,9 @@ CASE("test_2level_adjoint_test_with_powerspectrum_convolution") {
if (test_name[test_type].compare("inverse") == 0) {
transIFS.invtrans(spf, gpf2);
} else if (test_name[test_type].compare("direct") == 0) {
transIFS.dirtrans_adj(spf, gpf2);
ECTRANS_MAYBE_NOT_IMPLEMENTED(transIFS.dirtrans_adj(spf, gpf2));
}

Log::info() << "adjoint test transforms " << test_name[test_type] << std::endl;


Expand Down Expand Up @@ -1795,7 +1835,7 @@ CASE("test_2level_adjoint_test_with_vortdiv") {
}
atlas::mpi::comm().allReduceInPlace(adj_value, eckit::mpi::sum());

transIFS.dirtrans_wind2vordiv_adj(spfvor, spfdiv, gpfuv2);
ECTRANS_MAYBE_NOT_IMPLEMENTED(transIFS.dirtrans_wind2vordiv_adj(spfvor, spfdiv, gpfuv2));

double adj_value2(0.0);
for (atlas::idx_t j = gridFS.j_begin(); j < gridFS.j_end(); ++j) {
Expand Down

0 comments on commit cd69b7e

Please sign in to comment.