Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Implement encoder_hidden_states as input in GPT2_BeamSearch Node #18050

Open
Borntowarn opened this issue Oct 22, 2023 · 2 comments
Labels
feature request request for unsupported feature or enhancement model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.

Comments

@Borntowarn
Copy link

Describe the feature request

I try to use the convert_generation.py script to create a GPT2 code generation model with beam search with encoder_hidden_states (timesformer output) as input (my base model is Neleac/timesformer-gpt2-video-captioning), but there's no such flags in scripts or node input in graph. So GPT2 coverting as separate model without link to timesformer output.

So I was wondering if there are any plans to implement this option. I've tried manually manipulating the graph and script to no avail.

Describe scenario use case

Usage of Encoder-Decoder (such as SpeechEncoderDecoderModel or VisionEncoderDecoderModel from HF)

@Borntowarn Borntowarn added the feature request request for unsupported feature or enhancement label Oct 22, 2023
@github-actions github-actions bot added the model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. label Oct 22, 2023
@Borntowarn Borntowarn changed the title [Feature Request] [Feature Request] Implement encoder_hidden_states as input in GPT2_BeamSearch Node Oct 22, 2023
@tianleiwu
Copy link
Contributor

The convert generation supports encoder-decoder models (we tested T5, Bart). See the comments in the script for example uage:

Example 5: convert T5 model with beam search. All in one step:
python convert_generation.py -m t5-small --model_type t5 --output ./models/t5/onnx_models/t5_small_beam_search.onnx
Example 6: convert T5 model with beam search containing specific cuda optimizations. All in one step:
python convert_generation.py -m t5-small --model_type t5 --output ./models/t5/onnx_models/t5_small_beam_search.onnx \
--use_gpu --past_present_share_buffer --use_decoder_masked_attention

ORT also support Whisper in beam search. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/models/whisper/README.md for detail.

@Borntowarn
Copy link
Author

Perhaps I wrote not quite clear but I need to Bert/T5/GPT2 encoder has encoder_hidden_states from VisionEncoder (image embeddings for captioning implementation) as inputs in ths line

expected_inputs = ["input_ids", "position_ids", "attention_mask"] + [f"past_{i}" for i in range(layer_count)]
if len(graph.input) != len(expected_inputs):
raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}")

I guess you need to add it in

.Input(0, "input_ids", "The sequence used as a prompt for the generation in the encoder subgraph. Shape is (batch_size, sequence_length)", "F")
.Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I")
.Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional)
.Input(3, "num_beams", "Number of beams for beam search. 1 means no beam search. Shape is (1)", "I")
.Input(4, "num_return_sequences", "The number of returned sequences in the batch. Shape is (1)", "I")
.Input(5, "length_penalty",
"Exponential penalty to the length. Default value 1.0 means no penalty."
"Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences."
"Shape is (1,)",
"T", OpSchema::Optional)
.Input(6, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional)
.Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional)
.Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional)
.Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional)
.Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional)
.Input(11, "logits_processor", "Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)", "I", OpSchema::Optional)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request request for unsupported feature or enhancement model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.
Projects
None yet
Development

No branches or pull requests

2 participants