Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add adaptive OCR, factor out treatment of OCR areas and cell filtering #38

Merged
merged 5 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docling/backend/abstract_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ def get_text_in_rect(self, bbox: "BoundingBox") -> str:
def get_text_cells(self) -> Iterable["Cell"]:
pass

@abstractmethod
def get_bitmap_rects(self, scale: int = 1) -> Iterable["BoundingBox"]:
pass

@abstractmethod
def get_page_image(
self, scale: int = 1, cropbox: Optional["BoundingBox"] = None
Expand Down
26 changes: 23 additions & 3 deletions docling/backend/docling_parse_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
from io import BytesIO
from pathlib import Path
from typing import Iterable, List, Optional, Union
from typing import Iterable, Optional, Union

import pypdfium2 as pdfium
from docling_parse.docling_parse import pdf_parser
Expand Down Expand Up @@ -43,7 +43,7 @@ def get_text_in_rect(self, bbox: BoundingBox) -> str:
r=x1 * scale * page_size.width / parser_width,
t=y1 * scale * page_size.height / parser_height,
coord_origin=CoordOrigin.BOTTOMLEFT,
).to_top_left_origin(page_size.height * scale)
).to_top_left_origin(page_height=page_size.height * scale)

overlap_frac = cell_bbox.intersection_area_with(bbox) / cell_bbox.area()

Expand All @@ -66,6 +66,12 @@ def get_text_cells(self) -> Iterable[Cell]:
for i in range(len(self._dpage["cells"])):
rect = self._dpage["cells"][i]["box"]["device"]
x0, y0, x1, y1 = rect

if x1 < x0:
x0, x1 = x1, x0
if y1 < y0:
y0, y1 = y1, y0

text_piece = self._dpage["cells"][i]["content"]["rnormalized"]
cells.append(
Cell(
Expand Down Expand Up @@ -108,6 +114,20 @@ def draw_clusters_and_cells():

return cells

def get_bitmap_rects(self, scale: int = 1) -> Iterable[BoundingBox]:
AREA_THRESHOLD = 32 * 32

for i in range(len(self._dpage["images"])):
bitmap = self._dpage["images"][i]
cropbox = BoundingBox.from_tuple(
bitmap["box"], origin=CoordOrigin.BOTTOMLEFT
).to_top_left_origin(self.get_size().height)

if cropbox.area() > AREA_THRESHOLD:
cropbox = cropbox.scaled(scale=scale)

yield cropbox

def get_page_image(
self, scale: int = 1, cropbox: Optional[BoundingBox] = None
) -> Image.Image:
Expand Down Expand Up @@ -173,7 +193,7 @@ def __init__(self, path_or_stream: Union[BytesIO, Path]):
def page_count(self) -> int:
return len(self._parser_doc["pages"])

def load_page(self, page_no: int) -> PdfPage:
def load_page(self, page_no: int) -> DoclingParsePageBackend:
return DoclingParsePageBackend(
self._pdoc[page_no], self._parser_doc["pages"][page_no]
)
Expand Down
16 changes: 15 additions & 1 deletion docling/backend/pypdfium2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Iterable, List, Optional, Union

import pypdfium2 as pdfium
import pypdfium2.raw as pdfium_c
from PIL import Image, ImageDraw
from pypdfium2 import PdfPage

Expand All @@ -17,6 +18,19 @@ def __init__(self, page_obj: PdfPage):
self._ppage = page_obj
self.text_page = None

def get_bitmap_rects(self, scale: int = 1) -> Iterable[BoundingBox]:
AREA_THRESHOLD = 32 * 32
for obj in self._ppage.get_objects(filter=[pdfium_c.FPDF_PAGEOBJ_IMAGE]):
pos = obj.get_pos()
cropbox = BoundingBox.from_tuple(
pos, origin=CoordOrigin.BOTTOMLEFT
).to_top_left_origin(page_height=self.get_size().height)

if cropbox.area() > AREA_THRESHOLD:
cropbox = cropbox.scaled(scale=scale)

yield cropbox

def get_text_in_rect(self, bbox: BoundingBox) -> str:
if not self.text_page:
self.text_page = self._ppage.get_textpage()
Expand Down Expand Up @@ -208,7 +222,7 @@ def __init__(self, path_or_stream: Union[BytesIO, Path]):
def page_count(self) -> int:
return len(self._pdoc)

def load_page(self, page_no: int) -> PdfPage:
def load_page(self, page_no: int) -> PyPdfiumPageBackend:
return PyPdfiumPageBackend(self._pdoc[page_no])

def is_valid(self) -> bool:
Expand Down
22 changes: 15 additions & 7 deletions docling/datamodel/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,21 @@ def as_tuple(self):
@classmethod
def from_tuple(cls, coord: Tuple[float], origin: CoordOrigin):
if origin == CoordOrigin.TOPLEFT:
return BoundingBox(
l=coord[0], t=coord[1], r=coord[2], b=coord[3], coord_origin=origin
)
l, t, r, b = coord[0], coord[1], coord[2], coord[3]
if r < l:
l, r = r, l
if b < t:
b, t = t, b

return BoundingBox(l=l, t=t, r=r, b=b, coord_origin=origin)
elif origin == CoordOrigin.BOTTOMLEFT:
return BoundingBox(
l=coord[0], b=coord[1], r=coord[2], t=coord[3], coord_origin=origin
)
l, b, r, t = coord[0], coord[1], coord[2], coord[3]
if r < l:
l, r = r, l
if b > t:
b, t = t, b

return BoundingBox(l=l, t=t, r=r, b=b, coord_origin=origin)

def area(self) -> float:
return (self.r - self.l) * (self.b - self.t)
Expand Down Expand Up @@ -280,7 +288,7 @@ class TableStructureOptions(BaseModel):

class PipelineOptions(BaseModel):
do_table_structure: bool = True # True: perform table structure extraction
do_ocr: bool = False # True: perform OCR, replace programmatic PDF text
do_ocr: bool = True # True: perform OCR, replace programmatic PDF text

table_structure_options: TableStructureOptions = TableStructureOptions()

Expand Down
2 changes: 0 additions & 2 deletions docling/document_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@


class DocumentConverter:
_layout_model_path = "model_artifacts/layout/beehive_v0.0.5"
_table_model_path = "model_artifacts/tableformer"
_default_download_filename = "file.pdf"

def __init__(
Expand Down
124 changes: 124 additions & 0 deletions docling/models/base_ocr_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import copy
import logging
from abc import abstractmethod
from typing import Iterable, List, Tuple

import numpy
import numpy as np
from PIL import Image, ImageDraw
from rtree import index
from scipy.ndimage import find_objects, label

from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell, Page

_log = logging.getLogger(__name__)


class BaseOcrModel:
def __init__(self, config):
self.config = config
self.enabled = config["enabled"]

# Computes the optimum amount and coordinates of rectangles to OCR on a given page
def get_ocr_rects(self, page: Page) -> Tuple[bool, List[BoundingBox]]:
BITMAP_COVERAGE_TRESHOLD = 0.75

def find_ocr_rects(size, bitmap_rects):
image = Image.new(
"1", (round(size.width), round(size.height))
) # '1' mode is binary

# Draw all bitmap rects into a binary image
draw = ImageDraw.Draw(image)
for rect in bitmap_rects:
x0, y0, x1, y1 = rect.as_tuple()
x0, y0, x1, y1 = round(x0), round(y0), round(x1), round(y1)
draw.rectangle([(x0, y0), (x1, y1)], fill=1)

np_image = np.array(image)

# Find the connected components
labeled_image, num_features = label(
np_image > 0
) # Label black (0 value) regions

# Find enclosing bounding boxes for each connected component.
slices = find_objects(labeled_image)
bounding_boxes = [
BoundingBox(
l=slc[1].start,
t=slc[0].start,
r=slc[1].stop - 1,
b=slc[0].stop - 1,
coord_origin=CoordOrigin.TOPLEFT,
)
for slc in slices
]

# Compute area fraction on page covered by bitmaps
area_frac = np.sum(np_image > 0) / (size.width * size.height)

return (area_frac, bounding_boxes) # fraction covered # boxes

bitmap_rects = page._backend.get_bitmap_rects()
coverage, ocr_rects = find_ocr_rects(page.size, bitmap_rects)

# return full-page rectangle if sufficiently covered with bitmaps
if coverage > BITMAP_COVERAGE_TRESHOLD:
return [
BoundingBox(
l=0,
t=0,
r=page.size.width,
b=page.size.height,
coord_origin=CoordOrigin.TOPLEFT,
)
]
# return individual rectangles if the bitmap coverage is smaller
elif coverage < BITMAP_COVERAGE_TRESHOLD:
return ocr_rects

# Filters OCR cells by dropping any OCR cell that intersects with an existing programmatic cell.
def filter_ocr_cells(self, ocr_cells, programmatic_cells):
# Create R-tree index for programmatic cells
p = index.Property()
p.dimension = 2
idx = index.Index(properties=p)
for i, cell in enumerate(programmatic_cells):
idx.insert(i, cell.bbox.as_tuple())

def is_overlapping_with_existing_cells(ocr_cell):
# Query the R-tree to get overlapping rectangles
possible_matches_index = list(idx.intersection(ocr_cell.bbox.as_tuple()))

return (
len(possible_matches_index) > 0
) # this is a weak criterion but it works.

filtered_ocr_cells = [
rect for rect in ocr_cells if not is_overlapping_with_existing_cells(rect)
]
return filtered_ocr_cells

def draw_ocr_rects_and_cells(self, page, ocr_rects):
image = copy.deepcopy(page.image)
draw = ImageDraw.Draw(image, "RGBA")

# Draw OCR rectangles as yellow filled rect
for rect in ocr_rects:
x0, y0, x1, y1 = rect.as_tuple()
shade_color = (255, 255, 0, 40) # transparent yellow
draw.rectangle([(x0, y0), (x1, y1)], fill=shade_color, outline=None)

# Draw OCR and programmatic cells
for tc in page.cells:
x0, y0, x1, y1 = tc.bbox.as_tuple()
color = "red"
if isinstance(tc, OcrCell):
color = "magenta"
draw.rectangle([(x0, y0), (x1, y1)], outline=color)
image.show()

@abstractmethod
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
pass
85 changes: 39 additions & 46 deletions docling/models/easyocr_model.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
import copy
import logging
import random
from typing import Iterable

import numpy
from PIL import ImageDraw

from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell, Page
from docling.models.base_ocr_model import BaseOcrModel

_log = logging.getLogger(__name__)


class EasyOcrModel:
class EasyOcrModel(BaseOcrModel):
def __init__(self, config):
self.config = config
self.enabled = config["enabled"]
super().__init__(config)

self.scale = 3 # multiplier for 72 dpi == 216 dpi.

if self.enabled:
Expand All @@ -29,49 +27,44 @@ def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
return

for page in page_batch:
# rects = page._fpage.
high_res_image = page.get_image(scale=self.scale)
im = numpy.array(high_res_image)
result = self.reader.readtext(im)

del high_res_image
del im

cells = [
OcrCell(
id=ix,
text=line[1],
confidence=line[2],
bbox=BoundingBox.from_tuple(
coord=(
line[0][0][0] / self.scale,
line[0][0][1] / self.scale,
line[0][2][0] / self.scale,
line[0][2][1] / self.scale,
),
origin=CoordOrigin.TOPLEFT,
),
ocr_rects = self.get_ocr_rects(page)

all_ocr_cells = []
for ocr_rect in ocr_rects:
high_res_image = page._backend.get_page_image(
scale=self.scale, cropbox=ocr_rect
)
for ix, line in enumerate(result)
]
im = numpy.array(high_res_image)
result = self.reader.readtext(im)

del high_res_image
del im

cells = [
OcrCell(
id=ix,
text=line[1],
confidence=line[2],
bbox=BoundingBox.from_tuple(
coord=(
(line[0][0][0] / self.scale) + ocr_rect.l,
(line[0][0][1] / self.scale) + ocr_rect.t,
(line[0][2][0] / self.scale) + ocr_rect.l,
(line[0][2][1] / self.scale) + ocr_rect.t,
),
origin=CoordOrigin.TOPLEFT,
),
)
for ix, line in enumerate(result)
]
all_ocr_cells.extend(cells)

page.cells = cells # For now, just overwrites all digital cells.
## Remove OCR cells which overlap with programmatic cells.
filtered_ocr_cells = self.filter_ocr_cells(all_ocr_cells, page.cells)

# DEBUG code:
def draw_clusters_and_cells():
image = copy.deepcopy(page.image)
draw = ImageDraw.Draw(image)

cell_color = (
random.randint(30, 140),
random.randint(30, 140),
random.randint(30, 140),
)
for tc in cells:
x0, y0, x1, y1 = tc.bbox.as_tuple()
draw.rectangle([(x0, y0), (x1, y1)], outline=cell_color)
image.show()
page.cells.extend(filtered_ocr_cells)

# draw_clusters_and_cells()
# DEBUG code:
# self.draw_ocr_rects_and_cells(page, ocr_rects)

yield page
1 change: 0 additions & 1 deletion docling/models/table_structure_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
import random
from typing import Iterable, List

import numpy
Expand Down
Loading
Loading