Skip to content

Commit

Permalink
fix: CI mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
scanny committed Apr 18, 2024
1 parent e744367 commit 2c3f32b
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 16 deletions.
2 changes: 1 addition & 1 deletion unstructured_inference/inference/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(
):
if detection_model is not None and element_extraction_model is not None:
raise ValueError("Only one of detection_model and extraction_model should be passed.")
self.image = image
self.image: Optional[Image.Image] = image
if image_metadata is None:
image_metadata = {}
self.image_metadata = image_metadata
Expand Down
6 changes: 4 additions & 2 deletions unstructured_inference/models/detectron2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from pathlib import Path
from typing import Any, Dict, Final, List, Optional, Union

Expand All @@ -7,7 +9,7 @@
is_detectron2_available,
)
from layoutparser.models.model_config import LayoutModelConfig
from PIL import Image
from PIL import Image as PILImage

from unstructured_inference.constants import ElementType
from unstructured_inference.inference.layoutelement import LayoutElement
Expand Down Expand Up @@ -65,7 +67,7 @@
class UnstructuredDetectronModel(UnstructuredObjectDetectionModel):
"""Unstructured model wrapper for Detectron2LayoutModel."""

def predict(self, x: Image):
def predict(self, x: PILImage.Image):
"""Makes a prediction using detectron2 model."""
super().predict(x)
prediction = self.model.detect(x)
Expand Down
6 changes: 3 additions & 3 deletions unstructured_inference/models/donut.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional, Union

import torch
from PIL import Image
from PIL import Image as PILImage
from transformers import (
DonutProcessor,
VisionEncoderDecoderConfig,
Expand All @@ -16,7 +16,7 @@
class UnstructuredDonutModel(UnstructuredModel):
"""Unstructured model wrapper for Donut image transformer."""

def predict(self, x: Image):
def predict(self, x: PILImage.Image):
"""Make prediction using donut model"""
super().predict(x)
return self.run_prediction(x)
Expand Down Expand Up @@ -50,7 +50,7 @@ def initialize(
raise ImportError("Review the parameters to initialize a UnstructuredDonutModel obj")
self.model.to(device)

def run_prediction(self, x: Image):
def run_prediction(self, x: PILImage.Image):
"""Internal prediction method."""
pixel_values = self.processor(x, return_tensors="pt").pixel_values
decoder_input_ids = self.processor.tokenizer(
Expand Down
14 changes: 7 additions & 7 deletions unstructured_inference/models/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import cv2
import numpy as np
import torch
from PIL import Image
from PIL import Image as PILImage
from transformers import DetrImageProcessor, TableTransformerForObjectDetection

from unstructured_inference.config import inference_config
Expand All @@ -27,7 +27,7 @@ class UnstructuredTableTransformerModel(UnstructuredModel):
def __init__(self):
pass

def predict(self, x: Image, ocr_tokens: Optional[List[Dict]] = None):
def predict(self, x: PILImage.Image, ocr_tokens: Optional[List[Dict]] = None):
"""Predict table structure deferring to run_prediction with ocr tokens
Note:
Expand Down Expand Up @@ -70,7 +70,7 @@ def initialize(

def get_structure(
self,
x: Image,
x: PILImage.Image,
pad_for_structure_detection: int = inference_config.TABLE_IMAGE_BACKGROUND_PAD,
) -> dict:
"""get the table structure as a dictionary contaning different types of elements as
Expand All @@ -87,7 +87,7 @@ def get_structure(

def run_prediction(
self,
x: Image,
x: PILImage.Image,
pad_for_structure_detection: int = inference_config.TABLE_IMAGE_BACKGROUND_PAD,
ocr_tokens: Optional[List[Dict]] = None,
result_format: Optional[str] = "html",
Expand Down Expand Up @@ -155,7 +155,7 @@ def get_class_map(data_type: str):
}


def recognize(outputs: dict, img: Image, tokens: list):
def recognize(outputs: dict, img: PILImage.Image, tokens: list):
"""Recognize table elements."""
str_class_name2idx = get_class_map("structure")
str_class_idx2name = {v: k for k, v in str_class_name2idx.items()}
Expand Down Expand Up @@ -655,7 +655,7 @@ def cells_to_html(cells):
return str(ET.tostring(table, encoding="unicode", short_empty_elements=False))


def zoom_image(image: Image, zoom: float) -> Image:
def zoom_image(image: PILImage.Image, zoom: float) -> PILImage.Image:
"""scale an image based on the zoom factor using cv2; the scaled image is post processed by
dilation then erosion to improve edge sharpness for OCR tasks"""
if zoom <= 0:
Expand All @@ -673,4 +673,4 @@ def zoom_image(image: Image, zoom: float) -> Image:
new_image = cv2.dilate(new_image, kernel, iterations=1)
new_image = cv2.erode(new_image, kernel, iterations=1)

return Image.fromarray(new_image)
return PILImage.fromarray(new_image)
6 changes: 3 additions & 3 deletions unstructured_inference/models/yolox.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import onnxruntime
from huggingface_hub import hf_hub_download
from onnxruntime.capi import _pybind_state as C
from PIL import Image
from PIL import Image as PILImage

from unstructured_inference.constants import ElementType, Source
from unstructured_inference.inference.layoutelement import LayoutElement
Expand Down Expand Up @@ -60,7 +60,7 @@


class UnstructuredYoloXModel(UnstructuredObjectDetectionModel):
def predict(self, x: Image):
def predict(self, x: PILImage.Image):
"""Predict using YoloX model."""
super().predict(x)
return self.image_processing(x)
Expand All @@ -86,7 +86,7 @@ def initialize(self, model_path: str, label_map: dict):

def image_processing(
self,
image: Image = None,
image: PILImage.Image,
) -> List[LayoutElement]:
"""Method runing YoloX for layout detection, returns a PageLayout
parameters
Expand Down

0 comments on commit 2c3f32b

Please sign in to comment.