Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENHANCEMENT] use torch.onnx.export instead of convert_graph_to_onnx #347

Open
HenryL27 opened this issue Nov 28, 2023 · 1 comment
Open
Labels
enhancement New feature or request

Comments

@HenryL27
Copy link

Is your feature request related to a problem?
This is related to compiling ONNX models for upload to opensearch

two problems with the status quo:

  1. transformers.convert_graph_to_onnx.convert will be deprecated in the next major version of huggingface transformers
  2. transformers.convert_graph_to_onnx.convert only grabs the base model; so the head on top of the base logits is left off. For embedding models, we got around this by implementing the pooling layer in ml-commons for ONNX models, but for other pretrained classification heads (e.g. cross-encoders) this is simply impossible.

What solution would you like?
Instead use torch.onnx.export. An example (that implements this for cross encoders):

torch.onnx.export(
    model=model,
    args=(features['input_ids'], features['attention_mask'], features['token_type_ids']),
    f=f"/tmp/{mname}.onnx",
    input_names=['input_ids', 'attention_mask', 'token_type_ids'],
    output_names=['output'],
    dynamic_axes={
        'input_ids': {0: 'batch_size', 1: 'sequence_length'},
        'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
        'token_type_ids': {0: 'batch_size', 1: 'sequence_length'},
        'output': {0: 'batch_size'}
    }
)

usage is similar to torch.jit.trace, which we use for torchscript compilation

This will simplify the code in ml-commons that drives ONNX models

What alternatives have you considered?
There are probably other ways to export a complete model to ONNX (and if we want to support TF we might need to look at options for that) but this seems pretty clean.

Do you have any additional context?
Original comment

We should probably invest in supporting all the new kinds of models that will be coming from [RFC] Support more local model types in opensearch-py-ml.

@dblock
Copy link
Member

dblock commented Jun 6, 2024

[Triage -- attendees 1, 2, 3, 4, 5, 6, 7]

Can this be closed with #1615?

@dblock dblock removed the untriaged label Jun 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants