diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py index cbce2b624af3..b9f2e8a944b0 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py @@ -24,7 +24,6 @@ from apache_beam.ml.inference.base import ModelHandler from apache_beam.ml.inference.base import PredictionResult from apache_beam.ml.inference.base import RunInference -from apache_beam.ml.inference.tensorflow_inference import ModelType from apache_beam.ml.inference.tensorflow_inference import default_tensor_inference_fn from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor from apache_beam.ml.transforms.base import EmbeddingsManager @@ -34,16 +33,7 @@ import tensorflow_text as text # required to register TF ops. # pylint: disable=unused-import import tensorflow_hub as hub - -class TensorflowHubModelHandler(TFModelHandlerTensor): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def load_model(self): - if self._model_type == ModelType.SAVED_WEIGHTS: - raise NotImplementedError - model = hub.KerasLayer(self._model_uri) - return model +__all__ = ['TensorflowHubTextEmbeddings'] def yield_elements(elements: List[Dict[str, Any]]): @@ -61,8 +51,18 @@ def yield_elements(elements: List[Dict[str, Any]]): yield element -# TODO: many models requires preprocessing. -class TensorflowHubEmbeddings(EmbeddingsManager): +class _TensorflowHubModelHandler(TFModelHandlerTensor): + """ + Note: Intended for internal use only. No backwards compatibility guarantees. + """ + def load_model(self): + # unable to load the models with tf.keras.models.load_model so + # using hub.KerasLayer instead + model = hub.KerasLayer(self._model_uri) + return model + + +class TensorflowHubTextEmbeddings(EmbeddingsManager): def __init__( self, hub_url: str, preprocessing_url: Optional[str] = None, **kwargs): super().__init__(**kwargs) @@ -81,16 +81,20 @@ def custom_inference_fn(self, model, batch, inference_args, model_id): preprocessor_fn = hub.KerasLayer(self.preprocessing_url) vectorized_batch = preprocessor_fn(vectorized_batch) predictions = model(vectorized_batch) - # pooled output is the embeedings. - # TODO: Do other keys need to be returned? + # https://www.tensorflow.org/text/tutorials/classify_text_with_bert#using_the_bert_model # pylint: disable=line-too-long + # pooled_output -> represents the text as a whole. This is an embeddings + # of the whole text. The shape is [batch_size, embedding_dimension] + # sequence_output -> represents the text as a sequence of tokens. This is + # an embeddings of each token in the text. The shape is + # [batch_size, max_sequence_length, embedding_dimension] + # pooled output is the embeedings as per the documentation. so let's use + # that. embeddings = predictions['pooled_output'] return utils._convert_to_result(batch, embeddings, model_id) def get_model_handler(self) -> ModelHandler: # override the default inference function - if not self.inference_fn: - self.inference_fn = self.custom_inference_fn - return TensorflowHubModelHandler( + return _TensorflowHubModelHandler( model_uri=self.model_uri, preprocessor_uri=self.preprocessing_url, inference_fn=self.custom_inference_fn,