diff --git a/README.md b/README.md index 5cd9222a6c0..3e77c3550c5 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ # Hugging Face Optimum -🤗 Optimum is an extension of 🤗 Transformers, providing a set of performance optimization tools enabling maximum efficiency to train and run models on targeted hardware. +🤗 Optimum is an extension of 🤗 Transformers, providing a set of optimization tools enabling maximum efficiency to train and run models on targeted hardware. The AI ecosystem evolves quickly and more and more specialized hardware along with their own optimizations are emerging every day. As such, Optimum enables users to efficiently use any of these platforms with the same ease inherent to transformers. @@ -10,12 +10,13 @@ As such, Optimum enables users to efficiently use any of these platforms with th ## Integration with Hardware Partners -🤗 Optimum aims at providing more diversity towards the kind of hardware users can target to train and finetune their models. +Optimum aims at providing more diversity towards the kind of hardware users can target to train and finetune their models. To achieve this, we are collaborating with the following hardware manufacturers in order to provide the best transformers integration: - [Graphcore IPUs](https://github.com/huggingface/optimum-graphcore) - IPUs are a completely new kind of massively parallel processor to accelerate machine intelligence. More information [here](https://www.graphcore.ai/products/ipu). - [Habana Gaudi Processor (HPU)](https://github.com/huggingface/optimum-habana) - [HPUs](https://docs.habana.ai/en/latest/Gaudi_Overview/Gaudi_Architecture.html) are designed to maximize training throughput and efficiency. More information [here](https://habana.ai/training/). -- [Intel](https://github.com/huggingface/optimum-intel) - Enabling the usage of Intel tools to accelerate end-to-end pipelines on Intel architectures. More information about [Neural Compressor](https://www.intel.com/content/www/us/en/developer/tools/oneapi/neural-compressor.html) and [OpenVINO](https://docs.openvino.ai/latest/index.html). + +- [Intel](https://github.com/huggingface/optimum-intel) - Enabling the usage of Intel tools to accelerate inference on Intel architectures. More information about [Neural Compressor](https://www.intel.com/content/www/us/en/developer/tools/oneapi/neural-compressor.html) and [OpenVINO](https://docs.openvino.ai/latest/index.html). - More to come soon! :star: ## Optimizing models towards inference @@ -23,14 +24,14 @@ To achieve this, we are collaborating with the following hardware manufacturers Along with supporting dedicated AI hardware for training, Optimum also provides inference optimizations towards various frameworks and platforms. -Optimum enables the usage of popular compression techniques such as quantization and pruning by supporting [ONNX Runtime](https://onnxruntime.ai/docs/) along with [Intel Neural Compressor](https://www.intel.com/content/www/us/en/developer/tools/oneapi/neural-compressor.html). +Optimum enables the usage of popular compression techniques such as quantization and pruning by supporting [ONNX Runtime](https://onnxruntime.ai/docs/) along with Intel [Neural Compressor](https://www.intel.com/content/www/us/en/developer/tools/oneapi/neural-compressor.html) and OpenVINO [NNCF](https://docs.openvino.ai/latest/tmo_introduction.html). -| Features | ONNX Runtime | Intel Neural Compressor | -|:----------------------------------:|:---------------------:|:-----------------------:| -| Post-training Dynamic Quantization | :heavy_check_mark: | :heavy_check_mark: | -| Post-training Static Quantization | :heavy_check_mark: | :heavy_check_mark: | -| Quantization Aware Training (QAT) | Stay tuned! :star: | :heavy_check_mark: | -| Pruning | N/A | :heavy_check_mark: | +| Features | ONNX Runtime | Neural Compressor | OpenVINO | +|:----------------------------------:|:---------------------:|:-----------------------:|:-----------------------:| +| Post-training Dynamic Quantization | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| Post-training Static Quantization | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| Quantization Aware Training (QAT) | Stay tuned! :star: | :heavy_check_mark: | N/A | +| Pruning | N/A | :heavy_check_mark: | Stay tuned! :star: | ## Installation @@ -68,82 +69,161 @@ python -m pip install git+https://github.com/huggingface/optimum.git#egg=optimum Check out the examples below to see how 🤗 Optimum can be used to train and run inference on various hardware accelerators. -### Accelerated training +## Accelerated inference + +#### ONNX Runtime + +To accelerate inference with ONNX Runtime, 🤗 Optimum uses _configuration objects_ to define parameters for graph optimization and quantization. These objects are then used to instantiate dedicated _optimizers_ and _quantizers_. + +Before applying quantization or optimization, first we need to load our model. To load a model and run inference with ONNX Runtime, you can just replace the canonical Transformers [`AutoModelForXxx`](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModel) class with the corresponding [`ORTModelForXxx`](https://huggingface.co/docs/optimum/onnxruntime/package_reference/modeling_ort#optimum.onnxruntime.ORTModel) class. If you want to load from a PyTorch checkpoint, set `from_transformers=True` to export your model to the ONNX format. + +```python +from optimum.onnxruntime import ORTModelForSequenceClassification +from transformers import AutoTokenizer + +model_checkpoint = "distilbert-base-uncased-finetuned-sst-2-english" +save_directory = "tmp/onnx/" +# Load a model from transformers and export it to ONNX +tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) +ort_model = ORTModelForSequenceClassification.from_pretrained(model_checkpoint, from_transformers=True) +# Save the onnx model and tokenizer +ort_model.save_pretrained(save_directory) +tokenizer.save_pretrained(save_directory) +``` + +Let's see now how we can apply dynamic quantization with ONNX Runtime: + +```python +from optimum.onnxruntime.configuration import AutoQuantizationConfig +from optimum.onnxruntime import ORTQuantizer + +# Define the quantization methodology +qconfig = AutoQuantizationConfig.arm64(is_static=False, per_channel=False) +quantizer = ORTQuantizer.from_pretrained(ort_model) +# Apply dynamic quantization on the model +quantizer.quantize(save_dir=save_directory, quantization_config=qconfig) +``` + +In this example, we've quantized a model from the Hugging Face Hub, in the same manner we can quantize a model hosted locally by providing the path to the directory containing the model weights. The result from applying the `quantize()` method is a `model_quantized.onnx` file that can be used to run inference. Here's an example of how to load an ONNX Runtime model and generate predictions with it: + +```python +from optimum.onnxruntime import ORTModelForSequenceClassification +from transformers import pipeline, AutoTokenizer + +model = ORTModelForSequenceClassification.from_pretrained(save_directory, file_name="model_quantized.onnx") +tokenizer = AutoTokenizer.from_pretrained(save_directory) +classifier = pipeline("text-classification", model=model, tokenizer=tokenizer) +results = classifier("I love burritos!") +``` + +You can find more examples in the [documentation](https://huggingface.co/docs/optimum/onnxruntime/overview) and in the [examples](https://github.com/huggingface/optimum/tree/main/examples/onnxruntime). + + +#### Intel + +To load a model and run inference with OpenVINO Runtime, you can just replace your `AutoModelForXxx` class with the corresponding `OVModelForXxx` class. +If you want to load a PyTorch checkpoint, set `from_transformers=True` to convert your model to the OpenVINO IR (Intermediate Representation). + +```diff +- from transformers import AutoModelForSequenceClassification ++ from optimum.intel.openvino import OVModelForSequenceClassification + from transformers import AutoTokenizer, pipeline + + # Download a tokenizer and model from the Hub and convert to OpenVINO format + tokenizer = AutoTokenizer.from_pretrained(model_id) + model_id = "distilbert-base-uncased-finetuned-sst-2-english" +- model = AutoModelForSequenceClassification.from_pretrained(model_id) ++ model = OVModelForSequenceClassification.from_pretrained(model_id, from_transformers=True) + + # Run inference! + classifier = pipeline("text-classification", model=model, tokenizer=tokenizer) + results = classifier("He's a dreadful magician.") +``` + +You can find more examples in the [documentation](https://huggingface.co/docs/optimum/intel/index) and in the [examples](https://github.com/huggingface/optimum-intel/tree/main/examples/openvino). + + +## Accelerated training -#### Optimum Graphcore +#### Habana -To train transformers on Graphcore's IPUs, 🤗 Optimum provides a `IPUTrainer` that is very similar to the [🤗 Transformers trainer](https://huggingface.co/docs/transformers/main_classes/trainer). Here is a simple example: +To train transformers on Habana's Gaudi processors, 🤗 Optimum provides a `GaudiTrainer` that is very similar to the 🤗 Transformers [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer). Here is a simple example: ```diff - from transformers import Trainer, TrainingArguments -+ from optimum.graphcore import IPUConfig, IPUTrainer, IPUTrainingArguments ++ from optimum.habana import GaudiTrainer, GaudiTrainingArguments # Download a pretrained model from the Hub model = AutoModelForXxx.from_pretrained("bert-base-uncased") # Define the training arguments - training_args = TrainingArguments( -+ training_args = IPUTrainingArguments( ++ training_args = GaudiTrainingArguments( output_dir="path/to/save/folder/", -+ ipu_config_name="Graphcore/bert-base-ipu", # Any IPUConfig on the Hub or stored locally ++ use_habana=True, ++ use_lazy_mode=True, ++ gaudi_config_name="Habana/bert-base-uncased", ... ) - # Define the configuration to compile and put the model on the IPU -+ ipu_config = IPUConfig.from_pretrained(training_args.ipu_config_name) - # Initialize the trainer - trainer = Trainer( -+ trainer = IPUTrainer( ++ trainer = GaudiTrainer( model=model, -+ ipu_config=ipu_config args=training_args, - train_dataset=train_dataset + train_dataset=train_dataset, ... ) - # Use Graphcore IPU for training! + # Use Habana Gaudi processor for training! trainer.train() ``` +You can find more examples in the [documentation](https://huggingface.co/docs/optimum/habana/index) and in the [examples](https://github.com/huggingface/optimum-habana/tree/main/examples). -#### Optimum Habana -To train transformers on Habana's Gaudi processors, 🤗 Optimum provides a `GaudiTrainer` that is very similar to the [🤗 Transformers trainer](https://huggingface.co/docs/transformers/main_classes/trainer). Here is a simple example: +#### Graphcore + +To train transformers on Graphcore's IPUs, 🤗 Optimum provides a `IPUTrainer` that is very similar to the 🤗 Transformers [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer). Here is a simple example: ```diff - from transformers import Trainer, TrainingArguments -+ from optimum.habana import GaudiTrainer, GaudiTrainingArguments ++ from optimum.graphcore import IPUConfig, IPUTrainer, IPUTrainingArguments # Download a pretrained model from the Hub model = AutoModelForXxx.from_pretrained("bert-base-uncased") # Define the training arguments - training_args = TrainingArguments( -+ training_args = GaudiTrainingArguments( ++ training_args = IPUTrainingArguments( output_dir="path/to/save/folder/", -+ use_habana=True, -+ use_lazy_mode=True, -+ gaudi_config_name="Habana/bert-base-uncased", ++ ipu_config_name="Graphcore/bert-base-ipu", # Any IPUConfig on the Hub or stored locally ... ) + # Define the configuration to compile and put the model on the IPU ++ ipu_config = IPUConfig.from_pretrained(training_args.ipu_config_name) + # Initialize the trainer - trainer = Trainer( -+ trainer = GaudiTrainer( ++ trainer = IPUTrainer( model=model, ++ ipu_config=ipu_config args=training_args, - train_dataset=train_dataset, + train_dataset=train_dataset ... ) - # Use Habana Gaudi processor for training! + # Use Graphcore IPU for training! trainer.train() ``` +You can find more examples in the [documentation](https://huggingface.co/docs/optimum/graphcore/index) and in the [examples](https://github.com/huggingface/optimum-graphcore/tree/main/examples). + + #### ONNX Runtime -To train transformers with ONNX Runtime's acceleration features, 🤗 Optimum provides a `ORTTrainer` that is very similar to the [🤗 Transformers trainer](https://huggingface.co/docs/transformers/main_classes/trainer). Here is a simple example: +To train transformers with ONNX Runtime's acceleration features, 🤗 Optimum provides a `ORTTrainer` that is very similar to the 🤗 Transformers [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer). Here is a simple example: ```diff - from transformers import Trainer, TrainingArguments @@ -174,71 +254,4 @@ To train transformers with ONNX Runtime's acceleration features, 🤗 Optimum pr trainer.train() ``` - -### Accelerated inference - -#### ONNX Runtime - -To accelerate inference with ONNX Runtime, 🤗 Optimum uses _configuration objects_ to define parameters for optimization. These objects are then used to instantiate dedicated _optimizers_ and _quantizers_. - -Before applying quantization or optimization, first export our model to the ONNX format: - -```python -from optimum.onnxruntime import ORTModelForSequenceClassification -from transformers import AutoTokenizer - -model_checkpoint = "distilbert-base-uncased-finetuned-sst-2-english" -save_directory = "tmp/onnx/" -# Load a model from transformers and export it to ONNX -tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) -ort_model = ORTModelForSequenceClassification.from_pretrained(model_checkpoint, from_transformers=True) -# Save the onnx model and tokenizer -ort_model.save_pretrained(save_directory) -tokenizer.save_pretrained(save_directory) -``` - -Let's see now how we can apply dynamic quantization with ONNX Runtime: - -```python -from optimum.onnxruntime.configuration import AutoQuantizationConfig -from optimum.onnxruntime import ORTQuantizer - -# Define the quantization methodology -qconfig = AutoQuantizationConfig.arm64(is_static=False, per_channel=False) -quantizer = ORTQuantizer.from_pretrained(ort_model) -# Apply dynamic quantization on the model -quantizer.quantize(save_dir=save_directory, quantization_config=qconfig) -``` - -In this example, we've quantized a model from the Hugging Face Hub, but it could also be a path to a local model directory. The result from applying the `quantize()` method is a `model_quantized.onnx` file that can be used to run inference. Here's an example of how to load an ONNX Runtime model and generate predictions with it: - -```python -from optimum.onnxruntime import ORTModelForSequenceClassification -from transformers import pipeline, AutoTokenizer - -model = ORTModelForSequenceClassification.from_pretrained(save_directory, file_name="model_quantized.onnx") -tokenizer = AutoTokenizer.from_pretrained(save_directory) -classifier = pipeline("text-classification", model=model, tokenizer=tokenizer) -results = classifier("I love burritos!") -``` - -#### Optimum Intel - -Here is an example on how to perform inference with the OpenVINO Runtime: - -```diff -- from transformers import AutoModelForSequenceClassification -+ from optimum.intel.openvino import OVModelForSequenceClassification - from transformers import AutoTokenizer, pipeline - - # Download a tokenizer and model from the Hub and convert to OpenVINO format - tokenizer = AutoTokenizer.from_pretrained(model_id) - model_id = "distilbert-base-uncased-finetuned-sst-2-english" -- model = AutoModelForSequenceClassification.from_pretrained(model_id) -+ model = OVModelForSequenceClassification.from_pretrained(model_id, from_transformers=True) - - # Run inference! - classifier = pipeline("text-classification", model=model, tokenizer=tokenizer) - results = classifier("He's a dreadful magician.") -``` - +You can find more examples in the [documentation](https://huggingface.co/docs/optimum/onnxruntime/overview) and in the [examples](https://github.com/huggingface/optimum/tree/main/examples/onnxruntime/training). diff --git a/docs/combine_docs.py b/docs/combine_docs.py index 052bfcc5068..6f75bfe9d6a 100755 --- a/docs/combine_docs.py +++ b/docs/combine_docs.py @@ -6,6 +6,9 @@ import yaml +SECTIONS_AT_THE_END = ["Utilities"] + + parser = argparse.ArgumentParser( description="Script to combine doc builds from subpackages with base doc build of Optimum. " "Assumes all subpackage doc builds are present in the root of the `optimum` repo." @@ -83,6 +86,20 @@ def rename_copy_subpackage_html_paths(subpackage: str, subpackage_path: Path, op def main(): args = parser.parse_args() optimum_path = Path("optimum-doc-build") + # Load optimum table of contents + base_toc_path = next(optimum_path.rglob("_toctree.yml")) + with open(base_toc_path, "r") as f: + base_toc = yaml.safe_load(f) + + # Pop specific sections to add them after subpackages + sections_to_pop = {title: None for title in SECTIONS_AT_THE_END} + for i, section in enumerate(base_toc[:]): + if section["title"] in SECTIONS_AT_THE_END: + sections_to_pop[section["title"]] = base_toc.pop(i) + # Raise an error if a section was not found + for key, value in sections_to_pop.items(): + if value is None: + raise ValueError(f"No section was found for title '{key}'.") # Copy and rename all files from subpackages' docs to Optimum doc for subpackage in args.subpackages: @@ -96,10 +113,6 @@ def main(): args.version, ) - # Load optimum table of contents - base_toc_path = next(optimum_path.rglob("_toctree.yml")) - with open(base_toc_path, "r") as f: - base_toc = yaml.safe_load(f) # Load subpackage table of contents subpackage_toc_path = next(subpackage_path.rglob("_toctree.yml")) with open(subpackage_toc_path, "r") as f: @@ -108,8 +121,12 @@ def main(): rename_subpackage_toc(subpackage, subpackage_toc) # Update optimum table of contents base_toc.extend(subpackage_toc) - with open(base_toc_path, "w") as f: - yaml.safe_dump(base_toc, f, allow_unicode=True) + + # Add popped sections at the end + base_toc.extend(sections_to_pop.values()) + # Write final table of contents + with open(base_toc_path, "w") as f: + yaml.safe_dump(base_toc, f, allow_unicode=True) if __name__ == "__main__": diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 631bd01a13a..670de862447 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -114,7 +114,7 @@ - local: bettertransformer/tutorials/contribute title: How to add support for new architectures? title: Tutorials - title: BetterTransformer integration + title: BetterTransformer isExpanded: false - sections: - local: utils/dummy_input_generators diff --git a/docs/source/bettertransformer/overview.mdx b/docs/source/bettertransformer/overview.mdx index 0f0866bd027..da071545899 100644 --- a/docs/source/bettertransformer/overview.mdx +++ b/docs/source/bettertransformer/overview.mdx @@ -39,6 +39,7 @@ The list of supported model below: - [MarkupLM](https://arxiv.org/abs/2110.08518) - [MBart](https://arxiv.org/abs/2001.08210) - [M2M100](https://arxiv.org/abs/2010.11125) +- [RemBERT](https://arxiv.org/abs/2010.12821) - [RoBERTa](https://arxiv.org/abs/1907.11692) - [RoCBert](https://aclanthology.org/2022.acl-long.65.pdf) - [Splinter](https://arxiv.org/abs/2101.00438) diff --git a/docs/source/onnxruntime/package_reference/modeling_ort.mdx b/docs/source/onnxruntime/package_reference/modeling_ort.mdx index 08896b42790..66786b58cf2 100644 --- a/docs/source/onnxruntime/package_reference/modeling_ort.mdx +++ b/docs/source/onnxruntime/package_reference/modeling_ort.mdx @@ -20,6 +20,10 @@ specific language governing permissions and limitations under the License. [[autodoc]] onnxruntime.ORTModelForCausalLM +## ORTModelForCustomTasks + +[[autodoc]] onnxruntime.ORTModelForCustomTasks + ## ORTModelForFeatureExtraction [[autodoc]] onnxruntime.ORTModelForFeatureExtraction diff --git a/optimum/bettertransformer/models/__init__.py b/optimum/bettertransformer/models/__init__.py index 4c715a347ba..255d0ef736b 100644 --- a/optimum/bettertransformer/models/__init__.py +++ b/optimum/bettertransformer/models/__init__.py @@ -29,17 +29,18 @@ BETTER_TRANFORMER_LAYERS_MAPPING_DICT = { # Bert Family - "TapasLayer": BertLayerBetterTransformer, + "BertGenerationLayer": BertLayerBetterTransformer, "BertLayer": BertLayerBetterTransformer, - "ElectraLayer": BertLayerBetterTransformer, - "Data2VecTextLayer": BertLayerBetterTransformer, "CamembertLayer": BertLayerBetterTransformer, + "Data2VecTextLayer": BertLayerBetterTransformer, + "ElectraLayer": BertLayerBetterTransformer, + "ErnieLayer": BertLayerBetterTransformer, + "LayoutLMLayer": BertLayerBetterTransformer, "MarkupLMLayer": BertLayerBetterTransformer, + "RemBertLayer": BertLayerBetterTransformer, "RobertaLayer": BertLayerBetterTransformer, "SplinterLayer": BertLayerBetterTransformer, - "ErnieLayer": BertLayerBetterTransformer, - "LayoutLMLayer": BertLayerBetterTransformer, - "BertGenerationLayer": BertLayerBetterTransformer, + "TapasLayer": BertLayerBetterTransformer, "XLMRobertaLayer": BertLayerBetterTransformer, "RoCBertLayer": BertLayerBetterTransformer, # Albert Family @@ -62,13 +63,13 @@ # WhisperModel "WhisperEncoderLayer": WhisperEncoderLayerBetterTransformer, # Wav2vec2 family: - "Wav2Vec2EncoderLayer": Wav2Vec2EncoderLayerBetterTransformer, "HubertEncoderLayer": Wav2Vec2EncoderLayerBetterTransformer, + "Wav2Vec2EncoderLayer": Wav2Vec2EncoderLayerBetterTransformer, # "UniSpeechEncoderLayer": Wav2Vec2EncoderLayerBetterTransformer, # "Data2VecAudioEncoderLayer": Wav2Vec2EncoderLayerBetterTransformer, # ViT Family: - "ViTLayer": ViTLayerBetterTransformer, "DeiTLayer": ViTLayerBetterTransformer, + "ViTLayer": ViTLayerBetterTransformer, "ViTMAELayer": ViTLayerBetterTransformer, "ViTMSNLayer": ViTLayerBetterTransformer, "YolosLayer": ViTLayerBetterTransformer, diff --git a/optimum/bettertransformer/models/encoder_models.py b/optimum/bettertransformer/models/encoder_models.py index 8550d427f14..8340b552e16 100644 --- a/optimum/bettertransformer/models/encoder_models.py +++ b/optimum/bettertransformer/models/encoder_models.py @@ -302,6 +302,11 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__): """ super().forward_checker() + if not hasattr(hidden_states, "original_shape"): + original_shape = hidden_states.shape + else: + original_shape = hidden_states.original_shape + if hidden_states.is_nested: attention_mask = None @@ -339,8 +344,11 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__): self.linear2_bias, attention_mask, ) - if hidden_states.is_nested and self.is_last_layer: - hidden_states = hidden_states.to_padded_tensor(0.0) + + if not self.is_last_layer: + hidden_states.original_shape = original_shape + elif hidden_states.is_nested and self.is_last_layer: + hidden_states = hidden_states.to_padded_tensor(0.0, original_shape) return (hidden_states,) @@ -412,6 +420,11 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__): """ super().forward_checker() + if not hasattr(hidden_states, "original_shape"): + original_shape = hidden_states.shape + else: + original_shape = hidden_states.original_shape + if hidden_states.is_nested: attention_mask = None @@ -449,8 +462,11 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__): self.linear2_bias, attention_mask, ) - if hidden_states.is_nested and self.is_last_layer: - hidden_states = hidden_states.to_padded_tensor(0.0) + + if not self.is_last_layer: + hidden_states.original_shape = original_shape + elif hidden_states.is_nested and self.is_last_layer: + hidden_states = hidden_states.to_padded_tensor(0.0, original_shape) return (hidden_states,) @@ -1026,6 +1042,11 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__): """ super().forward_checker() + if not hasattr(hidden_states, "original_shape"): + original_shape = hidden_states.shape + else: + original_shape = hidden_states.original_shape + if hidden_states.is_nested: attention_mask = None @@ -1037,8 +1058,11 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__): seqlen = attention_mask.shape[1] lengths = torch.sum(~attention_mask, 1) + # FSMT swaps the first two axis before calling the encoder stack + # Reference: https://github.com/huggingface/transformers/blob/699e90437f984d69ad3c9b891dd2e9d0fc2cffe4/src/transformers/models/fsmt/modeling_fsmt.py#L508 if hidden_states.shape[0] != attention_mask.shape[0]: hidden_states = hidden_states.transpose(1, 0) + original_shape = hidden_states.shape if not all([l == seqlen for l in lengths]): hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask) @@ -1065,6 +1089,9 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__): self.linear2_bias, attention_mask, ) - if hidden_states.is_nested and self.is_last_layer: - hidden_states = hidden_states.to_padded_tensor(0.0) + + if not self.is_last_layer: + hidden_states.original_shape = original_shape + elif hidden_states.is_nested and self.is_last_layer: + hidden_states = hidden_states.to_padded_tensor(0.0, original_shape) return (hidden_states, attention_mask) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 5b1c1d42393..96c35161c76 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -17,9 +17,10 @@ from argparse import ArgumentParser from pathlib import Path -from transformers import AutoFeatureExtractor, AutoTokenizer +from transformers import AutoTokenizer from ...utils import logging +from ...utils.save_utils import maybe_save_preprocessors from ..tasks import TasksManager from .base import OnnxConfigWithPast from .convert import ( @@ -30,7 +31,7 @@ ) -logger = logging.get_logger() # pylint: disable=invalid-name +logger = logging.get_logger() logger.setLevel(logging.INFO) @@ -143,18 +144,7 @@ def main(): # Saving the model config as this is needed sometimes. model.config.save_pretrained(args.output.parent) - # Saving the tokenizer / feature extractor as well. - try: - tokenizer = AutoTokenizer.from_pretrained(args.model) - tokenizer.save_pretrained(args.output.parent) - except Exception: - pass - - try: - feature_extractor = AutoFeatureExtractor.from_pretrained(args.model) - feature_extractor.save_pretrained(args.output.parent) - except Exception: - pass + maybe_save_preprocessors(args.model, args.output.parent) if args.atol is None: args.atol = onnx_config.ATOL_FOR_VALIDATION diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index a4ae0773c6f..8a1599236c2 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -344,7 +344,9 @@ class OnnxConfigWithPast(OnnxConfig, ABC): Inherits from [`~exporters.onnx.OnnxConfig`]. A base class to handle the ONNX configuration of decoder-only models. """ - PAD_ATTENTION_MASK_TO_MATCH_TOTAL_SEQUENCE_LENGTH = True + PAD_ATTENTION_MASK_TO_MATCH_TOTAL_SEQUENCE_LENGTH: bool = True + USE_PAST_IN_INPUTS: Optional[bool] = None + USE_PRESENT_IN_OUTPUTS: Optional[bool] = None def __init__( self, @@ -352,8 +354,26 @@ def __init__( task: str = "default", patching_specs: List[PatchingSpec] = None, use_past: bool = False, + use_past_in_inputs: Optional[bool] = None, + use_present_in_outputs: Optional[bool] = None, ): self.use_past = use_past + if use_past_in_inputs is None: + use_past_in_inputs = self.USE_PAST_IN_INPUTS + if use_present_in_outputs is None: + use_present_in_outputs = self.USE_PRESENT_IN_OUTPUTS + self.use_past_in_inputs = use_past if use_past_in_inputs is None else use_past_in_inputs + self.use_present_in_outputs = use_past if use_present_in_outputs is None else use_present_in_outputs + if use_past != self.use_past_in_inputs: + logger.warning( + f"use_past = {use_past} is different than use_past_in_inputs = {use_past_in_inputs}, the value of " + "use_past_in_inputs will used for the inputs." + ) + if use_past != self.use_present_in_outputs: + logger.warning( + f"use_past = {use_past} is different than use_present_in_outputs = {use_present_in_outputs}, the value " + "of use_present_in_outputs value will used for the outputs." + ) super().__init__(config, task=task, patching_specs=patching_specs) @classmethod @@ -375,15 +395,14 @@ def with_past(cls, config: "PretrainedConfig", task: str = "default") -> "OnnxCo @property def outputs(self) -> Mapping[str, Mapping[int, str]]: common_outputs = super().outputs - if self.use_past: + if self.use_present_in_outputs: self.add_past_key_values(common_outputs, direction="outputs") - return common_outputs @property def values_override(self) -> Optional[Mapping[str, Any]]: if hasattr(self._config, "use_cache"): - return {"use_cache": self.use_past} + return {"use_cache": self.use_past_in_inputs or self.use_present_in_outputs} @add_dynamic_docstring(text=GENERATE_DUMMY_DOCSTRING, dynamic_elements=DEFAULT_DUMMY_SHAPES) def generate_dummy_inputs(self, framework: str = "pt", **kwargs): @@ -407,7 +426,7 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): if ( self.PAD_ATTENTION_MASK_TO_MATCH_TOTAL_SEQUENCE_LENGTH - and self.use_past + and self.use_past_in_inputs and "attention_mask" in dummy_inputs ): past_length = dummy_inputs["past_key_values"][0][0].shape[2] @@ -473,7 +492,7 @@ def outputs(self) -> Mapping[str, Mapping[int, str]]: # We reset the value as the order in common_outputs (OrderedDict) is lost otherwise else: axes_names[axis_idx] = name - if self.use_past: + if self.use_present_in_outputs: self.add_past_key_values(common_outputs, direction="outputs") return common_outputs diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index 8e0be8ec972..5ec28fccfe7 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -53,7 +53,7 @@ class TextDecoderOnnxConfig(OnnxConfigWithPast): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: common_inputs = {"input_ids": {0: "batch_size", 1: "sequence_length"}} - if self.use_past: + if self.use_past_in_inputs: self.add_past_key_values(common_inputs, direction="inputs") common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + sequence_length"} else: @@ -79,7 +79,7 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: "input_ids": {0: "batch_size", 1: "encoder_sequence_length"}, "attention_mask": {0: "batch_size", 1: "encoder_sequence_length"}, } - if self.use_past: + if self.use_past_in_inputs: common_inputs["attention_mask"][1] = "past_encoder_sequence_length + sequence_length" common_inputs["decoder_input_ids"] = {0: "batch_size"} # common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} @@ -87,7 +87,7 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"} # common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} - if self.use_past: + if self.use_past_in_inputs: self.add_past_key_values(common_inputs, direction="inputs") return common_inputs @@ -97,7 +97,7 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGen self.task, self._normalized_config, **kwargs ) - if self.use_past is True: + if self.use_past_in_inputs is True: if "sequence_length" in kwargs and kwargs["sequence_length"] != 1: logger.warning( f"Asked a sequence length of {kwargs['sequence_length']}, but expecting a sequence length of 1 with use_past == True. Overriding the sequence length to 1." diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index c5789cb6d12..f6d248a31f0 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -73,6 +73,8 @@ class Seq2SeqDecoderOnnxConfig(TextSeq2SeqOnnxConfig): DummySeq2SeqPastKeyValuesGenerator, ) + USE_PRESENT_IN_OUTPUTS = True + @property def inputs(self) -> Mapping[str, Mapping[int, str]]: common_inputs = { @@ -81,7 +83,7 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: "attention_mask": {0: "batch_size", 1: "encoder_sequence_length"}, } - if self.use_past: + if self.use_past_in_inputs: self.add_past_key_values(common_inputs, direction="inputs") return common_inputs @@ -94,21 +96,6 @@ def torch_to_onnx_input_map(self) -> Mapping[str, str]: "attention_mask": "encoder_attention_mask", } - @property - def outputs(self) -> Mapping[str, Mapping[int, str]]: - common_outputs = super().outputs - self.add_past_key_values(common_outputs, direction="outputs") - return common_outputs - - @property - def values_override(self) -> Optional[Mapping[str, Any]]: - # Needed here because the configuration will actually be used with both use_past = True and use_past = False, - # but the cache must always be used regardless. - if hasattr(self._config, "use_cache"): - return {"use_cache": True} - - return None - def generate_dummy_inputs_for_validation(self, reference_model_inputs: Mapping[str, Any]) -> Mapping[str, Any]: reference_model_inputs["input_ids"] = reference_model_inputs.pop("decoder_input_ids") reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0] @@ -413,7 +400,7 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGen self.task, self._normalized_config, **kwargs ) - if self.use_past is True: + if self.use_past_in_inputs is True: if "sequence_length" in kwargs and kwargs["sequence_length"] != 1: logger.warning( f"Asked a sequence length of {kwargs['sequence_length']}, but expecting a sequence length of 1 with use_past == True. Overriding the sequence length to 1." @@ -445,14 +432,14 @@ def inputs_for_default_and_seq2seq_lm(self): "input_ids": {0: "batch_size", 1: "encoder_sequence_length"}, "attention_mask": {0: "batch_size", 1: "encoder_sequence_length"}, } - if self.use_past: + if self.use_past_in_inputs: common_inputs["decoder_input_ids"] = {0: "batch_size"} # common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} else: common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"} # common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} - if self.use_past: + if self.use_past_in_inputs: self.add_past_key_values(common_inputs, direction="inputs") return common_inputs @@ -462,7 +449,7 @@ def inputs_for_causal_lm(self): "input_ids": {0: "batch_size", 1: "encoder_sequence_length"}, "attention_mask": {0: "batch_size", 1: "encoder_sequence_length"}, } - if self.use_past: + if self.use_past_in_inputs: for i in range(self._normalized_config.decoder_num_layers): common_inputs[f"past_key_values.{i}.key"] = { 0: "batch_size", @@ -498,7 +485,7 @@ def outputs(self) -> Mapping[str, Mapping[int, str]]: common_outputs = super().outputs else: common_outputs = super(OnnxConfigWithPast, self).outputs - if self.use_past: + if self.use_present_in_outputs: for i in range(self._normalized_config.encoder_num_layers): common_outputs[f"present.{i}.key"] = {0: "batch_size", 2: "past_sequence_length + sequence_length"} common_outputs[f"present.{i}.value"] = { @@ -796,7 +783,7 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: "encoder_outputs": {0: "batch_size", 1: "encoder_sequence_length"}, } - if self.use_past: + if self.use_past_in_inputs: self.add_past_key_values(common_inputs, direction="inputs") return common_inputs @@ -817,12 +804,12 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: common_inputs = { "input_features": {0: "batch_size", 1: "feature_size", 2: "encoder_sequence_length"}, } - if self.use_past: + if self.use_past_in_inputs: common_inputs["decoder_input_ids"] = {0: "batch_size"} else: common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"} - if self.use_past: + if self.use_past_in_inputs: self.add_past_key_values(common_inputs, direction="inputs") return common_inputs diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 247f1e6729d..8becc42e79c 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -14,8 +14,6 @@ # limitations under the License. """Utility functions.""" -from ctypes import c_float, sizeof -from enum import Enum from typing import TYPE_CHECKING, Dict, Tuple, Union import packaging diff --git a/optimum/modeling_base.py b/optimum/modeling_base.py index 19f80b8a007..285843dedaf 100644 --- a/optimum/modeling_base.py +++ b/optimum/modeling_base.py @@ -75,6 +75,7 @@ def __init__(self, model: Union["PreTrainedModel", "TFPreTrainedModel"], config: super().__init__() self.model = model self.config = config + self._preprocessors = [] def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) @@ -117,6 +118,8 @@ def save_pretrained( os.makedirs(save_directory, exist_ok=True) self.config.save_pretrained(save_directory) + for preprocessor in self._preprocessors: + preprocessor.save_pretrained(save_directory) self._save_pretrained(save_directory, **kwargs) if push_to_hub: @@ -132,10 +135,10 @@ def _save_pretrained(self, save_directory, **kwargs): def push_to_hub( self, - save_directory: str = None, - repository_id: Optional[str] = None, + save_directory: str, + repository_id: str, private: Optional[bool] = None, - use_auth_token: Optional[Union[bool, str]] = None, + use_auth_token: Union[bool, str] = True, ) -> str: if isinstance(use_auth_token, str): huggingface_token = use_auth_token diff --git a/optimum/onnx/configuration.py b/optimum/onnx/configuration.py index 40c6877feb2..01a51ad5003 100644 --- a/optimum/onnx/configuration.py +++ b/optimum/onnx/configuration.py @@ -340,61 +340,6 @@ def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int inputs_or_outputs[f"{name}_key_values_{i}"] = {0: "batch", 2: decoder_sequence} -class DecoderOnnxConfigWithPast(OnnxConfigWithPast): - @property - def inputs(self) -> Mapping[str, Mapping[int, str]]: - common_inputs = OrderedDict([("input_ids", {0: "batch", 1: "sequence"})]) - if self.use_past: - self.fill_with_past_key_values_(common_inputs, direction="inputs") - common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} - else: - common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} - - return common_inputs - - @property - def outputs(self) -> Mapping[str, Mapping[int, str]]: - common_outputs = super().outputs - if not self.use_past: - self.fill_with_past_key_values_(common_outputs, direction="outputs") - return common_outputs - - @property - def num_layers(self) -> Tuple[int]: - num_layers_names = {"decoder_layers", "n_layer", "num_layers"} - for num_layers_name in num_layers_names: - if hasattr(self._config, num_layers_name): - return getattr(self._config, num_layers_name) - raise AttributeError( - "Could not find the number of decoder layers attributes in the model configuration, override the " - "num_layers property to solve this" - ) - - @property - def num_attention_heads(self) -> int: - num_heads_names = {"num_attention_head", "n_head", "num_heads"} - for num_heads_name in num_heads_names: - if hasattr(self._config, num_heads_name): - return getattr(self._config, num_heads_name) - raise AttributeError( - "Could not find the number of decoder attention heads attributes in the model configuration, override the " - "num_heads property to solve this" - ) - - def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str): - num_pkv_per_layer = 2 - name = "past" if direction == "inputs" else "present" - decoder_sequence = "past_decoder_sequence" if direction == "inputs" else "past_decoder_sequence + sequence" - for i in range(self.num_layers * num_pkv_per_layer): - inputs_or_outputs[f"{name}_key_values_{i}"] = {0: "batch", 2: decoder_sequence} - - @property - def values_override(self) -> Optional[Mapping[str, Any]]: - if hasattr(self._config, "use_cache"): - return {"use_cache": True} - return None - - class OnnxSeq2SeqConfigWithPastAndLoss(DecoderOnnxConfig): def __init__(self, config: DecoderOnnxConfig): self.__dict__ = copy.deepcopy(config.__dict__) diff --git a/optimum/onnxruntime/io_binding/__init__.py b/optimum/onnxruntime/io_binding/__init__.py index e0810d5e807..d218d7a700d 100644 --- a/optimum/onnxruntime/io_binding/__init__.py +++ b/optimum/onnxruntime/io_binding/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .io_binding_helper import TypeHelper +from .io_binding_helper import IOBindingHelper, TypeHelper diff --git a/optimum/onnxruntime/io_binding/io_binding_helper.py b/optimum/onnxruntime/io_binding/io_binding_helper.py index e8005188bee..1911b1f8794 100644 --- a/optimum/onnxruntime/io_binding/io_binding_helper.py +++ b/optimum/onnxruntime/io_binding/io_binding_helper.py @@ -11,11 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import traceback + import numpy as np import torch +import onnxruntime as ort +from onnxruntime.capi.onnxruntime_inference_collection import OrtValue from onnxruntime.transformers.io_binding_helper import TypeHelper as ORTTypeHelper +from ..utils import is_cupy_available, is_onnxruntime_training_available + + +if is_cupy_available(): + import cupy as cp + # Adapted from https://github.com/microsoft/onnxruntime/blob/93e0a151177ad8222c2c95f814342bfa27f0a64d/onnxruntime/python/tools/transformers/io_binding_helper.py#L12 class TypeHelper(ORTTypeHelper): @@ -58,3 +69,81 @@ def ort_type_to_torch_type(ort_type: str): raise ValueError( f"{ort_type} is not supported. Here is a list of supported data type: {ort_type_to_torch_type_map.keys()}" ) + + +# Adapted from https://github.com/microsoft/onnxruntime/blob/1ab11a111ce0717bfbfaca964d04a017cb9b1752/onnxruntime/python/tools/transformers/io_binding_helper.py#L97 +class IOBindingHelper: + """ + A helper class to enable `ORTModel` instances to prepare IO binding with dynamic shaped outputs for an inference session and transfer + tensors from ONNX Runtime to other frameworks on device. It helps reduce memory copy between the host and device. + """ + + def __init__(self, model: ort.InferenceSession, device, **kwargs): + self.model = model + self.device = device + # Create {name:idx} dict for model inputs and outputs + self.model_inputs = {output_key.name: idx for idx, output_key in enumerate(model.get_inputs())} + self.model_outputs = {output_key.name: idx for idx, output_key in enumerate(model.get_outputs())} + self.model_input_names = list(self.model_inputs.keys()) + self.model_output_names = list(self.model_outputs.keys()) + + @staticmethod + def to_pytorch(ort_value: OrtValue) -> torch.Tensor: + """ + Converts tensors held by OrtValues in ONNX runtime memory buffer to torch tensor. + """ + + if is_onnxruntime_training_available(): + return IOBindingHelper.to_pytorch_via_dlpack(ort_value) + else: + try: + return IOBindingHelper.to_pytorch_via_cupy(ort_value) + except Exception as e: + logging.error(traceback.format_exc()) + logging.info("Unable to access output memory in CUDA, will offload to CPU") + return IOBindingHelper.to_pytorch_via_numpy(ort_value) + + @staticmethod + def to_pytorch_via_numpy(ort_value: OrtValue) -> torch.Tensor: + ort_device = ort_value.device_name().lower() + return torch.tensor(ort_value.numpy()).to(ort_device) + + @staticmethod + def to_pytorch_via_cupy(ort_value: OrtValue) -> torch.Tensor: + ort_device = ort_value.device_name().lower() + if ort_device != "cuda": + raise RuntimeError(f"Exchange tensors to PyTorch via CuPy only when device is CUDA, got: {ort_device}") + + ort_type = ort_value.data_type() + numpy_type = TypeHelper.ort_type_to_numpy_type(ort_type) + + # Access CUDA memory via CuPy + memory = cp.cuda.UnownedMemory(ort_value.data_ptr(), 0, None) + memory_ptr = cp.cuda.MemoryPointer(memory, 0) + cp_array = cp.ndarray(shape=ort_value.shape(), memptr=memory_ptr, dtype=numpy_type) + torch_tensor = torch.from_dlpack(cp_array.toDlpack()) + + # If is boolean, the dtype will be uint8 and need to be convert back to bool. + if "bool" in ort_type: + torch_tensor = torch_tensor.to(torch.bool) + + torch_tensor = torch_tensor.clone() + + return torch_tensor + + @staticmethod + # dlpack support is available for OrtValue only when `onnxruntime-training` is installed + def to_pytorch_via_dlpack(ort_value: OrtValue) -> torch.Tensor: + from torch._C import _from_dlpack + + torch_tensor = ort_value.to_dlpacks(_from_dlpack) + return torch_tensor + + @staticmethod + def get_device_index(device): + if isinstance(device, str): + # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 + device = torch.device(device) + elif isinstance(device, int): + return device + return 0 if device.index is None else device.index diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 38720f3d953..084776c1aee 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -14,30 +14,34 @@ """Classes handling causal-lm related architectures in ONNX Runtime.""" import logging -import os import shutil from pathlib import Path -from typing import Any, Dict, Optional, Tuple, Union +from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import numpy as np import torch -import transformers -from transformers import AutoModelForCausalLM, PretrainedConfig -from transformers.file_utils import add_start_docstrings_to_model_forward, default_cache_path +from transformers import AutoModelForCausalLM +from transformers.file_utils import add_start_docstrings_to_model_forward from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions -from transformers.onnx import FeaturesManager, export -from transformers.onnx.utils import get_preprocessor import onnxruntime from huggingface_hub import hf_hub_download -from ..onnx.configuration import DecoderOnnxConfigWithPast +from ..exporters import TasksManager +from ..exporters.onnx import export from ..utils import NormalizedConfigManager, check_if_transformers_greater +from ..utils.file_utils import validate_file_exists +from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors from .io_binding import TypeHelper from .modeling_ort import ORTModel from .utils import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, get_provider_for_device, parse_device +if TYPE_CHECKING: + from transformers import PretrainedConfig + + if check_if_transformers_greater("4.25.0"): from transformers.generation import GenerationMixin else: @@ -106,6 +110,9 @@ ``` """ +DECODER_ONNX_FILE_PATTERN = r"(.*)?decoder((?!with_past).)*?\.onnx" +DECODER_WITH_PAST_ONNX_FILE_PATTERN = r"(.*)?decoder(.*)?with_past(.*)?\.onnx" + class ORTDecoder: """ @@ -115,7 +122,7 @@ class ORTDecoder: def __init__( self, session: onnxruntime.InferenceSession, - config: transformers.PretrainedConfig, + config: "PretrainedConfig", device: torch.device, use_io_binding: bool = True, ): @@ -130,8 +137,11 @@ def __init__( self.session_outputs = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} self.session_input_names = list(self.session_inputs.keys()) self.session_output_names = list(self.session_outputs.keys()) - self.key_value_input_names = [key for key in self.session_input_names if "key_values" in key] - self.key_value_output_names = [key for key in self.session_output_names if "key_values" in key] + # TODO: make this less hacky. + self.key_value_input_names = [key for key in self.session_input_names if (".key" in key) or (".value" in key)] + self.key_value_output_names = [ + key for key in self.session_output_names if (".key" in key) or (".value" in key) + ] self.name_to_np_type = TypeHelper.get_io_numpy_type_map(self.session) if self.use_io_binding else None def prepare_output_buffer( @@ -149,7 +159,7 @@ def prepare_output_buffer( if output_name == "logits": output_shape = (batch_size, sequence_length, self.normalized_config.vocab_size) output_buffer = torch.empty(np.prod(output_shape), dtype=torch_type, device=self._device).contiguous() - elif "key_values" in output_name: + elif ".key" in output_name or ".value" in output_name: num_attention_heads = self.normalized_config.num_attention_heads hidden_size = self.normalized_config.hidden_size embed_size_per_head = hidden_size // num_attention_heads @@ -321,19 +331,19 @@ class ORTModelDecoder(ORTModel): def __init__( self, - config: transformers.PretrainedConfig, decoder_session: onnxruntime.InferenceSession, + config: "PretrainedConfig", decoder_with_past_session: Optional[onnxruntime.InferenceSession] = None, use_io_binding: bool = True, - model_save_dir: str = "", - last_decoder_model_name: str = ONNX_DECODER_NAME, - last_decoder_with_past_model_name: str = ONNX_DECODER_WITH_PAST_NAME, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + preprocessors: Optional[List] = None, + **kwargs ): """ Args: decoder_session (`onnxruntime.InferenceSession`): The ONNX Runtime inference session associated to the decoder. - config (`transformers.PretrainedConfig`): + config ([~`transformers.PretrainedConfig`]): An instance of the configuration associated to the model. Initializing with a config file does not load the weights associated with the model, only the configuration. decoder_with_past_session (`Optional[onnxruntime.InferenceSession]`, *optional*): @@ -343,26 +353,40 @@ def __init__( `True` if the device is CUDA, otherwise defaults to `False`. model_save_dir (`str`, *optional*, defaults to `""`): The directory under which the model exported to ONNX was saved. - last_decoder_model_name (`str`, *optional*, defaults to `optimum.onnxruntime.utils.ONNX_DECODER_NAME`): - The name of the ONNX file containing the decoder part of the model. - last_decoder_with_past_model_name (`str`, *optional*, defaults to `optimum.onnxruntime.utils.ONNX_DECODER_WITH_PAST_NAME`): - The name of the ONNX file containing the decoder with past key/values part of the model. + preprocessors (`Optional[List]`, defaults to `None`): + The list of the preprocessors (tokenizer, processor, feature_extractor) to save alongside the ORTModel. """ + # TODO: remove at version 2.0 + def show_deprecated_argument(arg_name): + if kwargs.pop(arg_name, None) is not None: + logger.warning( + f"The {arg_name} argument to create an {self.__class__.__name__} is deprecated, and not used " + "anymore." + ) + + show_deprecated_argument("last_decoder_model_name") + show_deprecated_argument("last_decoder_with_past_model_name") + if kwargs: + raise ValueError( + f"{self.__class__.__name__} received {', '.join(kwargs.keys())}, but do not accept those arguments." + ) + super().__init__( decoder_session, config, use_io_binding=use_io_binding, model_save_dir=model_save_dir, - latest_model_name=last_decoder_model_name, ) - self.decoder_file_name = last_decoder_model_name - self.decoder_file_with_past_name = last_decoder_with_past_model_name - self.use_cache = decoder_with_past_session is not None self.decoder = ORTDecoder( - session=self.model, config=self.config, device=self._device, use_io_binding=self.use_io_binding + session=decoder_session, config=self.config, device=self._device, use_io_binding=self.use_io_binding ) + self.decoder_model_path = Path(decoder_session._model_path) + self.decoder_model_name = self.decoder_model_path.name + self.decoder_with_past = None + self.decoder_with_past_model_path = None + self.decoder_with_past_model_name = None if self.use_cache: self.decoder_with_past = ORTDecoder( session=decoder_with_past_session, @@ -370,6 +394,8 @@ def __init__( device=self._device, use_io_binding=self.use_io_binding, ) + self.decoder_with_past_model_path = Path(decoder_with_past_session._model_path) + self.decoder_with_past_model_name = self.decoder_with_past_model_path.name @staticmethod def load_model( @@ -447,14 +473,13 @@ def _save_pretrained( The decoder with past key values model file name overwriting the default file name, allowing to save the decoder model with a different name. """ - src_file_names = [self.decoder_file_name] + src_paths = [self.decoder_model_path] dst_file_names = [decoder_file_name] if self.use_cache: - src_file_names.append(self.decoder_file_with_past_name) + src_paths.append(self.decoder_with_past_model_path) dst_file_names.append(decoder_with_past_file_name) - for src_file_name, dst_file_name in zip(src_file_names, dst_file_names): - src_path = self.model_save_dir.joinpath(src_file_name) + for src_path, dst_file_name in zip(src_paths, dst_file_names): dst_path = Path(save_directory).joinpath(dst_file_name) shutil.copyfile(src_path, dst_path) @@ -465,74 +490,115 @@ def _from_pretrained( config: "PretrainedConfig", use_auth_token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, - force_download: bool = True, + force_download: bool = False, cache_dir: Optional[str] = None, decoder_file_name: str = ONNX_DECODER_NAME, decoder_with_past_file_name: str = ONNX_DECODER_WITH_PAST_NAME, subfolder: str = "", local_files_only: bool = False, use_cache: bool = True, - use_io_binding: bool = True, provider: str = "CPUExecutionProvider", session_options: Optional[onnxruntime.SessionOptions] = None, provider_options: Optional[Dict[str, Any]] = None, + use_io_binding: bool = True, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, ): - file_names = {} - # Load model from a local directory - if os.path.isdir(os.path.join(model_id, subfolder)): - decoder_with_past_path = ( - os.path.join(model_id, subfolder, decoder_with_past_file_name) if use_cache else None + model_path = Path(model_id) + + if not validate_file_exists(model_id, decoder_file_name, subfolder=subfolder, revision=revision): + decoder_file_name = ORTModelDecoder.infer_onnx_filename( + model_id, + DECODER_ONNX_FILE_PATTERN, + "decoder_file_name", + subfolder=subfolder, + use_auth_token=use_auth_token, + revision=revision, + ) + if not validate_file_exists(model_id, decoder_with_past_file_name, subfolder=subfolder, revision=revision): + decoder_with_past_file_name = ORTModelDecoder.infer_onnx_filename( + model_id, + DECODER_WITH_PAST_ONNX_FILE_PATTERN, + "decoder_with_past_file_name", + subfolder=subfolder, + use_auth_token=use_auth_token, + revision=revision, + fail_if_not_found=use_cache, ) + + decoder_regular_onnx_filenames = ORTModelDecoder._generate_regular_names_for_filename(ONNX_DECODER_NAME) + decoder_with_past_regular_onnx_filenames = ORTModelDecoder._generate_regular_names_for_filename( + ONNX_DECODER_WITH_PAST_NAME + ) + + if decoder_file_name not in decoder_regular_onnx_filenames: + logger.warning( + f"The ONNX file {decoder_file_name} is not a regular name used in optimum.onnxruntime, the " + "ORTModelForConditionalGeneration might not behave as expected." + ) + if decoder_with_past_file_name not in decoder_with_past_regular_onnx_filenames: + logger.warning( + f"The ONNX file {decoder_with_past_file_name} is not a regular name used in optimum.onnxruntime, " + "the ORTModelForConditionalGeneration might not behave as expected." + ) + + decoder_with_past_path = model_path / decoder_with_past_file_name if use_cache else None + + preprocessors = None + if model_path.is_dir(): model = cls.load_model( - decoder_path=os.path.join(model_id, subfolder, decoder_file_name), + decoder_path=model_path / decoder_file_name, decoder_with_past_path=decoder_with_past_path, provider=provider, session_options=session_options, provider_options=provider_options, ) - model_save_dir = Path(model_id).joinpath(subfolder) - file_names["last_decoder_model_name"] = decoder_file_name - file_names["last_decoder_with_past_model_name"] = decoder_with_past_file_name - # Load model from hub + new_model_save_dir = model_path + preprocessors = maybe_load_preprocessors(model_id) else: - default_file_names = [ONNX_DECODER_NAME] - model_file_names = [decoder_file_name] - if use_cache: - default_file_names.append(ONNX_DECODER_WITH_PAST_NAME) - model_file_names.append(decoder_with_past_file_name) - # Download the decoder and decoder_with_past forming the model - for file_name, default_file_name in zip(model_file_names, default_file_names): + attribute_name_to_filename = { + "last_decoder_model_name": decoder_file_name, + "last_decoder_with_past_model_name": decoder_with_past_file_name if use_cache else None, + } + paths = {} + for attr_name, filename in attribute_name_to_filename.items(): + if filename is None: + continue model_cache_path = hf_hub_download( repo_id=model_id, subfolder=subfolder, - filename=file_name, + filename=filename, use_auth_token=use_auth_token, revision=revision, cache_dir=cache_dir, force_download=force_download, local_files_only=local_files_only, ) - file_names[f"last_{default_file_name.split('.')[0]}_name"] = Path(model_cache_path).name - model_save_dir = Path(model_cache_path).parent + paths[attr_name] = Path(model_cache_path).name + new_model_save_dir = Path(model_cache_path).parent + preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder) - last_decoder_with_past_name = file_names.get("last_decoder_with_past_model_name", None) + last_decoder_with_past_name = paths.get("last_decoder_with_past_model_name", None) if last_decoder_with_past_name is not None: - last_decoder_with_past_name = model_save_dir.joinpath(last_decoder_with_past_name) + last_decoder_with_past_name = new_model_save_dir / last_decoder_with_past_name + model = cls.load_model( - decoder_path=model_save_dir.joinpath(file_names["last_decoder_model_name"]), + decoder_path=new_model_save_dir / paths["last_decoder_model_name"], decoder_with_past_path=last_decoder_with_past_name, provider=provider, session_options=session_options, provider_options=provider_options, ) + if model_save_dir is None: + model_save_dir = new_model_save_dir + return cls( + model[0], config, - *model, + decoder_with_past_session=model[1], use_io_binding=use_io_binding, model_save_dir=model_save_dir, - last_decoder_model_name=file_names["last_decoder_model_name"], - last_decoder_with_past_model_name=file_names.get("last_decoder_with_past_model_name", None), + preprocessors=preprocessors, ) @classmethod @@ -540,46 +606,74 @@ def _from_transformers( cls, model_id: str, config: "PretrainedConfig", - subfolder: Optional[str] = "", - save_dir: Union[str, Path] = default_cache_path, use_auth_token: Optional[Union[bool, str]] = None, - revision: Optional[str] = None, + revision: str = "main", force_download: bool = True, cache_dir: Optional[str] = None, + subfolder: str = "", + local_files_only: bool = False, use_cache: bool = True, - **kwargs, - ): - # Create local save dir in cache dir - save_dir = Path(save_dir).joinpath(model_id) - save_dir.mkdir(parents=True, exist_ok=True) - preprocessor = get_preprocessor(model_id) - framework = FeaturesManager.determine_framework(os.path.join(model_id, subfolder)) - model_class = FeaturesManager.get_model_class_for_feature(cls.export_feature, framework) - model = model_class.from_pretrained(model_id, subfolder=subfolder, config=config, cache_dir=cache_dir) - - # Export the decoder without the past key values - onnx_config = DecoderOnnxConfigWithPast(model.config, task=cls.export_feature, use_past=False) - onnx_opset = onnx_config.default_onnx_opset + provider: str = "CPUExecutionProvider", + session_options: Optional[onnxruntime.SessionOptions] = None, + provider_options: Optional[Dict[str, Any]] = None, + use_io_binding: bool = True, + task: Optional[str] = None, + ) -> "ORTModelDecoder": + if task is None: + task = cls._AUTOMODELS_TO_TASKS[cls.auto_model_class] + + save_dir = TemporaryDirectory() + save_dir_path = Path(save_dir.name) + + model = TasksManager.get_model_from_task( + task, + model_id, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + config=config, + use_auth_token=use_auth_token, + local_files_only=local_files_only, + force_download=force_download, + ) + + model_type = model.config.model_type.replace("_", "-") + model_name = getattr(model, "name", None) + + onnx_config_constructor = TasksManager.get_exporter_config_constructor( + model_type, "onnx", task=task, model_name=model_name + ) + onnx_config = onnx_config_constructor(model.config, use_present_in_outputs=True) + export( - preprocessor=preprocessor, - model=model, - config=onnx_config, - opset=onnx_opset, - output=save_dir.joinpath(ONNX_DECODER_NAME), + model, + onnx_config, + onnx_config.DEFAULT_ONNX_OPSET, + save_dir_path.joinpath(ONNX_DECODER_NAME), ) - # Export the decoder with the past key values if use_cache: - onnx_config_with_past = DecoderOnnxConfigWithPast(model.config, task=cls.export_feature, use_past=True) + onnx_config = onnx_config_constructor(model.config, use_past=True) export( - preprocessor=preprocessor, - model=model, - config=onnx_config_with_past, - opset=onnx_opset, - output=save_dir.joinpath(ONNX_DECODER_WITH_PAST_NAME), + model, + onnx_config, + onnx_config.DEFAULT_ONNX_OPSET, + save_dir_path.joinpath(ONNX_DECODER_WITH_PAST_NAME), ) - return cls._from_pretrained(save_dir, config=config, use_cache=use_cache, **kwargs) + config.save_pretrained(save_dir_path) + maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder) + + return cls._from_pretrained( + save_dir_path, + config, + use_cache=use_cache, + provider=provider, + session_options=session_options, + provider_options=provider_options, + use_io_binding=use_io_binding, + model_save_dir=save_dir, + ) def to(self, device: Union[torch.device, str, int]): """ @@ -612,8 +706,6 @@ class ORTModelForCausalLM(ORTModelDecoder, GenerationMixin): ONNX model with a causal language modeling head for ONNX Runtime inference. """ - # Used to export the model to ONNX - export_feature = "causal-lm" auto_model_class = AutoModelForCausalLM main_input_name = "input_ids" diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index 94a66680076..f696346d1d8 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -14,10 +14,10 @@ """ORTModelForXXX classes, allowing to run ONNX Models with ONNX Runtime using the same API as Transformers.""" import logging -import os import shutil from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import numpy as np import torch @@ -30,7 +30,7 @@ AutoModelForSequenceClassification, AutoModelForTokenClassification, ) -from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, default_cache_path +from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from transformers.modeling_outputs import ( BaseModelOutput, ImageClassifierOutput, @@ -42,12 +42,14 @@ ) import onnxruntime as ort -from huggingface_hub import hf_hub_download +from huggingface_hub import HfApi, HfFolder, hf_hub_download from ..exporters import TasksManager from ..exporters.onnx import export from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel -from .io_binding import TypeHelper +from ..utils.file_utils import find_files_matching_pattern +from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors +from .io_binding import IOBindingHelper, TypeHelper from .utils import ( ONNX_WEIGHTS_NAME, get_device_for_provider, @@ -107,6 +109,14 @@ """ +class classproperty: + def __init__(self, getter): + self.getter = getter + + def __get__(self, instance, owner): + return self.getter(owner) + + class ORTModel(OptimizedModel): """ Base class for implementing models using ONNX Runtime. @@ -125,32 +135,63 @@ class ORTModel(OptimizedModel): - config ([`~transformers.PretrainedConfig`] -- The configuration of the model. - use_io_binding (`bool`, *optional*, defaults to `True`) -- Whether to use I/O bindings with **ONNX Runtime with the CUDAExecutionProvider**, this can significantly speedup inference depending on the task. - - model_save_dir (`Optional[str]`, *optional*) -- The directory where the model exported to ONNX will be saved. + - model_save_dir (`Path`) -- The directory where the model exported to ONNX is saved. By defaults, if the loaded model is local, the directory where the original model will be used. Otherwise, the cache directory is used. - - latest_model_name (`str`, *optional*, defaults to `"model.onnx"` -- The name of the last ONNX model file. - providers (`List[str]) -- The list of execution providers available to ONNX Runtime. """ + _AUTOMODELS_TO_TASKS = {cls_: task for task, cls_ in TasksManager._TASKS_TO_AUTOMODELS.items()} model_type = "onnx_model" auto_model_class = AutoModel + @classproperty + def export_feature(cls): + logger.warning(f"{cls.__name__}.export_feature is deprecated, and will be removed in optimum 2.0.") + return cls._AUTOMODELS_TO_TASKS.get(cls.auto_model_class, None) + def __init__( self, model: ort.InferenceSession, config: "PretrainedConfig", use_io_binding: bool = True, - model_save_dir: Optional[str] = None, - latest_model_name: str = "model.onnx", + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + preprocessors: Optional[List] = None, + **kwargs, ): - self.model = model - self.config = config + # TODO: remove at version 2.0 + if kwargs.pop("latest_model_name", None) is not None: + logger.warning( + f"The latest_model_name argument to create an {self.__class__.__name__} is deprecated, and not used " + "anymore." + ) + if kwargs: + raise ValueError( + f"{self.__class__.__name__} received {', '.join(kwargs.keys())}, but do not accept those arguments." + ) + + super().__init__(model, config) self.use_io_binding = use_io_binding - self.model_save_dir = model_save_dir - self.latest_model_name = latest_model_name self.providers = model.get_providers() self._device = get_device_for_provider(self.providers[0]) + # This attribute is needed to keep one reference on the temporary directory, since garbage collecting it + # would end-up removing the directory containing the underlying ONNX model. + self._model_save_dir_tempdirectory_instance = None + if model_save_dir is None: + self.model_save_dir = Path(model._model_path).parent + elif isinstance(model_save_dir, TemporaryDirectory): + self._model_save_dir_tempdirectory_instance = model_save_dir + self.model_save_dir = Path(model_save_dir.name) + elif isinstance(model_save_dir, str): + self.model_save_dir = Path(model_save_dir) + else: + self.model_save_dir = model_save_dir + self.model_path = Path(model._model_path) + self.model_name = self.model_path.name + + self._preprocessors = preprocessors if preprocessors is not None else [] + if self._device is None: logger.warning( f"ORTModel outputs will be sent to CPU as the device could not be inferred from the execution provider {self.providers[0]}." @@ -237,6 +278,9 @@ def load_model( # Follow advice in https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#python providers.append("CUDAExecutionProvider") + if not isinstance(path, str): + path = str(path) + # `providers` list must of be of the same length as `provider_options` list return ort.InferenceSession( path, @@ -257,9 +301,49 @@ def _save_pretrained(self, save_directory: Union[str, Path], file_name: str = ON file_name (`str`, *optional*, defaults to the value of `optimum.onnxruntime.utils.ONNX_WEIGHTS_NAME`): The filename to use when saving the model. """ - src_path = self.model_save_dir.joinpath(self.latest_model_name) + # TODO: support models with external data dst_path = Path(save_directory).joinpath(file_name) - shutil.copyfile(src_path, dst_path) + shutil.copyfile(self.model_path, dst_path) + + @staticmethod + def _generate_regular_names_for_filename(filename: str): + name, extension = filename.rsplit(".", maxsplit=1) + return [filename, f"{name}_quantized.{extension}", f"{name}_optimized.{extension}"] + + @staticmethod + def infer_onnx_filename( + model_name_or_path: Union[str, Path], + pattern: str, + argument_name: str, + subfolder: str = "", + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + fail_if_not_found: bool = True, + ) -> str: + onnx_files = find_files_matching_pattern( + model_name_or_path, + pattern, + glob_pattern="**/*.onnx", + subfolder=subfolder, + use_auth_token=use_auth_token, + revision=revision, + ) + + path = model_name_or_path + if subfolder != "": + path = f"{path}/{subfolder}" + + if len(onnx_files) == 0: + if fail_if_not_found: + raise FileNotFoundError(f"Could not find any ONNX model file in {path}") + return None + elif len(onnx_files) > 1: + if argument_name is not None: + raise RuntimeError( + f"Too many ONNX model files were found in {path}, specify which one to load by using the " + f"{argument_name} argument." + ) + return onnx_files[0].name @classmethod def _from_pretrained( @@ -270,23 +354,56 @@ def _from_pretrained( revision: Optional[str] = None, force_download: bool = False, cache_dir: Optional[str] = None, - file_name: str = ONNX_WEIGHTS_NAME, + file_name: Optional[str] = None, subfolder: str = "", local_files_only: bool = False, provider: str = "CPUExecutionProvider", session_options: Optional[ort.SessionOptions] = None, provider_options: Optional[Dict[str, Any]] = None, - **kwargs, + use_io_binding: bool = True, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, ) -> "ORTModel": - if os.path.isdir(os.path.join(model_id, subfolder)): + model_path = Path(model_id) + regular_onnx_filenames = ORTModel._generate_regular_names_for_filename(ONNX_WEIGHTS_NAME) + + if file_name is None: + if model_path.is_dir(): + onnx_files = list(model_path.glob("*.onnx")) + else: + if isinstance(use_auth_token, bool): + token = HfFolder().get_token() + else: + token = use_auth_token + repo_files = map(Path, HfApi().list_repo_files(model_id, revision=revision, token=token)) + pattern = "*.onnx" if subfolder == "" else f"{subfolder}/*.onnx" + onnx_files = [p for p in repo_files if p.match(pattern)] + + if len(onnx_files) == 0: + raise FileNotFoundError(f"Could not find any ONNX model file in {model_path}") + elif len(onnx_files) > 1: + raise RuntimeError( + f"Too many ONNX model files were found in {model_path}, specify which one to load by using the " + "file_name argument." + ) + else: + file_name = onnx_files[0].name + + if file_name not in regular_onnx_filenames: + logger.warning( + f"The ONNX file {file_name} is not a regular name used in optimum.onnxruntime, the ORTModel might " + "not behave as expected." + ) + + preprocessors = None + if model_path.is_dir(): model = ORTModel.load_model( - os.path.join(model_id, subfolder, file_name), + model_path / file_name, provider=provider, session_options=session_options, provider_options=provider_options, ) - kwargs["model_save_dir"] = Path(model_id).joinpath(subfolder) - kwargs["latest_model_name"] = file_name + new_model_save_dir = model_path + preprocessors = maybe_load_preprocessors(model_id) else: model_cache_path = hf_hub_download( repo_id=model_id, @@ -301,17 +418,27 @@ def _from_pretrained( model = ORTModel.load_model( model_cache_path, provider=provider, session_options=session_options, provider_options=provider_options ) - kwargs["model_save_dir"] = Path(model_cache_path).parent - kwargs["latest_model_name"] = Path(model_cache_path).name + new_model_save_dir = Path(model_cache_path).parent + preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder) - return cls(model=model, config=config, **kwargs) + # model_save_dir can be provided in kwargs as a TemporaryDirectory instance, in which case we want to keep it + # instead of the path only. + if model_save_dir is None: + model_save_dir = new_model_save_dir + + return cls( + model=model, + config=config, + use_io_binding=use_io_binding, + model_save_dir=model_save_dir, + preprocessors=preprocessors, + ) @classmethod def _from_transformers( cls, model_id: str, config: "PretrainedConfig", - save_dir: Union[str, Path] = default_cache_path, use_auth_token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, @@ -321,23 +448,11 @@ def _from_transformers( provider: str = "CPUExecutionProvider", session_options: Optional[ort.SessionOptions] = None, provider_options: Optional[Dict[str, Any]] = None, - **kwargs, + use_io_binding: bool = True, + task: Optional[str] = None, ) -> "ORTModel": - save_dir = Path(save_dir).joinpath(model_id, subfolder) - save_dir.mkdir(parents=True, exist_ok=True) - - # Reads pipeline task from ORTModelForXXX class if available else tries to extract from hub - if cls.export_feature is not None: - task = cls.export_feature - else: - # TODO: Do we want to actually support that? - # TODO: load from subfolder? - task = TasksManager.infer_task_from_model(model_id, revision=revision) - # TODO: is it still needed? - if task in ["sentiment-analysis", "text-classification", "zero-shot-classification"]: - task = "sequence-classification" - elif task in ["feature-extraction", "fill-mask"]: - task = "default" + if task is None: + task = cls._AUTOMODELS_TO_TASKS[cls.auto_model_class] kwargs_to_get_model = { "subfolder": subfolder, @@ -352,14 +467,26 @@ def _from_transformers( onnx_config = onnx_config_class(model.config) + tmp_dir = TemporaryDirectory() + tmp_dir_path = Path(tmp_dir.name) export( model=model, config=onnx_config, opset=onnx_config.DEFAULT_ONNX_OPSET, - output=save_dir.joinpath(ONNX_WEIGHTS_NAME), + output=tmp_dir_path / ONNX_WEIGHTS_NAME, + ) + config.save_pretrained(tmp_dir_path) + maybe_save_preprocessors(model_id, tmp_dir_path, src_subfolder=subfolder) + + return cls._from_pretrained( + tmp_dir_path, + config, + use_io_binding=use_io_binding, + model_save_dir=tmp_dir, + provider=provider, + session_options=session_options, + provider_options=provider_options, ) - - return cls._from_pretrained(save_dir.as_posix(), config, **kwargs) @classmethod @add_start_docstrings(FROM_PRETRAINED_START_DOCSTRING) @@ -454,8 +581,6 @@ class ORTModelForFeatureExtraction(ORTModel): Feature Extraction model for ONNX. """ - # used in from_transformers to export model to onnx - export_feature = "default" auto_model_class = AutoModel def __init__(self, model=None, config=None, use_io_binding=True, **kwargs): @@ -627,7 +752,6 @@ class ORTModelForQuestionAnswering(ORTModel): Question Answering model for ONNX. """ - export_feature = "question-answering" auto_model_class = AutoModelForQuestionAnswering def __init__(self, model=None, config=None, use_io_binding=True, **kwargs): @@ -827,7 +951,6 @@ class ORTModelForSequenceClassification(ORTModel): Sequence Classification model for ONNX. """ - export_feature = "sequence-classification" auto_model_class = AutoModelForSequenceClassification def __init__(self, model=None, config=None, use_io_binding=True, **kwargs): @@ -998,7 +1121,6 @@ class ORTModelForTokenClassification(ORTModel): Token Classification model for ONNX. """ - export_feature = "token-classification" auto_model_class = AutoModelForTokenClassification def __init__(self, model=None, config=None, use_io_binding=True, **kwargs): @@ -1164,7 +1286,6 @@ class ORTModelForMultipleChoice(ORTModel): Multiple choice model for ONNX. """ - export_feature = "multiple-choice" auto_model_class = AutoModelForMultipleChoice def __init__(self, model=None, config=None, use_io_binding=True, **kwargs): @@ -1334,7 +1455,6 @@ class ORTModelForImageClassification(ORTModel): Image Classification model for ONNX. """ - export_feature = "image-classification" auto_model_class = AutoModelForImageClassification def __init__(self, model=None, config=None, use_io_binding=True, **kwargs): @@ -1462,21 +1582,47 @@ def forward( ) class ORTModelForCustomTasks(ORTModel): """ - Onnx Model for any custom tasks. + Onnx Model for any custom tasks using encoder or decoder-only models. """ - export_feature = "default" - auto_model_class = AutoModel + def __init__(self, model=None, config=None, use_io_binding=True, **kwargs): + super().__init__(model, config, use_io_binding=True, **kwargs) + self.model_inputs = {output_key.name: idx for idx, output_key in enumerate(self.model.get_inputs())} + self.model_outputs = {output_key.name: idx for idx, output_key in enumerate(self.model.get_outputs())} + self.model_input_names = list(self.model_inputs.keys()) + self.model_output_names = list(self.model_outputs.keys()) - def __init__(self, model=None, config=None, **kwargs): - super().__init__(model, config, **kwargs) - if kwargs.pop("use_io_binding", False): - logger.warning( - "ORTModelForCustomTasks doesn't support IO Binding yet, and the inference will be done without IO binding which could cause" - " significant overhead on data copying. If you want us to enable IO binding for custom use case, please open an issue in " - "Optimum: https://github.com/huggingface/optimum." + def prepare_io_binding(self, **kwargs) -> ort.IOBinding: + """ + Returns IOBinding object for an inference session. This method is created for general purpose, if the inputs and outputs + are determined, you can prepare data buffers directly to avoid tensor transfers across frameworks. + """ + + name_to_np_type = TypeHelper.get_io_numpy_type_map(self.model) + + # Bind inputs and outputs to onnxruntime session + io_binding = self.model.io_binding() + + # Bind inputs + for input_name in self.model_input_names: + onnx_input = kwargs.pop(input_name) + onnx_input = onnx_input.contiguous() + + io_binding.bind_input( + input_name, + onnx_input.device.type, + self.device.index, + name_to_np_type[input_name], + list(onnx_input.size()), + onnx_input.data_ptr(), ) + # Bind outputs + for name in self.model_output_names: + io_binding.bind_output(name, self.device.type, device_id=self.device.index) + + return io_binding + @add_start_docstrings_to_model_forward( CUSTOM_TASKS_EXAMPLE.format( processor_class=_TOKENIZER_FOR_DOC, @@ -1485,13 +1631,30 @@ def __init__(self, model=None, config=None, **kwargs): ) ) def forward(self, **kwargs): - # converts pytorch inputs into numpy inputs for onnx - onnx_inputs = self._prepare_onnx_inputs(**kwargs) - # run inference - onnx_outputs = self.model.run(None, onnx_inputs) - outputs = self._prepare_onnx_outputs(onnx_outputs) - # converts outputs to namedtuple for pipelines post-processing if applicable - return ModelOutput(outputs) + if self.device.type == "cuda" and self.use_io_binding: + io_binding = self.prepare_io_binding(**kwargs) + + # run inference with binding + io_binding.synchronize_inputs() + self.model.run_with_iobinding(io_binding) + io_binding.synchronize_outputs() + + outputs = {} + for name, output in zip(self.model_output_names, io_binding._iobinding.get_outputs()): + outputs[name] = IOBindingHelper.to_pytorch(output) + + # converts output to namedtuple for pipelines post-processing + return ModelOutput(**outputs) + else: + # converts pytorch inputs into numpy inputs for onnx + onnx_inputs = self._prepare_onnx_inputs(**kwargs) + + # run inference + onnx_outputs = self.model.run(None, onnx_inputs) + outputs = self._prepare_onnx_outputs(onnx_outputs) + + # converts output to namedtuple for pipelines post-processing + return ModelOutput(outputs) def _prepare_onnx_inputs(self, **kwargs): model_inputs = {input_key.name: idx for idx, input_key in enumerate(self.model.get_inputs())} diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index b7a51558077..ad1c7489780 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -17,32 +17,34 @@ """ import logging -import os +import re import shutil from abc import ABC, abstractmethod from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union +from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import numpy as np import torch -from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq, AutoTokenizer -from transformers.file_utils import add_start_docstrings_to_model_forward, default_cache_path -from transformers.generation_utils import GenerationMixin +from transformers import AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq +from transformers.file_utils import add_start_docstrings_to_model_forward from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput import onnxruntime as ort -from huggingface_hub import hf_hub_download +from huggingface_hub import HfApi, HfFolder, get_hf_file_metadata, hf_hub_download, hf_hub_url from ..exporters.onnx.convert import export_encoder_decoder_model as export from ..exporters.tasks import TasksManager from ..utils import NormalizedConfigManager, check_if_transformers_greater +from ..utils.file_utils import validate_file_exists +from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors from .io_binding import TypeHelper +from .modeling_decoder import ORTDecoder from .modeling_ort import ORTModel from .utils import ( ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME, - get_device_for_provider, get_provider_for_device, parse_device, validate_provider_availability, @@ -197,475 +199,161 @@ ``` """ +ENCODER_ONNX_FILE_PATTERN = r"(.*)?encoder(.*)?\.onnx" +DECODER_ONNX_FILE_PATTERN = r"(.*)?decoder((?!with_past).)*?\.onnx" +DECODER_WITH_PAST_ONNX_FILE_PATTERN = r"(.*)?decoder(.*)?with_past(.*)?\.onnx" -class ORTModelForConditionalGeneration(ORTModel, ABC): - """ - Sequence-to-sequence model with a language modeling head for ONNX Runtime inference. - - Important attributes: - config ([`PretrainedConfig`]): - Instance of the configuration associated to the model. Initializing with a config file does - not load the weights associated with the model, only the configuration. - use_io_binding (`bool`): - Whether use IOBinding during inference to avoid memory copy between the host and devices. Defaults to `True` - if the device is CUDA, otherwise defaults to `False`. - use_cache (`bool`): - Whether or not past key/values cache should be used. It is determined by whether an InferenceSession for - that was provided or not. - providers (`List[str`]): - The list of execution providers the model is running on. - encoder (`ORTEncoder`): - The encoder model. - decoder (`ORTDecoder`): - The decoder model. - decoder_with_past (`Optional[ORTDecoder]`): - The decoder model handling the past key/values if `use_cache=True`, else `None`. - - Other attributes: - encoder_file_name (`str`, defaults to `optimum.onnxruntime.utils.ONNX_ENCODER_NAME`): - The name of the ONNX file containing the encoder part of the model. - decoder_file_name (`str`, defaults to `optimum.onnxruntime.utils.ONNX_DECODER_NAME`): - The name of the ONNX file containing the decoder part of the model. - decoder_file_with_past_name (`str`, defaults to `optimum.onnxruntime.utils.ONNX_DECODER_WITH_PAST_NAME`): - The name of the ONNX file containing the decoder with past key/values part of the model. - model_save_dir (`str`, defaults to `""`): - The directory under which the model exported to ONNX was saved. +class ORTEncoder: + """ + Encoder part of the encoder-decoder model for ONNX Runtime inference. """ - - # Used in from_transformers to export model to onnxORTEncoder - base_model_prefix = "onnx_model" def __init__( self, - encoder_session: ort.InferenceSession, - decoder_session: ort.InferenceSession, + session: ort.InferenceSession, config: "PretrainedConfig", - decoder_with_past_session: Optional[ort.InferenceSession] = None, + device: torch.device, use_io_binding: bool = True, - model_save_dir: str = "", - last_encoder_model_name: str = ONNX_ENCODER_NAME, - last_decoder_model_name: str = ONNX_DECODER_NAME, - last_decoder_with_past_model_name: str = ONNX_DECODER_WITH_PAST_NAME, + main_input_name: str = "input_ids", ): - """ - Args: - encoder_session (`ort.InferenceSession`): - The ONNX Runtime inference session associated to the encoder. - decoder_session (`ort.InferenceSession`): - The ONNX Runtime inference session associated to the decoder. - config ([`PretrainedConfig`]): - `config` is an instance of the configuration associated to the model. Initializing with a config file - does not load the weights associated with the model, only the configuration. - decoder_with_past_session (`Optional[ort.InferenceSession]`, *optional*): - The ONNX Runtime inference session associated to the decoder with past key values. - use_io_binding (`bool`, *optional*, defaults to `True`): - Whether use IOBinding during inference to avoid memory copy between the host and devices. Defaults to - `True` if the device is CUDA, otherwise defaults to `False`. - model_save_dir (`str`, *optional*, defaults to `""`): - The directory under which the model exported to ONNX was saved. - last_encoder_model_name (`str`, *optional*, defaults to `optimum.onnxruntime.utils.ONNX_ENCODER_NAME`): - The name of the ONNX file containing the encoder part of the model. - last_decoder_model_name (`str`, *optional*, defaults to `optimum.onnxruntime.utils.ONNX_DECODER_NAME`): - The name of the ONNX file containing the decoder part of the model. - last_decoder_with_past_model_name (`str`, *optional*, defaults to `optimum.onnxruntime.utils.ONNX_DECODER_WITH_PAST_NAME`): - The name of the ONNX file containing the decoder with past key/values part of the model. - """ - ABC.__init__(self) + self.session = session + self.config = config + self._device = device + self.use_io_binding = use_io_binding + self.main_input_name = main_input_name + self.normalized_config = NormalizedConfigManager.get_normalized_config_class(self.config.model_type)( + self.config + ) + self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())} + self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} + self.name_to_np_type = TypeHelper.get_io_numpy_type_map(self.session) if self.use_io_binding else None - self.encoder_file_name = last_encoder_model_name - self.decoder_file_name = last_decoder_model_name - self.decoder_file_with_past_name = last_decoder_with_past_model_name + def prepare_output_buffer(self, batch_size, sequence_length): + """Prepare the buffer of output(`last_hidden_state`) with a 1D tensor on shape: (batch_size, sequence_length, hidden_size).""" + ort_type = TypeHelper.get_output_type(self.session, "last_hidden_state") + torch_type = TypeHelper.ort_type_to_torch_type(ort_type) - self.config = config + hidden_size = self.normalized_config.hidden_size + output_shape = (batch_size, sequence_length, hidden_size) + output_buffer = torch.empty(np.prod(output_shape), dtype=torch_type, device=self._device).contiguous() - self.use_io_binding = use_io_binding - self.model_save_dir = model_save_dir + return output_shape, output_buffer - self.providers = encoder_session.get_providers() - self._device = get_device_for_provider(encoder_session.get_providers()[0]) + def prepare_io_binding( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + io_binding = self.session.io_binding() - if "TensorrtExecutionProvider" in self.providers and self.use_io_binding: - logger.warning( - "There is no need to do IO binding for TensorrtExecutionProvider, `use_io_binding` will be set to False." + # bind input ids + input_ids = input_ids.contiguous() + io_binding.bind_input( + "input_ids", + input_ids.device.type, + self._device.index, + self.name_to_np_type["input_ids"], + tuple(input_ids.shape), + input_ids.data_ptr(), + ) + if "attention_mask" in self.input_names: + # bind attention mask + attention_mask = attention_mask.contiguous() + io_binding.bind_input( + "attention_mask", + attention_mask.device.type, + self._device.index, + self.name_to_np_type["attention_mask"], + tuple(attention_mask.shape), + attention_mask.data_ptr(), ) - self.use_io_binding = False - self.encoder = self._initialize_encoder( - session=encoder_session, config=self.config, device=self._device, use_io_binding=self.use_io_binding + # bind last_hidden_state + output_shape, output_buffer = self.prepare_output_buffer( + batch_size=input_ids.size(0), + sequence_length=input_ids.size(1), ) - self.decoder = ORTDecoder( - session=decoder_session, config=self.config, device=self._device, use_io_binding=self.use_io_binding + io_binding.bind_output( + "last_hidden_state", + output_buffer.device.type, + self._device.index, + self.name_to_np_type["last_hidden_state"], + output_shape, + output_buffer.data_ptr(), ) + output_shapes = {"last_hidden_state": output_shape} + output_buffers = {"last_hidden_state": output_buffer} - self.use_cache = decoder_with_past_session is not None + return io_binding, output_shapes, output_buffers - # If a decoder_with_past_path is provided, an inference session for the decoder with past key/values as inputs - # will be enabled - self.decoder_with_past = None - if self.use_cache: - self.decoder_with_past = ORTDecoder( - session=decoder_with_past_session, - config=self.config, - device=self._device, - use_io_binding=self.use_io_binding, - ) + @add_start_docstrings_to_model_forward(SEQ2SEQ_ENCODER_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor, + attention_mask: torch.LongTensor, + **kwargs, + ) -> BaseModelOutput: - # Registers the ORTModelForXXX classes into the transformers AutoModel classes - # to avoid warnings when create a pipeline https://github.com/huggingface/transformers/blob/cad61b68396a1a387287a8e2e2fef78a25b79383/src/transformers/pipelines/base.py#L863 - AutoConfig.register(self.base_model_prefix, AutoConfig) - self.auto_model_class.register(AutoConfig, self.__class__) + if self._device.type == "cuda" and self.use_io_binding: + io_binding, output_shapes, output_buffers = self.prepare_io_binding(input_ids, attention_mask) - @abstractmethod - def _initialize_encoder( - self, - session: ort.InferenceSession, - config: "PretrainedConfig", - device: torch.device, - use_io_binding: bool = True, - ) -> "ORTEncoder": - pass + # run inference with binding & synchronize in case of multiple CUDA streams + io_binding.synchronize_inputs() + self.session.run_with_iobinding(io_binding) + io_binding.synchronize_outputs() - @staticmethod - def load_model( - encoder_path: Union[str, Path], - decoder_path: Union[str, Path], - decoder_with_past_path: Optional[Union[str, Path]] = None, - provider: str = "CPUExecutionProvider", - session_options: Optional[ort.SessionOptions] = None, - provider_options: Optional[Dict] = None, - ): - """ - Creates an instance of [`~optimum.onnxruntime.modeling_seq2seq.ORTModelForConditionalGeneration`]. - Three inference sessions will be created for respectively the encoder, decoder and decoder with past key values - models. The default provider is `CPUExecutionProvider` to match the default behaviour in PyTorch/TensorFlow/JAX. + # converts output to namedtuple for pipelines post-processing + return BaseModelOutput( + last_hidden_state=output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) + ) + else: + onnx_inputs = {"input_ids": input_ids.cpu().detach().numpy()} - Args: - encoder_path (`Union[str, Path]`): - The path of the encoder ONNX model. - decoder_path (`Union[str, Path]`): - The path of the decoder ONNX model. - decoder_with_past_path (`Optional[Union[str, Path]]`, *optional*): - The path of the decoder with past key values ONNX model. - provider (`str`, *optional*, defaults to `"CPUExecutionProvider"`): - ONNX Runtime provider to use for loading the model. See https://onnxruntime.ai/docs/execution-providers/ - for possible providers. - session_options (`Optional[ort.SessionOptions]`, *optional*),: - ONNX Runtime session options to use for loading the model. Defaults to `None`. - provider_options (`Optional[Dict]`, *optional*): - Provider option dictionary corresponding to the provider used. See available options - for each provider: https://onnxruntime.ai/docs/api/c/group___global.html . Defaults to `None`. - """ - validate_provider_availability(provider) # raise error if the provider is not available + # Add the attention_mask inputs when needed + if "attention_mask" in self.input_names: + onnx_inputs["attention_mask"] = attention_mask.cpu().detach().numpy() - providers = [provider] - if provider == "TensorrtExecutionProvider": - # follow advice in https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#python - providers.append("CUDAExecutionProvider") + # Run inference + outputs = self.session.run(None, onnx_inputs) + last_hidden_state = torch.from_numpy(outputs[self.output_names["last_hidden_state"]]).to(self._device) - encoder_session = ort.InferenceSession( - str(encoder_path), - providers=providers, - sess_options=session_options, - provider_options=None if provider_options is None else [provider_options], - ) - decoder_session = ort.InferenceSession( - str(decoder_path), - providers=providers, - sess_options=session_options, - provider_options=None if provider_options is None else [provider_options], - ) + return BaseModelOutput(last_hidden_state=last_hidden_state) - decoder_with_past_session = None - # If a decoder_with_past_path is provided, an inference session for the decoder with past key/values as inputs - # will be enabled - if decoder_with_past_path is not None: - decoder_with_past_session = ort.InferenceSession( - str(decoder_with_past_path), - providers=providers, - sess_options=session_options, - provider_options=None if provider_options is None else [provider_options], - ) + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) - return encoder_session, decoder_session, decoder_with_past_session - def _save_pretrained( +class ORTEncoderForWhisper(ORTEncoder): + """ + Encoder model for ONNX Runtime inference for Whisper model. + + Args: + session (`ort.InferenceSession`): + The ONNX Runtime inference session associated to the encoder. + """ + + def prepare_io_binding( self, - save_directory: Union[str, Path], - # TODO: should we make the default values available here? - encoder_file_name: Optional[str] = None, - decoder_file_name: Optional[str] = None, - decoder_with_past_file_name: Optional[str] = None, + input_features: torch.FloatTensor = None, ): - """ - Saves the model encoder, decoder and decoder with past key values as well as its configuration file to a - directory, so that it can be re-loaded using the - [`~optimum.onnxruntime.modeling_seq2seq.ORTModelForSeq2SeqLM.from_pretrained`] class method. + io_binding = self.session.io_binding() - Args: - save_directory (`Union[str, Path`]): - The directory where to save the model files. - encoder_file_name(`Optional[str]`, *optional*): - The encoder model file name. Overwrites the default file name and allows one to save the encoder model - with a different name. - decoder_file_name(`Optional[str]`, *optional*): - The decoder model file name. Overwrites the default file name and allows one to save the decoder model - with a different name. - decoder_with_past_file_name(`Optional[str]`, *optional*): - The decoder with past key values model file name overwriting the default file name, allowing to save - the decoder model with a different name. - """ - src_file_names = [self.encoder_file_name, self.decoder_file_name] - dst_file_names = [encoder_file_name or ONNX_ENCODER_NAME, decoder_file_name or ONNX_DECODER_NAME] - if self.use_cache: - src_file_names.append(self.decoder_file_with_past_name) - dst_file_names.append(decoder_with_past_file_name or ONNX_DECODER_WITH_PAST_NAME) - - for src_file_name, dst_file_name in zip(src_file_names, dst_file_names): - src_path = self.model_save_dir.joinpath(src_file_name) - dst_path = Path(save_directory).joinpath(dst_file_name) - shutil.copyfile(src_path, dst_path) - - @classmethod - def _from_pretrained( - cls, - model_id: Union[str, Path], - config: "PretrainedConfig", - use_auth_token: Optional[Union[bool, str]] = None, - revision: Optional[str] = None, - force_download: bool = False, - cache_dir: Optional[str] = None, - encoder_file_name: str = ONNX_ENCODER_NAME, - decoder_file_name: str = ONNX_DECODER_NAME, - decoder_with_past_file_name: str = ONNX_DECODER_WITH_PAST_NAME, - subfolder: str = "", - local_files_only: bool = False, - use_cache: bool = True, - provider: str = "CPUExecutionProvider", - session_options: Optional[ort.SessionOptions] = None, - provider_options: Optional[Dict[str, Any]] = None, - use_io_binding: bool = True, - ): - kwargs = {"use_io_binding": use_io_binding} - - # Load model from a local directory - if os.path.isdir(os.path.join(model_id, subfolder)): - decoder_with_past_path = ( - os.path.join(model_id, subfolder, decoder_with_past_file_name) if use_cache else None - ) - model = cls.load_model( - encoder_path=os.path.join(model_id, subfolder, encoder_file_name), - decoder_path=os.path.join(model_id, subfolder, decoder_file_name), - decoder_with_past_path=decoder_with_past_path, - provider=provider, - session_options=session_options, - provider_options=provider_options, - ) - kwargs["model_save_dir"] = Path(model_id).joinpath(subfolder) - kwargs["last_encoder_model_name"] = encoder_file_name - kwargs["last_decoder_model_name"] = decoder_file_name - kwargs["last_decoder_with_past_model_name"] = decoder_with_past_file_name - # Load model from hub - else: - default_file_names = [ONNX_ENCODER_NAME, ONNX_DECODER_NAME] - model_file_names = [encoder_file_name, decoder_file_name] - if use_cache: - default_file_names.append(ONNX_DECODER_WITH_PAST_NAME) - model_file_names.append(decoder_with_past_file_name) - # Download the encoder, decoder and decoder_with_past forming the model - for file_name, default_file_name in zip(model_file_names, default_file_names): - model_cache_path = hf_hub_download( - repo_id=model_id, - subfolder=subfolder, - filename=file_name, - use_auth_token=use_auth_token, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - local_files_only=local_files_only, - ) - kwargs[f"last_{default_file_name.split('.')[0]}_name"] = Path(model_cache_path).name - kwargs["model_save_dir"] = Path(model_cache_path).parent - - last_decoder_with_past_name = kwargs.get("last_decoder_with_past_model_name", None) - if last_decoder_with_past_name is not None: - last_decoder_with_past_name = kwargs["model_save_dir"].joinpath(last_decoder_with_past_name) - model = cls.load_model( - encoder_path=kwargs["model_save_dir"].joinpath(kwargs["last_encoder_model_name"]), - decoder_path=kwargs["model_save_dir"].joinpath(kwargs["last_decoder_model_name"]), - decoder_with_past_path=last_decoder_with_past_name, - provider=provider, - session_options=session_options, - provider_options=provider_options, - ) - - return cls(*model[:2], config, decoder_with_past_session=model[2], **kwargs) - - @classmethod - def _from_transformers( - cls, - model_id: str, - config: "PretrainedConfig", - save_dir: Union[str, Path] = default_cache_path, - use_auth_token: Optional[Union[bool, str]] = None, - revision: str = "main", - force_download: bool = True, - cache_dir: Optional[str] = None, - subfolder: str = "", - local_files_only: bool = False, - use_cache: bool = True, - provider: str = "CPUExecutionProvider", - session_options: Optional[ort.SessionOptions] = None, - provider_options: Optional[Dict[str, Any]] = None, - use_io_binding: bool = True, - ): - # Create local save dir in cache dir - save_dir = Path(save_dir).joinpath(model_id, subfolder) - save_dir.mkdir(parents=True, exist_ok=True) - - model = TasksManager.get_model_from_task( - cls.export_feature, - model_id, - subfolder=subfolder, - revision=revision, - cache_dir=cache_dir, - config=config, - use_auth_token=use_auth_token, - local_files_only=local_files_only, - force_download=force_download, - ) - - model_type = model.config.model_type.replace("_", "-") - model_name = getattr(model, "name", None) - - onnx_config_constructor = TasksManager.get_exporter_config_constructor( - model_type, "onnx", task=cls.export_feature, model_name=model_name - ) - onnx_config = onnx_config_constructor(model.config, use_past=use_cache) - onnx_opset = onnx_config.DEFAULT_ONNX_OPSET - - export( - model, - onnx_config, - onnx_opset, - save_dir.joinpath(ONNX_ENCODER_NAME), - save_dir.joinpath(ONNX_DECODER_NAME), - save_dir.joinpath(ONNX_DECODER_WITH_PAST_NAME), - ) - - return cls._from_pretrained( - save_dir, - config=config, - use_cache=use_cache, - provider=provider, - session_options=session_options, - provider_options=provider_options, - use_io_binding=use_io_binding, - ) - - def to(self, device: Union[torch.device, str, int]): - """ - Changes the ONNX Runtime provider according to the device. - - Args: - device (`torch.device` or `str` or `int`): - Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run - the model on the associated CUDA device id. You can pass native `torch.device` or a `str` too. - - Returns: - `ORTModel`: the model placed on the requested device. - """ - device, provider_options = parse_device(device) - - provider = get_provider_for_device(device) - validate_provider_availability(provider) # raise error if the provider is not available - - self.device = device - self.encoder._device = device - self.encoder.session.set_providers([provider], provider_options=[provider_options]) - self.decoder._device = device - self.decoder.session.set_providers([provider], provider_options=[provider_options]) - if self.decoder_with_past is not None: - self.decoder_with_past._device = device - self.decoder_with_past.session.set_providers([provider], provider_options=[provider_options]) - self.providers = self.encoder.session.get_providers() - - return self - - -class ORTEncoder: - """ - Encoder model for ONNX Runtime inference. - - Args: - session (`ort.InferenceSession`): - The ONNX Runtime inference session associated to the encoder. - """ - - def __init__( - self, - session: ort.InferenceSession, - config: "PretrainedConfig", - device: torch.device, - use_io_binding: bool = True, - main_input_name: str = "input_ids", - ): - self.session = session - self.config = config - self._device = device - self.use_io_binding = use_io_binding - self.main_input_name = main_input_name - self.normalized_config = NormalizedConfigManager.get_normalized_config_class(self.config.model_type)( - self.config - ) - self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())} - self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} - self.name_to_np_type = TypeHelper.get_io_numpy_type_map(self.session) if self.use_io_binding else None - - def prepare_output_buffer(self, batch_size, sequence_length): - """Prepare the buffer of output(`last_hidden_state`) with a 1D tensor on shape: (batch_size, sequence_length, hidden_size).""" - ort_type = TypeHelper.get_output_type(self.session, "last_hidden_state") - torch_type = TypeHelper.ort_type_to_torch_type(ort_type) - - hidden_size = self.normalized_config.hidden_size - output_shape = (batch_size, sequence_length, hidden_size) - output_buffer = torch.empty(np.prod(output_shape), dtype=torch_type, device=self._device).contiguous() - - return output_shape, output_buffer - - def prepare_io_binding( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - ): - io_binding = self.session.io_binding() - - # bind input ids - input_ids = input_ids.contiguous() + # bind input features + input_features = input_features.contiguous() io_binding.bind_input( - "input_ids", - input_ids.device.type, + "input_features", + input_features.device.type, self._device.index, - self.name_to_np_type["input_ids"], - tuple(input_ids.shape), - input_ids.data_ptr(), + self.name_to_np_type["input_features"], + tuple(input_features.shape), + input_features.data_ptr(), ) - if "attention_mask" in self.input_names: - # bind attention mask - attention_mask = attention_mask.contiguous() - io_binding.bind_input( - "attention_mask", - attention_mask.device.type, - self._device.index, - self.name_to_np_type["attention_mask"], - tuple(attention_mask.shape), - attention_mask.data_ptr(), - ) - # bind last_hidden_state + # bind logits output_shape, output_buffer = self.prepare_output_buffer( - batch_size=input_ids.size(0), - sequence_length=input_ids.size(1), + batch_size=input_features.size(0), + sequence_length=input_features.size(2) // 2, ) io_binding.bind_output( "last_hidden_state", @@ -680,16 +368,14 @@ def prepare_io_binding( return io_binding, output_shapes, output_buffers - @add_start_docstrings_to_model_forward(SEQ2SEQ_ENCODER_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(WHISPER_ENCODER_INPUTS_DOCSTRING) def forward( self, - input_ids: torch.LongTensor, - attention_mask: torch.LongTensor, + input_features: torch.FloatTensor, **kwargs, ) -> BaseModelOutput: - if self._device.type == "cuda" and self.use_io_binding: - io_binding, output_shapes, output_buffers = self.prepare_io_binding(input_ids, attention_mask) + io_binding, output_shapes, output_buffers = self.prepare_io_binding(input_features) # run inference with binding & synchronize in case of multiple CUDA streams io_binding.synchronize_inputs() @@ -701,11 +387,7 @@ def forward( last_hidden_state=output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) ) else: - onnx_inputs = {"input_ids": input_ids.cpu().detach().numpy()} - - # Add the attention_mask inputs when needed - if "attention_mask" in self.input_names: - onnx_inputs["attention_mask"] = attention_mask.cpu().detach().numpy() + onnx_inputs = {"input_features": input_features.cpu().detach().numpy()} # Run inference outputs = self.session.run(None, onnx_inputs) @@ -713,121 +395,20 @@ def forward( return BaseModelOutput(last_hidden_state=last_hidden_state) - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - -class ORTEncoderForWhisper(ORTEncoder): +class ORTDecoderForSeq2Seq(ORTDecoder): """ - Encoder model for ONNX Runtime inference for Whisper model. - - Args: - session (`ort.InferenceSession`): - The ONNX Runtime inference session associated to the encoder. + Decoder model with a language modeling head on top for ONNX Runtime inference. """ - def prepare_io_binding( + def prepare_output_buffer( self, - input_features: torch.FloatTensor = None, - ): - io_binding = self.session.io_binding() - - # bind input features - input_features = input_features.contiguous() - io_binding.bind_input( - "input_features", - input_features.device.type, - self._device.index, - self.name_to_np_type["input_features"], - tuple(input_features.shape), - input_features.data_ptr(), - ) - - # bind logits - output_shape, output_buffer = self.prepare_output_buffer( - batch_size=input_features.size(0), - sequence_length=input_features.size(2) // 2, - ) - io_binding.bind_output( - "last_hidden_state", - output_buffer.device.type, - self._device.index, - self.name_to_np_type["last_hidden_state"], - output_shape, - output_buffer.data_ptr(), - ) - output_shapes = {"last_hidden_state": output_shape} - output_buffers = {"last_hidden_state": output_buffer} - - return io_binding, output_shapes, output_buffers - - @add_start_docstrings_to_model_forward(WHISPER_ENCODER_INPUTS_DOCSTRING) - def forward( - self, - input_features: torch.FloatTensor, - **kwargs, - ) -> BaseModelOutput: - if self._device.type == "cuda" and self.use_io_binding: - io_binding, output_shapes, output_buffers = self.prepare_io_binding(input_features) - - # run inference with binding & synchronize in case of multiple CUDA streams - io_binding.synchronize_inputs() - self.session.run_with_iobinding(io_binding) - io_binding.synchronize_outputs() - - # converts output to namedtuple for pipelines post-processing - return BaseModelOutput( - last_hidden_state=output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) - ) - else: - onnx_inputs = {"input_features": input_features.cpu().detach().numpy()} - - # Run inference - outputs = self.session.run(None, onnx_inputs) - last_hidden_state = torch.from_numpy(outputs[self.output_names["last_hidden_state"]]).to(self._device) - - return BaseModelOutput(last_hidden_state=last_hidden_state) - - -class ORTDecoder: - """ - Decoder model with a language modeling head on top for ONNX Runtime inference. - - Args: - session (`ort.InferenceSession`): - The ONNX Runtime inference session associated to the decoder. - """ - - def __init__( - self, - session: ort.InferenceSession, - config: "PretrainedConfig", - device: torch.device, - use_io_binding: bool = True, - ): - self.session = session - self.config = config - self.normalized_config = NormalizedConfigManager.get_normalized_config_class(self.config.model_type)( - self.config - ) - self._device = device - self.use_io_binding = use_io_binding - self.session_inputs = {output_key.name: idx for idx, output_key in enumerate(self.session.get_inputs())} - self.session_outputs = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} - self.session_input_names = list(self.session_inputs.keys()) - self.session_output_names = list(self.session_outputs.keys()) - self.key_value_input_names = [key for key in self.session_input_names if (".key" in key or ".value" in key)] - self.key_value_output_names = [key for key in self.session_output_names if (".key" in key or ".value" in key)] - self.name_to_np_type = TypeHelper.get_io_numpy_type_map(self.session) if self.use_io_binding else None - - def prepare_output_buffer( - self, - output_name, - batch_size=None, - sequence_length=None, - encoder_sequence_length=None, - past_sequence_length=None, - is_self_attn=False, + output_name, + batch_size=None, + sequence_length=None, + encoder_sequence_length=None, + past_sequence_length=None, + is_self_attn=False, ): """ Prepare the buffer of outputs(`logits`/`key_values`/`loss`) with 1D tensors. @@ -1072,6 +653,7 @@ def forward( # Run inference outputs = self.session.run(None, onnx_inputs) + # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the # self-attention layer and 2 to the cross-attention layer) past_key_values = tuple( @@ -1096,12 +678,483 @@ def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) +class ORTModelForConditionalGeneration(ORTModel, ABC): + """ + Sequence-to-sequence model with a language modeling head for ONNX Runtime inference. + + Important attributes: + config ([`PretrainedConfig`]): + Instance of the configuration associated to the model. Initializing with a config file does + not load the weights associated with the model, only the configuration. + use_io_binding (`bool`): + Whether use IOBinding during inference to avoid memory copy between the host and devices. Defaults to `True` + if the device is CUDA, otherwise defaults to `False`. + use_cache (`bool`): + Whether or not past key/values cache should be used. It is determined by whether an InferenceSession for + that was provided or not. + providers (`List[str`]): + The list of execution providers the model is running on. + encoder (`ORTEncoder`): + The encoder model. + decoder (`ORTDecoderForSeq2Seq`): + The decoder model. + decoder_with_past (`Optional[ORTDecoderForSeq2Seq]`): + The decoder model handling the past key/values if `use_cache=True`, else `None`. + + Other attributes: + encoder_file_name (`str`, defaults to `optimum.onnxruntime.utils.ONNX_ENCODER_NAME`): + The name of the ONNX file containing the encoder part of the model. + decoder_file_name (`str`, defaults to `optimum.onnxruntime.utils.ONNX_DECODER_NAME`): + The name of the ONNX file containing the decoder part of the model. + decoder_file_with_past_name (`str`, defaults to `optimum.onnxruntime.utils.ONNX_DECODER_WITH_PAST_NAME`): + The name of the ONNX file containing the decoder with past key/values part of the model. + model_save_dir (`str`, defaults to `""`): + The directory under which the model exported to ONNX was saved. + + """ + + # Used in from_transformers to export model to onnxORTEncoder + base_model_prefix = "onnx_model" + + def __init__( + self, + encoder_session: ort.InferenceSession, + decoder_session: ort.InferenceSession, + config: "PretrainedConfig", + decoder_with_past_session: Optional[ort.InferenceSession] = None, + use_io_binding: bool = True, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + preprocessors: Optional[List] = None, + **kwargs, + ): + """ + Args: + encoder_session (`ort.InferenceSession`): + The ONNX Runtime inference session associated to the encoder. + decoder_session (`ort.InferenceSession`): + The ONNX Runtime inference session associated to the decoder. + config ([`PretrainedConfig`]): + `config` is an instance of the configuration associated to the model. Initializing with a config file + does not load the weights associated with the model, only the configuration. + decoder_with_past_session (`Optional[ort.InferenceSession]`, *optional*): + The ONNX Runtime inference session associated to the decoder with past key values. + use_io_binding (`bool`, *optional*, defaults to `True`): + Whether use IOBinding during inference to avoid memory copy between the host and devices. Defaults to + `True` if the device is CUDA, otherwise defaults to `False`. + model_save_dir (`str`, *optional*, defaults to `""`): + The directory under which the model exported to ONNX was saved. + preprocessors (`Optional[List]`, defaults to `None`): + The list of the preprocessors (tokenizer, processor, feature_extractor) to save alongside the ORTModel. + """ + # TODO: remove at version 2.0 + def show_deprecated_argument(arg_name): + if kwargs.pop(arg_name, None) is not None: + logger.warning( + f"The {arg_name} argument to create an {self.__class__.__name__} is deprecated, and not used " + "anymore." + ) + + show_deprecated_argument("last_encoder_model_name") + show_deprecated_argument("last_decoder_model_name") + show_deprecated_argument("last_decoder_with_past_model_name") + if kwargs: + raise ValueError( + f"{self.__class__.__name__} received {', '.join(kwargs.keys())}, but do not accept those arguments." + ) + + ABC.__init__(self) + + ORTModel.__init__( + self, + encoder_session, + config, + use_io_binding=use_io_binding, + model_save_dir=model_save_dir, + preprocessors=preprocessors, + ) + self.encoder = self._initialize_encoder( + session=encoder_session, config=self.config, device=self._device, use_io_binding=self.use_io_binding + ) + self.encoder_model_path = Path(encoder_session._model_path) + self.encoder_model_name = self.encoder_model_path.name + + self.decoder = ORTDecoderForSeq2Seq( + session=decoder_session, config=self.config, device=self._device, use_io_binding=self.use_io_binding + ) + self.decoder_model_path = Path(decoder_session._model_path) + self.decoder_model_name = self.decoder_model_path.name + + self.use_cache = decoder_with_past_session is not None + + # If a decoder_with_past_path is provided, an inference session for the decoder with past key/values as inputs + # will be enabled + self.decoder_with_past = None + self.decoder_with_past_model_path = None + self.decoder_with_past_model_name = None + if self.use_cache: + self.decoder_with_past = ORTDecoderForSeq2Seq( + session=decoder_with_past_session, + config=self.config, + device=self._device, + use_io_binding=self.use_io_binding, + ) + self.decoder_with_past_model_path = Path(decoder_with_past_session._model_path) + self.decoder_with_past_model_name = self.decoder_with_past_model_path.name + + @abstractmethod + def _initialize_encoder( + self, + session: ort.InferenceSession, + config: "PretrainedConfig", + device: torch.device, + use_io_binding: bool = True, + ) -> "ORTEncoder": + pass + + @staticmethod + def load_model( + encoder_path: Union[str, Path], + decoder_path: Union[str, Path], + decoder_with_past_path: Optional[Union[str, Path]] = None, + provider: str = "CPUExecutionProvider", + session_options: Optional[ort.SessionOptions] = None, + provider_options: Optional[Dict] = None, + ): + """ + Creates an instance of [`~optimum.onnxruntime.modeling_seq2seq.ORTModelForConditionalGeneration`]. + Three inference sessions will be created for respectively the encoder, decoder and decoder with past key values + models. The default provider is `CPUExecutionProvider` to match the default behaviour in PyTorch/TensorFlow/JAX. + + Args: + encoder_path (`Union[str, Path]`): + The path of the encoder ONNX model. + decoder_path (`Union[str, Path]`): + The path of the decoder ONNX model. + decoder_with_past_path (`Optional[Union[str, Path]]`, *optional*): + The path of the decoder with past key values ONNX model. + provider (`str`, *optional*, defaults to `"CPUExecutionProvider"`): + ONNX Runtime provider to use for loading the model. See https://onnxruntime.ai/docs/execution-providers/ + for possible providers. + session_options (`Optional[ort.SessionOptions]`, *optional*),: + ONNX Runtime session options to use for loading the model. Defaults to `None`. + provider_options (`Optional[Dict]`, *optional*): + Provider option dictionary corresponding to the provider used. See available options + for each provider: https://onnxruntime.ai/docs/api/c/group___global.html . Defaults to `None`. + """ + validate_provider_availability(provider) # raise error if the provider is not available + + providers = [provider] + if provider == "TensorrtExecutionProvider": + # follow advice in https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#python + providers.append("CUDAExecutionProvider") + + encoder_session = ort.InferenceSession( + str(encoder_path), + providers=providers, + sess_options=session_options, + provider_options=None if provider_options is None else [provider_options], + ) + decoder_session = ort.InferenceSession( + str(decoder_path), + providers=providers, + sess_options=session_options, + provider_options=None if provider_options is None else [provider_options], + ) + + decoder_with_past_session = None + # If a decoder_with_past_path is provided, an inference session for the decoder with past key/values as inputs + # will be enabled + if decoder_with_past_path is not None: + decoder_with_past_session = ort.InferenceSession( + str(decoder_with_past_path), + providers=providers, + sess_options=session_options, + provider_options=None if provider_options is None else [provider_options], + ) + + return encoder_session, decoder_session, decoder_with_past_session + + def _save_pretrained( + self, + save_directory: Union[str, Path], + # TODO: should we make the default values available here? + encoder_file_name: str = ONNX_ENCODER_NAME, + decoder_file_name: str = ONNX_DECODER_NAME, + decoder_with_past_file_name: str = ONNX_DECODER_WITH_PAST_NAME, + ): + """ + Saves the model encoder, decoder and decoder with past key values as well as its configuration file to a + directory, so that it can be re-loaded using the + [`~optimum.onnxruntime.modeling_seq2seq.ORTModelForSeq2SeqLM.from_pretrained`] class method. + + Args: + save_directory (`Union[str, Path`]): + The directory where to save the model files. + encoder_file_name(`str`, defaults to `optimum.onnxruntime.utils.ONNX_ENCODER_NAME`): + The encoder model file name. Overwrites the default file name and allows one to save the encoder model + with a different name. + decoder_file_name(`str`, defaults to `optimum.onnxruntime.utils.ONNX_DECODER_NAME`): + The decoder model file name. Overwrites the default file name and allows one to save the decoder model + with a different name. + decoder_with_past_file_name(`str`, defaults to `optimum.onnxruntime.ONNX_DECODER_WITH_PAST_NAME`): + The decoder with past key values model file name overwriting the default file name, allowing to save + the decoder model with a different name. + """ + src_file_names = [self.encoder_model_path, self.decoder_model_path] + dst_file_names = [encoder_file_name, decoder_file_name] + if self.use_cache: + src_file_names.append(self.decoder_with_past_model_path) + dst_file_names.append(decoder_with_past_file_name) + + for src_path, dst_file_name in zip(src_file_names, dst_file_names): + dst_path = Path(save_directory) / dst_file_name + shutil.copyfile(src_path, dst_path) + + @classmethod + def _from_pretrained( + cls, + model_id: Union[str, Path], + config: "PretrainedConfig", + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + force_download: bool = False, + cache_dir: Optional[str] = None, + encoder_file_name: str = ONNX_ENCODER_NAME, + decoder_file_name: str = ONNX_DECODER_NAME, + decoder_with_past_file_name: str = ONNX_DECODER_WITH_PAST_NAME, + subfolder: str = "", + local_files_only: bool = False, + use_cache: bool = True, + provider: str = "CPUExecutionProvider", + session_options: Optional[ort.SessionOptions] = None, + provider_options: Optional[Dict[str, Any]] = None, + use_io_binding: bool = True, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + ): + model_path = Path(model_id) + + if not validate_file_exists(model_id, encoder_file_name, subfolder=subfolder, revision=revision): + encoder_file_name = ORTModelForConditionalGeneration.infer_onnx_filename( + model_id, + ENCODER_ONNX_FILE_PATTERN, + "encoder_file_name", + subfolder=subfolder, + use_auth_token=use_auth_token, + revision=revision, + ) + if not validate_file_exists(model_id, decoder_file_name, subfolder=subfolder, revision=revision): + decoder_file_name = ORTModelForConditionalGeneration.infer_onnx_filename( + model_id, + DECODER_ONNX_FILE_PATTERN, + "decoder_file_name", + subfolder=subfolder, + use_auth_token=use_auth_token, + revision=revision, + ) + if not validate_file_exists(model_id, decoder_with_past_file_name, subfolder=subfolder, revision=revision): + decoder_with_past_file_name = ORTModelForConditionalGeneration.infer_onnx_filename( + model_id, + DECODER_WITH_PAST_ONNX_FILE_PATTERN, + "decoder_with_past_file_name", + subfolder=subfolder, + use_auth_token=use_auth_token, + revision=revision, + fail_if_not_found=use_cache, + ) + + encoder_regular_onnx_filenames = ORTModelForConditionalGeneration._generate_regular_names_for_filename( + ONNX_ENCODER_NAME + ) + decoder_regular_onnx_filenames = ORTModelForConditionalGeneration._generate_regular_names_for_filename( + ONNX_DECODER_NAME + ) + decoder_with_past_regular_onnx_filenames = ( + ORTModelForConditionalGeneration._generate_regular_names_for_filename(ONNX_DECODER_WITH_PAST_NAME) + ) + + if encoder_file_name not in encoder_regular_onnx_filenames: + logger.warning( + f"The ONNX file {encoder_file_name} is not a regular name used in optimum.onnxruntime, the " + "ORTModelForConditionalGeneration might not behave as expected." + ) + + if decoder_file_name not in decoder_regular_onnx_filenames: + logger.warning( + f"The ONNX file {decoder_file_name} is not a regular name used in optimum.onnxruntime, the " + "ORTModelForConditionalGeneration might not behave as expected." + ) + if decoder_with_past_file_name not in decoder_with_past_regular_onnx_filenames: + logger.warning( + f"The ONNX file {decoder_with_past_file_name} is not a regular name used in optimum.onnxruntime, " + "the ORTModelForConditionalGeneration might not behave as expected." + ) + + decoder_with_past_path = model_path / decoder_with_past_file_name if use_cache else None + + preprocessors = None + if model_path.is_dir(): + model = cls.load_model( + encoder_path=model_path / encoder_file_name, + decoder_path=model_path / decoder_file_name, + decoder_with_past_path=decoder_with_past_path, + provider=provider, + session_options=session_options, + provider_options=provider_options, + ) + new_model_save_dir = model_path + preprocessors = maybe_load_preprocessors(model_id) + else: + attribute_name_to_filename = { + "last_encoder_model_name": encoder_file_name, + "last_decoder_model_name": decoder_file_name, + "last_decoder_with_past_model_name": decoder_with_past_file_name if use_cache else None, + } + paths = {} + for attr_name, filename in attribute_name_to_filename.items(): + if filename is None: + continue + model_cache_path = hf_hub_download( + repo_id=model_id, + subfolder=subfolder, + filename=filename, + use_auth_token=use_auth_token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + ) + paths[attr_name] = Path(model_cache_path).name + new_model_save_dir = Path(model_cache_path).parent + preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder) + + last_decoder_with_past_name = paths.get("last_decoder_with_past_model_name", None) + if last_decoder_with_past_name is not None: + last_decoder_with_past_name = new_model_save_dir / last_decoder_with_past_name + + model = cls.load_model( + encoder_path=new_model_save_dir / paths["last_encoder_model_name"], + decoder_path=new_model_save_dir / paths["last_decoder_model_name"], + decoder_with_past_path=last_decoder_with_past_name, + provider=provider, + session_options=session_options, + provider_options=provider_options, + ) + + if model_save_dir is None: + model_save_dir = new_model_save_dir + + return cls( + *model[:2], + config, + decoder_with_past_session=model[2], + use_io_binding=use_io_binding, + model_save_dir=model_save_dir, + preprocessors=preprocessors, + ) + + @classmethod + def _from_transformers( + cls, + model_id: str, + config: "PretrainedConfig", + use_auth_token: Optional[Union[bool, str]] = None, + revision: str = "main", + force_download: bool = True, + cache_dir: Optional[str] = None, + subfolder: str = "", + local_files_only: bool = False, + use_cache: bool = True, + provider: str = "CPUExecutionProvider", + session_options: Optional[ort.SessionOptions] = None, + provider_options: Optional[Dict[str, Any]] = None, + use_io_binding: bool = True, + task: Optional[str] = None, + ) -> "ORTModelForConditionalGeneration": + if task is None: + task = cls._AUTOMODELS_TO_TASKS[cls.auto_model_class] + + save_dir = TemporaryDirectory() + save_dir_path = Path(save_dir.name) + + model = TasksManager.get_model_from_task( + task, + model_id, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + config=config, + use_auth_token=use_auth_token, + local_files_only=local_files_only, + force_download=force_download, + ) + + model_type = model.config.model_type.replace("_", "-") + model_name = getattr(model, "name", None) + + onnx_config_constructor = TasksManager.get_exporter_config_constructor( + model_type, "onnx", task=task, model_name=model_name + ) + onnx_config = onnx_config_constructor(model.config, use_past=use_cache) + onnx_opset = onnx_config.DEFAULT_ONNX_OPSET + + export( + model, + onnx_config, + onnx_opset, + save_dir_path.joinpath(ONNX_ENCODER_NAME), + save_dir_path.joinpath(ONNX_DECODER_NAME), + save_dir_path.joinpath(ONNX_DECODER_WITH_PAST_NAME), + ) + + config.save_pretrained(save_dir_path) + maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder) + + return cls._from_pretrained( + save_dir_path, + config, + use_cache=use_cache, + provider=provider, + session_options=session_options, + provider_options=provider_options, + use_io_binding=use_io_binding, + model_save_dir=save_dir, + ) + + def to(self, device: Union[torch.device, str, int]): + """ + Changes the ONNX Runtime provider according to the device. + + Args: + device (`torch.device` or `str` or `int`): + Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run + the model on the associated CUDA device id. You can pass native `torch.device` or a `str` too. + + Returns: + `ORTModel`: the model placed on the requested device. + """ + device, provider_options = parse_device(device) + + provider = get_provider_for_device(device) + validate_provider_availability(provider) # raise error if the provider is not available + + self.device = device + self.encoder._device = device + self.encoder.session.set_providers([provider], provider_options=[provider_options]) + self.decoder._device = device + self.decoder.session.set_providers([provider], provider_options=[provider_options]) + if self.decoder_with_past is not None: + self.decoder_with_past._device = device + self.decoder_with_past.session.set_providers([provider], provider_options=[provider_options]) + self.providers = self.encoder.session.get_providers() + + return self + + class ORTModelForSeq2SeqLM(ORTModelForConditionalGeneration, GenerationMixin): """ Sequence-to-sequence model with a language modeling head for ONNX Runtime inference. """ - export_feature = "seq2seq-lm" auto_model_class = AutoModelForSeq2SeqLM main_input_name = "input_ids" @@ -1210,7 +1263,6 @@ class ORTModelForSpeechSeq2Seq(ORTModelForConditionalGeneration, GenerationMixin Speech Sequence-to-sequence model with a language modeling head for ONNX Runtime inference. """ - export_feature = "speech2seq-lm" auto_model_class = AutoModelForSpeechSeq2Seq main_input_name = "input_features" diff --git a/optimum/onnxruntime/optimization.py b/optimum/onnxruntime/optimization.py index 015e7e765d0..527ec56115f 100644 --- a/optimum/onnxruntime/optimization.py +++ b/optimum/onnxruntime/optimization.py @@ -25,6 +25,7 @@ from onnxruntime.transformers.optimizer import optimize_model from ..utils import CONFIG_NAME, NormalizedConfigManager +from ..utils.save_utils import maybe_save_preprocessors from .configuration import OptimizationConfig, ORTConfig from .modeling_ort import ORTModel from .modeling_seq2seq import ORTModelForSeq2SeqLM @@ -48,7 +49,7 @@ def __init__(self, onnx_model_path: List[os.PathLike], config: "PretrainedConfig Args: onnx_model_path (`List[os.PathLike]`): The paths of the onnx models to optimize. - config ([`~PretrainedConfig`]): + config ([`~transformers.PretrainedConfig`]): An instance of the configuration associated to the model to optimize. """ super().__init__() @@ -67,24 +68,23 @@ def from_pretrained( The path to a local directory hosting the model to optimize or an instance of an `ORTModel` to quantize. Can be either: - A path to a local *directory* containing the model to optimize. - - An instance of ORTModel. - file_names(`List[str]`, *optional*): + - An instance of [`~optimum.onnxruntime.ORTModel`]. + file_names(`Optional[List[str]]`, *optional*): The list of file names of the models to optimize. """ onnx_model_path = [] config = None if isinstance(model_or_path, ORTModel): if isinstance(model_or_path, ORTModelForSeq2SeqLM): - model_save_dir = model_or_path.model_save_dir - onnx_model_path = [ - model_save_dir.joinpath(model_or_path.encoder_file_name), - model_save_dir.joinpath(model_or_path.decoder_file_name), + onnx_model_path += [ + model_or_path.encoder_model_path, + model_or_path.decoder_model_path, ] # Add the decoder with past key/values if present if model_or_path.use_cache: - onnx_model_path.append(model_save_dir.joinpath(model_or_path.decoder_file_with_past_name)) + onnx_model_path.append(model_or_path.decoder_with_past_model_path) else: - onnx_model_path = [model_or_path.model_save_dir.joinpath(model_or_path.latest_model_name)] + onnx_model_path.append(model_or_path.model_path) config = model_or_path.config elif os.path.isdir(model_or_path): file_names = [ONNX_WEIGHTS_NAME] if file_names is None else file_names @@ -110,7 +110,7 @@ def optimize( Optimizes a model given the optimization specifications defined in `optimization_config`. Args: - optimization_config (`OptimizationConfig`): + optimization_config ([`~optimum.onnxruntime.OptimizationConfig`]): The configuration containing the parameters related to optimization. save_dir (`Union[str, os.PathLike]`): The path used to save the optimized model. @@ -127,6 +127,9 @@ def optimize( save_dir.mkdir(parents=True, exist_ok=True) ORTConfigManager.check_optimization_supported_model(self.model_type) + self.config.save_pretrained(save_dir) + maybe_save_preprocessors(self.onnx_model_path[0].parent, save_dir) + # Create and save the configuration summarizing all the parameters related to optimization ort_config = ORTConfig( optimization=optimization_config, diff --git a/optimum/onnxruntime/quantization.py b/optimum/onnxruntime/quantization.py index 39d64e0e183..ddf4fdd7f6b 100644 --- a/optimum/onnxruntime/quantization.py +++ b/optimum/onnxruntime/quantization.py @@ -11,37 +11,40 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Classes handling quantization with ONNX Runtime.""" import logging import os -from abc import ABC from collections import defaultdict from pathlib import Path -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union from datasets import Dataset, load_dataset from packaging.version import Version, parse +from transformers import AutoConfig import onnx from onnxruntime import __version__ as ort_version from onnxruntime.quantization import CalibrationDataReader, QuantFormat, QuantizationMode, QuantType from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer from onnxruntime.quantization.qdq_quantizer import QDQQuantizer -from optimum.onnxruntime import ORTQuantizableOperator -from optimum.onnxruntime.configuration import CalibrationConfig, NodeName, NodeType, ORTConfig, QuantizationConfig -from optimum.onnxruntime.modeling_ort import ORTModel -from optimum.onnxruntime.modeling_seq2seq import ORTModelForConditionalGeneration -from optimum.onnxruntime.preprocessors import QuantizationPreprocessor -from optimum.onnxruntime.utils import ONNX_WEIGHTS_NAME -from optimum.quantization_base import OptimumQuantizer +from ..quantization_base import OptimumQuantizer +from ..utils.save_utils import maybe_save_preprocessors +from . import ORTQuantizableOperator +from .configuration import CalibrationConfig, NodeName, NodeType, ORTConfig, QuantizationConfig +from .modeling_ort import ORTModel +from .modeling_seq2seq import ORTModelForConditionalGeneration +from .preprocessors import QuantizationPreprocessor + + +if TYPE_CHECKING: + from transformers import PretrainedConfig LOGGER = logging.getLogger(__name__) class ORTCalibrationDataReader(CalibrationDataReader): - """ """ - __slots__ = ["batch_size", "dataset", "_dataset_iter"] def __init__(self, dataset: Dataset, batch_size: int = 1): @@ -83,65 +86,79 @@ class ORTQuantizer(OptimumQuantizer): Handles the ONNX Runtime quantization process for models shared on huggingface.co/models. """ - def __init__(self, onnx_model_path: List[Path]): + def __init__(self, onnx_model_path: Path, config: Optional["PretrainedConfig"] = None): """ Args: onnx_model_path (`Path`): Path to the onnx model files you want to quantize. + config (`Optional[PretrainedConfig]`, *optional*): + The configuration of the model. """ super().__init__() self.onnx_model_path = onnx_model_path + self.config = config + if self.config is None: + try: + self.config = AutoConfig.from_pretrained(self.onnx_model_path.parent) + except OSError: + LOGGER.warning( + f"Could not load the config for {self.onnx_model_path} automatically, this might make " + "the quantized model harder to use because it will not be able to be loaded by an ORTModel without " + "having to specify the configuration explicitly." + ) self._calibrator = None @classmethod def from_pretrained( cls, - model_or_path: Union[str, Path], + model_or_path: Union["ORTModel", str, Path], file_name: Optional[str] = None, ) -> "ORTQuantizer": """ - Instantiate a `ORTQuantizer` from a pretrained pytorch model and preprocessor. + Instantiates a `ORTQuantizer` from a an ONNX model file or an `ORTModel`. Args: - model_or_path (`Union[str, Path]`): + model_or_path (`Union[ORTModel, str, Path]`): Can be either: - A path to a saved exported ONNX Intermediate Representation (IR) model, e.g., `./my_model_directory/. - - Or a `ORTModelForXX` class, e.g., `ORTModelForQuestionAnswering`. - file_name(`Union[str, List[str]]`, *optional*): + - Or an `ORTModelForXX` class, e.g., `ORTModelForQuestionAnswering`. + file_name(`Optional[str]`, *optional*): Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load different model files from the same repository or directory. Returns: An instance of `ORTQuantizer`. """ - # define the file name for the quantizable models - if file_name is None: - if isinstance(model_or_path, ORTModel): - if isinstance(model_or_path, ORTModelForConditionalGeneration): - raise ValueError( - "ORTQuantizer does not support multi-file quantization. Please create separate ORTQuantizer instances for each model/file." - ) - model_file_name = model_or_path.latest_model_name - else: - model_file_name = ONNX_WEIGHTS_NAME - else: - model_file_name = file_name + ort_quantizer_error_message = "ORTQuantizer does not support multi-file quantization. Please create separate ORTQuantizer instances for each model/file." + + if isinstance(model_or_path, str): + model_or_path = Path(model_or_path) + + if isinstance(model_or_path, ORTModelForConditionalGeneration): + raise ValueError(ort_quantizer_error_message) + elif isinstance(model_or_path, Path): + onnx_files = list(model_or_path.glob("*.onnx")) + if len(onnx_files) == 0: + raise FileNotFoundError(f"Could not find any ONNX model file in {model_or_path}") + elif len(onnx_files) > 1: + raise RuntimeError( + f"Found too many ONNX model files in {model_or_path}. {ort_quantizer_error_message}" + ) + file_name = onnx_files[0].name - # create ORTQuantizer based on the provided input + path = None if isinstance(model_or_path, ORTModel): - return cls(model_or_path.model_save_dir.joinpath(model_file_name)) - # load from local path + path = Path(model_or_path.model._model_path) elif os.path.isdir(model_or_path): - if not isinstance(model_or_path, Path): - model_or_path = Path(model_or_path) - return cls(model_or_path.joinpath(model_file_name)) + path = Path(model_or_path) / file_name else: raise ValueError(f"Unable to load model from {model_or_path}.") + return cls(path) def fit( self, dataset: Dataset, calibration_config: CalibrationConfig, - onnx_augmented_model_name: str = "augmented_model.onnx", + onnx_augmented_model_name: Union[str, Path] = "augmented_model.onnx", operators_to_quantize: Optional[List[NodeType]] = None, batch_size: int = 1, use_external_data_format: bool = False, @@ -149,24 +166,24 @@ def fit( force_symmetric_range: bool = False, ) -> Dict[str, Tuple[float, float]]: """ - Perform the calibration step and collect the quantization ranges. + Performs the calibration step and collect the quantization ranges. Args: dataset (`Dataset`): The dataset to use when performing the calibration step. calibration_config (`CalibrationConfig`): The configuration containing the parameters related to the calibration step. - onnx_augmented_model_name (`Union[str, os.PathLike]`): + onnx_augmented_model_name (`Union[str, Path]`, *optional*, defaults to `"augmented_model.onnx"`): The path used to save the augmented model used to collect the quantization ranges. - operators_to_quantize (`list`, *optional*): + operators_to_quantize (`Optional[List[NodeType]]`, *optional*): List of the operators types to quantize. - batch_size (`int`, defaults to 1): + batch_size (`int`, *optional*, defaults to 1): The batch size to use when collecting the quantization ranges values. use_external_data_format (`bool`, defaults to `False`): Whether to use external data format to store model which size is >= 2Gb. use_gpu (`bool`, defaults to `False`): Whether to use the GPU when collecting the quantization ranges values. - force_symmetric_range (`bool`, defaults to `False`): + force_symmetric_range (`bool`, *optional*, defaults to `False`): Whether to make the quantization ranges symmetric. Returns: @@ -195,7 +212,7 @@ def partial_fit( self, dataset: Dataset, calibration_config: CalibrationConfig, - onnx_augmented_model_name: str = "augmented_model.onnx", + onnx_augmented_model_name: Union[str, Path] = "augmented_model.onnx", operators_to_quantize: Optional[List[NodeType]] = None, batch_size: int = 1, use_external_data_format: bool = False, @@ -203,24 +220,24 @@ def partial_fit( force_symmetric_range: bool = False, ): """ - Perform the calibration step and collect the quantization ranges. + Performs the calibration step and collect the quantization ranges. Args: dataset (`Dataset`): The dataset to use when performing the calibration step. calibration_config (`CalibrationConfig`): The configuration containing the parameters related to the calibration step. - onnx_augmented_model_name (`Union[str, os.PathLike]`): + onnx_augmented_model_name (`Union[str, Path]`, *optional*, defaults to `"augmented_model.onnx"`): The path used to save the augmented model used to collect the quantization ranges. - operators_to_quantize (`list`, *optional*): + operators_to_quantize (`Optional[List[NodeType]]`, *optional*): List of the operators types to quantize. - batch_size (`int`, defaults to 1): + batch_size (`int`, *optional*, defaults to 1): The batch size to use when collecting the quantization ranges values. - use_external_data_format (`bool`, defaults to `False`): + use_external_data_format (`bool`, *optional*, defaults to `False`): Whether uto se external data format to store model which size is >= 2Gb. - use_gpu (`bool`, defaults to `False`): + use_gpu (`bool`, *optional*, defaults to `False`): Whether to use the GPU when collecting the quantization ranges values. - force_symmetric_range (`bool`, defaults to `False`): + force_symmetric_range (`bool`, *optional*, defaults to `False`): Whether to make the quantization ranges symmetric. Returns: @@ -267,21 +284,21 @@ def quantize( preprocessor: Optional[QuantizationPreprocessor] = None, ) -> Path: """ - Quantize a model given the optimization specifications defined in `quantization_config`. + Quantizes a model given the optimization specifications defined in `quantization_config`. Args: quantization_config (`QuantizationConfig`): The configuration containing the parameters related to quantization. save_dir (`Union[str, Path]`): The directory where the quantized model should be saved. - file_suffix (`str`, *optional*, defaults to `"quantized"`): + file_suffix (`Optional[str]`, *optional*, defaults to `"quantized"`): The file_suffix used to save the quantized model. - calibration_tensors_range (`Dict[NodeName, Tuple[float, float]]`, *optional*): + calibration_tensors_range (`Optional[Dict[NodeName, Tuple[float, float]]]`, *optional*): The dictionary mapping the nodes name to their quantization ranges, used and required only when applying static quantization. - use_external_data_format (`bool`, defaults to `False`): + use_external_data_format (`bool`, *optional*, defaults to `False`): Whether to use external data format to store model which size is >= 2Gb. - preprocessor (`QuantizationPreprocessor`, *optional*): + preprocessor (`Optional[QuantizationPreprocessor]`, *optional*): The preprocessor to use to collect the nodes to include or exclude from quantization. Returns: @@ -388,6 +405,11 @@ def quantize( ort_config = ORTConfig(quantization=quantization_config, use_external_data_format=use_external_data_format) ort_config.save_pretrained(save_dir) + if self.config is not None: + self.config.save_pretrained(save_dir) + + maybe_save_preprocessors(self.onnx_model_path.parent, save_dir) + return Path(save_dir) def get_calibration_dataset( @@ -402,25 +424,25 @@ def get_calibration_dataset( use_auth_token: bool = False, ) -> Dataset: """ - Create the calibration `datasets.Dataset` to use for the post-training static quantization calibration step + Creates the calibration `datasets.Dataset` to use for the post-training static quantization calibration step. Args: dataset_name (`str`): The dataset repository name on the Hugging Face Hub or path to a local directory containing data files to load to use for the calibration step. - num_samples (`int`, defaults to 100): + num_samples (`int`, *optional*, defaults to 100): The maximum number of samples composing the calibration dataset. - dataset_config_name (`str`, *optional*): + dataset_config_name (`Optional[str]`, *optional*): The name of the dataset configuration. - dataset_split (`str`, *optional*): + dataset_split (`Optional[str]`, *optional*): Which split of the dataset to use to perform the calibration step. - preprocess_function (`Callable`, *optional*): + preprocess_function (`Optional[Callable]`, *optional*): Processing function to apply to each example after loading dataset. - preprocess_batch (`bool`, defaults to `True`): + preprocess_batch (`bool`, *optional*, defaults to `True`): Whether the `preprocess_function` should be batched. - seed (`int`, defaults to 2016): + seed (`int`, *optional*, defaults to 2016): The random seed to use when shuffling the calibration dataset. - use_auth_token (`bool`, defaults to `False`): + use_auth_token (`bool`, *optional*, defaults to `False`): Whether to use the token generated when running `transformers-cli login` (necessary for some datasets like ImageNet). Returns: diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index 81375c97909..a0cfa2e2ae8 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -13,9 +13,10 @@ # limitations under the License. """Utility functions, classes and constants for ONNX Runtime.""" +import importlib.util import os from enum import Enum -from typing import Dict, Tuple, Type, Union +from typing import Dict, Tuple, Union import torch from transformers.onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast @@ -23,16 +24,14 @@ import onnx import onnxruntime as ort +import pkg_resources from ..onnx import OnnxConfigWithLoss, OnnxConfigWithPastAndLoss, OnnxSeq2SeqConfigWithPastAndLoss -from ..utils import NormalizedTextConfig logger = logging.get_logger(__name__) ONNX_WEIGHTS_NAME = "model.onnx" -OPTIMIZED_ONNX_WEIGHTS_NAME = "optimized_model.onnx" -QUANTIZED_ONNX_WEIGHTS_NAME = "q8_model.onnx" ONNX_ENCODER_NAME = "encoder_model.onnx" ONNX_DECODER_NAME = "decoder_model.onnx" @@ -41,7 +40,7 @@ def _is_gpu_available(): """ - checks if a gpu is available. + Checks if a gpu is available. """ available_providers = ort.get_available_providers() if "CUDAExecutionProvider" in available_providers and torch.cuda.is_available(): @@ -50,6 +49,24 @@ def _is_gpu_available(): return False +def is_onnxruntime_training_available(): + """ + Checks if onnxruntime-training is available. + """ + path_training_dependecy = os.path.join(ort.__path__[0], "training") + if os.path.exists(path_training_dependecy): + return True + else: + return False + + +def is_cupy_available(): + """ + Checks if onnxruntime-training is available. + """ + return importlib.util.find_spec("cupy") is not None + + class ORTConfigManager: """ A class that contains all the information needed by ONNX Runtime optimization for a given model type. diff --git a/optimum/utils/file_utils.py b/optimum/utils/file_utils.py new file mode 100644 index 00000000000..bfd62d9f3ca --- /dev/null +++ b/optimum/utils/file_utils.py @@ -0,0 +1,88 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility functions related to both local files and files on the Hugging Face Hub.""" + +import re +from pathlib import Path +from typing import List, Optional, Union + +from huggingface_hub import HfApi, HfFolder, get_hf_file_metadata, hf_hub_url + + +def validate_file_exists( + model_name_or_path: Union[str, Path], filename: str, subfolder: str = "", revision: Optional[str] = None +) -> bool: + """ + Checks that the file called `filename` exists in the `model_name_or_path` directory or model repo. + """ + model_path = Path(model_name_or_path) if isinstance(model_name_or_path, str) else model_name_or_path + if model_path.is_dir(): + return (model_path / subfolder / filename).is_file() + succeeded = True + try: + get_hf_file_metadata(hf_hub_url(model_name_or_path, filename, subfolder=subfolder, revision=revision)) + except Exception: + succeeded = False + return succeeded + + +def find_files_matching_pattern( + model_name_or_path: Union[str, Path], + pattern: str, + glob_pattern: str = "**/*", + subfolder: str = "", + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, +) -> List[Path]: + """ + Scans either a model repo or a local directory to find filenames matching the pattern. + + Args: + model_name_or_path (`Union[str, Path]`): + The name of the model repo on the Hugging Face Hub or the path to a local directory. + pattern (`str`): + The pattern to use to look for files. + glob_pattern (`str`, defaults to `"**/*"`): + The pattern to use to list all the files that need to be checked. + subfolder (`str`, defaults to `""`): + In case the model files are located inside a subfolder of the model directory / repo on the Hugging + Face Hub, you can specify the subfolder name here. + use_auth_token (`Optional[bool, str]`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`Optional[str]`, defaults to `None`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. + + Returns: + `List[Path]` + """ + model_path = Path(model_name_or_path) if isinstance(model_name_or_path, str) else model_name_or_path + pattern = re.compile(f"{subfolder}/{pattern}" if subfolder != "" else pattern) + if model_path.is_dir(): + path = model_path + files = model_path.glob("**/*.onnx") + files = [p for p in files if re.search(pattern, str(p))] + else: + path = model_name_or_path + if isinstance(use_auth_token, bool): + token = HfFolder().get_token() + else: + token = use_auth_token + repo_files = map(Path, HfApi().list_repo_files(model_name_or_path, revision=revision, token=token)) + if subfolder != "": + path = f"{path}/{subfolder}" + files = [Path(p) for p in repo_files if re.match(pattern, str(p))] + + return files diff --git a/optimum/utils/save_utils.py b/optimum/utils/save_utils.py new file mode 100644 index 00000000000..3d5550a2fd9 --- /dev/null +++ b/optimum/utils/save_utils.py @@ -0,0 +1,64 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities related to saving files.""" + +import logging +from pathlib import Path +from typing import List, Union + +from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer + + +logger = logging.getLogger(__name__) + + +def maybe_load_preprocessors(src_name_or_path: Union[str, Path], subfolder: str = "") -> List: + preprocessors = [] + try: + preprocessors.append(AutoTokenizer.from_pretrained(src_name_or_path, subfolder=subfolder)) + except Exception: + pass + + try: + preprocessors.append(AutoProcessor.from_pretrained(src_name_or_path, subfolder=subfolder)) + except Exception: + pass + + try: + preprocessors.append(AutoFeatureExtractor.from_pretrained(src_name_or_path, subfolder=subfolder)) + except Exception: + pass + return preprocessors + + +def maybe_save_preprocessors(src_name_or_path: Union[str, Path], dest_dir: Union[str, Path], src_subfolder: str = ""): + """ + Saves the tokenizer, the processor and the feature extractor when found in `src_dir` in `dest_dir`. + + Args: + src_dir (`Union[str, Path]`): + The source directory from which to copy the files. + dest_dir (`Union[str, Path]`): + The destination directory to copy the files to. + src_subfolder (`str`, defaults to `""`): + In case the preprocessor files are located inside a subfolder of the model directory / repo on the Hugging + Face Hub, you can specify the subfolder name here. + """ + if not isinstance(dest_dir, Path): + dest_dir = Path(dest_dir) + + dest_dir.mkdir(exist_ok=True) + for preprocessor in maybe_load_preprocessors(src_name_or_path, subfolder=src_subfolder): + preprocessor.save_pretrained(dest_dir) diff --git a/optimum/utils/testing_utils.py b/optimum/utils/testing_utils.py index 6509624f862..15c56d863ba 100644 --- a/optimum/utils/testing_utils.py +++ b/optimum/utils/testing_utils.py @@ -1,8 +1,10 @@ import importlib.util +import itertools import os import subprocess import sys import unittest +from typing import Any, Dict, Iterable from packaging import version @@ -105,3 +107,11 @@ def convert_to_hf_classes(mapping_dict): hf_names_dict[fast_layer_key] = hf_class return hf_names_dict + + +def grid_parameters(parameters: Dict[str, Iterable[Any]]) -> Iterable[Dict[str, Any]]: + """ + Generate an iterable over the grid of all combinations of parameters + """ + for params in itertools.product(*parameters.values()): + yield list(params) diff --git a/setup.py b/setup.py index 9274fe674d8..c1e8f9cbe83 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ "huggingface_hub>=0.8.0", ] -TESTS_REQUIRE = ["pytest", "requests", "parameterized", "pytest-xdist", "Pillow"] +TESTS_REQUIRE = ["pytest", "requests", "parameterized", "pytest-xdist", "Pillow", "sacremoses"] QUALITY_REQUIRE = ["black~=22.0", "flake8>=3.8.3", "isort>=5.5.4"] diff --git a/tests/bettertransformer/test_bettertransformer_encoder.py b/tests/bettertransformer/test_bettertransformer_encoder.py index 32a976e0446..1cac4635296 100644 --- a/tests/bettertransformer/test_bettertransformer_encoder.py +++ b/tests/bettertransformer/test_bettertransformer_encoder.py @@ -18,40 +18,43 @@ import torch import transformers -from transformers import AutoModel +from transformers import AutoModel, AutoTokenizer from optimum.bettertransformer import BETTER_TRANFORMER_LAYERS_MAPPING_DICT, BetterTransformer from optimum.utils.testing_utils import ( convert_to_hf_classes, + grid_parameters, is_torch_greater_than_113, require_accelerate, require_torch_gpu, ) +from parameterized import parameterized from testing_bettertransformer_utils import BetterTransformersTestMixin ALL_ENCODER_MODELS_TO_TEST = [ - "hf-internal-testing/tiny-random-DistilBertModel", "hf-internal-testing/tiny-random-AlbertModel", - "hf-internal-testing/tiny-random-RobertaModel", - "hf-internal-testing/tiny-xlm-roberta", - "hf-internal-testing/tiny-random-SplinterModel", - "hf-internal-testing/tiny-random-ErnieModel", + "hf-internal-testing/tiny-random-BertModel", "hf-internal-testing/tiny-random-camembert", + "hf-internal-testing/tiny-random-Data2VecTextModel", + "hf-internal-testing/tiny-random-DistilBertModel", "hf-internal-testing/tiny-random-ElectraModel", + "hf-internal-testing/tiny-random-ErnieModel", "hf-internal-testing/tiny-random-LayoutLMModel", - "hf-internal-testing/tiny-random-Data2VecTextModel", "hf-internal-testing/tiny-random-MarkupLMModel", - "hf-internal-testing/tiny-random-BertModel", - "ybelkada/random-tiny-BertGenerationModel", + "hf-internal-testing/tiny-random-rembert", + "hf-internal-testing/tiny-random-RobertaModel", + "hf-internal-testing/tiny-random-SplinterModel", "hf-internal-testing/tiny-random-TapasModel", "hf-internal-testing/tiny-random-RoCBertModel", + "hf-internal-testing/tiny-xlm-roberta", + "ybelkada/random-tiny-BertGenerationModel", ] ALL_ENCODER_DECODER_MODELS_TO_TEST = [ + "hf-internal-testing/tiny-random-bart", "hf-internal-testing/tiny-random-FSMTModel", - "hf-internal-testing/tiny-random-BartModel", - "hf-internal-testing/tiny-random-MBartModel", + "hf-internal-testing/tiny-random-mbart", "hf-internal-testing/tiny-random-nllb", ] @@ -278,13 +281,24 @@ class BetterTransformersEncoderDecoderTest(BetterTransformersTestMixin, unittest def tearDown(self): gc.collect() - def prepare_inputs_for_class(self, model_id=None): - input_dict = { - "input_ids": torch.LongTensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]), - "attention_mask": torch.LongTensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]]), - "decoder_input_ids": torch.LongTensor([[0], [0]]), - } - return input_dict + def prepare_inputs_for_class(self, model_id, **preprocessor_kwargs): + tokenizer = AutoTokenizer.from_pretrained(model_id) + padding = preprocessor_kwargs.pop("padding", True) + inputs = tokenizer(["a dummy input", "and two"], return_tensors="pt", padding=padding, **preprocessor_kwargs) + inputs["decoder_input_ids"] = inputs["input_ids"] # just a hack for m2m100 + return inputs + + # run the test over all possible combinations of `model_id` and `padding` + @parameterized.expand( + grid_parameters( + { + "model_id": ALL_ENCODER_DECODER_MODELS_TO_TEST, + "padding": ["max_length", True], + } + ) + ) + def test_logits(self, model_id, padding, max_length=20): + super().test_logits([model_id], padding=padding, max_length=max_length) def get_batch(batch_size, avg_seqlen, max_sequence_length, seqlen_stdev, vocab_size, pad_idx=0): diff --git a/tests/bettertransformer/testing_bettertransformer_utils.py b/tests/bettertransformer/testing_bettertransformer_utils.py index 901fb72f963..e1876442d7e 100644 --- a/tests/bettertransformer/testing_bettertransformer_utils.py +++ b/tests/bettertransformer/testing_bettertransformer_utils.py @@ -14,6 +14,7 @@ # limitations under the License. import tempfile import unittest +from typing import List, Optional import torch from transformers import AutoModel @@ -34,21 +35,24 @@ class BetterTransformersTestMixin: """ all_models_to_test = [] - def prepare_inputs_for_class(self): + def prepare_inputs_for_class(self, models_to_test=None): raise NotImplementedError - def test_logits(self): + def test_logits(self, models_to_test: Optional[List] = None, **preprocessor_kwargs): r""" This tests if the converted model produces the same logits than the original model. """ # The first row of the attention mask needs to be all ones -> check: https://github.com/pytorch/pytorch/blob/19171a21ee8a9cc1a811ac46d3abd975f0b6fc3b/test/test_nn.py#L5283 - for model_to_test in self.all_models_to_test: - inputs = self.prepare_inputs_for_class(model_to_test) + if models_to_test is None: + models_to_test = self.all_models_to_test + + for model_id in models_to_test: + inputs = self.prepare_inputs_for_class(model_id=model_id, **preprocessor_kwargs) torch.manual_seed(0) - hf_random_model = AutoModel.from_pretrained(model_to_test).eval() + hf_random_model = AutoModel.from_pretrained(model_id).eval() random_config = hf_random_model.config torch.manual_seed(0) @@ -75,14 +79,31 @@ def test_logits(self): # discrepency. tol = 4e-2 else: - tol = 1e-3 - - self.assertTrue( - torch.allclose(hf_hidden_states[:, :3, :], bt_hidden_states[:, :3, :], atol=tol), - "The BetterTransformers Converted model does not produce the same logits as the original model. Failed for the model {}".format( - hf_random_model.__class__.__name__ - ), - ) + tol = 2e-3 + + if "attention_mask" in inputs: + for i, attention_mask in enumerate(inputs["attention_mask"]): + length = torch.argwhere(attention_mask != 0).max().item() + self.assert_equal( + tensor1=hf_hidden_states[i, : length + 1, :], + tensor2=bt_hidden_states[i, : length + 1, :], + atol=tol, + model_name=hf_random_model.__class__.__name__, + ) + else: + self.assert_equal( + tensor1=hf_hidden_states[:, :3, :], + tensor2=bt_hidden_states[:, :3, :], + atol=tol, + model_name=hf_random_model.__class__.__name__, + ) + + def assert_equal(self, tensor1, tensor2, atol: float, model_name: str): + self.assertTrue( + torch.allclose(tensor1, tensor2, atol=atol), + f"The BetterTransformer converted model does not produce the same logits as the original model. Failed for the model {model_name}." + f" Maxdiff: {torch.abs(tensor1 - tensor2).max()}", + ) def test_raise_on_save(self): r""" diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 380ff99d80c..4841d01efad 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -41,7 +41,6 @@ import onnxruntime import requests from huggingface_hub.constants import default_cache_path -from huggingface_hub.utils import EntryNotFoundError from optimum.onnxruntime import ( ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, @@ -226,7 +225,7 @@ def test_load_seq2seq_model_unknown_provider(self): ORTModelForSeq2SeqLM.from_pretrained(self.ONNX_SEQ2SEQ_MODEL_ID, provider="FooExecutionProvider") def test_load_model_from_hub_without_onnx_model(self): - with self.assertRaises(EntryNotFoundError): + with self.assertRaises(FileNotFoundError): ORTModel.from_pretrained(self.FAIL_ONNX_MODEL_ID) def test_model_on_cpu(self): @@ -451,7 +450,7 @@ def test_save_model_with_different_name(self): model = ORTModel.from_pretrained(tmpdirname, file_name=test_model_name) - self.assertEqual(model.latest_model_name, test_model_name) + self.assertEqual(model.model_name, test_model_name) @require_hf_token def test_save_model_from_hub(self): @@ -1685,3 +1684,24 @@ def test_default_pipeline_and_model_device(self, *args, **kwargs): tokenizer = get_preprocessor(model_id) pipe = pipeline("feature-extraction", model=onnx_model, tokenizer=tokenizer) self.assertEqual(pipe.device, onnx_model.device) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + @require_torch_gpu + def test_compare_to_io_binding(self, *args, **kwargs): + model_arch, model_id = args + set_seed(SEED) + onnx_model = ORTModelForCustomTasks.from_pretrained(model_id, use_io_binding=False) + set_seed(SEED) + io_model = ORTModelForCustomTasks.from_pretrained(model_id, use_io_binding=True) + tokenizer = get_preprocessor(model_id) + tokens = tokenizer("This is a sample output", return_tensors="pt") + onnx_outputs = onnx_model(**tokens) + io_outputs = io_model(**tokens) + + self.assertTrue("pooler_output" in io_outputs) + self.assertIsInstance(io_outputs.pooler_output, torch.Tensor) + + # compare tensor outputs + self.assertTrue(torch.equal(onnx_outputs.pooler_output, io_outputs.pooler_output)) + + gc.collect() diff --git a/tests/onnxruntime/test_optimization.py b/tests/onnxruntime/test_optimization.py index 9f05cba9051..47aa65ab196 100644 --- a/tests/onnxruntime/test_optimization.py +++ b/tests/onnxruntime/test_optimization.py @@ -18,15 +18,12 @@ import unittest from pathlib import Path -import numpy as np import torch -from transformers import AutoConfig, AutoTokenizer +from transformers import AutoTokenizer import onnx -from onnxruntime import InferenceSession from optimum.onnxruntime import ORTConfig, ORTModelForSequenceClassification, ORTOptimizer -from optimum.onnxruntime.configuration import AutoQuantizationConfig, OptimizationConfig -from optimum.onnxruntime.modeling_ort import ORTModelForSequenceClassification +from optimum.onnxruntime.configuration import OptimizationConfig from optimum.onnxruntime.modeling_seq2seq import ORTModelForSeq2SeqLM from parameterized import parameterized @@ -96,9 +93,6 @@ def test_compare_original_seq2seq_model_with_optimized_model(self, model_cls, mo optimizer.optimize(optimization_config=optimization_config, save_dir=tmp_dir) optimized_model = model_cls.from_pretrained( tmp_dir, - encoder_file_name="encoder_model_optimized.onnx", - decoder_file_name="decoder_model_optimized.onnx", - decoder_with_past_file_name="decoder_with_past_model_optimized.onnx" if use_cache else None, from_transformers=False, use_cache=use_cache, ) diff --git a/tests/onnxruntime/test_quantization.py b/tests/onnxruntime/test_quantization.py index 8ebc0fdacc4..b3c83acafa1 100644 --- a/tests/onnxruntime/test_quantization.py +++ b/tests/onnxruntime/test_quantization.py @@ -54,7 +54,7 @@ def test_from_pretrained_method(self, *args): def test_fail_from_pretrained_method(self): with self.assertRaises(Exception) as context: ORTQuantizer.from_pretrained("bert-base-cased") - self.assertIn("Unable to load model from bert-base-cased", str(context.exception)) + self.assertIn("Could not find any ONNX model file in bert-base-cased", str(context.exception)) with self.assertRaises(Exception) as context: model = ORTModelForSeq2SeqLM.from_pretrained("optimum/t5-small") diff --git a/tests/utils/test_dummpy_input_generators.py b/tests/utils/test_dummpy_input_generators.py index 88f8c5d4e14..58dd0ca85ac 100644 --- a/tests/utils/test_dummpy_input_generators.py +++ b/tests/utils/test_dummpy_input_generators.py @@ -16,8 +16,12 @@ from contextlib import nullcontext from unittest import TestCase +import torch from transformers import AutoConfig +from optimum.utils import DummyAudioInputGenerator, DummyTextInputGenerator, DummyVisionInputGenerator +from optimum.utils.normalized_config import NormalizedConfigManager +from optimum.utils.testing_utils import grid_parameters from parameterized import parameterized @@ -43,23 +47,6 @@ "audio_sequence_length": [16000, 8000], } -import itertools -from typing import Any, Dict, Iterable - -import torch -from transformers import AutoConfig - -from optimum.utils import DummyAudioInputGenerator, DummyTextInputGenerator, DummyVisionInputGenerator -from optimum.utils.normalized_config import NormalizedConfigManager - - -def grid_parameters(parameters: Dict[str, Iterable[Any]]) -> Iterable[Dict[str, Any]]: - """ - Generate an iterable over the grid of all combinations of parameters - """ - for params in itertools.product(*parameters.values()): - yield list(params) - class GenerateDummy(TestCase): @parameterized.expand(