From 1b76101c0b24df26d1d1615d0b4dd6d049a5f7a9 Mon Sep 17 00:00:00 2001 From: Ritesh Ghorse Date: Wed, 2 Aug 2023 13:03:27 -0400 Subject: [PATCH] [Python] Hugging Face pipeline support (#27399) * automodel first pass * new model * updated model handler api * add model_class param * update doc comments * updated integration test and example * unit test, modified params * add test setup for hugging face tests * fix lints * fix import order * refactor, doc, lints * refactor, doc comments * change test file * update types * add hugging face pipeline support * integration test for pipeline * add doc, gs link * test raises exception * fix python lints * add inference fn * update doc * docs, lint * docs, lint * remove optional from inference_fn * add enum for tasks * update pydoc * update pydoc * doc, formatting changes * fix doc * fix optional in doc * pin model version --- .../huggingface_question_answering.py | 164 ++++++++++++++ .../ml/inference/huggingface_inference.py | 200 +++++++++++++++++- .../huggingface_inference_it_test.py | 35 ++- .../huggingface_tests_requirements.txt | 2 +- 4 files changed, 393 insertions(+), 8 deletions(-) create mode 100644 sdks/python/apache_beam/examples/inference/huggingface_question_answering.py diff --git a/sdks/python/apache_beam/examples/inference/huggingface_question_answering.py b/sdks/python/apache_beam/examples/inference/huggingface_question_answering.py new file mode 100644 index 000000000000..9005ea5d11d7 --- /dev/null +++ b/sdks/python/apache_beam/examples/inference/huggingface_question_answering.py @@ -0,0 +1,164 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""""A pipeline that uses RunInference to perform Question Answering using the +model from Hugging Face Models Hub. + +This pipeline takes questions and context from a custom text file separated by +a semicolon. These are converted to SquadExamples by using the utility provided +by transformers.QuestionAnsweringPipeline and passed to the model handler. +We just provide the model name here because the model repository specifies the +task that it will do. The pipeline then writes the prediction to an output +file in which users can then compare against the original context. +""" + +import argparse +import logging +from typing import Iterable +from typing import Tuple + +import apache_beam as beam +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.huggingface_inference import HuggingFacePipelineModelHandler +from apache_beam.ml.inference.huggingface_inference import PipelineTask +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.runners.runner import PipelineResult +from transformers import QuestionAnsweringPipeline + + +class PostProcessor(beam.DoFn): + """Processes the PredictionResult to get the predicted answer. + + Hugging Face Pipeline for Question Answering returns a dictionary + with score, start and end index of answer and the answer. + """ + def process(self, result: Tuple[str, PredictionResult]) -> Iterable[str]: + text, prediction = result + predicted_answer = prediction.inference['answer'] + yield text + ';' + predicted_answer + + +def preprocess(text): + """ + preprocess separates the text into question and context + by splitting on semi-colon. + + Args: + text (str): string with question and context separated by semi-colon. + + Yields: + (str, str): yields question and context from text. + """ + if len(text.strip()) > 0: + question, context = text.split(';') + yield (question, context) + + +def create_squad_example(text): + """Creates SquadExample objects to be fed to QuestionAnsweringPipeline + supported by Hugging Face. + + Check out https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.QuestionAnsweringPipeline.__call__.X #pylint: disable=line-too-long + to learn about valid input types for QuestionAnswering Pipeline. + Args: + text (Tuple[str,str]): a tuple of question and context. + """ + question, context = text + yield question, QuestionAnsweringPipeline.create_sample(question, context) + + +def parse_known_args(argv): + """Parses args for the workflow.""" + parser = argparse.ArgumentParser() + parser.add_argument( + '--input', + dest='input', + help='Path of file containing question and context separated by semicolon' + ) + parser.add_argument( + '--output', + dest='output', + required=True, + help='Path of file in which to save the output predictions.') + parser.add_argument( + '--model_name', + dest='model_name', + default="deepset/roberta-base-squad2", + help='Model repository-id from Hugging Face Models Hub.') + parser.add_argument( + '--revision', + dest='revision', + help= + 'Specific model version to use - branch name, tag name, or a commit-id.') + return parser.parse_known_args(argv) + + +def run( + argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult: + """ + Args: + argv: Command line arguments defined for this example. + save_main_session: Used for internal testing. + test_pipeline: Used for internal testing. + """ + known_args, pipeline_args = parse_known_args(argv) + pipeline_options = PipelineOptions(pipeline_args) + pipeline_options.view_as(SetupOptions).save_main_session = save_main_session + + pipeline = test_pipeline + if not test_pipeline: + pipeline = beam.Pipeline(options=pipeline_options) + + model_handler = HuggingFacePipelineModelHandler( + task=PipelineTask.QuestionAnswering, + model=known_args.model_name, + load_model_args={ + 'framework': 'pt', 'revision': known_args.revision + }) + if not known_args.input: + text = ( + pipeline | 'CreateSentences' >> beam.Create([ + "What does Apache Beam do?;" + "Apache Beam enables batch and streaming data processing.", + "What is the capital of France?;The capital of France is Paris .", + "Where was beam summit?;Apache Beam Summit 2023 was in NYC.", + ])) + else: + text = ( + pipeline | 'ReadSentences' >> beam.io.ReadFromText(known_args.input)) + processed_text = ( + text + | 'PreProcess' >> beam.ParDo(preprocess) + | 'SquadExample' >> beam.ParDo(create_squad_example)) + output = ( + processed_text + | 'RunInference' >> RunInference(KeyedModelHandler(model_handler)) + | 'ProcessOutput' >> beam.ParDo(PostProcessor())) + _ = output | "WriteOutput" >> beam.io.WriteToText( + known_args.output, shard_name_template='', append_trailing_newlines=True) + + result = pipeline.run() + result.wait_until_finish() + return result + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + run() diff --git a/sdks/python/apache_beam/ml/inference/huggingface_inference.py b/sdks/python/apache_beam/ml/inference/huggingface_inference.py index 35c3a1686c70..09201d3b080b 100644 --- a/sdks/python/apache_beam/ml/inference/huggingface_inference.py +++ b/sdks/python/apache_beam/ml/inference/huggingface_inference.py @@ -20,6 +20,7 @@ import logging import sys from collections import defaultdict +from enum import Enum from typing import Any from typing import Callable from typing import Dict @@ -35,13 +36,16 @@ from apache_beam.ml.inference.base import PredictionResult from apache_beam.ml.inference.pytorch_inference import _convert_to_device from transformers import AutoModel +from transformers import Pipeline from transformers import TFAutoModel +from transformers import pipeline _LOGGER = logging.getLogger(__name__) __all__ = [ "HuggingFaceModelHandlerTensor", "HuggingFaceModelHandlerKeyedTensor", + "HuggingFacePipelineModelHandler", ] TensorInferenceFn = Callable[[ @@ -59,10 +63,50 @@ Union[AutoModel, TFAutoModel], str, Optional[Dict[str, Any]], - Optional[str], + Optional[str] ], - Iterable[PredictionResult], - ] + Iterable[PredictionResult]] + +PipelineInferenceFn = Callable[ + [Sequence[str], Pipeline, Optional[Dict[str, Any]]], + Iterable[PredictionResult]] + + +class PipelineTask(str, Enum): + """ + PipelineTask lists all the tasks supported by the Hugging Face Pipelines. + Only these tasks can be passed to HuggingFacePipelineModelHandler. + """ + AudioClassification = 'audio-classification' + AutomaticSpeechRecognition = 'automatic-speech-recognition' + Conversational = 'conversational' + DepthEstimation = 'depth-estimation' + DocumentQuestionAnswering = 'document-question-answering' + FeatureExtraction = 'feature-extraction' + FillMask = 'fill-mask' + ImageClassification = 'image-classification' + ImageSegmentation = 'image-segmentation' + ImageToText = 'image-to-text' + MaskGeneration = 'mask-generation' + NER = 'ner' + ObjectDetection = 'object-detection' + QuestionAnswering = 'question-answering' + SentimentAnalysis = 'sentiment-analysis' + Summarization = 'summarization' + TableQuestionAnswering = 'table-question-answering' + TextClassification = 'text-classification' + TextGeneration = 'text-generation' + Text2TextGeneration = 'text2text-generation' + TokenClassification = 'token-classification' + Translation = 'translation' + VideoClassification = 'video-classification' + VisualQuestionAnswering = 'visual-question-answering' + VQA = 'vqa' + ZeroShotAudioClassification = 'zero-shot-audio-classification' + ZeroShotClassification = 'zero-shot-classification' + ZeroShotImageClassification = 'zero-shot-image-classification' + ZeroShotObjectDetection = 'zero-shot-object-detection' + Translation_XX_to_YY = 'translation_XX_to_YY' def _validate_constructor_args(model_uri, model_class): @@ -109,6 +153,14 @@ def is_gpu_available_tensorflow(device): return True +def _validate_constructor_args_hf_pipeline(task, model): + if not task and not model: + raise RuntimeError( + 'Please provide either task or model to the ' + 'HuggingFacePipelineModelHandler. If the model already defines the ' + 'task, no need to specify the task.') + + def _run_inference_torch_keyed_tensor( batch: Sequence[Dict[str, torch.Tensor]], model: AutoModel, @@ -447,7 +499,7 @@ def run_inference( else: self._framework = "pt" - if (self._framework == 'pt' and self._device == "GPU" and + if (self._framework == "pt" and self._device == "GPU" and is_gpu_available_torch()): model.to(torch.device("cuda")) @@ -462,6 +514,9 @@ def run_inference( return _default_inference_fn_torch( batch, model, self._device, inference_args, self._model_uri) + def update_model_path(self, model_path: Optional[str] = None): + self._model_uri = model_path if model_path else self._model_uri + def get_num_bytes( self, batch: Sequence[Union[tf.Tensor, torch.Tensor]]) -> int: """ @@ -483,6 +538,139 @@ def share_model_across_processes(self) -> bool: def get_metrics_namespace(self) -> str: """ Returns: - A namespace for metrics collected by the RunInference transform. + A namespace for metrics collected by the RunInference transform. + """ + return 'BeamML_HuggingFaceModelHandler_Tensor' + + +def _convert_to_result( + batch: Iterable, + predictions: Union[Iterable, Dict[Any, Iterable]], + model_id: Optional[str] = None, +) -> Iterable[PredictionResult]: + return [ + PredictionResult(x, y, model_id) for x, y in zip(batch, [predictions]) + ] + + +def _default_pipeline_inference_fn( + batch, pipeline, inference_args) -> Iterable[PredictionResult]: + predicitons = pipeline(batch, **inference_args) + return predicitons + + +class HuggingFacePipelineModelHandler(ModelHandler[str, + PredictionResult, + Pipeline]): + def __init__( + self, + task: Union[str, PipelineTask] = "", + model=None, + *, + inference_fn: PipelineInferenceFn = _default_pipeline_inference_fn, + load_pipeline_args: Optional[Dict[str, Any]] = None, + inference_args: Optional[Dict[str, Any]] = None, + min_batch_size: Optional[int] = None, + max_batch_size: Optional[int] = None, + large_model: bool = False, + **kwargs): + """ + Implementation of the ModelHandler interface for Hugging Face Pipelines. + + **Note:** To specify which device to use (CPU/GPU), + use the load_pipeline_args with key-value as you would do in the usual + Hugging Face pipeline. Ex: load_pipeline_args={'device':0}) + + Example Usage model:: + pcoll | RunInference(HuggingFacePipelineModelHandler( + task="fill-mask")) + + Args: + task (str or enum.Enum): task supported by HuggingFace Pipelines. + Accepts a string task or an enum.Enum from PipelineTask. + model : path to pretrained model on Hugging Face Models Hub to use custom + model for the chosen task. If the model already defines the task then + no need to specify the task parameter. + inference_fn: the inference function to use during RunInference. + Default is _default_pipeline_inference_fn. + load_pipeline_args (Dict[str, Any]): keyword arguments to provide load + options while loading pipelines from Hugging Face. Defaults to None. + inference_args (Dict[str, Any]): Non-batchable arguments + required as inputs to the model's inference function. + Defaults to None. + min_batch_size: the minimum batch size to use when batching inputs. + max_batch_size: the maximum batch size to use when batching inputs. + large_model: set to true if your model is large enough to run into + memory pressure if you load multiple copies. Given a model that + consumes N memory and a machine with W cores and M memory, you should + set this to True if N*W > M. + kwargs: 'env_vars' can be used to set environment variables + before loading the model. + + **Supported Versions:** HuggingFacePipelineModelHandler supports + transformers>=4.18.0. + """ + self._task = task + self._model = model + self._inference_fn = inference_fn + self._load_pipeline_args = load_pipeline_args if load_pipeline_args else {} + self._inference_args = inference_args if inference_args else {} + self._batching_kwargs = {} + self._framework = "torch" + self._env_vars = kwargs.get('env_vars', {}) + if min_batch_size is not None: + self._batching_kwargs['min_batch_size'] = min_batch_size + if max_batch_size is not None: + self._batching_kwargs['max_batch_size'] = max_batch_size + self._large_model = large_model + _validate_constructor_args_hf_pipeline(self._task, self._model) + + def load_model(self): + return pipeline( + task=self._task, model=self._model, **self._load_pipeline_args) + + def run_inference( + self, + batch: Sequence[str], + pipeline: Pipeline, + inference_args: Optional[Dict[str, Any]] = None + ) -> Iterable[PredictionResult]: + """ + Runs inferences on a batch of examples passed as a string resource. + These can either be string sentences, or string path to images or + audio files. + + Args: + batch: A sequence of strings resources. + pipeline: A Hugging Face Pipeline. + inference_args: Non-batchable arguments required as inputs to the model's + inference function. + Returns: + An Iterable of type PredictionResult. + """ + inference_args = {} if not inference_args else inference_args + predictions = self._inference_fn(batch, pipeline, inference_args) + return _convert_to_result(batch, predictions) + + def update_model_path(self, model_path: Optional[str] = None): + """ + Updates the pretrained model used by the Hugging Face Pipeline task. + Make sure that the new model does the same task as initial model. + + Args: + model_path (str): (Optional) Path to the new trained model + from Hugging Face. Defaults to None. """ - return "BeamML_HuggingFaceModelHandler_Tensor" + self._model = model_path if model_path else self._model + + def get_num_bytes(self, batch: Sequence[str]) -> int: + return sum(sys.getsizeof(element) for element in batch) + + def batch_elements_kwargs(self): + return self._batching_kwargs + + def share_model_across_processes(self) -> bool: + return self._large_model + + def get_metrics_namespace(self) -> str: + return 'BeamML_HuggingFacePipelineModelHandler' diff --git a/sdks/python/apache_beam/ml/inference/huggingface_inference_it_test.py b/sdks/python/apache_beam/ml/inference/huggingface_inference_it_test.py index ed442a4b801a..0be359a87196 100644 --- a/sdks/python/apache_beam/ml/inference/huggingface_inference_it_test.py +++ b/sdks/python/apache_beam/ml/inference/huggingface_inference_it_test.py @@ -28,6 +28,7 @@ try: from apache_beam.examples.inference import huggingface_language_modeling + from apache_beam.examples.inference import huggingface_question_answering from apache_beam.ml.inference import pytorch_inference_it_test except ImportError: raise unittest.SkipTest( @@ -38,8 +39,8 @@ @pytest.mark.uses_transformers @pytest.mark.it_postcommit +@pytest.mark.timeout(1800) class HuggingFaceInference(unittest.TestCase): - @pytest.mark.timeout(1800) def test_hf_language_modeling(self): test_pipeline = TestPipeline(is_integration_test=True) # Path to text file containing some sentences @@ -74,6 +75,38 @@ def test_hf_language_modeling(self): predicted_predicted_text = predictions_dict[text] self.assertEqual(actual_predicted_text, predicted_predicted_text) + def test_hf_pipeline(self): + test_pipeline = TestPipeline(is_integration_test=True) + # Path to text file containing some questions and context + input_file = 'gs://apache-beam-ml/datasets/custom/questions.txt' + output_file_dir = 'gs://apache-beam-ml/hf/testing/predictions' + output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt']) + extra_opts = { + 'input': input_file, + 'output': output_file, + 'revision': 'deedc3e42208524e0df3d9149d1f26aa6934f05f', + } + huggingface_question_answering.run( + test_pipeline.get_full_options_as_args(**extra_opts), + save_main_session=False) + self.assertEqual(FileSystems().exists(output_file), True) + predictions = pytorch_inference_it_test.process_outputs( + filepath=output_file) + actuals_file = ( + 'gs://apache-beam-ml/testing/expected_outputs/' + 'test_hf_pipeline_answers.txt') + actuals = pytorch_inference_it_test.process_outputs(filepath=actuals_file) + + predictions_dict = {} + for prediction in predictions: + text, predicted_text = prediction.split(';') + predictions_dict[text] = predicted_text.strip() + + for actual in actuals: + text, actual_predicted_text = actual.split(';') + predicted_predicted_text = predictions_dict[text] + self.assertEqual(actual_predicted_text, predicted_predicted_text) + if __name__ == '__main__': logging.getLogger().setLevel(logging.DEBUG) diff --git a/sdks/python/apache_beam/ml/inference/huggingface_tests_requirements.txt b/sdks/python/apache_beam/ml/inference/huggingface_tests_requirements.txt index 09c1fa8ca90c..adb4816cab6b 100644 --- a/sdks/python/apache_beam/ml/inference/huggingface_tests_requirements.txt +++ b/sdks/python/apache_beam/ml/inference/huggingface_tests_requirements.txt @@ -16,5 +16,5 @@ # torch>=1.7.1 -transformers>=4.18.0 +transformers==4.30.0 tensorflow>=2.12.0 \ No newline at end of file