diff --git a/optimum/onnxruntime/base.py b/optimum/onnxruntime/base.py index f9091650b20..bbf7421a86a 100644 --- a/optimum/onnxruntime/base.py +++ b/optimum/onnxruntime/base.py @@ -326,6 +326,7 @@ def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, labels: Optional[torch.LongTensor] = None, use_cache_branch: None = None, @@ -365,6 +366,9 @@ def forward( if "attention_mask" in self.input_names: model_inputs.append(attention_mask) + if "position_ids" in self.input_names: + model_inputs.append(position_ids) + if past_key_values is not None: model_inputs += past_key_values @@ -413,6 +417,9 @@ def forward( "attention_mask": attention_mask.cpu().detach().numpy(), } + if "position_ids" in self.input_names: + onnx_inputs["position_ids"] = position_ids.cpu().detach().numpy() + if self.parent_model.use_merged is True: onnx_inputs["use_cache_branch"] = use_cache_branch_tensor.cpu().detach().numpy() @@ -429,6 +436,9 @@ def forward( "attention_mask": attention_mask, } + if "position_ids" in self.input_names: + onnx_inputs["position_ids"] = position_ids + if self.parent_model.use_merged is True: onnx_inputs["use_cache_branch"] = use_cache_branch_tensor diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 1ffc81d8832..91f65153b40 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -642,6 +642,7 @@ def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, labels: Optional[torch.LongTensor] = None, **kwargs, @@ -650,6 +651,7 @@ def forward( outputs = self.decoder( input_ids=input_ids, attention_mask=attention_mask, + position_ids=position_ids, past_key_values=past_key_values, labels=labels, ) @@ -658,6 +660,7 @@ def forward( input_ids=input_ids[:, -1:], past_key_values=past_key_values, attention_mask=attention_mask, + position_ids=position_ids, ) else: outputs = self.decoder_with_past( @@ -665,6 +668,7 @@ def forward( past_key_values=past_key_values, attention_mask=attention_mask, labels=labels, + position_ids=position_ids, ) return CausalLMOutputWithCrossAttentions(