Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[API] Add new dlpack API #20546

Merged
merged 19 commits into from
Nov 29, 2021
8 changes: 8 additions & 0 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,14 @@ integrationtest_ubuntu_cpp_package_gpu() {
cpp-package/tests/ci_test.sh
}

test_python3_data_interchange_gpu() {
set -ex
python3 -m pip install torch==1.10.0+cu113 torchvision==0.11.1+cu113 torchaudio==0.10.0+cu113 \
-f https://download.pytorch.org/whl/cu113/torch_stable.html
MXNET_ENGINE_TYPE=ThreadedEngineAsync \
python3 -m pytest --durations=50 tests/python/array-api/test_data_interchange.py
}

integrationtest_ubuntu_cpu_onnx() {
set -ex
export PYTHONPATH=./python/
Expand Down
14 changes: 14 additions & 0 deletions ci/jenkins/Jenkins_steps.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,20 @@ def test_unix_cpp_package_gpu(lib_name) {
}]
}

def test_unix_python3_data_interchange_gpu(lib_name) {
return ['Data Interchange': {
node(NODE_LINUX_GPU_G4) {
ws('workspace/it-data-interchange') {
timeout(time: max_time, unit: 'MINUTES') {
utils.unpack_and_init(lib_name, mx_lib)
utils.docker_run('ubuntu_gpu_cu111', 'test_python3_data_interchange_gpu', true)
utils.publish_test_coverage()
}
}
}
}]
}

def test_centos7_python3_cpu(lib_name) {
return ['Python3: CentOS 7 CPU': {
node(NODE_LINUX_CPU) {
Expand Down
1 change: 1 addition & 0 deletions ci/jenkins/Jenkinsfile_unix_gpu
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ core_logic: {
custom_steps.test_unix_python3_onednn_gpu('onednn_gpu'),
custom_steps.test_unix_python3_onednn_nocudnn_gpu('onednn_gpu_nocudnn'),
custom_steps.test_unix_cpp_package_gpu('gpu'),
custom_steps.test_unix_python3_data_interchange_gpu('gpu'),
// TODO(szha): fix and reenable the hanging issue. tracked in #18098
// custom_steps.test_unix_distributed_kvstore_gpu('gpu'),
// TODO(spanev): reenable when byteps is updated with the new dep engine API
Expand Down
14 changes: 14 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -3102,6 +3102,20 @@ MXNET_DLL int MXEnginePushSyncND(EngineSyncFunc sync_func,
*/
MXNET_DLL int MXCheckDynamicShapeOp(SymbolHandle sym_handle, bool* has_dynamic_shape);

/*!
* \brief Synchronize the consumer stream with the producer stream where the NDArray lives.
* \param handle NDArray handle of producer.
* \param stream A pointer to a stream from consumer.
*/
MXNET_DLL int MXPushStreamDep(NDArrayHandle handle, int stream);

/*!
* \brief Get current stream pointer based on current device type and id
* \param device_id Current device id.
* \param stream A pointer pointing to current stream.
*/
MXNET_DLL int MXGetCurrentStream(int device_id, int* stream);

/*!
* \brief Push a new NVTX range. Requires building with CUDA and NVTX.
* \param name Name of the range.
Expand Down
6 changes: 6 additions & 0 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,12 @@ class NDArray {
* trigger computation.
*/
void WaitToWrite() const;
/*!
* \brief Synchronize the destination stream provided by consumer with the
* source stream that current NDArray lives on.
* \param stream a pointer to the stream provided by consumer.
*/
void StreamSync(int stream) const;
/*! \return the associated variable of the ndarray.*/
inline Engine::VarHandle var() const {
return ptr_->var;
Expand Down
32 changes: 30 additions & 2 deletions python/mxnet/dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
"""DLPack API of MXNet."""

import ctypes
from .base import _LIB, c_str, check_call, NDArrayHandle
import enum

from mxnet.device import current_device
from .base import _LIB, c_str, check_call, NDArrayHandle, mx_int

DLPackHandle = ctypes.c_void_p

Expand All @@ -39,6 +42,18 @@ def _dlpack_deleter(pycapsule):

_c_dlpack_deleter = PyCapsuleDestructor(_dlpack_deleter)

class DLDeviceType(enum.IntEnum):
DLCPU = 1,
DLGPU = 2,
DLCPUPINNED = 3,
DLOPENCL = 4,
DLVULKAN = 7,
DLMETAL = 8,
DLVPI = 9,
DLROCM = 10,
DLEXTDEV = 12,


class DLContext(ctypes.Structure):
_fields_ = [("device_type", ctypes.c_int),
("device_id", ctypes.c_int)]
Expand Down Expand Up @@ -94,8 +109,21 @@ def ndarray_from_dlpack(array_cls):
fn : dlpack -> array_cls
"""
def from_dlpack(dlpack):
tp = type(dlpack)
if tp.__module__ == "builtins" and tp.__name__ == "PyCapsule":
dlpack = ctypes.py_object(dlpack)
elif hasattr(dlpack, "__dlpack__"):
device, device_id = dlpack.__dlpack_device__()
if device != DLDeviceType.DLGPU:
dlpack = ctypes.py_object(dlpack.__dlpack__())
else:
s = mx_int()
check_call(_LIB.MXGetCurrentStream(
ctypes.c_int(device_id), ctypes.byref(s)))
dlpack = ctypes.py_object(dlpack.__dlpack__(stream=s.value))
else:
raise AttributeError("Required PyCapsule or object with __dlpack__")
handle = NDArrayHandle()
dlpack = ctypes.py_object(dlpack)
assert ctypes.pythonapi.PyCapsule_IsValid(dlpack, _c_str_dltensor), ValueError(
'Invalid DLPack Tensor. DLTensor capsules can be consumed only once.')
dlpack_handle = ctypes.c_void_p(ctypes.pythonapi.PyCapsule_GetPointer(dlpack, _c_str_dltensor))
Expand Down
43 changes: 41 additions & 2 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@
from ..ndarray import numpy as _mx_nd_np
from ..ndarray.numpy import _internal as _npi
from ..ndarray.ndarray import _storage_type
from ..dlpack import ndarray_from_numpy, ndarray_from_dlpack
from ..dlpack import ndarray_from_numpy, ndarray_to_dlpack_for_write, DLDeviceType,\
ndarray_from_dlpack
from .utils import _get_np_op
from .fallback import * # pylint: disable=wildcard-import,unused-wildcard-import
from . import fallback
Expand Down Expand Up @@ -446,6 +447,45 @@ def __array_namespace__(self, api_version=None):
return sys.modules[self.__module__]


def __dlpack__(self, stream=None):
"""Exports the array for consumption by from_dlpack() as a DLPack capsule.

Parameters
----------
stream : int, optional
A Python integer representing a pointer to a stream (CUDA or ROCm).
Stream is provided by the consumer to the producer to instruct the producer
to ensure that operations can safely be performed on the array. The pointer must
be positive integer or -1. If stream is -1, the value must be used by the consumer
to signal "producer must not perform any synchronization".

Returns
-------
capsule : PyCapsule
A DLPack capsule for the array, containing a DLPackManagedTensor.
"""
if stream is not None:
if type(stream) is not int:
raise TypeError('The input stream must be int or None')
if self.device.device_type != "gpu":
raise ValueError('Stream {} is not supported in current device {}'\
.format(stream, self.device.device_type))
if stream != -1:
check_call(_LIB.MXPushStreamDep(self.handle, ctypes.c_int64(stream)))
to_dlpack_write = ndarray_to_dlpack_for_write()
return to_dlpack_write(self)


def __dlpack_device__(self):
"""Returns device type and device ID in DLPack format"""
devtype_map = {'cpu': DLDeviceType.DLCPU,
'gpu': DLDeviceType.DLGPU,
'cpu_pinned': DLDeviceType.DLCPUPINNED}
if self.device.device_type not in devtype_map:
raise ValueError('Unkown device type {} for DLPack'.format(self.device.device_type))
return (devtype_map[self.device.device_type], self.device.device_id)


def _get_np_basic_indexing(self, key):
"""
This function indexes ``self`` with a tuple of `slice` objects only.
Expand Down Expand Up @@ -13183,7 +13223,6 @@ def sum(a, axis=None, dtype=None, out=None, keepdims=None, initial=None, where=N
array(-128, dtype=int8)
"""
return _mx_nd_np.sum(a, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where)
# pylint: enable=redefined-outer-name, too-many-arguments


@set_module('mxnet.numpy')
Expand Down
18 changes: 18 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3947,6 +3947,24 @@ int MXShallowCopyNDArray(NDArrayHandle src_handle, NDArrayHandle* out) {
API_END_HANDLE_ERROR(delete ret);
}

int MXPushStreamDep(NDArrayHandle handle, int stream) {
API_BEGIN();
static_cast<NDArray*>(handle)->StreamSync(stream);
API_END();
}

int MXGetCurrentStream(int device_id, int* stream) {
API_BEGIN();
#if MXNET_USE_CUDA
RunContext rctx{Context::GPU(device_id), new mshadow::Stream<gpu>(), nullptr};
mshadow::Stream<gpu>* cur_stream = rctx.get_stream<gpu>();
*stream = reinterpret_cast<int64_t>(mshadow::Stream<gpu>::GetStream(cur_stream));
#else
LOG(FATAL) << "GPU is not enabled.";
#endif
API_END();
}

int MXNVTXRangePush(const char* name, mx_uint color) {
API_BEGIN();
#if MXNET_USE_CUDA && MXNET_USE_NVTX
Expand Down
59 changes: 59 additions & 0 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2483,6 +2483,65 @@ void NDArray::WaitToWrite() const {
Engine::Get()->WaitForVar(ptr_->var);
}

void NDArray::StreamSync(int stream) const {
if (is_none())
return;
Imperative::DCInfo::Compute(*this);
#if MXNET_USE_CUDA
Engine::Get()->PushAsync(
[this, stream](RunContext ctx,
Engine::CallbackOnStart on_start,
Engine::CallbackOnComplete on_complete) {
on_start();
cudaStream_t consumer = reinterpret_cast<cudaStream_t>(stream);
std::unordered_map<cudaStream_t, engine::EventInfo> events_per_stream;
auto& sync_obj = this->var()->sync_object;
std::lock_guard<std::mutex> l(sync_obj.mutex);
auto& reader_events = sync_obj.reader_events;
reader_events.erase(
std::remove_if(reader_events.begin(),
reader_events.end(),
[&](const engine::EventInfo e_i) { return e_i.event.expired(); }),
reader_events.end());
for (auto& writer : sync_obj.writer_event) {
if (writer.event.expired()) {
sync_obj.writer_event.clear();
break;
}
if (writer.stream != consumer) {
bool found = false;
for (const auto& reader : reader_events) {
if (reader.stream == consumer) {
found = true;
break;
}
}
if (!found) {
auto event_stream = writer.stream;
if (events_per_stream.count(event_stream) > 0) {
if (events_per_stream[event_stream].pool_index < writer.pool_index) {
events_per_stream[event_stream] = writer;
}
} else {
events_per_stream.emplace(event_stream, writer);
}
}
}
}
for (auto event : events_per_stream) {
auto ev = event.second.event.lock();
MSHADOW_CUDA_CALL(cudaStreamWaitEvent(consumer, *ev, 0));
}
on_complete();
},
this->ctx(),
{},
{});
#else
LOG(FATAL) << "GPU is not enabled";
#endif
}

#if MXNET_PREDICT_ONLY == 0
// register API function
// those with underscore will be registered at NDArray
Expand Down
65 changes: 65 additions & 0 deletions tests/python/array-api/test_data_interchange.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import mxnet as mx
from mxnet import np
import torch
import numpy
import pytest


def test_dlpack_torch_mxnet_torch():
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
x = torch.tensor((5,), device='cuda:0', dtype=torch.float64) + 1
stream.synchronize()
nx = np.from_dlpack(x)
assert nx.device == mx.gpu(0)
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
z = torch.from_dlpack(nx)
stream.synchronize()
z += 1
assert z == x

def test_dlpack_mxnet_torch_mxnet():
szha marked this conversation as resolved.
Show resolved Hide resolved
x = np.array([5], device=mx.gpu(), dtype="float64") + 1
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
tx = torch.from_dlpack(x)
stream.synchronize()
z = np.from_dlpack(tx)
z += 1
assert z.device == mx.gpu(0)
assert z == x

def test_dlpack_error_message():
with pytest.raises(AttributeError):
# raise Attribute Error, NumPy array is not PyCapsule or has __dlpack__ attribute
nx = numpy.array([5])
x = np.from_dlpack(nx)

with pytest.raises(TypeError):
# raise TypeError, Stream must be int or None
stream = torch.cuda.Stream()
x = np.array([5], device=mx.gpu(), dtype="float64")
tx = torch.from_dlpack(x.__dlpack__(stream=stream))

with pytest.raises(ValueError):
# raise ValueError, CPU device has no stream
x = np.array([5], dtype="float64")
tx = torch.from_dlpack(x.__dlpack__(stream=0))