From 3429155dad20d89f8a90a5188e5a01657dc2e31c Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Wed, 1 May 2024 17:52:03 -0700 Subject: [PATCH] Implement numpy array over CPU OrtValues on return --- .../python/onnxruntime_pybind_iobinding.cc | 15 +- .../python/onnxruntime_pybind_mlvalue.h | 42 +++- .../python/onnxruntime_pybind_ortvalue.cc | 12 +- .../onnxruntime_pybind_sparse_tensor.cc | 13 +- .../python/onnxruntime_pybind_state.cc | 185 +++++++++++------- 5 files changed, 162 insertions(+), 105 deletions(-) diff --git a/onnxruntime/python/onnxruntime_pybind_iobinding.cc b/onnxruntime/python/onnxruntime_pybind_iobinding.cc index 59d5a77bfbea3..bbd1d109195a5 100644 --- a/onnxruntime/python/onnxruntime_pybind_iobinding.cc +++ b/onnxruntime/python/onnxruntime_pybind_iobinding.cc @@ -161,23 +161,24 @@ void addIoBindingMethods(pybind11::module& m) { return io_binding->Get()->GetOutputs(); }, py::return_value_policy::reference_internal) - .def("copy_outputs_to_cpu", [](const SessionIOBinding* io_binding) -> std::vector { + .def("copy_outputs_to_cpu", [](const SessionIOBinding* io_binding) -> py::list { const std::vector& outputs = io_binding->Get()->GetOutputs(); - std::vector rfetch; - rfetch.reserve(outputs.size()); + size_t pos = 0; const auto& dtm = io_binding->GetInferenceSession()->GetDataTransferManager(); + + py::list result; for (const auto& ort_value : outputs) { if (ort_value.IsTensor()) { - rfetch.push_back(AddTensorAsPyObj(ort_value, &dtm, nullptr)); + result.append(AddTensorAsPyObj(ort_value, &dtm, nullptr)); } else if (ort_value.IsSparseTensor()) { - rfetch.push_back(GetPyObjectFromSparseTensor(pos, ort_value, &dtm)); + result.append(GetPyObjectFromSparseTensor(pos, ort_value, &dtm)); } else { - rfetch.push_back(AddNonTensorAsPyObj(ort_value, &dtm, nullptr)); + result.append(AddNonTensorAsPyObj(ort_value, &dtm, nullptr)); } ++pos; } - return rfetch; + return result; }); } diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.h b/onnxruntime/python/onnxruntime_pybind_mlvalue.h index e3f277bcb9c41..e802c5f4a4832 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.h +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.h @@ -117,9 +117,45 @@ void CreateGenericMLValue(const onnxruntime::InputDefList* input_def_list, const const std::string& name_input, const pybind11::object& value, OrtValue* p_mlvalue, bool accept_only_numpy_array = false, bool use_numpy_data_memory = true, MemCpyFunc mem_cpy_to_device = CpuToCpuMemCpy); -void GetPyObjFromTensor(const Tensor& rtensor, pybind11::object& obj, - const DataTransferManager* data_transfer_manager = nullptr, - const std::unordered_map* mem_cpy_to_host_functions = nullptr); +pybind11::object GetPyObjFromTensor(const OrtValue& rtensor, + const DataTransferManager* data_transfer_manager = nullptr, + const std::unordered_map* mem_cpy_to_host_functions = nullptr); + +// The below two functions are used to convert OrtValue to numpy arrays + +/// +/// This function operates on string tensors. Strings are always +/// copied to python and converted to UTF-16/UCS-4/32 dependeing on the platform. +/// This is accomplished using py::cast() +/// +/// It is an error to pass a non-tensor or a non-string tensor to this function. +/// +/// Tensor that contains strings +/// py::array object +pybind11::array StringTensorToNumpyArray(const Tensor& tensor); + +/// +/// Creates a numpy array with shape over OrtValue memory. Numpy array +/// does not own the memory, but it holds a copy or OrtValue in a py::capsule. +/// OrtValue is destroyed when the numpy array is garbage collected. +/// This is used when the OrtValue memory is on CPU. +/// +/// OrtValue with data +/// numpy array +pybind11::array PrimitiveTensorToNumpyOverOrtValue(const OrtValue& ort_value); + +/// +/// Creates a numpy array with shape with a copy of OrtValue data. +/// This function is used when the OrtValue memory is not on CPU. +/// Eother data_transfer or func should not be null. +/// +/// Source memory that is not on CPYU. +/// data transfer manager +/// copy function if data transfer mamanger is not available. +/// +pybind11::array PrimitiveTensorToNumpyFromDevice(const OrtValue& ort_value, + const DataTransferManager* data_transfer, + MemCpyFunc func); template struct DecRefFn { diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc index dc4a4dcc13b7f..dfaf50c2cc99b 100644 --- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc @@ -302,18 +302,16 @@ void addOrtValueMethods(pybind11::module& m) { .def("numpy", [](const OrtValue* ml_value) -> py::object { ORT_ENFORCE(ml_value->IsTensor(), "Only OrtValues that are Tensors are convertible to Numpy objects"); - py::object obj; - #ifdef USE_CUDA - GetPyObjFromTensor(ml_value->Get(), obj, nullptr, GetCudaToHostMemCpyFunction()); + py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetCudaToHostMemCpyFunction()); #elif USE_ROCM - GetPyObjFromTensor(ml_value->Get(), obj, nullptr, GetRocmToHostMemCpyFunction()); + py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetRocmToHostMemCpyFunction()); #elif USE_CANN - GetPyObjFromTensor(ml_value->Get(), obj, nullptr, GetCannToHostMemCpyFunction()); + py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetCannToHostMemCpyFunction()); #elif USE_DML - GetPyObjFromTensor(ml_value->Get(), obj, nullptr, GetDmlToHostMemCpyFunction()); + py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetDmlToHostMemCpyFunction()); #else - GetPyObjFromTensor(ml_value->Get(), obj, nullptr, nullptr); + py::object obj = GetPyObjFromTensor(*ml_value, nullptr, nullptr); #endif return obj; }) diff --git a/onnxruntime/python/onnxruntime_pybind_sparse_tensor.cc b/onnxruntime/python/onnxruntime_pybind_sparse_tensor.cc index 5c3118081da89..80d7d140408ae 100644 --- a/onnxruntime/python/onnxruntime_pybind_sparse_tensor.cc +++ b/onnxruntime/python/onnxruntime_pybind_sparse_tensor.cc @@ -305,18 +305,7 @@ void addSparseTensorMethods(pybind11::module& m) { if (sparse_tensor.IsDataTypeString()) { // Strings can not be on GPU and require conversion UTF-8 to Python UNICODE // We need to create a copy. - const int numpy_type = OnnxRuntimeTensorToNumpyType(DataTypeImpl::GetType()); - ORT_ENFORCE(NPY_OBJECT == numpy_type, "We are expecting to map strings to NPY_OBJECT type"); - const auto& values_shape = sparse_tensor.Values().Shape(); - py::dtype dtype("object"); - py::array result(dtype, values_shape.GetDims(), {}); - auto* out_ptr = static_cast( - PyArray_DATA(reinterpret_cast(result.ptr()))); - const std::string* src = sparse_tensor.Values().Data(); - for (int64_t i = 0, size = values_shape.Size(); i < size; ++i, src++) { - out_ptr[i] = py::cast(*src); - } - return result; + return StringTensorToNumpyArray(sparse_tensor.Values()); } else { utils::MLTypeCallDispatcher t_disp(sparse_tensor.GetElementType()); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 7fc6515d3d50a..653f8f6581743 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -166,66 +166,100 @@ static py::object AddNonTensor(const OrtValue& val, return py::cast(val.Get()); } -// In all cases, we may not have access to a DataTransferManager, hence the user may specify functions that -// pretty much does what a DataTransferManager does - copy data from device(s) to the host -void GetPyObjFromTensor(const Tensor& rtensor, py::object& obj, - const DataTransferManager* data_transfer_manager, - const std::unordered_map* mem_cpy_to_host_functions) { - std::vector npy_dims; - const TensorShape& shape = rtensor.Shape(); - - for (size_t n = 0; n < shape.NumDimensions(); ++n) { - npy_dims.push_back(shape[n]); +// This function is used to return strings from a string tensor to python +// as a numpy array of strings +// Strings are always on CPU and must always be copied to python memory +py::array StringTensorToNumpyArray(const Tensor& tensor) { + // Create the result and allocate memory with the right size + py::array result(py::dtype(NPY_OBJECT), tensor.Shape().GetDims()); + const auto span = tensor.DataAsSpan(); + auto* mutable_data = reinterpret_cast(result.mutable_data()); + for (size_t i = 0, lim = span.size(); i < lim; ++i) { + mutable_data[i] = py::cast(span[i]); } + return result; +} - MLDataType dtype = rtensor.DataType(); - const int numpy_type = OnnxRuntimeTensorToNumpyType(dtype); - obj = py::reinterpret_steal(PyArray_SimpleNew( - narrow(shape.NumDimensions()), npy_dims.data(), numpy_type)); - - void* out_ptr = static_cast( - PyArray_DATA(reinterpret_cast(obj.ptr()))); - - if (numpy_type != NPY_OBJECT) { - // if it is not cpu tensor, need to copy to host - auto device_type = rtensor.Location().device.Type(); - if (device_type != OrtDevice::CPU) { - if (!data_transfer_manager && !mem_cpy_to_host_functions) - throw std::runtime_error( - "GetPyObjFromTensor: Either data transfer manager or a " - "function to copy data to the host is needed to convert non-CPU tensor to numpy array"); - static const OrtMemoryInfo cpu_alloc_info{onnxruntime::CPU, OrtDeviceAllocator}; - - // Prefer DataTransferManager if available - if (data_transfer_manager) { - auto span = gsl::make_span(reinterpret_cast(out_ptr), dtype->Size() * shape.Size()); - ORT_THROW_IF_ERROR(CopyTensorDataToByteSpan( - *data_transfer_manager, rtensor, cpu_alloc_info, span)); - } else { - auto mem_cpy_to_host = mem_cpy_to_host_functions->find(device_type); +pybind11::array PrimitiveTensorToNumpyOverOrtValue(const OrtValue& ort_value) { + const Tensor& tensor = ort_value.Get(); + // The capsule destructor must be stateless + // We create a copy of OrtValue on the heap. + auto memory_release = [](void* data) { + auto* ort_value = reinterpret_cast(data); + delete ort_value; + }; - ORT_ENFORCE(mem_cpy_to_host != mem_cpy_to_host_functions->end(), - "Unable to locate a function that can copy data to the host from the device"); + const int numpy_type = OnnxRuntimeTensorToNumpyType(tensor.DataType()); + auto ort_value_ptr = std::make_unique(ort_value); + // Not using array_t because it may not handle MLFloat16 properly + pybind11::array result(py::dtype(numpy_type), tensor.Shape().GetDims(), + tensor.DataRaw(), + pybind11::capsule(ort_value_ptr.get(), memory_release)); - ORT_ENFORCE(mem_cpy_to_host->second != 0, - "No function that can copy data to the host from the device provided"); + ort_value_ptr.release(); + return result; +} - mem_cpy_to_host->second(out_ptr, rtensor.DataRaw(), dtype->Size() * shape.Size()); - } +pybind11::array PrimitiveTensorToNumpyFromDevice(const OrtValue& ort_value, + const DataTransferManager* data_transfer, + MemCpyFunc func) { + const Tensor& tensor = ort_value.Get(); + const int numpy_type = OnnxRuntimeTensorToNumpyType(tensor.DataType()); + pybind11::array result(py::dtype(numpy_type), tensor.Shape().GetDims()); + void* data = result.mutable_data(); + + if (data_transfer != nullptr) { + static const OrtMemoryInfo cpu_alloc_info{onnxruntime::CPU, OrtDeviceAllocator}; + const auto span = gsl::make_span(reinterpret_cast(data), tensor.SizeInBytes()); + ORT_THROW_IF_ERROR(CopyTensorDataToByteSpan(*data_transfer, tensor, cpu_alloc_info, span)); + } else if (func != nullptr) { + func(data, tensor.DataRaw(), tensor.SizeInBytes()); + } else { + throw std::runtime_error( + "Data transfer manager and memcpy function cannot be both null" + " in PrimitiveTensorToNumpyFromDevice"); + } + return result; +} + +// In all cases, we may not have access to a DataTransferManager, hence the user may specify functions that +// pretty much does what a DataTransferManager does - copy data from device(s) to the host +py::object GetPyObjFromTensor(const OrtValue& ort_value, + const DataTransferManager* data_transfer_manager, + const std::unordered_map* mem_cpy_to_host_functions) { + ORT_ENFORCE(ort_value.IsTensor(), "This function only supports tensors"); + + const auto& tensor = ort_value.Get(); + if (tensor.IsDataTypeString()) { + ORT_ENFORCE(tensor.Location().device.Type() == OrtDevice::CPU, "Strings can only be on CPU"); + // Create a numpy array of strings (python objects) by copy/converting them + py::array result = StringTensorToNumpyArray(tensor); + return py::cast(result); + } - } else - memcpy(out_ptr, rtensor.DataRaw(dtype), dtype->Size() * shape.Size()); + const auto device_type = tensor.Location().device.Type(); + // Create an numpy array on top of the OrtValue memory, no copy + if (device_type == OrtDevice::CPU) { + py::array result = PrimitiveTensorToNumpyOverOrtValue(ort_value); + return py::cast(result); + } + + if (!data_transfer_manager && !mem_cpy_to_host_functions) { + throw std::runtime_error( + "GetPyObjFromTensor: Either data transfer manager or a " + "function to copy data to the host is needed to convert non-CPU tensor to numpy array"); + } + + py::array result; + if (data_transfer_manager != nullptr) { + result = PrimitiveTensorToNumpyFromDevice(ort_value, data_transfer_manager, nullptr); } else { - // Handle string type. - // Copying strings to cpu from device is currently not supported - ORT_ENFORCE(rtensor.Location().device.Type() == OrtDevice::CPU, - "Copying string tensors located on another device to the host is currently not supported"); - py::object* outObj = static_cast(out_ptr); - const std::string* src = rtensor.Data(); - for (int i = 0; i < rtensor.Shape().Size(); i++, src++) { - outObj[i] = py::cast(*src); - } + auto mem_cpy_to_host = mem_cpy_to_host_functions->find(device_type); + ORT_ENFORCE(mem_cpy_to_host != mem_cpy_to_host_functions->end(), + "Unable to locate a function that can copy data to the host from the device"); + result = PrimitiveTensorToNumpyFromDevice(ort_value, nullptr, mem_cpy_to_host->second); } + return py::cast(result); } const char* GetDeviceName(const OrtDevice& device) { @@ -292,9 +326,8 @@ py::object AddNonTensor(const OrtValue& val, const std::unordered_map* mem_cpy_to_host_functions) { const auto& seq_tensors = val.Get(); py::list py_list; - for (const auto& rtensor : seq_tensors) { - py::object obj; - GetPyObjFromTensor(rtensor.Get(), obj, data_transfer_manager, mem_cpy_to_host_functions); + for (const auto& ort_value : seq_tensors) { + py::object obj = GetPyObjFromTensor(ort_value, data_transfer_manager, mem_cpy_to_host_functions); py_list.append(obj); } // XToolChain kills the build @@ -347,10 +380,7 @@ py::object AddNonTensorAsPyObj(const OrtValue& val, py::object AddTensorAsPyObj(const OrtValue& val, const DataTransferManager* data_transfer_manager, const std::unordered_map* mem_cpy_to_host_functions) { - const Tensor& rtensor = val.Get(); - py::object obj; - GetPyObjFromTensor(rtensor, obj, data_transfer_manager, mem_cpy_to_host_functions); - return obj; + return GetPyObjFromTensor(val, data_transfer_manager, mem_cpy_to_host_functions); } static std::unique_ptr LoadExecutionProvider( @@ -1863,11 +1893,12 @@ including arg name, arg type (contains both type and shape).)pbdoc") }, R"pbdoc(Load a model saved in ONNX or ORT format.)pbdoc") .def("run", - [](PyInferenceSession* sess, std::vector output_names, - std::map pyfeeds, RunOptions* run_options = nullptr) - -> std::vector { + [](PyInferenceSession* sess, const std::vector& output_names, + const std::map& pyfeeds, RunOptions* run_options = nullptr) + -> py::list { NameMLValMap feeds; - for (auto feed : pyfeeds) { + feeds.reserve(pyfeeds.size()); + for (const auto& feed : pyfeeds) { // No need to process 'None's sent in by the user // to feed Optional inputs in the graph. // We just won't include anything in the feed and ORT @@ -1885,6 +1916,7 @@ including arg name, arg type (contains both type and shape).)pbdoc") } std::vector fetches; + fetches.reserve(output_names.size()); common::Status status; { @@ -1897,29 +1929,28 @@ including arg name, arg type (contains both type and shape).)pbdoc") } } - std::vector rfetch; - rfetch.reserve(fetches.size()); + py::list result; size_t pos = 0; - for (auto fet : fetches) { + for (const auto& fet : fetches) { if (fet.IsAllocated()) { if (fet.IsTensor()) { - rfetch.push_back(AddTensorAsPyObj(fet, nullptr, nullptr)); + result.append(AddTensorAsPyObj(fet, nullptr, nullptr)); } else if (fet.IsSparseTensor()) { - rfetch.push_back(GetPyObjectFromSparseTensor(pos, fet, nullptr)); + result.append(GetPyObjectFromSparseTensor(pos, fet, nullptr)); } else { - rfetch.push_back(AddNonTensorAsPyObj(fet, nullptr, nullptr)); + result.append(AddNonTensorAsPyObj(fet, nullptr, nullptr)); } } else { // Send back None because the corresponding OrtValue was empty - rfetch.push_back(py::none()); + result.append(py::none()); } ++pos; } - return rfetch; + return result; }) .def("run_async", [](PyInferenceSession* sess, - std::vector output_names, - std::map pyfeeds, + const std::vector& output_names, + const std::map& pyfeeds, PyCallback callback, py::object user_data = {}, RunOptions* run_options = nullptr) -> void { @@ -1928,7 +1959,7 @@ including arg name, arg type (contains both type and shape).)pbdoc") async_resource->user_data = user_data; // prepare feeds async_resource->ReserveFeeds(pyfeeds.size()); - for (auto feed : pyfeeds) { + for (const auto& feed : pyfeeds) { if (!feed.second.is(py::none())) { OrtValue ml_value; auto px = sess->GetSessionHandle()->GetModelInputs(); @@ -1945,7 +1976,7 @@ including arg name, arg type (contains both type and shape).)pbdoc") } // prepare fetches async_resource->ReserveFetches(output_names.size()); - for (auto& output_name : output_names) { + for (const auto& output_name : output_names) { async_resource->fetch_names.push_back(output_name); async_resource->fetch_names_raw.push_back(async_resource->fetch_names.back().c_str()); async_resource->fetches_raw.push_back({}); @@ -1968,15 +1999,17 @@ including arg name, arg type (contains both type and shape).)pbdoc") /// a Tensor, SparseTensor or a TensorSequence. .def("run_with_ort_values", [](PyInferenceSession* sess, const py::dict& feeds, const std::vector& output_names, RunOptions* run_options = nullptr) -> std::vector { NameMLValMap ort_feeds; + ort_feeds.reserve(feeds.size()); // item is always a copy since dict returns a value and not a ref // and Apple XToolChain barks - for (const auto item : feeds) { + for (const auto& item : feeds) { auto name = item.first.cast(); const OrtValue* ort_value = item.second.cast(); ort_feeds.emplace(name, *ort_value); } std::vector fetches; + fetches.reserve(output_names.size()); { // release GIL to allow multiple python threads to invoke Run() in parallel. py::gil_scoped_release release;