Skip to content

Commit

Permalink
Use thrust::binary_search to verify negative samples on GPU (#524)
Browse files Browse the repository at this point in the history
Use thrust::binary_search to verify negative samples on the GPU instead of
doing a linear scan in the BPR model.

This leads to a noticeable perf increase on larger datasets. For instance on the
Github stars dataset:
    * Using linear_search: 3.09s/it
    * using thrust::binary_search 1.53s/it
    * w/ verify_negative_samples=False  1.18s/it

This change doubles the BPR training performance on that dataset, and also
leads to times that are only 30% slower than not verifying samples at all.
  • Loading branch information
benfred authored Jan 21, 2022
1 parent f960bea commit 07bfd8f
Showing 1 changed file with 4 additions and 24 deletions.
28 changes: 4 additions & 24 deletions implicit/gpu/bpr.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <curand.h>
#include <thrust/binary_search.h>
#include <thrust/execution_policy.h>

#include "implicit/gpu/als.h"
#include "implicit/gpu/dot.cuh"
Expand All @@ -12,28 +14,6 @@
namespace implicit {
namespace gpu {

// TODO: we could use an n-ary search here instead, but
// that will only be faster when the number of likes for a user is
// much greater than the number of threads (factors) we are using.
// Since most users on most datasets have relatively few likes, I'm
// using a simple linear scan here instea
__inline__ __device__ bool linear_search(int *start, int *end, int target) {
__shared__ bool ret;

if (threadIdx.x == 0)
ret = false;
__syncthreads();

int size = end - start;
for (int i = threadIdx.x; i < size; i += blockDim.x) {
if (start[i] == target) {
ret = true;
}
}
__syncthreads();
return ret;
}

__global__ void bpr_update_kernel(int samples, unsigned int *random_likes,
unsigned int *random_dislikes, int *itemids,
int *userids, int *indptr, int factors,
Expand All @@ -53,8 +33,8 @@ __global__ void bpr_update_kernel(int samples, unsigned int *random_likes,
dislikedid = itemids[disliked_index];

if (verify_negative_samples &&
linear_search(&itemids[indptr[userid]], &itemids[indptr[userid + 1]],
dislikedid)) {
thrust::binary_search(thrust::seq, &itemids[indptr[userid]],
&itemids[indptr[userid + 1]], dislikedid)) {
skipped += 1;
continue;
}
Expand Down

0 comments on commit 07bfd8f

Please sign in to comment.