From a9b812ee48fa0e5079d72ce0352ae1a3233b6e6a Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Fri, 18 Oct 2024 20:59:29 +0200 Subject: [PATCH] Add auto and multi-device training to CLI --- kraken/ketos/util.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/kraken/ketos/util.py b/kraken/ketos/util.py index 70bc28c16..d53db0aed 100644 --- a/kraken/ketos/util.py +++ b/kraken/ketos/util.py @@ -62,11 +62,18 @@ def message(msg, **styles): def to_ptl_device(device: str) -> Tuple[str, Optional[List[int]]]: - if device in ['cpu', 'mps']: - return device, 'auto' - elif any([device.startswith(x) for x in ['tpu', 'cuda', 'hpu', 'ipu']]): - dev, idx = device.split(':') + if device.strip() == 'auto': + return 'auto', 'auto' + devices = device.split(',') + if devices[0] in ['cpu', 'mps']: + return devices[0], 'auto' + elif any([devices[0].startswith(x) for x in ['tpu', 'cuda', 'hpu', 'ipu']]): + devices = [device.split(':') for device in devices] + devices = [(x[0].strip(), x[1].strip()) for x in devices] + if len(set(x[0] for x in devices)) > 1: + raise Exception('Can only use a single type of device at a time.') + dev, _ = devices[0] if dev == 'cuda': dev = 'gpu' - return dev, [int(idx)] + return dev, [int(x[1]) for x in devices] raise Exception(f'Invalid device {device} specified')