From 15173b26e8ba792bf094022f1595aecc5b7d563a Mon Sep 17 00:00:00 2001 From: Phaired <65019388+Phaired@users.noreply.github.com> Date: Tue, 20 Aug 2024 19:45:33 +0200 Subject: [PATCH] fix: error exporting when a gpu is available --- ocrs_models/train_detection.py | 2 +- ocrs_models/train_rec.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ocrs_models/train_detection.py b/ocrs_models/train_detection.py index e4cee09..ee71112 100644 --- a/ocrs_models/train_detection.py +++ b/ocrs_models/train_detection.py @@ -393,7 +393,7 @@ def main(): raise Exception("ONNX export requires a checkpoint to load") test_batch = next(iter(val_dataloader)) - test_image = test_batch["image"][0:1] + test_image = test_batch["image"][0:1].to(device) torch.onnx.export( model, diff --git a/ocrs_models/train_rec.py b/ocrs_models/train_rec.py index 7280d46..4566d50 100644 --- a/ocrs_models/train_rec.py +++ b/ocrs_models/train_rec.py @@ -397,7 +397,7 @@ def main(): test_batch = next(iter(val_dataloader)) torch.onnx.export( model, - test_batch["image"], + test_batch["image"].to(device), args.export, input_names=["line_image"], output_names=["chars"],