Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic compilation in generate: do not rely on inner function #34923

Merged
merged 25 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -436,3 +436,9 @@ A [`Constraint`] can be used to force the generation to include specific tokens

[[autodoc]] SynthIDTextWatermarkDetector
- __call__

## Compile Utils

[[autodoc]] CompileConfig
- __call__

3 changes: 2 additions & 1 deletion src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
"file_utils": [],
"generation": [
"CompileConfig",
"GenerationConfig",
"TextIteratorStreamer",
"TextStreamer",
Expand Down Expand Up @@ -4977,7 +4978,7 @@
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin

# Generation
from .generation import GenerationConfig, TextIteratorStreamer, TextStreamer, WatermarkingConfig
from .generation import CompileConfig, GenerationConfig, TextIteratorStreamer, TextStreamer, WatermarkingConfig
from .hf_argparser import HfArgumentParser

# Integrations
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
_import_structure = {
"configuration_utils": [
"BaseWatermarkingConfig",
"CompileConfig",
"GenerationConfig",
"GenerationMode",
"SynthIDTextWatermarkingConfig",
Expand Down Expand Up @@ -192,6 +193,7 @@
if TYPE_CHECKING:
from .configuration_utils import (
BaseWatermarkingConfig,
CompileConfig,
GenerationConfig,
GenerationMode,
SynthIDTextWatermarkingConfig,
Expand Down
69 changes: 67 additions & 2 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass, is_dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

from .. import __version__
from ..configuration_utils import PretrainedConfig
Expand Down Expand Up @@ -361,6 +361,12 @@ class GenerationConfig(PushToHubMixin):
If set to a positive integer, early exit of the model will be used as an assistant. Can only be used with
models that support early exit (i.e. models where logits from intermediate layers can be interpreted by the LM head).

> Parameters related to performances and compilation

compile_config (CompileConfig, *optional*):
If using a static cache, this controls how `generate` will `compile` the forward pass for performance
gains.

> Wild card

generation_kwargs:
Expand Down Expand Up @@ -461,6 +467,9 @@ def __init__(self, **kwargs):
self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None)
self.assistant_early_exit = kwargs.pop("assistant_early_exit", None)

# Performances
self.compile_config = kwargs.pop("compile_config", CompileConfig())

# Wild card
self.generation_kwargs = kwargs.pop("generation_kwargs", {})

Expand Down Expand Up @@ -781,7 +790,13 @@ def validate(self, is_init=False):
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
self.watermarking_config.validate()

# 7. other incorrect combinations
# 7. performances arguments
if not isinstance(self.compile_config, CompileConfig):
raise ValueError(
f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an instance of `CompileConfig`."
)

# 8. other incorrect combinations
if self.return_dict_in_generate is not True:
for extra_output_flag in self.extra_output_flags:
if getattr(self, extra_output_flag) is True:
Expand Down Expand Up @@ -1162,6 +1177,8 @@ def to_dict(self) -> Dict[str, Any]:
del output["_commit_hash"]
if "_original_object_hash" in output:
del output["_original_object_hash"]
if "compile_config" in output:
del output["compile_config"]

# Transformers version when serializing this file
output["transformers_version"] = __version__
Expand Down Expand Up @@ -1546,3 +1563,51 @@ def construct_processor(self, vocab_size: int, device) -> "WatermarkLogitsProces
skip_first_ngram_calls=self.skip_first_ngram_calls,
debug_mode=self.debug_mode,
)


@dataclass
class CompileConfig(object):
"""
Class that holds arguments relative to `torch.compile` behavior, when using automatic compilation in `generate`.
See [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) for more details on the arguments.

Args:
fullgraph (`bool`, *optional*, defaults to `True`):
If `True`, requires that the whole forward be capturable in a single graph.
dynamic (`bool` or `None`, *optional*):
Whether to try to use dynamic shape graphs.
backend (`str` or `Callable`, *optional*, defaults to `"inductor"`):
Backend to be used.
mode (`str`, *optional*, defaults to `"reduce-overhead"`):
Controls balance between performance and overhead.
options (`dict`, *optional*):
A dictionary of options to pass to the backend.

Examples:
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, CompileConfig

>>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b')
>>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b').cuda()

>>> # Automatic compile configuration, used with static cache
>>> compile_config = CompileConfig(dynamic=True)

>>> # Generation with static cache and compile config
>>> input = tokenizer.encode("Hello there, how", return_tensors="pt").cuda()
>>> output = model.generate(
... input, do_sample=False, max_new_tokens=300, cache_implementation="static", compile_config=compile_config
... )
>>> output_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
```
"""

fullgraph: bool = True
dynamic: Optional[bool] = None
backend: Union[str, Callable] = "inductor"
mode: str = "reduce-overhead"
options: Optional[dict] = None

def to_dict(self) -> Dict[str, Any]:
"""Serializes this instance to a Python dictionary."""
return copy.deepcopy(self.__dict__)
14 changes: 6 additions & 8 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3230,16 +3230,14 @@ def _sample(
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)

def model_forward(model, *args, **kwargs):
return model.forward(*args, **kwargs)

model_forward = self.__call__
if isinstance(model_kwargs.get("past_key_values"), StaticCache):
if self.device.type == "cuda":
logger.warning_once("Using `torch.compile`.")
os.environ["TOKENIZERS_PARALLELISM"] = "0"
model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
model_forward = self.get_compiled_call(generation_config.compile_config)

i = 0
is_prefill = True
while self._has_unfinished_sequences(
this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
):
Expand All @@ -3250,11 +3248,11 @@ def model_forward(model, *args, **kwargs):
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})

if i == 0:
if is_prefill:
outputs = self(**model_inputs, return_dict=True)
i += 1
is_prefill = False
else:
outputs = model_forward(self, return_dict=True, **model_inputs)
outputs = model_forward(**model_inputs, return_dict=True)

# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
Expand Down
17 changes: 16 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from .activations import get_activation
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .generation import GenerationConfig, GenerationMixin
from .generation import CompileConfig, GenerationConfig, GenerationMixin
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
from .loss.loss_utils import LOSS_MAPPING
from .pytorch_utils import ( # noqa: F401
Expand Down Expand Up @@ -5083,6 +5083,21 @@ def loss_function(self):
loss_type = "ForCausalLM"
return LOSS_MAPPING[loss_type]

def get_compiled_call(self, compile_config: CompileConfig):
"""Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between
non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't
want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
(where we want the speed-ups of compiled version with static shapes)."""
# Only reset it if not present or different from previous config
default_config = getattr(self.generation_config, "compile_config", CompileConfig())
if (
not hasattr(self, "_compiled_call")
or getattr(self, "_last_compile_config", default_config) != compile_config
):
self._last_compile_config = compile_config
self._compiled_call = torch.compile(self.__call__, **compile_config.to_dict())
Copy link
Collaborator

@ydshieh ydshieh Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if it would (probably) make more sense to do

self._compiled_call = None

at this line and delegate the actually compile only in compiled_call (as you already there) - when it is called (like you do in generate).

Having 2 places to call torch.compile seems a slight strange to me (but it's ok though).

And if you decide to take my suggestion, probably the name _set_compile_call should be changed to _set_compile_config

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.compile does not compile, calling the compile funciton does compile

Copy link
Collaborator

@ydshieh ydshieh Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know. But what I mean here is not the underlying compile happening, but rather the call to torch.compile twice. 2 different methods calling torch.compile seems to me not very good style (for me, it's better if one method is only set config and another one take the job to call torch.compile)

But it's no big deal but just of a habit of making each method does its own job.

Copy link
Member Author

@Cyrilvallez Cyrilvallez Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, but maybe then using _set_compile_call in compiled_call instead? That way, set_compile_call is the only place using torch.compile, and compiled_call is still only used for accessing the fnction? (with drawback call to _set_compile in case it does not already exist)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even with that, in generate we still need to call _set_compile_call then compiled_call right? (as compiled_call) doesn't contain the argument. If this is the case, then it's odd to have compiled_call calling _set_compile_call.

Otherwise, if you think it's ok/good to change compiled_call to accept compile_config arugment, then sound good to me with your suggested change.

Copy link
Collaborator

@ydshieh ydshieh Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, compiled_call is a property so not to take argument .... so the concern in the first paragraph in the above comment is there.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

anyway, it's implementation details and not affecting users. Don't take it too serious if the changes will take too much time.

Copy link
Member Author

@Cyrilvallez Cyrilvallez Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I thought it makes sense to make it a property so that model.compiled_call(**inputs) could always be used directly as an alternative model.forward(**inputs)

return self._compiled_call


PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
if PreTrainedModel.push_to_hub.__doc__ is not None:
Expand Down