Skip to content

Commit

Permalink
chore: added gpu id control to image generation script
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Jun 10, 2022
1 parent e10f690 commit 779c77b
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions scripts/gen_single_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
parser.add_argument("--img-out", help="transformed image", required=True)
parser.add_argument("--bw", action="store_true", help="whether input/output is bw")
parser.add_argument("--cpu", action="store_true", help="whether to use CPU")
parser.add_argument("--gpuid", type=int, default=0, help="which GPU to use")
args = parser.parse_args()

if args.bw:
Expand Down Expand Up @@ -63,7 +64,8 @@
model.load_state_dict(torch.load(args.model_in_file))

if not args.cpu:
model = model.cuda()
device = torch.device("cuda:" + str(args.gpuid))
model = model.to(device)

# reading image
img = cv2.imread(args.img_in)
Expand All @@ -77,7 +79,7 @@
tran = transforms.Compose(tranlist)
img_tensor = tran(img)
if not args.cpu:
img_tensor = img_tensor.cuda()
img_tensor = img_tensor.to(device)

# run through model
out_tensor = model(img_tensor.unsqueeze(0))[0].detach()
Expand Down

0 comments on commit 779c77b

Please sign in to comment.