Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support the merge of decoder without/with past for encoder-decoder models in the ONNX export #926

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix
fxmarty committed Mar 27, 2023
commit 3bcc4cd34e2250bab6f05891688cf27b9a432ee3
1 change: 1 addition & 0 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
@@ -476,6 +476,7 @@ def main_export(
# TODO: treating stable diffusion separately is quite ugly
if not no_post_process and task != "stable-diffusion":
try:
logger.info("Post-processing the exported models...")
models_and_onnx_configs, onnx_files_subpaths = onnx_config.post_process_exported_models(
output, models_and_onnx_configs, onnx_files_subpaths
)
59 changes: 36 additions & 23 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
@@ -33,15 +33,15 @@
DEFAULT_DUMMY_SHAPES,
DummyInputGenerator,
DummyLabelsGenerator,
DummyPastKeyValuesGenerator,
DummySeq2SeqPastKeyValuesGenerator,
is_diffusers_available,
logging,
)
from ...utils import TORCH_MINIMUM_VERSION as GLOBAL_MIN_TORCH_VERSION
from ...utils.doc import add_dynamic_docstring
from ...utils.import_utils import is_onnx_available, is_onnxruntime_available
from ..base import ExportConfig
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME
from .model_patcher import ModelPatcher, Seq2SeqModelPatcher


@@ -456,7 +456,6 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
PAD_ATTENTION_MASK_TO_PAST: bool = False
USE_PAST_IN_INPUTS: Optional[bool] = None
USE_PRESENT_IN_OUTPUTS: Optional[bool] = None
DUMMY_PKV_GENERATOR_CLASS = DummyPastKeyValuesGenerator

def __init__(
self,
@@ -617,6 +616,26 @@ def flatten_output_collection_property(self, name: str, field: Iterable[Any]) ->

return flattened_output

def generate_dummy_inputs_for_validation(self, reference_model_inputs: Dict[str, Any]) -> Dict[str, Any]:
if self._behavior is ConfigBehavior.DECODER:
reference_model_inputs["input_ids"] = reference_model_inputs.pop("decoder_input_ids")
reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0]

if self.is_merged is True and self.use_cache_branch is True:
reference_model_inputs["use_cache_branch"] = DummyInputGenerator.constant_tensor(shape=[1], value=True)
elif self.is_merged is True and self.use_cache_branch is False:
reference_model_inputs["use_cache_branch"] = DummyInputGenerator.constant_tensor(shape=[1], value=False)
Comment on lines +620 to +623
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.is_merged is True and self.use_cache_branch is True:
reference_model_inputs["use_cache_branch"] = DummyInputGenerator.constant_tensor(shape=[1], value=True)
elif self.is_merged is True and self.use_cache_branch is False:
reference_model_inputs["use_cache_branch"] = DummyInputGenerator.constant_tensor(shape=[1], value=False)
if self.is_merged:
reference_model_inputs["use_cache_branch"] = DummyInputGenerator.constant_tensor(shape=[1], value=self.use_cache_branch)

Copy link
Contributor Author

@fxmarty fxmarty Mar 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edit: actually this is less explicit


# We don't support optional inputs for now, so even though the non-cache branch is used,
# dummy past key values are necessary
batch_size = reference_model_inputs["input_ids"].shape[0]
pkv_generator = self.DUMMY_PKV_GENERATOR_CLASS(
task=self.task, normalized_config=self._normalized_config, sequence_length=1, batch_size=batch_size
)
reference_model_inputs["past_key_values"] = pkv_generator.generate("past_key_values", framework="pt")

return reference_model_inputs


class ConfigBehavior(str, enum.Enum):
"""
@@ -636,6 +655,8 @@ class OnnxSeq2SeqConfigWithPast(OnnxConfigWithPast):
Inherits from [`~exporters.onnx.OnnxConfigWithPast`]. A base class to handle the ONNX configuration of encoder-decoder models.
"""

DUMMY_PKV_GENERATOR_CLASS = DummySeq2SeqPastKeyValuesGenerator

def __init__(
self,
config: "PretrainedConfig",
@@ -702,7 +723,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
new_axes_names = {}
for axis_idx, axis_name in axes_names.items():
if "sequence" in axis_name:
if not self.use_past_in_inputs:
if self.use_past_in_inputs is False or self.is_merged is True:
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
new_axes_names[axis_idx] = sequence_name
else:
# Trick to force it since ONNX sometimes infer a dynamic axis where it's not.
@@ -734,7 +755,11 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
inputs_or_outputs[f"{name}.{i}.decoder.key"] = {0: "batch_size", 2: decoder_sequence_name}
inputs_or_outputs[f"{name}.{i}.decoder.value"] = {0: "batch_size", 2: decoder_sequence_name}

if direction == "inputs" or (self._behavior is ConfigBehavior.DECODER and self.use_past is False):
if (
direction == "inputs"
or (self._behavior is ConfigBehavior.DECODER and self.use_past is False)
or self.is_merged is True
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
):
inputs_or_outputs[f"{name}.{i}.encoder.key"] = {0: "batch_size", 2: "encoder_sequence_length"}
inputs_or_outputs[f"{name}.{i}.encoder.value"] = {0: "batch_size", 2: "encoder_sequence_length"}

@@ -780,8 +805,12 @@ def post_process_exported_models(
raise Exception(f"Unable to merge decoders. Detailed error: {e}")

# In order to do the validation of the two branches on the same file
onnx_files_subpaths[1] = decoder_merged_path.name
onnx_files_subpaths[2] = decoder_merged_path.name
if onnx_files_subpaths is not None:
encoder_path = onnx_files_subpaths[0]
else:
encoder_path = ONNX_ENCODER_NAME + ".onnx"

onnx_files_subpaths = [encoder_path, decoder_merged_path.name, decoder_merged_path.name]

# We validate the two branches of the decoder model then
models_and_onnx_configs[ONNX_DECODER_NAME][1].is_merged = True
@@ -796,22 +825,6 @@ def post_process_exported_models(

return models_and_onnx_configs, onnx_files_subpaths

def generate_dummy_inputs_for_validation(self, reference_model_inputs: Dict[str, Any]) -> Dict[str, Any]:
if self._behavior == "decoder" and self.is_merged is True and self.use_cache_branch is True:
reference_model_inputs["use_cache_branch"] = DummyInputGenerator.constant_tensor(shape=[1], value=True)
elif self._behavior == "decoder" and self.is_merged is True and self.use_cache_branch is False:
reference_model_inputs["use_cache_branch"] = DummyInputGenerator.constant_tensor(shape=[1], value=False)

# We don't support optional inputs for now, so even though the non-cache branch is used,
# dummy past key values are necessary
batch_size = reference_model_inputs["input_ids"].shape[0]
pkv_generator = self.DUMMY_PKV_GENERATOR_CLASS(
task=self.task, normalized_config=self._normalized_config, sequence_length=1, batch_size=batch_size
)
reference_model_inputs["past_key_values"] = pkv_generator.generate("past_key_values", framework="pt")

return reference_model_inputs


class OnnxConfigWithLoss(OnnxConfig, ABC):
"""
39 changes: 1 addition & 38 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
@@ -64,6 +64,7 @@ class TextDecoderOnnxConfig(OnnxConfigWithPast):

PAD_ATTENTION_MASK_TO_PAST = True
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = DummyPastKeyValuesGenerator

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
@@ -131,22 +132,6 @@ def post_process_exported_models(

return models_and_onnx_configs, onnx_files_subpaths

def generate_dummy_inputs_for_validation(self, reference_model_inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.is_merged is True and self.use_cache_branch is True:
reference_model_inputs["use_cache_branch"] = DummyInputGenerator.constant_tensor(shape=[1], value=True)
elif self.is_merged is True and self.use_cache_branch is False:
reference_model_inputs["use_cache_branch"] = DummyInputGenerator.constant_tensor(shape=[1], value=False)

# We don't support optional inputs for now, so even though the non-cache branch is used,
# dummy past key values are necessary
batch_size = reference_model_inputs["input_ids"].shape[0]
pkv_generator = self.DUMMY_PKV_GENERATOR_CLASS(
task=self.task, normalized_config=self._normalized_config, sequence_length=1, batch_size=batch_size
)
reference_model_inputs["past_key_values"] = pkv_generator.generate("past_key_values", framework="pt")

return reference_model_inputs


class TextSeq2SeqOnnxConfig(OnnxSeq2SeqConfigWithPast):
"""
@@ -220,14 +205,6 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGen

return dummy_inputs_generators

def generate_dummy_inputs_for_validation(self, reference_model_inputs: Dict[str, Any]) -> Dict[str, Any]:
if self._behavior is ConfigBehavior.DECODER:
reference_model_inputs["input_ids"] = reference_model_inputs.pop("decoder_input_ids")
reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0]
# TODO: validate that it should be removed.
# reference_model_inputs["encoder_attention_mask"] = reference_model_inputs.pop("attention_mask")
return reference_model_inputs


class VisionOnnxConfig(OnnxConfig):
"""
@@ -295,13 +272,6 @@ def torch_to_onnx_input_map(self) -> Dict[str, str]:
}
return {}

def generate_dummy_inputs_for_validation(self, reference_model_inputs: Dict[str, Any]) -> Dict[str, Any]:
if self._behavior is ConfigBehavior.DECODER:
reference_model_inputs["input_ids"] = reference_model_inputs.pop("decoder_input_ids")
reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0]

return reference_model_inputs


class EncoderDecoderOnnxConfig(OnnxSeq2SeqConfigWithPast):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator,)
@@ -389,13 +359,6 @@ def flatten_past_key_values(self, flattened_output, name, idx, t):
def flatten_output_collection_property(self, name: str, field: Iterable[Any]) -> Dict[str, Any]:
return self._decoder_onnx_config.flatten_output_collection_property(name, field)

def generate_dummy_inputs_for_validation(self, reference_model_inputs: Dict[str, Any]) -> Dict[str, Any]:
if self._behavior is ConfigBehavior.DECODER:
reference_model_inputs["input_ids"] = reference_model_inputs.pop("decoder_input_ids")
reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0]

return reference_model_inputs

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = super().outputs