Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dataset-based data splitting code #461

Merged
merged 22 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 0 additions & 85 deletions docs/crossfold.rst

This file was deleted.

2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Resources
:caption: Running Experiments

data
crossfold
splitting
batch
evaluation/index
documenting
Expand Down
11 changes: 11 additions & 0 deletions docs/releases/2024.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ Significant Changes
without round-tripping through Pandas and NumPy, and keep this transparent
to client code).

* Data splitting for offline evaluation has been moved into
:mod:`lenskit.splitting`, updated to work with data sets and item lists
instead of raw data frames, and splitting functions have been renamed (e.g.
``rows`` to ``records``) and had parameters updated for clarity and
consistency.

* Where Pandas data frames are still used, the standard user and item columns
have been renamed to ``user_id`` and ``item_id`` respectively, with
``user_num`` and ``item_num`` for 0-based user and item numbers. This is to
remove ambiguity about how users and items are being referenced.

* **PyTorch**. LensKit now uses PyTorch to implement most of its algorithms,
instead of Numba-accelerated NumPy code. Algorithms using PyTorch are:

Expand Down
92 changes: 92 additions & 0 deletions docs/splitting.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
Splitting Data
==============

.. module:: lenskit.splitting

The LKPY `splitting` module splits data sets for offline evaluation using
cross-validation and other strategies. The various splitters are implemented as
functions that operate on a :class:`~lenskit.data.Dataset` and return one or
more train-test splits (as :class:`TTSplit` objects).

.. versionchanged:: 2024.1
Data splitting was moved from ``lenskit.crossfold`` to the ``lenskit.splitting``
module and functions were renamed and had their interfaces revised.

Experiment code should generally use these functions to prepare train-test files
for training and evaluating algorithms. For example, the following will perform
a user-based 5-fold cross-validation as was the default in the old LensKit:

.. code:: python

import pandas as pd
from lenskit.data import load_movielens
from lenskit.splitting import crossfold_users, SampleN, dict_to_df
dataset = load_movielens('data/ml-20m.zip')
for i, tp in enumerate(crossfold_users(ratings, 5, SampleN(5))):
tp.train_df.to_parquet(f'ml-20m.exp/train-{i}.parquet')
tp.test_df.to_parquet(f'ml-20m.exp/test-{i}.parquet')

Record-based Random Splitting
-----------------------------

The simplest preparation methods sample or partition the records in the input
data. A 5-fold :func:`crossfold_records` split will result in 5 splits, each of
which extracts 20% of the user-item interaction records for testing and leaves
80% for training.

.. note::

When a dataset has repeated interactions, these functions operate only on
the *matrix* view of the data (user-item observations are deduplicated).
Specifically, they operate on the results of calling
:meth:`~lenskit.data.Dataset.interaction_matrix` with ``format="pandas"``
and ``field="all"``.

.. autofunction:: crossfold_records

.. autofunction:: sample_records

User-based Splitting
--------------------

It's often desirable to use users, instead of raw rows, as the basis for
splitting data. This allows you to control the experimental conditions on a
user-by-user basis, e.g. by making sure each user is tested with the same number
of ratings. These methods require that the input data frame have a `user`
column with the user names or identifiers.

The algorithm used by each is as follows:

1. Sample or partition the set of user IDs into *n* sets of test users.
2. For each set of test users, select a set of that user's rows to be test rows.
3. Create a training set for each test set consisting of the non-selected rows
from each of that set's test users, along with all rows from each non-test
user.

.. autofunction:: crossfold_users

.. autofunction:: sample_users

Selecting user holdout rows
~~~~~~~~~~~~~~~~~~~~~~~~~~~

These functions each take a `method` to decide how select each user's test rows. The method
is a function that takes an item list (containing just the user's rows) and returns the
test rows.

We provide several holdout method factories:

.. autofunction:: SampleN
.. autofunction:: SampleFrac
.. autofunction:: LastN
.. autofunction:: LastFrac

Utility Classes
---------------

.. autoclass:: lenskit.splitting.holdout.HoldoutMethod
:members:
:special-members: __call__

.. autoclass:: TTSplit
:members:
21 changes: 14 additions & 7 deletions lenskit/lenskit/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def count(self, what: str) -> int:

* users
* items
* pairs (observed user-item pairs)
* interactions
* ratings
"""
Expand Down Expand Up @@ -338,6 +339,11 @@ def interaction_matrix(
underlying data, then this is equivalent to ``"indicator"``,
except that the ``"pandas"`` format will include a ``"rating"``
column of all 1s.

The ``"pandas"`` format also supports the special field name
``"all"`` to return a data frame with all available fields. When
``field="all"``, a field named ``count`` (if defined) is
combined with the ``sum`` method, and other fields use ``last``.
combine:
How to combine multiple observations for a single user-item
pair. Available methods are:
Expand All @@ -348,7 +354,8 @@ def interaction_matrix(
field.
* ``"sum"`` — sum the field values.
* ``"first"``, ``"last"`` — take the first or last value seen
(in timestamp order, if timestamps are defined).
(in timestamp order, if timestamps are defined; otherwise,
their order in the original input).
layout:
The layout for a sparse matrix. Can be either ``csr`` or
``coo``, or ``None`` to use the default for the specified
Expand Down Expand Up @@ -488,8 +495,8 @@ def user_stats(self) -> pd.DataFrame:

class MatrixDataset(Dataset):
"""
Dataset implementation using an in-memory rating or implicit-feedback
matrix.
Dataset implementation using an in-memory rating or implicit-feedback matrix
(with no duplicate interactions).

.. note::
Client code generally should not construct this class directly. Instead
Expand Down Expand Up @@ -554,7 +561,7 @@ def count(self, what: str) -> int:
return self._users.size
case "items":
return self._items.size
case "interactions" | "ratings":
case "pairs" | "interactions" | "ratings":
return self._matrix.n_obs
case _:
raise KeyError(f"unknown entity type {what}")
Expand Down Expand Up @@ -603,16 +610,16 @@ def _int_mat_pandas(self, field: str | None, original_ids: bool) -> pd.DataFrame
"user_num": self._matrix.user_nums,
"item_num": self._matrix.item_nums,
}
if field == "rating":
if field == "all" or field == "rating":
if self._matrix.ratings is not None:
cols["rating"] = self._matrix.ratings
else:
cols["rating"] = np.ones(self._matrix.n_obs)
elif field == "timestamp":
elif field == "all" or field == "timestamp":
if self._matrix.timestamps is None:
raise FieldError("interaction", field)
cols["timestamp"] = self._matrix.timestamps
elif field:
elif field and field != "all":
raise FieldError("interaction", field)
return pd.DataFrame(cols)

Expand Down
Loading
Loading