Skip to content

Commit

Permalink
changed order of arguments according to best practice
Browse files Browse the repository at this point in the history
  • Loading branch information
mfoerste4 committed Mar 30, 2023
1 parent f57be13 commit 3f61b64
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 77 deletions.
72 changes: 36 additions & 36 deletions cpp/include/raft/distance/detail/kernels/gram_matrix.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -64,118 +64,118 @@ class GramMatrixBase {
/** Convenience function to evaluate the Gram matrix for two vector sets.
* Vector sets are provided in Matrix format
*
* @param [in] handle raft handle
* @param [in] x1 dense device matrix view, size [n1*n_cols]
* @param [in] x2 dense device matrix view, size [n2*n_cols]
* @param [out] out dense device matrix view for the Gram matrix, size [n1*n2]
* @param [in] handle raft handle
* @param norm_x1 optional L2-norm of x1's rows for computation within RBF.
* @param norm_x2 optional L2-norm of x2's rows for computation within RBF.
*/
void operator()(dense_input_matrix_view_t<math_t> x1,
void operator()(raft::device_resources const& handle,
dense_input_matrix_view_t<math_t> x1,
dense_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out,
raft::device_resources const& handle,
math_t* norm_x1 = nullptr,
math_t* norm_x2 = nullptr)
{
evaluate(x1, x2, out, handle, norm_x1, norm_x2);
evaluate(handle, x1, x2, out, norm_x1, norm_x2);
}

/** Convenience function to evaluate the Gram matrix for two vector sets.
* Vector sets are provided in Matrix format
*
* @param [in] handle raft handle
* @param [in] x1 csr device matrix view, size [n1*n_cols]
* @param [in] x2 dense device matrix view, size [n2*n_cols]
* @param [out] out dense device matrix view for the Gram matrix, size [n1*n2]
* @param [in] handle raft handle
* @param norm_x1 optional L2-norm of x1's rows for computation within RBF.
* @param norm_x2 optional L2-norm of x2's rows for computation within RBF.
*/
void operator()(csr_input_matrix_view_t<math_t> x1,
void operator()(raft::device_resources const& handle,
csr_input_matrix_view_t<math_t> x1,
dense_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out,
raft::device_resources const& handle,
math_t* norm_x1 = nullptr,
math_t* norm_x2 = nullptr)
{
evaluate(x1, x2, out, handle, norm_x1, norm_x2);
evaluate(handle, x1, x2, out, norm_x1, norm_x2);
}

/** Convenience function to evaluate the Gram matrix for two vector sets.
* Vector sets are provided in Matrix format
*
* @param [in] handle raft handle
* @param [in] x1 csr device matrix view, size [n1*n_cols]
* @param [in] x2 csr device matrix view, size [n2*n_cols]
* @param [out] out dense device matrix view for the Gram matrix, size [n1*n2]
* @param [in] handle raft handle
* @param norm_x1 optional L2-norm of x1's rows for computation within RBF.
* @param norm_x2 optional L2-norm of x2's rows for computation within RBF.
*/
void operator()(csr_input_matrix_view_t<math_t> x1,
void operator()(raft::device_resources const& handle,
csr_input_matrix_view_t<math_t> x1,
csr_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out,
raft::device_resources const& handle,
math_t* norm_x1 = nullptr,
math_t* norm_x2 = nullptr)
{
evaluate(x1, x2, out, handle, norm_x1, norm_x2);
evaluate(handle, x1, x2, out, norm_x1, norm_x2);
}

// unfortunately, 'evaluate' cannot be templatized as it needs to be virtual

/** Evaluate the Gram matrix for two vector sets using simple dot product.
*
* @param [in] handle raft handle
* @param [in] x1 dense device matrix view, size [n1*n_cols]
* @param [in] x2 dense device matrix view, size [n2*n_cols]
* @param [out] out dense device matrix view for the Gram matrix, size [n1*n2]
* @param [in] handle raft handle
* @param norm_x1 unused.
* @param norm_x2 unused.
*/
virtual void evaluate(dense_input_matrix_view_t<math_t> x1,
virtual void evaluate(raft::device_resources const& handle,
dense_input_matrix_view_t<math_t> x1,
dense_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out,
raft::device_resources const& handle,
math_t* norm_x1,
math_t* norm_x2)
{
linear(x1, x2, out, handle);
linear(handle, x1, x2, out);
}
/** Evaluate the Gram matrix for two vector sets using simple dot product.
*
* @param [in] handle raft handle
* @param [in] x1 csr device matrix view, size [n1*n_cols]
* @param [in] x2 dense device matrix view, size [n2*n_cols]
* @param [out] out dense device matrix view for the Gram matrix, size [n1*n2]
* @param [in] handle raft handle
* @param norm_x1 unused.
* @param norm_x2 unused.
*/
virtual void evaluate(csr_input_matrix_view_t<math_t> x1,
virtual void evaluate(raft::device_resources const& handle,
csr_input_matrix_view_t<math_t> x1,
dense_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out,
raft::device_resources const& handle,
math_t* norm_x1,
math_t* norm_x2)
{
linear(x1, x2, out, handle);
linear(handle, x1, x2, out);
}
/** Evaluate the Gram matrix for two vector sets using simple dot product.
*
* @param [in] handle raft handle
* @param [in] x1 csr device matrix view, size [n1*n_cols]
* @param [in] x2 csr device matrix view, size [n2*n_cols]
* @param [out] out dense device matrix view for the Gram matrix, size [n1*n2]
* @param [in] handle raft handle
* @param norm_x1 unused.
* @param norm_x2 unused.
*/
virtual void evaluate(csr_input_matrix_view_t<math_t> x1,
virtual void evaluate(raft::device_resources const& handle,
csr_input_matrix_view_t<math_t> x1,
csr_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out,
raft::device_resources const& handle,
math_t* norm_x1,
math_t* norm_x2)
{
linear(x1, x2, out, handle);
linear(handle, x1, x2, out);
}

/** Evaluate the Gram matrix for two vector sets using simple dot product.
Expand Down Expand Up @@ -340,15 +340,15 @@ class GramMatrixBase {
*
* Can be used as a building block for more complex kernel functions.
*
* @param [in] handle raft handle
* @param [in] x1 dense device matrix view, size [n1*n_cols]
* @param [in] x2 dense device matrix view, size [n2*n_cols]
* @param [out] out dense device matrix view for the Gram matrix, size [n1*n2]
* @param [in] handle raft handle
*/
void linear(dense_input_matrix_view_t<math_t> x1,
void linear(raft::device_resources const& handle,
dense_input_matrix_view_t<math_t> x1,
dense_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out,
raft::device_resources const& handle)
dense_output_matrix_view_t<math_t> out)
{
// check is_row_major consistency
bool is_row_major = get_is_row_major(x1) && get_is_row_major(x2) && get_is_row_major(out);
Expand Down Expand Up @@ -416,15 +416,15 @@ class GramMatrixBase {
*
* Can be used as a building block for more complex kernel functions.
*
* @param [in] handle raft handle
* @param [in] x1 csr device matrix view, size [n1*n_cols]
* @param [in] x2 dense device matrix view, size [n2*n_cols]
* @param [out] out dense device matrix view for the Gram matrix, size [n1*n2]
* @param [in] handle raft handle
*/
void linear(csr_input_matrix_view_t<math_t> x1,
void linear(raft::device_resources const& handle,
csr_input_matrix_view_t<math_t> x1,
dense_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out,
raft::device_resources const& handle)
dense_output_matrix_view_t<math_t> out)
{
// check is_row_major consistency
bool is_row_major = get_is_row_major(x2) && get_is_row_major(out);
Expand Down Expand Up @@ -453,15 +453,15 @@ class GramMatrixBase {
*
* Can be used as a building block for more complex kernel functions.
*
* @param [in] handle raft handle
* @param [in] x1 csr device matrix view, size [n1*n_cols]
* @param [in] x2 csr device matrix view, size [n2*n_cols]
* @param [out] out dense device matrix view for the Gram matrix, size [n1*n2]
* @param [in] handle raft handle
*/
void linear(csr_input_matrix_view_t<math_t> x1,
void linear(raft::device_resources const& handle,
csr_input_matrix_view_t<math_t> x1,
csr_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out,
raft::device_resources const& handle)
dense_output_matrix_view_t<math_t> out)
{
// check is_row_major consistency
bool is_row_major = get_is_row_major(out);
Expand Down
Loading

0 comments on commit 3f61b64

Please sign in to comment.