From 98ea3be01d2e0d65d7c08d2d813dfbd5f88700bb Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Fri, 8 Sep 2023 11:59:29 +0200 Subject: [PATCH] Enable non-cpu devices in `ketos test` Also fixes #510 --- kraken/ketos/recognition.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/kraken/ketos/recognition.py b/kraken/ketos/recognition.py index ffb529d2b..0b0f1490b 100644 --- a/kraken/ketos/recognition.py +++ b/kraken/ketos/recognition.py @@ -398,9 +398,13 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers, nn = {} for p in model: message('Loading model {}\t'.format(p), nl=False) - nn[p] = models.load_any(p) + nn[p] = models.load_any(p, device) message('\u2713', fg='green') + pin_ds_mem = False + if device != 'cpu': + pin_ds_mem = True + test_set = list(test_set) # set number of OpenMP threads @@ -464,7 +468,7 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers, ds_loader = DataLoader(ds, batch_size=batch_size, num_workers=workers, - pin_memory=True, + pin_memory=pin_ds_mem, collate_fn=collate_sequences) with KrakenProgressBar() as progress: