Skip to content

Commit

Permalink
add back the spmm.cuh for better compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Jan 13, 2024
1 parent 4563c3b commit e922336
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
22 changes: 22 additions & 0 deletions cpp/include/raft/sparse/linalg/spmm.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#pragma message(__FILE__ \
" is deprecated and will be removed in a future release." \
" Please use the spmm.hpp at the same path instead.")

#include <raft/sparse/linalg/detail/spmm.hpp>
20 changes: 20 additions & 0 deletions cpp/test/sparse/sddmm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <raft/random/make_blobs.cuh>
#include <raft/sparse/linalg/sddmm.hpp>
#include <raft/util/cudart_utils.hpp>
#include <thrust/reduce.h>

#include "../test_utils.cuh"

Expand All @@ -49,6 +50,14 @@ struct SDDMMInputs {
unsigned long long int seed;
};

template <typename ValueType>
struct sum_abs_op {
__host__ __device__ ValueType operator()(const ValueType& x, const ValueType& y) const
{
return y >= ValueType(0.0) ? (x + y) : (x - y);
}
};

template <typename ValueType, typename IndexType>
::std::ostream& operator<<(::std::ostream& os, const SDDMMInputs<ValueType, IndexType>& params)
{
Expand Down Expand Up @@ -271,6 +280,17 @@ class SDDMMTest : public ::testing::TestWithParam<SDDMMInputs<ValueType, IndexTy
c_expected_data_d.size(),
raft::CompareApprox<ValueType>(params.tolerance),
stream));

thrust::device_ptr<ValueType> expected_data_ptr =
thrust::device_pointer_cast(c_expected_data_d.data());
ValueType sum_abs = thrust::reduce(thrust::cuda::par.on(stream),
expected_data_ptr,
expected_data_ptr + c_expected_data_d.size(),
ValueType(0.0f),
sum_abs_op<ValueType>());
ValueType avg = sum_abs / (1.0f * c_expected_data_d.size());

ASSERT_GE(avg, (params.tolerance * static_cast<ValueType>(0.001f)));
}

raft::resources handle;
Expand Down

0 comments on commit e922336

Please sign in to comment.