Skip to content

Commit

Permalink
Merge pull request #543 from mdekstrand/feature/fix-bulk-analysis
Browse files Browse the repository at this point in the history
Correct item list collection merging and implement bulk analysis grouped summaries
  • Loading branch information
mdekstrand authored Dec 7, 2024
2 parents 2d66234 + 94d3a52 commit 8cf227c
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 8 deletions.
4 changes: 2 additions & 2 deletions lenskit/lenskit/data/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ def add_from(self, other: ItemListCollection, **fields: ID):
"""
for key, list in other:
if fields:
fields = key._asdict() | fields
key = self._key_class(**fields)
cf = key._asdict() | fields
key = self._key_class(**cf)
self._add(key, list)

def _add(self, key: K, list: ItemList):
Expand Down
17 changes: 13 additions & 4 deletions lenskit/lenskit/metrics/bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def global_metrics(self) -> pd.Series:
"""
return self._global_metrics

def list_metrics(self, *, fill_missing=True) -> pd.DataFrame:
def list_metrics(self, fill_missing=True) -> pd.DataFrame:
"""
Get the per-list scores of the results. This is a data frame with one
row per list (with the list key on the inded), and one metric per
Expand All @@ -91,7 +91,7 @@ def list_metrics(self, *, fill_missing=True) -> pd.DataFrame:
"""
return self._list_metrics.fillna(self._defaults)

def list_summary(self) -> pd.DataFrame:
def list_summary(self, *keys: str) -> pd.DataFrame:
"""
Sumamry statistics for the per-list metrics. Each metric is on its own
row, with columns reporting the following:
Expand All @@ -105,10 +105,19 @@ def list_summary(self) -> pd.DataFrame:
Additional columns are added based on other options. Missing metric
values are filled with their defaults before computing statistics.
Args:
keys:
Identifiers for different conditions that should be reported
separately (grouping keys for the final result).
"""
scores = self.list_metrics(fill_missing=True)
df = scores.agg(["mean", "median", "std"]).T
df.index.name = "metric"
if keys:
df = scores.groupby(list(keys)).agg(["mean", "median", "std"]).stack(level=0)
assert isinstance(df, pd.DataFrame)
else:
df = scores.agg(["mean", "median", "std"]).T
df.index.name = "metric"
return df

def merge_from(self, other: RunAnalysisResult):
Expand Down
7 changes: 5 additions & 2 deletions lenskit/tests/data/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,16 @@ def test_lookup_projected():
def test_add_from():
ilc = ItemListCollection(["model", "user_id"])

ilc1 = ItemListCollection.from_dict({72: ItemList(["a", "b"])}, key="user_id")
ilc1 = ItemListCollection.from_dict({72: ItemList(["a", "b"]), 48: ItemList()}, key="user_id")
ilc.add_from(ilc1, model="foo")

assert len(ilc) == 1
assert len(ilc) == 2
il = ilc.lookup(("foo", 72))
assert il is not None
assert il.ids().tolist() == ["a", "b"]
il = ilc.lookup(("foo", 48))
assert il is not None
assert len(il) == 0


def test_from_df(rng, ml_ratings: pd.DataFrame):
Expand Down
23 changes: 23 additions & 0 deletions lenskit/tests/eval/test_bulk_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,26 @@ def test_recs(demo_recs):
print(stats)
for m in bms.metrics:
assert stats.loc[m.label, "mean"] == approx(scores[m.label].mean())


def test_recs_multi(demo_recs):
split, recs = demo_recs

il2 = ItemListCollection(["rep", "user_id"])
il2.add_from(recs, rep=1)
il2.add_from(recs, rep=2)

bms = RunAnalysis()
bms.add_metric(ListLength())
bms.add_metric(Precision())
bms.add_metric(NDCG())
bms.add_metric(RBP)
bms.add_metric(RecipRank)

metrics = bms.compute(il2, split.test)
scores = metrics.list_metrics()
stats = metrics.list_summary("rep")
print(stats)
for m in bms.metrics:
assert stats.loc[(1, m.label), "mean"] == approx(scores.loc[1, m.label].mean())
assert stats.loc[(2, m.label), "mean"] == approx(scores.loc[2, m.label].mean())

0 comments on commit 8cf227c

Please sign in to comment.