diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index eb25ddb63297..a54ac432006a 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -436,3 +436,9 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] SynthIDTextWatermarkDetector - __call__ + +## Compile Utils + +[[autodoc]] CompileConfig + - __call__ + diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 1f69b76d7ac6..970f32b9a88d 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -122,6 +122,7 @@ "feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"], "file_utils": [], "generation": [ + "CompileConfig", "GenerationConfig", "TextIteratorStreamer", "TextStreamer", @@ -4990,7 +4991,7 @@ from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin # Generation - from .generation import GenerationConfig, TextIteratorStreamer, TextStreamer, WatermarkingConfig + from .generation import CompileConfig, GenerationConfig, TextIteratorStreamer, TextStreamer, WatermarkingConfig from .hf_argparser import HfArgumentParser # Integrations diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index e2ed48433b16..59d970db1541 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -20,6 +20,7 @@ _import_structure = { "configuration_utils": [ "BaseWatermarkingConfig", + "CompileConfig", "GenerationConfig", "GenerationMode", "SynthIDTextWatermarkingConfig", @@ -192,6 +193,7 @@ if TYPE_CHECKING: from .configuration_utils import ( BaseWatermarkingConfig, + CompileConfig, GenerationConfig, GenerationMode, SynthIDTextWatermarkingConfig, diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 30a632aa8cca..486cd2336c3e 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -20,7 +20,7 @@ import warnings from abc import ABC, abstractmethod from dataclasses import dataclass, is_dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union from .. import __version__ from ..configuration_utils import PretrainedConfig @@ -371,6 +371,12 @@ class GenerationConfig(PushToHubMixin): to correctly align tokens. Can only be used with different tokenizers in speculative decoding. See this [blog](https://huggingface.co/blog/universal_assisted_generation) for more details. + > Parameters related to performances and compilation + + compile_config (CompileConfig, *optional*): + If using a static cache, this controls how `generate` will `compile` the forward pass for performance + gains. + > Wild card generation_kwargs: @@ -474,6 +480,9 @@ def __init__(self, **kwargs): self.assistant_lookbehind = kwargs.pop("assistant_lookbehind", 10) self.target_lookbehind = kwargs.pop("target_lookbehind", 10) + # Performances + self.compile_config = kwargs.pop("compile_config", CompileConfig()) + # Wild card self.generation_kwargs = kwargs.pop("generation_kwargs", {}) @@ -794,7 +803,13 @@ def validate(self, is_init=False): self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config) self.watermarking_config.validate() - # 7. other incorrect combinations + # 7. performances arguments + if not isinstance(self.compile_config, CompileConfig): + raise ValueError( + f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an instance of `CompileConfig`." + ) + + # 8. other incorrect combinations if self.return_dict_in_generate is not True: for extra_output_flag in self.extra_output_flags: if getattr(self, extra_output_flag) is True: @@ -1175,6 +1190,8 @@ def to_dict(self) -> Dict[str, Any]: del output["_commit_hash"] if "_original_object_hash" in output: del output["_original_object_hash"] + if "compile_config" in output: + del output["compile_config"] # Transformers version when serializing this file output["transformers_version"] = __version__ @@ -1559,3 +1576,51 @@ def construct_processor(self, vocab_size: int, device) -> "WatermarkLogitsProces skip_first_ngram_calls=self.skip_first_ngram_calls, debug_mode=self.debug_mode, ) + + +@dataclass +class CompileConfig(object): + """ + Class that holds arguments relative to `torch.compile` behavior, when using automatic compilation in `generate`. + See [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) for more details on the arguments. + + Args: + fullgraph (`bool`, *optional*, defaults to `True`): + If `True`, requires that the whole forward be capturable in a single graph. + dynamic (`bool` or `None`, *optional*): + Whether to try to use dynamic shape graphs. + backend (`str` or `Callable`, *optional*, defaults to `"inductor"`): + Backend to be used. + mode (`str`, *optional*, defaults to `"reduce-overhead"`): + Controls balance between performance and overhead. + options (`dict`, *optional*): + A dictionary of options to pass to the backend. + + Examples: + ```python + >>> from transformers import AutoModelForCausalLM, AutoTokenizer, CompileConfig + + >>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b') + >>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b').cuda() + + >>> # Automatic compile configuration, used with static cache + >>> compile_config = CompileConfig(dynamic=True) + + >>> # Generation with static cache and compile config + >>> input = tokenizer.encode("Hello there, how", return_tensors="pt").cuda() + >>> output = model.generate( + ... input, do_sample=False, max_new_tokens=300, cache_implementation="static", compile_config=compile_config + ... ) + >>> output_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0] + ``` + """ + + fullgraph: bool = True + dynamic: Optional[bool] = None + backend: Union[str, Callable] = "inductor" + mode: str = "reduce-overhead" + options: Optional[dict] = None + + def to_dict(self) -> Dict[str, Any]: + """Serializes this instance to a Python dictionary.""" + return copy.deepcopy(self.__dict__) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 05e39c4a9b56..1982841df667 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -3230,16 +3230,14 @@ def _sample( unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) - def model_forward(model, *args, **kwargs): - return model.forward(*args, **kwargs) - + model_forward = self.__call__ if isinstance(model_kwargs.get("past_key_values"), StaticCache): if self.device.type == "cuda": logger.warning_once("Using `torch.compile`.") os.environ["TOKENIZERS_PARALLELISM"] = "0" - model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True) + model_forward = self.get_compiled_call(generation_config.compile_config) - i = 0 + is_prefill = True while self._has_unfinished_sequences( this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length ): @@ -3250,11 +3248,11 @@ def model_forward(model, *args, **kwargs): model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) - if i == 0: + if is_prefill: outputs = self(**model_inputs, return_dict=True) - i += 1 + is_prefill = False else: - outputs = model_forward(self, return_dict=True, **model_inputs) + outputs = model_forward(**model_inputs, return_dict=True) # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping model_kwargs = self._update_model_kwargs_for_generation( diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 0806c318e101..50622c9f5514 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -43,7 +43,7 @@ from .activations import get_activation from .configuration_utils import PretrainedConfig from .dynamic_module_utils import custom_object_save -from .generation import GenerationConfig, GenerationMixin +from .generation import CompileConfig, GenerationConfig, GenerationMixin from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled from .loss.loss_utils import LOSS_MAPPING from .pytorch_utils import ( # noqa: F401 @@ -5094,6 +5094,21 @@ def loss_function(self): loss_type = "ForCausalLM" return LOSS_MAPPING[loss_type] + def get_compiled_call(self, compile_config: CompileConfig): + """Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between + non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't + want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding + (where we want the speed-ups of compiled version with static shapes).""" + # Only reset it if not present or different from previous config + default_config = getattr(self.generation_config, "compile_config", CompileConfig()) + if ( + not hasattr(self, "_compiled_call") + or getattr(self, "_last_compile_config", default_config) != compile_config + ): + self._last_compile_config = compile_config + self._compiled_call = torch.compile(self.__call__, **compile_config.to_dict()) + return self._compiled_call + PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub) if PreTrainedModel.push_to_hub.__doc__ is not None: