Skip to content

Commit

Permalink
Fixing test error and ensuring all batch_predict() with the new API s…
Browse files Browse the repository at this point in the history
…upport Loader with transforms (which include pre-trained embeddings)
  • Loading branch information
gabrielspmoreira committed Jul 17, 2023
1 parent 7250eca commit d4dc945
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 15 deletions.
37 changes: 26 additions & 11 deletions merlin/models/tf/core/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def batch_predict(
Parameters
----------
dataset: Union[merlin.io.Dataset, merlin.models.tf.loader.Loader]
Dataset to predict on.
Dataset or Loader to predict on.
batch_size: int
Batch size to use for prediction.
Expand All @@ -162,14 +162,16 @@ def batch_predict(
raise ValueError("Only one column can be used as index")
index = index.first.name

dataset_schema = None
if hasattr(dataset, "schema"):
data_schema = dataset.schema
dataset_schema = dataset.schema
data_output_schema = dataset_schema
if isinstance(dataset, Loader):
data_schema = dataset.output_schema
if not set(self.schema.column_names).issubset(set(data_schema.column_names)):
data_output_schema = dataset.output_schema
if not set(self.schema.column_names).issubset(set(data_output_schema.column_names)):
raise ValueError(
f"Model schema {self.schema.column_names} does not match dataset schema"
+ f" {data_schema.column_names}"
+ f" {data_output_schema.column_names}"
)

loader_transforms = None
Expand All @@ -183,7 +185,11 @@ def batch_predict(
dataset = dataset.to_ddf()

model_encode = TFModelEncode(
self, batch_size=batch_size, loader_transforms=loader_transforms, **kwargs
self,
batch_size=batch_size,
loader_transforms=loader_transforms,
schema=dataset_schema,
**kwargs,
)

encode_kwargs = {}
Expand Down Expand Up @@ -595,7 +601,7 @@ def encode_candidates(

def batch_predict(
self,
dataset: merlin.io.Dataset,
dataset: Union[merlin.io.Dataset, Loader],
batch_size: int,
output_schema: Optional[Schema] = None,
**kwargs,
Expand All @@ -604,8 +610,8 @@ def batch_predict(
Parameters
----------
dataset : merlin.io.Dataset
Raw queries features dataset
dataset : Union[merlin.io.Dataset, merlin.models.tf.loader.Loader]
Raw queries features dataset or Loader
batch_size : int
The number of queries to process at each prediction step
output_schema: Schema, optional
Expand All @@ -618,15 +624,24 @@ def batch_predict(
"""
from merlin.models.tf.utils.batch_utils import TFModelEncode

loader_transforms = None
if isinstance(dataset, Loader):
loader_transforms = dataset.transforms
batch_size = dataset.batch_size
dataset = dataset.dataset

dataset_schema = dataset.schema
dataset = dataset.to_ddf()

model_encode = TFModelEncode(
model=self,
batch_size=batch_size,
loader_transforms=loader_transforms,
schema=dataset_schema,
output_names=TopKPrediction.output_names(self.k),
**kwargs,
)

dataset = dataset.to_ddf()

encode_kwargs = {}
if output_schema:
encode_kwargs["filter_input_columns"] = output_schema.column_names
Expand Down
26 changes: 22 additions & 4 deletions merlin/models/tf/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,7 +1553,7 @@ def predict(
return out

def batch_predict(
self, dataset: merlin.io.Dataset, batch_size: int, **kwargs
self, dataset: Union[merlin.io.Dataset, Loader], batch_size: int, **kwargs
) -> merlin.io.Dataset:
"""Batched prediction using the Dask.
Parameters
Expand All @@ -1565,20 +1565,38 @@ def batch_predict(
Returns merlin.io.Dataset
-------
"""
dataset_schema = None
if hasattr(dataset, "schema"):
if not set(self.schema.column_names).issubset(set(dataset.schema.column_names)):
dataset_schema = dataset.schema
data_output_schema = dataset_schema
if isinstance(dataset, Loader):
data_output_schema = dataset.output_schema

if not set(self.schema.column_names).issubset(set(data_output_schema.column_names)):
raise ValueError(
f"Model schema {self.schema.column_names} does not match dataset schema"
+ f" {dataset.schema.column_names}"
+ f" {data_output_schema.column_names}"
)

loader_transforms = None
if isinstance(dataset, Loader):
loader_transforms = dataset.transforms
batch_size = dataset.batch_size
dataset = dataset.dataset

# Check if merlin-dataset is passed
if hasattr(dataset, "to_ddf"):
dataset = dataset.to_ddf()

from merlin.models.tf.utils.batch_utils import TFModelEncode

model_encode = TFModelEncode(self, batch_size=batch_size, **kwargs)
model_encode = TFModelEncode(
self,
batch_size=batch_size,
loader_transforms=loader_transforms,
schema=dataset_schema,
**kwargs,
)

# Processing a sample of the dataset with the model encoder
# to get the output dataframe dtypes
Expand Down

0 comments on commit d4dc945

Please sign in to comment.