Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Makes RetrievalModelV2 support item tower with transforms (e.g. pre-trained embeddings) #1198

Merged
merged 7 commits into from
Jul 20, 2023
8 changes: 5 additions & 3 deletions examples/05-Retrieval-Model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1616,7 +1616,8 @@
}
],
"source": [
"queries = model.query_embeddings(Dataset(user_features, schema=schema), batch_size=1024, index=Tags.USER_ID)\n",
"queries = model.query_embeddings(Dataset(user_features, schema=schema.select_by_tag(Tags.USER)), \n",
" batch_size=1024, index=Tags.USER_ID)\n",
"query_embs_df = queries.compute(scheduler=\"synchronous\").reset_index()"
]
},
Expand Down Expand Up @@ -1996,7 +1997,8 @@
}
],
"source": [
"item_embs = model.candidate_embeddings(Dataset(item_features, schema=schema), batch_size=1024, index=Tags.ITEM_ID)"
"item_embs = model.candidate_embeddings(Dataset(item_features, schema=schema.select_by_tag(Tags.ITEM)), \n",
" batch_size=1024, index=Tags.ITEM_ID)"
]
},
{
Expand Down Expand Up @@ -2460,7 +2462,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.8.10"
},
"merlin": {
"containers": [
Expand Down
53 changes: 40 additions & 13 deletions merlin/models/tf/core/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from merlin.models.tf.core.prediction import TopKPrediction
from merlin.models.tf.inputs.base import InputBlockV2
from merlin.models.tf.inputs.embedding import CombinerType, EmbeddingTable
from merlin.models.tf.loader import Loader
from merlin.models.tf.models.base import BaseModel, get_output_schema
from merlin.models.tf.outputs.topk import TopKOutput
from merlin.models.tf.transforms.features import PrepareFeatures
Expand Down Expand Up @@ -84,7 +85,7 @@ def __init__(

def encode(
self,
dataset: merlin.io.Dataset,
dataset: Union[merlin.io.Dataset, Loader],
index: Union[str, ColumnSchema, Schema, Tags],
batch_size: int,
**kwargs,
Expand All @@ -93,7 +94,7 @@ def encode(

Parameters
----------
dataset: merlin.io.Dataset
dataset: Union[merlin.io.Dataset, merlin.models.tf.loader.Loader]
The dataset to encode.
index: Union[str, ColumnSchema, Schema, Tags]
The index to use for encoding.
Expand Down Expand Up @@ -127,7 +128,7 @@ def encode(

def batch_predict(
self,
dataset: merlin.io.Dataset,
dataset: Union[merlin.io.Dataset, Loader],
batch_size: int,
output_schema: Optional[Schema] = None,
index: Optional[Union[str, ColumnSchema, Schema, Tags]] = None,
Expand All @@ -137,8 +138,8 @@ def batch_predict(

Parameters
----------
dataset: merlin.io.Dataset
Dataset to predict on.
dataset: Union[merlin.io.Dataset, merlin.models.tf.loader.Loader]
Dataset or Loader to predict on.
batch_size: int
Batch size to use for prediction.

Expand All @@ -161,18 +162,35 @@ def batch_predict(
raise ValueError("Only one column can be used as index")
index = index.first.name

dataset_schema = None
if hasattr(dataset, "schema"):
if not set(self.schema.column_names).issubset(set(dataset.schema.column_names)):
dataset_schema = dataset.schema
data_output_schema = dataset_schema
if isinstance(dataset, Loader):
data_output_schema = dataset.output_schema
if not set(self.schema.column_names).issubset(set(data_output_schema.column_names)):
raise ValueError(
f"Model schema {self.schema.column_names} does not match dataset schema"
+ f" {dataset.schema.column_names}"
+ f" {data_output_schema.column_names}"
)

loader_transforms = None
if isinstance(dataset, Loader):
loader_transforms = dataset.transforms
batch_size = dataset.batch_size
dataset = dataset.dataset

# Check if merlin-dataset is passed
if hasattr(dataset, "to_ddf"):
dataset = dataset.to_ddf()

model_encode = TFModelEncode(self, batch_size=batch_size, **kwargs)
model_encode = TFModelEncode(
self,
batch_size=batch_size,
loader_transforms=loader_transforms,
schema=dataset_schema,
**kwargs,
)

encode_kwargs = {}
if output_schema:
Expand Down Expand Up @@ -583,7 +601,7 @@ def encode_candidates(

def batch_predict(
self,
dataset: merlin.io.Dataset,
dataset: Union[merlin.io.Dataset, Loader],
batch_size: int,
output_schema: Optional[Schema] = None,
**kwargs,
Expand All @@ -592,8 +610,8 @@ def batch_predict(

Parameters
----------
dataset : merlin.io.Dataset
Raw queries features dataset
dataset : Union[merlin.io.Dataset, merlin.models.tf.loader.Loader]
Raw queries features dataset or Loader
batch_size : int
The number of queries to process at each prediction step
output_schema: Schema, optional
Expand All @@ -606,15 +624,24 @@ def batch_predict(
"""
from merlin.models.tf.utils.batch_utils import TFModelEncode

loader_transforms = None
if isinstance(dataset, Loader):
loader_transforms = dataset.transforms
batch_size = dataset.batch_size
dataset = dataset.dataset

dataset_schema = dataset.schema
dataset = dataset.to_ddf()

model_encode = TFModelEncode(
model=self,
batch_size=batch_size,
loader_transforms=loader_transforms,
schema=dataset_schema,
output_names=TopKPrediction.output_names(self.k),
**kwargs,
)

dataset = dataset.to_ddf()

encode_kwargs = {}
if output_schema:
encode_kwargs["filter_input_columns"] = output_schema.column_names
Expand Down
34 changes: 26 additions & 8 deletions merlin/models/tf/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,7 +1553,7 @@ def predict(
return out

def batch_predict(
self, dataset: merlin.io.Dataset, batch_size: int, **kwargs
self, dataset: Union[merlin.io.Dataset, Loader], batch_size: int, **kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the docstrings to expose to the user that we are supporting Loader as a possible input dataset

) -> merlin.io.Dataset:
"""Batched prediction using the Dask.
Parameters
Expand All @@ -1565,20 +1565,38 @@ def batch_predict(
Returns merlin.io.Dataset
-------
"""
dataset_schema = None
if hasattr(dataset, "schema"):
if not set(self.schema.column_names).issubset(set(dataset.schema.column_names)):
dataset_schema = dataset.schema
data_output_schema = dataset_schema
if isinstance(dataset, Loader):
data_output_schema = dataset.output_schema

if not set(self.schema.column_names).issubset(set(data_output_schema.column_names)):
raise ValueError(
f"Model schema {self.schema.column_names} does not match dataset schema"
+ f" {dataset.schema.column_names}"
+ f" {data_output_schema.column_names}"
)

loader_transforms = None
if isinstance(dataset, Loader):
loader_transforms = dataset.transforms
batch_size = dataset.batch_size
dataset = dataset.dataset

# Check if merlin-dataset is passed
if hasattr(dataset, "to_ddf"):
dataset = dataset.to_ddf()

from merlin.models.tf.utils.batch_utils import TFModelEncode

model_encode = TFModelEncode(self, batch_size=batch_size, **kwargs)
model_encode = TFModelEncode(
self,
batch_size=batch_size,
loader_transforms=loader_transforms,
schema=dataset_schema,
**kwargs,
)

# Processing a sample of the dataset with the model encoder
# to get the output dataframe dtypes
Expand Down Expand Up @@ -2510,20 +2528,20 @@ def query_embeddings(

def candidate_embeddings(
self,
dataset: Optional[merlin.io.Dataset] = None,
data: Optional[Union[merlin.io.Dataset, Loader]] = None,
index: Optional[Union[str, ColumnSchema, Schema, Tags]] = None,
**kwargs,
) -> merlin.io.Dataset:
if self.has_candidate_encoder:
candidate = self.candidate_encoder

if dataset is not None and hasattr(candidate, "encode"):
return candidate.encode(dataset, index=index, **kwargs)
if data is not None and hasattr(candidate, "encode"):
return candidate.encode(data, index=index, **kwargs)

if hasattr(candidate, "to_dataset"):
return candidate.to_dataset(**kwargs)

return candidate.encode(dataset, index=index, **kwargs)
return candidate.encode(data, index=index, **kwargs)

if isinstance(self.last, (ContrastiveOutput, CategoricalOutput)):
return self.last.to_dataset()
Expand Down
8 changes: 6 additions & 2 deletions merlin/models/tf/utils/batch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
block_load_func: tp.Optional[tp.Callable[[str], Block]] = None,
schema: tp.Optional[Schema] = None,
output_concat_func=None,
loader_transforms=None,
):
save_path = save_path or tempfile.mkdtemp()
model.save(save_path)
Expand All @@ -95,7 +96,9 @@ def __init__(
super().__init__(
save_path,
output_names,
data_iterator_func=data_iterator_func(self.schema, batch_size=batch_size),
data_iterator_func=data_iterator_func(
self.schema, batch_size=batch_size, loader_transforms=loader_transforms
),
model_load_func=model_load_func,
model_encode_func=model_encode,
output_concat_func=output_concat_func,
Expand Down Expand Up @@ -172,14 +175,15 @@ def encode_output(output: tf.Tensor):
return output.numpy()


def data_iterator_func(schema, batch_size: int = 512):
def data_iterator_func(schema, batch_size: int = 512, loader_transforms=None):
import merlin.io.dataset

def data_iterator(dataset):
return Loader(
merlin.io.dataset.Dataset(dataset, schema=schema),
batch_size=batch_size,
shuffle=False,
transforms=loader_transforms,
)

return data_iterator
19 changes: 19 additions & 0 deletions tests/integration/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest

pytest.importorskip("torch")
pytest.importorskip("pytorch_lightning")
65 changes: 64 additions & 1 deletion tests/unit/tf/models/test_retrieval.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from pathlib import Path

import nvtabular as nvt
import numpy as np
import pytest
import tensorflow as tf

import merlin.models.tf as mm
from merlin.core.dispatch import make_df
from merlin.dataloader.ops.embeddings import EmbeddingOperator
from merlin.io import Dataset
from merlin.models.tf.metrics.topk import (
AvgPrecisionAt,
Expand All @@ -24,6 +25,8 @@


def test_two_tower_shared_embeddings():
nvt = pytest.importorskip("nvtabular")

train = make_df(
{
"user_id": [1, 3, 3, 4, 3, 1, 2, 4, 6, 7, 8, 9] * 100,
Expand Down Expand Up @@ -435,6 +438,66 @@ def test_two_tower_model_topk_evaluation(ecommerce_data: Dataset, run_eagerly):
assert all([metric >= 0 for metric in metrics.values()])


@pytest.mark.parametrize("run_eagerly", [True, False])
def test_two_tower_model_topk_evaluation_with_pretrained_emb(music_streaming_data, run_eagerly):
music_streaming_data.schema = music_streaming_data.schema.select_by_tag([Tags.USER, Tags.ITEM])

cardinality = music_streaming_data.schema["item_category"].int_domain.max + 1
pretrained_embedding = np.random.rand(cardinality, 12)

loader_transforms = [
EmbeddingOperator(
pretrained_embedding,
lookup_key="item_category",
embedding_name="pretrained_category_embeddings",
),
]
loader = mm.Loader(
music_streaming_data,
schema=music_streaming_data.schema.select_by_tag([Tags.USER, Tags.ITEM]),
batch_size=10,
transforms=loader_transforms,
)
schema = loader.output_schema

pretrained_embeddings = mm.PretrainedEmbeddings(
schema.select_by_tag(Tags.EMBEDDING),
output_dims=16,
)

schema = loader.output_schema

query_input = mm.InputBlockV2(schema.select_by_tag(Tags.USER))
query = mm.Encoder(query_input, mm.MLPBlock([4], no_activation_last_layer=True))
candidate_input = mm.InputBlockV2(
schema.select_by_tag(Tags.ITEM), pretrained_embeddings=pretrained_embeddings
)
candidate = mm.Encoder(candidate_input, mm.MLPBlock([4], no_activation_last_layer=True))
model = mm.TwoTowerModelV2(
query,
candidate,
negative_samplers=["in-batch"],
)
model.compile(optimizer="adam", run_eagerly=run_eagerly)
_ = testing_utils.model_test(model, loader)

# Top-K evaluation
candidate_features_data = unique_rows_by_features(music_streaming_data, Tags.ITEM, Tags.ITEM_ID)
loader_candidates = mm.Loader(
candidate_features_data,
batch_size=16,
transforms=loader_transforms,
)

topk_model = model.to_top_k_encoder(loader_candidates, k=20, batch_size=16)
topk_model.compile(run_eagerly=run_eagerly)

loader = mm.Loader(music_streaming_data, batch_size=32).map(mm.ToTarget(schema, "item_id"))

metrics = topk_model.evaluate(loader, return_dict=True)
assert all([metric >= 0 for metric in metrics.values()])


@pytest.mark.parametrize("run_eagerly", [True, False])
@pytest.mark.parametrize("logits_pop_logq_correction", [True, False])
@pytest.mark.parametrize("loss", ["categorical_crossentropy", "bpr-max", "binary_crossentropy"])
Expand Down