Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Throw if unique_ptr or array allocation fails due to SafeInt overflow #18941

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 50 additions & 21 deletions include/onnxruntime/core/framework/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "core/session/onnxruntime_c_api.h"
#include "ortdevice.h"
#include "ortmemoryinfo.h"
#include <cassert>

Check warning on line 11 in include/onnxruntime/core/framework/allocator.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/framework/allocator.h#L11

Found C++ system header after other header. Should be: allocator.h, c system, c++ system, other. [build/include_order] [4]
Raw output
include/onnxruntime/core/framework/allocator.h:11:  Found C++ system header after other header. Should be: allocator.h, c system, c++ system, other.  [build/include_order] [4]
edgchen1 marked this conversation as resolved.
Show resolved Hide resolved
#include <map>

// This configures the arena based allocator used by ORT
Expand Down Expand Up @@ -100,7 +101,8 @@
* \param out Total size required after any alignment is applied
* \return true, successful. false, overflow
*/
[[nodiscard]] static bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t alignment, size_t* out) noexcept;
[[nodiscard]] static bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t alignment,
size_t* out) noexcept;

/**
* https://cwe.mitre.org/data/definitions/190.html
Expand All @@ -120,8 +122,10 @@
*/
void* AllocArray(size_t nmemb, size_t size) {
size_t len;
if (!CalcMemSizeForArray(nmemb, size, &len))
return nullptr;
if (!CalcMemSizeForArray(nmemb, size, &len)) {
ORT_THROW("Invalid size requested for allocation: ", nmemb, " * ", size);
}

return Alloc(len);
}

Expand All @@ -131,8 +135,10 @@
template <size_t alignment>
void* AllocArrayWithAlignment(size_t nmemb, size_t size) {
size_t len;
if (!CalcMemSizeForArrayWithAlignment(nmemb, size, alignment, &len))
return nullptr;
if (!CalcMemSizeForArrayWithAlignment(nmemb, size, alignment, &len)) {
ORT_THROW("Invalid size requested for allocation: ", nmemb, " * ", size, " with alignment ", alignment);
}

return Alloc(len);
}

Expand All @@ -144,13 +150,14 @@
@param stream Which stream instance allocated chunk will be used with.
@param wait_fn If the allocator want to dynamic reuse a chunk from another stream, use this wait_fn to sync on
the target stream to make the reuse safe.
@returns std::unique_ptr with allocated memory and deleter.
@returns std::unique_ptr with allocated memory and deleter. Throws if it cannot allocate memory.
skottmckay marked this conversation as resolved.
Show resolved Hide resolved
*/
template <typename T>
static IAllocatorUniquePtr<T> MakeUniquePtr(std::shared_ptr<IAllocator> allocator, size_t count_or_bytes,
bool use_reserve = false,
Stream* stream = nullptr, WaitNotificationFn wait_fn = nullptr) {
if (allocator == nullptr) return nullptr;
ValidateAllocator(allocator);

// for now limit to fundamental types. we could support others, but to do so either we or the caller
// needs to call the dtor for the objects, for buffers allocated on device we don't have destructor
// static_assert(std::is_fundamental<T>::value, "Fundamental type required as no destructors are called.");
Expand All @@ -161,38 +168,60 @@
if constexpr (!std::is_void<T>::value) {
// sizeof(void) isn't valid, but the compiler isn't smart enough to ignore that this line isn't
// reachable if T is void. use std::conditional to 'use' void* in the sizeof call
if (!CalcMemSizeForArray(
count_or_bytes, sizeof(typename std::conditional<std::is_void<T>::value, void*, T>::type), &alloc_size)) {
return nullptr;
}
const auto size = sizeof(typename std::conditional<std::is_void<T>::value, void*, T>::type);
skottmckay marked this conversation as resolved.
Show resolved Hide resolved
alloc_size = CheckedCalcMemSizeForArray(count_or_bytes, size);
}

// allocate
T* p = static_cast<T*>(AllocateBufferWithOptions(*allocator, alloc_size, use_reserve, stream, std::move(wait_fn)));
return IAllocatorUniquePtr<T>{
p,
[allocator = std::move(allocator)](T* p) { allocator->Free(p); }};
return IAllocatorUniquePtr<T>{p,
[allocator = std::move(allocator)](T* p) {

Check warning on line 178 in include/onnxruntime/core/framework/allocator.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/framework/allocator.h#L178

Add #include <utility> for move [build/include_what_you_use] [4]
Raw output
include/onnxruntime/core/framework/allocator.h:178:  Add #include <utility> for move  [build/include_what_you_use] [4]
allocator->Free(p);
}};
}

/**
Create a std::unique_ptr that is allocated and freed by the provided OrtAllocator.
@param ort_allocator The allocator.
@param count_or_bytes The exact bytes to allocate if T is void, otherwise the number of elements to allocate.
@returns std::unique_ptr with allocated memory and deleter. Throws if it cannot allocate memory.
*/
template <typename T>
static IAllocatorUniquePtr<T> MakeUniquePtrFromOrtAllocator(OrtAllocator* ort_allocator, size_t count_or_bytes) {
if (!ort_allocator) return nullptr;
ValidateAllocator(ort_allocator);

size_t alloc_size = count_or_bytes;
// if T is not void, 'count_or_bytes' == number of items so allow for that
if constexpr (!std::is_void<T>::value) {
// sizeof(void) isn't valid, but the compiler isn't smart enough to ignore that this line isn't
// reachable if T is void. use std::conditional to 'use' void* in the sizeof call
if (!CalcMemSizeForArray(
count_or_bytes, sizeof(typename std::conditional<std::is_void<T>::value, void*, T>::type), &alloc_size)) {
return nullptr;
}
const auto size = sizeof(typename std::conditional<std::is_void<T>::value, void*, T>::type);
alloc_size = CheckedCalcMemSizeForArray(count_or_bytes, size);
}
T* p = static_cast<T*>(ort_allocator->Alloc(ort_allocator, count_or_bytes));
return IAllocatorUniquePtr<T>{p, [ort_allocator](T* p) { ort_allocator->Free(ort_allocator, p); }};

T* p = static_cast<T*>(ort_allocator->Alloc(ort_allocator, alloc_size));
return IAllocatorUniquePtr<T>{p,
[ort_allocator](T* p) {
ort_allocator->Free(ort_allocator, p);
}};
}

private:
// validation functions. split out from methods that are templatized on the data type to minimize binary size.
template <typename T>
static void ValidateAllocator(const T& allocator) {
ORT_ENFORCE(allocator != nullptr);
}

static size_t CheckedCalcMemSizeForArray(size_t count, size_t size) {
size_t alloc_size = 0;
if (!CalcMemSizeForArray(count, size, &alloc_size)) {
ORT_THROW("Invalid size requested for allocation: ", count, " * ", size);
}

return alloc_size;
}

OrtMemoryInfo memory_info_;
};

Expand Down
3 changes: 0 additions & 3 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,6 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
if (packed_b_size_ == 0) return Status::OK();
auto qptr = tensor.Data<uint8_t>();
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true);
if (packed_b_ == nullptr) {
return Status::OK();
}
std::memset(packed_b_.get(), 0, packed_b_size_);
MlasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, static_cast<int>(nbits_),
is_asym_, false, compt_type, pool);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ bool IAllocator::CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, siz
ORT_CATCH(const OnnxRuntimeException& ex) {
// overflow in calculating the size thrown by SafeInt.
ORT_HANDLE_EXCEPTION([&]() {
LOGS_DEFAULT(ERROR) << ex.what();
LOGS_DEFAULT(ERROR) << ex.what() << " nmemb=" << nmemb << " size=" << size << " alignment=" << alignment;
ok = false;
});
}
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/core/framework/sparse_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ Status SparseTensor::AllocateBuffer(int64_t buffer_size, size_t num_values) {
ORT_RETURN_IF_NOT(buffer_size_t > values_bytes,
"Values size ", static_cast<size_t>(values_bytes), " must be less than total buffer size: ", buffer_size);
auto data_ptr = IAllocator::MakeUniquePtr<void>(allocator_, buffer_size_t);
ORT_RETURN_IF(data_ptr == nullptr, "SparseTensor Allocation failed for size: ", buffer_size);
if (IsDataTypeString()) {
// We own the buffer, so we must properly construct strings. Neither of the Tensors
// we construct on top of the buffer own it. We are constructing empty strings, hopefully
Expand Down Expand Up @@ -592,4 +591,4 @@ Status SparseTensor::Copy(const IDataTransfer& data_transfer, SparseTensor& dst_

} // namespace onnxruntime

#endif // !defined(DISABLE_SPARSE_TENSORS)
#endif // !defined(DISABLE_SPARSE_TENSORS)
Loading