diff --git a/yolort/runtime/yolo_graphsurgeon.py b/yolort/runtime/yolo_graphsurgeon.py index 4107a34fa..0807d9052 100644 --- a/yolort/runtime/yolo_graphsurgeon.py +++ b/yolort/runtime/yolo_graphsurgeon.py @@ -141,15 +141,15 @@ def register_nms( op = "BatchedNMS_TRT" attrs = { "plugin_version": "1", - 'shareLocation': True, - 'backgroundLabelId': -1, # no background class - 'numClasses': self.num_classes, - 'topK': 1024, - 'keepTopK': detections_per_img, - 'scoreThreshold': score_thresh, - 'iouThreshold': nms_thresh, - 'isNormalized': normalized, - 'clipBoxes': False, + "shareLocation": True, + "backgroundLabelId": -1, # no background class + "numClasses": self.num_classes, + "topK": 1024, + "keepTopK": detections_per_img, + "scoreThreshold": score_thresh, + "iouThreshold": nms_thresh, + "isNormalized": normalized, + "clipBoxes": False, } # NMS Outputs