Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Makes RetrievalModelV2 support item tower with transforms (e.g. pre-trained embeddings) #1198

Merged
merged 7 commits into from
Jul 20, 2023

Conversation

gabrielspmoreira
Copy link
Member

@gabrielspmoreira gabrielspmoreira commented Jul 11, 2023

Goals ⚽

  • Currently the methods based on batch_predict() support only a Dataset as input. That prevents performing batch_predict() using a Loader, which might contain transforms that are important to the model
  • A particular use case not supported because of that limitation is generating the item embeddings from the item tower of RetrievalModelV2, when among item features there are pre-trained embeddings loaded with EmbeddingOperator transform.
  • This PR, adds support to Loader with transforms to all batch_predict() based methods (e.g. to_top_k_encoder().

Implementation Details 🚧

  • The batch_predict() method and associated methods were changed to accept Loader besides Dataset. When a Loader is passed, it cascades any transforms (like the EmbeddingOperator) to the Loader that is created inside the data_iterator_func.
  • It also checks if the model schema matches the Loader output schema (which is different from the dataset schema).

Testing Details 🔍

  • Added the test_two_tower_model_topk_evaluation_with_pretrained_emb to ensure that it is possible to have pre-trained embeddings in the item tower and convert the retrieval_model.to_top_k_encoder().

@gabrielspmoreira gabrielspmoreira marked this pull request as draft July 11, 2023 16:29
@github-actions
Copy link

Documentation preview

https://nvidia-merlin.github.io/models/review/pr-1198

…atch_predict() support Loader with transforms for pre-trained embeddings in item tower
…upport Loader with transforms (which include pre-trained embeddings)
@gabrielspmoreira gabrielspmoreira force-pushed the tf/pretrained_emb_item_tower branch from ea1db75 to d4dc945 Compare July 17, 2023 20:24
@gabrielspmoreira gabrielspmoreira marked this pull request as ready for review July 17, 2023 20:24
@gabrielspmoreira gabrielspmoreira added the enhancement New feature or request label Jul 17, 2023
@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@rnyak
Copy link
Contributor

rnyak commented Jul 20, 2023

@oliverholworthy hello. do you think you can review this PR? this is an important feature that'd be needed by users and us. thanks.

Copy link
Contributor

@sararb sararb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes lgtm. Thanks for the PR!!

@@ -1553,7 +1553,7 @@ def predict(
return out

def batch_predict(
self, dataset: merlin.io.Dataset, batch_size: int, **kwargs
self, dataset: Union[merlin.io.Dataset, Loader], batch_size: int, **kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the docstrings to expose to the user that we are supporting Loader as a possible input dataset

@gabrielspmoreira gabrielspmoreira merged commit a1d0be2 into main Jul 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants