Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] [FEA] error macros: determining buffer size instead of fixed 2048 chars #420

Merged
merged 4 commits into from
Jan 12, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 37 additions & 19 deletions cpp/include/raft/error.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
#pragma once

#include <execinfo.h>
#include <cstdio>
#include <iostream>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
Expand Down Expand Up @@ -91,16 +93,23 @@ struct logic_error : public raft::exception {

// FIXME: Need to be replaced with RAFT_FAIL
/** macro to throw a runtime error */
#define THROW(fmt, ...) \
do { \
std::string msg; \
char errMsg[2048]; /* NOLINT */ \
std::snprintf( \
errMsg, sizeof(errMsg), "exception occured! file=%s line=%d: ", __FILE__, __LINE__); \
msg += errMsg; \
std::snprintf(errMsg, sizeof(errMsg), fmt, ##__VA_ARGS__); \
msg += errMsg; \
throw raft::exception(msg); \
#define THROW(fmt, ...) \
do { \
int size1 = \
std::snprintf(nullptr, 0, "exception occured! file=%s line=%d: ", __FILE__, __LINE__); \
int size2 = std::snprintf(nullptr, 0, fmt, ##__VA_ARGS__); \
if (size1 < 0 || size2 < 0) \
throw raft::exception("Error in snprintf, cannot handle raft exception."); \
auto size = size1 + size2 + 1; /* +1 for final '\0' */ \
auto buf = std::make_unique<char[]>(size_t(size)); \
std::snprintf(buf.get(), \
size1 + 1 /* +1 for '\0' */, \
"exception occured! file=%s line=%d: ", \
__FILE__, \
__LINE__); \
std::snprintf(buf.get() + size1, size2 + 1 /* +1 for '\0' */, fmt, ##__VA_ARGS__); \
std::string msg(buf.get(), buf.get() + size - 1); /* -1 to remove final '\0' */ \
throw raft::exception(msg); \
} while (0)

// FIXME: Need to be replaced with RAFT_EXPECTS
Expand All @@ -110,15 +119,24 @@ struct logic_error : public raft::exception {
if (!(check)) THROW(fmt, ##__VA_ARGS__); \
} while (0)

#define SET_ERROR_MSG(msg, location_prefix, fmt, ...) \
do { \
char err_msg[2048]; /* NOLINT */ \
std::snprintf(err_msg, sizeof(err_msg), location_prefix); \
msg += err_msg; \
std::snprintf(err_msg, sizeof(err_msg), "file=%s line=%d: ", __FILE__, __LINE__); \
msg += err_msg; \
std::snprintf(err_msg, sizeof(err_msg), fmt, ##__VA_ARGS__); \
msg += err_msg; \
/**
* Macro to append error message to first argument.
* This should only be called in contexts where it is OK to throw exceptions!
*/
#define SET_ERROR_MSG(msg, location_prefix, fmt, ...) \
do { \
int size1 = std::snprintf(nullptr, 0, "%s", location_prefix); \
int size2 = std::snprintf(nullptr, 0, "file=%s line=%d: ", __FILE__, __LINE__); \
int size3 = std::snprintf(nullptr, 0, fmt, ##__VA_ARGS__); \
if (size1 < 0 || size2 < 0 || size3 < 0) \
throw raft::exception("Error in snprintf, cannot handle raft exception."); \
auto size = size1 + size2 + size3 + 1; /* +1 for final '\0' */ \
auto buf = std::make_unique<char[]>(size_t(size)); \
std::snprintf(buf.get(), size1 + 1 /* +1 for '\0' */, "%s", location_prefix); \
std::snprintf( \
buf.get() + size1, size2 + 1 /* +1 for '\0' */, "file=%s line=%d: ", __FILE__, __LINE__); \
std::snprintf(buf.get() + size1 + size2, size3 + 1 /* +1 for '\0' */, fmt, ##__VA_ARGS__); \
msg += std::string(buf.get(), buf.get() + size - 1); /* -1 to remove final '\0' */ \
} while (0)

/**
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/linalg/cholesky_r1_update.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ void choleskyRank1Update(const raft::handle_t& handle,
math_t* s = reinterpret_cast<math_t*>(((char*)workspace) + offset);
math_t* L_22 = L + (n - 1) * ld + n - 1;

math_t* A_new;
math_t* A_row;
math_t* A_new = nullptr;
math_t* A_row = nullptr;
if (uplo == CUBLAS_FILL_MODE_UPPER) {
// A_new is stored as the n-1 th column of L
A_new = L + (n - 1) * ld;
Expand Down