Skip to content

Commit

Permalink
Fix create ort value hardcoded memory info to CPU (#10510)
Browse files Browse the repository at this point in the history
* Fix create ort value hardcoded memory info to CPU

* Remove unneeded check

* Remove unneeded header

* Remove unneeded header

* Update ort_ops.cpp

* Update ort_ops.cpp

* Update ort_ops.cpp

* Update ort_ops.cpp

Co-authored-by: root <[email protected]>
  • Loading branch information
anhnguyen7198 and root authored Feb 15, 2022
1 parent 1cdc23a commit 0c3e889
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 12 deletions.
6 changes: 5 additions & 1 deletion orttraining/orttraining/eager/ort_aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,15 @@ OrtValue create_ort_value(
return impl->tensor();
}

OrtMemoryInfo *mem_info;
Ort::ThrowOnError(Ort::GetApi().CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &mem_info));

OrtValue ort_tensor;
CreateMLValue(
tensor.data_ptr(),
ort_scalar_type_from_aten(tensor.scalar_type()),
tensor.sizes().vec(),
*mem_info,
&ort_tensor);
return ort_tensor;
}
Expand Down Expand Up @@ -544,4 +548,4 @@ at::Tensor& add__Tensor(
//#pragma endregion

} // namespace eager
} // namespace torch_ort
} // namespace torch_ort
3 changes: 2 additions & 1 deletion orttraining/orttraining/eager/ort_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ void createInplaceOutputValue(OrtValue& input, V<int64_t> shape, OrtValue* p_mlv
onnxruntime::ReshapeHelper helper(input.Get<onnxruntime::Tensor>().Shape(), target_shape);
onnxruntime::TensorShape new_shape(target_shape);
CreateMLValue(input_ort_tensor->MutableDataRaw(),
input_ort_tensor->DataType(), new_shape, p_mlvalue);
input_ort_tensor->DataType(), new_shape,
input_ort_tensor->Location(), p_mlvalue);
}

template void createInplaceOutputValue<c10::ArrayRef>(OrtValue& input, c10::ArrayRef<int64_t> shape, OrtValue* p_mlvalue);
Expand Down
20 changes: 13 additions & 7 deletions orttraining/orttraining/eager/ort_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,28 @@ void CreateMLValue(onnxruntime::AllocatorPtr alloc,
onnxruntime::DataTypeImpl::GetType<onnxruntime::Tensor>()->GetDeleteFunc());
}

void CreateMLValue(void* data_ptr, onnxruntime::MLDataType element_type, onnxruntime::TensorShape& shape, OrtValue* p_mlvalue){
OrtMemoryInfo *cpu_info;
Ort::ThrowOnError(Ort::GetApi().CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &cpu_info));
void CreateMLValue(void* data_ptr,
onnxruntime::MLDataType element_type,
onnxruntime::TensorShape& shape,
const OrtMemoryInfo& memory_info,
OrtValue* p_mlvalue) {
std::unique_ptr<onnxruntime::Tensor> p_tensor = std::make_unique<onnxruntime::Tensor>(element_type,
shape,
data_ptr,
*cpu_info);
memory_info);

p_mlvalue->Init(p_tensor.release(),
onnxruntime::DataTypeImpl::GetType<onnxruntime::Tensor>(),
onnxruntime::DataTypeImpl::GetType<onnxruntime::Tensor>()->GetDeleteFunc());
}

void CreateMLValue(void* data_ptr, onnxruntime::MLDataType element_type, const std::vector<int64_t>& dims, OrtValue* p_mlvalue) {
void CreateMLValue(void* data_ptr,
onnxruntime::MLDataType element_type,
const std::vector<int64_t>& dims,
const OrtMemoryInfo& memory_info,
OrtValue* p_mlvalue) {
onnxruntime::TensorShape shape(dims);
CreateMLValue(data_ptr, element_type, shape, p_mlvalue);
CreateMLValue(data_ptr, element_type, shape, memory_info, p_mlvalue);
}

std::vector<int64_t> GetStrides(gsl::span<const int64_t> shape) {
Expand All @@ -52,4 +58,4 @@ std::vector<int64_t> GetStrides(gsl::span<const int64_t> shape) {
}

} // namespace eager
} // namespace torch_ort
} // namespace torch_ort
15 changes: 12 additions & 3 deletions orttraining/orttraining/eager/ort_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,17 @@ void CreateMLValue(onnxruntime::AllocatorPtr alloc,
const std::vector<int64_t>& dims,
OrtValue* p_mlvalue);

void CreateMLValue(void* data_ptr, onnxruntime::MLDataType element_type, const std::vector<int64_t>& dims, OrtValue* p_mlvalue);
void CreateMLValue(void* data_ptr, onnxruntime::MLDataType element_type, onnxruntime::TensorShape& shape, OrtValue* p_mlvalue);
void CreateMLValue(void* data_ptr,
onnxruntime::MLDataType element_type,
const std::vector<int64_t>& dims,
const OrtMemoryInfo& memory_info,
OrtValue* p_mlvalue);

void CreateMLValue(void* data_ptr,
onnxruntime::MLDataType element_type,
onnxruntime::TensorShape& shape,
const OrtMemoryInfo& memory_info,
OrtValue* p_mlvalue);

template <typename T>
inline void CopyVectorToTensor(onnxruntime::ORTInvoker& invoker,
Expand Down Expand Up @@ -57,4 +66,4 @@ inline void CopyVectorToTensor<bool>(onnxruntime::ORTInvoker& /*invoker*/,
std::vector<int64_t> GetStrides(gsl::span<const int64_t> shape);

} // namespace eager
} // namespace torch_ort
} // namespace torch_ort

0 comments on commit 0c3e889

Please sign in to comment.