Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for google ocr #662

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docling/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
AcceleratorDevice,
AcceleratorOptions,
EasyOcrOptions,
GoogleOcrOptions,
OcrEngine,
OcrMacOptions,
OcrOptions,
Expand Down Expand Up @@ -335,6 +336,8 @@ def convert(
ocr_options = OcrMacOptions(force_full_page_ocr=force_ocr)
elif ocr_engine == OcrEngine.RAPIDOCR:
ocr_options = RapidOcrOptions(force_full_page_ocr=force_ocr)
elif ocr_engine == OcrEngine.GOOGLEOCR:
ocr_options = GoogleOcrOptions(force_full_page_ocr=force_ocr)
else:
raise RuntimeError(f"Unexpected OCR engine type {ocr_engine}")

Expand Down
15 changes: 15 additions & 0 deletions docling/datamodel/pipeline_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,19 @@ class EasyOcrOptions(OcrOptions):
)


class GoogleOcrOptions(OcrOptions):
"""Options for the dense GoogleOcr engine."""

kind: Literal["googleocr"] = "googleocr"
lang: List[str] = ["en", "de"]
google_ocr_config_file_path: Optional[str] = os.getenv("GOOGLE_CONFIG_FILE_PATH")
google_ocr_region: Optional[str] = "eu-vision.googleapis.com"

model_config = ConfigDict(
extra="forbid",
)


class TesseractCliOcrOptions(OcrOptions):
"""Options for the TesseractCli engine."""

Expand Down Expand Up @@ -205,6 +218,7 @@ class OcrEngine(str, Enum):
TESSERACT = "tesseract"
OCRMAC = "ocrmac"
RAPIDOCR = "rapidocr"
GOOGLEOCR = "googleocr"


class PipelineOptions(BaseModel):
Expand All @@ -231,6 +245,7 @@ class PdfPipelineOptions(PipelineOptions):
TesseractOcrOptions,
OcrMacOptions,
RapidOcrOptions,
GoogleOcrOptions,
] = Field(EasyOcrOptions(), discriminator="kind")

images_scale: float = 1.0
Expand Down
180 changes: 180 additions & 0 deletions docling/models/google_ocr_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import io
import logging
from typing import Iterable

from docling_core.types.doc import BoundingBox, CoordOrigin

from docling.datamodel.base_models import Cell, OcrCell, Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import GoogleOcrOptions
from docling.datamodel.settings import settings
from docling.models.base_ocr_model import BaseOcrModel
from docling.utils.profiling import TimeRecorder

_log = logging.getLogger(__name__)


class GoogleOcrModel(BaseOcrModel):
def __init__(self, enabled: bool, options: GoogleOcrOptions):
super().__init__(enabled=enabled, options=options)
self.options: GoogleOcrOptions

self.scale = 3 # multiplier for 72 dpi == 216 dpi.
self.reader = None

if self.enabled:
try:
from google.cloud import vision
from google.oauth2 import service_account

# Initialize the tesseractAPI
_log.debug("Initializing Google OCR ")
self.image_context = {"language_hints": self.options.lang}
client_options = {"api_endpoint": self.options.google_ocr_region}
if self.options.google_ocr_config_file_path is None:
raise FileNotFoundError(
"Google OCR Config File is missing. Please provide a valid file path "
"via the GOOGLE_CONFIG_FILE_PATH environment variable."
)
google_creds = service_account.Credentials.from_service_account_file(
self.options.google_ocr_config_file_path
)
self.reader = vision.ImageAnnotatorClient(
credentials=google_creds, client_options=client_options
)

except ImportError:
raise ImportError(
"Failed to import required libraries for Google OCR. Ensure that the "
"'google-cloud-vision' and 'google-auth' packages are installed. "
"You can install them using 'pip install google-cloud-vision google-auth'."
)

def __del__(self):
if self.reader is not None:
pass

def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:

if not self.enabled:
yield from page_batch
return

for page in page_batch:

assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
with TimeRecorder(conv_res, "ocr"):

assert self.reader is not None

ocr_rects = self.get_ocr_rects(page)
try:
all_ocr_cells = []
for ocr_rect in ocr_rects:
# Skip zero area boxes
if ocr_rect.area() == 0:
continue
high_res_image = page._backend.get_page_image(
scale=self.scale, cropbox=ocr_rect
)
# Convert Pillow image to content, represented as a stream of bytes, using IO buffer.
buffer = io.BytesIO()
try:
from google.cloud import vision
from google.oauth2 import service_account
except:
raise Exception

high_res_image.save(buffer, "PNG")
content = buffer.getvalue()

new_image = vision.Image(content=content)
google_response = self.reader.text_detection(
image=new_image, image_context=self.image_context
)

cells = []
ix = 0
for file_page in google_response.full_text_annotation.pages:
for block in file_page.blocks:
for paragraph in block.paragraphs:
for word in paragraph.words:
box = word.bounding_box.vertices
text = ""
for symbol in word.symbols:
text += symbol.text

# Extract text within the bounding box
confidence = word.confidence * 100
left = (
min(
box[0].x,
box[1].x,
box[2].x,
box[3].x,
)
/ self.scale
) + ocr_rect.l
bottom = (
max(
box[0].y,
box[1].y,
box[2].y,
box[3].y,
)
/ self.scale
) + ocr_rect.t
top = (
min(
box[0].y,
box[1].y,
box[2].y,
box[3].y,
)
/ self.scale
) + ocr_rect.t
right = (
max(
box[0].x,
box[1].x,
box[2].x,
box[3].x,
)
/ self.scale
) + ocr_rect.l

cells.append(
OcrCell(
id=ix,
text=text,
confidence=confidence,
bbox=BoundingBox.from_tuple(
coord=(
left,
top,
right,
bottom,
),
origin=CoordOrigin.TOPLEFT,
),
)
)
ix += 1

del high_res_image, buffer, content
all_ocr_cells.extend(cells)
except Exception as e:
raise e
# Post-process the cells
page.cells = self.post_process_cells(all_ocr_cells, page.cells)

# DEBUG code:
if settings.debug.visualize_ocr:
self.draw_ocr_rects_and_cells(conv_res, page, ocr_rects, show=True)

yield page
7 changes: 7 additions & 0 deletions docling/pipeline/standard_pdf_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
EasyOcrOptions,
GoogleOcrOptions,
OcrMacOptions,
PdfPipelineOptions,
RapidOcrOptions,
Expand All @@ -20,6 +21,7 @@
from docling.models.base_ocr_model import BaseOcrModel
from docling.models.ds_glm_model import GlmModel, GlmOptions
from docling.models.easyocr_model import EasyOcrModel
from docling.models.google_ocr_model import GoogleOcrModel
from docling.models.layout_model import LayoutModel
from docling.models.ocr_mac_model import OcrMacModel
from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions
Expand Down Expand Up @@ -143,6 +145,11 @@ def get_ocr_model(self) -> Optional[BaseOcrModel]:
enabled=self.pipeline_options.do_ocr,
options=self.pipeline_options.ocr_options,
)
elif isinstance(self.pipeline_options.ocr_options, GoogleOcrOptions):
return GoogleOcrModel(
enabled=self.pipeline_options.do_ocr,
options=self.pipeline_options.ocr_options,
)
return None

def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page:
Expand Down
Loading