Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix OnnxConversion failure in GPTJ example #469

Merged
merged 2 commits into from
Aug 9, 2023

Conversation

yuwenzho
Copy link
Contributor

@yuwenzho yuwenzho commented Aug 8, 2023

Describe your changes

summary:
Skip removing past_key_values in dummy input in OnnxConversion, which is required for hf model with past.

Details:
GPTJ example failed in OnnxConversion with error: RuntimeError: The size of tensor a (8) must match the size of tensor b (18) at non-singleton dimension 3
To quickly reproduce the error:

from olive.model import PyTorchModel
from olive.passes.olive_pass import create_pass_from_dict
from olive.passes.onnx.conversion import OnnxConversion
from olive.systems.local import LocalSystem

local_system = LocalSystem()
hf_config={"model_name": "hf-internal-testing/tiny-random-gptj",
                   "task": "text-generation",
                   "feature": "causal-lm-with-past"}
pytorch_model = PyTorchModel(hf_config=hf_config)
onnx_conversion_config = {}

p = create_pass_from_dict(OnnxConversion, onnx_conversion_config, disable_search=True)
output_folder = str("onnx")

# execute
onnx_model = local_system.run_pass(p, pytorch_model, None, output_folder)

The error is caused in conversion.py#L115-L124.

# some dummy inputs might not be used in the model, so we need to remove them
# this can happen when we are using an hf dataset to generate dummy inputs
# only handle dict for now since we cannot get the name of the input from a list/tuple
if isinstance(dummy_inputs, dict):
    dummy_input_keys = set(dummy_inputs.keys())
    unused_keys = dummy_input_keys - set(input_names)
-    # print(dummy_input_keys): {'past_key_values', 'attention_mask', 'input_ids'}
-    # print(set(input_names)): {'past_key_values.3.value', 'past_key_values.2.key', 'attention_mask', 'past_key_values.0.value', 'past_key_values.4.key', 'past_key_values.3.key', 'past_key_values.1.value', 'input_ids', 'past_key_values.1.key', 'past_key_values.4.value', 'past_key_values.2.value', 'past_key_values.0.key'}
-    # print(unused_keys): {'past_key_values'}
-    # past key values in dummy input are removed here mistakenly, which will cause torch.onnx.export failure.
    if unused_keys:
        logger.debug(f"Removing unused dummy inputs: {unused_keys}")
    for key in unused_keys:
        del dummy_inputs[key]

Checklist before requesting a review

  • Add unit tests for this change.
  • Make sure all tests can pass.
  • Update documents if necessary.
  • Format your code by running pre-commit run --all-files
  • Is this a user-facing change? If yes, give a description of this change to be included in the release notes.

(Optional) Issue link

@yuwenzho
Copy link
Contributor Author

yuwenzho commented Aug 8, 2023

My modification may not be robust. If you have any suggestions, please feel free to raise them.

@yuwenzho yuwenzho force-pushed the yuwenzho/fix_gptj branch 2 times, most recently from 2a6de0b to 507caa7 Compare August 9, 2023 02:09
@yuwenzho yuwenzho force-pushed the yuwenzho/fix_gptj branch from 507caa7 to daebe72 Compare August 9, 2023 06:06
@yuwenzho yuwenzho force-pushed the yuwenzho/fix_gptj branch from daebe72 to 293fcdb Compare August 9, 2023 07:28
@jambayk
Copy link
Contributor

jambayk commented Aug 9, 2023

/azp run

@azure-pipelines
Copy link

Azure Pipelines successfully started running 2 pipeline(s).

@jambayk
Copy link
Contributor

jambayk commented Aug 9, 2023

Thanks for the fix! The failing test is not because of this PR. We will look into solving it and then we can retrigger the ci after the fix.

@jambayk jambayk merged commit 33be852 into microsoft:main Aug 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants