Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: calculate layouts iou to filter out incorrect extracted layout #334

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions unstructured_inference/inference/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,39 @@ def grow_region_to_match_region(region_to_grow: Rectangle, region_to_match: Rect
new_x2,
new_y2,
)

def regions_iou(
regions_a: Collection[TextRegion],
regions_b: Collection[TextRegion],
page_image_size: tuple[int, int],
show: bool = False,
) -> float:
"""Calculate intersection over union between two collections of regions."""
w, h = page_image_size
# numpy shape order is (height, width)
layout_a = np.zeros((h, w), dtype=np.uint8)
layout_b = layout_a.copy()
for regions, layout in ((regions_a, layout_a), (regions_b, layout_b)):
for region in regions:
x1 = int(region.bbox.x1)
x2 = int(region.bbox.x2)
y1 = int(region.bbox.y1)
y2 = int(region.bbox.y2)
# numpy shape order is (height, width)
layout[y1:y2,x1:x2] = 1
intersection = layout_a & layout_b
union = layout_a | layout_b
if show:
from PIL import Image
arrays = {
"layout_a": layout_a,
"layout_b": layout_b,
"intersection": intersection,
"union": union,
}
for title, array in arrays.items():
print(title)
Image.fromarray(array * 255).convert("RGB").show(title=title)
intersection_area = np.sum(intersection)
union_area = np.sum(union)
return (intersection_area / union_area).astype(float)
10 changes: 9 additions & 1 deletion unstructured_inference/inference/layoutelement.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,15 @@ def merge_inferred_layout_with_extracted_layout(
) -> List[LayoutElement]:
"""Merge two layouts to produce a single layout."""
extracted_elements_to_add: List[TextRegion] = []
inferred_regions_to_remove = []
inferred_regions_to_remove: List[LayoutElement] = []
w, h = page_image_size
full_page_region = Rectangle(0, 0, w, h)

from unstructured_inference.inference.elements import regions_iou
layouts_iou = regions_iou(extracted_layout, inferred_layout, page_image_size)
if layouts_iou < 0.75:
return list(inferred_layout)

for extracted_region in extracted_layout:
extracted_is_image = isinstance(extracted_region, ImageTextRegion)
if extracted_is_image:
Expand Down Expand Up @@ -140,6 +146,7 @@ def merge_inferred_layout_with_extracted_layout(
if extracted_is_image:
# keep extracted region, remove inferred region
inferred_regions_to_remove.append(inferred_region)
region_matched = False
else:
# keep inferred region, remove extracted region
grow_region_to_match_region(inferred_region.bbox, extracted_region.bbox)
Expand All @@ -159,6 +166,7 @@ def merge_inferred_layout_with_extracted_layout(
):
# keep extracted region, remove inferred region
inferred_regions_to_remove.append(inferred_region)
region_matched = False
if not region_matched:
extracted_elements_to_add.append(extracted_region)
# Need to classify the extracted layout elements we're keeping.
Expand Down
Loading