diff --git a/cpp/include/raft/error.hpp b/cpp/include/raft/error.hpp index 773b83ab13..0eba4326e6 100644 --- a/cpp/include/raft/error.hpp +++ b/cpp/include/raft/error.hpp @@ -16,8 +16,10 @@ #pragma once +#include #include #include +#include #include #include #include @@ -91,16 +93,23 @@ struct logic_error : public raft::exception { // FIXME: Need to be replaced with RAFT_FAIL /** macro to throw a runtime error */ -#define THROW(fmt, ...) \ - do { \ - std::string msg; \ - char errMsg[2048]; /* NOLINT */ \ - std::snprintf( \ - errMsg, sizeof(errMsg), "exception occured! file=%s line=%d: ", __FILE__, __LINE__); \ - msg += errMsg; \ - std::snprintf(errMsg, sizeof(errMsg), fmt, ##__VA_ARGS__); \ - msg += errMsg; \ - throw raft::exception(msg); \ +#define THROW(fmt, ...) \ + do { \ + int size1 = \ + std::snprintf(nullptr, 0, "exception occured! file=%s line=%d: ", __FILE__, __LINE__); \ + int size2 = std::snprintf(nullptr, 0, fmt, ##__VA_ARGS__); \ + if (size1 < 0 || size2 < 0) \ + throw raft::exception("Error in snprintf, cannot handle raft exception."); \ + auto size = size1 + size2 + 1; /* +1 for final '\0' */ \ + auto buf = std::make_unique(size_t(size)); \ + std::snprintf(buf.get(), \ + size1 + 1 /* +1 for '\0' */, \ + "exception occured! file=%s line=%d: ", \ + __FILE__, \ + __LINE__); \ + std::snprintf(buf.get() + size1, size2 + 1 /* +1 for '\0' */, fmt, ##__VA_ARGS__); \ + std::string msg(buf.get(), buf.get() + size - 1); /* -1 to remove final '\0' */ \ + throw raft::exception(msg); \ } while (0) // FIXME: Need to be replaced with RAFT_EXPECTS @@ -110,15 +119,24 @@ struct logic_error : public raft::exception { if (!(check)) THROW(fmt, ##__VA_ARGS__); \ } while (0) -#define SET_ERROR_MSG(msg, location_prefix, fmt, ...) \ - do { \ - char err_msg[2048]; /* NOLINT */ \ - std::snprintf(err_msg, sizeof(err_msg), location_prefix); \ - msg += err_msg; \ - std::snprintf(err_msg, sizeof(err_msg), "file=%s line=%d: ", __FILE__, __LINE__); \ - msg += err_msg; \ - std::snprintf(err_msg, sizeof(err_msg), fmt, ##__VA_ARGS__); \ - msg += err_msg; \ +/** + * Macro to append error message to first argument. + * This should only be called in contexts where it is OK to throw exceptions! + */ +#define SET_ERROR_MSG(msg, location_prefix, fmt, ...) \ + do { \ + int size1 = std::snprintf(nullptr, 0, "%s", location_prefix); \ + int size2 = std::snprintf(nullptr, 0, "file=%s line=%d: ", __FILE__, __LINE__); \ + int size3 = std::snprintf(nullptr, 0, fmt, ##__VA_ARGS__); \ + if (size1 < 0 || size2 < 0 || size3 < 0) \ + throw raft::exception("Error in snprintf, cannot handle raft exception."); \ + auto size = size1 + size2 + size3 + 1; /* +1 for final '\0' */ \ + auto buf = std::make_unique(size_t(size)); \ + std::snprintf(buf.get(), size1 + 1 /* +1 for '\0' */, "%s", location_prefix); \ + std::snprintf( \ + buf.get() + size1, size2 + 1 /* +1 for '\0' */, "file=%s line=%d: ", __FILE__, __LINE__); \ + std::snprintf(buf.get() + size1 + size2, size3 + 1 /* +1 for '\0' */, fmt, ##__VA_ARGS__); \ + msg += std::string(buf.get(), buf.get() + size - 1); /* -1 to remove final '\0' */ \ } while (0) /** diff --git a/cpp/include/raft/linalg/cholesky_r1_update.cuh b/cpp/include/raft/linalg/cholesky_r1_update.cuh index 40009414ed..1745b0dcc8 100644 --- a/cpp/include/raft/linalg/cholesky_r1_update.cuh +++ b/cpp/include/raft/linalg/cholesky_r1_update.cuh @@ -160,8 +160,8 @@ void choleskyRank1Update(const raft::handle_t& handle, math_t* s = reinterpret_cast(((char*)workspace) + offset); math_t* L_22 = L + (n - 1) * ld + n - 1; - math_t* A_new; - math_t* A_row; + math_t* A_new = nullptr; + math_t* A_row = nullptr; if (uplo == CUBLAS_FILL_MODE_UPPER) { // A_new is stored as the n-1 th column of L A_new = L + (n - 1) * ld; diff --git a/cpp/test/cudart_utils.cpp b/cpp/test/cudart_utils.cpp index ff7588ce49..9df8600527 100644 --- a/cpp/test/cudart_utils.cpp +++ b/cpp/test/cudart_utils.cpp @@ -14,18 +14,74 @@ * limitations under the License. */ +#include + #include + #include -#include +#include namespace raft { +#define TEST_ADD_FILENAME(s) \ + { \ + s += std::string{__FILE__}; \ + } + +std::string reg_escape(const std::string& s) +{ + static const std::regex SPECIAL_CHARS{R"([-[\]{}()*+?.,\^$|#\s])"}; + return std::regex_replace(s, SPECIAL_CHARS, R"(\$&)"); +} + TEST(Raft, Utils) { ASSERT_NO_THROW(ASSERT(1 == 1, "Should not assert!")); ASSERT_THROW(ASSERT(1 != 1, "Should assert!"), exception); ASSERT_THROW(THROW("Should throw!"), exception); ASSERT_NO_THROW(RAFT_CUDA_TRY(cudaFree(nullptr))); + + // test for long error message strings + std::string test{"This is a test string repeated many times. "}; + for (size_t i = 0; i < 6; ++i) + test += test; + EXPECT_TRUE(test.size() > 2048) << "size of test string is: " << test.size(); + auto test_format = test + "%d"; + auto* test_format_c = test_format.c_str(); + + std::string file{}; + TEST_ADD_FILENAME(file); + std::string reg_file = reg_escape(file); + + // THROW has to convert the test string into an exception string + try { + ASSERT(1 != 1, test_format_c, 121); + } catch (const raft::exception& e) { + std::string msg_full{e.what()}; + // only use first line + std::string msg = msg_full.substr(0, msg_full.find('\n')); + std::string re_exp{"^exception occured! file="}; + re_exp += reg_file; + // test code must be at line >10 (copyright), assume line is never >9999 + re_exp += " line=\\d{2,4}: "; + re_exp += reg_escape(test); + re_exp += "121$"; + EXPECT_TRUE(std::regex_match(msg, std::regex(re_exp))) << "message:'" << msg << "'" << std::endl + << "expected regex:'" << re_exp << "'"; + } + + // Now we test SET_ERROR_MSG instead of THROW + std::string msg{"prefix:"}; + ASSERT_NO_THROW(SET_ERROR_MSG(msg, "location prefix:", test_format_c, 123)); + + std::string re_exp{"^prefix:location prefix:file="}; + re_exp += reg_file; + // test code must be at line >10 (copyright), assume line is never >9999 + re_exp += " line=\\d{2,4}: "; + re_exp += reg_escape(test); + re_exp += "123$"; + EXPECT_TRUE(std::regex_match(msg, std::regex(re_exp))) << "message:'" << msg << "'" << std::endl + << "expected regex:'" << re_exp << "'"; } } // namespace raft