Skip to content

Commit

Permalink
Fix gpu version of FAISS with > 1024 results
Browse files Browse the repository at this point in the history
Querying a GPU Faiss index and asking for more than 1024 results caused an
assert that kills the program. This most frequently happens when trying
to recommend items for a user that has more than 1K liked items.
(see #149 )

Fix by falling back to the exact CPU version in this case.
  • Loading branch information
benfred committed Aug 28, 2018
1 parent 7f30220 commit e4817ab
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 1 deletion.
10 changes: 9 additions & 1 deletion implicit/approximate_als.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def fit(self, Ciu):
self.similar_items_index = index

def similar_items(self, itemid, N=10):
if not self.approximate_similar_items:
if not self.approximate_similar_items or (self.use_gpu and N >= 1024):
return super(FaissAlternatingLeastSquares, self).similar_items(itemid, N)

factors = self.item_factors[itemid]
Expand All @@ -389,6 +389,14 @@ def recommend(self, userid, user_items, N=10, filter_items=None, recalculate_use
liked.update(filter_items)
count = N + len(liked)

# the GPU variant of faiss doesn't support returning more than 1024 results.
# fall back to the exact match when this happens
if self.use_gpu and count >= 1024:
return super(FaissAlternatingLeastSquares,
self).recommend(userid, user_items, N=N,
filter_items=filter_items,
recalculate_user=recalculate_user)

# faiss expects multiple queries - convert query to a matrix
# and results back to single vectors
query = user.reshape(1, -1).astype('float32')
Expand Down
3 changes: 3 additions & 0 deletions implicit/nearest_neighbours.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def fit(self, weighted):

def recommend(self, userid, user_items, N=10, filter_items=None, recalculate_user=False):
""" returns the best N recommendations for a user given its id"""
if userid >= user_items.shape[0]:
raise ValueError("userid is out of bounds of the user_items matrix")

# recalculate_user is ignored because this is not a model based algorithm
items = N
if filter_items:
Expand Down
36 changes: 36 additions & 0 deletions tests/approximate_als_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from implicit.approximate_als import (AnnoyAlternatingLeastSquares, FaissAlternatingLeastSquares,
NMSLibAlternatingLeastSquares)
from implicit.cuda import HAS_CUDA

from .recommender_base_test import TestRecommenderBaseMixin

Expand Down Expand Up @@ -34,6 +35,41 @@ class FaissALSTest(unittest.TestCase, TestRecommenderBaseMixin):
def _get_model(self):
return FaissAlternatingLeastSquares(nlist=1, nprobe=1, factors=2, regularization=0,
use_gpu=False)

if HAS_CUDA:
class FaissALSGPUTest(unittest.TestCase, TestRecommenderBaseMixin):
__regularization = 0

def _get_model(self):
return FaissAlternatingLeastSquares(nlist=1, nprobe=1, factors=32,
regularization=self.__regularization,
use_gpu=True)

def test_similar_items(self):
# For the GPU version, we currently have to have factors be a multiple of 32
# (limitation that I think is caused by how we are currently calculating the
# dot product in CUDA, TODO: eventually should fix that code).
# this causes the test_similar_items call to fail if we set regularization to 0
self.__regularization = 1.0
try:
super(FaissALSGPUTest, self).test_similar_items()
finally:
self.__regularization = 0.0

def test_large_recommend(self):
# the GPU version of FAISS can't return more than 1K result (and will assert/exit)
# this tests out that we fall back in this case to the exact version and don't die
plays = self.get_checker_board(2048)
model = self._get_model()
model.show_progress = False
model.fit(plays)

recs = model.similar_items(0, N=1050)
self.assertEqual(recs[0][0], 0)

recs = model.recommend(0, plays.T.tocsr(), N=1050)
self.assertEqual(recs[0][0], 0)

except ImportError:
pass

Expand Down

0 comments on commit e4817ab

Please sign in to comment.