Skip to content

Commit

Permalink
Adding logger (#550)
Browse files Browse the repository at this point in the history
We've talked about moving this for awhile so I figured it's time to do it. This is also going to help @jnke2016 and I trace through a segfault in the cugraph mnmg code.

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Brad Rees (https://github.com/BradReesWork)

URL: #550
  • Loading branch information
cjnolet authored Mar 11, 2022
1 parent c0925f3 commit b28c705
Show file tree
Hide file tree
Showing 7 changed files with 528 additions and 37 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ auto input = raft::make_device_matrix<float>(handle, n_samples, n_features);
auto labels = raft::make_device_vector<int>(handle, n_samples);
auto output = raft::make_device_matrix<float>(handle, n_samples, n_samples);

raft::random::make_blobs(handle, input, labels);
raft::random::make_blobs(handle, input.view(), labels.view());

auto metric = raft::distance::DistanceType::L2SqrtExpanded;
raft::distance::pairwise_distance(handle, input.view(), input.view(), output.view(), metric);
Expand Down
71 changes: 71 additions & 0 deletions cpp/include/raft/common/detail/callback_sink.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <iostream>
#include <mutex>

#define SPDLOG_HEADER_ONLY
#include <spdlog/common.h>
#include <spdlog/details/log_msg.h>
#include <spdlog/sinks/base_sink.h>

namespace spdlog::sinks {

typedef void (*LogCallback)(int lvl, const char* msg);

template <class Mutex>
class CallbackSink : public base_sink<Mutex> {
public:
explicit CallbackSink(std::string tag = "spdlog",
LogCallback callback = nullptr,
void (*flush)() = nullptr)
: _callback{callback}, _flush{flush} {};

void set_callback(LogCallback callback) { _callback = callback; }
void set_flush(void (*flush)()) { _flush = flush; }

protected:
void sink_it_(const details::log_msg& msg) override
{
spdlog::memory_buf_t formatted;
base_sink<Mutex>::formatter_->format(msg, formatted);
std::string msg_string = fmt::to_string(formatted);

if (_callback) {
_callback(static_cast<int>(msg.level), msg_string.c_str());
} else {
std::cout << msg_string;
}
}

void flush_() override
{
if (_flush) {
_flush();
} else {
std::cout << std::flush;
}
}

LogCallback _callback;
void (*_flush)();
};

using callback_sink_mt = CallbackSink<std::mutex>;
using callback_sink_st = CallbackSink<details::null_mutex>;

} // end namespace spdlog::sinks
52 changes: 52 additions & 0 deletions cpp/include/raft/common/detail/scatter.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/cuda_utils.cuh>
#include <raft/vectorized.cuh>

namespace raft::detail {

template <typename DataT, int VecLen, typename Lambda, typename IdxT>
__global__ void scatterKernel(DataT* out, const DataT* in, const IdxT* idx, IdxT len, Lambda op)
{
typedef TxN_t<DataT, VecLen> DataVec;
typedef TxN_t<IdxT, VecLen> IdxVec;
IdxT tid = threadIdx.x + ((IdxT)blockIdx.x * blockDim.x);
tid *= VecLen;
if (tid >= len) return;
IdxVec idxIn;
idxIn.load(idx, tid);
DataVec dataIn;
#pragma unroll
for (int i = 0; i < VecLen; ++i) {
auto inPos = idxIn.val.data[i];
dataIn.val.data[i] = op(in[inPos], tid + i);
}
dataIn.store(out, tid);
}

template <typename DataT, int VecLen, typename Lambda, typename IdxT, int TPB>
void scatterImpl(
DataT* out, const DataT* in, const IdxT* idx, IdxT len, Lambda op, cudaStream_t stream)
{
const IdxT nblks = raft::ceildiv(VecLen ? len / VecLen : len, (IdxT)TPB);
scatterKernel<DataT, VecLen, Lambda, IdxT><<<nblks, TPB, 0, stream>>>(out, in, idx, len, op);
RAFT_CUDA_TRY(cudaGetLastError());
}

} // namespace raft::detail
Loading

0 comments on commit b28c705

Please sign in to comment.