From e9223367248372cfad2f49707a2776a64b95c879 Mon Sep 17 00:00:00 2001 From: rhdong Date: Fri, 12 Jan 2024 13:23:43 -0800 Subject: [PATCH] add back the `spmm.cuh` for better compatibility --- cpp/include/raft/sparse/linalg/spmm.cuh | 22 ++++++++++++++++++++++ cpp/test/sparse/sddmm.cu | 20 ++++++++++++++++++++ 2 files changed, 42 insertions(+) create mode 100644 cpp/include/raft/sparse/linalg/spmm.cuh diff --git a/cpp/include/raft/sparse/linalg/spmm.cuh b/cpp/include/raft/sparse/linalg/spmm.cuh new file mode 100644 index 0000000000..439ed8c341 --- /dev/null +++ b/cpp/include/raft/sparse/linalg/spmm.cuh @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#pragma message(__FILE__ \ + " is deprecated and will be removed in a future release." \ + " Please use the spmm.hpp at the same path instead.") + +#include diff --git a/cpp/test/sparse/sddmm.cu b/cpp/test/sparse/sddmm.cu index 9288d89199..9323ee8c2b 100644 --- a/cpp/test/sparse/sddmm.cu +++ b/cpp/test/sparse/sddmm.cu @@ -24,6 +24,7 @@ #include #include #include +#include #include "../test_utils.cuh" @@ -49,6 +50,14 @@ struct SDDMMInputs { unsigned long long int seed; }; +template +struct sum_abs_op { + __host__ __device__ ValueType operator()(const ValueType& x, const ValueType& y) const + { + return y >= ValueType(0.0) ? (x + y) : (x - y); + } +}; + template ::std::ostream& operator<<(::std::ostream& os, const SDDMMInputs& params) { @@ -271,6 +280,17 @@ class SDDMMTest : public ::testing::TestWithParam(params.tolerance), stream)); + + thrust::device_ptr expected_data_ptr = + thrust::device_pointer_cast(c_expected_data_d.data()); + ValueType sum_abs = thrust::reduce(thrust::cuda::par.on(stream), + expected_data_ptr, + expected_data_ptr + c_expected_data_d.size(), + ValueType(0.0f), + sum_abs_op()); + ValueType avg = sum_abs / (1.0f * c_expected_data_d.size()); + + ASSERT_GE(avg, (params.tolerance * static_cast(0.001f))); } raft::resources handle;