diff --git a/cpp/include/raft/common/detail/logger.hpp b/cpp/include/raft/common/detail/logger.hpp index 053b6e3c88..619fb89452 100644 --- a/cpp/include/raft/common/detail/logger.hpp +++ b/cpp/include/raft/common/detail/logger.hpp @@ -15,139 +15,8 @@ */ #pragma once -#include +#pragma message(__FILE__ \ + " is deprecated and will be removed in future releases." \ + " Please use the version instead.") -#define SPDLOG_HEADER_ONLY -#include // NOLINT -#include // NOLINT - -#include - -#include -#include -#include -#include - -#include - -/** - * @defgroup logging levels used in raft - * - * @note exactly match the corresponding ones (but reverse in terms of value) - * in spdlog for wrapping purposes - * - * @{ - */ -#define RAFT_LEVEL_TRACE 6 -#define RAFT_LEVEL_DEBUG 5 -#define RAFT_LEVEL_INFO 4 -#define RAFT_LEVEL_WARN 3 -#define RAFT_LEVEL_ERROR 2 -#define RAFT_LEVEL_CRITICAL 1 -#define RAFT_LEVEL_OFF 0 -/** @} */ - -#if !defined(RAFT_ACTIVE_LEVEL) -#define RAFT_ACTIVE_LEVEL RAFT_LEVEL_DEBUG -#endif - -namespace spdlog { -class logger; -namespace sinks { -template -class CallbackSink; -using callback_sink_mt = CallbackSink; -}; // namespace sinks -}; // namespace spdlog - -namespace raft::detail { - -/** - * @defgroup CStringFormat Expand a C-style format string - * - * @brief Expands C-style formatted string into std::string - * - * @param[in] fmt format string - * @param[in] vl respective values for each of format modifiers in the string - * - * @return the expanded `std::string` - * - * @{ - */ -std::string format(const char* fmt, va_list& vl) -{ - char buf[4096]; - vsnprintf(buf, sizeof(buf), fmt, vl); - return std::string(buf); -} - -std::string format(const char* fmt, ...) -{ - va_list vl; - va_start(vl, fmt); - std::string str = format(fmt, vl); - va_end(vl); - return str; -} -/** @} */ - -int convert_level_to_spdlog(int level) -{ - level = std::max(RAFT_LEVEL_OFF, std::min(RAFT_LEVEL_TRACE, level)); - return RAFT_LEVEL_TRACE - level; -} - -}; // namespace raft::detail - -/** - * @defgroup loggerMacros Helper macros for dealing with logging - * @{ - */ -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_TRACE) -#define RAFT_LOG_TRACE(fmt, ...) \ - do { \ - std::stringstream ss; \ - ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ - ss << raft::detail::format(fmt, ##__VA_ARGS__); \ - raft::logger::get().log(RAFT_LEVEL_TRACE, ss.str().c_str()); \ - } while (0) -#else -#define RAFT_LOG_TRACE(fmt, ...) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_DEBUG) -#define RAFT_LOG_DEBUG(fmt, ...) \ - do { \ - std::stringstream ss; \ - ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ - ss << raft::detail::format(fmt, ##__VA_ARGS__); \ - raft::logger::get().log(RAFT_LEVEL_DEBUG, ss.str().c_str()); \ - } while (0) -#else -#define RAFT_LOG_DEBUG(fmt, ...) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_INFO) -#define RAFT_LOG_INFO(fmt, ...) raft::logger::get().log(RAFT_LEVEL_INFO, fmt, ##__VA_ARGS__) -#else -#define RAFT_LOG_INFO(fmt, ...) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_WARN) -#define RAFT_LOG_WARN(fmt, ...) raft::logger::get().log(RAFT_LEVEL_WARN, fmt, ##__VA_ARGS__) -#else -#define RAFT_LOG_WARN(fmt, ...) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_ERROR) -#define RAFT_LOG_ERROR(fmt, ...) raft::logger::get().log(RAFT_LEVEL_ERROR, fmt, ##__VA_ARGS__) -#else -#define RAFT_LOG_ERROR(fmt, ...) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_CRITICAL) -#define RAFT_LOG_CRITICAL(fmt, ...) raft::logger::get().log(RAFT_LEVEL_CRITICAL, fmt, ##__VA_ARGS__) -#else -#define RAFT_LOG_CRITICAL(fmt, ...) void(0) -#endif -/** @} */ \ No newline at end of file +#include diff --git a/cpp/include/raft/core/cudart_utils.hpp b/cpp/include/raft/core/cudart_utils.hpp index 5adc0227a8..a1e7e6bc32 100644 --- a/cpp/include/raft/core/cudart_utils.hpp +++ b/cpp/include/raft/core/cudart_utils.hpp @@ -36,7 +36,7 @@ #include #include #include -#include +#include ///@todo: enable once logging has been enabled in raft //#include "logger.hpp" @@ -294,7 +294,7 @@ void print_host_vector(const char* variable_name, if (i != 0) out << ","; out << host_mem[i]; } - out << "];\n"; + out << "];" << std::endl; } template @@ -303,10 +303,32 @@ void print_device_vector(const char* variable_name, size_t componentsCount, OutStream& out) { - T* host_mem = new T[componentsCount]; - CUDA_CHECK(cudaMemcpy(host_mem, devMem, componentsCount * sizeof(T), cudaMemcpyDeviceToHost)); - print_host_vector(variable_name, host_mem, componentsCount, out); - delete[] host_mem; + std::vector host_mem(componentsCount); + CUDA_CHECK( + cudaMemcpy(host_mem.data(), devMem, componentsCount * sizeof(T), cudaMemcpyDeviceToHost)); + print_host_vector(variable_name, host_mem.data(), componentsCount, out); +} + +/** + * @brief Print an array given a device or a host pointer. + * + * @param[in] variable_name + * @param[in] ptr any pointer (device/host/managed, etc) + * @param[in] componentsCount array length + * @param out the output stream + */ +template +void print_vector(const char* variable_name, const T* ptr, size_t componentsCount, OutStream& out) +{ + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, ptr)); + if (attr.hostPointer != nullptr) { + print_host_vector(variable_name, reinterpret_cast(attr.hostPointer), componentsCount, out); + } else if (attr.type == cudaMemoryTypeUnregistered) { + print_host_vector(variable_name, ptr, componentsCount, out); + } else { + print_device_vector(variable_name, ptr, componentsCount, out); + } } /** @} */ @@ -425,4 +447,4 @@ constexpr T upper_bound() } // namespace raft -#endif \ No newline at end of file +#endif diff --git a/cpp/include/raft/core/error.hpp b/cpp/include/raft/core/error.hpp index a65b9a8469..8348595db3 100644 --- a/cpp/include/raft/core/error.hpp +++ b/cpp/include/raft/core/error.hpp @@ -26,6 +26,7 @@ #include #include #include +#include namespace raft { @@ -126,20 +127,20 @@ struct logic_error : public raft::exception { * Macro to append error message to first argument. * This should only be called in contexts where it is OK to throw exceptions! */ -#define SET_ERROR_MSG(msg, location_prefix, fmt, ...) \ - do { \ - int size1 = std::snprintf(nullptr, 0, "%s", location_prefix); \ - int size2 = std::snprintf(nullptr, 0, "file=%s line=%d: ", __FILE__, __LINE__); \ - int size3 = std::snprintf(nullptr, 0, fmt, ##__VA_ARGS__); \ - if (size1 < 0 || size2 < 0 || size3 < 0) \ - throw raft::exception("Error in snprintf, cannot handle raft exception."); \ - auto size = size1 + size2 + size3 + 1; /* +1 for final '\0' */ \ - auto buf = std::make_unique(size_t(size)); \ - std::snprintf(buf.get(), size1 + 1 /* +1 for '\0' */, "%s", location_prefix); \ - std::snprintf( \ - buf.get() + size1, size2 + 1 /* +1 for '\0' */, "file=%s line=%d: ", __FILE__, __LINE__); \ - std::snprintf(buf.get() + size1 + size2, size3 + 1 /* +1 for '\0' */, fmt, ##__VA_ARGS__); \ - msg += std::string(buf.get(), buf.get() + size - 1); /* -1 to remove final '\0' */ \ +#define SET_ERROR_MSG(msg, location_prefix, fmt, ...) \ + do { \ + int size1 = std::snprintf(nullptr, 0, "%s", location_prefix); \ + int size2 = std::snprintf(nullptr, 0, "file=%s line=%d: ", __FILE__, __LINE__); \ + int size3 = std::snprintf(nullptr, 0, fmt, ##__VA_ARGS__); \ + if (size1 < 0 || size2 < 0 || size3 < 0) \ + throw raft::exception("Error in snprintf, cannot handle raft exception."); \ + auto size = size1 + size2 + size3 + 1; /* +1 for final '\0' */ \ + std::vector buf(size); \ + std::snprintf(buf.data(), size1 + 1 /* +1 for '\0' */, "%s", location_prefix); \ + std::snprintf( \ + buf.data() + size1, size2 + 1 /* +1 for '\0' */, "file=%s line=%d: ", __FILE__, __LINE__); \ + std::snprintf(buf.data() + size1 + size2, size3 + 1 /* +1 for '\0' */, fmt, ##__VA_ARGS__); \ + msg += std::string(buf.data(), buf.data() + size - 1); /* -1 to remove final '\0' */ \ } while (0) /** @@ -173,4 +174,4 @@ struct logic_error : public raft::exception { throw raft::logic_error(msg); \ } while (0) -#endif \ No newline at end of file +#endif diff --git a/cpp/include/raft/core/logger.hpp b/cpp/include/raft/core/logger.hpp index 9066e103d0..927eb8943e 100644 --- a/cpp/include/raft/core/logger.hpp +++ b/cpp/include/raft/core/logger.hpp @@ -29,6 +29,7 @@ #define SPDLOG_HEADER_ONLY #include +#include #include // NOLINT #include // NOLINT @@ -50,7 +51,7 @@ /** @} */ #if !defined(RAFT_ACTIVE_LEVEL) -#define RAFT_ACTIVE_LEVEL RAFT_LEVEL_DEBUG +#define RAFT_ACTIVE_LEVEL RAFT_LEVEL_INFO #endif namespace raft { @@ -58,6 +59,8 @@ namespace raft { static const std::string RAFT_NAME = "raft"; static const std::string default_log_pattern("[%L] [%H:%M:%S.%f] %v"); +namespace detail { + /** * @defgroup CStringFormat Expand a C-style format string * @@ -70,14 +73,16 @@ static const std::string default_log_pattern("[%L] [%H:%M:%S.%f] %v"); * * @{ */ -std::string format(const char* fmt, va_list& vl) +inline std::string format(const char* fmt, va_list& vl) { - char buf[4096]; - vsnprintf(buf, sizeof(buf), fmt, vl); - return std::string(buf); + int length = std::vsnprintf(nullptr, 0, fmt, vl); + assert(length >= 0); + std::vector buf(length + 1); + std::vsnprintf(buf.data(), length + 1, fmt, vl); + return std::string(buf.data()); } -std::string format(const char* fmt, ...) +inline std::string format(const char* fmt, ...) { va_list vl; va_start(vl, fmt); @@ -87,12 +92,14 @@ std::string format(const char* fmt, ...) } /** @} */ -int convert_level_to_spdlog(int level) +inline int convert_level_to_spdlog(int level) { level = std::max(RAFT_LEVEL_OFF, std::min(RAFT_LEVEL_TRACE, level)); return RAFT_LEVEL_TRACE - level; } +} // namespace detail + /** * @brief The main Logging class for raft library. * @@ -112,7 +119,7 @@ class logger { cur_pattern() { set_pattern(default_log_pattern); - set_level(RAFT_LEVEL_INFO); + set_level(RAFT_ACTIVE_LEVEL); } /** * @brief Singleton method to get the underlying logger object @@ -140,7 +147,7 @@ class logger { */ void set_level(int level) { - level = convert_level_to_spdlog(level); + level = raft::detail::convert_level_to_spdlog(level); spdlogger->set_level(static_cast(level)); } @@ -179,7 +186,7 @@ class logger { */ bool should_log_for(int level) const { - level = convert_level_to_spdlog(level); + level = raft::detail::convert_level_to_spdlog(level); auto level_e = static_cast(level); return spdlogger->should_log(level_e); } @@ -209,13 +216,13 @@ class logger { */ void log(int level, const char* fmt, ...) { - level = convert_level_to_spdlog(level); + level = raft::detail::convert_level_to_spdlog(level); auto level_e = static_cast(level); // explicit check to make sure that we only expand messages when required if (spdlogger->should_log(level_e)) { va_list vl; va_start(vl, fmt); - auto msg = format(fmt, vl); + auto msg = raft::detail::format(fmt, vl); va_end(vl); spdlogger->log(level_e, msg); } @@ -256,12 +263,24 @@ class logger { #define RAFT_LOG_TRACE(fmt, ...) void(0) #endif +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_TRACE) +#define RAFT_LOG_TRACE_VEC(ptr, len) \ + do { \ + std::stringstream ss; \ + ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ + print_vector(#ptr, ptr, len, ss); \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_TRACE, ss.str().c_str()); \ + } while (0) +#else +#define RAFT_LOG_TRACE_VEC(ptr, len) void(0) +#endif + #if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_DEBUG) #define RAFT_LOG_DEBUG(fmt, ...) \ do { \ std::stringstream ss; \ - ss << raft::format("%s:%d ", __FILE__, __LINE__); \ - ss << raft::format(fmt, ##__VA_ARGS__); \ + ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ + ss << raft::detail::format(fmt, ##__VA_ARGS__); \ raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_DEBUG, ss.str().c_str()); \ } while (0) #else diff --git a/cpp/test/common/logger.cpp b/cpp/test/common/logger.cpp index 813ce2b5f1..a8460e45ca 100644 --- a/cpp/test/common/logger.cpp +++ b/cpp/test/common/logger.cpp @@ -15,7 +15,7 @@ */ #include -#include +#include #include namespace raft { @@ -61,6 +61,15 @@ class loggerTest : public ::testing::Test { } }; +// The logging macros depend on `RAFT_ACTIVE_LEVEL` as well as the logger verbosity; +// The verbosity is set to `RAFT_LEVEL_TRACE`, but `RAFT_ACTIVE_LEVEL` is set outside of here. +auto check_if_logged(const std::string& msg, int log_level_def) -> bool +{ + bool actually_logged = logged.find(msg) != std::string::npos; + bool should_be_logged = RAFT_ACTIVE_LEVEL >= log_level_def; + return actually_logged == should_be_logged; +} + TEST_F(loggerTest, callback) { std::string testMsg; @@ -68,23 +77,27 @@ TEST_F(loggerTest, callback) testMsg = "This is a critical message"; RAFT_LOG_CRITICAL(testMsg.c_str()); - ASSERT_TRUE(logged.find(testMsg) != std::string::npos); + ASSERT_TRUE(check_if_logged(testMsg, RAFT_LEVEL_CRITICAL)); testMsg = "This is an error message"; RAFT_LOG_ERROR(testMsg.c_str()); - ASSERT_TRUE(logged.find(testMsg) != std::string::npos); + ASSERT_TRUE(check_if_logged(testMsg, RAFT_LEVEL_ERROR)); testMsg = "This is a warning message"; RAFT_LOG_WARN(testMsg.c_str()); - ASSERT_TRUE(logged.find(testMsg) != std::string::npos); + ASSERT_TRUE(check_if_logged(testMsg, RAFT_LEVEL_WARN)); testMsg = "This is an info message"; RAFT_LOG_INFO(testMsg.c_str()); - ASSERT_TRUE(logged.find(testMsg) != std::string::npos); + ASSERT_TRUE(check_if_logged(testMsg, RAFT_LEVEL_INFO)); testMsg = "This is a debug message"; RAFT_LOG_DEBUG(testMsg.c_str()); - ASSERT_TRUE(logged.find(testMsg) != std::string::npos); + ASSERT_TRUE(check_if_logged(testMsg, RAFT_LEVEL_DEBUG)); + + testMsg = "This is a trace message"; + RAFT_LOG_TRACE(testMsg.c_str()); + ASSERT_TRUE(check_if_logged(testMsg, RAFT_LEVEL_TRACE)); } TEST_F(loggerTest, flush)