Skip to content

Commit

Permalink
support 4d mask
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyems committed Nov 4, 2022
1 parent ecf7246 commit a44d023
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 35 deletions.
18 changes: 12 additions & 6 deletions onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,19 @@ class GenerateBase {

if (attention_mask != nullptr) {
const auto& dims_attn = attention_mask->Shape().GetDims();
if (dims_attn.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'attention_mask' is expected to have 2 dimensions, got ", dims_attn.size());
}
if (!SpanEq(dims_attn, dims)) {
if (dims_attn.size() == 2) {
if (!SpanEq(dims_attn, dims)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'attention_mask' is expected to have same shape as input_ids");
}
} else if (dims_attn.size() == 4) {
if (dims_attn[0] != dims[0] || dims_attn[1] != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'attention_mask' is expected to shape [batch_size, 1, max_sequence_length, max_sequence_length]");
}
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'attention_mask' is expected to have same shape as input_ids");
"Input 'attention_mask' is expected to have 2 or 4 dimensions, got ", dims_attn.size());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ Status CreateGptInputs(
OrtValue attention_mask;
if (attn_mask_value != nullptr) {
const Tensor& attn_mask = attn_mask_value->Get<Tensor>();
Tensor::InitOrtValue(element_type, input_ids_shape, const_cast<Tensor*>(&attn_mask)->MutableData<int32_t>(),
const TensorShape& attn_mask_shape = attn_mask.Shape(); // 2d or 4d
Tensor::InitOrtValue(element_type, attn_mask_shape, const_cast<Tensor*>(&attn_mask)->MutableData<int32_t>(),
allocator->Info(), attention_mask);
} else {
auto mask_type = DataTypeImpl::GetType<int32_t>();
Expand Down Expand Up @@ -176,9 +177,16 @@ Status CreateGptInputs(

// Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length)
// TODO(tianleiwu): Try expand outputs after first subgraph call instead. That may get better performance.
ExpandInputs<int32_t>(input_ids, num_beams, allocator, expanded_input_ids);
ExpandInputs<int32_t>(position_ids, num_beams, allocator, expanded_position_ids);
ExpandInputs<int32_t>(attention_mask, num_beams, allocator, expanded_attention_mask);
if (num_beams == 1) {
expanded_input_ids = input_ids;
expanded_position_ids = position_ids;
expanded_attention_mask = attention_mask;
} else {
// bugbug: 4d not supported here
ExpandInputs<int32_t>(input_ids, num_beams, allocator, expanded_input_ids);
ExpandInputs<int32_t>(position_ids, num_beams, allocator, expanded_position_ids);
ExpandInputs<int32_t>(attention_mask, num_beams, allocator, expanded_attention_mask);
}

return Status::OK();
}
Expand Down Expand Up @@ -637,19 +645,22 @@ Status UpdateGptFeeds(
next_inputs[1] = position_ids;
// Update attention mask
const OrtValue& old_mask = next_inputs[2];
const int32_t* old_mask_data = old_mask.Get<Tensor>().Data<int32_t>();
int64_t mask_dims[] = {batch_beam_size, current_length};
TensorShape mask_shape(&mask_dims[0], 2);
OrtValue attention_mask;
Tensor::InitOrtValue(int32_type, mask_shape, allocator, attention_mask);
int32_t* mask_data = attention_mask.GetMutable<Tensor>()->MutableData<int32_t>();
for (int i = 0; i < batch_beam_size; i++) {
for (int j = 0; j < current_length - 1; j++) {
mask_data[i * current_length + j] = old_mask_data[i * (current_length - 1) + j];
const auto& mask_dims = old_mask.Get<Tensor>().Shape().GetDims();
if (mask_dims.size() == 2) {
const int32_t* old_mask_data = old_mask.Get<Tensor>().Data<int32_t>();
int64_t mask_dims[] = {batch_beam_size, current_length};
TensorShape mask_shape(&mask_dims[0], 2);
OrtValue attention_mask;
Tensor::InitOrtValue(int32_type, mask_shape, allocator, attention_mask);
int32_t* mask_data = attention_mask.GetMutable<Tensor>()->MutableData<int32_t>();
for (int i = 0; i < batch_beam_size; i++) {
for (int j = 0; j < current_length - 1; j++) {
mask_data[i * current_length + j] = old_mask_data[i * (current_length - 1) + j];
}
mask_data[i * current_length + current_length - 1] = 1;
}
mask_data[i * current_length + current_length - 1] = 1;
}
next_inputs[2] = attention_mask;
next_inputs[2] = attention_mask;
} // if mask_dims.size() == 4 do nothing

// Update past state
if (num_beams == 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -811,19 +811,22 @@ Status UpdateGptFeeds(

// Update attention mask
const OrtValue& old_mask = next_inputs[2];
const int32_t* old_mask_data = old_mask.Get<Tensor>().Data<int32_t>();
int64_t mask_dims[] = {batch_beam_size, current_length};
TensorShape mask_shape(&mask_dims[0], 2);
OrtValue attention_mask;
auto mask_type = DataTypeImpl::GetType<int32_t>();
Tensor::InitOrtValue(mask_type, mask_shape, allocator, attention_mask);
int32_t* mask_data = attention_mask.GetMutable<Tensor>()->MutableData<int32_t>();

// Launch kernel to update position_ids and attention_mask for next iteration
cuda::LaunchUpdateGptKernel(old_mask_data, mask_data, position_data, batch_beam_size, current_length,
reinterpret_cast<cudaStream_t>(stream));

next_inputs[2] = attention_mask;
const auto& mask_dims = old_mask.Get<Tensor>().Shape().GetDims();
if (mask_dims.size() == 2) {
const int32_t* old_mask_data = old_mask.Get<Tensor>().Data<int32_t>();
int64_t mask_dims[] = {batch_beam_size, current_length};
TensorShape mask_shape(&mask_dims[0], 2);
OrtValue attention_mask;
auto mask_type = DataTypeImpl::GetType<int32_t>();
Tensor::InitOrtValue(mask_type, mask_shape, allocator, attention_mask);
int32_t* mask_data = attention_mask.GetMutable<Tensor>()->MutableData<int32_t>();

// Launch kernel to update position_ids and attention_mask for next iteration
cuda::LaunchUpdateGptKernel(old_mask_data, mask_data, position_data, batch_beam_size, current_length,
reinterpret_cast<cudaStream_t>(stream));

next_inputs[2] = attention_mask;
} // do nothing for mask_dims.size() == 4

// Update past state
if (num_beams == 1) {
Expand Down

0 comments on commit a44d023

Please sign in to comment.