Skip to content

Commit

Permalink
enh: Add device selection mechanism (#60)
Browse files Browse the repository at this point in the history
* enh: Add device selection mechanism

* Fix bug in util

* remove prefix in patch
  • Loading branch information
stephantul authored Oct 3, 2024
1 parent 78fce2e commit bf12b49
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 3 deletions.
7 changes: 4 additions & 3 deletions model2vec/distill/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
create_output_embeddings_from_model_name_and_tokens,
)
from model2vec.distill.tokenizer import add_tokens, preprocess_vocabulary, remove_tokens
from model2vec.distill.utils import select_optimal_device
from model2vec.model import StaticModel

logger = logging.getLogger(__name__)
Expand All @@ -24,7 +25,7 @@ def distill_from_model(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerFast,
vocabulary: list[str] | None = None,
device: str = "cpu",
device: str | None = None,
pca_dims: PCADimType = 256,
apply_zipf: bool = True,
use_subword: bool = True,
Expand Down Expand Up @@ -52,7 +53,7 @@ def distill_from_model(
:return: A StaticModel
"""
""""""
device = select_optimal_device(device)
if not use_subword and vocabulary is None:
raise ValueError(
"You must pass a vocabulary if you don't use subword tokens. Either pass a vocabulary, or set use_subword to True."
Expand Down Expand Up @@ -137,7 +138,7 @@ def distill_from_model(
def distill(
model_name: str,
vocabulary: list[str] | None = None,
device: str = "cpu",
device: str | None = None,
pca_dims: PCADimType = 256,
apply_zipf: bool = True,
use_subword: bool = True,
Expand Down
26 changes: 26 additions & 0 deletions model2vec/distill/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from logging import getLogger

import torch

logger = getLogger(__name__)


def select_optimal_device(device: str | None) -> str:
"""
Guess what your optimal device should be based on backend availability.
If you pass a device, we just pass it through.
:param device: The device to use. If this is not None you get back what you passed.
:return: The selected device.
"""
if device is None:
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
logger.info(f"Automatically selected device: {device}")

return device
27 changes: 27 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from pathlib import Path
from tempfile import NamedTemporaryFile
from unittest.mock import patch

import pytest

from model2vec.distill.utils import select_optimal_device
from model2vec.utils import _get_metadata_from_readme


Expand All @@ -23,3 +27,26 @@ def test__get_metadata_from_readme_mocked_file_keys() -> None:
f.write(b"")
f.flush()
assert set(_get_metadata_from_readme(Path(f.name))) == set()


@pytest.mark.parametrize(
"device, expected, cuda, mps",
[
("cpu", "cpu", True, True),
("cpu", "cpu", True, False),
("cpu", "cpu", False, True),
("cpu", "cpu", False, False),
("clown", "clown", False, False),
(None, "cuda", True, True),
(None, "cuda", True, False),
(None, "mps", False, True),
(None, "cpu", False, False),
],
)
def test_select_optimal_device(device: str | None, expected: str, cuda: bool, mps: bool) -> None:
"""Test whether the optimal device is selected."""
with (
patch("torch.cuda.is_available", return_value=cuda),
patch("torch.backends.mps.is_available", return_value=mps),
):
assert select_optimal_device(device) == expected

0 comments on commit bf12b49

Please sign in to comment.