diff --git a/tools/export_model.py b/tools/export_model.py index 243f87f9..1607c44e 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -47,9 +47,10 @@ def get_parser(): ) parser.add_argument( "--image_size", - default=640, + nargs="+", type=int, - help="Image size for evaluation (default: 640).", + default=[640, 640], + help="Image size for evaluation (default: 640, 640).", ) parser.add_argument("--batch_size", default=1, type=int, help="Batch size.") parser.add_argument("--opset", default=DEFAULT_OPSET, type=int, help="opset_version") @@ -128,9 +129,10 @@ def cli_main(): checkpoint_path = Path(args.checkpoint_path) assert checkpoint_path.exists(), f"Not found checkpoint file at '{checkpoint_path}'" + image_size = args.image_size * 2 if len(args.image_size) == 1 else 1 # expand if args.skip_preprocess: # input data - inputs = torch.rand(args.batch_size, 3, args.image_size, args.image_size) + inputs = torch.rand(args.batch_size, 3, *image_size) dynamic_axes = { "images_tensors": {0: "batch", 2: "height", 3: "width"}, "boxes": {0: "batch", 1: "num_objects"}, @@ -147,7 +149,7 @@ def cli_main(): model.eval() else: # input data - images = [torch.rand(3, args.image_size, args.image_size)] + images = [torch.rand(3, *image_size)] inputs = (images,) dynamic_axes = { "images_tensors": {1: "height", 2: "width"}, @@ -159,7 +161,8 @@ def cli_main(): output_names = ["scores", "labels", "boxes"] model = YOLOv5.load_from_yolov5( checkpoint_path, - score_thresh=args.score_thresh, + size=tuple(image_size), + core_thresh=args.score_thresh, version=args.version, ) model.eval()