Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix - table transformer predictions are filtered if confidence is below threshold #338

Merged
merged 5 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## 0.7.29

* fix: table transformer predictions are now removed if confidence is below threshold


## 0.7.28

* feat: allow table transformer agent to return table prediction in not parsed format
Expand Down
50 changes: 50 additions & 0 deletions test_unstructured_inference/models/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import unstructured_inference.models.table_postprocess as postprocess
from unstructured_inference.models import tables
from unstructured_inference.models.tables import apply_thresholds_on_objects

skip_outside_ci = os.getenv("CI", "").lower() in {"", "false", "f", "0"}

Expand Down Expand Up @@ -977,6 +978,55 @@ def test_table_prediction_with_no_ocr_tokens(table_transformer, example_image):
table_transformer.predict(example_image)


@pytest.mark.parametrize(
("thresholds", "expected_object_number"),
[
({"0": 0.5}, 1),
({"0": 0.1}, 3),
({"0": 0.9}, 0),
],
)
def test_objects_are_filtered_based_on_class_thresholds_when_correct_prediction_and_threshold(
thresholds, expected_object_number
):
objects = [
{"label": "0", "score": 0.2},
{"label": "0", "score": 0.4},
{"label": "0", "score": 0.55},
]
assert len(apply_thresholds_on_objects(objects, thresholds)) == expected_object_number


@pytest.mark.parametrize(
("thresholds", "expected_object_number"),
[
({"0": 0.5, "1": 0.1}, 4),
({"0": 0.1, "1": 0.9}, 3),
({"0": 0.9, "1": 0.5}, 1),
],
)
def test_objects_are_filtered_based_on_class_thresholds_when_two_classes(
thresholds, expected_object_number
):
objects = [
{"label": "0", "score": 0.2},
{"label": "0", "score": 0.4},
{"label": "0", "score": 0.55},
{"label": "1", "score": 0.2},
{"label": "1", "score": 0.4},
{"label": "1", "score": 0.55},
]
assert len(apply_thresholds_on_objects(objects, thresholds)) == expected_object_number


def test_objects_filtering_when_missing_threshold():
class_name = "class_name"
objects = [{"label": class_name, "score": 0.2}]
thresholds = {"1": 0.5}
with pytest.raises(KeyError, match=class_name):
apply_thresholds_on_objects(objects, thresholds)


def test_intersect():
a = postprocess.Rect()
b = postprocess.Rect([1, 2, 3, 4])
Expand Down
2 changes: 1 addition & 1 deletion unstructured_inference/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.7.28" # pragma: no cover
__version__ = "0.7.29" # pragma: no cover
18 changes: 0 additions & 18 deletions unstructured_inference/models/table_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,24 +80,6 @@ def apply_threshold(objects, threshold):
return [obj for obj in objects if obj["score"] >= threshold]


# def apply_class_thresholds(bboxes, labels, scores, class_names, class_thresholds):
# """
# Filter out bounding boxes whose confidence is below the confidence threshold for
# its associated class label.
# """
# # Apply class-specific thresholds
# indices_above_threshold = [
# idx
# for idx, (score, label) in enumerate(zip(scores, labels))
# if score >= class_thresholds[class_names[label]]
# ]
# bboxes = [bboxes[idx] for idx in indices_above_threshold]
# scores = [scores[idx] for idx in indices_above_threshold]
# labels = [labels[idx] for idx in indices_above_threshold]

# return bboxes, scores, labels


def refine_rows(rows, tokens, score_threshold):
"""
Apply operations to the detected rows, such as
Expand Down
43 changes: 38 additions & 5 deletions unstructured_inference/models/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
import xml.etree.ElementTree as ET
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union

import cv2
import numpy as np
import torch
from PIL import Image as PILImage
from transformers import DetrImageProcessor, TableTransformerForObjectDetection
from transformers.models.table_transformer.modeling_table_transformer import (
TableTransformerObjectDetectionOutput,
)

from unstructured_inference.config import inference_config
from unstructured_inference.inference.layoutelement import table_cells_to_dataframe
Expand Down Expand Up @@ -172,18 +175,22 @@ def recognize(outputs: dict, img: PILImage.Image, tokens: list):
"""Recognize table elements."""
str_class_name2idx = get_class_map("structure")
str_class_idx2name = {v: k for k, v in str_class_name2idx.items()}
str_class_thresholds = structure_class_thresholds
class_thresholds = structure_class_thresholds

# Post-process detected objects, assign class labels
objects = outputs_to_objects(outputs, img.size, str_class_idx2name)

high_confidence_objects = apply_thresholds_on_objects(objects, class_thresholds)
# Further process the detected objects so they correspond to a consistent table
tables_structure = objects_to_structures(objects, tokens, str_class_thresholds)
tables_structure = objects_to_structures(high_confidence_objects, tokens, class_thresholds)
# Enumerate all table cells: grid cells and spanning cells
return [structure_to_cells(structure, tokens)[0] for structure in tables_structure]


def outputs_to_objects(outputs, img_size, class_idx2name):
def outputs_to_objects(
outputs: TableTransformerObjectDetectionOutput,
img_size: tuple[int, int],
class_idx2name: Mapping[int, str],
):
"""Output table element types."""
m = outputs["logits"].softmax(-1).max(-1)
pred_labels = list(m.indices.detach().cpu().numpy())[0]
Expand Down Expand Up @@ -212,6 +219,32 @@ def outputs_to_objects(outputs, img_size, class_idx2name):
return objects


def apply_thresholds_on_objects(
objects: Sequence[Mapping[str, Any]], thresholds: Mapping[str, float]
) -> Sequence[Mapping[str, Any]]:
"""
Filters predicted objects which the confidence scores below the thresholds

Args:
objects: Sequence of mappings for example:
[
{
"label": "table row",
"score": 0.55,
"bbox": [...],
},
...,
]
thresholds: Mapping from labels to thresholds

Returns:
Filtered list of objects

"""
objects = [obj for obj in objects if obj["score"] >= thresholds[obj["label"]]]
return objects


# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
"""Convert rectangle format from center-x, center-y, width, height to
Expand Down
Loading