Skip to content

Commit

Permalink
fix: model based sampler cache
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed May 14, 2024
1 parent 1091555 commit f0c2a7f
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/xpmir/letor/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import json
from pathlib import Path
from typing import Iterator, List, Tuple, Dict, Any
Expand Down Expand Up @@ -173,6 +174,7 @@ def _itertopics(
# Retrieve documents
skipped = 0
for query in tqdm(queries):
q_fp = io.StringIO()
qassessments = assessments.get(query[IDItem].id, None)
if not qassessments:
skipped += 1
Expand All @@ -185,7 +187,7 @@ def _itertopics(
positives = []
for docno, rel in qassessments.items():
if rel > 0:
fp.write(
q_fp.write(
f"{query.text if not positives else ''}"
f"\t{docno}\t0.\t{rel}\n"
)
Expand All @@ -211,7 +213,7 @@ def _itertopics(
continue

negatives.append((sd.document[IDItem].id, rel, sd.score))
fp.write(f"\t{sd.document[IDItem].id}\t{sd.score}\t{rel}\n")
q_fp.write(f"\t{sd.document[IDItem].id}\t{sd.score}\t{rel}\n")

if not negatives:
self.logger.warning(
Expand All @@ -222,6 +224,10 @@ def _itertopics(
continue

assert len(positives) > 0 and len(negatives) > 0

# Write in cache, and yield
fp.write(q_fp.getvalue())
q_fp.close()
yield query.text, positives, negatives

# Finally, move the cache file in place...
Expand Down

0 comments on commit f0c2a7f

Please sign in to comment.