Skip to content

Commit

Permalink
get matrix close to working
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Jul 12, 2024
1 parent 3e961fc commit 04a9f93
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 88 deletions.
65 changes: 54 additions & 11 deletions lenskit/lenskit/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch
from numpy.typing import ArrayLike

from lenskit.data.matrix import InteractionMatrix
from lenskit.data.matrix import CSRStructure, InteractionMatrix

from . import EntityId
from .tables import NumpyUserItemTable, TorchUserItemTable
Expand All @@ -30,6 +30,15 @@
_log = logging.getLogger(__name__)


class FieldError(KeyError):
"""
The requested field does not exist.
"""

def __init__(self, entity, field):
super().__init__(f"{entity}[{field}]")


class Dataset:
"""
Representation of a data set for LensKit training, evaluation, etc. Data can
Expand Down Expand Up @@ -388,9 +397,7 @@ def interaction_matrix(
format: Literal["structure"],
*,
layout: Literal["csr"] | None = None,
field: str | None = None,
combine: MAT_AGG | None = None,
) -> sps.coo_matrix: ...
) -> CSRStructure: ...
def interaction_matrix(
self,
format: str,
Expand Down Expand Up @@ -431,6 +438,15 @@ def interaction_matrix(
field:
Which field to return in the matrix. Common fields include
``"rating"`` and ``"timestamp"``.
If unspecified (``None``), this will yield an implicit-feedback
indicator matrix, with 1s for observed items; the ``"pandas"``
format will only include user and item columns.
If the ``rating`` field is requested but is not defined in the
underlying data, then this is equivalent to ``"indicator"``,
except that the ``"pandas"`` format will include a ``"rating"``
column of all 1s.
combine:
How to combine multiple observations for a single user-item
pair. Available methods are:
Expand All @@ -449,14 +465,41 @@ def interaction_matrix(
legacy:
``True`` to return a legacy SciPy sparse matrix instead of
sparse array.
original_ids:
If ``True``, return user and item IDs as represented in the
original source data in columns named ``user_id`` and
``item_id``, instead of the user and item numbers typically
returned. Only applicable to the ``pandas`` format. See
:ref:`data-identifiers`.
"""
pass
match format:
case "structure":
if layout and layout != "csr":
raise ValueError(f"unsupported layout {layout} for structure")
if field:
raise ValueError("structure does not support fields")
return self._int_mat_structure()
case "pandas":
if layout and layout != "coo":
raise ValueError(f"unsupported layout {layout} for Pandas")
return self._int_mat_pandas(field)
case _:
raise ValueError(f"unsupported format “{format}”")

def _int_mat_structure(self) -> CSRStructure:
return CSRStructure(self._matrix.user_ptrs, self._matrix.item_nums, self._matrix.shape)

def _int_mat_pandas(self, field: str | None) -> pd.DataFrame:
cols: dict[str, ArrayLike] = {
"user_num": self._matrix.user_nums,
"item_num": self._matrix.item_nums,
}
if field == "rating":
if self._matrix.ratings is not None:
cols["rating"] = self._matrix.ratings
else:
cols["rating"] = np.ones(self._matrix.n_obs)
elif field == "timestamp":
if self._matrix.timestamps is None:
raise FieldError("interaction", field)
cols["timestamp"] = self._matrix.timestamps
elif field:
raise FieldError("interaction", field)
return pd.DataFrame(cols)


def from_interactions_df(
Expand Down
16 changes: 13 additions & 3 deletions lenskit/lenskit/data/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ class InteractionMatrix:
"""
Internal helper class used by :class:`lenskit.data.Dataset` to store the
user-item interaction matrix. The data is stored simultaneously in CSR and
COO format.
COO format. Most code has no need to interact with this class directly —
:class:`~lenskit.data.Dataset` methods provide data in a range of formats.
"""

n_obs: int
n_users: int
n_items: int

Expand All @@ -73,9 +75,9 @@ class InteractionMatrix:
"User (row) offsets / pointers."
item_nums: np.ndarray[int, np.dtype[np.int32]]
"Item (column) numbers."
ratings: Optional[np.ndarray[int, np.dtype[np.float32]]]
ratings: Optional[np.ndarray[int, np.dtype[np.float32]]] = None
"Rating values."
timestamps: Optional[np.ndarray[int, np.dtype[np.int64]]]
timestamps: Optional[np.ndarray[int, np.dtype[np.int64]]] = None
"Timestamps as 64-bit Unix timestamps."

def __init__(
Expand All @@ -94,6 +96,7 @@ def __init__(
if timestamps is not None:
self.timestamps = np.asarray(timestamps, np.int64)

self.n_obs = len(self.user_nums)
self.n_items = n_items
self.n_users = len(user_counts)
cp1 = np.zeros(self.n_users + 1, np.int32)
Expand All @@ -102,6 +105,13 @@ def __init__(
if self.user_ptrs[-1] != len(self.user_nums):
raise ValueError("mismatched counts and array sizes")

@property
def shape(self) -> tuple[int, int]:
"""
The shape of the interaction matrix (rows x columns).
"""
return (self.n_users, self.n_items)


class RatingMatrix(NamedTuple, Generic[M]):
"""
Expand Down
Loading

0 comments on commit 04a9f93

Please sign in to comment.