Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional position ids to forward pass #1289

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache_branch: None = None,
Expand Down Expand Up @@ -365,6 +366,9 @@ def forward(
if "attention_mask" in self.input_names:
model_inputs.append(attention_mask)

if "position_ids" in self.input_names:
model_inputs.append(position_ids)

if past_key_values is not None:
model_inputs += past_key_values

Expand Down Expand Up @@ -413,6 +417,9 @@ def forward(
"attention_mask": attention_mask.cpu().detach().numpy(),
}

if "position_ids" in self.input_names:
onnx_inputs["position_ids"] = position_ids.cpu().detach().numpy()

if self.parent_model.use_merged is True:
onnx_inputs["use_cache_branch"] = use_cache_branch_tensor.cpu().detach().numpy()

Expand All @@ -429,6 +436,9 @@ def forward(
"attention_mask": attention_mask,
}

if "position_ids" in self.input_names:
onnx_inputs["position_ids"] = position_ids

if self.parent_model.use_merged is True:
onnx_inputs["use_cache_branch"] = use_cache_branch_tensor

Expand Down
4 changes: 4 additions & 0 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,7 @@ def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs,
Expand All @@ -650,6 +651,7 @@ def forward(
outputs = self.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
labels=labels,
)
Expand All @@ -658,13 +660,15 @@ def forward(
input_ids=input_ids[:, -1:],
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
)
else:
outputs = self.decoder_with_past(
input_ids=input_ids[:, -1:],
past_key_values=past_key_values,
attention_mask=attention_mask,
labels=labels,
position_ids=position_ids,
)

return CausalLMOutputWithCrossAttentions(
Expand Down