Skip to content

Commit

Permalink
Implement numpy array over CPU OrtValues on return
Browse files Browse the repository at this point in the history
  • Loading branch information
yuslepukhin committed May 2, 2024
1 parent dfd4bce commit 3429155
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 105 deletions.
15 changes: 8 additions & 7 deletions onnxruntime/python/onnxruntime_pybind_iobinding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,23 +161,24 @@ void addIoBindingMethods(pybind11::module& m) {
return io_binding->Get()->GetOutputs();
},
py::return_value_policy::reference_internal)
.def("copy_outputs_to_cpu", [](const SessionIOBinding* io_binding) -> std::vector<py::object> {
.def("copy_outputs_to_cpu", [](const SessionIOBinding* io_binding) -> py::list {
const std::vector<OrtValue>& outputs = io_binding->Get()->GetOutputs();
std::vector<py::object> rfetch;
rfetch.reserve(outputs.size());

size_t pos = 0;
const auto& dtm = io_binding->GetInferenceSession()->GetDataTransferManager();

py::list result;
for (const auto& ort_value : outputs) {
if (ort_value.IsTensor()) {
rfetch.push_back(AddTensorAsPyObj(ort_value, &dtm, nullptr));
result.append(AddTensorAsPyObj(ort_value, &dtm, nullptr));
} else if (ort_value.IsSparseTensor()) {
rfetch.push_back(GetPyObjectFromSparseTensor(pos, ort_value, &dtm));
result.append(GetPyObjectFromSparseTensor(pos, ort_value, &dtm));
} else {
rfetch.push_back(AddNonTensorAsPyObj(ort_value, &dtm, nullptr));
result.append(AddNonTensorAsPyObj(ort_value, &dtm, nullptr));
}
++pos;
}
return rfetch;
return result;
});
}

Expand Down
42 changes: 39 additions & 3 deletions onnxruntime/python/onnxruntime_pybind_mlvalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,45 @@ void CreateGenericMLValue(const onnxruntime::InputDefList* input_def_list, const
const std::string& name_input, const pybind11::object& value, OrtValue* p_mlvalue,
bool accept_only_numpy_array = false, bool use_numpy_data_memory = true, MemCpyFunc mem_cpy_to_device = CpuToCpuMemCpy);

void GetPyObjFromTensor(const Tensor& rtensor, pybind11::object& obj,
const DataTransferManager* data_transfer_manager = nullptr,
const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* mem_cpy_to_host_functions = nullptr);
pybind11::object GetPyObjFromTensor(const OrtValue& rtensor,
const DataTransferManager* data_transfer_manager = nullptr,
const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* mem_cpy_to_host_functions = nullptr);

Check warning on line 122 in onnxruntime/python/onnxruntime_pybind_mlvalue.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/python/onnxruntime_pybind_mlvalue.h:122: Lines should be <= 120 characters long [whitespace/line_length] [2]

Check warning on line 122 in onnxruntime/python/onnxruntime_pybind_mlvalue.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/python/onnxruntime_pybind_mlvalue.h:122: Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4]

// The below two functions are used to convert OrtValue to numpy arrays

/// <summary>
/// This function operates on string tensors. Strings are always
/// copied to python and converted to UTF-16/UCS-4/32 dependeing on the platform.
/// This is accomplished using py::cast()
///
/// It is an error to pass a non-tensor or a non-string tensor to this function.
/// </summary>
/// <param name="tensor">Tensor that contains strings</param>
/// <returns>py::array object</returns>
pybind11::array StringTensorToNumpyArray(const Tensor& tensor);

/// <summary>
/// Creates a numpy array with shape over OrtValue memory. Numpy array
/// does not own the memory, but it holds a copy or OrtValue in a py::capsule.
/// OrtValue is destroyed when the numpy array is garbage collected.
/// This is used when the OrtValue memory is on CPU.
/// </summary>
/// <param name="ort_value">OrtValue with data</param>
/// <returns>numpy array</returns>
pybind11::array PrimitiveTensorToNumpyOverOrtValue(const OrtValue& ort_value);

/// <summary>
/// Creates a numpy array with shape with a copy of OrtValue data.
/// This function is used when the OrtValue memory is not on CPU.
/// Eother data_transfer or func should not be null.
/// </summary>
/// <param name="ort_value">Source memory that is not on CPYU.</param>
/// <param name="data_transfer">data transfer manager</param>
/// <param name="func">copy function if data transfer mamanger is not available.</param>
/// <returns></returns>
pybind11::array PrimitiveTensorToNumpyFromDevice(const OrtValue& ort_value,
const DataTransferManager* data_transfer,
MemCpyFunc func);

template <class T>
struct DecRefFn {
Expand Down
12 changes: 5 additions & 7 deletions onnxruntime/python/onnxruntime_pybind_ortvalue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,18 +302,16 @@ void addOrtValueMethods(pybind11::module& m) {
.def("numpy", [](const OrtValue* ml_value) -> py::object {
ORT_ENFORCE(ml_value->IsTensor(), "Only OrtValues that are Tensors are convertible to Numpy objects");

py::object obj;

#ifdef USE_CUDA
GetPyObjFromTensor(ml_value->Get<Tensor>(), obj, nullptr, GetCudaToHostMemCpyFunction());
py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetCudaToHostMemCpyFunction());
#elif USE_ROCM
GetPyObjFromTensor(ml_value->Get<Tensor>(), obj, nullptr, GetRocmToHostMemCpyFunction());
py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetRocmToHostMemCpyFunction());
#elif USE_CANN
GetPyObjFromTensor(ml_value->Get<Tensor>(), obj, nullptr, GetCannToHostMemCpyFunction());
py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetCannToHostMemCpyFunction());
#elif USE_DML
GetPyObjFromTensor(ml_value->Get<Tensor>(), obj, nullptr, GetDmlToHostMemCpyFunction());
py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetDmlToHostMemCpyFunction());
#else
GetPyObjFromTensor(ml_value->Get<Tensor>(), obj, nullptr, nullptr);
py::object obj = GetPyObjFromTensor(*ml_value, nullptr, nullptr);
#endif
return obj;
})
Expand Down
13 changes: 1 addition & 12 deletions onnxruntime/python/onnxruntime_pybind_sparse_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,18 +305,7 @@ void addSparseTensorMethods(pybind11::module& m) {
if (sparse_tensor.IsDataTypeString()) {
// Strings can not be on GPU and require conversion UTF-8 to Python UNICODE
// We need to create a copy.
const int numpy_type = OnnxRuntimeTensorToNumpyType(DataTypeImpl::GetType<std::string>());
ORT_ENFORCE(NPY_OBJECT == numpy_type, "We are expecting to map strings to NPY_OBJECT type");
const auto& values_shape = sparse_tensor.Values().Shape();
py::dtype dtype("object");
py::array result(dtype, values_shape.GetDims(), {});
auto* out_ptr = static_cast<py::object*>(
PyArray_DATA(reinterpret_cast<PyArrayObject*>(result.ptr())));
const std::string* src = sparse_tensor.Values().Data<std::string>();
for (int64_t i = 0, size = values_shape.Size(); i < size; ++i, src++) {
out_ptr[i] = py::cast(*src);
}
return result;
return StringTensorToNumpyArray(sparse_tensor.Values());
} else {
utils::MLTypeCallDispatcher<float, double, int8_t, uint8_t, int16_t, uint16_t, int32_t, uint32_t, int64_t, uint64_t>
t_disp(sparse_tensor.GetElementType());
Expand Down
Loading

0 comments on commit 3429155

Please sign in to comment.