diff --git a/unstructured_inference/inference/layout.py b/unstructured_inference/inference/layout.py index 27c3eefe..8a7ee377 100644 --- a/unstructured_inference/inference/layout.py +++ b/unstructured_inference/inference/layout.py @@ -140,7 +140,7 @@ def __init__( ): if detection_model is not None and element_extraction_model is not None: raise ValueError("Only one of detection_model and extraction_model should be passed.") - self.image = image + self.image: Optional[Image.Image] = image if image_metadata is None: image_metadata = {} self.image_metadata = image_metadata diff --git a/unstructured_inference/models/detectron2.py b/unstructured_inference/models/detectron2.py index 98939f88..c38f6848 100644 --- a/unstructured_inference/models/detectron2.py +++ b/unstructured_inference/models/detectron2.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pathlib import Path from typing import Any, Dict, Final, List, Optional, Union @@ -7,7 +9,7 @@ is_detectron2_available, ) from layoutparser.models.model_config import LayoutModelConfig -from PIL import Image +from PIL import Image as PILImage from unstructured_inference.constants import ElementType from unstructured_inference.inference.layoutelement import LayoutElement @@ -65,7 +67,7 @@ class UnstructuredDetectronModel(UnstructuredObjectDetectionModel): """Unstructured model wrapper for Detectron2LayoutModel.""" - def predict(self, x: Image): + def predict(self, x: PILImage.Image): """Makes a prediction using detectron2 model.""" super().predict(x) prediction = self.model.detect(x) diff --git a/unstructured_inference/models/donut.py b/unstructured_inference/models/donut.py index 1f753f56..bc60d2c6 100644 --- a/unstructured_inference/models/donut.py +++ b/unstructured_inference/models/donut.py @@ -3,7 +3,7 @@ from typing import Optional, Union import torch -from PIL import Image +from PIL import Image as PILImage from transformers import ( DonutProcessor, VisionEncoderDecoderConfig, @@ -16,7 +16,7 @@ class UnstructuredDonutModel(UnstructuredModel): """Unstructured model wrapper for Donut image transformer.""" - def predict(self, x: Image): + def predict(self, x: PILImage.Image): """Make prediction using donut model""" super().predict(x) return self.run_prediction(x) @@ -50,7 +50,7 @@ def initialize( raise ImportError("Review the parameters to initialize a UnstructuredDonutModel obj") self.model.to(device) - def run_prediction(self, x: Image): + def run_prediction(self, x: PILImage.Image): """Internal prediction method.""" pixel_values = self.processor(x, return_tensors="pt").pixel_values decoder_input_ids = self.processor.tokenizer( diff --git a/unstructured_inference/models/tables.py b/unstructured_inference/models/tables.py index ca7e75a4..b6607812 100644 --- a/unstructured_inference/models/tables.py +++ b/unstructured_inference/models/tables.py @@ -8,7 +8,7 @@ import cv2 import numpy as np import torch -from PIL import Image +from PIL import Image as PILImage from transformers import DetrImageProcessor, TableTransformerForObjectDetection from unstructured_inference.config import inference_config @@ -27,7 +27,7 @@ class UnstructuredTableTransformerModel(UnstructuredModel): def __init__(self): pass - def predict(self, x: Image, ocr_tokens: Optional[List[Dict]] = None): + def predict(self, x: PILImage.Image, ocr_tokens: Optional[List[Dict]] = None): """Predict table structure deferring to run_prediction with ocr tokens Note: @@ -70,7 +70,7 @@ def initialize( def get_structure( self, - x: Image, + x: PILImage.Image, pad_for_structure_detection: int = inference_config.TABLE_IMAGE_BACKGROUND_PAD, ) -> dict: """get the table structure as a dictionary contaning different types of elements as @@ -87,7 +87,7 @@ def get_structure( def run_prediction( self, - x: Image, + x: PILImage.Image, pad_for_structure_detection: int = inference_config.TABLE_IMAGE_BACKGROUND_PAD, ocr_tokens: Optional[List[Dict]] = None, result_format: Optional[str] = "html", @@ -155,7 +155,7 @@ def get_class_map(data_type: str): } -def recognize(outputs: dict, img: Image, tokens: list): +def recognize(outputs: dict, img: PILImage.Image, tokens: list): """Recognize table elements.""" str_class_name2idx = get_class_map("structure") str_class_idx2name = {v: k for k, v in str_class_name2idx.items()} @@ -655,7 +655,7 @@ def cells_to_html(cells): return str(ET.tostring(table, encoding="unicode", short_empty_elements=False)) -def zoom_image(image: Image, zoom: float) -> Image: +def zoom_image(image: PILImage.Image, zoom: float) -> PILImage.Image: """scale an image based on the zoom factor using cv2; the scaled image is post processed by dilation then erosion to improve edge sharpness for OCR tasks""" if zoom <= 0: @@ -673,4 +673,4 @@ def zoom_image(image: Image, zoom: float) -> Image: new_image = cv2.dilate(new_image, kernel, iterations=1) new_image = cv2.erode(new_image, kernel, iterations=1) - return Image.fromarray(new_image) + return PILImage.fromarray(new_image) diff --git a/unstructured_inference/models/yolox.py b/unstructured_inference/models/yolox.py index 47455cf4..852e15b3 100644 --- a/unstructured_inference/models/yolox.py +++ b/unstructured_inference/models/yolox.py @@ -10,7 +10,7 @@ import onnxruntime from huggingface_hub import hf_hub_download from onnxruntime.capi import _pybind_state as C -from PIL import Image +from PIL import Image as PILImage from unstructured_inference.constants import ElementType, Source from unstructured_inference.inference.layoutelement import LayoutElement @@ -60,7 +60,7 @@ class UnstructuredYoloXModel(UnstructuredObjectDetectionModel): - def predict(self, x: Image): + def predict(self, x: PILImage.Image): """Predict using YoloX model.""" super().predict(x) return self.image_processing(x) @@ -86,7 +86,7 @@ def initialize(self, model_path: str, label_map: dict): def image_processing( self, - image: Image = None, + image: PILImage.Image, ) -> List[LayoutElement]: """Method runing YoloX for layout detection, returns a PageLayout parameters