Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support the merge of decoder without/with past for encoder-decoder models in the ONNX export #926

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ def main_export(
# TODO: treating stable diffusion separately is quite ugly
if not no_post_process and task != "stable-diffusion":
try:
logger.info("Post-processing the exported models...")
models_and_onnx_configs, onnx_files_subpaths = onnx_config.post_process_exported_models(
output, models_and_onnx_configs, onnx_files_subpaths
)
Expand Down
89 changes: 85 additions & 4 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,30 @@
import re
from abc import ABC, abstractmethod
from collections import OrderedDict
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
from transformers.utils import is_torch_available

from ...onnx import merge_decoders
from ...utils import (
DEFAULT_DUMMY_SHAPES,
DummyInputGenerator,
DummyLabelsGenerator,
DummySeq2SeqPastKeyValuesGenerator,
is_diffusers_available,
logging,
)
from ...utils import TORCH_MINIMUM_VERSION as GLOBAL_MIN_TORCH_VERSION
from ...utils.doc import add_dynamic_docstring
from ...utils.import_utils import is_onnx_available, is_onnxruntime_available
from ..base import ExportConfig
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME
from .model_patcher import ModelPatcher, Seq2SeqModelPatcher


if TYPE_CHECKING:
from pathlib import Path

from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel

if is_diffusers_available():
Expand Down Expand Up @@ -614,6 +616,22 @@ def flatten_output_collection_property(self, name: str, field: Iterable[Any]) ->

return flattened_output

def generate_dummy_inputs_for_validation(self, reference_model_inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.is_merged is True and self.use_cache_branch is True:
reference_model_inputs["use_cache_branch"] = DummyInputGenerator.constant_tensor(shape=[1], value=True)
elif self.is_merged is True and self.use_cache_branch is False:
reference_model_inputs["use_cache_branch"] = DummyInputGenerator.constant_tensor(shape=[1], value=False)
Comment on lines +620 to +623
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.is_merged is True and self.use_cache_branch is True:
reference_model_inputs["use_cache_branch"] = DummyInputGenerator.constant_tensor(shape=[1], value=True)
elif self.is_merged is True and self.use_cache_branch is False:
reference_model_inputs["use_cache_branch"] = DummyInputGenerator.constant_tensor(shape=[1], value=False)
if self.is_merged:
reference_model_inputs["use_cache_branch"] = DummyInputGenerator.constant_tensor(shape=[1], value=self.use_cache_branch)

Copy link
Contributor Author

@fxmarty fxmarty Mar 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edit: actually this is less explicit


# We don't support optional inputs for now, so even though the non-cache branch is used,
# dummy past key values are necessary
batch_size = reference_model_inputs["input_ids"].shape[0]
pkv_generator = self.DUMMY_PKV_GENERATOR_CLASS(
task=self.task, normalized_config=self._normalized_config, sequence_length=1, batch_size=batch_size
)
reference_model_inputs["past_key_values"] = pkv_generator.generate("past_key_values", framework="pt")

return reference_model_inputs


class ConfigBehavior(str, enum.Enum):
"""
Expand All @@ -633,6 +651,8 @@ class OnnxSeq2SeqConfigWithPast(OnnxConfigWithPast):
Inherits from [`~exporters.onnx.OnnxConfigWithPast`]. A base class to handle the ONNX configuration of encoder-decoder models.
"""

DUMMY_PKV_GENERATOR_CLASS = DummySeq2SeqPastKeyValuesGenerator

def __init__(
self,
config: "PretrainedConfig",
Expand Down Expand Up @@ -699,7 +719,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
new_axes_names = {}
for axis_idx, axis_name in axes_names.items():
if "sequence" in axis_name:
if not self.use_past_in_inputs:
if self.use_past_in_inputs is False or self.is_merged is True:
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
new_axes_names[axis_idx] = sequence_name
else:
# Trick to force it since ONNX sometimes infer a dynamic axis where it's not.
Expand Down Expand Up @@ -731,7 +751,11 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
inputs_or_outputs[f"{name}.{i}.decoder.key"] = {0: "batch_size", 2: decoder_sequence_name}
inputs_or_outputs[f"{name}.{i}.decoder.value"] = {0: "batch_size", 2: decoder_sequence_name}

if direction == "inputs" or (self._behavior is ConfigBehavior.DECODER and self.use_past is False):
if (
self.is_merged is True
or (self._behavior is ConfigBehavior.DECODER and self.use_past is False)
or direction == "inputs"
):
inputs_or_outputs[f"{name}.{i}.encoder.key"] = {0: "batch_size", 2: "encoder_sequence_length"}
inputs_or_outputs[f"{name}.{i}.encoder.value"] = {0: "batch_size", 2: "encoder_sequence_length"}

Expand All @@ -747,6 +771,63 @@ def flatten_past_key_values(self, flattened_output, name, idx, t):
def patch_model_for_export(self, model: Union["PreTrainedModel", "TFPreTrainedModel"]) -> ModelPatcher:
return Seq2SeqModelPatcher(self, model)

def post_process_exported_models(
self,
path: Path,
models_and_onnx_configs: Dict[
str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"]
],
onnx_files_subpaths: List[str],
):
# Attempt to merge only if the decoder was exported without/with past
if self.use_past is True and len(models_and_onnx_configs) == 3:
if onnx_files_subpaths is not None:
decoder_path = Path(path, onnx_files_subpaths[1])
decoder_with_past_path = Path(path, onnx_files_subpaths[2])
else:
decoder_path = Path(path, ONNX_DECODER_NAME + ".onnx")
decoder_with_past_path = Path(path, ONNX_DECODER_WITH_PAST_NAME + ".onnx")
decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx")
try:
# The decoder with past does not output the cross attention past key values as they are constant,
# hence the need for strict=False
merge_decoders(
decoder=decoder_path,
decoder_with_past=decoder_with_past_path,
save_path=decoder_merged_path,
strict=False,
)
except Exception as e:
raise Exception(f"Unable to merge decoders. Detailed error: {e}")

# In order to do the validation of the two branches on the same file
if onnx_files_subpaths is not None:
encoder_path = onnx_files_subpaths[0]
else:
encoder_path = ONNX_ENCODER_NAME + ".onnx"

onnx_files_subpaths = [encoder_path, decoder_merged_path.name, decoder_merged_path.name]

# We validate the two branches of the decoder model then
models_and_onnx_configs[ONNX_DECODER_NAME][1].is_merged = True
models_and_onnx_configs[ONNX_DECODER_NAME][1].use_cache_branch = False

# Past key values won't be generated by default, but added in the input
models_and_onnx_configs[ONNX_DECODER_NAME][1].use_past = False
models_and_onnx_configs[ONNX_DECODER_NAME][1].use_past_in_inputs = True

models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].use_cache_branch = True
models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].is_merged = True

return models_and_onnx_configs, onnx_files_subpaths

def generate_dummy_inputs_for_validation(self, reference_model_inputs: Dict[str, Any]) -> Dict[str, Any]:
if self._behavior is ConfigBehavior.DECODER:
reference_model_inputs["input_ids"] = reference_model_inputs.pop("decoder_input_ids")
reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0]

return super().generate_dummy_inputs_for_validation(reference_model_inputs)


class OnnxConfigWithLoss(OnnxConfig, ABC):
"""
Expand Down
68 changes: 32 additions & 36 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from collections import OrderedDict
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union

from transformers.utils import is_tf_available

Expand Down Expand Up @@ -132,22 +132,6 @@ def post_process_exported_models(

return models_and_onnx_configs, onnx_files_subpaths

def generate_dummy_inputs_for_validation(self, reference_model_inputs: Mapping[str, Any]) -> Mapping[str, Any]:
if self.is_merged is True and self.use_cache_branch is True:
reference_model_inputs["use_cache_branch"] = DummyInputGenerator.constant_tensor(shape=[1], value=True)
elif self.is_merged is True and self.use_cache_branch is False:
reference_model_inputs["use_cache_branch"] = DummyInputGenerator.constant_tensor(shape=[1], value=False)

# We don't support optional inputs for now, so even though the non-cache branch is used,
# dummy past key values are necessary
batch_size = reference_model_inputs["input_ids"].shape[0]
pkv_generator = self.DUMMY_PKV_GENERATOR_CLASS(
task=self.task, normalized_config=self._normalized_config, sequence_length=1, batch_size=batch_size
)
reference_model_inputs["past_key_values"] = pkv_generator.generate("past_key_values", framework="pt")

return reference_model_inputs


class TextSeq2SeqOnnxConfig(OnnxSeq2SeqConfigWithPast):
"""
Expand Down Expand Up @@ -221,14 +205,6 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGen

return dummy_inputs_generators

def generate_dummy_inputs_for_validation(self, reference_model_inputs: Dict[str, Any]) -> Dict[str, Any]:
if self._behavior is ConfigBehavior.DECODER:
reference_model_inputs["input_ids"] = reference_model_inputs.pop("decoder_input_ids")
reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0]
# TODO: validate that it should be removed.
# reference_model_inputs["encoder_attention_mask"] = reference_model_inputs.pop("attention_mask")
return reference_model_inputs


class VisionOnnxConfig(OnnxConfig):
"""
Expand Down Expand Up @@ -296,13 +272,6 @@ def torch_to_onnx_input_map(self) -> Dict[str, str]:
}
return {}

def generate_dummy_inputs_for_validation(self, reference_model_inputs: Dict[str, Any]) -> Dict[str, Any]:
if self._behavior is ConfigBehavior.DECODER:
reference_model_inputs["input_ids"] = reference_model_inputs.pop("decoder_input_ids")
reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0]

return reference_model_inputs


class EncoderDecoderOnnxConfig(OnnxSeq2SeqConfigWithPast):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator,)
Expand Down Expand Up @@ -391,11 +360,38 @@ def flatten_output_collection_property(self, name: str, field: Iterable[Any]) ->
return self._decoder_onnx_config.flatten_output_collection_property(name, field)

def generate_dummy_inputs_for_validation(self, reference_model_inputs: Dict[str, Any]) -> Dict[str, Any]:
if self._behavior is ConfigBehavior.DECODER:
reference_model_inputs["input_ids"] = reference_model_inputs.pop("decoder_input_ids")
reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0]
if self._behavior is ConfigBehavior.ENCODER:
return self._encoder_onnx_config.generate_dummy_inputs_for_validation(reference_model_inputs)
else:
if self._behavior is ConfigBehavior.DECODER:
reference_model_inputs["input_ids"] = reference_model_inputs.pop("decoder_input_ids")
reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0]

return self._decoder_onnx_config.generate_dummy_inputs_for_validation(reference_model_inputs)

return reference_model_inputs
def post_process_exported_models(
self,
path: Path,
models_and_onnx_configs: Dict[
str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"]
],
onnx_files_subpaths: List[str],
):
models_and_onnx_configs, onnx_files_subpaths = super().post_process_exported_models(
path, models_and_onnx_configs, onnx_files_subpaths
)
if self.use_past is True and len(models_and_onnx_configs) == 3:
models_and_onnx_configs[ONNX_DECODER_NAME][1]._decoder_onnx_config.is_merged = True
models_and_onnx_configs[ONNX_DECODER_NAME][1]._decoder_onnx_config.use_cache_branch = False

# Past key values won't be generated by default, but added in the input
models_and_onnx_configs[ONNX_DECODER_NAME][1]._decoder_onnx_config.use_past = False
models_and_onnx_configs[ONNX_DECODER_NAME][1]._decoder_onnx_config.use_past_in_inputs = True

models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1]._decoder_onnx_config.use_cache_branch = True
models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1]._decoder_onnx_config.is_merged = True

return models_and_onnx_configs, onnx_files_subpaths

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
Expand Down
10 changes: 8 additions & 2 deletions tests/exporters/onnx/test_exporters_onnx_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _get_models_to_test(export_models_dict: Dict):
for task in tasks:
models_to_test.append((f"{model_type}_{task}", model_type, model_name, task, False, False))

# -with-past and monolith case are absurd, so we don't test them as not supported
# -with-past and monolith cases are absurd, so we don't test them as not supported
if any(
task == ort_special_task
for ort_special_task in ["causal-lm", "seq2seq-lm", "speech2seq-lm", "vision2seq-lm"]
Expand All @@ -66,7 +66,13 @@ def _get_models_to_test(export_models_dict: Dict):
)

# For other tasks, we don't test --no-post-process as there is none anyway
if task == "causal-lm-with-past":
if task in [
"default-with-past",
"causal-lm-with-past",
"speech2seq-lm-with-past",
"vision2seq-lm-with-past",
"seq2seq-lm-with-past",
]:
models_to_test.append(
(f"{model_type}_{task}_no_postprocess", model_type, model_name, task, False, True)
)
Expand Down