Skip to content

Commit

Permalink
Feat/improve chipper bounding boxes (#292)
Browse files Browse the repository at this point in the history
Chipper bounding boxes largely exceed the area where the annotated text
is. This PR intends solving this problem.

There are several fixes applied:

* Improved cross attention processing. Maps are filtered at the head
level, which is normalised [0-1], this improves the bounding box
definition.
* The correlation between the token index and cross attention map has
been resolved. Before, with beam search size = 1 or 3 there were cases
in which the cross attention map did not match the token being
processed.

In addition, the empty areas of the bounding boxes have been cleaned and
the overlaps between bounding boxes are resolved by identifying the
largest margin that separates both bounding boxes, identified as the
largest gap without text either in the horizontal or vertical
directions.

Note: overlapping bounding boxes for child elements are not resolved for
now. In the case of a list with list items, the list element will be
affected by the overlapping resolution code.

---------

Co-authored-by: Antonio Jimeno Yepes <[email protected]>
  • Loading branch information
ajjimeno and ajjimeno authored Nov 29, 2023
1 parent 0f0c2be commit 631e6fb
Show file tree
Hide file tree
Showing 4 changed files with 668 additions and 31 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 0.7.16-dev1

* enhancement: improved Chipper bounding boxes

## 0.7.16

* bug: Allow supplied ONNX models to use label_map dictionary from json file
Expand Down
121 changes: 120 additions & 1 deletion test_unstructured_inference/models/test_chippermodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,10 @@ def test_no_repeat_ngram_logits():
@pytest.mark.parametrize(
("decoded_str", "expected_classes"),
[
("<s><s_Misc> 1</s_Misc><s_Text>There is some text here.</s_Text></s>", ["Misc", "Text"]),
(
"<s><s_Misc> 1</s_Misc><s_Text>There is some text here.</s_Text></s>",
["Misc", "Text"],
),
(
"<s><s_List><s_List-item>Text here.</s_List-item><s_List><s_List-item>Final one",
["List", "List-item", "List", "List-item"],
Expand Down Expand Up @@ -245,3 +248,119 @@ def test_run_chipper_v2():
tables = [el for el in elements if el.type == "Table"]
assert all(table.text_as_html.startswith("<table>") for table in tables)
assert all("<table>" not in table.text for table in tables)


@pytest.mark.parametrize(
("bbox", "output"),
[
(
[0, 0, 0, 0],
None,
),
(
[0, 1, 1, -1],
None,
),
],
)
def test_largest_margin(bbox, output):
model = get_model("chipper")
img = Image.open("sample-docs/easy_table.jpg")
assert model.largest_margin(img, bbox) is output


@pytest.mark.parametrize(
("bbox", "output"),
[
(
[0, 1, 0, -1],
[0, 1, 0, -1],
),
(
[0, 1, 1, -1],
[0, 1, 1, -1],
),
(
[20, 10, 30, 40],
[20, 10, 30, 40],
),
],
)
def test_reduce_bbox_overlap(bbox, output):
model = get_model("chipper")
img = Image.open("sample-docs/easy_table.jpg")
assert model.reduce_bbox_overlap(img, bbox) == output


@pytest.mark.parametrize(
("bbox", "output"),
[
(
[20, 10, 30, 40],
[20, 10, 30, 40],
),
],
)
def test_reduce_bbox_no_overlap(bbox, output):
model = get_model("chipper")
img = Image.open("sample-docs/easy_table.jpg")
assert model.reduce_bbox_no_overlap(img, bbox) == output


@pytest.mark.parametrize(
("bbox1", "bbox2", "output"),
[
(
[0, 50, 20, 80],
[10, 10, 30, 30],
(
"horizontal",
[10, 10, 30, 30],
[0, 50, 20, 80],
[0, 50, 20, 80],
[10, 10, 30, 30],
None,
),
),
(
[10, 10, 30, 30],
[40, 10, 60, 30],
(
"vertical",
[40, 10, 60, 30],
[10, 10, 30, 30],
[10, 10, 30, 30],
[40, 10, 60, 30],
None,
),
),
(
[10, 80, 30, 100],
[40, 10, 60, 30],
(
"none",
[40, 10, 60, 30],
[10, 80, 30, 100],
[10, 80, 30, 100],
[40, 10, 60, 30],
None,
),
),
(
[40, 10, 60, 30],
[10, 10, 30, 30],
(
"vertical",
[10, 10, 30, 30],
[40, 10, 60, 30],
[10, 10, 30, 30],
[40, 10, 60, 30],
None,
),
),
],
)
def test_check_overlap(bbox1, bbox2, output):
model = get_model("chipper")

assert model.check_overlap(bbox1, bbox2) == output
2 changes: 1 addition & 1 deletion unstructured_inference/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.7.16" # pragma: no cover
__version__ = "0.7.16-dev1" # pragma: no cover
Loading

0 comments on commit 631e6fb

Please sign in to comment.