diff --git a/imutils/object_detection.py b/imutils/object_detection.py index cca12ef..e4a9696 100644 --- a/imutils/object_detection.py +++ b/imutils/object_detection.py @@ -1,7 +1,7 @@ # import the necessary packages import numpy as np -def non_max_suppression(boxes, probs=None, overlapThresh=0.3): +def non_max_suppression(boxes, probs=None, iouThresh=0.3): # if there are no boxes, return an empty list if len(boxes) == 0: return [] @@ -53,13 +53,15 @@ def non_max_suppression(boxes, probs=None, overlapThresh=0.3): w = np.maximum(0, xx2 - xx1 + 1) h = np.maximum(0, yy2 - yy1 + 1) - # compute the ratio of overlap - overlap = (w * h) / area[idxs[:last]] + # compute the IoU + intersections = w * h + IoU = intersections / (area[idxs[:last]] + area[i] - intersections) # delete all indexes from the index list that have overlap greater # than the provided overlap threshold idxs = np.delete(idxs, np.concatenate(([last], - np.where(overlap > overlapThresh)[0]))) + np.where(IoU > iouThresh)[0]))) # return only the bounding boxes that were picked - return boxes[pick].astype("int") \ No newline at end of file + return boxes[pick].astype("int") +