diff --git a/CHANGELOG.md b/CHANGELOG.md index d8cc5b24..0bad1ab0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.7.30 + +* fix: table transformer doesn't return multiple cells with same coordinates +* ## 0.7.29 * fix: table transformer predictions are now removed if confidence is below threshold diff --git a/test_unstructured_inference/models/test_tables.py b/test_unstructured_inference/models/test_tables.py index 945e5fbb..10f845e7 100644 --- a/test_unstructured_inference/models/test_tables.py +++ b/test_unstructured_inference/models/test_tables.py @@ -7,10 +7,11 @@ from transformers.models.table_transformer.modeling_table_transformer import ( TableTransformerDecoder, ) +from copy import deepcopy import unstructured_inference.models.table_postprocess as postprocess from unstructured_inference.models import tables -from unstructured_inference.models.tables import apply_thresholds_on_objects +from unstructured_inference.models.tables import apply_thresholds_on_objects, structure_to_cells skip_outside_ci = os.getenv("CI", "").lower() in {"", "false", "f", "0"} @@ -1285,3 +1286,100 @@ def test_padded_results_has_right_dimensions(table_transformer, example_image): def test_compute_confidence_score_zero_division_error_handling(): assert tables.compute_confidence_score([]) == 0 + + +@pytest.mark.parametrize( + "column_span_score, row_span_score, expected_text_to_indexes", + [ + ( + 0.9, + 0.8, + ( + { + "one three": {"row_nums": [0, 1], "column_nums": [0]}, + "two": {"row_nums": [0], "column_nums": [1]}, + "four": {"row_nums": [1], "column_nums": [1]}, + } + ), + ), + ( + 0.8, + 0.9, + ( + { + "one two": {"row_nums": [0], "column_nums": [0, 1]}, + "three": {"row_nums": [1], "column_nums": [0]}, + "four": {"row_nums": [1], "column_nums": [1]}, + } + ), + ), + ], +) +def test_subcells_filtering_when_overlapping_spanning_cells( + column_span_score, row_span_score, expected_text_to_indexes +): + """ + # table + # +-----------+----------+ + # | one | two | + # |-----------+----------| + # | three | four | + # +-----------+----------+ + + spanning cells over first row and over first column + """ + table_structure = { + "rows": [ + {"bbox": [0, 0, 10, 20]}, + {"bbox": [10, 0, 20, 20]}, + ], + "columns": [ + {"bbox": [0, 0, 20, 10]}, + {"bbox": [0, 10, 20, 20]}, + ], + "spanning cells": [ + {"bbox": [0, 0, 20, 10], "score": column_span_score}, + {"bbox": [0, 0, 10, 20], "score": row_span_score}, + ], + } + tokens = [ + { + "text": "one", + "bbox": [0, 0, 10, 10], + }, + { + "text": "two", + "bbox": [0, 10, 10, 20], + }, + { + "text": "three", + "bbox": [10, 0, 20, 10], + }, + {"text": "four", "bbox": [10, 10, 20, 20]}, + ] + token_args = {"span_num": 1, "line_num": 1, "block_num": 1} + for token in tokens: + token.update(token_args) + for spanning_cell in table_structure["spanning cells"]: + spanning_cell["projected row header"] = False + + # table structure is edited inside structure_to_cells, save copy for future runs + saved_table_structure = deepcopy(table_structure) + + predicted_cells, _ = structure_to_cells(table_structure, tokens=tokens) + predicted_text_to_indexes = { + cell["cell text"]: { + "row_nums": cell["row_nums"], + "column_nums": cell["column_nums"], + } + for cell in predicted_cells + } + assert predicted_text_to_indexes == expected_text_to_indexes + + # swap spanning cells to ensure the highest prob spanning cell is used + spans = saved_table_structure["spanning cells"] + spans[0], spans[1] = spans[1], spans[0] + saved_table_structure["spanning cells"] = spans + + predicted_cells_after_reorder, _ = structure_to_cells(saved_table_structure, tokens=tokens) + assert predicted_cells_after_reorder == predicted_cells diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py index 2ec9a1ca..813471ef 100644 --- a/unstructured_inference/__version__.py +++ b/unstructured_inference/__version__.py @@ -1 +1 @@ -__version__ = "0.7.29" # pragma: no cover +__version__ = "0.7.30" # pragma: no cover diff --git a/unstructured_inference/models/tables.py b/unstructured_inference/models/tables.py index 1a81b354..5f0fb53e 100644 --- a/unstructured_inference/models/tables.py +++ b/unstructured_inference/models/tables.py @@ -484,6 +484,8 @@ def structure_to_cells(table_structure, tokens): columns = table_structure["columns"] rows = table_structure["rows"] spanning_cells = table_structure["spanning cells"] + spanning_cells = sorted(spanning_cells, reverse=True, key=lambda cell: cell["score"]) + cells = [] subcells = [] # Identify complete cells and subcells @@ -507,6 +509,7 @@ def structure_to_cells(table_structure, tokens): spanning_cell_rect.intersect(cell_rect).get_area() / cell_rect.get_area() ) > inference_config.TABLE_IOB_THRESHOLD: cell["subcell"] = True + cell["is_merged"] = False break if cell["subcell"]: @@ -528,7 +531,7 @@ def structure_to_cells(table_structure, tokens): subcell_rect_area = subcell_rect.get_area() if ( subcell_rect.intersect(spanning_cell_rect).get_area() / subcell_rect_area - ) > inference_config.TABLE_IOB_THRESHOLD: + ) > inference_config.TABLE_IOB_THRESHOLD and subcell["is_merged"] is False: if cell_rect is None: cell_rect = Rect(list(subcell["bbox"])) else: @@ -539,6 +542,8 @@ def structure_to_cells(table_structure, tokens): # as header cells for a spanning cell to be classified as a header cell; # otherwise, this could lead to a non-rectangular header region header = header and "column header" in subcell and subcell["column header"] + subcell["is_merged"] = True + if len(cell_rows) > 0 and len(cell_columns) > 0: cell = { "bbox": cell_rect.get_bbox(),