Skip to content

Commit

Permalink
Optimize box_match_metrics with bounding box intersection test
Browse files Browse the repository at this point in the history
Replace precise intersection test with a cheap bounding box intersection test.
  • Loading branch information
robertknight committed Jan 30, 2024
1 parent c11284c commit 293c774
Showing 1 changed file with 34 additions and 1 deletion.
35 changes: 34 additions & 1 deletion ocrs_models/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,29 @@ def expand_quads(quads: torch.Tensor, dist: float) -> torch.Tensor:
return torch.stack([expand_quad(quad, dist) for quad in quads])


def lines_intersect(a_start: float, a_end: float, b_start: float, b_end: float) -> bool:
"""
Return true if the lines (a_start, a_end) and (b_start, b_end) intersect.
"""
if a_start <= b_start:
return a_end > b_start
else:
return b_end > a_start


def bounds_intersect(
a: tuple[float, float, float, float], b: tuple[float, float, float, float]
) -> bool:
"""
Return true if the rects defined by two (minx, miny, maxx, maxy) tuples intersect.
"""
a_minx, a_miny, a_maxx, a_maxy = a
b_minx, b_miny, b_maxx, b_maxy = b
return lines_intersect(a_minx, a_maxx, b_minx, b_maxx) and lines_intersect(
a_miny, a_maxy, b_miny, b_maxy
)


def box_match_metrics(pred: torch.Tensor, target: torch.Tensor) -> dict[str, float]:
"""
Compute metrics for quality of matches between two sets of rotated rects.
Expand All @@ -99,12 +122,22 @@ def box_match_metrics(pred: torch.Tensor, target: torch.Tensor) -> dict[str, flo
# Areas of unions of predictions and targets
union = torch.zeros((len(pred), len(target)))

# Get bounding boxes of polys for a cheap intersection test.
pred_polys_bounds = [poly.bounds for poly in pred_polys]
target_polys_bounds = [poly.bounds for poly in target_polys]

pred_areas = torch.zeros((len(pred),))
for pred_index, pred_poly in enumerate(pred_polys):
pred_areas[pred_index] = pred_poly.area
pred_bounds = pred_polys_bounds[pred_index]

for target_index, target_poly in enumerate(target_polys):
if not pred_poly.intersects(target_poly):
# Do a cheap intersection test and skip computing the actual
# union/intersection of that fails.
target_bounds = target_polys_bounds[target_index]
if not bounds_intersect(pred_bounds, target_bounds):
continue

pt_intersection = pred_poly.intersection(target_poly)
intersection[pred_index, target_index] = pt_intersection.area

Expand Down

0 comments on commit 293c774

Please sign in to comment.