Skip to content

Commit

Permalink
Making retrieval model to_top_k_model(), candidate_embeddings() and b…
Browse files Browse the repository at this point in the history
…atch_predict() support Loader with transforms for pre-trained embeddings in item tower
  • Loading branch information
gabrielspmoreira committed Jul 11, 2023
1 parent 8a9e5ea commit ea1db75
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 13 deletions.
26 changes: 19 additions & 7 deletions merlin/models/tf/core/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from merlin.models.tf.core.prediction import TopKPrediction
from merlin.models.tf.inputs.base import InputBlockV2
from merlin.models.tf.inputs.embedding import CombinerType, EmbeddingTable
from merlin.models.tf.loader import Loader
from merlin.models.tf.models.base import BaseModel, get_output_schema
from merlin.models.tf.outputs.topk import TopKOutput
from merlin.models.tf.transforms.features import PrepareFeatures
Expand Down Expand Up @@ -84,7 +85,7 @@ def __init__(

def encode(
self,
dataset: merlin.io.Dataset,
dataset: Union[merlin.io.Dataset, Loader],
index: Union[str, ColumnSchema, Schema, Tags],
batch_size: int,
**kwargs,
Expand All @@ -93,7 +94,7 @@ def encode(
Parameters
----------
dataset: merlin.io.Dataset
dataset: Union[merlin.io.Dataset, merlin.models.tf.loader.Loader]
The dataset to encode.
index: Union[str, ColumnSchema, Schema, Tags]
The index to use for encoding.
Expand Down Expand Up @@ -127,7 +128,7 @@ def encode(

def batch_predict(
self,
dataset: merlin.io.Dataset,
dataset: Union[merlin.io.Dataset, Loader],
batch_size: int,
output_schema: Optional[Schema] = None,
index: Optional[Union[str, ColumnSchema, Schema, Tags]] = None,
Expand All @@ -137,7 +138,7 @@ def batch_predict(
Parameters
----------
dataset: merlin.io.Dataset
dataset: Union[merlin.io.Dataset, merlin.models.tf.loader.Loader]
Dataset to predict on.
batch_size: int
Batch size to use for prediction.
Expand All @@ -162,17 +163,28 @@ def batch_predict(
index = index.first.name

if hasattr(dataset, "schema"):
if not set(self.schema.column_names).issubset(set(dataset.schema.column_names)):
data_schema = dataset.schema
if isinstance(dataset, Loader):
data_schema = dataset.output_schema
if not set(self.schema.column_names).issubset(set(data_schema.column_names)):
raise ValueError(
f"Model schema {self.schema.column_names} does not match dataset schema"
+ f" {dataset.schema.column_names}"
+ f" {data_schema.column_names}"
)

loader_transforms = None
if isinstance(dataset, Loader):
loader_transforms = dataset.transforms
batch_size = dataset.batch_size
dataset = dataset.dataset

# Check if merlin-dataset is passed
if hasattr(dataset, "to_ddf"):
dataset = dataset.to_ddf()

model_encode = TFModelEncode(self, batch_size=batch_size, **kwargs)
model_encode = TFModelEncode(
self, batch_size=batch_size, loader_transforms=loader_transforms, **kwargs
)

encode_kwargs = {}
if output_schema:
Expand Down
8 changes: 4 additions & 4 deletions merlin/models/tf/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2510,20 +2510,20 @@ def query_embeddings(

def candidate_embeddings(
self,
dataset: Optional[merlin.io.Dataset] = None,
data: Optional[Union[merlin.io.Dataset, Loader]] = None,
index: Optional[Union[str, ColumnSchema, Schema, Tags]] = None,
**kwargs,
) -> merlin.io.Dataset:
if self.has_candidate_encoder:
candidate = self.candidate_encoder

if dataset is not None and hasattr(candidate, "encode"):
return candidate.encode(dataset, index=index, **kwargs)
if data is not None and hasattr(candidate, "encode"):
return candidate.encode(data, index=index, **kwargs)

if hasattr(candidate, "to_dataset"):
return candidate.to_dataset(**kwargs)

return candidate.encode(dataset, index=index, **kwargs)
return candidate.encode(data, index=index, **kwargs)

if isinstance(self.last, (ContrastiveOutput, CategoricalOutput)):
return self.last.to_dataset()
Expand Down
8 changes: 6 additions & 2 deletions merlin/models/tf/utils/batch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
block_load_func: tp.Optional[tp.Callable[[str], Block]] = None,
schema: tp.Optional[Schema] = None,
output_concat_func=None,
loader_transforms=None,
):
save_path = save_path or tempfile.mkdtemp()
model.save(save_path)
Expand All @@ -95,7 +96,9 @@ def __init__(
super().__init__(
save_path,
output_names,
data_iterator_func=data_iterator_func(self.schema, batch_size=batch_size),
data_iterator_func=data_iterator_func(
self.schema, batch_size=batch_size, loader_transforms=loader_transforms
),
model_load_func=model_load_func,
model_encode_func=model_encode,
output_concat_func=output_concat_func,
Expand Down Expand Up @@ -172,14 +175,15 @@ def encode_output(output: tf.Tensor):
return output.numpy()


def data_iterator_func(schema, batch_size: int = 512):
def data_iterator_func(schema, batch_size: int = 512, loader_transforms=None):
import merlin.io.dataset

def data_iterator(dataset):
return Loader(
merlin.io.dataset.Dataset(dataset, schema=schema),
batch_size=batch_size,
shuffle=False,
transforms=loader_transforms,
)

return data_iterator
62 changes: 62 additions & 0 deletions tests/unit/tf/models/test_retrieval.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from pathlib import Path

import numpy as np
import nvtabular as nvt
import pytest
import tensorflow as tf

import merlin.models.tf as mm
from merlin.core.dispatch import make_df
from merlin.dataloader.ops.embeddings import EmbeddingOperator
from merlin.io import Dataset
from merlin.models.tf.metrics.topk import (
AvgPrecisionAt,
Expand Down Expand Up @@ -435,6 +437,66 @@ def test_two_tower_model_topk_evaluation(ecommerce_data: Dataset, run_eagerly):
assert all([metric >= 0 for metric in metrics.values()])


@pytest.mark.parametrize("run_eagerly", [True, False])
def test_two_tower_model_topk_evaluation_with_pretrained_emb(music_streaming_data, run_eagerly):
music_streaming_data.schema = music_streaming_data.schema.select_by_tag([Tags.USER, Tags.ITEM])

cardinality = music_streaming_data.schema["item_category"].int_domain.max + 1
pretrained_embedding = np.random.rand(cardinality, 12)

loader_transforms = [
EmbeddingOperator(
pretrained_embedding,
lookup_key="item_category",
embedding_name="pretrained_category_embeddings",
),
]
loader = mm.Loader(
music_streaming_data,
schema=music_streaming_data.schema.select_by_tag([Tags.USER, Tags.ITEM]),
batch_size=10,
transforms=loader_transforms,
)
schema = loader.output_schema

pretrained_embeddings = mm.PretrainedEmbeddings(
schema.select_by_tag(Tags.EMBEDDING),
output_dims=16,
)

schema = loader.output_schema

query_input = mm.InputBlockV2(schema.select_by_tag(Tags.USER))
query = mm.Encoder(query_input, mm.MLPBlock([4], no_activation_last_layer=True))
candidate_input = mm.InputBlockV2(
schema.select_by_tag(Tags.ITEM), pretrained_embeddings=pretrained_embeddings
)
candidate = mm.Encoder(candidate_input, mm.MLPBlock([4], no_activation_last_layer=True))
model = mm.TwoTowerModelV2(
query,
candidate,
negative_samplers=["in-batch"],
)
model.compile(optimizer="adam", run_eagerly=run_eagerly)
_ = testing_utils.model_test(model, loader)

# Top-K evaluation
candidate_features_data = unique_rows_by_features(music_streaming_data, Tags.ITEM, Tags.ITEM_ID)
loader_candidates = mm.Loader(
candidate_features_data,
batch_size=16,
transforms=loader_transforms,
)

topk_model = model.to_top_k_encoder(loader_candidates, k=20, batch_size=16)
topk_model.compile(run_eagerly=run_eagerly)

loader = mm.Loader(music_streaming_data, batch_size=32).map(mm.ToTarget(schema, "item_id"))

metrics = topk_model.evaluate(loader, return_dict=True)
assert all([metric >= 0 for metric in metrics.values()])


@pytest.mark.parametrize("run_eagerly", [True, False])
@pytest.mark.parametrize("logits_pop_logq_correction", [True, False])
@pytest.mark.parametrize("loss", ["categorical_crossentropy", "bpr-max", "binary_crossentropy"])
Expand Down

0 comments on commit ea1db75

Please sign in to comment.