Skip to content

Commit

Permalink
remove double datatype API, code cleanup and other review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mdoijade committed Feb 16, 2024
1 parent 42841da commit 5384408
Show file tree
Hide file tree
Showing 10 changed files with 5 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ struct FusedDistanceNNPersistent {
void const* ptr_B,
void const* ptr_C,
void* ptr_Vector,
// volatile void* ptr_Tensor,
void* ptr_Tensor,
typename LayoutA::Stride::Index lda,
typename LayoutB::Stride::Index ldb,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,8 @@ class PredicatedTileIteratorReducedVec {
/// Parameters structure containing reference and precomputed state.
Params params_;

/// Byte-level pointer
// uint8_t* byte_pointer_;
/// Byte-level pointer first tile offset of this threadblock.
volatile uint8_t* first_tile_byte_pointer_;
// uint8_t* first_tile_byte_pointer_;

/// Array of boolean values to contain steady-state predicates
Mask mask_;
Expand Down Expand Up @@ -421,8 +418,6 @@ class PredicatedTileIteratorReducedVec {
first_tile_byte_pointer_ = reinterpret_cast<volatile uint8_t*>(pointer) +
LongIndex(block_offset.row()) * LongIndex(params_.stride);

// first_tile_byte_pointer_ = reinterpret_cast<uint8_t*>(pointer) +
// LongIndex(block_offset.row()) * LongIndex(params_.stride);
// Initialize internal state counter
state_[0] = state_[1] = state_[2] = 0;
}
Expand Down Expand Up @@ -468,13 +463,6 @@ class PredicatedTileIteratorReducedVec {
CUTLASS_DEVICE
~PredicatedTileIteratorReducedVec() {}

/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset)
{
// byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
}

/// Performs reduction and Stores a reduced output to memory
CUTLASS_DEVICE
void store_with_byte_offset(Fragment& frag, int64_t byte_offset) const
Expand Down Expand Up @@ -574,29 +562,25 @@ class PredicatedTileIteratorReducedVec {
PredicatedTileIteratorReducedVec& operator++()
{
++state_[0];
// if (!ScatterD) { byte_pointer_ += params_.advance_row; }

thread_start_row_ += ThreadMap::Shape::kRow;

if (state_[0] == ThreadMap::Count::kRow) {
state_[0] = 0;
++state_[1];
// byte_pointer_ += params_.advance_group;

thread_start_row_ +=
(ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow;

if (state_[1] == ThreadMap::Count::kGroup) {
state_[1] = 0;
++state_[2];
// byte_pointer_ += params_.advance_cluster;

thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup *
ThreadMap::Count::kRow * ThreadMap::Shape::kRow;

if (state_[2] == ThreadMap::Count::kCluster) {
state_[2] = 0;
// byte_pointer_ += params_.advance_tile;
}
}
}
Expand Down
10 changes: 1 addition & 9 deletions cpp/include/raft/distance/fused_distance_nn-ext.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -67,20 +67,12 @@ void fusedDistanceNNMinReduce(OutT* min,
float metric_arg, \
cudaStream_t stream)

instantiate_raft_distance_fusedDistanceNNMinReduce(double, double, int);
instantiate_raft_distance_fusedDistanceNNMinReduce(double, double, int64_t);
instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int);
instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int64_t);

// We can't have comma's in the macro expansion, so we use the COMMA macro:
#define COMMA ,

instantiate_raft_distance_fusedDistanceNNMinReduce(double,
raft::KeyValuePair<int COMMA double>,
int);
instantiate_raft_distance_fusedDistanceNNMinReduce(double,
raft::KeyValuePair<int64_t COMMA double>,
int64_t);
instantiate_raft_distance_fusedDistanceNNMinReduce(float, raft::KeyValuePair<int COMMA float>, int);
instantiate_raft_distance_fusedDistanceNNMinReduce(float,
raft::KeyValuePair<int64_t COMMA float>,
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/distance/fused_distance_nn-inl.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/distance/fused_distance_nn.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
1 change: 0 additions & 1 deletion cpp/include/raft/distance/fused_distance_nn_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#pragma once

#include <raft/core/resource/cuda_stream.hpp>
// #include <raft/distance/detail/fused_l2_nn.cuh>
#include <raft/distance/detail/fused_distance_nn/helper_structs.cuh>

namespace raft::distance {
Expand Down
12 changes: 0 additions & 12 deletions cpp/include/raft_runtime/distance/fused_distance_nn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,6 @@ void fused_distance_nn_min_arg(raft::resources const& handle,
bool isRowMajor,
float metric_arg);

void fused_distance_nn_min_arg(raft::resources const& handle,
int* min,
const double* x,
const double* y,
int m,
int n,
int k,
bool sqrt,
raft::distance::DistanceType metric,
bool isRowMajor,
float metric_arg);

/** @} */ // end group fused_distance_nn_min_arg_runtime

} // end namespace raft::runtime::distance
8 changes: 0 additions & 8 deletions cpp/src/distance/fused_distance_nn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,12 @@
float metric_arg, \
cudaStream_t stream)

instantiate_raft_distance_fusedDistanceNNMinReduce(double, double, int);
instantiate_raft_distance_fusedDistanceNNMinReduce(double, double, int64_t);
instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int);
instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int64_t);

// We can't have comma's in the macro expansion, so we use the COMMA macro:
#define COMMA ,

instantiate_raft_distance_fusedDistanceNNMinReduce(double,
raft::KeyValuePair<int COMMA double>,
int);
instantiate_raft_distance_fusedDistanceNNMinReduce(double,
raft::KeyValuePair<int64_t COMMA double>,
int64_t);
instantiate_raft_distance_fusedDistanceNNMinReduce(float, raft::KeyValuePair<int COMMA float>, int);
instantiate_raft_distance_fusedDistanceNNMinReduce(float,
raft::KeyValuePair<int64_t COMMA float>,
Expand Down
20 changes: 0 additions & 20 deletions cpp/src/raft_runtime/distance/fused_distance_min_arg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -114,24 +114,4 @@ void fused_distance_nn_min_arg(raft::resources const& handle,
}
}

void fused_distance_nn_min_arg(raft::resources const& handle,
int* min,
const double* x,
const double* y,
int m,
int n,
int k,
bool sqrt,
raft::distance::DistanceType metric,
bool isRowMajor,
float metric_arg)
{
switch (metric) {
case raft::distance::DistanceType::CosineExpanded:
compute_fused_cosine_nn_min_arg<double, int>(handle, min, x, y, m, n, k, sqrt);
break;
default: assert("only cosine metric is supported with fusedDistanceNN\n"); break;
}
}

} // end namespace raft::runtime::distance
2 changes: 2 additions & 0 deletions cpp/test/distance/fused_cosine_nn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Search with filter instantiation

#include "../test_utils.cuh"
#include <gtest/gtest.h>
#include <raft/core/kvp.hpp>
Expand Down

0 comments on commit 5384408

Please sign in to comment.