-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Makes RetrievalModelV2 support item tower with transforms (e.g. pre-trained embeddings) #1198
Conversation
Documentation preview |
…atch_predict() support Loader with transforms for pre-trained embeddings in item tower
…upport Loader with transforms (which include pre-trained embeddings)
ea1db75
to
d4dc945
Compare
…y and item embeddings
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
@oliverholworthy hello. do you think you can review this PR? this is an important feature that'd be needed by users and us. thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes lgtm. Thanks for the PR!!
@@ -1553,7 +1553,7 @@ def predict( | |||
return out | |||
|
|||
def batch_predict( | |||
self, dataset: merlin.io.Dataset, batch_size: int, **kwargs | |||
self, dataset: Union[merlin.io.Dataset, Loader], batch_size: int, **kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you update the docstrings to expose to the user that we are supporting Loader
as a possible input dataset
Goals ⚽
batch_predict()
support only aDataset
as input. That prevents performingbatch_predict()
using aLoader
, which might contain transforms that are important to the modelRetrievalModelV2
, when among item features there are pre-trained embeddings loaded withEmbeddingOperator
transform.Loader
withtransforms
to allbatch_predict()
based methods (e.g.to_top_k_encoder()
.Implementation Details 🚧
batch_predict()
method and associated methods were changed to acceptLoader
besidesDataset
. When aLoader
is passed, it cascades any transforms (like theEmbeddingOperator
) to theLoader
that is created inside thedata_iterator_func
.Testing Details 🔍
test_two_tower_model_topk_evaluation_with_pretrained_emb
to ensure that it is possible to have pre-trained embeddings in the item tower and convert theretrieval_model.to_top_k_encoder()
.