Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] [FEA] error macros: determining buffer size instead of fixed 2048 chars #420

Merged
merged 4 commits into from
Jan 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 37 additions & 19 deletions cpp/include/raft/error.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

#pragma once

#include <cstdio>
#include <execinfo.h>
#include <iostream>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
Expand Down Expand Up @@ -91,16 +93,23 @@ struct logic_error : public raft::exception {

// FIXME: Need to be replaced with RAFT_FAIL
/** macro to throw a runtime error */
#define THROW(fmt, ...) \
do { \
std::string msg; \
char errMsg[2048]; /* NOLINT */ \
std::snprintf( \
errMsg, sizeof(errMsg), "exception occured! file=%s line=%d: ", __FILE__, __LINE__); \
msg += errMsg; \
std::snprintf(errMsg, sizeof(errMsg), fmt, ##__VA_ARGS__); \
msg += errMsg; \
throw raft::exception(msg); \
#define THROW(fmt, ...) \
do { \
int size1 = \
std::snprintf(nullptr, 0, "exception occured! file=%s line=%d: ", __FILE__, __LINE__); \
int size2 = std::snprintf(nullptr, 0, fmt, ##__VA_ARGS__); \
if (size1 < 0 || size2 < 0) \
throw raft::exception("Error in snprintf, cannot handle raft exception."); \
auto size = size1 + size2 + 1; /* +1 for final '\0' */ \
auto buf = std::make_unique<char[]>(size_t(size)); \
std::snprintf(buf.get(), \
size1 + 1 /* +1 for '\0' */, \
"exception occured! file=%s line=%d: ", \
__FILE__, \
__LINE__); \
std::snprintf(buf.get() + size1, size2 + 1 /* +1 for '\0' */, fmt, ##__VA_ARGS__); \
std::string msg(buf.get(), buf.get() + size - 1); /* -1 to remove final '\0' */ \
throw raft::exception(msg); \
} while (0)

// FIXME: Need to be replaced with RAFT_EXPECTS
Expand All @@ -110,15 +119,24 @@ struct logic_error : public raft::exception {
if (!(check)) THROW(fmt, ##__VA_ARGS__); \
} while (0)

#define SET_ERROR_MSG(msg, location_prefix, fmt, ...) \
do { \
char err_msg[2048]; /* NOLINT */ \
std::snprintf(err_msg, sizeof(err_msg), location_prefix); \
msg += err_msg; \
std::snprintf(err_msg, sizeof(err_msg), "file=%s line=%d: ", __FILE__, __LINE__); \
msg += err_msg; \
std::snprintf(err_msg, sizeof(err_msg), fmt, ##__VA_ARGS__); \
msg += err_msg; \
/**
* Macro to append error message to first argument.
* This should only be called in contexts where it is OK to throw exceptions!
*/
#define SET_ERROR_MSG(msg, location_prefix, fmt, ...) \
do { \
int size1 = std::snprintf(nullptr, 0, "%s", location_prefix); \
int size2 = std::snprintf(nullptr, 0, "file=%s line=%d: ", __FILE__, __LINE__); \
int size3 = std::snprintf(nullptr, 0, fmt, ##__VA_ARGS__); \
if (size1 < 0 || size2 < 0 || size3 < 0) \
throw raft::exception("Error in snprintf, cannot handle raft exception."); \
auto size = size1 + size2 + size3 + 1; /* +1 for final '\0' */ \
auto buf = std::make_unique<char[]>(size_t(size)); \
std::snprintf(buf.get(), size1 + 1 /* +1 for '\0' */, "%s", location_prefix); \
std::snprintf( \
buf.get() + size1, size2 + 1 /* +1 for '\0' */, "file=%s line=%d: ", __FILE__, __LINE__); \
std::snprintf(buf.get() + size1 + size2, size3 + 1 /* +1 for '\0' */, fmt, ##__VA_ARGS__); \
msg += std::string(buf.get(), buf.get() + size - 1); /* -1 to remove final '\0' */ \
} while (0)

/**
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/linalg/cholesky_r1_update.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ void choleskyRank1Update(const raft::handle_t& handle,
math_t* s = reinterpret_cast<math_t*>(((char*)workspace) + offset);
math_t* L_22 = L + (n - 1) * ld + n - 1;

math_t* A_new;
math_t* A_row;
math_t* A_new = nullptr;
math_t* A_row = nullptr;
if (uplo == CUBLAS_FILL_MODE_UPPER) {
// A_new is stored as the n-1 th column of L
A_new = L + (n - 1) * ld;
Expand Down
58 changes: 57 additions & 1 deletion cpp/test/cudart_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,74 @@
* limitations under the License.
*/

#include <raft/cudart_utils.h>

#include <gtest/gtest.h>

#include <iostream>
#include <raft/cudart_utils.h>
#include <regex>

namespace raft {

#define TEST_ADD_FILENAME(s) \
{ \
s += std::string{__FILE__}; \
}

std::string reg_escape(const std::string& s)
{
static const std::regex SPECIAL_CHARS{R"([-[\]{}()*+?.,\^$|#\s])"};
return std::regex_replace(s, SPECIAL_CHARS, R"(\$&)");
}

TEST(Raft, Utils)
{
ASSERT_NO_THROW(ASSERT(1 == 1, "Should not assert!"));
ASSERT_THROW(ASSERT(1 != 1, "Should assert!"), exception);
ASSERT_THROW(THROW("Should throw!"), exception);
ASSERT_NO_THROW(RAFT_CUDA_TRY(cudaFree(nullptr)));

// test for long error message strings
std::string test{"This is a test string repeated many times. "};
for (size_t i = 0; i < 6; ++i)
test += test;
EXPECT_TRUE(test.size() > 2048) << "size of test string is: " << test.size();
auto test_format = test + "%d";
auto* test_format_c = test_format.c_str();

std::string file{};
TEST_ADD_FILENAME(file);
std::string reg_file = reg_escape(file);

// THROW has to convert the test string into an exception string
try {
ASSERT(1 != 1, test_format_c, 121);
} catch (const raft::exception& e) {
std::string msg_full{e.what()};
// only use first line
std::string msg = msg_full.substr(0, msg_full.find('\n'));
std::string re_exp{"^exception occured! file="};
re_exp += reg_file;
// test code must be at line >10 (copyright), assume line is never >9999
re_exp += " line=\\d{2,4}: ";
re_exp += reg_escape(test);
re_exp += "121$";
EXPECT_TRUE(std::regex_match(msg, std::regex(re_exp))) << "message:'" << msg << "'" << std::endl
<< "expected regex:'" << re_exp << "'";
}

// Now we test SET_ERROR_MSG instead of THROW
std::string msg{"prefix:"};
ASSERT_NO_THROW(SET_ERROR_MSG(msg, "location prefix:", test_format_c, 123));

std::string re_exp{"^prefix:location prefix:file="};
re_exp += reg_file;
// test code must be at line >10 (copyright), assume line is never >9999
re_exp += " line=\\d{2,4}: ";
re_exp += reg_escape(test);
re_exp += "123$";
EXPECT_TRUE(std::regex_match(msg, std::regex(re_exp))) << "message:'" << msg << "'" << std::endl
<< "expected regex:'" << re_exp << "'";
}

} // namespace raft