Skip to content

Commit

Permalink
error macros: determining buffer size instead of fixed 2048 chars (#420)
Browse files Browse the repository at this point in the history
Determining buffer size for error macros instead of using fixed size of 2048 bytes.
I ran all tests (which use these macros extensively) as well as tested without CUDA driver to see that output looks as expected.

Closes #419.

Authors:
  - Matt Joux (https://github.com/MatthiasKohl)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Brad Rees (https://github.com/BradReesWork)

URL: #420
  • Loading branch information
MatthiasKohl authored Jan 12, 2022
1 parent 15fd1d3 commit 605458e
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 22 deletions.
56 changes: 37 additions & 19 deletions cpp/include/raft/error.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

#pragma once

#include <cstdio>
#include <execinfo.h>
#include <iostream>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
Expand Down Expand Up @@ -91,16 +93,23 @@ struct logic_error : public raft::exception {

// FIXME: Need to be replaced with RAFT_FAIL
/** macro to throw a runtime error */
#define THROW(fmt, ...) \
do { \
std::string msg; \
char errMsg[2048]; /* NOLINT */ \
std::snprintf( \
errMsg, sizeof(errMsg), "exception occured! file=%s line=%d: ", __FILE__, __LINE__); \
msg += errMsg; \
std::snprintf(errMsg, sizeof(errMsg), fmt, ##__VA_ARGS__); \
msg += errMsg; \
throw raft::exception(msg); \
#define THROW(fmt, ...) \
do { \
int size1 = \
std::snprintf(nullptr, 0, "exception occured! file=%s line=%d: ", __FILE__, __LINE__); \
int size2 = std::snprintf(nullptr, 0, fmt, ##__VA_ARGS__); \
if (size1 < 0 || size2 < 0) \
throw raft::exception("Error in snprintf, cannot handle raft exception."); \
auto size = size1 + size2 + 1; /* +1 for final '\0' */ \
auto buf = std::make_unique<char[]>(size_t(size)); \
std::snprintf(buf.get(), \
size1 + 1 /* +1 for '\0' */, \
"exception occured! file=%s line=%d: ", \
__FILE__, \
__LINE__); \
std::snprintf(buf.get() + size1, size2 + 1 /* +1 for '\0' */, fmt, ##__VA_ARGS__); \
std::string msg(buf.get(), buf.get() + size - 1); /* -1 to remove final '\0' */ \
throw raft::exception(msg); \
} while (0)

// FIXME: Need to be replaced with RAFT_EXPECTS
Expand All @@ -110,15 +119,24 @@ struct logic_error : public raft::exception {
if (!(check)) THROW(fmt, ##__VA_ARGS__); \
} while (0)

#define SET_ERROR_MSG(msg, location_prefix, fmt, ...) \
do { \
char err_msg[2048]; /* NOLINT */ \
std::snprintf(err_msg, sizeof(err_msg), location_prefix); \
msg += err_msg; \
std::snprintf(err_msg, sizeof(err_msg), "file=%s line=%d: ", __FILE__, __LINE__); \
msg += err_msg; \
std::snprintf(err_msg, sizeof(err_msg), fmt, ##__VA_ARGS__); \
msg += err_msg; \
/**
* Macro to append error message to first argument.
* This should only be called in contexts where it is OK to throw exceptions!
*/
#define SET_ERROR_MSG(msg, location_prefix, fmt, ...) \
do { \
int size1 = std::snprintf(nullptr, 0, "%s", location_prefix); \
int size2 = std::snprintf(nullptr, 0, "file=%s line=%d: ", __FILE__, __LINE__); \
int size3 = std::snprintf(nullptr, 0, fmt, ##__VA_ARGS__); \
if (size1 < 0 || size2 < 0 || size3 < 0) \
throw raft::exception("Error in snprintf, cannot handle raft exception."); \
auto size = size1 + size2 + size3 + 1; /* +1 for final '\0' */ \
auto buf = std::make_unique<char[]>(size_t(size)); \
std::snprintf(buf.get(), size1 + 1 /* +1 for '\0' */, "%s", location_prefix); \
std::snprintf( \
buf.get() + size1, size2 + 1 /* +1 for '\0' */, "file=%s line=%d: ", __FILE__, __LINE__); \
std::snprintf(buf.get() + size1 + size2, size3 + 1 /* +1 for '\0' */, fmt, ##__VA_ARGS__); \
msg += std::string(buf.get(), buf.get() + size - 1); /* -1 to remove final '\0' */ \
} while (0)

/**
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/linalg/cholesky_r1_update.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ void choleskyRank1Update(const raft::handle_t& handle,
math_t* s = reinterpret_cast<math_t*>(((char*)workspace) + offset);
math_t* L_22 = L + (n - 1) * ld + n - 1;

math_t* A_new;
math_t* A_row;
math_t* A_new = nullptr;
math_t* A_row = nullptr;
if (uplo == CUBLAS_FILL_MODE_UPPER) {
// A_new is stored as the n-1 th column of L
A_new = L + (n - 1) * ld;
Expand Down
58 changes: 57 additions & 1 deletion cpp/test/cudart_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,74 @@
* limitations under the License.
*/

#include <raft/cudart_utils.h>

#include <gtest/gtest.h>

#include <iostream>
#include <raft/cudart_utils.h>
#include <regex>

namespace raft {

#define TEST_ADD_FILENAME(s) \
{ \
s += std::string{__FILE__}; \
}

std::string reg_escape(const std::string& s)
{
static const std::regex SPECIAL_CHARS{R"([-[\]{}()*+?.,\^$|#\s])"};
return std::regex_replace(s, SPECIAL_CHARS, R"(\$&)");
}

TEST(Raft, Utils)
{
ASSERT_NO_THROW(ASSERT(1 == 1, "Should not assert!"));
ASSERT_THROW(ASSERT(1 != 1, "Should assert!"), exception);
ASSERT_THROW(THROW("Should throw!"), exception);
ASSERT_NO_THROW(RAFT_CUDA_TRY(cudaFree(nullptr)));

// test for long error message strings
std::string test{"This is a test string repeated many times. "};
for (size_t i = 0; i < 6; ++i)
test += test;
EXPECT_TRUE(test.size() > 2048) << "size of test string is: " << test.size();
auto test_format = test + "%d";
auto* test_format_c = test_format.c_str();

std::string file{};
TEST_ADD_FILENAME(file);
std::string reg_file = reg_escape(file);

// THROW has to convert the test string into an exception string
try {
ASSERT(1 != 1, test_format_c, 121);
} catch (const raft::exception& e) {
std::string msg_full{e.what()};
// only use first line
std::string msg = msg_full.substr(0, msg_full.find('\n'));
std::string re_exp{"^exception occured! file="};
re_exp += reg_file;
// test code must be at line >10 (copyright), assume line is never >9999
re_exp += " line=\\d{2,4}: ";
re_exp += reg_escape(test);
re_exp += "121$";
EXPECT_TRUE(std::regex_match(msg, std::regex(re_exp))) << "message:'" << msg << "'" << std::endl
<< "expected regex:'" << re_exp << "'";
}

// Now we test SET_ERROR_MSG instead of THROW
std::string msg{"prefix:"};
ASSERT_NO_THROW(SET_ERROR_MSG(msg, "location prefix:", test_format_c, 123));

std::string re_exp{"^prefix:location prefix:file="};
re_exp += reg_file;
// test code must be at line >10 (copyright), assume line is never >9999
re_exp += " line=\\d{2,4}: ";
re_exp += reg_escape(test);
re_exp += "123$";
EXPECT_TRUE(std::regex_match(msg, std::regex(re_exp))) << "message:'" << msg << "'" << std::endl
<< "expected regex:'" << re_exp << "'";
}

} // namespace raft

0 comments on commit 605458e

Please sign in to comment.