diff --git a/merlin/models/tf/core/encoder.py b/merlin/models/tf/core/encoder.py index 3f32a73f67..19d620418d 100644 --- a/merlin/models/tf/core/encoder.py +++ b/merlin/models/tf/core/encoder.py @@ -27,6 +27,7 @@ from merlin.models.tf.core.prediction import TopKPrediction from merlin.models.tf.inputs.base import InputBlockV2 from merlin.models.tf.inputs.embedding import CombinerType, EmbeddingTable +from merlin.models.tf.loader import Loader from merlin.models.tf.models.base import BaseModel, get_output_schema from merlin.models.tf.outputs.topk import TopKOutput from merlin.models.tf.transforms.features import PrepareFeatures @@ -84,7 +85,7 @@ def __init__( def encode( self, - dataset: merlin.io.Dataset, + dataset: Union[merlin.io.Dataset, Loader], index: Union[str, ColumnSchema, Schema, Tags], batch_size: int, **kwargs, @@ -93,7 +94,7 @@ def encode( Parameters ---------- - dataset: merlin.io.Dataset + dataset: Union[merlin.io.Dataset, merlin.models.tf.loader.Loader] The dataset to encode. index: Union[str, ColumnSchema, Schema, Tags] The index to use for encoding. @@ -127,7 +128,7 @@ def encode( def batch_predict( self, - dataset: merlin.io.Dataset, + dataset: Union[merlin.io.Dataset, Loader], batch_size: int, output_schema: Optional[Schema] = None, index: Optional[Union[str, ColumnSchema, Schema, Tags]] = None, @@ -137,7 +138,7 @@ def batch_predict( Parameters ---------- - dataset: merlin.io.Dataset + dataset: Union[merlin.io.Dataset, merlin.models.tf.loader.Loader] Dataset to predict on. batch_size: int Batch size to use for prediction. @@ -162,17 +163,28 @@ def batch_predict( index = index.first.name if hasattr(dataset, "schema"): - if not set(self.schema.column_names).issubset(set(dataset.schema.column_names)): + data_schema = dataset.schema + if isinstance(dataset, Loader): + data_schema = dataset.output_schema + if not set(self.schema.column_names).issubset(set(data_schema.column_names)): raise ValueError( f"Model schema {self.schema.column_names} does not match dataset schema" - + f" {dataset.schema.column_names}" + + f" {data_schema.column_names}" ) + loader_transforms = None + if isinstance(dataset, Loader): + loader_transforms = dataset.transforms + batch_size = dataset.batch_size + dataset = dataset.dataset + # Check if merlin-dataset is passed if hasattr(dataset, "to_ddf"): dataset = dataset.to_ddf() - model_encode = TFModelEncode(self, batch_size=batch_size, **kwargs) + model_encode = TFModelEncode( + self, batch_size=batch_size, loader_transforms=loader_transforms, **kwargs + ) encode_kwargs = {} if output_schema: diff --git a/merlin/models/tf/models/base.py b/merlin/models/tf/models/base.py index 365c8fe9d9..808a8852ae 100644 --- a/merlin/models/tf/models/base.py +++ b/merlin/models/tf/models/base.py @@ -2510,20 +2510,20 @@ def query_embeddings( def candidate_embeddings( self, - dataset: Optional[merlin.io.Dataset] = None, + data: Optional[Union[merlin.io.Dataset, Loader]] = None, index: Optional[Union[str, ColumnSchema, Schema, Tags]] = None, **kwargs, ) -> merlin.io.Dataset: if self.has_candidate_encoder: candidate = self.candidate_encoder - if dataset is not None and hasattr(candidate, "encode"): - return candidate.encode(dataset, index=index, **kwargs) + if data is not None and hasattr(candidate, "encode"): + return candidate.encode(data, index=index, **kwargs) if hasattr(candidate, "to_dataset"): return candidate.to_dataset(**kwargs) - return candidate.encode(dataset, index=index, **kwargs) + return candidate.encode(data, index=index, **kwargs) if isinstance(self.last, (ContrastiveOutput, CategoricalOutput)): return self.last.to_dataset() diff --git a/merlin/models/tf/utils/batch_utils.py b/merlin/models/tf/utils/batch_utils.py index bc48da9ddd..2e83957eb3 100644 --- a/merlin/models/tf/utils/batch_utils.py +++ b/merlin/models/tf/utils/batch_utils.py @@ -74,6 +74,7 @@ def __init__( block_load_func: tp.Optional[tp.Callable[[str], Block]] = None, schema: tp.Optional[Schema] = None, output_concat_func=None, + loader_transforms=None, ): save_path = save_path or tempfile.mkdtemp() model.save(save_path) @@ -95,7 +96,9 @@ def __init__( super().__init__( save_path, output_names, - data_iterator_func=data_iterator_func(self.schema, batch_size=batch_size), + data_iterator_func=data_iterator_func( + self.schema, batch_size=batch_size, loader_transforms=loader_transforms + ), model_load_func=model_load_func, model_encode_func=model_encode, output_concat_func=output_concat_func, @@ -172,7 +175,7 @@ def encode_output(output: tf.Tensor): return output.numpy() -def data_iterator_func(schema, batch_size: int = 512): +def data_iterator_func(schema, batch_size: int = 512, loader_transforms=None): import merlin.io.dataset def data_iterator(dataset): @@ -180,6 +183,7 @@ def data_iterator(dataset): merlin.io.dataset.Dataset(dataset, schema=schema), batch_size=batch_size, shuffle=False, + transforms=loader_transforms, ) return data_iterator diff --git a/tests/unit/tf/models/test_retrieval.py b/tests/unit/tf/models/test_retrieval.py index 252a666fba..2e461d7c0c 100644 --- a/tests/unit/tf/models/test_retrieval.py +++ b/tests/unit/tf/models/test_retrieval.py @@ -1,11 +1,13 @@ from pathlib import Path +import numpy as np import nvtabular as nvt import pytest import tensorflow as tf import merlin.models.tf as mm from merlin.core.dispatch import make_df +from merlin.dataloader.ops.embeddings import EmbeddingOperator from merlin.io import Dataset from merlin.models.tf.metrics.topk import ( AvgPrecisionAt, @@ -435,6 +437,66 @@ def test_two_tower_model_topk_evaluation(ecommerce_data: Dataset, run_eagerly): assert all([metric >= 0 for metric in metrics.values()]) +@pytest.mark.parametrize("run_eagerly", [True, False]) +def test_two_tower_model_topk_evaluation_with_pretrained_emb(music_streaming_data, run_eagerly): + music_streaming_data.schema = music_streaming_data.schema.select_by_tag([Tags.USER, Tags.ITEM]) + + cardinality = music_streaming_data.schema["item_category"].int_domain.max + 1 + pretrained_embedding = np.random.rand(cardinality, 12) + + loader_transforms = [ + EmbeddingOperator( + pretrained_embedding, + lookup_key="item_category", + embedding_name="pretrained_category_embeddings", + ), + ] + loader = mm.Loader( + music_streaming_data, + schema=music_streaming_data.schema.select_by_tag([Tags.USER, Tags.ITEM]), + batch_size=10, + transforms=loader_transforms, + ) + schema = loader.output_schema + + pretrained_embeddings = mm.PretrainedEmbeddings( + schema.select_by_tag(Tags.EMBEDDING), + output_dims=16, + ) + + schema = loader.output_schema + + query_input = mm.InputBlockV2(schema.select_by_tag(Tags.USER)) + query = mm.Encoder(query_input, mm.MLPBlock([4], no_activation_last_layer=True)) + candidate_input = mm.InputBlockV2( + schema.select_by_tag(Tags.ITEM), pretrained_embeddings=pretrained_embeddings + ) + candidate = mm.Encoder(candidate_input, mm.MLPBlock([4], no_activation_last_layer=True)) + model = mm.TwoTowerModelV2( + query, + candidate, + negative_samplers=["in-batch"], + ) + model.compile(optimizer="adam", run_eagerly=run_eagerly) + _ = testing_utils.model_test(model, loader) + + # Top-K evaluation + candidate_features_data = unique_rows_by_features(music_streaming_data, Tags.ITEM, Tags.ITEM_ID) + loader_candidates = mm.Loader( + candidate_features_data, + batch_size=16, + transforms=loader_transforms, + ) + + topk_model = model.to_top_k_encoder(loader_candidates, k=20, batch_size=16) + topk_model.compile(run_eagerly=run_eagerly) + + loader = mm.Loader(music_streaming_data, batch_size=32).map(mm.ToTarget(schema, "item_id")) + + metrics = topk_model.evaluate(loader, return_dict=True) + assert all([metric >= 0 for metric in metrics.values()]) + + @pytest.mark.parametrize("run_eagerly", [True, False]) @pytest.mark.parametrize("logits_pop_logq_correction", [True, False]) @pytest.mark.parametrize("loss", ["categorical_crossentropy", "bpr-max", "binary_crossentropy"])