diff --git a/ocrs_models/train_detection.py b/ocrs_models/train_detection.py index e4cee09..ee71112 100644 --- a/ocrs_models/train_detection.py +++ b/ocrs_models/train_detection.py @@ -393,7 +393,7 @@ def main(): raise Exception("ONNX export requires a checkpoint to load") test_batch = next(iter(val_dataloader)) - test_image = test_batch["image"][0:1] + test_image = test_batch["image"][0:1].to(device) torch.onnx.export( model, diff --git a/ocrs_models/train_rec.py b/ocrs_models/train_rec.py index 7280d46..4566d50 100644 --- a/ocrs_models/train_rec.py +++ b/ocrs_models/train_rec.py @@ -397,7 +397,7 @@ def main(): test_batch = next(iter(val_dataloader)) torch.onnx.export( model, - test_batch["image"], + test_batch["image"].to(device), args.export, input_names=["line_image"], output_names=["chars"],