Skip to content

Commit

Permalink
Merge pull request #158 from rapidsai/fea-019-fusedl2nn
Browse files Browse the repository at this point in the history
[REVIEW] Moving optimized L2 1-nearest neighbors implementation from cuml
  • Loading branch information
cjnolet authored Mar 2, 2021
2 parents 539b8e9 + ffa8a8f commit b055cf8
Show file tree
Hide file tree
Showing 5 changed files with 1,193 additions and 0 deletions.
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ if(BUILD_RAFT_TESTS)
# keep the files in alphabetical order!
add_executable(test_raft
test/cudart_utils.cpp
test/distance/fused_l2_nn.cu
test/handle.cpp
test/integer_utils.cpp
test/lap/lap.cu
Expand Down
120 changes: 120 additions & 0 deletions cpp/include/raft/common/device_loads_stores.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/cuda_utils.cuh>

namespace raft {

/**
* @defgroup SmemStores Shared memory store operations
* @{
* @brief Stores to shared memory (both vectorized and non-vectorized forms)
* @param[out] addr shared memory address
* @param[in] x data to be stored at this address
*/
DI void sts(float* addr, const float& x) { *addr = x; }
DI void sts(float* addr, const float (&x)[1]) { *addr = x[0]; }
DI void sts(float* addr, const float (&x)[2]) {
float2 v2 = make_float2(x[0], x[1]);
auto* s2 = reinterpret_cast<float2*>(addr);
*s2 = v2;
}
DI void sts(float* addr, const float (&x)[4]) {
float4 v4 = make_float4(x[0], x[1], x[2], x[3]);
auto* s4 = reinterpret_cast<float4*>(addr);
*s4 = v4;
}
DI void sts(double* addr, const double& x) { *addr = x; }
DI void sts(double* addr, const double (&x)[1]) { *addr = x[0]; }
DI void sts(double* addr, const double (&x)[2]) {
double2 v2 = make_double2(x[0], x[1]);
auto* s2 = reinterpret_cast<double2*>(addr);
*s2 = v2;
}
/** @} */

/**
* @defgroup SmemLoads Shared memory load operations
* @{
* @brief Loads from shared memory (both vectorized and non-vectorized forms)
* @param[out] x the data to be loaded
* @param[in] addr shared memory address from where to load
*/
DI void lds(float& x, float* addr) { x = *addr; }
DI void lds(float (&x)[1], float* addr) { x[0] = *addr; }
DI void lds(float (&x)[2], float* addr) {
auto* s2 = reinterpret_cast<float2*>(addr);
auto v2 = *s2;
x[0] = v2.x;
x[1] = v2.y;
}
DI void lds(float (&x)[4], float* addr) {
auto* s4 = reinterpret_cast<float4*>(addr);
auto v4 = *s4;
x[0] = v4.x;
x[1] = v4.y;
x[2] = v4.z;
x[3] = v4.w;
}
DI void lds(double& x, double* addr) { x = *addr; }
DI void lds(double (&x)[1], double* addr) { x[0] = *addr; }
DI void lds(double (&x)[2], double* addr) {
auto* s2 = reinterpret_cast<double2*>(addr);
auto v2 = *s2;
x[0] = v2.x;
x[1] = v2.y;
}
/** @} */

/**
* @defgroup GlobalLoads Global cached load operations
* @{
* @brief Load from global memory with caching at L1 level
* @param[out] x data to be loaded from global memory
* @param[in] addr address in global memory from where to load
*/
DI void ldg(float& x, const float* addr) {
asm volatile("ld.global.cg.f32 %0, [%1];" : "=f"(x) : "l"(addr));
}
DI void ldg(float (&x)[1], const float* addr) {
asm volatile("ld.global.cg.f32 %0, [%1];" : "=f"(x[0]) : "l"(addr));
}
DI void ldg(float (&x)[2], const float* addr) {
asm volatile("ld.global.cg.v2.f32 {%0, %1}, [%2];"
: "=f"(x[0]), "=f"(x[1])
: "l"(addr));
}
DI void ldg(float (&x)[4], const float* addr) {
asm volatile("ld.global.cg.v4.f32 {%0, %1, %2, %3}, [%4];"
: "=f"(x[0]), "=f"(x[1]), "=f"(x[2]), "=f"(x[3])
: "l"(addr));
}
DI void ldg(double& x, const double* addr) {
asm volatile("ld.global.cg.f64 %0, [%1];" : "=d"(x) : "l"(addr));
}
DI void ldg(double (&x)[1], const double* addr) {
asm volatile("ld.global.cg.f64 %0, [%1];" : "=d"(x[0]) : "l"(addr));
}
DI void ldg(double (&x)[2], const double* addr) {
asm volatile("ld.global.cg.v2.f64 {%0, %1}, [%2];"
: "=d"(x[0]), "=d"(x[1])
: "l"(addr));
}
/** @} */

} // namespace raft
Loading

0 comments on commit b055cf8

Please sign in to comment.