From 035417c90a867843a9d110252e46b6ced1e2ab7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Kmiecik?= Date: Mon, 10 Jun 2024 11:45:04 +0200 Subject: [PATCH 1/7] feat: modified cells_to_html and fill_cells functions to genrate html tables with correct syntax --- unstructured_inference/models/tables.py | 63 +++++++++++++++++++------ 1 file changed, 48 insertions(+), 15 deletions(-) diff --git a/unstructured_inference/models/tables.py b/unstructured_inference/models/tables.py index 48f4c383..9be0264c 100644 --- a/unstructured_inference/models/tables.py +++ b/unstructured_inference/models/tables.py @@ -648,11 +648,8 @@ def structure_to_cells(table_structure, tokens): def fill_cells(cells: List[dict]) -> List[dict]: - """add empty cells to pad cells that spans multiple rows for html conversion - - For example if a cell takes row 0 and 1 and column 0, we add a new empty cell at row 1 and - column 0. This padding ensures the structure of the output table is intact. In this example the - cell data is {"row_nums": [0, 1], "column_nums": [0], ...} + """fills the missing cells in the table by adding a cells with empty text + where there are no cells detected by the model. A cell contains the following keys relevent to the html conversion: row_nums: List[int] @@ -663,25 +660,60 @@ def fill_cells(cells: List[dict]) -> List[dict]: than one numbers cell text: str the text in this cell + column header: bool + whether this cell is a column header """ - new_cells = cells.copy() + table_rows_no = max(set([row for cell in cells for row in cell["row_nums"]])) + table_cols_no = max(set([col for cell in cells for col in cell["column_nums"]])) + filled = np.zeros((table_rows_no + 1, table_cols_no + 1), dtype=bool) for cell in cells: - for extra_row in sorted(cell["row_nums"][1:]): - new_cell = cell.copy() - new_cell["row_nums"] = [extra_row] - new_cell["cell text"] = "" - new_cells.append(new_cell) + for row in cell["row_nums"]: + for col in cell["column_nums"]: + filled[row, col] = True + # add cells for which filled is false + header_rows = set([row for cell in cells if cell["column header"] for row in cell["row_nums"]]) + new_cells = cells.copy() + not_filled_idx = np.where(filled == False) + for row, col in zip(not_filled_idx[0], not_filled_idx[1]): + new_cell = { + "row_nums": [row], + "column_nums": [col], + "cell text": "", + "column header": row in header_rows + } + new_cells.append(new_cell) return new_cells -def cells_to_html(cells): - """Convert table structure to html format.""" +def cells_to_html(cells: List[dict]) -> str: + """Convert table structure to html format. + + Args: + cells: List of dictionaries representing table cells, where each dictionary has the + following format: + { + "row_nums": List[int], + "column_nums": List[int], + "cell text": str, + "column header": bool, + } + Returns: + str: HTML table string + """ cells = sorted(fill_cells(cells), key=lambda k: (min(k["row_nums"]), min(k["column_nums"]))) + # cells = sorted(cells, key=lambda k: (min(k["row_nums"]), min(k["column_nums"]))) table = ET.Element("table") current_row = -1 + table_header = None + table_has_header = any(cell["column header"] for cell in cells) + if table_has_header: + table_header = ET.SubElement(table, "thead") + + table_body = ET.SubElement(table, "tbody") + table_subelement = None for cell in cells: this_row = min(cell["row_nums"]) @@ -695,11 +727,12 @@ def cells_to_html(cells): if this_row > current_row: current_row = this_row if cell["column header"]: + table_subelement = table_header cell_tag = "th" - row = ET.SubElement(table, "thead") else: cell_tag = "td" - row = ET.SubElement(table, "tr") + table_subelement = table_body + row = ET.SubElement(table_subelement, "tr") tcell = ET.SubElement(row, cell_tag, attrib=attrib) tcell.text = cell["cell text"] From 207b2d0fefc21fc62d98b90fb9b3adb0f5dbd9cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Kmiecik?= Date: Mon, 10 Jun 2024 11:45:26 +0200 Subject: [PATCH 2/7] test: updated tests for table generation --- .../models/test_tables.py | 1042 ++++++++++------- 1 file changed, 615 insertions(+), 427 deletions(-) diff --git a/test_unstructured_inference/models/test_tables.py b/test_unstructured_inference/models/test_tables.py index 10f845e7..184c8a63 100644 --- a/test_unstructured_inference/models/test_tables.py +++ b/test_unstructured_inference/models/test_tables.py @@ -596,40 +596,40 @@ def test_load_donut_model(model_path): ("input_test", "output_test"), [ ( - [ - { - "label": "table column header", - "score": 0.9349299073219299, - "bbox": [ - 47.83147430419922, - 116.8877944946289, - 2557.79296875, - 216.98883056640625, - ], - }, - { - "label": "table column header", - "score": 0.934, - "bbox": [ - 47.83147430419922, - 116.8877944946289, - 2557.79296875, - 216.98883056640625, - ], - }, - ], - [ - { - "label": "table column header", - "score": 0.9349299073219299, - "bbox": [ - 47.83147430419922, - 116.8877944946289, - 2557.79296875, - 216.98883056640625, - ], - }, - ], + [ + { + "label": "table column header", + "score": 0.9349299073219299, + "bbox": [ + 47.83147430419922, + 116.8877944946289, + 2557.79296875, + 216.98883056640625, + ], + }, + { + "label": "table column header", + "score": 0.934, + "bbox": [ + 47.83147430419922, + 116.8877944946289, + 2557.79296875, + 216.98883056640625, + ], + }, + ], + [ + { + "label": "table column header", + "score": 0.9349299073219299, + "bbox": [ + 47.83147430419922, + 116.8877944946289, + 2557.79296875, + 216.98883056640625, + ], + }, + ], ), ([], []), ], @@ -644,234 +644,234 @@ def test_nms(input_test, output_test): ("supercell1", "supercell2"), [ ( - { - "label": "table spanning cell", - "score": 0.526617169380188, - "bbox": [ - 1446.2801513671875, - 1023.817138671875, - 2114.3525390625, - 1099.20166015625, - ], - "projected row header": False, - "header": False, - "row_numbers": [3, 4], - "column_numbers": [0, 4], - }, - { - "label": "table spanning cell", - "score": 0.5199193954467773, - "bbox": [ - 98.92312622070312, - 676.1566772460938, - 751.0982666015625, - 938.5986938476562, - ], - "projected row header": False, - "header": False, - "row_numbers": [3, 4, 6], - "column_numbers": [0, 4], - }, - ), - ( - { - "label": "table spanning cell", - "score": 0.526617169380188, - "bbox": [ - 1446.2801513671875, - 1023.817138671875, - 2114.3525390625, - 1099.20166015625, - ], - "projected row header": False, - "header": False, - "row_numbers": [3, 4], - "column_numbers": [0, 4], - }, - { - "label": "table spanning cell", - "score": 0.5199193954467773, - "bbox": [ - 98.92312622070312, - 676.1566772460938, - 751.0982666015625, - 938.5986938476562, - ], - "projected row header": False, - "header": False, - "row_numbers": [4], - "column_numbers": [0, 4, 6], - }, - ), - ( - { - "label": "table spanning cell", - "score": 0.526617169380188, - "bbox": [ - 1446.2801513671875, - 1023.817138671875, - 2114.3525390625, - 1099.20166015625, - ], - "projected row header": False, - "header": False, - "row_numbers": [3, 4], - "column_numbers": [1, 4], - }, - { - "label": "table spanning cell", - "score": 0.5199193954467773, - "bbox": [ - 98.92312622070312, - 676.1566772460938, - 751.0982666015625, - 938.5986938476562, - ], - "projected row header": False, - "header": False, - "row_numbers": [4], - "column_numbers": [0, 4, 6], - }, - ), - ( - { - "label": "table spanning cell", - "score": 0.526617169380188, - "bbox": [ - 1446.2801513671875, - 1023.817138671875, - 2114.3525390625, - 1099.20166015625, - ], - "projected row header": False, - "header": False, - "row_numbers": [3, 4], - "column_numbers": [1, 4], - }, - { - "label": "table spanning cell", - "score": 0.5199193954467773, - "bbox": [ - 98.92312622070312, - 676.1566772460938, - 751.0982666015625, - 938.5986938476562, - ], - "projected row header": False, - "header": False, - "row_numbers": [2, 4, 5, 6, 7, 8], - "column_numbers": [0, 4, 6], - }, - ), - ], -) -def test_remove_supercell_overlap(supercell1, supercell2): - assert postprocess.remove_supercell_overlap(supercell1, supercell2) is None - - -@pytest.mark.parametrize( - ("supercells", "rows", "columns", "output_test"), - [ - ( - [ { "label": "table spanning cell", - "score": 0.9, + "score": 0.526617169380188, "bbox": [ - 98.92312622070312, - 143.11549377441406, - 2115.197265625, - 1238.27587890625, + 1446.2801513671875, + 1023.817138671875, + 2114.3525390625, + 1099.20166015625, ], - "projected row header": True, - "header": True, - "span": True, - }, - ], - [ - { - "label": "table row", - "score": 0.9299452900886536, - "bbox": [0, 0, 10, 10], - "column header": True, - "header": True, + "projected row header": False, + "header": False, + "row_numbers": [3, 4], + "column_numbers": [0, 4], }, { - "label": "table row", - "score": 0.9299452900886536, + "label": "table spanning cell", + "score": 0.5199193954467773, "bbox": [ 98.92312622070312, - 143.11549377441406, - 2114.3525390625, - 193.67681884765625, + 676.1566772460938, + 751.0982666015625, + 938.5986938476562, ], - "column header": True, - "header": True, + "projected row header": False, + "header": False, + "row_numbers": [3, 4, 6], + "column_numbers": [0, 4], }, + ), + ( { - "label": "table row", - "score": 0.9299452900886536, + "label": "table spanning cell", + "score": 0.526617169380188, "bbox": [ - 98.92312622070312, - 143.11549377441406, + 1446.2801513671875, + 1023.817138671875, 2114.3525390625, - 193.67681884765625, + 1099.20166015625, ], - "column header": True, - "header": True, + "projected row header": False, + "header": False, + "row_numbers": [3, 4], + "column_numbers": [0, 4], }, - ], - [ { - "label": "table column", - "score": 0.9996132254600525, + "label": "table spanning cell", + "score": 0.5199193954467773, "bbox": [ 98.92312622070312, - 143.11549377441406, - 517.6508178710938, - 1616.48779296875, + 676.1566772460938, + 751.0982666015625, + 938.5986938476562, ], + "projected row header": False, + "header": False, + "row_numbers": [4], + "column_numbers": [0, 4, 6], }, + ), + ( { - "label": "table column", - "score": 0.9935646653175354, + "label": "table spanning cell", + "score": 0.526617169380188, "bbox": [ - 520.0474853515625, - 143.11549377441406, - 751.0982666015625, - 1616.48779296875, + 1446.2801513671875, + 1023.817138671875, + 2114.3525390625, + 1099.20166015625, ], + "projected row header": False, + "header": False, + "row_numbers": [3, 4], + "column_numbers": [1, 4], }, - ], - [ { "label": "table spanning cell", - "score": 0.9, + "score": 0.5199193954467773, "bbox": [ 98.92312622070312, - 143.11549377441406, + 676.1566772460938, 751.0982666015625, - 193.67681884765625, + 938.5986938476562, ], - "projected row header": True, - "header": True, - "span": True, - "row_numbers": [1, 2], - "column_numbers": [0, 1], + "projected row header": False, + "header": False, + "row_numbers": [4], + "column_numbers": [0, 4, 6], }, + ), + ( { - "row_numbers": [0], - "column_numbers": [0, 1], - "score": 0.9, - "propagated": True, + "label": "table spanning cell", + "score": 0.526617169380188, + "bbox": [ + 1446.2801513671875, + 1023.817138671875, + 2114.3525390625, + 1099.20166015625, + ], + "projected row header": False, + "header": False, + "row_numbers": [3, 4], + "column_numbers": [1, 4], + }, + { + "label": "table spanning cell", + "score": 0.5199193954467773, "bbox": [ 98.92312622070312, - 143.11549377441406, + 676.1566772460938, 751.0982666015625, - 193.67681884765625, + 938.5986938476562, ], + "projected row header": False, + "header": False, + "row_numbers": [2, 4, 5, 6, 7, 8], + "column_numbers": [0, 4, 6], }, - ], + ), + ], +) +def test_remove_supercell_overlap(supercell1, supercell2): + assert postprocess.remove_supercell_overlap(supercell1, supercell2) is None + + +@pytest.mark.parametrize( + ("supercells", "rows", "columns", "output_test"), + [ + ( + [ + { + "label": "table spanning cell", + "score": 0.9, + "bbox": [ + 98.92312622070312, + 143.11549377441406, + 2115.197265625, + 1238.27587890625, + ], + "projected row header": True, + "header": True, + "span": True, + }, + ], + [ + { + "label": "table row", + "score": 0.9299452900886536, + "bbox": [0, 0, 10, 10], + "column header": True, + "header": True, + }, + { + "label": "table row", + "score": 0.9299452900886536, + "bbox": [ + 98.92312622070312, + 143.11549377441406, + 2114.3525390625, + 193.67681884765625, + ], + "column header": True, + "header": True, + }, + { + "label": "table row", + "score": 0.9299452900886536, + "bbox": [ + 98.92312622070312, + 143.11549377441406, + 2114.3525390625, + 193.67681884765625, + ], + "column header": True, + "header": True, + }, + ], + [ + { + "label": "table column", + "score": 0.9996132254600525, + "bbox": [ + 98.92312622070312, + 143.11549377441406, + 517.6508178710938, + 1616.48779296875, + ], + }, + { + "label": "table column", + "score": 0.9935646653175354, + "bbox": [ + 520.0474853515625, + 143.11549377441406, + 751.0982666015625, + 1616.48779296875, + ], + }, + ], + [ + { + "label": "table spanning cell", + "score": 0.9, + "bbox": [ + 98.92312622070312, + 143.11549377441406, + 751.0982666015625, + 193.67681884765625, + ], + "projected row header": True, + "header": True, + "span": True, + "row_numbers": [1, 2], + "column_numbers": [0, 1], + }, + { + "row_numbers": [0], + "column_numbers": [0, 1], + "score": 0.9, + "propagated": True, + "bbox": [ + 98.92312622070312, + 143.11549377441406, + 751.0982666015625, + 193.67681884765625, + ], + }, + ], ), ], ) @@ -889,26 +889,26 @@ def test_align_rows(rows, bbox, output): [ ("html", "Blind51434.5%, n=1"), ( - "cells", - { - "column_nums": [0], - "row_nums": [2], - "column header": False, - "cell text": "Blind", - }, + "cells", + { + "column_nums": [0], + "row_nums": [2], + "column header": False, + "cell text": "Blind", + }, ), ("dataframe", ["Blind", "5", "1", "4", "34.5%, n=1", "1199 sec, n=1"]), (None, "Blind51434.5%, n=1"), ], ) def test_table_prediction_output_format( - output_format, - expectation, - table_transformer, - example_image, - mocker, - example_table_cells, - mocked_ocr_tokens, + output_format, + expectation, + table_transformer, + example_image, + mocker, + example_table_cells, + mocked_ocr_tokens, ): mocker.patch.object(tables, "recognize", return_value=example_table_cells) mocker.patch.object( @@ -935,11 +935,11 @@ def test_table_prediction_output_format( def test_table_prediction_output_format_when_wrong_type_then_value_error( - table_transformer, - example_image, - mocker, - example_table_cells, - mocked_ocr_tokens, + table_transformer, + example_image, + mocker, + example_table_cells, + mocked_ocr_tokens, ): mocker.patch.object(tables, "recognize", return_value=example_table_cells) mocker.patch.object( @@ -954,10 +954,10 @@ def test_table_prediction_output_format_when_wrong_type_then_value_error( def test_table_prediction_runs_with_empty_recognize( - table_transformer, - example_image, - mocker, - mocked_ocr_tokens, + table_transformer, + example_image, + mocker, + mocked_ocr_tokens, ): mocker.patch.object(tables, "recognize", return_value=[]) mocker.patch.object( @@ -988,7 +988,7 @@ def test_table_prediction_with_no_ocr_tokens(table_transformer, example_image): ], ) def test_objects_are_filtered_based_on_class_thresholds_when_correct_prediction_and_threshold( - thresholds, expected_object_number + thresholds, expected_object_number ): objects = [ {"label": "0", "score": 0.2}, @@ -1007,7 +1007,7 @@ def test_objects_are_filtered_based_on_class_thresholds_when_correct_prediction_ ], ) def test_objects_are_filtered_based_on_class_thresholds_when_two_classes( - thresholds, expected_object_number + thresholds, expected_object_number ): objects = [ {"label": "0", "score": 0.2}, @@ -1043,98 +1043,98 @@ def test_include_rect(): ("spans", "join_with_space", "expected"), [ ( - [ - { - "flags": 2**0, - "text": "5", - "superscript": False, - "span_num": 0, - "line_num": 0, - "block_num": 0, - }, - ], - True, - "", + [ + { + "flags": 2 ** 0, + "text": "5", + "superscript": False, + "span_num": 0, + "line_num": 0, + "block_num": 0, + }, + ], + True, + "", ), ( - [ - { - "flags": 2**0, - "text": "p", - "superscript": False, - "span_num": 0, - "line_num": 0, - "block_num": 0, - }, - ], - True, - "p", + [ + { + "flags": 2 ** 0, + "text": "p", + "superscript": False, + "span_num": 0, + "line_num": 0, + "block_num": 0, + }, + ], + True, + "p", ), ( - [ - { - "flags": 2**0, - "text": "p", - "superscript": False, - "span_num": 0, - "line_num": 0, - "block_num": 0, - }, - { - "flags": 2**0, - "text": "p", - "superscript": False, - "span_num": 0, - "line_num": 0, - "block_num": 0, - }, - ], - True, - "p p", + [ + { + "flags": 2 ** 0, + "text": "p", + "superscript": False, + "span_num": 0, + "line_num": 0, + "block_num": 0, + }, + { + "flags": 2 ** 0, + "text": "p", + "superscript": False, + "span_num": 0, + "line_num": 0, + "block_num": 0, + }, + ], + True, + "p p", ), ( - [ - { - "flags": 2**0, - "text": "p", - "superscript": False, - "span_num": 0, - "line_num": 0, - "block_num": 0, - }, - { - "flags": 2**0, - "text": "p", - "superscript": False, - "span_num": 0, - "line_num": 0, - "block_num": 1, - }, - ], - True, - "p p", + [ + { + "flags": 2 ** 0, + "text": "p", + "superscript": False, + "span_num": 0, + "line_num": 0, + "block_num": 0, + }, + { + "flags": 2 ** 0, + "text": "p", + "superscript": False, + "span_num": 0, + "line_num": 0, + "block_num": 1, + }, + ], + True, + "p p", ), ( - [ - { - "flags": 2**0, - "text": "p", - "superscript": False, - "span_num": 0, - "line_num": 0, - "block_num": 0, - }, - { - "flags": 2**0, - "text": "p", - "superscript": False, - "span_num": 0, - "line_num": 0, - "block_num": 1, - }, - ], - False, - "p p", + [ + { + "flags": 2 ** 0, + "text": "p", + "superscript": False, + "span_num": 0, + "line_num": 0, + "block_num": 0, + }, + { + "flags": 2 ** 0, + "text": "p", + "superscript": False, + "span_num": 0, + "line_num": 0, + "block_num": 1, + }, + ], + False, + "p p", ), ], ) @@ -1152,62 +1152,62 @@ def test_extract_text_from_spans(spans, join_with_space, expected): [ ([{"header": "hi", "row_numbers": [0, 1, 2], "score": 0.9}], 1), ( - [ - { - "header": "hi", - "row_numbers": [0], - "column_numbers": [1, 2, 3], - "score": 0.9, - }, - { - "header": "hi", - "row_numbers": [1], - "column_numbers": [1], - "score": 0.9, - }, - { - "header": "hi", - "row_numbers": [1], - "column_numbers": [2], - "score": 0.9, - }, - { - "header": "hi", - "row_numbers": [1], - "column_numbers": [3], - "score": 0.9, - }, - ], - 4, + [ + { + "header": "hi", + "row_numbers": [0], + "column_numbers": [1, 2, 3], + "score": 0.9, + }, + { + "header": "hi", + "row_numbers": [1], + "column_numbers": [1], + "score": 0.9, + }, + { + "header": "hi", + "row_numbers": [1], + "column_numbers": [2], + "score": 0.9, + }, + { + "header": "hi", + "row_numbers": [1], + "column_numbers": [3], + "score": 0.9, + }, + ], + 4, ), ( - [ - { - "header": "hi", - "row_numbers": [0], - "column_numbers": [0], - "score": 0.9, - }, - { - "header": "hi", - "row_numbers": [1], - "column_numbers": [0], - "score": 0.9, - }, - { - "header": "hi", - "row_numbers": [1, 2], - "column_numbers": [0], - "score": 0.9, - }, - { - "header": "hi", - "row_numbers": [3], - "column_numbers": [0], - "score": 0.9, - }, - ], - 3, + [ + { + "header": "hi", + "row_numbers": [0], + "column_numbers": [0], + "score": 0.9, + }, + { + "header": "hi", + "row_numbers": [1], + "column_numbers": [0], + "score": 0.9, + }, + { + "header": "hi", + "row_numbers": [1, 2], + "column_numbers": [0], + "score": 0.9, + }, + { + "header": "hi", + "row_numbers": [3], + "column_numbers": [0], + "score": 0.9, + }, + ], + 3, ), ], ) @@ -1216,26 +1216,6 @@ def test_header_supercell_tree(supercells, expected_len): assert len(supercells) == expected_len -def test_cells_to_html(): - # example table - # +----------+---------------------+ - # | two | two columns | - # | |----------+----------| - # | rows |sub cell 1|sub cell 2| - # +----------+----------+----------+ - cells = [ - {"row_nums": [0, 1], "column_nums": [0], "cell text": "two row", "column header": False}, - {"row_nums": [0], "column_nums": [1, 2], "cell text": "two cols", "column header": False}, - {"row_nums": [1], "column_nums": [1], "cell text": "sub cell 1", "column header": False}, - {"row_nums": [1], "column_nums": [2], "cell text": "sub cell 2", "column header": False}, - ] - expected = ( - '
two rowtwo ' - "cols
sub cell 1sub cell 2
" - ) - assert tables.cells_to_html(cells) == expected - - @pytest.mark.parametrize("zoom", [1, 0.1, 5, -1, 0]) def test_zoom_image(example_image, zoom): width, height = example_image.size @@ -1247,6 +1227,214 @@ def test_zoom_image(example_image, zoom): assert new_h == np.round(height * zoom, 0) +@pytest.mark.parametrize( + ("input_cells", "expected_html"), [ + # +----------+---------------------+ + # | row1col1 | row1col2 | row1col3 | + # |----------|----------+----------| + # | row2col1 | row2col2 | row2col3 | + # +----------+----------+----------+ + pytest.param( + [ + {"row_nums": [0], "column_nums": [0], "cell text": "row1col1", "column header": False}, + {"row_nums": [0], "column_nums": [1], "cell text": "row1col2", "column header": False}, + {"row_nums": [0], "column_nums": [2], "cell text": "row1col3", "column header": False}, + {"row_nums": [1], "column_nums": [0], "cell text": "row2col1", "column header": False}, + {"row_nums": [1], "column_nums": [1], "cell text": "row2col2", "column header": False}, + {"row_nums": [1], "column_nums": [2], "cell text": "row2col3", "column header": False}, + ], + ( + '' + '
row1col1row1col2row1col3
row2col1row2col2row2col3
' + ), + id="simple table without header", + ), + # +----------+---------------------+ + # | h1col1 | h1col2 | h1col3 | + # |----------|----------+----------| + # | row1col1 | row1col2 | row1col3 | + # |----------|----------+----------| + # | row2col1 | row2col2 | row2col3 | + # +----------+----------+----------+ + pytest.param( + [ + {"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True}, + {"row_nums": [0], "column_nums": [1], "cell text": "h1col2", "column header": True}, + {"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True}, + {"row_nums": [1], "column_nums": [0], "cell text": "row1col1", "column header": False}, + {"row_nums": [1], "column_nums": [1], "cell text": "row1col2", "column header": False}, + {"row_nums": [1], "column_nums": [2], "cell text": "row1col3", "column header": False}, + {"row_nums": [2], "column_nums": [0], "cell text": "row2col1", "column header": False}, + {"row_nums": [2], "column_nums": [1], "cell text": "row2col2", "column header": False}, + {"row_nums": [2], "column_nums": [2], "cell text": "row2col3", "column header": False}, + ], + ( + '' + '' + '
h1col1h1col2h1col2
row1col1row1col2row1col3
row2col1row2col2row2col3
' + ), + id="simple table with header", + ), + # +----------+---------------------+ + # | h1col1 | h1col2 | h1col3 | + # |----------|----------+----------| + # | row1col1 | row1col2 | row1col3 | + # |----------|----------+----------| + # | row2col1 | row2col2 | row2col3 | + # +----------+----------+----------+ + pytest.param( + [ + {"row_nums": [0], "column_nums": [1], "cell text": "h1col2", "column header": True}, + {"row_nums": [2], "column_nums": [0], "cell text": "row2col1", "column header": False}, + {"row_nums": [1], "column_nums": [0], "cell text": "row1col1", "column header": False}, + {"row_nums": [2], "column_nums": [1], "cell text": "row2col2", "column header": False}, + {"row_nums": [1], "column_nums": [1], "cell text": "row1col2", "column header": False}, + {"row_nums": [2], "column_nums": [2], "cell text": "row2col3", "column header": False}, + {"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True}, + {"row_nums": [1], "column_nums": [2], "cell text": "row1col3", "column header": False}, + {"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True}, + ], + ( + '' + '' + '
h1col1h1col2h1col2
row1col1row1col2row1col3
row2col1row2col2row2col3
' + ), + id="simple table with header, mixed elements", + ), + # +----------+---------------------+ + # | two | two columns | + # | |----------+----------| + # | rows |sub cell 1|sub cell 2| + # +----------+----------+----------+ + pytest.param( + [ + {"row_nums": [0, 1], "column_nums": [0], "cell text": "two row", "column header": False}, + {"row_nums": [0], "column_nums": [1, 2], "cell text": "two cols", "column header": False}, + {"row_nums": [1], "column_nums": [1], "cell text": "sub cell 1", "column header": False}, + {"row_nums": [1], "column_nums": [2], "cell text": "sub cell 2", "column header": False}, + ], + ( + '
two rowtwo ' + "cols
sub cell 1sub cell 2
" + ), + id="various spans, no headers", + ), + # +----------+---------------------+----------+ + # | | h1col23 | h1col4 | + # | h12col1 |----------+----------+----------| + # | | h2col2 | h2col34 | + # |----------|----------+----------+----------+ + # | r3col1 | r3col2 | | + # |----------+----------| r34col34 | + # | r4col12 | | + # +----------+----------+----------+----------+ + pytest.param( + [ + {"row_nums": [0, 1], "column_nums": [0], "cell text": "h12col1", "column header": True}, + {"row_nums": [0], "column_nums": [1, 2], "cell text": "h1col23", "column header": True}, + {"row_nums": [0], "column_nums": [3], "cell text": "h1col4", "column header": True}, + {"row_nums": [1], "column_nums": [1], "cell text": "h2col2", "column header": True}, + {"row_nums": [1], "column_nums": [2, 3], "cell text": "h2col34", "column header": True}, + {"row_nums": [2], "column_nums": [0], "cell text": "r3col1", "column header": False}, + {"row_nums": [2], "column_nums": [1], "cell text": "r3col2", "column header": False}, + {"row_nums": [2, 3], "column_nums": [2, 3], "cell text": "r34col34", "column header": False}, + {"row_nums": [3], "column_nums": [0, 1], "cell text": "r4col12", "column header": False}, + ], + ( + '' + '' + '' + '
h12col1h1col23h1col4
h2col2h2col34
r3col1r3col2r34col34
r4col12
' + ), + id="various spans, no headers", + ), + ] +) +def test_cells_to_html(input_cells, expected_html): + assert tables.cells_to_html(input_cells) == expected_html + + +@pytest.mark.parametrize( + ("input_cells", "expected_cells"), [ + pytest.param( + [ + {"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True}, + {"row_nums": [0], "column_nums": [1], "cell text": "h1col2", "column header": True}, + {"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True}, + {"row_nums": [1], "column_nums": [0], "cell text": "row1col1", "column header": False}, + {"row_nums": [1], "column_nums": [1], "cell text": "row1col2", "column header": False}, + {"row_nums": [1], "column_nums": [2], "cell text": "row1col3", "column header": False}, + {"row_nums": [2], "column_nums": [0], "cell text": "row2col1", "column header": False}, + {"row_nums": [2], "column_nums": [1], "cell text": "row2col2", "column header": False}, + {"row_nums": [2], "column_nums": [2], "cell text": "row2col3", "column header": False}, + ], + [ + {"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True}, + {"row_nums": [0], "column_nums": [1], "cell text": "h1col2", "column header": True}, + {"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True}, + {"row_nums": [1], "column_nums": [0], "cell text": "row1col1", "column header": False}, + {"row_nums": [1], "column_nums": [1], "cell text": "row1col2", "column header": False}, + {"row_nums": [1], "column_nums": [2], "cell text": "row1col3", "column header": False}, + {"row_nums": [2], "column_nums": [0], "cell text": "row2col1", "column header": False}, + {"row_nums": [2], "column_nums": [1], "cell text": "row2col2", "column header": False}, + {"row_nums": [2], "column_nums": [2], "cell text": "row2col3", "column header": False}, + ], + id="identical tables, no changes expected" + ), + pytest.param( + [ + {"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True}, + {"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True}, + {"row_nums": [1], "column_nums": [0], "cell text": "row1col1", "column header": False}, + {"row_nums": [1], "column_nums": [1], "cell text": "row1col2", "column header": False}, + {"row_nums": [2], "column_nums": [0], "cell text": "row2col1", "column header": False}, + {"row_nums": [2], "column_nums": [1], "cell text": "row2col2", "column header": False}, + {"row_nums": [2], "column_nums": [2], "cell text": "row2col3", "column header": False}, + ], + [ + {"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True}, + {"row_nums": [0], "column_nums": [1], "cell text": "", "column header": True}, + {"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True}, + {"row_nums": [1], "column_nums": [0], "cell text": "row1col1", "column header": False}, + {"row_nums": [1], "column_nums": [1], "cell text": "row1col2", "column header": False}, + {"row_nums": [1], "column_nums": [2], "cell text": "", "column header": False}, + {"row_nums": [2], "column_nums": [0], "cell text": "row2col1", "column header": False}, + {"row_nums": [2], "column_nums": [1], "cell text": "row2col2", "column header": False}, + {"row_nums": [2], "column_nums": [2], "cell text": "row2col3", "column header": False}, + ], + id="missing column in header and in the middle", + ), + pytest.param( + [ + {"row_nums": [0, 1], "column_nums": [0], "cell text": "h12col1", "column header": True}, + {"row_nums": [0], "column_nums": [1, 2], "cell text": "h1col23", "column header": True}, + {"row_nums": [1], "column_nums": [1], "cell text": "h2col2", "column header": True}, + {"row_nums": [1], "column_nums": [2, 3], "cell text": "h2col34", "column header": True}, + {"row_nums": [2], "column_nums": [0], "cell text": "r3col1", "column header": False}, + {"row_nums": [2, 3], "column_nums": [2, 3], "cell text": "r34col34", "column header": False}, + {"row_nums": [3], "column_nums": [0, 1], "cell text": "r4col12", "column header": False}, + ], + [ + {"row_nums": [0, 1], "column_nums": [0], "cell text": "h12col1", "column header": True}, + {"row_nums": [0], "column_nums": [1, 2], "cell text": "h1col23", "column header": True}, + {"row_nums": [0], "column_nums": [3], "cell text": "", "column header": True}, + {"row_nums": [1], "column_nums": [1], "cell text": "h2col2", "column header": True}, + {"row_nums": [1], "column_nums": [2, 3], "cell text": "h2col34", "column header": True}, + {"row_nums": [2], "column_nums": [0], "cell text": "r3col1", "column header": False}, + {"row_nums": [2], "column_nums": [1], "cell text": "", "column header": False}, + {"row_nums": [2, 3], "column_nums": [2, 3], "cell text": "r34col34", "column header": False}, + {"row_nums": [3], "column_nums": [0, 1], "cell text": "r4col12", "column header": False}, + ], + id="missing column in header and in the middle in table with spans" + ) +] +) +def test_fill_cells(input_cells, expected_cells): + def sort_cells(cells): + return sorted(cells, key=lambda x: (x["row_nums"], x["column_nums"])) + assert sort_cells(tables.fill_cells(input_cells)) == sort_cells(expected_cells) + + def test_padded_results_has_right_dimensions(table_transformer, example_image): str_class_name2idx = tables.get_class_map("structure") # a simpler mapping so we keep all structure in the returned objs below for test @@ -1292,31 +1480,31 @@ def test_compute_confidence_score_zero_division_error_handling(): "column_span_score, row_span_score, expected_text_to_indexes", [ ( - 0.9, - 0.8, - ( - { - "one three": {"row_nums": [0, 1], "column_nums": [0]}, - "two": {"row_nums": [0], "column_nums": [1]}, - "four": {"row_nums": [1], "column_nums": [1]}, - } - ), + 0.9, + 0.8, + ( + { + "one three": {"row_nums": [0, 1], "column_nums": [0]}, + "two": {"row_nums": [0], "column_nums": [1]}, + "four": {"row_nums": [1], "column_nums": [1]}, + } + ), ), ( - 0.8, - 0.9, - ( - { - "one two": {"row_nums": [0], "column_nums": [0, 1]}, - "three": {"row_nums": [1], "column_nums": [0]}, - "four": {"row_nums": [1], "column_nums": [1]}, - } - ), + 0.8, + 0.9, + ( + { + "one two": {"row_nums": [0], "column_nums": [0, 1]}, + "three": {"row_nums": [1], "column_nums": [0]}, + "four": {"row_nums": [1], "column_nums": [1]}, + } + ), ), ], ) def test_subcells_filtering_when_overlapping_spanning_cells( - column_span_score, row_span_score, expected_text_to_indexes + column_span_score, row_span_score, expected_text_to_indexes ): """ # table From bf96cf32c20f5301746727e37c5617ef9d04a0f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Kmiecik?= Date: Mon, 10 Jun 2024 12:08:24 +0200 Subject: [PATCH 3/7] chore: style and typing fixes --- .../models/test_tables.py | 1302 ++++++++++------- unstructured_inference/models/tables.py | 17 +- 2 files changed, 818 insertions(+), 501 deletions(-) diff --git a/test_unstructured_inference/models/test_tables.py b/test_unstructured_inference/models/test_tables.py index 184c8a63..6a218f9d 100644 --- a/test_unstructured_inference/models/test_tables.py +++ b/test_unstructured_inference/models/test_tables.py @@ -596,40 +596,40 @@ def test_load_donut_model(model_path): ("input_test", "output_test"), [ ( - [ - { - "label": "table column header", - "score": 0.9349299073219299, - "bbox": [ - 47.83147430419922, - 116.8877944946289, - 2557.79296875, - 216.98883056640625, - ], - }, - { - "label": "table column header", - "score": 0.934, - "bbox": [ - 47.83147430419922, - 116.8877944946289, - 2557.79296875, - 216.98883056640625, - ], - }, - ], - [ - { - "label": "table column header", - "score": 0.9349299073219299, - "bbox": [ - 47.83147430419922, - 116.8877944946289, - 2557.79296875, - 216.98883056640625, - ], - }, - ], + [ + { + "label": "table column header", + "score": 0.9349299073219299, + "bbox": [ + 47.83147430419922, + 116.8877944946289, + 2557.79296875, + 216.98883056640625, + ], + }, + { + "label": "table column header", + "score": 0.934, + "bbox": [ + 47.83147430419922, + 116.8877944946289, + 2557.79296875, + 216.98883056640625, + ], + }, + ], + [ + { + "label": "table column header", + "score": 0.9349299073219299, + "bbox": [ + 47.83147430419922, + 116.8877944946289, + 2557.79296875, + 216.98883056640625, + ], + }, + ], ), ([], []), ], @@ -644,234 +644,234 @@ def test_nms(input_test, output_test): ("supercell1", "supercell2"), [ ( + { + "label": "table spanning cell", + "score": 0.526617169380188, + "bbox": [ + 1446.2801513671875, + 1023.817138671875, + 2114.3525390625, + 1099.20166015625, + ], + "projected row header": False, + "header": False, + "row_numbers": [3, 4], + "column_numbers": [0, 4], + }, + { + "label": "table spanning cell", + "score": 0.5199193954467773, + "bbox": [ + 98.92312622070312, + 676.1566772460938, + 751.0982666015625, + 938.5986938476562, + ], + "projected row header": False, + "header": False, + "row_numbers": [3, 4, 6], + "column_numbers": [0, 4], + }, + ), + ( + { + "label": "table spanning cell", + "score": 0.526617169380188, + "bbox": [ + 1446.2801513671875, + 1023.817138671875, + 2114.3525390625, + 1099.20166015625, + ], + "projected row header": False, + "header": False, + "row_numbers": [3, 4], + "column_numbers": [0, 4], + }, + { + "label": "table spanning cell", + "score": 0.5199193954467773, + "bbox": [ + 98.92312622070312, + 676.1566772460938, + 751.0982666015625, + 938.5986938476562, + ], + "projected row header": False, + "header": False, + "row_numbers": [4], + "column_numbers": [0, 4, 6], + }, + ), + ( + { + "label": "table spanning cell", + "score": 0.526617169380188, + "bbox": [ + 1446.2801513671875, + 1023.817138671875, + 2114.3525390625, + 1099.20166015625, + ], + "projected row header": False, + "header": False, + "row_numbers": [3, 4], + "column_numbers": [1, 4], + }, + { + "label": "table spanning cell", + "score": 0.5199193954467773, + "bbox": [ + 98.92312622070312, + 676.1566772460938, + 751.0982666015625, + 938.5986938476562, + ], + "projected row header": False, + "header": False, + "row_numbers": [4], + "column_numbers": [0, 4, 6], + }, + ), + ( + { + "label": "table spanning cell", + "score": 0.526617169380188, + "bbox": [ + 1446.2801513671875, + 1023.817138671875, + 2114.3525390625, + 1099.20166015625, + ], + "projected row header": False, + "header": False, + "row_numbers": [3, 4], + "column_numbers": [1, 4], + }, + { + "label": "table spanning cell", + "score": 0.5199193954467773, + "bbox": [ + 98.92312622070312, + 676.1566772460938, + 751.0982666015625, + 938.5986938476562, + ], + "projected row header": False, + "header": False, + "row_numbers": [2, 4, 5, 6, 7, 8], + "column_numbers": [0, 4, 6], + }, + ), + ], +) +def test_remove_supercell_overlap(supercell1, supercell2): + assert postprocess.remove_supercell_overlap(supercell1, supercell2) is None + + +@pytest.mark.parametrize( + ("supercells", "rows", "columns", "output_test"), + [ + ( + [ { "label": "table spanning cell", - "score": 0.526617169380188, + "score": 0.9, "bbox": [ - 1446.2801513671875, - 1023.817138671875, - 2114.3525390625, - 1099.20166015625, + 98.92312622070312, + 143.11549377441406, + 2115.197265625, + 1238.27587890625, ], - "projected row header": False, - "header": False, - "row_numbers": [3, 4], - "column_numbers": [0, 4], + "projected row header": True, + "header": True, + "span": True, }, + ], + [ { - "label": "table spanning cell", - "score": 0.5199193954467773, - "bbox": [ - 98.92312622070312, - 676.1566772460938, - 751.0982666015625, - 938.5986938476562, - ], - "projected row header": False, - "header": False, - "row_numbers": [3, 4, 6], - "column_numbers": [0, 4], + "label": "table row", + "score": 0.9299452900886536, + "bbox": [0, 0, 10, 10], + "column header": True, + "header": True, }, - ), - ( { - "label": "table spanning cell", - "score": 0.526617169380188, + "label": "table row", + "score": 0.9299452900886536, "bbox": [ - 1446.2801513671875, - 1023.817138671875, + 98.92312622070312, + 143.11549377441406, 2114.3525390625, - 1099.20166015625, + 193.67681884765625, ], - "projected row header": False, - "header": False, - "row_numbers": [3, 4], - "column_numbers": [0, 4], + "column header": True, + "header": True, }, { - "label": "table spanning cell", - "score": 0.5199193954467773, + "label": "table row", + "score": 0.9299452900886536, "bbox": [ 98.92312622070312, - 676.1566772460938, - 751.0982666015625, - 938.5986938476562, + 143.11549377441406, + 2114.3525390625, + 193.67681884765625, ], - "projected row header": False, - "header": False, - "row_numbers": [4], - "column_numbers": [0, 4, 6], + "column header": True, + "header": True, }, - ), - ( + ], + [ { - "label": "table spanning cell", - "score": 0.526617169380188, + "label": "table column", + "score": 0.9996132254600525, "bbox": [ - 1446.2801513671875, - 1023.817138671875, - 2114.3525390625, - 1099.20166015625, + 98.92312622070312, + 143.11549377441406, + 517.6508178710938, + 1616.48779296875, ], - "projected row header": False, - "header": False, - "row_numbers": [3, 4], - "column_numbers": [1, 4], }, { - "label": "table spanning cell", - "score": 0.5199193954467773, + "label": "table column", + "score": 0.9935646653175354, "bbox": [ - 98.92312622070312, - 676.1566772460938, + 520.0474853515625, + 143.11549377441406, 751.0982666015625, - 938.5986938476562, + 1616.48779296875, ], - "projected row header": False, - "header": False, - "row_numbers": [4], - "column_numbers": [0, 4, 6], }, - ), - ( + ], + [ { "label": "table spanning cell", - "score": 0.526617169380188, + "score": 0.9, "bbox": [ - 1446.2801513671875, - 1023.817138671875, - 2114.3525390625, - 1099.20166015625, + 98.92312622070312, + 143.11549377441406, + 751.0982666015625, + 193.67681884765625, ], - "projected row header": False, - "header": False, - "row_numbers": [3, 4], - "column_numbers": [1, 4], + "projected row header": True, + "header": True, + "span": True, + "row_numbers": [1, 2], + "column_numbers": [0, 1], }, { - "label": "table spanning cell", - "score": 0.5199193954467773, + "row_numbers": [0], + "column_numbers": [0, 1], + "score": 0.9, + "propagated": True, "bbox": [ 98.92312622070312, - 676.1566772460938, + 143.11549377441406, 751.0982666015625, - 938.5986938476562, + 193.67681884765625, ], - "projected row header": False, - "header": False, - "row_numbers": [2, 4, 5, 6, 7, 8], - "column_numbers": [0, 4, 6], }, - ), - ], -) -def test_remove_supercell_overlap(supercell1, supercell2): - assert postprocess.remove_supercell_overlap(supercell1, supercell2) is None - - -@pytest.mark.parametrize( - ("supercells", "rows", "columns", "output_test"), - [ - ( - [ - { - "label": "table spanning cell", - "score": 0.9, - "bbox": [ - 98.92312622070312, - 143.11549377441406, - 2115.197265625, - 1238.27587890625, - ], - "projected row header": True, - "header": True, - "span": True, - }, - ], - [ - { - "label": "table row", - "score": 0.9299452900886536, - "bbox": [0, 0, 10, 10], - "column header": True, - "header": True, - }, - { - "label": "table row", - "score": 0.9299452900886536, - "bbox": [ - 98.92312622070312, - 143.11549377441406, - 2114.3525390625, - 193.67681884765625, - ], - "column header": True, - "header": True, - }, - { - "label": "table row", - "score": 0.9299452900886536, - "bbox": [ - 98.92312622070312, - 143.11549377441406, - 2114.3525390625, - 193.67681884765625, - ], - "column header": True, - "header": True, - }, - ], - [ - { - "label": "table column", - "score": 0.9996132254600525, - "bbox": [ - 98.92312622070312, - 143.11549377441406, - 517.6508178710938, - 1616.48779296875, - ], - }, - { - "label": "table column", - "score": 0.9935646653175354, - "bbox": [ - 520.0474853515625, - 143.11549377441406, - 751.0982666015625, - 1616.48779296875, - ], - }, - ], - [ - { - "label": "table spanning cell", - "score": 0.9, - "bbox": [ - 98.92312622070312, - 143.11549377441406, - 751.0982666015625, - 193.67681884765625, - ], - "projected row header": True, - "header": True, - "span": True, - "row_numbers": [1, 2], - "column_numbers": [0, 1], - }, - { - "row_numbers": [0], - "column_numbers": [0, 1], - "score": 0.9, - "propagated": True, - "bbox": [ - 98.92312622070312, - 143.11549377441406, - 751.0982666015625, - 193.67681884765625, - ], - }, - ], + ], ), ], ) @@ -889,26 +889,26 @@ def test_align_rows(rows, bbox, output): [ ("html", "Blind51434.5%, n=1"), ( - "cells", - { - "column_nums": [0], - "row_nums": [2], - "column header": False, - "cell text": "Blind", - }, + "cells", + { + "column_nums": [0], + "row_nums": [2], + "column header": False, + "cell text": "Blind", + }, ), ("dataframe", ["Blind", "5", "1", "4", "34.5%, n=1", "1199 sec, n=1"]), (None, "Blind51434.5%, n=1"), ], ) def test_table_prediction_output_format( - output_format, - expectation, - table_transformer, - example_image, - mocker, - example_table_cells, - mocked_ocr_tokens, + output_format, + expectation, + table_transformer, + example_image, + mocker, + example_table_cells, + mocked_ocr_tokens, ): mocker.patch.object(tables, "recognize", return_value=example_table_cells) mocker.patch.object( @@ -935,11 +935,11 @@ def test_table_prediction_output_format( def test_table_prediction_output_format_when_wrong_type_then_value_error( - table_transformer, - example_image, - mocker, - example_table_cells, - mocked_ocr_tokens, + table_transformer, + example_image, + mocker, + example_table_cells, + mocked_ocr_tokens, ): mocker.patch.object(tables, "recognize", return_value=example_table_cells) mocker.patch.object( @@ -954,10 +954,10 @@ def test_table_prediction_output_format_when_wrong_type_then_value_error( def test_table_prediction_runs_with_empty_recognize( - table_transformer, - example_image, - mocker, - mocked_ocr_tokens, + table_transformer, + example_image, + mocker, + mocked_ocr_tokens, ): mocker.patch.object(tables, "recognize", return_value=[]) mocker.patch.object( @@ -988,7 +988,7 @@ def test_table_prediction_with_no_ocr_tokens(table_transformer, example_image): ], ) def test_objects_are_filtered_based_on_class_thresholds_when_correct_prediction_and_threshold( - thresholds, expected_object_number + thresholds, expected_object_number ): objects = [ {"label": "0", "score": 0.2}, @@ -1007,7 +1007,7 @@ def test_objects_are_filtered_based_on_class_thresholds_when_correct_prediction_ ], ) def test_objects_are_filtered_based_on_class_thresholds_when_two_classes( - thresholds, expected_object_number + thresholds, expected_object_number ): objects = [ {"label": "0", "score": 0.2}, @@ -1043,98 +1043,98 @@ def test_include_rect(): ("spans", "join_with_space", "expected"), [ ( - [ - { - "flags": 2 ** 0, - "text": "5", - "superscript": False, - "span_num": 0, - "line_num": 0, - "block_num": 0, - }, - ], - True, - "", + [ + { + "flags": 2**0, + "text": "5", + "superscript": False, + "span_num": 0, + "line_num": 0, + "block_num": 0, + }, + ], + True, + "", ), ( - [ - { - "flags": 2 ** 0, - "text": "p", - "superscript": False, - "span_num": 0, - "line_num": 0, - "block_num": 0, - }, - ], - True, - "p", + [ + { + "flags": 2**0, + "text": "p", + "superscript": False, + "span_num": 0, + "line_num": 0, + "block_num": 0, + }, + ], + True, + "p", ), ( - [ - { - "flags": 2 ** 0, - "text": "p", - "superscript": False, - "span_num": 0, - "line_num": 0, - "block_num": 0, - }, - { - "flags": 2 ** 0, - "text": "p", - "superscript": False, - "span_num": 0, - "line_num": 0, - "block_num": 0, - }, - ], - True, - "p p", + [ + { + "flags": 2**0, + "text": "p", + "superscript": False, + "span_num": 0, + "line_num": 0, + "block_num": 0, + }, + { + "flags": 2**0, + "text": "p", + "superscript": False, + "span_num": 0, + "line_num": 0, + "block_num": 0, + }, + ], + True, + "p p", ), ( - [ - { - "flags": 2 ** 0, - "text": "p", - "superscript": False, - "span_num": 0, - "line_num": 0, - "block_num": 0, - }, - { - "flags": 2 ** 0, - "text": "p", - "superscript": False, - "span_num": 0, - "line_num": 0, - "block_num": 1, - }, - ], - True, - "p p", + [ + { + "flags": 2**0, + "text": "p", + "superscript": False, + "span_num": 0, + "line_num": 0, + "block_num": 0, + }, + { + "flags": 2**0, + "text": "p", + "superscript": False, + "span_num": 0, + "line_num": 0, + "block_num": 1, + }, + ], + True, + "p p", ), ( - [ - { - "flags": 2 ** 0, - "text": "p", - "superscript": False, - "span_num": 0, - "line_num": 0, - "block_num": 0, - }, - { - "flags": 2 ** 0, - "text": "p", - "superscript": False, - "span_num": 0, - "line_num": 0, - "block_num": 1, - }, - ], - False, - "p p", + [ + { + "flags": 2**0, + "text": "p", + "superscript": False, + "span_num": 0, + "line_num": 0, + "block_num": 0, + }, + { + "flags": 2**0, + "text": "p", + "superscript": False, + "span_num": 0, + "line_num": 0, + "block_num": 1, + }, + ], + False, + "p p", ), ], ) @@ -1152,62 +1152,62 @@ def test_extract_text_from_spans(spans, join_with_space, expected): [ ([{"header": "hi", "row_numbers": [0, 1, 2], "score": 0.9}], 1), ( - [ - { - "header": "hi", - "row_numbers": [0], - "column_numbers": [1, 2, 3], - "score": 0.9, - }, - { - "header": "hi", - "row_numbers": [1], - "column_numbers": [1], - "score": 0.9, - }, - { - "header": "hi", - "row_numbers": [1], - "column_numbers": [2], - "score": 0.9, - }, - { - "header": "hi", - "row_numbers": [1], - "column_numbers": [3], - "score": 0.9, - }, - ], - 4, + [ + { + "header": "hi", + "row_numbers": [0], + "column_numbers": [1, 2, 3], + "score": 0.9, + }, + { + "header": "hi", + "row_numbers": [1], + "column_numbers": [1], + "score": 0.9, + }, + { + "header": "hi", + "row_numbers": [1], + "column_numbers": [2], + "score": 0.9, + }, + { + "header": "hi", + "row_numbers": [1], + "column_numbers": [3], + "score": 0.9, + }, + ], + 4, ), ( - [ - { - "header": "hi", - "row_numbers": [0], - "column_numbers": [0], - "score": 0.9, - }, - { - "header": "hi", - "row_numbers": [1], - "column_numbers": [0], - "score": 0.9, - }, - { - "header": "hi", - "row_numbers": [1, 2], - "column_numbers": [0], - "score": 0.9, - }, - { - "header": "hi", - "row_numbers": [3], - "column_numbers": [0], - "score": 0.9, - }, - ], - 3, + [ + { + "header": "hi", + "row_numbers": [0], + "column_numbers": [0], + "score": 0.9, + }, + { + "header": "hi", + "row_numbers": [1], + "column_numbers": [0], + "score": 0.9, + }, + { + "header": "hi", + "row_numbers": [1, 2], + "column_numbers": [0], + "score": 0.9, + }, + { + "header": "hi", + "row_numbers": [3], + "column_numbers": [0], + "score": 0.9, + }, + ], + 3, ), ], ) @@ -1228,7 +1228,8 @@ def test_zoom_image(example_image, zoom): @pytest.mark.parametrize( - ("input_cells", "expected_html"), [ + ("input_cells", "expected_html"), + [ # +----------+---------------------+ # | row1col1 | row1col2 | row1col3 | # |----------|----------+----------| @@ -1236,16 +1237,46 @@ def test_zoom_image(example_image, zoom): # +----------+----------+----------+ pytest.param( [ - {"row_nums": [0], "column_nums": [0], "cell text": "row1col1", "column header": False}, - {"row_nums": [0], "column_nums": [1], "cell text": "row1col2", "column header": False}, - {"row_nums": [0], "column_nums": [2], "cell text": "row1col3", "column header": False}, - {"row_nums": [1], "column_nums": [0], "cell text": "row2col1", "column header": False}, - {"row_nums": [1], "column_nums": [1], "cell text": "row2col2", "column header": False}, - {"row_nums": [1], "column_nums": [2], "cell text": "row2col3", "column header": False}, + { + "row_nums": [0], + "column_nums": [0], + "cell text": "row1col1", + "column header": False, + }, + { + "row_nums": [0], + "column_nums": [1], + "cell text": "row1col2", + "column header": False, + }, + { + "row_nums": [0], + "column_nums": [2], + "cell text": "row1col3", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [0], + "cell text": "row2col1", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [1], + "cell text": "row2col2", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [2], + "cell text": "row2col3", + "column header": False, + }, ], ( - '' - '
row1col1row1col2row1col3
row2col1row2col2row2col3
' + "" + "
row1col1row1col2row1col3
row2col1row2col2row2col3
" ), id="simple table without header", ), @@ -1261,17 +1292,47 @@ def test_zoom_image(example_image, zoom): {"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True}, {"row_nums": [0], "column_nums": [1], "cell text": "h1col2", "column header": True}, {"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True}, - {"row_nums": [1], "column_nums": [0], "cell text": "row1col1", "column header": False}, - {"row_nums": [1], "column_nums": [1], "cell text": "row1col2", "column header": False}, - {"row_nums": [1], "column_nums": [2], "cell text": "row1col3", "column header": False}, - {"row_nums": [2], "column_nums": [0], "cell text": "row2col1", "column header": False}, - {"row_nums": [2], "column_nums": [1], "cell text": "row2col2", "column header": False}, - {"row_nums": [2], "column_nums": [2], "cell text": "row2col3", "column header": False}, + { + "row_nums": [1], + "column_nums": [0], + "cell text": "row1col1", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [1], + "cell text": "row1col2", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [2], + "cell text": "row1col3", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [0], + "cell text": "row2col1", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [1], + "cell text": "row2col2", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [2], + "cell text": "row2col3", + "column header": False, + }, ], ( - '' - '' - '
h1col1h1col2h1col2
row1col1row1col2row1col3
row2col1row2col2row2col3
' + "" + "" + "
h1col1h1col2h1col2
row1col1row1col2row1col3
row2col1row2col2row2col3
" ), id="simple table with header", ), @@ -1285,19 +1346,49 @@ def test_zoom_image(example_image, zoom): pytest.param( [ {"row_nums": [0], "column_nums": [1], "cell text": "h1col2", "column header": True}, - {"row_nums": [2], "column_nums": [0], "cell text": "row2col1", "column header": False}, - {"row_nums": [1], "column_nums": [0], "cell text": "row1col1", "column header": False}, - {"row_nums": [2], "column_nums": [1], "cell text": "row2col2", "column header": False}, - {"row_nums": [1], "column_nums": [1], "cell text": "row1col2", "column header": False}, - {"row_nums": [2], "column_nums": [2], "cell text": "row2col3", "column header": False}, + { + "row_nums": [2], + "column_nums": [0], + "cell text": "row2col1", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [0], + "cell text": "row1col1", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [1], + "cell text": "row2col2", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [1], + "cell text": "row1col2", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [2], + "cell text": "row2col3", + "column header": False, + }, {"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True}, - {"row_nums": [1], "column_nums": [2], "cell text": "row1col3", "column header": False}, + { + "row_nums": [1], + "column_nums": [2], + "cell text": "row1col3", + "column header": False, + }, {"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True}, ], ( - '' - '' - '
h1col1h1col2h1col2
row1col1row1col2row1col3
row2col1row2col2row2col3
' + "" + "" + "
h1col1h1col2h1col2
row1col1row1col2row1col3
row2col1row2col2row2col3
" ), id="simple table with header, mixed elements", ), @@ -1308,14 +1399,35 @@ def test_zoom_image(example_image, zoom): # +----------+----------+----------+ pytest.param( [ - {"row_nums": [0, 1], "column_nums": [0], "cell text": "two row", "column header": False}, - {"row_nums": [0], "column_nums": [1, 2], "cell text": "two cols", "column header": False}, - {"row_nums": [1], "column_nums": [1], "cell text": "sub cell 1", "column header": False}, - {"row_nums": [1], "column_nums": [2], "cell text": "sub cell 2", "column header": False}, + { + "row_nums": [0, 1], + "column_nums": [0], + "cell text": "two row", + "column header": False, + }, + { + "row_nums": [0], + "column_nums": [1, 2], + "cell text": "two cols", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [1], + "cell text": "sub cell 1", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [2], + "cell text": "sub cell 2", + "column header": False, + }, ], ( - '
two rowtwo ' - "cols
sub cell 1sub cell 2
" + '" + "
two rowtwo ' + "cols
sub cell 1sub cell 2
" ), id="various spans, no headers", ), @@ -1330,108 +1442,316 @@ def test_zoom_image(example_image, zoom): # +----------+----------+----------+----------+ pytest.param( [ - {"row_nums": [0, 1], "column_nums": [0], "cell text": "h12col1", "column header": True}, - {"row_nums": [0], "column_nums": [1, 2], "cell text": "h1col23", "column header": True}, + { + "row_nums": [0, 1], + "column_nums": [0], + "cell text": "h12col1", + "column header": True, + }, + { + "row_nums": [0], + "column_nums": [1, 2], + "cell text": "h1col23", + "column header": True, + }, {"row_nums": [0], "column_nums": [3], "cell text": "h1col4", "column header": True}, {"row_nums": [1], "column_nums": [1], "cell text": "h2col2", "column header": True}, - {"row_nums": [1], "column_nums": [2, 3], "cell text": "h2col34", "column header": True}, - {"row_nums": [2], "column_nums": [0], "cell text": "r3col1", "column header": False}, - {"row_nums": [2], "column_nums": [1], "cell text": "r3col2", "column header": False}, - {"row_nums": [2, 3], "column_nums": [2, 3], "cell text": "r34col34", "column header": False}, - {"row_nums": [3], "column_nums": [0, 1], "cell text": "r4col12", "column header": False}, + { + "row_nums": [1], + "column_nums": [2, 3], + "cell text": "h2col34", + "column header": True, + }, + { + "row_nums": [2], + "column_nums": [0], + "cell text": "r3col1", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [1], + "cell text": "r3col2", + "column header": False, + }, + { + "row_nums": [2, 3], + "column_nums": [2, 3], + "cell text": "r34col34", + "column header": False, + }, + { + "row_nums": [3], + "column_nums": [0, 1], + "cell text": "r4col12", + "column header": False, + }, ], ( - '' - '' - '' - '
h12col1h1col23h1col4
h2col2h2col34
r3col1r3col2r34col34
r4col12
' + '' + '' + '' + '' + '
h12col1h1col23h1col4
h2col2h2col34
r3col1r3col2r34col34
r4col12
' ), id="various spans, no headers", ), - ] + ], ) def test_cells_to_html(input_cells, expected_html): assert tables.cells_to_html(input_cells) == expected_html @pytest.mark.parametrize( - ("input_cells", "expected_cells"), [ + ("input_cells", "expected_cells"), + [ pytest.param( [ {"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True}, {"row_nums": [0], "column_nums": [1], "cell text": "h1col2", "column header": True}, {"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True}, - {"row_nums": [1], "column_nums": [0], "cell text": "row1col1", "column header": False}, - {"row_nums": [1], "column_nums": [1], "cell text": "row1col2", "column header": False}, - {"row_nums": [1], "column_nums": [2], "cell text": "row1col3", "column header": False}, - {"row_nums": [2], "column_nums": [0], "cell text": "row2col1", "column header": False}, - {"row_nums": [2], "column_nums": [1], "cell text": "row2col2", "column header": False}, - {"row_nums": [2], "column_nums": [2], "cell text": "row2col3", "column header": False}, + { + "row_nums": [1], + "column_nums": [0], + "cell text": "row1col1", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [1], + "cell text": "row1col2", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [2], + "cell text": "row1col3", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [0], + "cell text": "row2col1", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [1], + "cell text": "row2col2", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [2], + "cell text": "row2col3", + "column header": False, + }, ], [ {"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True}, {"row_nums": [0], "column_nums": [1], "cell text": "h1col2", "column header": True}, {"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True}, - {"row_nums": [1], "column_nums": [0], "cell text": "row1col1", "column header": False}, - {"row_nums": [1], "column_nums": [1], "cell text": "row1col2", "column header": False}, - {"row_nums": [1], "column_nums": [2], "cell text": "row1col3", "column header": False}, - {"row_nums": [2], "column_nums": [0], "cell text": "row2col1", "column header": False}, - {"row_nums": [2], "column_nums": [1], "cell text": "row2col2", "column header": False}, - {"row_nums": [2], "column_nums": [2], "cell text": "row2col3", "column header": False}, + { + "row_nums": [1], + "column_nums": [0], + "cell text": "row1col1", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [1], + "cell text": "row1col2", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [2], + "cell text": "row1col3", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [0], + "cell text": "row2col1", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [1], + "cell text": "row2col2", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [2], + "cell text": "row2col3", + "column header": False, + }, ], - id="identical tables, no changes expected" + id="identical tables, no changes expected", ), pytest.param( [ {"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True}, {"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True}, - {"row_nums": [1], "column_nums": [0], "cell text": "row1col1", "column header": False}, - {"row_nums": [1], "column_nums": [1], "cell text": "row1col2", "column header": False}, - {"row_nums": [2], "column_nums": [0], "cell text": "row2col1", "column header": False}, - {"row_nums": [2], "column_nums": [1], "cell text": "row2col2", "column header": False}, - {"row_nums": [2], "column_nums": [2], "cell text": "row2col3", "column header": False}, + { + "row_nums": [1], + "column_nums": [0], + "cell text": "row1col1", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [1], + "cell text": "row1col2", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [0], + "cell text": "row2col1", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [1], + "cell text": "row2col2", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [2], + "cell text": "row2col3", + "column header": False, + }, ], [ {"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True}, {"row_nums": [0], "column_nums": [1], "cell text": "", "column header": True}, {"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True}, - {"row_nums": [1], "column_nums": [0], "cell text": "row1col1", "column header": False}, - {"row_nums": [1], "column_nums": [1], "cell text": "row1col2", "column header": False}, + { + "row_nums": [1], + "column_nums": [0], + "cell text": "row1col1", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [1], + "cell text": "row1col2", + "column header": False, + }, {"row_nums": [1], "column_nums": [2], "cell text": "", "column header": False}, - {"row_nums": [2], "column_nums": [0], "cell text": "row2col1", "column header": False}, - {"row_nums": [2], "column_nums": [1], "cell text": "row2col2", "column header": False}, - {"row_nums": [2], "column_nums": [2], "cell text": "row2col3", "column header": False}, + { + "row_nums": [2], + "column_nums": [0], + "cell text": "row2col1", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [1], + "cell text": "row2col2", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [2], + "cell text": "row2col3", + "column header": False, + }, ], id="missing column in header and in the middle", ), pytest.param( [ - {"row_nums": [0, 1], "column_nums": [0], "cell text": "h12col1", "column header": True}, - {"row_nums": [0], "column_nums": [1, 2], "cell text": "h1col23", "column header": True}, + { + "row_nums": [0, 1], + "column_nums": [0], + "cell text": "h12col1", + "column header": True, + }, + { + "row_nums": [0], + "column_nums": [1, 2], + "cell text": "h1col23", + "column header": True, + }, {"row_nums": [1], "column_nums": [1], "cell text": "h2col2", "column header": True}, - {"row_nums": [1], "column_nums": [2, 3], "cell text": "h2col34", "column header": True}, - {"row_nums": [2], "column_nums": [0], "cell text": "r3col1", "column header": False}, - {"row_nums": [2, 3], "column_nums": [2, 3], "cell text": "r34col34", "column header": False}, - {"row_nums": [3], "column_nums": [0, 1], "cell text": "r4col12", "column header": False}, + { + "row_nums": [1], + "column_nums": [2, 3], + "cell text": "h2col34", + "column header": True, + }, + { + "row_nums": [2], + "column_nums": [0], + "cell text": "r3col1", + "column header": False, + }, + { + "row_nums": [2, 3], + "column_nums": [2, 3], + "cell text": "r34col34", + "column header": False, + }, + { + "row_nums": [3], + "column_nums": [0, 1], + "cell text": "r4col12", + "column header": False, + }, ], [ - {"row_nums": [0, 1], "column_nums": [0], "cell text": "h12col1", "column header": True}, - {"row_nums": [0], "column_nums": [1, 2], "cell text": "h1col23", "column header": True}, + { + "row_nums": [0, 1], + "column_nums": [0], + "cell text": "h12col1", + "column header": True, + }, + { + "row_nums": [0], + "column_nums": [1, 2], + "cell text": "h1col23", + "column header": True, + }, {"row_nums": [0], "column_nums": [3], "cell text": "", "column header": True}, {"row_nums": [1], "column_nums": [1], "cell text": "h2col2", "column header": True}, - {"row_nums": [1], "column_nums": [2, 3], "cell text": "h2col34", "column header": True}, - {"row_nums": [2], "column_nums": [0], "cell text": "r3col1", "column header": False}, + { + "row_nums": [1], + "column_nums": [2, 3], + "cell text": "h2col34", + "column header": True, + }, + { + "row_nums": [2], + "column_nums": [0], + "cell text": "r3col1", + "column header": False, + }, {"row_nums": [2], "column_nums": [1], "cell text": "", "column header": False}, - {"row_nums": [2, 3], "column_nums": [2, 3], "cell text": "r34col34", "column header": False}, - {"row_nums": [3], "column_nums": [0, 1], "cell text": "r4col12", "column header": False}, + { + "row_nums": [2, 3], + "column_nums": [2, 3], + "cell text": "r34col34", + "column header": False, + }, + { + "row_nums": [3], + "column_nums": [0, 1], + "cell text": "r4col12", + "column header": False, + }, ], - id="missing column in header and in the middle in table with spans" - ) -] + id="missing column in header and in the middle in table with spans", + ), + ], ) def test_fill_cells(input_cells, expected_cells): def sort_cells(cells): return sorted(cells, key=lambda x: (x["row_nums"], x["column_nums"])) + assert sort_cells(tables.fill_cells(input_cells)) == sort_cells(expected_cells) @@ -1480,31 +1800,31 @@ def test_compute_confidence_score_zero_division_error_handling(): "column_span_score, row_span_score, expected_text_to_indexes", [ ( - 0.9, - 0.8, - ( - { - "one three": {"row_nums": [0, 1], "column_nums": [0]}, - "two": {"row_nums": [0], "column_nums": [1]}, - "four": {"row_nums": [1], "column_nums": [1]}, - } - ), + 0.9, + 0.8, + ( + { + "one three": {"row_nums": [0, 1], "column_nums": [0]}, + "two": {"row_nums": [0], "column_nums": [1]}, + "four": {"row_nums": [1], "column_nums": [1]}, + } + ), ), ( - 0.8, - 0.9, - ( - { - "one two": {"row_nums": [0], "column_nums": [0, 1]}, - "three": {"row_nums": [1], "column_nums": [0]}, - "four": {"row_nums": [1], "column_nums": [1]}, - } - ), + 0.8, + 0.9, + ( + { + "one two": {"row_nums": [0], "column_nums": [0, 1]}, + "three": {"row_nums": [1], "column_nums": [0]}, + "four": {"row_nums": [1], "column_nums": [1]}, + } + ), ), ], ) def test_subcells_filtering_when_overlapping_spanning_cells( - column_span_score, row_span_score, expected_text_to_indexes + column_span_score, row_span_score, expected_text_to_indexes ): """ # table diff --git a/unstructured_inference/models/tables.py b/unstructured_inference/models/tables.py index 9be0264c..d639eb62 100644 --- a/unstructured_inference/models/tables.py +++ b/unstructured_inference/models/tables.py @@ -664,23 +664,23 @@ def fill_cells(cells: List[dict]) -> List[dict]: whether this cell is a column header """ - table_rows_no = max(set([row for cell in cells for row in cell["row_nums"]])) - table_cols_no = max(set([col for cell in cells for col in cell["column_nums"]])) + table_rows_no = max({row for cell in cells for row in cell["row_nums"]}) + table_cols_no = max({col for cell in cells for col in cell["column_nums"]}) filled = np.zeros((table_rows_no + 1, table_cols_no + 1), dtype=bool) for cell in cells: for row in cell["row_nums"]: for col in cell["column_nums"]: filled[row, col] = True # add cells for which filled is false - header_rows = set([row for cell in cells if cell["column header"] for row in cell["row_nums"]]) + header_rows = {row for cell in cells if cell["column header"] for row in cell["row_nums"]} new_cells = cells.copy() - not_filled_idx = np.where(filled == False) + not_filled_idx = np.where(filled == False) # noqa: E712 for row, col in zip(not_filled_idx[0], not_filled_idx[1]): new_cell = { "row_nums": [row], "column_nums": [col], "cell text": "", - "column header": row in header_rows + "column header": row in header_rows, } new_cells.append(new_cell) return new_cells @@ -702,7 +702,6 @@ def cells_to_html(cells: List[dict]) -> str: str: HTML table string """ cells = sorted(fill_cells(cells), key=lambda k: (min(k["row_nums"]), min(k["column_nums"]))) - # cells = sorted(cells, key=lambda k: (min(k["row_nums"]), min(k["column_nums"]))) table = ET.Element("table") current_row = -1 @@ -713,10 +712,8 @@ def cells_to_html(cells: List[dict]) -> str: table_header = ET.SubElement(table, "thead") table_body = ET.SubElement(table, "tbody") - table_subelement = None for cell in cells: this_row = min(cell["row_nums"]) - attrib = {} colspan = len(cell["column_nums"]) if colspan > 1: @@ -730,9 +727,9 @@ def cells_to_html(cells: List[dict]) -> str: table_subelement = table_header cell_tag = "th" else: - cell_tag = "td" table_subelement = table_body - row = ET.SubElement(table_subelement, "tr") + cell_tag = "td" + row = ET.SubElement(table_subelement, "tr") # type: ignore tcell = ET.SubElement(row, cell_tag, attrib=attrib) tcell.text = cell["cell text"] From f3b1b8371d0c574632e87c7edb23a10cf0c59f25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Kmiecik?= Date: Mon, 10 Jun 2024 12:08:50 +0200 Subject: [PATCH 4/7] chore: updated CHANGELOG and version --- CHANGELOG.md | 3 +++ unstructured_inference/__version__.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index af8711b0..b73df54b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,6 @@ +## 0.7.35-dev0 +Fix syntax for generated HTML tables + ## 0.7.34 * Reduce excessive logging diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py index e6fd9f15..d838f872 100644 --- a/unstructured_inference/__version__.py +++ b/unstructured_inference/__version__.py @@ -1 +1 @@ -__version__ = "0.7.34" # pragma: no cover +__version__ = "0.7.35-dev0" # pragma: no cover From f4664cf1897ccea8984bb946eac0ce196c95f3e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Kmiecik?= Date: Mon, 10 Jun 2024 15:52:20 +0200 Subject: [PATCH 5/7] test: corrected some of table tests --- test_unstructured_inference/models/test_tables.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test_unstructured_inference/models/test_tables.py b/test_unstructured_inference/models/test_tables.py index 6a218f9d..d0e7c976 100644 --- a/test_unstructured_inference/models/test_tables.py +++ b/test_unstructured_inference/models/test_tables.py @@ -970,7 +970,7 @@ def test_table_prediction_runs_with_empty_recognize( def test_table_prediction_with_ocr_tokens(table_transformer, example_image, mocked_ocr_tokens): prediction = table_transformer.predict(example_image, ocr_tokens=mocked_ocr_tokens) - assert '' '
' in prediction + assert '" in prediction @@ -1426,7 +1426,7 @@ def test_zoom_image(example_image, zoom): ], ( '
' in prediction assert "
Blind51434.5%, n=1
" + "cols" "
two rowtwo ' - "cols
sub cell 1sub cell 2
sub cell 1sub cell 2
" ), id="various spans, no headers", From 29bd8ed1614fab2b7e054f4dd762febe18bcb1a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Kmiecik?= Date: Thu, 13 Jun 2024 11:56:20 +0200 Subject: [PATCH 6/7] chore: removed dev from version --- CHANGELOG.md | 2 +- unstructured_inference/__version__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b73df54b..b7241945 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -## 0.7.35-dev0 +## 0.7.35 Fix syntax for generated HTML tables ## 0.7.34 diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py index d838f872..d0586119 100644 --- a/unstructured_inference/__version__.py +++ b/unstructured_inference/__version__.py @@ -1 +1 @@ -__version__ = "0.7.35-dev0" # pragma: no cover +__version__ = "0.7.35" # pragma: no cover From 50ad448e110871e7cba60e25b458a6519f36299d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Kmiecik?= Date: Thu, 13 Jun 2024 11:56:44 +0200 Subject: [PATCH 7/7] chore: corrected test id --- test_unstructured_inference/models/test_tables.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_unstructured_inference/models/test_tables.py b/test_unstructured_inference/models/test_tables.py index d0e7c976..15c467cd 100644 --- a/test_unstructured_inference/models/test_tables.py +++ b/test_unstructured_inference/models/test_tables.py @@ -1494,7 +1494,7 @@ def test_zoom_image(example_image, zoom): '
r3col1r3col2r34col34
r4col12
' ), - id="various spans, no headers", + id="various spans, with 2 row header", ), ], )