Skip to content

Commit

Permalink
Set image_size configurable in export_model.py (#250)
Browse files Browse the repository at this point in the history
* Update export_model.py

* Fix arguments of image_size

Co-authored-by: Zhiqiang Wang <[email protected]>
  • Loading branch information
Yogurt-Peng and zhiqwang authored Dec 22, 2021
1 parent 5dd25c3 commit 75ef670
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions tools/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ def get_parser():
)
parser.add_argument(
"--image_size",
default=640,
nargs="+",
type=int,
help="Image size for evaluation (default: 640).",
default=[640, 640],
help="Image size for evaluation (default: 640, 640).",
)
parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")
parser.add_argument("--opset", default=DEFAULT_OPSET, type=int, help="opset_version")
Expand Down Expand Up @@ -128,9 +129,10 @@ def cli_main():
checkpoint_path = Path(args.checkpoint_path)
assert checkpoint_path.exists(), f"Not found checkpoint file at '{checkpoint_path}'"

image_size = args.image_size * 2 if len(args.image_size) == 1 else 1 # expand
if args.skip_preprocess:
# input data
inputs = torch.rand(args.batch_size, 3, args.image_size, args.image_size)
inputs = torch.rand(args.batch_size, 3, *image_size)
dynamic_axes = {
"images_tensors": {0: "batch", 2: "height", 3: "width"},
"boxes": {0: "batch", 1: "num_objects"},
Expand All @@ -147,7 +149,7 @@ def cli_main():
model.eval()
else:
# input data
images = [torch.rand(3, args.image_size, args.image_size)]
images = [torch.rand(3, *image_size)]
inputs = (images,)
dynamic_axes = {
"images_tensors": {1: "height", 2: "width"},
Expand All @@ -159,7 +161,8 @@ def cli_main():
output_names = ["scores", "labels", "boxes"]
model = YOLOv5.load_from_yolov5(
checkpoint_path,
score_thresh=args.score_thresh,
size=tuple(image_size),
core_thresh=args.score_thresh,
version=args.version,
)
model.eval()
Expand Down

0 comments on commit 75ef670

Please sign in to comment.