From bf12b49ce2fddf2a5e17b1d653ecba2a839279d6 Mon Sep 17 00:00:00 2001 From: Stephan Tulkens Date: Thu, 3 Oct 2024 16:30:17 +0200 Subject: [PATCH] enh: Add device selection mechanism (#60) * enh: Add device selection mechanism * Fix bug in util * remove prefix in patch --- model2vec/distill/distillation.py | 7 ++++--- model2vec/distill/utils.py | 26 ++++++++++++++++++++++++++ tests/test_utils.py | 27 +++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 3 deletions(-) create mode 100644 model2vec/distill/utils.py diff --git a/model2vec/distill/distillation.py b/model2vec/distill/distillation.py index 28ddea6..35da025 100644 --- a/model2vec/distill/distillation.py +++ b/model2vec/distill/distillation.py @@ -12,6 +12,7 @@ create_output_embeddings_from_model_name_and_tokens, ) from model2vec.distill.tokenizer import add_tokens, preprocess_vocabulary, remove_tokens +from model2vec.distill.utils import select_optimal_device from model2vec.model import StaticModel logger = logging.getLogger(__name__) @@ -24,7 +25,7 @@ def distill_from_model( model: PreTrainedModel, tokenizer: PreTrainedTokenizerFast, vocabulary: list[str] | None = None, - device: str = "cpu", + device: str | None = None, pca_dims: PCADimType = 256, apply_zipf: bool = True, use_subword: bool = True, @@ -52,7 +53,7 @@ def distill_from_model( :return: A StaticModel """ - """""" + device = select_optimal_device(device) if not use_subword and vocabulary is None: raise ValueError( "You must pass a vocabulary if you don't use subword tokens. Either pass a vocabulary, or set use_subword to True." @@ -137,7 +138,7 @@ def distill_from_model( def distill( model_name: str, vocabulary: list[str] | None = None, - device: str = "cpu", + device: str | None = None, pca_dims: PCADimType = 256, apply_zipf: bool = True, use_subword: bool = True, diff --git a/model2vec/distill/utils.py b/model2vec/distill/utils.py new file mode 100644 index 0000000..cf16977 --- /dev/null +++ b/model2vec/distill/utils.py @@ -0,0 +1,26 @@ +from logging import getLogger + +import torch + +logger = getLogger(__name__) + + +def select_optimal_device(device: str | None) -> str: + """ + Guess what your optimal device should be based on backend availability. + + If you pass a device, we just pass it through. + + :param device: The device to use. If this is not None you get back what you passed. + :return: The selected device. + """ + if device is None: + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + logger.info(f"Automatically selected device: {device}") + + return device diff --git a/tests/test_utils.py b/tests/test_utils.py index 7ac4c70..d6219e0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,10 @@ from pathlib import Path from tempfile import NamedTemporaryFile +from unittest.mock import patch +import pytest + +from model2vec.distill.utils import select_optimal_device from model2vec.utils import _get_metadata_from_readme @@ -23,3 +27,26 @@ def test__get_metadata_from_readme_mocked_file_keys() -> None: f.write(b"") f.flush() assert set(_get_metadata_from_readme(Path(f.name))) == set() + + +@pytest.mark.parametrize( + "device, expected, cuda, mps", + [ + ("cpu", "cpu", True, True), + ("cpu", "cpu", True, False), + ("cpu", "cpu", False, True), + ("cpu", "cpu", False, False), + ("clown", "clown", False, False), + (None, "cuda", True, True), + (None, "cuda", True, False), + (None, "mps", False, True), + (None, "cpu", False, False), + ], +) +def test_select_optimal_device(device: str | None, expected: str, cuda: bool, mps: bool) -> None: + """Test whether the optimal device is selected.""" + with ( + patch("torch.cuda.is_available", return_value=cuda), + patch("torch.backends.mps.is_available", return_value=mps), + ): + assert select_optimal_device(device) == expected