diff --git a/cpp/src/interop/dlpack.cpp b/cpp/src/interop/dlpack.cpp
index e5da4794ca3..be64a9c9bc1 100644
--- a/cpp/src/interop/dlpack.cpp
+++ b/cpp/src/interop/dlpack.cpp
@@ -148,20 +148,33 @@ std::unique_ptr
from_dlpack(DLManagedTensor const* managed_tensor,
CUDF_EXPECTS(tensor.device.device_id == device_id, "DLTensor device ID must be current device");
}
- // Currently only 1D and 2D tensors are supported
- CUDF_EXPECTS(tensor.ndim > 0 && tensor.ndim <= 2, "DLTensor must be 1D or 2D");
-
+ // We only support 1D and 2D tensors with some restrictions on layout
+ if (tensor.ndim == 1) {
+ // 1D tensors must have dense layout (strides == nullptr <=> dense row-major)
+ CUDF_EXPECTS(nullptr == tensor.strides || tensor.strides[0] == 1,
+ "from_dlpack of 1D DLTensor only for unit-stride data");
+ } else if (tensor.ndim == 2) {
+ // 2D tensors must have column-major layout and the fastest dimension must have dense layout
+ CUDF_EXPECTS((
+ // 1D tensor reshaped into (N, 1) is fine
+ tensor.shape[1] == 1 && (nullptr == tensor.strides || tensor.strides[0] == 1))
+ // General case
+ || (nullptr != tensor.strides && tensor.strides[0] == 1 &&
+ tensor.strides[1] >= tensor.shape[0]),
+ "from_dlpack of 2D DLTensor only for column-major unit-stride data");
+ } else {
+ CUDF_FAIL("DLTensor must be 1D or 2D");
+ }
CUDF_EXPECTS(tensor.shape[0] >= 0,
- "DLTensor first dim should be of shape greater than or equal-to 0.");
+ "DLTensor first dim should be of shape greater than or equal to 0.");
CUDF_EXPECTS(tensor.shape[0] < std::numeric_limits::max(),
"DLTensor first dim exceeds size supported by cudf");
if (tensor.ndim > 1) {
CUDF_EXPECTS(tensor.shape[1] >= 0,
- "DLTensor second dim should be of shape greater than or equal-to 0.");
+ "DLTensor second dim should be of shape greater than or equal to 0.");
CUDF_EXPECTS(tensor.shape[1] < std::numeric_limits::max(),
"DLTensor second dim exceeds size supported by cudf");
}
-
size_t const num_columns = (tensor.ndim == 2) ? static_cast(tensor.shape[1]) : 1;
// Validate and convert data type to cudf
diff --git a/cpp/tests/interop/dlpack_test.cpp b/cpp/tests/interop/dlpack_test.cpp
index 2528c3e5a83..a722f66951f 100644
--- a/cpp/tests/interop/dlpack_test.cpp
+++ b/cpp/tests/interop/dlpack_test.cpp
@@ -208,6 +208,120 @@ TEST_F(DLPackUntypedTests, UnsupportedLanesFromDlpack)
EXPECT_THROW(cudf::from_dlpack(tensor.get()), cudf::logic_error);
}
+TEST_F(DLPackUntypedTests, UnsupportedBroadcast1DTensorFromDlpack)
+{
+ using T = float;
+ constexpr int ndim = 1;
+ // Broadcasted (stride-0) 1D tensor
+ auto const data = cudf::test::make_type_param_vector({1});
+ int64_t shape[ndim] = {5};
+ int64_t strides[ndim] = {0};
+
+ DLManagedTensor tensor{};
+ tensor.dl_tensor.device.device_type = kDLCPU;
+ tensor.dl_tensor.dtype = get_dtype();
+ tensor.dl_tensor.ndim = ndim;
+ tensor.dl_tensor.byte_offset = 0;
+ tensor.dl_tensor.shape = shape;
+ tensor.dl_tensor.strides = strides;
+
+ thrust::host_vector host_vector(data.begin(), data.end());
+ tensor.dl_tensor.data = host_vector.data();
+
+ EXPECT_THROW(cudf::from_dlpack(&tensor), cudf::logic_error);
+}
+
+TEST_F(DLPackUntypedTests, UnsupportedStrided1DTensorFromDlpack)
+{
+ using T = float;
+ constexpr int ndim = 1;
+ // Strided 1D tensor
+ auto const data = cudf::test::make_type_param_vector({1, 2, 3, 4});
+ int64_t shape[ndim] = {2};
+ int64_t strides[ndim] = {2};
+
+ DLManagedTensor tensor{};
+ tensor.dl_tensor.device.device_type = kDLCPU;
+ tensor.dl_tensor.dtype = get_dtype();
+ tensor.dl_tensor.ndim = ndim;
+ tensor.dl_tensor.byte_offset = 0;
+ tensor.dl_tensor.shape = shape;
+ tensor.dl_tensor.strides = strides;
+
+ thrust::host_vector host_vector(data.begin(), data.end());
+ tensor.dl_tensor.data = host_vector.data();
+
+ EXPECT_THROW(cudf::from_dlpack(&tensor), cudf::logic_error);
+}
+
+TEST_F(DLPackUntypedTests, UnsupportedImplicitRowMajor2DTensorFromDlpack)
+{
+ using T = float;
+ constexpr int ndim = 2;
+ // Row major 2D tensor
+ auto const data = cudf::test::make_type_param_vector({1, 2, 3, 4});
+ int64_t shape[ndim] = {2, 2};
+
+ DLManagedTensor tensor{};
+ tensor.dl_tensor.device.device_type = kDLCPU;
+ tensor.dl_tensor.dtype = get_dtype();
+ tensor.dl_tensor.ndim = ndim;
+ tensor.dl_tensor.byte_offset = 0;
+ tensor.dl_tensor.shape = shape;
+ tensor.dl_tensor.strides = nullptr;
+
+ thrust::host_vector host_vector(data.begin(), data.end());
+ tensor.dl_tensor.data = host_vector.data();
+
+ EXPECT_THROW(cudf::from_dlpack(&tensor), cudf::logic_error);
+}
+
+TEST_F(DLPackUntypedTests, UnsupportedExplicitRowMajor2DTensorFromDlpack)
+{
+ using T = float;
+ constexpr int ndim = 2;
+ // Row major 2D tensor with explicit strides
+ auto const data = cudf::test::make_type_param_vector({1, 2, 3, 4});
+ int64_t shape[ndim] = {2, 2};
+ int64_t strides[ndim] = {2, 1};
+
+ DLManagedTensor tensor{};
+ tensor.dl_tensor.device.device_type = kDLCPU;
+ tensor.dl_tensor.dtype = get_dtype();
+ tensor.dl_tensor.ndim = ndim;
+ tensor.dl_tensor.byte_offset = 0;
+ tensor.dl_tensor.shape = shape;
+ tensor.dl_tensor.strides = strides;
+
+ thrust::host_vector host_vector(data.begin(), data.end());
+ tensor.dl_tensor.data = host_vector.data();
+
+ EXPECT_THROW(cudf::from_dlpack(&tensor), cudf::logic_error);
+}
+
+TEST_F(DLPackUntypedTests, UnsupportedStridedColMajor2DTensorFromDlpack)
+{
+ using T = float;
+ constexpr int ndim = 2;
+ // Column major, but strided in fastest dimension
+ auto const data = cudf::test::make_type_param_vector({1, 2, 3, 4, 5, 6, 7, 8});
+ int64_t shape[ndim] = {2, 2};
+ int64_t strides[ndim] = {2, 4};
+
+ DLManagedTensor tensor{};
+ tensor.dl_tensor.device.device_type = kDLCPU;
+ tensor.dl_tensor.dtype = get_dtype();
+ tensor.dl_tensor.ndim = ndim;
+ tensor.dl_tensor.byte_offset = 0;
+ tensor.dl_tensor.shape = shape;
+ tensor.dl_tensor.strides = strides;
+
+ thrust::host_vector host_vector(data.begin(), data.end());
+ tensor.dl_tensor.data = host_vector.data();
+
+ EXPECT_THROW(cudf::from_dlpack(&tensor), cudf::logic_error);
+}
+
template
class DLPackTimestampTests : public BaseFixture {
};
diff --git a/docs/cudf/source/api_docs/general_functions.rst b/docs/cudf/source/api_docs/general_functions.rst
index a4a08a6be7f..c666bcb147d 100644
--- a/docs/cudf/source/api_docs/general_functions.rst
+++ b/docs/cudf/source/api_docs/general_functions.rst
@@ -23,6 +23,7 @@ Top-level conversions
:toctree: api/
cudf.to_numeric
+ cudf.from_dlpack
Top-level dealing with datetimelike
-----------------------------------