From b28c705dede5923f80097ea0a030aa33d0c2ab99 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 11 Mar 2022 09:04:56 -0500 Subject: [PATCH] Adding logger (#550) We've talked about moving this for awhile so I figured it's time to do it. This is also going to help @jnke2016 and I trace through a segfault in the cugraph mnmg code. Authors: - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Brad Rees (https://github.com/BradReesWork) URL: https://github.com/rapidsai/raft/pull/550 --- README.md | 2 +- .../raft/common/detail/callback_sink.hpp | 71 +++++ cpp/include/raft/common/detail/scatter.cuh | 52 +++ cpp/include/raft/common/logger.hpp | 298 ++++++++++++++++++ cpp/include/raft/common/scatter.cuh | 44 +-- cpp/test/CMakeLists.txt | 1 + cpp/test/common/logger.cpp | 97 ++++++ 7 files changed, 528 insertions(+), 37 deletions(-) create mode 100644 cpp/include/raft/common/detail/callback_sink.hpp create mode 100644 cpp/include/raft/common/detail/scatter.cuh create mode 100644 cpp/include/raft/common/logger.hpp create mode 100644 cpp/test/common/logger.cpp diff --git a/README.md b/README.md index 54dd394a69..606197cde0 100755 --- a/README.md +++ b/README.md @@ -74,7 +74,7 @@ auto input = raft::make_device_matrix(handle, n_samples, n_features); auto labels = raft::make_device_vector(handle, n_samples); auto output = raft::make_device_matrix(handle, n_samples, n_samples); -raft::random::make_blobs(handle, input, labels); +raft::random::make_blobs(handle, input.view(), labels.view()); auto metric = raft::distance::DistanceType::L2SqrtExpanded; raft::distance::pairwise_distance(handle, input.view(), input.view(), output.view(), metric); diff --git a/cpp/include/raft/common/detail/callback_sink.hpp b/cpp/include/raft/common/detail/callback_sink.hpp new file mode 100644 index 0000000000..e6dc07b49d --- /dev/null +++ b/cpp/include/raft/common/detail/callback_sink.hpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +#define SPDLOG_HEADER_ONLY +#include +#include +#include + +namespace spdlog::sinks { + +typedef void (*LogCallback)(int lvl, const char* msg); + +template +class CallbackSink : public base_sink { + public: + explicit CallbackSink(std::string tag = "spdlog", + LogCallback callback = nullptr, + void (*flush)() = nullptr) + : _callback{callback}, _flush{flush} {}; + + void set_callback(LogCallback callback) { _callback = callback; } + void set_flush(void (*flush)()) { _flush = flush; } + + protected: + void sink_it_(const details::log_msg& msg) override + { + spdlog::memory_buf_t formatted; + base_sink::formatter_->format(msg, formatted); + std::string msg_string = fmt::to_string(formatted); + + if (_callback) { + _callback(static_cast(msg.level), msg_string.c_str()); + } else { + std::cout << msg_string; + } + } + + void flush_() override + { + if (_flush) { + _flush(); + } else { + std::cout << std::flush; + } + } + + LogCallback _callback; + void (*_flush)(); +}; + +using callback_sink_mt = CallbackSink; +using callback_sink_st = CallbackSink; + +} // end namespace spdlog::sinks \ No newline at end of file diff --git a/cpp/include/raft/common/detail/scatter.cuh b/cpp/include/raft/common/detail/scatter.cuh new file mode 100644 index 0000000000..4087625320 --- /dev/null +++ b/cpp/include/raft/common/detail/scatter.cuh @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::detail { + +template +__global__ void scatterKernel(DataT* out, const DataT* in, const IdxT* idx, IdxT len, Lambda op) +{ + typedef TxN_t DataVec; + typedef TxN_t IdxVec; + IdxT tid = threadIdx.x + ((IdxT)blockIdx.x * blockDim.x); + tid *= VecLen; + if (tid >= len) return; + IdxVec idxIn; + idxIn.load(idx, tid); + DataVec dataIn; +#pragma unroll + for (int i = 0; i < VecLen; ++i) { + auto inPos = idxIn.val.data[i]; + dataIn.val.data[i] = op(in[inPos], tid + i); + } + dataIn.store(out, tid); +} + +template +void scatterImpl( + DataT* out, const DataT* in, const IdxT* idx, IdxT len, Lambda op, cudaStream_t stream) +{ + const IdxT nblks = raft::ceildiv(VecLen ? len / VecLen : len, (IdxT)TPB); + scatterKernel<<>>(out, in, idx, len, op); + RAFT_CUDA_TRY(cudaGetLastError()); +} + +} // namespace raft::detail diff --git a/cpp/include/raft/common/logger.hpp b/cpp/include/raft/common/logger.hpp new file mode 100644 index 0000000000..d8d020ee58 --- /dev/null +++ b/cpp/include/raft/common/logger.hpp @@ -0,0 +1,298 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include + +#include +#include +#include +#include +#include + +#include + +#define SPDLOG_HEADER_ONLY +#include +#include // NOLINT +#include // NOLINT + +/** + * @defgroup logging levels used in raft + * + * @note exactly match the corresponding ones (but reverse in terms of value) + * in spdlog for wrapping purposes + * + * @{ + */ +#define RAFT_LEVEL_TRACE 6 +#define RAFT_LEVEL_DEBUG 5 +#define RAFT_LEVEL_INFO 4 +#define RAFT_LEVEL_WARN 3 +#define RAFT_LEVEL_ERROR 2 +#define RAFT_LEVEL_CRITICAL 1 +#define RAFT_LEVEL_OFF 0 +/** @} */ + +#if !defined(RAFT_ACTIVE_LEVEL) +#define RAFT_ACTIVE_LEVEL RAFT_LEVEL_DEBUG +#endif + +namespace raft { + +static const std::string RAFT_NAME = "raft"; +static const std::string default_log_pattern("[%L] [%H:%M:%S.%f] %v"); + +/** + * @defgroup CStringFormat Expand a C-style format string + * + * @brief Expands C-style formatted string into std::string + * + * @param[in] fmt format string + * @param[in] vl respective values for each of format modifiers in the string + * + * @return the expanded `std::string` + * + * @{ + */ +std::string format(const char* fmt, va_list& vl) +{ + char buf[4096]; + vsnprintf(buf, sizeof(buf), fmt, vl); + return std::string(buf); +} + +std::string format(const char* fmt, ...) +{ + va_list vl; + va_start(vl, fmt); + std::string str = format(fmt, vl); + va_end(vl); + return str; +} +/** @} */ + +int convert_level_to_spdlog(int level) +{ + level = std::max(RAFT_LEVEL_OFF, std::min(RAFT_LEVEL_TRACE, level)); + return RAFT_LEVEL_TRACE - level; +} + +/** + * @brief The main Logging class for raft library. + * + * This class acts as a thin wrapper over the underlying `spdlog` interface. The + * design is done in this way in order to avoid us having to also ship `spdlog` + * header files in our installation. + * + * @todo This currently only supports logging to stdout. Need to add support in + * future to add custom loggers as well [Issue #2046] + */ +class logger { + public: + // @todo setting the logger once per process with + logger(std::string const& name_ = "") + : sink{std::make_shared()}, + spdlogger{std::make_shared(name_, sink)}, + cur_pattern() + { + set_pattern(default_log_pattern); + set_level(RAFT_LEVEL_INFO); + } + /** + * @brief Singleton method to get the underlying logger object + * + * @return the singleton logger object + */ + static logger& get(std::string const& name = "") + { + if (log_map.find(name) == log_map.end()) { + log_map[name] = std::make_shared(name); + } + return *log_map[name]; + } + + /** + * @brief Set the logging level. + * + * Only messages with level equal or above this will be printed + * + * @param[in] level logging level + * + * @note The log level will actually be set only if the input is within the + * range [RAFT_LEVEL_TRACE, RAFT_LEVEL_OFF]. If it is not, then it'll + * be ignored. See documentation of decisiontree for how this gets used + */ + void set_level(int level) + { + level = convert_level_to_spdlog(level); + spdlogger->set_level(static_cast(level)); + } + + /** + * @brief Set the logging pattern + * + * @param[in] pattern the pattern to be set. Refer this link + * https://github.com/gabime/spdlog/wiki/3.-Custom-formatting + * to know the right syntax of this pattern + */ + void set_pattern(const std::string& pattern) + { + cur_pattern = pattern; + spdlogger->set_pattern(pattern); + } + + /** + * @brief Register a callback function to be run in place of usual log call + * + * @param[in] callback the function to be run on all logged messages + */ + void set_callback(void (*callback)(int lvl, const char* msg)) { sink->set_callback(callback); } + + /** + * @brief Register a flush function compatible with the registered callback + * + * @param[in] flush the function to use when flushing logs + */ + void set_flush(void (*flush)()) { sink->set_flush(flush); } + + /** + * @brief Tells whether messages will be logged for the given log level + * + * @param[in] level log level to be checked for + * @return true if messages will be logged for this level, else false + */ + bool should_log_for(int level) const + { + level = convert_level_to_spdlog(level); + auto level_e = static_cast(level); + return spdlogger->should_log(level_e); + } + + /** + * @brief Query for the current log level + * + * @return the current log level + */ + int get_level() const + { + auto level_e = spdlogger->level(); + return RAFT_LEVEL_TRACE - static_cast(level_e); + } + + /** + * @brief Get the current logging pattern + * @return the pattern + */ + std::string get_pattern() const { return cur_pattern; } + + /** + * @brief Main logging method + * + * @param[in] level logging level of this message + * @param[in] fmt C-like format string, followed by respective params + */ + void log(int level, const char* fmt, ...) + { + level = convert_level_to_spdlog(level); + auto level_e = static_cast(level); + // explicit check to make sure that we only expand messages when required + if (spdlogger->should_log(level_e)) { + va_list vl; + va_start(vl, fmt); + auto msg = format(fmt, vl); + va_end(vl); + spdlogger->log(level_e, msg); + } + } + + /** + * @brief Flush logs by calling flush on underlying logger + */ + void flush() { spdlogger->flush(); } + + ~logger() {} + + private: + logger(); + + static inline std::unordered_map> log_map; + std::shared_ptr sink; + std::shared_ptr spdlogger; + std::string cur_pattern; + int cur_level; +}; // class logger + +}; // namespace raft + +/** + * @defgroup loggerMacros Helper macros for dealing with logging + * @{ + */ +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_TRACE) +#define RAFT_LOG_TRACE(fmt, ...) \ + do { \ + std::stringstream ss; \ + ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ + ss << raft::detail::format(fmt, ##__VA_ARGS__); \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_TRACE, ss.str().c_str()); \ + } while (0) +#else +#define RAFT_LOG_TRACE(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_DEBUG) +#define RAFT_LOG_DEBUG(fmt, ...) \ + do { \ + std::stringstream ss; \ + ss << raft::format("%s:%d ", __FILE__, __LINE__); \ + ss << raft::format(fmt, ##__VA_ARGS__); \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_DEBUG, ss.str().c_str()); \ + } while (0) +#else +#define RAFT_LOG_DEBUG(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_INFO) +#define RAFT_LOG_INFO(fmt, ...) \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_INFO, fmt, ##__VA_ARGS__) +#else +#define RAFT_LOG_INFO(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_WARN) +#define RAFT_LOG_WARN(fmt, ...) \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_WARN, fmt, ##__VA_ARGS__) +#else +#define RAFT_LOG_WARN(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_ERROR) +#define RAFT_LOG_ERROR(fmt, ...) \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_ERROR, fmt, ##__VA_ARGS__) +#else +#define RAFT_LOG_ERROR(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_CRITICAL) +#define RAFT_LOG_CRITICAL(fmt, ...) \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_CRITICAL, fmt, ##__VA_ARGS__) +#else +#define RAFT_LOG_CRITICAL(fmt, ...) void(0) +#endif +/** @} */ \ No newline at end of file diff --git a/cpp/include/raft/common/scatter.cuh b/cpp/include/raft/common/scatter.cuh index 2d25b85a50..9735ccdf2b 100644 --- a/cpp/include/raft/common/scatter.cuh +++ b/cpp/include/raft/common/scatter.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,39 +16,11 @@ #pragma once +#include #include -#include namespace raft { -template -__global__ void scatterKernel(DataT* out, const DataT* in, const IdxT* idx, IdxT len, Lambda op) -{ - typedef TxN_t DataVec; - typedef TxN_t IdxVec; - IdxT tid = threadIdx.x + ((IdxT)blockIdx.x * blockDim.x); - tid *= VecLen; - if (tid >= len) return; - IdxVec idxIn; - idxIn.load(idx, tid); - DataVec dataIn; -#pragma unroll - for (int i = 0; i < VecLen; ++i) { - auto inPos = idxIn.val.data[i]; - dataIn.val.data[i] = op(in[inPos], tid + i); - } - dataIn.store(out, tid); -} - -template -void scatterImpl( - DataT* out, const DataT* in, const IdxT* idx, IdxT len, Lambda op, cudaStream_t stream) -{ - const IdxT nblks = raft::ceildiv(VecLen ? len / VecLen : len, (IdxT)TPB); - scatterKernel<<>>(out, in, idx, len, op); - RAFT_CUDA_TRY(cudaGetLastError()); -} - /** * @brief Performs scatter operation based on the input indexing array * @tparam DataT data type whose array gets scattered @@ -79,17 +51,17 @@ void scatter(DataT* out, constexpr size_t MaxPerElem = DataSize > IdxSize ? DataSize : IdxSize; size_t bytes = len * MaxPerElem; if (16 / MaxPerElem && bytes % 16 == 0) { - scatterImpl(out, in, idx, len, op, stream); + detail::scatterImpl(out, in, idx, len, op, stream); } else if (8 / MaxPerElem && bytes % 8 == 0) { - scatterImpl(out, in, idx, len, op, stream); + detail::scatterImpl(out, in, idx, len, op, stream); } else if (4 / MaxPerElem && bytes % 4 == 0) { - scatterImpl(out, in, idx, len, op, stream); + detail::scatterImpl(out, in, idx, len, op, stream); } else if (2 / MaxPerElem && bytes % 2 == 0) { - scatterImpl(out, in, idx, len, op, stream); + detail::scatterImpl(out, in, idx, len, op, stream); } else if (1 / MaxPerElem) { - scatterImpl(out, in, idx, len, op, stream); + detail::scatterImpl(out, in, idx, len, op, stream); } else { - scatterImpl(out, in, idx, len, op, stream); + detail::scatterImpl(out, in, idx, len, op, stream); } } diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 0d3121fee6..f8ae28f550 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -16,6 +16,7 @@ # keep the files in alphabetical order! add_executable(test_raft + test/common/logger.cpp test/common/seive.cu test/cudart_utils.cpp test/cluster_solvers.cu diff --git a/cpp/test/common/logger.cpp b/cpp/test/common/logger.cpp new file mode 100644 index 0000000000..218b33050c --- /dev/null +++ b/cpp/test/common/logger.cpp @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace raft { + +TEST(logger, Test) +{ + RAFT_LOG_CRITICAL("This is a critical message"); + RAFT_LOG_ERROR("This is an error message"); + RAFT_LOG_WARN("This is a warning message"); + RAFT_LOG_INFO("This is an info message"); + + logger::get(RAFT_NAME).set_level(RAFT_LEVEL_WARN); + ASSERT_EQ(RAFT_LEVEL_WARN, logger::get(RAFT_NAME).get_level()); + logger::get(RAFT_NAME).set_level(RAFT_LEVEL_INFO); + ASSERT_EQ(RAFT_LEVEL_INFO, logger::get(RAFT_NAME).get_level()); + + ASSERT_FALSE(logger::get(RAFT_NAME).should_log_for(RAFT_LEVEL_TRACE)); + ASSERT_FALSE(logger::get(RAFT_NAME).should_log_for(RAFT_LEVEL_DEBUG)); + ASSERT_TRUE(logger::get(RAFT_NAME).should_log_for(RAFT_LEVEL_INFO)); + ASSERT_TRUE(logger::get(RAFT_NAME).should_log_for(RAFT_LEVEL_WARN)); +} + +std::string logged = ""; +void exampleCallback(int lvl, const char* msg) { logged = std::string(msg); } + +int flushCount = 0; +void exampleFlush() { ++flushCount; } + +class loggerTest : public ::testing::Test { + protected: + void SetUp() override + { + flushCount = 0; + logged = ""; + logger::get(RAFT_NAME).set_level(RAFT_LEVEL_TRACE); + } + + void TearDown() override + { + logger::get(RAFT_NAME).set_callback(nullptr); + logger::get(RAFT_NAME).set_flush(nullptr); + logger::get(RAFT_NAME).set_level(RAFT_LEVEL_INFO); + } +}; + +TEST_F(loggerTest, callback) +{ + std::string testMsg; + logger::get(RAFT_NAME).set_callback(exampleCallback); + + testMsg = "This is a critical message"; + RAFT_LOG_CRITICAL(testMsg.c_str()); + ASSERT_TRUE(logged.find(testMsg) != std::string::npos); + + testMsg = "This is an error message"; + RAFT_LOG_ERROR(testMsg.c_str()); + ASSERT_TRUE(logged.find(testMsg) != std::string::npos); + + testMsg = "This is a warning message"; + RAFT_LOG_WARN(testMsg.c_str()); + ASSERT_TRUE(logged.find(testMsg) != std::string::npos); + + testMsg = "This is an info message"; + RAFT_LOG_INFO(testMsg.c_str()); + ASSERT_TRUE(logged.find(testMsg) != std::string::npos); + + testMsg = "This is a debug message"; + RAFT_LOG_DEBUG(testMsg.c_str()); + ASSERT_TRUE(logged.find(testMsg) != std::string::npos); +} + +TEST_F(loggerTest, flush) +{ + logger::get(RAFT_NAME).set_flush(exampleFlush); + logger::get(RAFT_NAME).flush(); + ASSERT_EQ(1, flushCount); +} + +} // namespace raft \ No newline at end of file