Skip to content

Commit

Permalink
Merge pull request #379 from markstur/sentence-transformers-3
Browse files Browse the repository at this point in the history
Update sentence-transformers and allow setting trust_remote_code
  • Loading branch information
evaline-ju authored Aug 12, 2024
2 parents b28c2b1 + bcc1233 commit 989cd45
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 4 deletions.
2 changes: 2 additions & 0 deletions caikit_nlp/config/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ training_data_limit:

# Config used only in EmbeddingModule. Set here or use env vars like EMBEDDING_RETRIES=32
embedding:
# Allow models with remote code.
trust_remote_code: false
# Number of times to retry on error. Most deployments should use 0 retries.
retries: 0
# Batch size for encode() if <= 0 or invalid, the sentence-transformers default is used
Expand Down
42 changes: 40 additions & 2 deletions caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Dict,
Iterable,
List,
Literal,
NamedTuple,
Optional,
TypeVar,
Expand Down Expand Up @@ -82,6 +83,8 @@
sentence_transformers = importlib.import_module("sentence_transformers")
# Third Party
from sentence_transformers import SentenceTransformer
from sentence_transformers.model_card import SentenceTransformerModelCardData
from sentence_transformers.similarity_functions import SimilarityFunction
from sentence_transformers.util import batch_to_device, cos_sim, dot_score
from sentence_transformers.util import (
normalize_embeddings as normalize, # avoid parameter shadowing
Expand All @@ -107,6 +110,7 @@ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
val=embedding_cfg.get("implicit_truncation_errors", True)
)
DEVICE = embedding_cfg.get("device", "")
TRUST_REMOTE_CODE = embedding_cfg.get("trust_remote_code")

RT = TypeVar("RT") # return type

Expand Down Expand Up @@ -183,7 +187,9 @@ def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule":
ipex = cls._get_ipex(IPEX)
device = cls._select_device(ipex, DEVICE)
model = SentenceTransformerWithTruncate(
model_name_or_path=artifacts_path, device=device
model_name_or_path=artifacts_path,
device=device,
trust_remote_code=TRUST_REMOTE_CODE,
)
model.eval() # required for IPEX at least
if device is not None:
Expand Down Expand Up @@ -719,7 +725,12 @@ def bootstrap(cls, model_name_or_path: str) -> "EmbeddingModule":
model_name_or_path: str
Model name (Hugging Face hub) or path to model to load.
"""
return cls(model=SentenceTransformer(model_name_or_path=model_name_or_path))
return cls(
model=SentenceTransformer(
model_name_or_path=model_name_or_path,
trust_remote_code=TRUST_REMOTE_CODE,
)
)

def save(self, model_path: str, *args, **kwargs):
"""Save model using config in model_path
Expand Down Expand Up @@ -875,21 +886,39 @@ def __init__(
model_name_or_path: Optional[str] = None,
modules: Optional[Iterable[nn.Module]] = None,
device: Optional[str] = None,
prompts: Optional[Dict[str, str]] = None,
default_prompt_name: Optional[str] = None,
similarity_fn_name: Optional[Union[str, SimilarityFunction]] = None,
cache_folder: Optional[str] = None,
trust_remote_code: bool = False,
revision: Optional[str] = None,
local_files_only: bool = False,
token: Optional[Union[bool, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
truncate_dim: Optional[int] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
model_card_data: Optional[SentenceTransformerModelCardData] = None,
):
super().__init__(
model_name_or_path,
modules,
device,
prompts,
default_prompt_name,
similarity_fn_name,
cache_folder,
trust_remote_code,
revision,
local_files_only,
token,
use_auth_token,
truncate_dim,
model_kwargs,
tokenizer_kwargs,
config_kwargs,
model_card_data,
)
self.tokenizers = {}

Expand Down Expand Up @@ -1014,9 +1043,12 @@ def _get_tokenized(self, texts):
def encode(
self,
sentences: Union[str, List[str]],
prompt_name: Optional[str] = None,
prompt: Optional[str] = None,
batch_size: int = 32,
show_progress_bar: bool = None,
output_value: str = "sentence_embedding",
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
device: str = None,
Expand All @@ -1029,9 +1061,12 @@ def encode(
Computes sentence embeddings
:param sentences: the sentences to embed
:param prompt_name: Ignored here. Added for compatibility with super API.
:param prompt: Ignored here. Added for compatibility with super API.
:param batch_size: the batch size used for the computation
:param show_progress_bar: Ignored here. Added for compatibility with super API.
:param output_value: Ignored here. Added for compatibility with super API.
:param precision: Ignored here. Added for compatibility with super API.
:param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list
of pytorch tensors.
:param convert_to_tensor: If true, you get one large tensor as return. Overwrites any
Expand All @@ -1057,8 +1092,11 @@ def encode(

# These args are for API compatability, but are currently ignored in our version of encode()
_ = (
prompt_name,
prompt,
show_progress_bar,
output_value,
precision,
normalize_embeddings,
)

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ dependencies = [
"pandas>=1.5.0",
"scikit-learn>=1.1",
"scipy>=1.8.1",
"sentence-transformers>=2.3.1,<2.4.0",
"sentence-transformers>=3.0.0,<3.1.0",
"tokenizers>=0.13.3",
"torch>=2.3.1,<2.4.0",
"tqdm>=4.65.0",
"transformers>=4.32.0",
"transformers>=4.38.0,<4.44.0",
"peft==0.6.0",
]

Expand Down
2 changes: 2 additions & 0 deletions runtime_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ model_management:

# Config used only in EmbeddingModule. Set here or use env vars like EMBEDDING_RETRIES=32
embedding:
# Allow models with remote code.
trust_remote_code: false
# Number of times to retry on error. Most deployments should use 0 retries.
retries: 0
# Batch size for encode() if <= 0 or invalid, the sentence-transformers default is used
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ passenv =
LOG_FORMATTER
LOG_THREAD_ID
LOG_CHANNEL_WIDTH
PYTORCH_ENABLE_MPS_FALLBACK
commands = pytest --durations=42 --cov=caikit_nlp --cov-report=term --cov-report=html {posargs:tests}

; Unclear: We probably want to test wheel packaging
Expand Down

0 comments on commit 989cd45

Please sign in to comment.