Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement numpy array over CPU OrtValues on return values #20539

Merged
merged 10 commits into from
May 8, 2024
2 changes: 1 addition & 1 deletion onnxruntime/python/onnxruntime_inference_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ def get_outputs_as_ortvaluevector(self):
return self._iobinding.get_outputs()

def copy_outputs_to_cpu(self):
"""Copy output contents to CPU (if on another device). No-op if already on the CPU."""
"""Copy output contents to CPU."""
return self._iobinding.copy_outputs_to_cpu()

def clear_binding_inputs(self):
Expand Down
18 changes: 11 additions & 7 deletions onnxruntime/python/onnxruntime_pybind_iobinding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,23 +161,27 @@ void addIoBindingMethods(pybind11::module& m) {
return io_binding->Get()->GetOutputs();
},
py::return_value_policy::reference_internal)
.def("copy_outputs_to_cpu", [](const SessionIOBinding* io_binding) -> std::vector<py::object> {
.def("copy_outputs_to_cpu", [](const SessionIOBinding* io_binding) -> py::list {
const std::vector<OrtValue>& outputs = io_binding->Get()->GetOutputs();
std::vector<py::object> rfetch;
rfetch.reserve(outputs.size());

size_t pos = 0;
const auto& dtm = io_binding->GetInferenceSession()->GetDataTransferManager();

py::list result;
for (const auto& ort_value : outputs) {
if (ort_value.IsTensor()) {
rfetch.push_back(AddTensorAsPyObj(ort_value, &dtm, nullptr));
// We make a copy of the tensor to CPU even if it is already on CPU
// as the function name implies using DataTransferManager.
py::array arr = PrimitiveTensorToNumpyFromDevice(ort_value, &dtm);
result.append(py::cast<py::object>(arr));
} else if (ort_value.IsSparseTensor()) {
rfetch.push_back(GetPyObjectFromSparseTensor(pos, ort_value, &dtm));
result.append(GetPyObjectFromSparseTensor(pos, ort_value, &dtm));
} else {
rfetch.push_back(AddNonTensorAsPyObj(ort_value, &dtm, nullptr));
result.append(AddNonTensorAsPyObj(ort_value, &dtm, nullptr));
}
++pos;
}
return rfetch;
return result;
});
}

Expand Down
43 changes: 40 additions & 3 deletions onnxruntime/python/onnxruntime_pybind_mlvalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include "core/framework/ort_value.h"
#include "core/session/inference_session.h"

#include <variant>

Check warning on line 19 in onnxruntime/python/onnxruntime_pybind_mlvalue.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: onnxruntime_pybind_mlvalue.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/python/onnxruntime_pybind_mlvalue.h:19: Found C++ system header after other header. Should be: onnxruntime_pybind_mlvalue.h, c system, c++ system, other. [build/include_order] [4]

PYBIND11_MAKE_OPAQUE(std::vector<OrtValue>);

namespace onnxruntime {
Expand All @@ -40,6 +42,8 @@

using MemCpyFunc = void (*)(void*, const void*, size_t);

using DataTransferAlternative = std::variant<const DataTransferManager*, MemCpyFunc>;

void CpuToCpuMemCpy(void*, const void*, size_t);

void CopyDataToTensor(const pybind11::array& py_array, int npy_type, Tensor& tensor, MemCpyFunc mem_cpy_to_device = CpuToCpuMemCpy);
Expand Down Expand Up @@ -117,9 +121,42 @@
const std::string& name_input, const pybind11::object& value, OrtValue* p_mlvalue,
bool accept_only_numpy_array = false, bool use_numpy_data_memory = true, MemCpyFunc mem_cpy_to_device = CpuToCpuMemCpy);

void GetPyObjFromTensor(const Tensor& rtensor, pybind11::object& obj,
const DataTransferManager* data_transfer_manager = nullptr,
const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* mem_cpy_to_host_functions = nullptr);
pybind11::object GetPyObjFromTensor(const OrtValue& rtensor,
const DataTransferManager* data_transfer_manager = nullptr,
const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* mem_cpy_to_host_functions = nullptr);

Check warning on line 126 in onnxruntime/python/onnxruntime_pybind_mlvalue.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/python/onnxruntime_pybind_mlvalue.h:126: Lines should be <= 120 characters long [whitespace/line_length] [2]

Check warning on line 126 in onnxruntime/python/onnxruntime_pybind_mlvalue.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/python/onnxruntime_pybind_mlvalue.h:126: Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4]

// The below two functions are used to convert OrtValue to numpy arrays

/// <summary>
/// This function operates on string tensors. Strings are always
/// copied to python and converted to UTF-16/UCS-4/32 depending on the platform.
/// This is accomplished using py::cast()
///
/// It is an error to pass a non-tensor or a non-string tensor to this function.
/// </summary>
/// <param name="tensor">Tensor that contains strings</param>
/// <returns>py::array object</returns>
pybind11::array StringTensorToNumpyArray(const Tensor& tensor);

/// <summary>
/// Creates a numpy array with shape over OrtValue memory. Numpy array
/// does not own the memory, but it holds a copy or OrtValue in a py::capsule.
/// OrtValue is destroyed when the numpy array is garbage collected.
/// This is used when the OrtValue memory is on CPU.
/// </summary>
/// <param name="ort_value">OrtValue with data</param>
/// <returns>numpy array</returns>
pybind11::array PrimitiveTensorToNumpyOverOrtValue(const OrtValue& ort_value);

/// <summary>
/// Creates a numpy array with shape with a copy of OrtValue data.
/// This function is used when the OrtValue memory is not on CPU.
/// </summary>
/// <param name="ort_value">Source memory that is not on CPU.</param>
/// <param name="data_transfer">a variant encapsulating alternatives for copying data</param>
/// <returns></returns>
pybind11::array PrimitiveTensorToNumpyFromDevice(const OrtValue& ort_value,
const DataTransferAlternative& data_transfer);

template <class T>
struct DecRefFn {
Expand Down
22 changes: 10 additions & 12 deletions onnxruntime/python/onnxruntime_pybind_ortvalue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,20 +233,20 @@ void addOrtValueMethods(pybind11::module& m) {
#endif
})
.def("shape", [](const OrtValue* ort_value) -> py::list {
py::list shape_arr;
#if !defined(DISABLE_SPARSE_TENSORS)
// OrtValue can only be a Tensor/SparseTensor, make this generic to handle non-Tensors
ORT_ENFORCE(ort_value->IsTensor() || ort_value->IsSparseTensor(),
"Only OrtValues that are Tensors/SpareTensors are currently supported");

const auto& dims = (ort_value->IsTensor())
? ort_value->Get<Tensor>().Shape().GetDims()
: ort_value->Get<SparseTensor>().DenseShape().GetDims();
const auto dims = (ort_value->IsTensor())
? ort_value->Get<Tensor>().Shape().GetDims()
: ort_value->Get<SparseTensor>().DenseShape().GetDims();
#else
ORT_ENFORCE(ort_value->IsTensor(), "Only OrtValues that are Tensors are supported in this build");
const auto& dims = ort_value->Get<Tensor>().Shape().GetDims();
const auto dims = ort_value->Get<Tensor>().Shape().GetDims();
#endif

py::list shape_arr;
for (auto dim : dims) {
// For sequence tensors - we would append a list of dims to the outermost list
// For now only tensors are supported in OrtValue
Expand Down Expand Up @@ -302,18 +302,16 @@ void addOrtValueMethods(pybind11::module& m) {
.def("numpy", [](const OrtValue* ml_value) -> py::object {
ORT_ENFORCE(ml_value->IsTensor(), "Only OrtValues that are Tensors are convertible to Numpy objects");

py::object obj;

#ifdef USE_CUDA
GetPyObjFromTensor(ml_value->Get<Tensor>(), obj, nullptr, GetCudaToHostMemCpyFunction());
py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetCudaToHostMemCpyFunction());
#elif USE_ROCM
GetPyObjFromTensor(ml_value->Get<Tensor>(), obj, nullptr, GetRocmToHostMemCpyFunction());
py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetRocmToHostMemCpyFunction());
#elif USE_CANN
GetPyObjFromTensor(ml_value->Get<Tensor>(), obj, nullptr, GetCannToHostMemCpyFunction());
py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetCannToHostMemCpyFunction());
#elif USE_DML
GetPyObjFromTensor(ml_value->Get<Tensor>(), obj, nullptr, GetDmlToHostMemCpyFunction());
py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetDmlToHostMemCpyFunction());
#else
GetPyObjFromTensor(ml_value->Get<Tensor>(), obj, nullptr, nullptr);
py::object obj = GetPyObjFromTensor(*ml_value, nullptr, nullptr);
#endif
return obj;
})
Expand Down
15 changes: 2 additions & 13 deletions onnxruntime/python/onnxruntime_pybind_sparse_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,18 +305,7 @@ void addSparseTensorMethods(pybind11::module& m) {
if (sparse_tensor.IsDataTypeString()) {
// Strings can not be on GPU and require conversion UTF-8 to Python UNICODE
// We need to create a copy.
const int numpy_type = OnnxRuntimeTensorToNumpyType(DataTypeImpl::GetType<std::string>());
ORT_ENFORCE(NPY_OBJECT == numpy_type, "We are expecting to map strings to NPY_OBJECT type");
const auto& values_shape = sparse_tensor.Values().Shape();
py::dtype dtype("object");
py::array result(dtype, values_shape.GetDims(), {});
auto* out_ptr = static_cast<py::object*>(
PyArray_DATA(reinterpret_cast<PyArrayObject*>(result.ptr())));
const std::string* src = sparse_tensor.Values().Data<std::string>();
for (int64_t i = 0, size = values_shape.Size(); i < size; ++i, src++) {
out_ptr[i] = py::cast(*src);
}
return result;
return StringTensorToNumpyArray(sparse_tensor.Values());
} else {
utils::MLTypeCallDispatcher<float, double, int8_t, uint8_t, int16_t, uint16_t, int32_t, uint32_t, int64_t, uint64_t>
t_disp(sparse_tensor.GetElementType());
Expand Down Expand Up @@ -386,7 +375,7 @@ void addSparseTensorMethods(pybind11::module& m) {
})
.def("dense_shape", [](const PySparseTensor* py_tensor) -> py::list {
const SparseTensor& st = py_tensor->Instance();
const auto& dims = st.DenseShape().GetDims();
const auto dims = st.DenseShape().GetDims();
// We create a copy of dimensions, it is small
py::list py_dims;
for (auto d : dims) {
Expand Down
Loading
Loading