diff --git a/examples/notebooks/beam-ml/run_inference_multi_model.ipynb b/examples/notebooks/beam-ml/run_inference_multi_model.ipynb index 9a99ad2cf475..7cd144223cae 100644 --- a/examples/notebooks/beam-ml/run_inference_multi_model.ipynb +++ b/examples/notebooks/beam-ml/run_inference_multi_model.ipynb @@ -47,8 +47,7 @@ { "cell_type": "markdown", "source": [ - "# Ensemble model using an image captioning and ranking example", - "\n", + "# Ensemble model using an image captioning and ranking example\n", "\n", "
\n", " Run in Google Colab\n", @@ -65,12 +64,12 @@ { "cell_type": "markdown", "source": [ - "A single machine learning model might not be the right solution for your task. Often, machine learning model tasks involve aggregating mutliple models together to produce one optimal predictive model and to boost performance. \n", - " \n", + "When performing complex tasks like image captioning, using a single ML model may not be the best solution.\n", + "\n", "\n", "This notebook shows how to implement a cascade model in Apache Beam using the [RunInference API](https://beam.apache.org/documentation/sdks/python-machine-learning/). The RunInference API enables you to run your Beam transforms as part of your pipeline for optimal machine learning inference.\n", "\n", - "For more information about the RunInference API, review the [RunInference notebook](https://colab.research.google.com/drive/111USL4VhUa0xt_mKJxl5nC1YLOC8_yF4?usp=sharing#scrollTo=746b67a7-3562-467f-bea3-d8cd18c14927).\n", + "For more information about the RunInference API, review the [RunInference notebook](https://colab.research.google.com/drive/111USL4VhUa0xt_mKJxl5nC1YLOC8_yF4?usp=sharing#scrollTo=746b67a7-3562-467f-bea3-d8cd18c14927) or the [Beam ML documentation](https://beam.apache.org/documentation/ml/overview/).\n", "\n", "**Note:** All images are licensed CC-BY, and creators are listed in the [LICENSE.txt](https://storage.googleapis.com/apache-beam-samples/image_captioning/LICENSE.txt) file." ], @@ -94,7 +93,7 @@ "\n", "This example shows how to generate captions on a a large set of images. Apache Beam is the ideal tool to handle this workflow. We use two models for this task:\n", "\n", - "* [BLIP](https://github.com/salesforce/BLIP): Generates a set of candidate captions for a given image. \n", + "* [BLIP](https://github.com/salesforce/BLIP): Generates a set of candidate captions for a given image.\n", "* [CLIP](https://github.com/openai/CLIP): Ranks the generated captions based on accuracy." ], "metadata": { @@ -119,7 +118,7 @@ "* Run inference with BLIP to generate a list of caption candidates.\n", "* Aggregate the generated captions with their source image.\n", "* Preprocess the aggregated image-caption pairs to rank them with CLIP.\n", - "* Run inference with CLIP to generate the caption ranking. \n", + "* Run inference with CLIP to generate the caption ranking.\n", "* Print the image names and the captions sorted according to their ranking.\n", "\n", "\n", @@ -139,13 +138,13 @@ "metadata": { "colab": { "base_uri": "https://localhost:8080/", - "height": 440 + "height": 460 }, "id": "3suC5woJLW_N", - "outputId": "d2f9f67b-361b-4ae9-f9db-ce2ff9abd509", + "outputId": "2b5e78bf-f212-4a77-9325-8808ef024c2e", "cellView": "form" }, - "execution_count": null, + "execution_count": 1, "outputs": [ { "output_type": "execute_result", @@ -158,7 +157,7 @@ ] }, "metadata": {}, - "execution_count": 3 + "execution_count": 1 } ] }, @@ -184,68 +183,34 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 1, "metadata": { - "id": "tTUZpG9_q-OW", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "9ee6407a-8e4b-4520-fe5d-54a886b6e0b1" + "id": "tTUZpG9_q-OW" }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\u001b[K |████████████████████████████████| 2.1 MB 7.0 MB/s \n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.4/3.4 MB\u001b[0m \u001b[31m47.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.3/3.3 MB\u001b[0m \u001b[31m90.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m182.4/182.4 kB\u001b[0m \u001b[31m21.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m880.6/880.6 kB\u001b[0m \u001b[31m69.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m377.0/377.0 kB\u001b[0m \u001b[31m10.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m53.1/53.1 kB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.5/6.5 MB\u001b[0m \u001b[31m60.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.8/12.8 MB\u001b[0m \u001b[31m91.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m235.4/235.4 kB\u001b[0m \u001b[31m8.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Installing backend dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for fairscale (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] - } - ], + "outputs": [], "source": [ "!pip install --upgrade pip --quiet\n", - "!pip install transformers==4.15.0 --quiet\n", + "!pip install transformers==4.30.2 --quiet\n", "!pip install timm==0.4.12 --quiet\n", "!pip install ftfy==6.1.1 --quiet\n", "!pip install spacy==3.4.1 --quiet\n", "!pip install fairscale==0.4.4 --quiet\n", - "!pip install apache_beam[gcp]>=2.40.0 \n", + "!pip install apache_beam[gcp]>=2.48.0\n", "\n", "# To use the newly installed versions, restart the runtime.\n", - "exit() " + "exit()" ] }, { "cell_type": "code", "source": [ "import requests\n", - "import os \n", + "import os\n", "import urllib\n", - "import json \n", + "import json\n", "import io\n", "from io import BytesIO\n", + "from typing import Sequence\n", "from typing import Iterator\n", "from typing import Iterable\n", "from typing import Tuple\n", @@ -303,7 +268,7 @@ "base_uri": "https://localhost:8080/" }, "id": "Ud4sUXV2x8LO", - "outputId": "9e12ea04-a347-426f-8145-280a5676e78b" + "outputId": "cc814ff8-d424-4880-e006-56803e0508aa" }, "execution_count": 2, "outputs": [ @@ -311,7 +276,6 @@ "output_type": "stream", "name": "stdout", "text": [ - "Error: Failed to call git rev-parse --git-dir --show-toplevel: \"fatal: not a git repository (or any of the parent directories): .git\\n\"\n", "Git LFS initialized.\n", "Cloning into 'clip-vit-base-patch32'...\n", "remote: Enumerating objects: 51, done.\u001b[K\n", @@ -362,7 +326,7 @@ "base_uri": "https://localhost:8080/" }, "id": "g4-6WwqUtxea", - "outputId": "3b04b933-aab0-4f5b-c967-ed784125bc6a" + "outputId": "29112ca0-f111-48b7-d8cc-a4e04fb7a02b" }, "execution_count": 4, "outputs": [ @@ -388,8 +352,8 @@ "from BLIP.models.blip import blip_decoder\n", "\n", "!gdown 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'\n", - "# The blip model is saved as a checkoint, load it and save it as a state dict since RunInference required \n", - "# a state dict for model instantiation \n", + "# The blip model is saved as a checkpoint, load it and save it as a state dict since RunInference required\n", + "# a state dict for model instantiation\n", "blip_state_dict_path = '/content/BLIP/blip_state_dict.pth'\n", "torch.save(torch.load('/content/BLIP/model*_base_caption.pth')['model'], blip_state_dict_path)" ], @@ -398,7 +362,7 @@ "base_uri": "https://localhost:8080/" }, "id": "GCvOP_iZh41c", - "outputId": "224c22b1-eda6-463c-c926-1341ec9edef8" + "outputId": "a96f0ff5-cdf7-4394-be6e-d5bfca2f3a1f" }, "execution_count": 5, "outputs": [ @@ -409,7 +373,7 @@ "Downloading...\n", "From: https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth\n", "To: /content/BLIP/model*_base_caption.pth\n", - "100% 896M/896M [00:04<00:00, 198MB/s] \n" + "100% 896M/896M [00:04<00:00, 198MB/s]\n" ] } ] @@ -500,9 +464,9 @@ "\n", " \"\"\"\n", " Process the raw image input to a format suitable for BLIP inference. The processed\n", - " images are duplicated to the number of desired captions per image. \n", + " images are duplicated to the number of desired captions per image.\n", "\n", - " Preprocessing transformation taken from: \n", + " Preprocessing transformation taken from:\n", " https://github.com/salesforce/BLIP/blob/d10be550b2974e17ea72e74edc7948c9e5eab884/predict.py\n", " \"\"\"\n", "\n", @@ -510,7 +474,7 @@ " self._captions_per_image = captions_per_image\n", "\n", " def setup(self):\n", - " \n", + "\n", " # Initialize the image transformer.\n", " self._transform = transforms.Compose([\n", " transforms.Resize((384, 384),interpolation=InterpolationMode.BICUBIC),\n", @@ -519,7 +483,7 @@ " ])\n", "\n", " def process(self, element):\n", - " image_url, image = element \n", + " image_url, image = element\n", " # The following lines provide a workaround to turn off BatchElements.\n", " preprocessed_img = self._transform(image).unsqueeze(0)\n", " preprocessed_img = preprocessed_img.repeat(self._captions_per_image, 1, 1, 1)\n", @@ -533,7 +497,7 @@ " Process the PredictionResult to get the generated image captions\n", " \"\"\"\n", " def process(self, element : Tuple[str, Iterable[PredictionResult]]):\n", - " image_url, prediction = element \n", + " image_url, prediction = element\n", "\n", " return [(image_url, prediction.inference)]" ], @@ -546,7 +510,7 @@ { "cell_type": "markdown", "source": [ - "### Define CLIP functions \n", + "### Define CLIP functions\n", "\n", "Define the preprocessing and postprocessing functions for CLIP." ], @@ -560,9 +524,9 @@ "class PreprocessCLIPInput(beam.DoFn):\n", "\n", " \"\"\"\n", - " Process the image-caption pair to a format suitable for CLIP inference. \n", + " Process the image-caption pair to a format suitable for CLIP inference.\n", "\n", - " After grouping the raw images with the generated captions, we need to \n", + " After grouping the raw images with the generated captions, we need to\n", " preprocess them before passing them to the ranking stage (CLIP model).\n", " \"\"\"\n", "\n", @@ -572,12 +536,12 @@ " merges_file_config_path: str):\n", "\n", " self._feature_extractor_config_path = feature_extractor_config_path\n", - " self._tokenizer_vocab_config_path = tokenizer_vocab_config_path \n", + " self._tokenizer_vocab_config_path = tokenizer_vocab_config_path\n", " self._merges_file_config_path = merges_file_config_path\n", "\n", "\n", " def setup(self):\n", - " \n", + "\n", " # Initialize the CLIP feature extractor.\n", " feature_extractor_config = CLIPConfig.from_pretrained(self._feature_extractor_config_path)\n", " feature_extractor = CLIPFeatureExtractor(feature_extractor_config)\n", @@ -585,14 +549,14 @@ " # Initialize the CLIP tokenizer.\n", " tokenizer = CLIPTokenizer(self._tokenizer_vocab_config_path,\n", " self._merges_file_config_path)\n", - " \n", + "\n", " # Initialize the CLIP processor used to process the image-caption pair.\n", " self._processor = CLIPProcessor(feature_extractor=feature_extractor,\n", " tokenizer=tokenizer)\n", "\n", " def process(self, element: Tuple[str, Dict[str, List[Any]]]):\n", "\n", - " image_url, image_captions_pair = element \n", + " image_url, image_captions_pair = element\n", " # Unpack the image and captions after grouping them with 'CoGroupByKey()'.\n", " image = image_captions_pair['image'][0]\n", " captions = image_captions_pair['captions'][0]\n", @@ -600,7 +564,7 @@ " text = captions,\n", " return_tensors=\"pt\",\n", " padding=True)\n", - " \n", + "\n", " image_url_caption_pair = (image_url, captions)\n", " return [(image_url_caption_pair, preprocessed_clip_input)]\n", "\n", @@ -612,7 +576,7 @@ " The logits are the output of the CLIP model. Here, we apply a softmax activation\n", " function to the logits to get the probabilistic distribution of the relevance\n", " of each caption to the target image. After that, we sort the captions in descending\n", - " order with respect to the probabilities as a caption-probability pair. \n", + " order with respect to the probabilities as a caption-probability pair.\n", " \"\"\"\n", "\n", " def process(self, element : Tuple[Tuple[str, List[str]], Iterable[PredictionResult]]):\n", @@ -642,7 +606,9 @@ { "cell_type": "markdown", "source": [ - "Use a `KeyedModelHandler` for both models to attach a key to the general `ModelHandler`.\n", + "A `ModelHandler` is Beam's method for defining the configuration needed to load and invoke your model. Since both the BLIP and CLIP models use Pytorch and take KeyedTensors as inputs, we will use `PytorchModelHandlerKeyedTensor` for both.\n", + "\n", + "We will use a `KeyedModelHandler` for both models to attach a key to the general `ModelHandler`.\n", "The key is used for the following purposes:\n", "* To keep a reference to the image that the inference is associated with.\n", "* To aggregate transforms of different inputs.\n", @@ -654,36 +620,6 @@ "id": "BTmSPnjj8M2m" } }, - { - "cell_type": "code", - "source": [ - "class PytorchNoBatchModelHandlerKeyedTensor(PytorchModelHandlerKeyedTensor):\n", - " \"\"\"Wrapper to PytorchModelHandler to limit batch size to 1.\n", - " The caption strings generated from the BLIP tokenizer might have different\n", - " lengths. Different length strings don't work with torch.stack() in the current RunInference\n", - " implementation, because stack() requires tensors to be the same size.\n", - " Restricting max_batch_size to 1 means there is only 1 example per `batch`\n", - " in the run_inference() call.\n", - " \"\"\"\n", - " # The following lines provide a workaround to turn off BatchElements.\n", - " def batch_elements_kwargs(self):\n", - " return {'max_batch_size': 1}" - ], - "metadata": { - "id": "OaR02_wxTMpc" - }, - "execution_count": 9, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "Note that we use a `KeyedModelHandler` for both models to attach a key to the general `ModelHandler`. The key is used for aggregation transforms of different inputs." - ], - "metadata": { - "id": "gNLRO0EwvcGP" - } - }, { "cell_type": "markdown", "source": [ @@ -713,48 +649,36 @@ { "cell_type": "code", "source": [ - "class BLIPWrapper(torch.nn.Module):\n", - " \"\"\"\n", - " Wrapper around the BLIP model to overwrite the default \"forward\" method with the \"generate\" method, because BLIP uses the \n", - " \"generate\" method to produce the image captions.\n", - " \"\"\"\n", - " \n", - " def __init__(self, base_model: blip_decoder, num_beams: int, max_length: int,\n", - " min_length: int):\n", - " super().__init__()\n", - " self._model = base_model()\n", - " self._num_beams = num_beams\n", - " self._max_length = max_length\n", - " self._min_length = min_length\n", - "\n", - " def forward(self, inputs: torch.Tensor):\n", - " # Squeeze because RunInference adds an extra dimension, which is empty.\n", - " # The following lines provide a workaround to turn off BatchElements.\n", - " inputs = inputs.squeeze(0)\n", - " captions = self._model.generate(inputs,\n", - " sample=True,\n", - " num_beams=self._num_beams,\n", - " max_length=self._max_length,\n", - " min_length=self._min_length)\n", - " return [captions]\n", - "\n", - " def load_state_dict(self, state_dict: dict):\n", - " self._model.load_state_dict(state_dict)\n", - "\n", - "\n", - "BLIP_model_handler = PytorchNoBatchModelHandlerKeyedTensor(\n", + "def blip_keyed_tensor_inference_fn(\n", + " batch: Sequence[Dict[str, torch.Tensor]],\n", + " model: torch.nn.Module,\n", + " device: str,\n", + " inference_args: Optional[Dict[str, Any]] = None,\n", + " model_id: Optional[str] = None,\n", + ") -> Iterable[PredictionResult]:\n", + " # By default, Beam batches inputs for bulk inference and calls model(batch)\n", + " # Since we want to call model.generate on a single unbatched input (BLIP/CLIP\n", + " # don't handle batched inputs), we define a custom inference function.\n", + " captions = model.generate(batch[0]['inputs'],\n", + " sample=True,\n", + " num_beams=NUM_BEAMS,\n", + " max_length=MAX_CAPTION_LENGTH,\n", + " min_length=MIN_CAPTION_LENGTH)\n", + " return [PredictionResult(batch[0], captions, model_id)]\n", + "\n", + "\n", + "BLIP_model_handler = PytorchModelHandlerKeyedTensor(\n", " state_dict_path=blip_state_dict_path,\n", - " model_class=BLIPWrapper,\n", - " model_params={'base_model': blip_decoder, 'num_beams': NUM_BEAMS,\n", - " 'max_length': MAX_CAPTION_LENGTH, 'min_length': MIN_CAPTION_LENGTH},\n", - " device='GPU')\n", + " model_class=blip_decoder,\n", + " inference_fn=blip_keyed_tensor_inference_fn,\n", + " max_batch_size=1)\n", "\n", "BLIP_keyed_model_handler = KeyedModelHandler(BLIP_model_handler)" ], "metadata": { "id": "RCKBJjujVw4q" }, - "execution_count": 11, + "execution_count": 10, "outputs": [] }, { @@ -771,29 +695,33 @@ { "cell_type": "code", "source": [ - "class CLIPWrapper(CLIPModel):\n", - "\n", - " def forward(self, **kwargs: Dict[str, torch.Tensor]):\n", - " # Squeeze because RunInference adds an extra dimension, which is empty.\n", - " # The following lines provide a workaround to turn off BatchElements.\n", - " kwargs = {key: tensor.squeeze(0) for key, tensor in kwargs.items()}\n", - " output = super().forward(**kwargs)\n", - " logits = output.logits_per_image\n", - " return logits\n", - "\n", - "\n", - "CLIP_model_handler = PytorchNoBatchModelHandlerKeyedTensor(\n", + "def clip_keyed_tensor_inference_fn(\n", + " batch: Sequence[Dict[str, torch.Tensor]],\n", + " model: torch.nn.Module,\n", + " device: str,\n", + " inference_args: Optional[Dict[str, Any]] = None,\n", + " model_id: Optional[str] = None,\n", + ") -> Iterable[PredictionResult]:\n", + " # By default, Beam batches inputs for bulk inference and calls model(batch)\n", + " # Since we want to call model on a single unbatched input (BLIP/CLIP don't\n", + " # handle batched inputs), we define a custom inference function.\n", + " output = model(**batch[0], **inference_args)\n", + " return [PredictionResult(batch[0], output.logits_per_image[0], model_id)]\n", + "\n", + "\n", + "CLIP_model_handler = PytorchModelHandlerKeyedTensor(\n", " state_dict_path=clip_state_dict_path,\n", - " model_class=CLIPWrapper,\n", + " model_class=CLIPModel,\n", " model_params={'config': CLIPConfig.from_pretrained(clip_model_config_path)},\n", - " device='GPU')\n", + " inference_fn=clip_keyed_tensor_inference_fn,\n", + " max_batch_size=1)\n", "\n", "CLIP_keyed_model_handler = KeyedModelHandler(CLIP_model_handler)\n" ], "metadata": { "id": "EJw_OnZ1ZfuH" }, - "execution_count": 12, + "execution_count": 11, "outputs": [] }, { @@ -817,7 +745,7 @@ "metadata": { "id": "VJwE0bquoXOf" }, - "execution_count": 13, + "execution_count": 12, "outputs": [] }, { @@ -834,7 +762,7 @@ "source": [ "#@title\n", "license_txt_url = 'https://storage.googleapis.com/apache-beam-samples/image_captioning/LICENSE.txt'\n", - "license_dict = json.loads(urllib.request.urlopen(license_txt_url).read().decode(\"utf-8\")) \n", + "license_dict = json.loads(urllib.request.urlopen(license_txt_url).read().decode(\"utf-8\"))\n", "\n", "for image_url in images_url:\n", " response = requests.get(image_url)\n", @@ -855,7 +783,7 @@ "outputId": "6e771e4e-a76a-4855-b466-976cdf35b506", "cellView": "form" }, - "execution_count": 16, + "execution_count": null, "outputs": [ { "output_type": "display_data", @@ -918,7 +846,7 @@ "metadata": { "id": "Dcz_M9GW0Kan" }, - "execution_count": 14, + "execution_count": 13, "outputs": [] }, { @@ -947,13 +875,13 @@ "with beam.Pipeline() as pipeline:\n", "\n", " read_images = (\n", - " pipeline \n", + " pipeline\n", " | \"ReadUrl\" >> beam.Create(images_url)\n", " | \"ReadImages\" >> beam.ParDo(ReadImagesFromUrl()))\n", "\n", " blip_caption_generation = (\n", " read_images\n", - " | \"PreprocessBlipInput\" >> beam.ParDo(PreprocessBLIPInput(NUM_CAPTIONS_PER_IMAGE)) \n", + " | \"PreprocessBlipInput\" >> beam.ParDo(PreprocessBLIPInput(NUM_CAPTIONS_PER_IMAGE))\n", " | \"GenerateCaptions\" >> RunInference(BLIP_keyed_model_handler)\n", " | \"PostprocessCaptions\" >> beam.ParDo(PostprocessBLIPOutput()))\n", "\n", @@ -966,19 +894,21 @@ " clip_tokenizer_vocab_config_path,\n", " clip_merges_config_path))\n", " | \"GetRankingLogits\" >> RunInference(CLIP_keyed_model_handler)\n", - " | \"RankClipOutput\" >> beam.ParDo(RankCLIPOutput()))\n", + " | \"RankClipOutput\" >> beam.ParDo(RankCLIPOutput())\n", + " )\n", "\n", " clip_captions_ranking | \"FormatCaptions\" >> beam.ParDo(FormatCaptions(NUM_TOP_CAPTIONS_TO_DISPLAY))\n", - " " + "" ], "metadata": { "colab": { - "base_uri": "https://localhost:8080/" + "base_uri": "https://localhost:8080/", + "height": 428 }, "id": "002e-FNbmuB8", - "outputId": "49c646f1-8612-433f-b134-ea8af0ff5591" + "outputId": "1b540b1e-b146-45d6-f8d3-ccaf461a87b7" }, - "execution_count": 18, + "execution_count": 14, "outputs": [ { "output_type": "stream", @@ -986,29 +916,41 @@ "text": [ "Image: Paris-sunset\n", "\tTop 3 captions ranked by CLIP:\n", - "\t\t1: the eiffel tower in paris is silhouetted at sunset. (Caption probability: 0.23)\n", - "\t\t2: the sun sets over the city of paris, with the eiffel tower in the distance. (Caption probability: 0.19)\n", - "\t\t3: the sun sets over the eiffel tower in paris. (Caption probability: 0.17)\n", + "\t\t1: the setting sun is reflected in an orange setting sky over paris. (Caption probability: 0.28)\n", + "\t\t2: the sun rising above the eiffel tower over paris. (Caption probability: 0.23)\n", + "\t\t3: the sun setting over the eiffel tower and rooftops. (Caption probability: 0.15)\n", "\n", "\n", "Image: Wedges\n", "\tTop 3 captions ranked by CLIP:\n", - "\t\t1: a basket of baked fries with a sauce in it. (Caption probability: 0.60)\n", - "\t\t2: cooked french fries with ketchup and dip sitting in napkin. (Caption probability: 0.16)\n", - "\t\t3: some french fries with dipping sauce on the side. (Caption probability: 0.08)\n", + "\t\t1: sweet potato fries with ketchup served in bowl. (Caption probability: 0.73)\n", + "\t\t2: this is a plate of sweet potato fries with ketchup. (Caption probability: 0.16)\n", + "\t\t3: sweet potato fries and a dipping sauce are on the tray. (Caption probability: 0.06)\n", "\n", "\n", "Image: Hamsters\n", "\tTop 3 captions ranked by CLIP:\n", - "\t\t1: a person petting two small hamsters while in their home. (Caption probability: 0.51)\n", - "\t\t2: a woman holding two small white baby animals. (Caption probability: 0.23)\n", - "\t\t3: a hand holding a small mouse that looks tiny. (Caption probability: 0.09)\n", + "\t\t1: person holding two small animals in their hands. (Caption probability: 0.62)\n", + "\t\t2: a person's hand holding a small hamster in front of them. (Caption probability: 0.20)\n", + "\t\t3: a person holding a small animal in their hands. (Caption probability: 0.09)\n", "\n", "\n" ] } ] }, + { + "cell_type": "markdown", + "source": [ + "# Conclusion\n", + "\n", + "After running the pipeline, you can see the captions generated by the BLIP model and ranked by the CLIP model with all of our pre/postprocessing logic applied.\n", + "As you can see, running multi-model inference is easy with the power of Beam.\n" + ], + "metadata": { + "id": "gPCMXWgOtM_0" + } + }, { "cell_type": "markdown", "source": [