Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: pass model parameters to HFLocalInvocationLayer via model_kwargs, enabling direct model usage #4956

Merged
merged 15 commits into from
Jun 7, 2023
Merged
162 changes: 88 additions & 74 deletions haystack/nodes/prompt/invocation_layer/hugging_face.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union, List, Dict
from typing import Optional, Union, List, Dict, Any
import logging
import os

Expand All @@ -11,6 +11,7 @@
PreTrainedTokenizer,
PreTrainedTokenizerFast,
GenerationConfig,
Pipeline,
)
from transformers.pipelines import get_task

Expand Down Expand Up @@ -50,7 +51,8 @@ def __init__(
:param kwargs: Additional keyword arguments passed to the underlying model. Due to reflective construction of
all PromptModelInvocationLayer instances, this instance of HFLocalInvocationLayer might receive some unrelated
kwargs. Only kwargs relevant to the HFLocalInvocationLayer are considered. The list of supported kwargs
includes: task_name, trust_remote_code, revision, feature_extractor, tokenizer, config, use_fast, torch_dtype, device_map.
includes: "task", "model", "config", "tokenizer", "feature_extractor", "revision", "use_auth_token",
"device_map", "device", "torch_dtype", "trust_remote_code", "model_kwargs", and "pipeline_class".
For more details about pipeline kwargs in general, see
Hugging Face [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline).

Expand All @@ -72,81 +74,41 @@ def __init__(
self.__class__.__name__,
self.devices[0],
)

# Due to reflective construction of all invocation layers we might receive some
# unknown kwargs, so we need to take only the relevant.
# For more details refer to Hugging Face pipeline documentation
# Do not use `device_map` AND `device` at the same time as they will conflict
model_input_kwargs = {
key: kwargs[key]
for key in [
"model_kwargs",
"trust_remote_code",
"revision",
"feature_extractor",
"tokenizer",
"config",
"use_fast",
"torch_dtype",
"device_map",
"generation_kwargs",
"model_max_length",
"stream",
"stream_handler",
]
if key in kwargs
}
# flatten model_kwargs one level
if "model_kwargs" in model_input_kwargs:
mkwargs = model_input_kwargs.pop("model_kwargs")
model_input_kwargs.update(mkwargs)
if "device" not in kwargs:
kwargs["device"] = self.devices[0]

# save stream settings and stream_handler for pipeline invocation
self.stream_handler = model_input_kwargs.pop("stream_handler", None)
self.stream = model_input_kwargs.pop("stream", False)
self.stream_handler = kwargs.get("stream_handler", None)
self.stream = kwargs.get("stream", False)

# save generation_kwargs for pipeline invocation
self.generation_kwargs = model_input_kwargs.pop("generation_kwargs", {})
model_max_length = model_input_kwargs.pop("model_max_length", None)

torch_dtype = model_input_kwargs.get("torch_dtype")
if torch_dtype is not None:
if isinstance(torch_dtype, str):
if "torch." in torch_dtype:
torch_dtype_resolved = getattr(torch, torch_dtype.strip("torch."))
elif torch_dtype == "auto":
torch_dtype_resolved = torch_dtype
else:
raise ValueError(
f"torch_dtype should be a torch.dtype, a string with 'torch.' prefix or the string 'auto', got {torch_dtype}"
)
elif isinstance(torch_dtype, torch.dtype):
torch_dtype_resolved = torch_dtype
else:
raise ValueError(f"Invalid torch_dtype value {torch_dtype}")
model_input_kwargs["torch_dtype"] = torch_dtype_resolved

if len(model_input_kwargs) > 0:
logger.info("Using model input kwargs %s in %s", model_input_kwargs, self.__class__.__name__)
self.generation_kwargs = kwargs.get("generation_kwargs", {})

# If task_name is not provided, get the task name from the model name or path (uses HFApi)
if "task_name" in kwargs:
self.task_name = kwargs.get("task_name")
else:
self.task_name = get_task(model_name_or_path, use_auth_token=use_auth_token)

self.pipe = pipeline(
task=self.task_name, # task_name is used to determine the pipeline type
model=model_name_or_path,
device=self.devices[0] if "device_map" not in model_input_kwargs else None,
use_auth_token=self.use_auth_token,
model_kwargs=model_input_kwargs,
self.task_name = (
kwargs.get("task_name")
if "task_name" in kwargs
else get_task(model_name_or_path, use_auth_token=use_auth_token)
)
# we check in supports class method if task_name is supported but here we check again as
# we could have gotten the task_name from kwargs
if self.task_name not in ["text2text-generation", "text-generation"]:
raise ValueError(
f"Task name {self.task_name} is not supported. "
f"We only support text2text-generation and text-generation tasks."
)
pipeline_kwargs = self._prepare_pipeline_kwargs(
task=self.task_name, model_name_or_path=model_name_or_path, use_auth_token=use_auth_token, **kwargs
)
# create the transformer pipeline
self.pipe: Pipeline = pipeline(**pipeline_kwargs)

# This is how the default max_length is determined for Text2TextGenerationPipeline shown here
# https://huggingface.co/transformers/v4.6.0/_modules/transformers/pipelines/text2text_generation.html
# max_length must be set otherwise HFLocalInvocationLayer._ensure_token_limit will fail.
self.max_length = max_length or self.pipe.model.config.max_length

model_max_length = kwargs.get("model_max_length", None)
# we allow users to override the tokenizer's model_max_length because models like T5 have relative positional
# embeddings and can accept sequences of more than 512 tokens
if model_max_length is not None:
Expand All @@ -160,6 +122,37 @@ def __init__(
self.pipe.tokenizer.model_max_length,
)

def _prepare_pipeline_kwargs(self, **kwargs) -> Dict[str, Any]:
"""
Sanitizes and prepares the kwargs passed to the transformers pipeline function.
For more details about pipeline kwargs in general, see Hugging Face
[documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline).
"""
# as device and device_map are mutually exclusive, we set device to None if device_map is provided
device_map = kwargs.get("device_map", None)
device = kwargs.get("device") if device_map is None else None
# prepare torch_dtype for pipeline invocation
torch_dtype = self._extract_torch_dtype(**kwargs)
# and the model (prefer model instance over model_name_or_path str identifier)
model = kwargs.get("model") or kwargs.get("model_name_or_path")

pipeline_kwargs = {
"task": kwargs.get("task", None),
"model": model,
"config": kwargs.get("config", None),
"tokenizer": kwargs.get("tokenizer", None),
"feature_extractor": kwargs.get("feature_extractor", None),
"revision": kwargs.get("revision", None),
"use_auth_token": kwargs.get("use_auth_token", None),
"device_map": device_map,
"device": device,
"torch_dtype": torch_dtype,
"trust_remote_code": kwargs.get("trust_remote_code", False),
"model_kwargs": kwargs.get("model_kwargs", {}),
"pipeline_class": kwargs.get("pipeline_class", None),
}
return pipeline_kwargs

def invoke(self, *args, **kwargs):
"""
It takes a prompt and returns a list of generated texts using the local Hugging Face transformers model
Expand All @@ -172,14 +165,15 @@ def invoke(self, *args, **kwargs):
stop_words = kwargs.pop("stop_words", None)
top_k = kwargs.pop("top_k", None)
# either stream is True (will use default handler) or stream_handler is provided for custom handler
stream = kwargs.get("stream", self.stream) or kwargs.get("stream_handler", self.stream_handler) is not None
stream = kwargs.get("stream", self.stream)
stream_handler = kwargs.get("stream_handler", self.stream_handler)
stream = stream or stream_handler is not None
if kwargs and "prompt" in kwargs:
prompt = kwargs.pop("prompt")

# Consider only Text2TextGenerationPipeline and TextGenerationPipeline relevant, ignore others
# For more details refer to Hugging Face Text2TextGenerationPipeline and TextGenerationPipeline
# documentation
# TODO resolve these kwargs from the pipeline signature
model_input_kwargs = {
key: kwargs[key]
for key in [
Expand Down Expand Up @@ -227,7 +221,7 @@ def invoke(self, *args, **kwargs):
model_input_kwargs["max_length"] = self.max_length

if stream:
stream_handler: TokenStreamingHandler = kwargs.pop("stream_handler", DefaultTokenStreamingHandler())
stream_handler: TokenStreamingHandler = stream_handler or DefaultTokenStreamingHandler()
model_input_kwargs["streamer"] = HFTokenStreamingHandler(self.pipe.tokenizer, stream_handler)

output = self.pipe(prompt, **model_input_kwargs)
Expand All @@ -248,7 +242,8 @@ def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union
:param prompt: Prompt text to be sent to the generative model.
"""
model_max_length = self.pipe.tokenizer.model_max_length
n_prompt_tokens = len(self.pipe.tokenizer.tokenize(prompt))
tokenized_prompt = self.pipe.tokenizer.tokenize(prompt)
n_prompt_tokens = len(tokenized_prompt)
n_answer_tokens = self.max_length
if (n_prompt_tokens + n_answer_tokens) <= model_max_length:
return prompt
Expand All @@ -263,20 +258,38 @@ def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union
model_max_length,
)

tokenized_payload = self.pipe.tokenizer.tokenize(prompt)
decoded_string = self.pipe.tokenizer.convert_tokens_to_string(
tokenized_payload[: model_max_length - n_answer_tokens]
tokenized_prompt[: model_max_length - n_answer_tokens]
)
return decoded_string

def _extract_torch_dtype(self, **kwargs) -> Optional[torch.dtype]:
torch_dtype_resolved = None
torch_dtype = kwargs.get("torch_dtype", None)
if torch_dtype is not None:
if isinstance(torch_dtype, str):
if "torch." in torch_dtype:
torch_dtype_resolved = getattr(torch, torch_dtype.strip("torch."))
elif torch_dtype == "auto":
torch_dtype_resolved = torch_dtype
else:
raise ValueError(
f"torch_dtype should be a torch.dtype, a string with 'torch.' prefix or the string 'auto', got {torch_dtype}"
)
elif isinstance(torch_dtype, torch.dtype):
torch_dtype_resolved = torch_dtype
else:
raise ValueError(f"Invalid torch_dtype value {torch_dtype}")
return torch_dtype_resolved

@classmethod
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
task_name: Optional[str] = None
task_name: Optional[str] = kwargs.get("task_name", None)
if os.path.exists(model_name_or_path):
return True

try:
task_name = get_task(model_name_or_path, use_auth_token=kwargs.get("use_auth_token", None))
task_name = task_name or get_task(model_name_or_path, use_auth_token=kwargs.get("use_auth_token", None))
except RuntimeError:
# This will fail for all non-HF models
return False
Expand All @@ -300,4 +313,5 @@ def __init__(
self.stop_words = tokenizer(stop_words, add_special_tokens=False, return_tensors="pt").to(device)

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return any(torch.isin(input_ids[-1], self.stop_words["input_ids"]))
stop_result = torch.isin(self.stop_words["input_ids"], input_ids[-1])
return any(all(stop_word) for stop_word in stop_result)
Loading