Skip to content

Commit

Permalink
fix item configuration save
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Dec 7, 2024
1 parent 2f78134 commit 2c49ff8
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions lenskit/lenskit/knn/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class ItemKNNScorer(Component, Trainable):
min_nbrs: int = 1
min_sim: float
save_nbrs: int | None = None
feeedback: FeedbackType
feedback: FeedbackType
block_size: int

items_: Vocabulary
Expand Down Expand Up @@ -94,7 +94,7 @@ def __init__(
self.save_nbrs = save_nbrs
self.block_size = block_size

self.feeedback = feedback
self.feedback = feedback

if self.min_sim < 0:
_log.warning("item-item does not currently support negative similarities")
Expand Down Expand Up @@ -127,10 +127,11 @@ def train(self, data: Dataset):
# 1. Normalize item vectors to be mean-centered and unit-normalized
# 2. Compute similarities with pairwise dot products
self._timer = util.Stopwatch()
_log.info("training IKNN for %d users in %s feedback mode", data.item_count, self.feedback)

_log.debug("[%s] beginning fit, memory use %s", self._timer, util.max_memory())

field = "rating" if self.feeedback == "explicit" else None
field = "rating" if self.feedback == "explicit" else None
init_rmat = data.interaction_matrix("torch", field=field)
n_items = data.item_count
_log.info(
Expand All @@ -145,7 +146,7 @@ def train(self, data: Dataset):
# we operate on *transposed* rating matrix: items on the rows
rmat = init_rmat.transpose(0, 1).to_sparse_csr().to(torch.float64)

if self.feeedback == "explicit":
if self.feedback == "explicit":
rmat, means = normalize_sparse_rows(rmat, "center")
if np.allclose(rmat.values(), 0.0):
_log.warning("normalized ratings are zero, centering is not recommended")
Expand Down Expand Up @@ -220,7 +221,7 @@ def __call__(self, query: QueryInput, items: ItemList) -> ItemList:
n_valid = len(ri_vpos)
_log.debug("user %s: %d of %d rated items in model", query.user_id, n_valid, len(ratings))

if self.feeedback == "explicit":
if self.feedback == "explicit":
ri_vals = ratings.field("rating", "torch")
if ri_vals is None:
raise RuntimeError("explicit-feedback scorer must have ratings")
Expand All @@ -229,12 +230,12 @@ def __call__(self, query: QueryInput, items: ItemList) -> ItemList:
ri_vals = torch.full((n_valid,), 1.0, dtype=torch.float64)

# mean-center the rating array
if self.feeedback == "explicit":
if self.feedback == "explicit":
assert self.item_means_ is not None
ri_vals -= self.item_means_[ri_vpos]

# now compute the predictions
if self.feeedback == "explicit":
if self.feedback == "explicit":
sims = _predict_weighted_average(
self.sim_matrix_, (self.min_nbrs, self.nnbrs), ri_vals, ri_vpos
)
Expand Down

0 comments on commit 2c49ff8

Please sign in to comment.