From a1d0be2a00248696e1c1c6d9b299ae904b94b9aa Mon Sep 17 00:00:00 2001 From: Gabriel Moreira Date: Thu, 20 Jul 2023 19:42:28 -0300 Subject: [PATCH] Makes RetrievalModelV2 support item tower with transforms (e.g. pre-trained embeddings) (#1198) * Making retrieval model to_top_k_model(), candidate_embeddings() and batch_predict() support Loader with transforms for pre-trained embeddings in item tower * Fixing test error and ensuring all batch_predict() with the new API support Loader with transforms (which include pre-trained embeddings) * Fixing retrieval example, which was using wrong schema to export query and item embeddings * Added missing importorskip on torch and pytorch_lightning for torch integration tests * Skiping a test if nvtabular is available --- examples/05-Retrieval-Model.ipynb | 8 ++-- merlin/models/tf/core/encoder.py | 53 +++++++++++++++------ merlin/models/tf/models/base.py | 34 ++++++++++---- merlin/models/tf/utils/batch_utils.py | 8 +++- tests/integration/torch/__init__.py | 19 ++++++++ tests/unit/tf/models/test_retrieval.py | 65 +++++++++++++++++++++++++- 6 files changed, 160 insertions(+), 27 deletions(-) diff --git a/examples/05-Retrieval-Model.ipynb b/examples/05-Retrieval-Model.ipynb index c2553347e5..5306610e01 100644 --- a/examples/05-Retrieval-Model.ipynb +++ b/examples/05-Retrieval-Model.ipynb @@ -1616,7 +1616,8 @@ } ], "source": [ - "queries = model.query_embeddings(Dataset(user_features, schema=schema), batch_size=1024, index=Tags.USER_ID)\n", + "queries = model.query_embeddings(Dataset(user_features, schema=schema.select_by_tag(Tags.USER)), \n", + " batch_size=1024, index=Tags.USER_ID)\n", "query_embs_df = queries.compute(scheduler=\"synchronous\").reset_index()" ] }, @@ -1996,7 +1997,8 @@ } ], "source": [ - "item_embs = model.candidate_embeddings(Dataset(item_features, schema=schema), batch_size=1024, index=Tags.ITEM_ID)" + "item_embs = model.candidate_embeddings(Dataset(item_features, schema=schema.select_by_tag(Tags.ITEM)), \n", + " batch_size=1024, index=Tags.ITEM_ID)" ] }, { @@ -2460,7 +2462,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.8.10" }, "merlin": { "containers": [ diff --git a/merlin/models/tf/core/encoder.py b/merlin/models/tf/core/encoder.py index 3f32a73f67..5833b2d674 100644 --- a/merlin/models/tf/core/encoder.py +++ b/merlin/models/tf/core/encoder.py @@ -27,6 +27,7 @@ from merlin.models.tf.core.prediction import TopKPrediction from merlin.models.tf.inputs.base import InputBlockV2 from merlin.models.tf.inputs.embedding import CombinerType, EmbeddingTable +from merlin.models.tf.loader import Loader from merlin.models.tf.models.base import BaseModel, get_output_schema from merlin.models.tf.outputs.topk import TopKOutput from merlin.models.tf.transforms.features import PrepareFeatures @@ -84,7 +85,7 @@ def __init__( def encode( self, - dataset: merlin.io.Dataset, + dataset: Union[merlin.io.Dataset, Loader], index: Union[str, ColumnSchema, Schema, Tags], batch_size: int, **kwargs, @@ -93,7 +94,7 @@ def encode( Parameters ---------- - dataset: merlin.io.Dataset + dataset: Union[merlin.io.Dataset, merlin.models.tf.loader.Loader] The dataset to encode. index: Union[str, ColumnSchema, Schema, Tags] The index to use for encoding. @@ -127,7 +128,7 @@ def encode( def batch_predict( self, - dataset: merlin.io.Dataset, + dataset: Union[merlin.io.Dataset, Loader], batch_size: int, output_schema: Optional[Schema] = None, index: Optional[Union[str, ColumnSchema, Schema, Tags]] = None, @@ -137,8 +138,8 @@ def batch_predict( Parameters ---------- - dataset: merlin.io.Dataset - Dataset to predict on. + dataset: Union[merlin.io.Dataset, merlin.models.tf.loader.Loader] + Dataset or Loader to predict on. batch_size: int Batch size to use for prediction. @@ -161,18 +162,35 @@ def batch_predict( raise ValueError("Only one column can be used as index") index = index.first.name + dataset_schema = None if hasattr(dataset, "schema"): - if not set(self.schema.column_names).issubset(set(dataset.schema.column_names)): + dataset_schema = dataset.schema + data_output_schema = dataset_schema + if isinstance(dataset, Loader): + data_output_schema = dataset.output_schema + if not set(self.schema.column_names).issubset(set(data_output_schema.column_names)): raise ValueError( f"Model schema {self.schema.column_names} does not match dataset schema" - + f" {dataset.schema.column_names}" + + f" {data_output_schema.column_names}" ) + loader_transforms = None + if isinstance(dataset, Loader): + loader_transforms = dataset.transforms + batch_size = dataset.batch_size + dataset = dataset.dataset + # Check if merlin-dataset is passed if hasattr(dataset, "to_ddf"): dataset = dataset.to_ddf() - model_encode = TFModelEncode(self, batch_size=batch_size, **kwargs) + model_encode = TFModelEncode( + self, + batch_size=batch_size, + loader_transforms=loader_transforms, + schema=dataset_schema, + **kwargs, + ) encode_kwargs = {} if output_schema: @@ -583,7 +601,7 @@ def encode_candidates( def batch_predict( self, - dataset: merlin.io.Dataset, + dataset: Union[merlin.io.Dataset, Loader], batch_size: int, output_schema: Optional[Schema] = None, **kwargs, @@ -592,8 +610,8 @@ def batch_predict( Parameters ---------- - dataset : merlin.io.Dataset - Raw queries features dataset + dataset : Union[merlin.io.Dataset, merlin.models.tf.loader.Loader] + Raw queries features dataset or Loader batch_size : int The number of queries to process at each prediction step output_schema: Schema, optional @@ -606,15 +624,24 @@ def batch_predict( """ from merlin.models.tf.utils.batch_utils import TFModelEncode + loader_transforms = None + if isinstance(dataset, Loader): + loader_transforms = dataset.transforms + batch_size = dataset.batch_size + dataset = dataset.dataset + + dataset_schema = dataset.schema + dataset = dataset.to_ddf() + model_encode = TFModelEncode( model=self, batch_size=batch_size, + loader_transforms=loader_transforms, + schema=dataset_schema, output_names=TopKPrediction.output_names(self.k), **kwargs, ) - dataset = dataset.to_ddf() - encode_kwargs = {} if output_schema: encode_kwargs["filter_input_columns"] = output_schema.column_names diff --git a/merlin/models/tf/models/base.py b/merlin/models/tf/models/base.py index 365c8fe9d9..ec1f950e77 100644 --- a/merlin/models/tf/models/base.py +++ b/merlin/models/tf/models/base.py @@ -1553,7 +1553,7 @@ def predict( return out def batch_predict( - self, dataset: merlin.io.Dataset, batch_size: int, **kwargs + self, dataset: Union[merlin.io.Dataset, Loader], batch_size: int, **kwargs ) -> merlin.io.Dataset: """Batched prediction using the Dask. Parameters @@ -1565,20 +1565,38 @@ def batch_predict( Returns merlin.io.Dataset ------- """ + dataset_schema = None if hasattr(dataset, "schema"): - if not set(self.schema.column_names).issubset(set(dataset.schema.column_names)): + dataset_schema = dataset.schema + data_output_schema = dataset_schema + if isinstance(dataset, Loader): + data_output_schema = dataset.output_schema + + if not set(self.schema.column_names).issubset(set(data_output_schema.column_names)): raise ValueError( f"Model schema {self.schema.column_names} does not match dataset schema" - + f" {dataset.schema.column_names}" + + f" {data_output_schema.column_names}" ) + loader_transforms = None + if isinstance(dataset, Loader): + loader_transforms = dataset.transforms + batch_size = dataset.batch_size + dataset = dataset.dataset + # Check if merlin-dataset is passed if hasattr(dataset, "to_ddf"): dataset = dataset.to_ddf() from merlin.models.tf.utils.batch_utils import TFModelEncode - model_encode = TFModelEncode(self, batch_size=batch_size, **kwargs) + model_encode = TFModelEncode( + self, + batch_size=batch_size, + loader_transforms=loader_transforms, + schema=dataset_schema, + **kwargs, + ) # Processing a sample of the dataset with the model encoder # to get the output dataframe dtypes @@ -2510,20 +2528,20 @@ def query_embeddings( def candidate_embeddings( self, - dataset: Optional[merlin.io.Dataset] = None, + data: Optional[Union[merlin.io.Dataset, Loader]] = None, index: Optional[Union[str, ColumnSchema, Schema, Tags]] = None, **kwargs, ) -> merlin.io.Dataset: if self.has_candidate_encoder: candidate = self.candidate_encoder - if dataset is not None and hasattr(candidate, "encode"): - return candidate.encode(dataset, index=index, **kwargs) + if data is not None and hasattr(candidate, "encode"): + return candidate.encode(data, index=index, **kwargs) if hasattr(candidate, "to_dataset"): return candidate.to_dataset(**kwargs) - return candidate.encode(dataset, index=index, **kwargs) + return candidate.encode(data, index=index, **kwargs) if isinstance(self.last, (ContrastiveOutput, CategoricalOutput)): return self.last.to_dataset() diff --git a/merlin/models/tf/utils/batch_utils.py b/merlin/models/tf/utils/batch_utils.py index bc48da9ddd..2e83957eb3 100644 --- a/merlin/models/tf/utils/batch_utils.py +++ b/merlin/models/tf/utils/batch_utils.py @@ -74,6 +74,7 @@ def __init__( block_load_func: tp.Optional[tp.Callable[[str], Block]] = None, schema: tp.Optional[Schema] = None, output_concat_func=None, + loader_transforms=None, ): save_path = save_path or tempfile.mkdtemp() model.save(save_path) @@ -95,7 +96,9 @@ def __init__( super().__init__( save_path, output_names, - data_iterator_func=data_iterator_func(self.schema, batch_size=batch_size), + data_iterator_func=data_iterator_func( + self.schema, batch_size=batch_size, loader_transforms=loader_transforms + ), model_load_func=model_load_func, model_encode_func=model_encode, output_concat_func=output_concat_func, @@ -172,7 +175,7 @@ def encode_output(output: tf.Tensor): return output.numpy() -def data_iterator_func(schema, batch_size: int = 512): +def data_iterator_func(schema, batch_size: int = 512, loader_transforms=None): import merlin.io.dataset def data_iterator(dataset): @@ -180,6 +183,7 @@ def data_iterator(dataset): merlin.io.dataset.Dataset(dataset, schema=schema), batch_size=batch_size, shuffle=False, + transforms=loader_transforms, ) return data_iterator diff --git a/tests/integration/torch/__init__.py b/tests/integration/torch/__init__.py index e69de29bb2..598c04b960 100644 --- a/tests/integration/torch/__init__.py +++ b/tests/integration/torch/__init__.py @@ -0,0 +1,19 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest + +pytest.importorskip("torch") +pytest.importorskip("pytorch_lightning") diff --git a/tests/unit/tf/models/test_retrieval.py b/tests/unit/tf/models/test_retrieval.py index 252a666fba..c7531c5f0a 100644 --- a/tests/unit/tf/models/test_retrieval.py +++ b/tests/unit/tf/models/test_retrieval.py @@ -1,11 +1,12 @@ from pathlib import Path -import nvtabular as nvt +import numpy as np import pytest import tensorflow as tf import merlin.models.tf as mm from merlin.core.dispatch import make_df +from merlin.dataloader.ops.embeddings import EmbeddingOperator from merlin.io import Dataset from merlin.models.tf.metrics.topk import ( AvgPrecisionAt, @@ -24,6 +25,8 @@ def test_two_tower_shared_embeddings(): + nvt = pytest.importorskip("nvtabular") + train = make_df( { "user_id": [1, 3, 3, 4, 3, 1, 2, 4, 6, 7, 8, 9] * 100, @@ -435,6 +438,66 @@ def test_two_tower_model_topk_evaluation(ecommerce_data: Dataset, run_eagerly): assert all([metric >= 0 for metric in metrics.values()]) +@pytest.mark.parametrize("run_eagerly", [True, False]) +def test_two_tower_model_topk_evaluation_with_pretrained_emb(music_streaming_data, run_eagerly): + music_streaming_data.schema = music_streaming_data.schema.select_by_tag([Tags.USER, Tags.ITEM]) + + cardinality = music_streaming_data.schema["item_category"].int_domain.max + 1 + pretrained_embedding = np.random.rand(cardinality, 12) + + loader_transforms = [ + EmbeddingOperator( + pretrained_embedding, + lookup_key="item_category", + embedding_name="pretrained_category_embeddings", + ), + ] + loader = mm.Loader( + music_streaming_data, + schema=music_streaming_data.schema.select_by_tag([Tags.USER, Tags.ITEM]), + batch_size=10, + transforms=loader_transforms, + ) + schema = loader.output_schema + + pretrained_embeddings = mm.PretrainedEmbeddings( + schema.select_by_tag(Tags.EMBEDDING), + output_dims=16, + ) + + schema = loader.output_schema + + query_input = mm.InputBlockV2(schema.select_by_tag(Tags.USER)) + query = mm.Encoder(query_input, mm.MLPBlock([4], no_activation_last_layer=True)) + candidate_input = mm.InputBlockV2( + schema.select_by_tag(Tags.ITEM), pretrained_embeddings=pretrained_embeddings + ) + candidate = mm.Encoder(candidate_input, mm.MLPBlock([4], no_activation_last_layer=True)) + model = mm.TwoTowerModelV2( + query, + candidate, + negative_samplers=["in-batch"], + ) + model.compile(optimizer="adam", run_eagerly=run_eagerly) + _ = testing_utils.model_test(model, loader) + + # Top-K evaluation + candidate_features_data = unique_rows_by_features(music_streaming_data, Tags.ITEM, Tags.ITEM_ID) + loader_candidates = mm.Loader( + candidate_features_data, + batch_size=16, + transforms=loader_transforms, + ) + + topk_model = model.to_top_k_encoder(loader_candidates, k=20, batch_size=16) + topk_model.compile(run_eagerly=run_eagerly) + + loader = mm.Loader(music_streaming_data, batch_size=32).map(mm.ToTarget(schema, "item_id")) + + metrics = topk_model.evaluate(loader, return_dict=True) + assert all([metric >= 0 for metric in metrics.values()]) + + @pytest.mark.parametrize("run_eagerly", [True, False]) @pytest.mark.parametrize("logits_pop_logq_correction", [True, False]) @pytest.mark.parametrize("loss", ["categorical_crossentropy", "bpr-max", "binary_crossentropy"])