From aaa72e4758fdab41aeb070998ce5db5b1da1d516 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 9 Mar 2022 18:23:53 -0500 Subject: [PATCH 1/4] adding logging. --- .../raft/common/detail/callback_sink.hpp | 71 ++++++ cpp/include/raft/common/detail/logger.hpp | 153 +++++++++++++ cpp/include/raft/common/detail/scatter.cuh | 52 +++++ cpp/include/raft/common/logger.hpp | 211 ++++++++++++++++++ cpp/include/raft/common/scatter.cuh | 42 +--- cpp/test/CMakeLists.txt | 1 + cpp/test/common/logger.cpp | 97 ++++++++ 7 files changed, 592 insertions(+), 35 deletions(-) create mode 100644 cpp/include/raft/common/detail/callback_sink.hpp create mode 100644 cpp/include/raft/common/detail/logger.hpp create mode 100644 cpp/include/raft/common/detail/scatter.cuh create mode 100644 cpp/include/raft/common/logger.hpp create mode 100644 cpp/test/common/logger.cpp diff --git a/cpp/include/raft/common/detail/callback_sink.hpp b/cpp/include/raft/common/detail/callback_sink.hpp new file mode 100644 index 0000000000..ecd869ee4f --- /dev/null +++ b/cpp/include/raft/common/detail/callback_sink.hpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +#define SPDLOG_HEADER_ONLY +#include +#include +#include + +namespace spdlog::sinks { + +typedef void (*LogCallback)(int lvl, const char* msg); + +template +class CallbackSink : public base_sink { + public: + explicit CallbackSink(std::string tag = "spdlog", + LogCallback callback = nullptr, + void (*flush)() = nullptr) + : _callback{callback}, _flush{flush} {}; + + void set_callback(LogCallback callback) { _callback = callback; } + void set_flush(void (*flush)()) { _flush = flush; } + + protected: + void sink_it_(const details::log_msg& msg) override + { + spdlog::memory_buf_t formatted; + base_sink::formatter_->format(msg, formatted); + std::string msg_string = fmt::to_string(formatted); + + if (_callback) { + _callback(static_cast(msg.level), msg_string.c_str()); + } else { + std::cout << msg_string; + } + } + + void flush_() override + { + if (_flush) { + _flush(); + } else { + std::cout << std::flush; + } + } + + LogCallback _callback; + void (*_flush)(); +}; + +using callback_sink_mt = CallbackSink; +using callback_sink_st = CallbackSink; + +} // end namespace spdlog::sinks \ No newline at end of file diff --git a/cpp/include/raft/common/detail/logger.hpp b/cpp/include/raft/common/detail/logger.hpp new file mode 100644 index 0000000000..053b6e3c88 --- /dev/null +++ b/cpp/include/raft/common/detail/logger.hpp @@ -0,0 +1,153 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#define SPDLOG_HEADER_ONLY +#include // NOLINT +#include // NOLINT + +#include + +#include +#include +#include +#include + +#include + +/** + * @defgroup logging levels used in raft + * + * @note exactly match the corresponding ones (but reverse in terms of value) + * in spdlog for wrapping purposes + * + * @{ + */ +#define RAFT_LEVEL_TRACE 6 +#define RAFT_LEVEL_DEBUG 5 +#define RAFT_LEVEL_INFO 4 +#define RAFT_LEVEL_WARN 3 +#define RAFT_LEVEL_ERROR 2 +#define RAFT_LEVEL_CRITICAL 1 +#define RAFT_LEVEL_OFF 0 +/** @} */ + +#if !defined(RAFT_ACTIVE_LEVEL) +#define RAFT_ACTIVE_LEVEL RAFT_LEVEL_DEBUG +#endif + +namespace spdlog { +class logger; +namespace sinks { +template +class CallbackSink; +using callback_sink_mt = CallbackSink; +}; // namespace sinks +}; // namespace spdlog + +namespace raft::detail { + +/** + * @defgroup CStringFormat Expand a C-style format string + * + * @brief Expands C-style formatted string into std::string + * + * @param[in] fmt format string + * @param[in] vl respective values for each of format modifiers in the string + * + * @return the expanded `std::string` + * + * @{ + */ +std::string format(const char* fmt, va_list& vl) +{ + char buf[4096]; + vsnprintf(buf, sizeof(buf), fmt, vl); + return std::string(buf); +} + +std::string format(const char* fmt, ...) +{ + va_list vl; + va_start(vl, fmt); + std::string str = format(fmt, vl); + va_end(vl); + return str; +} +/** @} */ + +int convert_level_to_spdlog(int level) +{ + level = std::max(RAFT_LEVEL_OFF, std::min(RAFT_LEVEL_TRACE, level)); + return RAFT_LEVEL_TRACE - level; +} + +}; // namespace raft::detail + +/** + * @defgroup loggerMacros Helper macros for dealing with logging + * @{ + */ +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_TRACE) +#define RAFT_LOG_TRACE(fmt, ...) \ + do { \ + std::stringstream ss; \ + ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ + ss << raft::detail::format(fmt, ##__VA_ARGS__); \ + raft::logger::get().log(RAFT_LEVEL_TRACE, ss.str().c_str()); \ + } while (0) +#else +#define RAFT_LOG_TRACE(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_DEBUG) +#define RAFT_LOG_DEBUG(fmt, ...) \ + do { \ + std::stringstream ss; \ + ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ + ss << raft::detail::format(fmt, ##__VA_ARGS__); \ + raft::logger::get().log(RAFT_LEVEL_DEBUG, ss.str().c_str()); \ + } while (0) +#else +#define RAFT_LOG_DEBUG(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_INFO) +#define RAFT_LOG_INFO(fmt, ...) raft::logger::get().log(RAFT_LEVEL_INFO, fmt, ##__VA_ARGS__) +#else +#define RAFT_LOG_INFO(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_WARN) +#define RAFT_LOG_WARN(fmt, ...) raft::logger::get().log(RAFT_LEVEL_WARN, fmt, ##__VA_ARGS__) +#else +#define RAFT_LOG_WARN(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_ERROR) +#define RAFT_LOG_ERROR(fmt, ...) raft::logger::get().log(RAFT_LEVEL_ERROR, fmt, ##__VA_ARGS__) +#else +#define RAFT_LOG_ERROR(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_CRITICAL) +#define RAFT_LOG_CRITICAL(fmt, ...) raft::logger::get().log(RAFT_LEVEL_CRITICAL, fmt, ##__VA_ARGS__) +#else +#define RAFT_LOG_CRITICAL(fmt, ...) void(0) +#endif +/** @} */ \ No newline at end of file diff --git a/cpp/include/raft/common/detail/scatter.cuh b/cpp/include/raft/common/detail/scatter.cuh new file mode 100644 index 0000000000..e158999b1b --- /dev/null +++ b/cpp/include/raft/common/detail/scatter.cuh @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::detail { + +template +__global__ void scatterKernel(DataT* out, const DataT* in, const IdxT* idx, IdxT len, Lambda op) +{ + typedef TxN_t DataVec; + typedef TxN_t IdxVec; + IdxT tid = threadIdx.x + ((IdxT)blockIdx.x * blockDim.x); + tid *= VecLen; + if (tid >= len) return; + IdxVec idxIn; + idxIn.load(idx, tid); + DataVec dataIn; +#pragma unroll + for (int i = 0; i < VecLen; ++i) { + auto inPos = idxIn.val.data[i]; + dataIn.val.data[i] = op(in[inPos], tid + i); + } + dataIn.store(out, tid); +} + +template +void scatterImpl( + DataT* out, const DataT* in, const IdxT* idx, IdxT len, Lambda op, cudaStream_t stream) +{ + const IdxT nblks = raft::ceildiv(VecLen ? len / VecLen : len, (IdxT)TPB); + scatterKernel<<>>(out, in, idx, len, op); + RAFT_CUDA_TRY(cudaGetLastError()); +} + +} // namespace raft::detail diff --git a/cpp/include/raft/common/logger.hpp b/cpp/include/raft/common/logger.hpp new file mode 100644 index 0000000000..aa7c55e863 --- /dev/null +++ b/cpp/include/raft/common/logger.hpp @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include + +#include +#include +#include +#include +#include + +#include + +namespace raft { + +static const std::string default_log_pattern("[%L] [%H:%M:%S.%f] %v"); + +/** + * @brief The main Logging class for raft library. + * + * This class acts as a thin wrapper over the underlying `spdlog` interface. The + * design is done in this way in order to avoid us having to also ship `spdlog` + * header files in our installation. + * + * @todo This currently only supports logging to stdout. Need to add support in + * future to add custom loggers as well [Issue #2046] + */ +class logger { + public: + // @todo setting the logger once per process with + logger(std::string const& name_ = "") + : sink{std::make_shared()}, + spdlogger{std::make_shared(name_, sink)}, + cur_pattern() + { + set_pattern(default_log_pattern); + set_level(RAFT_LEVEL_INFO); + } + /** + * @brief Singleton method to get the underlying logger object + * + * @return the singleton logger object + */ + static logger& get(std::string const& name = "") + { + if (log_map.find(name) == log_map.end()) { + log_map[name] = std::make_shared(name); + } + return *log_map[name]; + } + + /** + * @brief Set the logging level. + * + * Only messages with level equal or above this will be printed + * + * @param[in] level logging level + * + * @note The log level will actually be set only if the input is within the + * range [RAFT_LEVEL_TRACE, RAFT_LEVEL_OFF]. If it is not, then it'll + * be ignored. See documentation of decisiontree for how this gets used + */ + void set_level(int level) + { + level = detail::convert_level_to_spdlog(level); + spdlogger->set_level(static_cast(level)); + } + + /** + * @brief Set the logging pattern + * + * @param[in] pattern the pattern to be set. Refer this link + * https://github.com/gabime/spdlog/wiki/3.-Custom-formatting + * to know the right syntax of this pattern + */ + void set_pattern(const std::string& pattern) + { + cur_pattern = pattern; + spdlogger->set_pattern(pattern); + } + + /** + * @brief Register a callback function to be run in place of usual log call + * + * @param[in] callback the function to be run on all logged messages + */ + void set_callback(void (*callback)(int lvl, const char* msg)) { sink->set_callback(callback); } + + /** + * @brief Register a flush function compatible with the registered callback + * + * @param[in] flush the function to use when flushing logs + */ + void set_flush(void (*flush)()) { sink->set_flush(flush); } + + /** + * @brief Tells whether messages will be logged for the given log level + * + * @param[in] level log level to be checked for + * @return true if messages will be logged for this level, else false + */ + bool should_log_for(int level) const + { + level = detail::convert_level_to_spdlog(level); + auto level_e = static_cast(level); + return spdlogger->should_log(level_e); + } + + /** + * @brief Query for the current log level + * + * @return the current log level + */ + int get_level() const + { + auto level_e = spdlogger->level(); + return RAFT_LEVEL_TRACE - static_cast(level_e); + } + + /** + * @brief Get the current logging pattern + * @return the pattern + */ + std::string get_pattern() const { return cur_pattern; } + + /** + * @brief Main logging method + * + * @param[in] level logging level of this message + * @param[in] fmt C-like format string, followed by respective params + */ + void log(int level, const char* fmt, ...) + { + level = detail::convert_level_to_spdlog(level); + auto level_e = static_cast(level); + // explicit check to make sure that we only expand messages when required + if (spdlogger->should_log(level_e)) { + va_list vl; + va_start(vl, fmt); + auto msg = detail::format(fmt, vl); + va_end(vl); + spdlogger->log(level_e, msg); + } + } + + /** + * @brief Flush logs by calling flush on underlying logger + */ + void flush() { spdlogger->flush(); } + + ~logger() {} + + private: + logger(); + + static inline std::unordered_map> log_map; + std::shared_ptr sink; + std::shared_ptr spdlogger; + std::string cur_pattern; + int cur_level; +}; // class logger + +/** + * @brief RAII based pattern setter for logger class + * + * @code{.cpp} + * { + * PatternSetter _("%l -- %v"); + * RAFT_LOG_INFO("Test message\n"); + * } + * @endcode + */ +class PatternSetter { + public: + /** + * @brief Set the pattern for the rest of the log messages + * @param[in] pattern pattern to be set + */ + PatternSetter(const std::string& pattern = "%v") : prev_pattern() + { + prev_pattern = logger::get().get_pattern(); + logger::get().set_pattern(pattern); + } + + /** + * @brief This will restore the previous pattern that was active during the + * moment this object was created + */ + ~PatternSetter() { logger::get().set_pattern(prev_pattern); } + + private: + std::string prev_pattern; +}; // class PatternSetter + +}; // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/common/scatter.cuh b/cpp/include/raft/common/scatter.cuh index 2d25b85a50..04b4393261 100644 --- a/cpp/include/raft/common/scatter.cuh +++ b/cpp/include/raft/common/scatter.cuh @@ -16,39 +16,11 @@ #pragma once +#include #include -#include namespace raft { -template -__global__ void scatterKernel(DataT* out, const DataT* in, const IdxT* idx, IdxT len, Lambda op) -{ - typedef TxN_t DataVec; - typedef TxN_t IdxVec; - IdxT tid = threadIdx.x + ((IdxT)blockIdx.x * blockDim.x); - tid *= VecLen; - if (tid >= len) return; - IdxVec idxIn; - idxIn.load(idx, tid); - DataVec dataIn; -#pragma unroll - for (int i = 0; i < VecLen; ++i) { - auto inPos = idxIn.val.data[i]; - dataIn.val.data[i] = op(in[inPos], tid + i); - } - dataIn.store(out, tid); -} - -template -void scatterImpl( - DataT* out, const DataT* in, const IdxT* idx, IdxT len, Lambda op, cudaStream_t stream) -{ - const IdxT nblks = raft::ceildiv(VecLen ? len / VecLen : len, (IdxT)TPB); - scatterKernel<<>>(out, in, idx, len, op); - RAFT_CUDA_TRY(cudaGetLastError()); -} - /** * @brief Performs scatter operation based on the input indexing array * @tparam DataT data type whose array gets scattered @@ -79,17 +51,17 @@ void scatter(DataT* out, constexpr size_t MaxPerElem = DataSize > IdxSize ? DataSize : IdxSize; size_t bytes = len * MaxPerElem; if (16 / MaxPerElem && bytes % 16 == 0) { - scatterImpl(out, in, idx, len, op, stream); + detail::scatterImpl(out, in, idx, len, op, stream); } else if (8 / MaxPerElem && bytes % 8 == 0) { - scatterImpl(out, in, idx, len, op, stream); + detail::scatterImpl(out, in, idx, len, op, stream); } else if (4 / MaxPerElem && bytes % 4 == 0) { - scatterImpl(out, in, idx, len, op, stream); + detail::scatterImpl(out, in, idx, len, op, stream); } else if (2 / MaxPerElem && bytes % 2 == 0) { - scatterImpl(out, in, idx, len, op, stream); + detail::scatterImpl(out, in, idx, len, op, stream); } else if (1 / MaxPerElem) { - scatterImpl(out, in, idx, len, op, stream); + detail::scatterImpl(out, in, idx, len, op, stream); } else { - scatterImpl(out, in, idx, len, op, stream); + detail::scatterImpl(out, in, idx, len, op, stream); } } diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 0d3121fee6..f8ae28f550 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -16,6 +16,7 @@ # keep the files in alphabetical order! add_executable(test_raft + test/common/logger.cpp test/common/seive.cu test/cudart_utils.cpp test/cluster_solvers.cu diff --git a/cpp/test/common/logger.cpp b/cpp/test/common/logger.cpp new file mode 100644 index 0000000000..ff63b8249e --- /dev/null +++ b/cpp/test/common/logger.cpp @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace raft { + +TEST(logger, Test) +{ + RAFT_LOG_CRITICAL("This is a critical message"); + RAFT_LOG_ERROR("This is an error message"); + RAFT_LOG_WARN("This is a warning message"); + RAFT_LOG_INFO("This is an info message"); + + logger::get().set_level(RAFT_LEVEL_WARN); + ASSERT_EQ(RAFT_LEVEL_WARN, logger::get().get_level()); + logger::get().set_level(RAFT_LEVEL_INFO); + ASSERT_EQ(RAFT_LEVEL_INFO, logger::get().get_level()); + + ASSERT_FALSE(logger::get().should_log_for(RAFT_LEVEL_TRACE)); + ASSERT_FALSE(logger::get().should_log_for(RAFT_LEVEL_DEBUG)); + ASSERT_TRUE(logger::get().should_log_for(RAFT_LEVEL_INFO)); + ASSERT_TRUE(logger::get().should_log_for(RAFT_LEVEL_WARN)); +} + +std::string logged = ""; +void exampleCallback(int lvl, const char* msg) { logged = std::string(msg); } + +int flushCount = 0; +void exampleFlush() { ++flushCount; } + +class loggerTest : public ::testing::Test { + protected: + void SetUp() override + { + flushCount = 0; + logged = ""; + logger::get().set_level(RAFT_LEVEL_TRACE); + } + + void TearDown() override + { + logger::get().set_callback(nullptr); + logger::get().set_flush(nullptr); + logger::get().set_level(RAFT_LEVEL_INFO); + } +}; + +TEST_F(loggerTest, callback) +{ + std::string testMsg; + logger::get().set_callback(exampleCallback); + + testMsg = "This is a critical message"; + RAFT_LOG_CRITICAL(testMsg.c_str()); + ASSERT_TRUE(logged.find(testMsg) != std::string::npos); + + testMsg = "This is an error message"; + RAFT_LOG_ERROR(testMsg.c_str()); + ASSERT_TRUE(logged.find(testMsg) != std::string::npos); + + testMsg = "This is a warning message"; + RAFT_LOG_WARN(testMsg.c_str()); + ASSERT_TRUE(logged.find(testMsg) != std::string::npos); + + testMsg = "This is an info message"; + RAFT_LOG_INFO(testMsg.c_str()); + ASSERT_TRUE(logged.find(testMsg) != std::string::npos); + + testMsg = "This is a debug message"; + RAFT_LOG_DEBUG(testMsg.c_str()); + ASSERT_TRUE(logged.find(testMsg) != std::string::npos); +} + +TEST_F(loggerTest, flush) +{ + logger::get().set_flush(exampleFlush); + logger::get().flush(); + ASSERT_EQ(1, flushCount); +} + +} // namespace raft \ No newline at end of file From b7cad57e5e49f74dadae6ad2eb43ec792879652f Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 9 Mar 2022 18:47:00 -0500 Subject: [PATCH 2/4] Consolidating into single file --- cpp/include/raft/common/detail/logger.hpp | 153 ---------------------- cpp/include/raft/common/logger.hpp | 153 +++++++++++++++++----- cpp/test/common/logger.cpp | 32 ++--- 3 files changed, 136 insertions(+), 202 deletions(-) delete mode 100644 cpp/include/raft/common/detail/logger.hpp diff --git a/cpp/include/raft/common/detail/logger.hpp b/cpp/include/raft/common/detail/logger.hpp deleted file mode 100644 index 053b6e3c88..0000000000 --- a/cpp/include/raft/common/detail/logger.hpp +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -#define SPDLOG_HEADER_ONLY -#include // NOLINT -#include // NOLINT - -#include - -#include -#include -#include -#include - -#include - -/** - * @defgroup logging levels used in raft - * - * @note exactly match the corresponding ones (but reverse in terms of value) - * in spdlog for wrapping purposes - * - * @{ - */ -#define RAFT_LEVEL_TRACE 6 -#define RAFT_LEVEL_DEBUG 5 -#define RAFT_LEVEL_INFO 4 -#define RAFT_LEVEL_WARN 3 -#define RAFT_LEVEL_ERROR 2 -#define RAFT_LEVEL_CRITICAL 1 -#define RAFT_LEVEL_OFF 0 -/** @} */ - -#if !defined(RAFT_ACTIVE_LEVEL) -#define RAFT_ACTIVE_LEVEL RAFT_LEVEL_DEBUG -#endif - -namespace spdlog { -class logger; -namespace sinks { -template -class CallbackSink; -using callback_sink_mt = CallbackSink; -}; // namespace sinks -}; // namespace spdlog - -namespace raft::detail { - -/** - * @defgroup CStringFormat Expand a C-style format string - * - * @brief Expands C-style formatted string into std::string - * - * @param[in] fmt format string - * @param[in] vl respective values for each of format modifiers in the string - * - * @return the expanded `std::string` - * - * @{ - */ -std::string format(const char* fmt, va_list& vl) -{ - char buf[4096]; - vsnprintf(buf, sizeof(buf), fmt, vl); - return std::string(buf); -} - -std::string format(const char* fmt, ...) -{ - va_list vl; - va_start(vl, fmt); - std::string str = format(fmt, vl); - va_end(vl); - return str; -} -/** @} */ - -int convert_level_to_spdlog(int level) -{ - level = std::max(RAFT_LEVEL_OFF, std::min(RAFT_LEVEL_TRACE, level)); - return RAFT_LEVEL_TRACE - level; -} - -}; // namespace raft::detail - -/** - * @defgroup loggerMacros Helper macros for dealing with logging - * @{ - */ -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_TRACE) -#define RAFT_LOG_TRACE(fmt, ...) \ - do { \ - std::stringstream ss; \ - ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ - ss << raft::detail::format(fmt, ##__VA_ARGS__); \ - raft::logger::get().log(RAFT_LEVEL_TRACE, ss.str().c_str()); \ - } while (0) -#else -#define RAFT_LOG_TRACE(fmt, ...) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_DEBUG) -#define RAFT_LOG_DEBUG(fmt, ...) \ - do { \ - std::stringstream ss; \ - ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ - ss << raft::detail::format(fmt, ##__VA_ARGS__); \ - raft::logger::get().log(RAFT_LEVEL_DEBUG, ss.str().c_str()); \ - } while (0) -#else -#define RAFT_LOG_DEBUG(fmt, ...) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_INFO) -#define RAFT_LOG_INFO(fmt, ...) raft::logger::get().log(RAFT_LEVEL_INFO, fmt, ##__VA_ARGS__) -#else -#define RAFT_LOG_INFO(fmt, ...) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_WARN) -#define RAFT_LOG_WARN(fmt, ...) raft::logger::get().log(RAFT_LEVEL_WARN, fmt, ##__VA_ARGS__) -#else -#define RAFT_LOG_WARN(fmt, ...) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_ERROR) -#define RAFT_LOG_ERROR(fmt, ...) raft::logger::get().log(RAFT_LEVEL_ERROR, fmt, ##__VA_ARGS__) -#else -#define RAFT_LOG_ERROR(fmt, ...) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_CRITICAL) -#define RAFT_LOG_CRITICAL(fmt, ...) raft::logger::get().log(RAFT_LEVEL_CRITICAL, fmt, ##__VA_ARGS__) -#else -#define RAFT_LOG_CRITICAL(fmt, ...) void(0) -#endif -/** @} */ \ No newline at end of file diff --git a/cpp/include/raft/common/logger.hpp b/cpp/include/raft/common/logger.hpp index aa7c55e863..d8d020ee58 100644 --- a/cpp/include/raft/common/logger.hpp +++ b/cpp/include/raft/common/logger.hpp @@ -25,12 +25,74 @@ #include #include -#include +#include + +#define SPDLOG_HEADER_ONLY +#include +#include // NOLINT +#include // NOLINT + +/** + * @defgroup logging levels used in raft + * + * @note exactly match the corresponding ones (but reverse in terms of value) + * in spdlog for wrapping purposes + * + * @{ + */ +#define RAFT_LEVEL_TRACE 6 +#define RAFT_LEVEL_DEBUG 5 +#define RAFT_LEVEL_INFO 4 +#define RAFT_LEVEL_WARN 3 +#define RAFT_LEVEL_ERROR 2 +#define RAFT_LEVEL_CRITICAL 1 +#define RAFT_LEVEL_OFF 0 +/** @} */ + +#if !defined(RAFT_ACTIVE_LEVEL) +#define RAFT_ACTIVE_LEVEL RAFT_LEVEL_DEBUG +#endif namespace raft { +static const std::string RAFT_NAME = "raft"; static const std::string default_log_pattern("[%L] [%H:%M:%S.%f] %v"); +/** + * @defgroup CStringFormat Expand a C-style format string + * + * @brief Expands C-style formatted string into std::string + * + * @param[in] fmt format string + * @param[in] vl respective values for each of format modifiers in the string + * + * @return the expanded `std::string` + * + * @{ + */ +std::string format(const char* fmt, va_list& vl) +{ + char buf[4096]; + vsnprintf(buf, sizeof(buf), fmt, vl); + return std::string(buf); +} + +std::string format(const char* fmt, ...) +{ + va_list vl; + va_start(vl, fmt); + std::string str = format(fmt, vl); + va_end(vl); + return str; +} +/** @} */ + +int convert_level_to_spdlog(int level) +{ + level = std::max(RAFT_LEVEL_OFF, std::min(RAFT_LEVEL_TRACE, level)); + return RAFT_LEVEL_TRACE - level; +} + /** * @brief The main Logging class for raft library. * @@ -78,7 +140,7 @@ class logger { */ void set_level(int level) { - level = detail::convert_level_to_spdlog(level); + level = convert_level_to_spdlog(level); spdlogger->set_level(static_cast(level)); } @@ -117,7 +179,7 @@ class logger { */ bool should_log_for(int level) const { - level = detail::convert_level_to_spdlog(level); + level = convert_level_to_spdlog(level); auto level_e = static_cast(level); return spdlogger->should_log(level_e); } @@ -147,13 +209,13 @@ class logger { */ void log(int level, const char* fmt, ...) { - level = detail::convert_level_to_spdlog(level); + level = convert_level_to_spdlog(level); auto level_e = static_cast(level); // explicit check to make sure that we only expand messages when required if (spdlogger->should_log(level_e)) { va_list vl; va_start(vl, fmt); - auto msg = detail::format(fmt, vl); + auto msg = format(fmt, vl); va_end(vl); spdlogger->log(level_e, msg); } @@ -176,36 +238,61 @@ class logger { int cur_level; }; // class logger +}; // namespace raft + /** - * @brief RAII based pattern setter for logger class - * - * @code{.cpp} - * { - * PatternSetter _("%l -- %v"); - * RAFT_LOG_INFO("Test message\n"); - * } - * @endcode + * @defgroup loggerMacros Helper macros for dealing with logging + * @{ */ -class PatternSetter { - public: - /** - * @brief Set the pattern for the rest of the log messages - * @param[in] pattern pattern to be set - */ - PatternSetter(const std::string& pattern = "%v") : prev_pattern() - { - prev_pattern = logger::get().get_pattern(); - logger::get().set_pattern(pattern); - } +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_TRACE) +#define RAFT_LOG_TRACE(fmt, ...) \ + do { \ + std::stringstream ss; \ + ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ + ss << raft::detail::format(fmt, ##__VA_ARGS__); \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_TRACE, ss.str().c_str()); \ + } while (0) +#else +#define RAFT_LOG_TRACE(fmt, ...) void(0) +#endif - /** - * @brief This will restore the previous pattern that was active during the - * moment this object was created - */ - ~PatternSetter() { logger::get().set_pattern(prev_pattern); } +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_DEBUG) +#define RAFT_LOG_DEBUG(fmt, ...) \ + do { \ + std::stringstream ss; \ + ss << raft::format("%s:%d ", __FILE__, __LINE__); \ + ss << raft::format(fmt, ##__VA_ARGS__); \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_DEBUG, ss.str().c_str()); \ + } while (0) +#else +#define RAFT_LOG_DEBUG(fmt, ...) void(0) +#endif - private: - std::string prev_pattern; -}; // class PatternSetter +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_INFO) +#define RAFT_LOG_INFO(fmt, ...) \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_INFO, fmt, ##__VA_ARGS__) +#else +#define RAFT_LOG_INFO(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_WARN) +#define RAFT_LOG_WARN(fmt, ...) \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_WARN, fmt, ##__VA_ARGS__) +#else +#define RAFT_LOG_WARN(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_ERROR) +#define RAFT_LOG_ERROR(fmt, ...) \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_ERROR, fmt, ##__VA_ARGS__) +#else +#define RAFT_LOG_ERROR(fmt, ...) void(0) +#endif -}; // namespace raft \ No newline at end of file +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_CRITICAL) +#define RAFT_LOG_CRITICAL(fmt, ...) \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_CRITICAL, fmt, ##__VA_ARGS__) +#else +#define RAFT_LOG_CRITICAL(fmt, ...) void(0) +#endif +/** @} */ \ No newline at end of file diff --git a/cpp/test/common/logger.cpp b/cpp/test/common/logger.cpp index ff63b8249e..218b33050c 100644 --- a/cpp/test/common/logger.cpp +++ b/cpp/test/common/logger.cpp @@ -27,15 +27,15 @@ TEST(logger, Test) RAFT_LOG_WARN("This is a warning message"); RAFT_LOG_INFO("This is an info message"); - logger::get().set_level(RAFT_LEVEL_WARN); - ASSERT_EQ(RAFT_LEVEL_WARN, logger::get().get_level()); - logger::get().set_level(RAFT_LEVEL_INFO); - ASSERT_EQ(RAFT_LEVEL_INFO, logger::get().get_level()); - - ASSERT_FALSE(logger::get().should_log_for(RAFT_LEVEL_TRACE)); - ASSERT_FALSE(logger::get().should_log_for(RAFT_LEVEL_DEBUG)); - ASSERT_TRUE(logger::get().should_log_for(RAFT_LEVEL_INFO)); - ASSERT_TRUE(logger::get().should_log_for(RAFT_LEVEL_WARN)); + logger::get(RAFT_NAME).set_level(RAFT_LEVEL_WARN); + ASSERT_EQ(RAFT_LEVEL_WARN, logger::get(RAFT_NAME).get_level()); + logger::get(RAFT_NAME).set_level(RAFT_LEVEL_INFO); + ASSERT_EQ(RAFT_LEVEL_INFO, logger::get(RAFT_NAME).get_level()); + + ASSERT_FALSE(logger::get(RAFT_NAME).should_log_for(RAFT_LEVEL_TRACE)); + ASSERT_FALSE(logger::get(RAFT_NAME).should_log_for(RAFT_LEVEL_DEBUG)); + ASSERT_TRUE(logger::get(RAFT_NAME).should_log_for(RAFT_LEVEL_INFO)); + ASSERT_TRUE(logger::get(RAFT_NAME).should_log_for(RAFT_LEVEL_WARN)); } std::string logged = ""; @@ -50,21 +50,21 @@ class loggerTest : public ::testing::Test { { flushCount = 0; logged = ""; - logger::get().set_level(RAFT_LEVEL_TRACE); + logger::get(RAFT_NAME).set_level(RAFT_LEVEL_TRACE); } void TearDown() override { - logger::get().set_callback(nullptr); - logger::get().set_flush(nullptr); - logger::get().set_level(RAFT_LEVEL_INFO); + logger::get(RAFT_NAME).set_callback(nullptr); + logger::get(RAFT_NAME).set_flush(nullptr); + logger::get(RAFT_NAME).set_level(RAFT_LEVEL_INFO); } }; TEST_F(loggerTest, callback) { std::string testMsg; - logger::get().set_callback(exampleCallback); + logger::get(RAFT_NAME).set_callback(exampleCallback); testMsg = "This is a critical message"; RAFT_LOG_CRITICAL(testMsg.c_str()); @@ -89,8 +89,8 @@ TEST_F(loggerTest, callback) TEST_F(loggerTest, flush) { - logger::get().set_flush(exampleFlush); - logger::get().flush(); + logger::get(RAFT_NAME).set_flush(exampleFlush); + logger::get(RAFT_NAME).flush(); ASSERT_EQ(1, flushCount); } From 22972fca93f55bde421bae0a3a6cba28c0ea86f3 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 9 Mar 2022 18:51:04 -0500 Subject: [PATCH 3/4] Updating copyrights --- cpp/include/raft/common/detail/callback_sink.hpp | 2 +- cpp/include/raft/common/detail/scatter.cuh | 2 +- cpp/include/raft/common/scatter.cuh | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/common/detail/callback_sink.hpp b/cpp/include/raft/common/detail/callback_sink.hpp index ecd869ee4f..e6dc07b49d 100644 --- a/cpp/include/raft/common/detail/callback_sink.hpp +++ b/cpp/include/raft/common/detail/callback_sink.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/common/detail/scatter.cuh b/cpp/include/raft/common/detail/scatter.cuh index e158999b1b..4087625320 100644 --- a/cpp/include/raft/common/detail/scatter.cuh +++ b/cpp/include/raft/common/detail/scatter.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/common/scatter.cuh b/cpp/include/raft/common/scatter.cuh index 04b4393261..9735ccdf2b 100644 --- a/cpp/include/raft/common/scatter.cuh +++ b/cpp/include/raft/common/scatter.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From a7d3469d3e6a9999621f0eba9b42d782e4e8d6fe Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 9 Mar 2022 18:52:48 -0500 Subject: [PATCH 4/4] Updating readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 54dd394a69..606197cde0 100755 --- a/README.md +++ b/README.md @@ -74,7 +74,7 @@ auto input = raft::make_device_matrix(handle, n_samples, n_features); auto labels = raft::make_device_vector(handle, n_samples); auto output = raft::make_device_matrix(handle, n_samples, n_samples); -raft::random::make_blobs(handle, input, labels); +raft::random::make_blobs(handle, input.view(), labels.view()); auto metric = raft::distance::DistanceType::L2SqrtExpanded; raft::distance::pairwise_distance(handle, input.view(), input.view(), output.view(), metric);