Skip to content

Commit

Permalink
Add tensorflow hub text embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
AnandInguva committed Nov 14, 2023
1 parent 368fa5d commit 058c6b4
Showing 1 changed file with 22 additions and 18 deletions.
40 changes: 22 additions & 18 deletions sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.tensorflow_inference import ModelType
from apache_beam.ml.inference.tensorflow_inference import default_tensor_inference_fn
from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor
from apache_beam.ml.transforms.base import EmbeddingsManager
Expand All @@ -34,16 +33,7 @@
import tensorflow_text as text # required to register TF ops. # pylint: disable=unused-import
import tensorflow_hub as hub


class TensorflowHubModelHandler(TFModelHandlerTensor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def load_model(self):
if self._model_type == ModelType.SAVED_WEIGHTS:
raise NotImplementedError
model = hub.KerasLayer(self._model_uri)
return model
__all__ = ['TensorflowHubTextEmbeddings']


def yield_elements(elements: List[Dict[str, Any]]):
Expand All @@ -61,8 +51,18 @@ def yield_elements(elements: List[Dict[str, Any]]):
yield element


# TODO: many models requires preprocessing.
class TensorflowHubEmbeddings(EmbeddingsManager):
class _TensorflowHubModelHandler(TFModelHandlerTensor):
"""
Note: Intended for internal use only. No backwards compatibility guarantees.
"""
def load_model(self):
# unable to load the models with tf.keras.models.load_model so
# using hub.KerasLayer instead
model = hub.KerasLayer(self._model_uri)
return model


class TensorflowHubTextEmbeddings(EmbeddingsManager):
def __init__(
self, hub_url: str, preprocessing_url: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
Expand All @@ -81,16 +81,20 @@ def custom_inference_fn(self, model, batch, inference_args, model_id):
preprocessor_fn = hub.KerasLayer(self.preprocessing_url)
vectorized_batch = preprocessor_fn(vectorized_batch)
predictions = model(vectorized_batch)
# pooled output is the embeedings.
# TODO: Do other keys need to be returned?
# https://www.tensorflow.org/text/tutorials/classify_text_with_bert#using_the_bert_model # pylint: disable=line-too-long
# pooled_output -> represents the text as a whole. This is an embeddings
# of the whole text. The shape is [batch_size, embedding_dimension]
# sequence_output -> represents the text as a sequence of tokens. This is
# an embeddings of each token in the text. The shape is
# [batch_size, max_sequence_length, embedding_dimension]
# pooled output is the embeedings as per the documentation. so let's use
# that.
embeddings = predictions['pooled_output']
return utils._convert_to_result(batch, embeddings, model_id)

def get_model_handler(self) -> ModelHandler:
# override the default inference function
if not self.inference_fn:
self.inference_fn = self.custom_inference_fn
return TensorflowHubModelHandler(
return _TensorflowHubModelHandler(
model_uri=self.model_uri,
preprocessor_uri=self.preprocessing_url,
inference_fn=self.custom_inference_fn,
Expand Down

0 comments on commit 058c6b4

Please sign in to comment.