From 42507b3a6127f555df71a885aa96cca0c3883f99 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Mon, 14 Aug 2023 02:01:08 +0000 Subject: [PATCH 1/3] Add position ids for LLaMA forward pass --- optimum/onnxruntime/base.py | 10 ++++++++++ optimum/onnxruntime/modeling_decoder.py | 3 +++ 2 files changed, 13 insertions(+) diff --git a/optimum/onnxruntime/base.py b/optimum/onnxruntime/base.py index f9091650b20..fc8273b779f 100644 --- a/optimum/onnxruntime/base.py +++ b/optimum/onnxruntime/base.py @@ -329,6 +329,7 @@ def forward( past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, labels: Optional[torch.LongTensor] = None, use_cache_branch: None = None, + **kwargs, ) -> CausalLMOutputWithCrossAttentions: # adding use_cache_branch in the signature here is just a hack for IO Binding use_torch = isinstance(input_ids, torch.Tensor) @@ -365,6 +366,9 @@ def forward( if "attention_mask" in self.input_names: model_inputs.append(attention_mask) + if "position_ids" in self.input_names: + model_inputs.append(kwargs["position_ids"]) + if past_key_values is not None: model_inputs += past_key_values @@ -413,6 +417,9 @@ def forward( "attention_mask": attention_mask.cpu().detach().numpy(), } + if "position_ids" in self.input_names: + onnx_inputs["position_ids"] = kwargs["position_ids"].cpu().detach().numpy() + if self.parent_model.use_merged is True: onnx_inputs["use_cache_branch"] = use_cache_branch_tensor.cpu().detach().numpy() @@ -429,6 +436,9 @@ def forward( "attention_mask": attention_mask, } + if "position_ids" in self.input_names: + onnx_inputs["position_ids"] = kwargs["position_ids"] + if self.parent_model.use_merged is True: onnx_inputs["use_cache_branch"] = use_cache_branch_tensor diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 1ffc81d8832..38fa80bc620 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -652,12 +652,14 @@ def forward( attention_mask=attention_mask, past_key_values=past_key_values, labels=labels, + **kwargs, ) elif self.use_merged is True: outputs = self.decoder( input_ids=input_ids[:, -1:], past_key_values=past_key_values, attention_mask=attention_mask, + **kwargs, ) else: outputs = self.decoder_with_past( @@ -665,6 +667,7 @@ def forward( past_key_values=past_key_values, attention_mask=attention_mask, labels=labels, + **kwargs, ) return CausalLMOutputWithCrossAttentions( From c1eae33fe7970626df067e8b23fb50530171793a Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Tue, 15 Aug 2023 02:38:41 +0000 Subject: [PATCH 2/3] Add position ids to function signature --- optimum/onnxruntime/base.py | 7 ++++--- optimum/onnxruntime/modeling_decoder.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/optimum/onnxruntime/base.py b/optimum/onnxruntime/base.py index fc8273b779f..040081e6ec7 100644 --- a/optimum/onnxruntime/base.py +++ b/optimum/onnxruntime/base.py @@ -326,6 +326,7 @@ def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, labels: Optional[torch.LongTensor] = None, use_cache_branch: None = None, @@ -367,7 +368,7 @@ def forward( model_inputs.append(attention_mask) if "position_ids" in self.input_names: - model_inputs.append(kwargs["position_ids"]) + model_inputs.append(position_ids) if past_key_values is not None: model_inputs += past_key_values @@ -418,7 +419,7 @@ def forward( } if "position_ids" in self.input_names: - onnx_inputs["position_ids"] = kwargs["position_ids"].cpu().detach().numpy() + onnx_inputs["position_ids"] = position_ids.cpu().detach().numpy() if self.parent_model.use_merged is True: onnx_inputs["use_cache_branch"] = use_cache_branch_tensor.cpu().detach().numpy() @@ -437,7 +438,7 @@ def forward( } if "position_ids" in self.input_names: - onnx_inputs["position_ids"] = kwargs["position_ids"] + onnx_inputs["position_ids"] = position_ids if self.parent_model.use_merged is True: onnx_inputs["use_cache_branch"] = use_cache_branch_tensor diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 38fa80bc620..91f65153b40 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -642,6 +642,7 @@ def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, labels: Optional[torch.LongTensor] = None, **kwargs, @@ -650,16 +651,16 @@ def forward( outputs = self.decoder( input_ids=input_ids, attention_mask=attention_mask, + position_ids=position_ids, past_key_values=past_key_values, labels=labels, - **kwargs, ) elif self.use_merged is True: outputs = self.decoder( input_ids=input_ids[:, -1:], past_key_values=past_key_values, attention_mask=attention_mask, - **kwargs, + position_ids=position_ids, ) else: outputs = self.decoder_with_past( @@ -667,7 +668,7 @@ def forward( past_key_values=past_key_values, attention_mask=attention_mask, labels=labels, - **kwargs, + position_ids=position_ids, ) return CausalLMOutputWithCrossAttentions( From f4c0da915ee7fd9043c4b6a9901402601669f337 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Tue, 15 Aug 2023 02:40:42 +0000 Subject: [PATCH 3/3] Remove kwargs from function signature --- optimum/onnxruntime/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/optimum/onnxruntime/base.py b/optimum/onnxruntime/base.py index 040081e6ec7..bbf7421a86a 100644 --- a/optimum/onnxruntime/base.py +++ b/optimum/onnxruntime/base.py @@ -330,7 +330,6 @@ def forward( past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, labels: Optional[torch.LongTensor] = None, use_cache_branch: None = None, - **kwargs, ) -> CausalLMOutputWithCrossAttentions: # adding use_cache_branch in the signature here is just a hack for IO Binding use_torch = isinstance(input_ids, torch.Tensor)