From 293c7748a14c31f33440a71693b404a393d5a721 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Mon, 29 Jan 2024 08:27:10 +0000 Subject: [PATCH] Optimize `box_match_metrics` with bounding box intersection test Replace precise intersection test with a cheap bounding box intersection test. --- ocrs_models/postprocess.py | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/ocrs_models/postprocess.py b/ocrs_models/postprocess.py index 4755702..205e227 100644 --- a/ocrs_models/postprocess.py +++ b/ocrs_models/postprocess.py @@ -79,6 +79,29 @@ def expand_quads(quads: torch.Tensor, dist: float) -> torch.Tensor: return torch.stack([expand_quad(quad, dist) for quad in quads]) +def lines_intersect(a_start: float, a_end: float, b_start: float, b_end: float) -> bool: + """ + Return true if the lines (a_start, a_end) and (b_start, b_end) intersect. + """ + if a_start <= b_start: + return a_end > b_start + else: + return b_end > a_start + + +def bounds_intersect( + a: tuple[float, float, float, float], b: tuple[float, float, float, float] +) -> bool: + """ + Return true if the rects defined by two (minx, miny, maxx, maxy) tuples intersect. + """ + a_minx, a_miny, a_maxx, a_maxy = a + b_minx, b_miny, b_maxx, b_maxy = b + return lines_intersect(a_minx, a_maxx, b_minx, b_maxx) and lines_intersect( + a_miny, a_maxy, b_miny, b_maxy + ) + + def box_match_metrics(pred: torch.Tensor, target: torch.Tensor) -> dict[str, float]: """ Compute metrics for quality of matches between two sets of rotated rects. @@ -99,12 +122,22 @@ def box_match_metrics(pred: torch.Tensor, target: torch.Tensor) -> dict[str, flo # Areas of unions of predictions and targets union = torch.zeros((len(pred), len(target))) + # Get bounding boxes of polys for a cheap intersection test. + pred_polys_bounds = [poly.bounds for poly in pred_polys] + target_polys_bounds = [poly.bounds for poly in target_polys] + pred_areas = torch.zeros((len(pred),)) for pred_index, pred_poly in enumerate(pred_polys): pred_areas[pred_index] = pred_poly.area + pred_bounds = pred_polys_bounds[pred_index] + for target_index, target_poly in enumerate(target_polys): - if not pred_poly.intersects(target_poly): + # Do a cheap intersection test and skip computing the actual + # union/intersection of that fails. + target_bounds = target_polys_bounds[target_index] + if not bounds_intersect(pred_bounds, target_bounds): continue + pt_intersection = pred_poly.intersection(target_poly) intersection[pred_index, target_index] = pt_intersection.area