Skip to content

Commit

Permalink
Fix a potential race in the CUDA TopK kernel (#19917)
Browse files Browse the repository at this point in the history
### Description
If the `K` value is flowing through as a tensor, we are updating a
mutable member of the `TopK` class and basing the compute off that -
which is likely to cause data race issues with concurrent Run() calls
and `K` value changes.


### Motivation and Context
Fix potential race in CUDA TopK kernel
  • Loading branch information
hariharans29 authored Mar 15, 2024
1 parent bcf47d3 commit 42399df
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
24 changes: 17 additions & 7 deletions onnxruntime/core/providers/cuda/math/topk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ TopK<inputk>::TopK(const OpKernelInfo& info) : CudaKernel(info) {
info.GetAttrOrDefault<int64_t>("largest", &largest_, 1);
info.GetAttrOrDefault<int64_t>("sorted", &sorted_, 1);
if (!inputk) {
info.GetAttrOrDefault<int64_t>("k", &K_, 0);
info.GetAttrOrDefault<int64_t>("k", &attr_k_, 0);
}
}

Expand All @@ -67,7 +67,7 @@ TopK<inputk>::TopK(const OpKernelInfo& info) : CudaKernel(info) {
static_cast<int64_t*>(tensor_I->MutableDataRaw()), \
elem_nums_cuda, \
elem_nums.size(), \
axis, K_, largest_, sorted_, N, dimension)
axis, k_value, largest_, sorted_, N, dimension)

template <bool inputk>
Status TopK<inputk>::ComputeInternal(OpKernelContext* ctx) const {
Expand All @@ -77,19 +77,29 @@ Status TopK<inputk>::ComputeInternal(OpKernelContext* ctx) const {
int32_t axis = static_cast<int32_t>(axis_ < 0 ? rank + axis_ : axis_);
ORT_ENFORCE(axis > -1 && axis < rank);

int64_t k_value = 0;
if (inputk) {
auto tensor_K = ctx->Input<Tensor>(1);
ORT_ENFORCE(nullptr != tensor_K);
K_ = *tensor_K->Data<int64_t>();
ORT_ENFORCE(K_ >= 0 && K_ <= tensor_X->Shape().GetDims()[axis]);
k_value = *tensor_K->Data<int64_t>();
} else { // from attribute
k_value = attr_k_;
}

auto output_shape = tensor_X->Shape();
output_shape[axis] = K_;
// Now that we know the value of 'K' and the input shape,
// make a final validation before going to the implementation
const auto& input_shape = tensor_X->Shape();
if ((k_value < 0) || (k_value > input_shape.GetDims()[axis])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Value of K outside range. K value: ", k_value,
". Input shape: ", input_shape, " . Axis: ", axis);
}

auto output_shape = input_shape;
output_shape[axis] = k_value;
auto tensor_V = ctx->Output(0, output_shape);
auto tensor_I = ctx->Output(1, output_shape);

if (0 == K_) {
if (output_shape.Size() == 0) { // Bail out early if the output is going to be empty
return Status::OK();
}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cuda/math/topk.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class TopK final : public CudaKernel {
int64_t axis_;
int64_t largest_;
int64_t sorted_;
mutable int64_t K_;
int64_t attr_k_;
};
} // namespace cuda
} // namespace onnxruntime

0 comments on commit 42399df

Please sign in to comment.