Skip to content

Commit

Permalink
speed up and de-compile weighted avg
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed May 6, 2024
1 parent 6f5e0bc commit 790ef3d
Showing 1 changed file with 26 additions and 13 deletions.
39 changes: 26 additions & 13 deletions lenskit/algorithms/item_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _sim_blocks(
)


@torch.jit.script
@profile

Check failure on line 141 in lenskit/algorithms/item_knn.py

View workflow job for this annotation

GitHub Actions / Check Source Code

Ruff (F821)

lenskit/algorithms/item_knn.py:141:2: F821 Undefined name `profile`
def _predict_weighted_average(
model: torch.Tensor,
nrange: tuple[int, int],
Expand All @@ -154,8 +154,11 @@ def _predict_weighted_average(
scores = torch.zeros(nitems)
t_sims = torch.zeros(nitems)
counts = torch.zeros(nitems, dtype=torch.int32)
# these store the similarities and values for neighbors, so we can un-count
nbr_sims = torch.empty((nitems, max_nbrs))
nbr_vals = torch.empty((nitems, max_nbrs))
# and this stores the smallest similarity so far for each item
nbr_min = torch.full((nitems,), torch.finfo().max)

for i, iidx in enumerate(rated):
row = model[int(iidx)]
Expand All @@ -177,6 +180,7 @@ def _predict_weighted_average(
counts[ris_fast] += 1
t_sims[ris_fast] += avs_fast
scores[ris_fast] += vals_fast
nbr_min[ris_fast] = torch.minimum(nbr_min[ris_fast], vs_fast)

# skip early if we're done
if torch.all(fast):
Expand All @@ -185,25 +189,30 @@ def _predict_weighted_average(
# now we have the slow-path items
slow = torch.logical_not(fast)
ris_slow = row_is[slow]
rvs_slow = row_vs[slow]
# which slow items might actually need an update?
exc = rvs_slow > nbr_min[ris_slow]
ris_slow = ris_slow[exc]
rvs_slow = rvs_slow[exc]

# this is brute-force linear search for simplicity right now
# for each, find the neighbor that's the smallest:
mins = torch.argmin(nbr_sims[ris_slow], dim=1)
min_sims, mins = torch.min(nbr_sims[ris_slow], dim=1)
# find the items where this neighbor exceeds the smallest so far:
min_sims = nbr_sims[ris_slow, mins]
exc = min_sims < row_vs[slow]
if not torch.any(exc):
continue
assert torch.all(min_sims < rvs_slow)

# now we need to update values: add in new and remove old
min_vals = nbr_vals[ris_slow, mins]
ris_exc = ris_slow[exc]
ravs_exc = row_avs[slow][exc]
rvs_exc = row_vs[slow][exc]
t_sims[ris_exc] += ravs_exc - min_sims[exc].abs()
scores[ris_exc] += rvs_exc * rate_v[i] - min_vals[exc]
ravs_slow = row_avs[slow][exc]
slow_vals = rvs_slow * rate_v[i]
t_sims[ris_slow] += ravs_slow - min_sims.abs()
scores[ris_slow] += slow_vals - min_vals
# and save
nbr_sims[ris_exc, mins[exc]] = ravs_exc
nbr_vals[ris_exc, mins[exc]] = rvs_exc * rate_v[i]
nbr_sims[ris_slow, mins] = ravs_slow
nbr_vals[ris_slow, mins] = slow_vals
# and now we need to update the saved minimums
nm_sims, _nm_is = torch.min(nbr_sims[ris_slow], dim=1)
nbr_min[ris_slow] = nm_sims

# compute averages for items that pass match the threshold
mask = counts >= min_nbrs
Expand Down Expand Up @@ -430,6 +439,10 @@ def _check_setup(self):
ConfigWarning,
)

if self.min_sim < 0:
_logger.warning("item-item does not currently support negative similarities")
warnings.warn("item-item does not currently support negative similarities")

def fit(self, ratings, **kwargs):
"""
Train a model.
Expand Down

0 comments on commit 790ef3d

Please sign in to comment.