diff --git a/lib/src/common.cpp b/lib/src/common.cpp index 9cc0020..cccaaa0 100644 --- a/lib/src/common.cpp +++ b/lib/src/common.cpp @@ -1,4 +1,5 @@ #include "def.h" +#include "utils.h" #include "pybind11/pybind11.h" #include "pybind11/numpy.h" @@ -174,9 +175,14 @@ pybind11::array_t compute_column_nan_counts(const MatrixPointer& ma /** Grouped stats **/ -pybind11::array_t compute_row_sums_by_group(const MatrixPointer& mat, const pybind11::array_t& grouping, int num_threads) { - auto gptr = static_cast(grouping.request().ptr); - size_t ngroups = tatami_stats::total_groups(gptr, mat->ncol()); +pybind11::array_t compute_row_sums_by_group(const MatrixPointer& mat, const pybind11::array& grouping, int num_threads) { + auto gptr = check_numpy_array(grouping); + size_t ncol = mat->ncol(); + if (grouping.size() != ncol) { + throw std::runtime_error("'grouping' should have length equal to the number of columns"); + } + + size_t ngroups = tatami_stats::total_groups(gptr, ncol); size_t nrow = mat->nrow(); pybind11::array_t output({ nrow, ngroups }); @@ -192,9 +198,14 @@ pybind11::array_t compute_row_sums_by_group(const MatrixPointer& ma return output; } -pybind11::array_t compute_column_sums_by_group(const MatrixPointer& mat, const pybind11::array_t& grouping, int num_threads) { - auto gptr = static_cast(grouping.request().ptr); - size_t ngroups = tatami_stats::total_groups(gptr, mat->nrow()); +pybind11::array_t compute_column_sums_by_group(const MatrixPointer& mat, const pybind11::array& grouping, int num_threads) { + auto gptr = check_numpy_array(grouping); + size_t nrow = mat->nrow(); + if (grouping.size() != nrow) { + throw std::runtime_error("'grouping' should have length equal to the number of rows"); + } + + size_t ngroups = tatami_stats::total_groups(gptr, nrow); size_t ncol = mat->ncol(); pybind11::array_t output({ ncol, ngroups }); @@ -211,8 +222,13 @@ pybind11::array_t compute_column_sums_by_group(const MatrixPointer& } pybind11::array_t compute_row_variances_by_group(const MatrixPointer& mat, const pybind11::array_t& grouping, int num_threads) { - auto gptr = static_cast(grouping.request().ptr); - auto group_sizes = tatami_stats::tabulate_groups(gptr, mat->ncol()); + auto gptr = check_numpy_array(grouping); + size_t ncol = mat->ncol(); + if (grouping.size() != ncol) { + throw std::runtime_error("'grouping' should have length equal to the number of columns"); + } + + auto group_sizes = tatami_stats::tabulate_groups(gptr, ncol); size_t ngroups = group_sizes.size(); size_t nrow = mat->nrow(); pybind11::array_t output({ nrow, ngroups }); @@ -230,8 +246,13 @@ pybind11::array_t compute_row_variances_by_group(const MatrixPointe } pybind11::array_t compute_column_variances_by_group(const MatrixPointer& mat, const pybind11::array_t& grouping, int num_threads) { - auto gptr = static_cast(grouping.request().ptr); - auto group_sizes = tatami_stats::tabulate_groups(gptr, mat->ncol()); + auto gptr = check_numpy_array(grouping); + size_t nrow = mat->nrow(); + if (grouping.size() != nrow) { + throw std::runtime_error("'grouping' should have length equal to the number of rows"); + } + + auto group_sizes = tatami_stats::tabulate_groups(gptr, nrow); size_t ngroups = group_sizes.size(); size_t ncol = mat->ncol(); pybind11::array_t output({ ncol, ngroups }); @@ -249,8 +270,13 @@ pybind11::array_t compute_column_variances_by_group(const MatrixPoi } pybind11::array_t compute_row_medians_by_group(const MatrixPointer& mat, const pybind11::array_t& grouping, int num_threads) { - auto gptr = static_cast(grouping.request().ptr); - auto group_sizes = tatami_stats::tabulate_groups(gptr, mat->ncol()); + auto gptr = check_numpy_array(grouping); + size_t ncol = mat->ncol(); + if (grouping.size() != ncol) { + throw std::runtime_error("'grouping' should have length equal to the number of columns"); + } + + auto group_sizes = tatami_stats::tabulate_groups(gptr, ncol); size_t ngroups = group_sizes.size(); size_t nrow = mat->nrow(); pybind11::array_t output({ nrow, ngroups }); @@ -268,8 +294,13 @@ pybind11::array_t compute_row_medians_by_group(const MatrixPointer& } pybind11::array_t compute_column_medians_by_group(const MatrixPointer& mat, const pybind11::array_t& grouping, int num_threads) { - auto gptr = static_cast(grouping.request().ptr); - auto group_sizes = tatami_stats::tabulate_groups(gptr, mat->ncol()); + auto gptr = check_numpy_array(grouping); + size_t nrow = mat->nrow(); + if (grouping.size() != nrow) { + throw std::runtime_error("'grouping' should have length equal to the number of rows"); + } + + auto group_sizes = tatami_stats::tabulate_groups(gptr, nrow); size_t ngroups = group_sizes.size(); size_t ncol = mat->ncol(); pybind11::array_t output({ ncol, ngroups }); @@ -288,16 +319,16 @@ pybind11::array_t compute_column_medians_by_group(const MatrixPoint /** Extraction **/ -pybind11::array_t extract_dense_subset(MatrixPointer mat, - bool row_noop, const pybind11::array_t& row_sub, - bool col_noop, const pybind11::array_t& col_sub) -{ +pybind11::array_t extract_dense_subset(MatrixPointer mat, bool row_noop, const pybind11::array& row_sub, bool col_noop, const pybind11::array& col_sub) { if (!row_noop) { - auto tmp = tatami::make_DelayedSubset<0>(std::move(mat), tatami::ArrayView(row_sub.data(), row_sub.size())); + auto rptr = check_numpy_array(row_sub); + auto tmp = tatami::make_DelayedSubset<0>(std::move(mat), tatami::ArrayView(rptr, row_sub.size())); mat.swap(tmp); } + if (!col_noop) { - auto tmp = tatami::make_DelayedSubset<1>(std::move(mat), tatami::ArrayView(col_sub.data(), col_sub.size())); + auto cptr = check_numpy_array(col_sub); + auto tmp = tatami::make_DelayedSubset<1>(std::move(mat), tatami::ArrayView(cptr, col_sub.size())); mat.swap(tmp); } @@ -308,16 +339,16 @@ pybind11::array_t extract_dense_subset(MatrixPointer mat, return output; } -pybind11::object extract_sparse_subset(MatrixPointer mat, - bool row_noop, const pybind11::array_t& row_sub, - bool col_noop, const pybind11::array_t& col_sub) -{ +pybind11::object extract_sparse_subset(MatrixPointer mat, bool row_noop, const pybind11::array& row_sub, bool col_noop, const pybind11::array& col_sub) { if (!row_noop) { - auto tmp = tatami::make_DelayedSubset<0>(std::move(mat), tatami::ArrayView(row_sub.data(), row_sub.size())); + auto rptr = check_numpy_array(row_sub); + auto tmp = tatami::make_DelayedSubset<0>(std::move(mat), tatami::ArrayView(rptr, row_sub.size())); mat.swap(tmp); } + if (!col_noop) { - auto tmp = tatami::make_DelayedSubset<1>(std::move(mat), tatami::ArrayView(col_sub.data(), col_sub.size())); + auto cptr = check_numpy_array(col_sub); + auto tmp = tatami::make_DelayedSubset<1>(std::move(mat), tatami::ArrayView(cptr, col_sub.size())); mat.swap(tmp); } diff --git a/lib/src/compressed_sparse_matrix.cpp b/lib/src/compressed_sparse_matrix.cpp index 94dfdd0..053823f 100644 --- a/lib/src/compressed_sparse_matrix.cpp +++ b/lib/src/compressed_sparse_matrix.cpp @@ -1,4 +1,5 @@ #include "def.h" +#include "utils.h" #include "pybind11/pybind11.h" #include "pybind11/numpy.h" @@ -8,65 +9,77 @@ #include template -MatrixPointer initialize_compressed_sparse_matrix_raw(MatrixIndex nr, MatrixValue nc, const Data_* dptr, const Index_* iptr, const pybind11::array_t& indptr, bool byrow) { - size_t nz = indptr.at(indptr.size() - 1); - tatami::ArrayView dview(dptr, nz); - tatami::ArrayView iview(iptr, nz); - tatami::ArrayView pview(static_cast(indptr.request().ptr), indptr.size()); - return MatrixPointer(new tatami::CompressedSparseMatrix(nr, nc, std::move(dview), std::move(iview), std::move(pview), byrow)); +MatrixPointer initialize_compressed_sparse_matrix_raw(MatrixIndex nr, MatrixValue nc, const pybind11::array& data, const pybind11::array& index, const pybind11::array& indptr, bool byrow) { + size_t expected = (byrow ? nr : nc); + if (indptr.size() != expected + 1) { + throw std::runtime_error("unexpected length for the 'indptr' array"); + } + tatami::ArrayView pview(check_numpy_array(indptr), indptr.size()); + + size_t nz = pview[pview.size() - 1]; + if (data.size() != nz) { + throw std::runtime_error("unexpected length for the 'data' array"); + } + tatami::ArrayView dview(check_contiguous_numpy_array(data), nz); + + if (data.size() != nz) { + throw std::runtime_error("unexpected length for the 'data' array"); + } + tatami::ArrayView iview(check_contiguous_numpy_array(index), nz); + + typedef tatami::CompressedSparseMatrix Spmat; + return MatrixPointer(new Spmat(nr, nc, std::move(dview), std::move(iview), std::move(pview), byrow)); } template -MatrixPointer initialize_compressed_sparse_matrix_itype(MatrixIndex nr, MatrixValue nc, const Data_* dptr, const pybind11::array& index, const pybind11::array_t& indptr, bool byrow) { +MatrixPointer initialize_compressed_sparse_matrix_itype(MatrixIndex nr, MatrixValue nc, const pybind11::array& data, const pybind11::array& index, const pybind11::array& indptr, bool byrow) { auto dtype = index.dtype(); - auto iptr = index.request().ptr; if (dtype.is(pybind11::dtype::of())) { - return initialize_compressed_sparse_matrix_raw(nr, nc, dptr, static_cast(iptr), indptr, byrow); + return initialize_compressed_sparse_matrix_raw(nr, nc, data, index, indptr, byrow); } else if (dtype.is(pybind11::dtype::of())) { - return initialize_compressed_sparse_matrix_raw(nr, nc, dptr, static_cast(iptr), indptr, byrow); + return initialize_compressed_sparse_matrix_raw(nr, nc, data, index, indptr, byrow); } else if (dtype.is(pybind11::dtype::of())) { - return initialize_compressed_sparse_matrix_raw(nr, nc, dptr, static_cast(iptr), indptr, byrow); + return initialize_compressed_sparse_matrix_raw(nr, nc, data, index, indptr, byrow); } else if (dtype.is(pybind11::dtype::of())) { - return initialize_compressed_sparse_matrix_raw(nr, nc, dptr, static_cast(iptr), indptr, byrow); + return initialize_compressed_sparse_matrix_raw(nr, nc, data, index, indptr, byrow); } else if (dtype.is(pybind11::dtype::of())) { - return initialize_compressed_sparse_matrix_raw(nr, nc, dptr, static_cast(iptr), indptr, byrow); + return initialize_compressed_sparse_matrix_raw(nr, nc, data, index, indptr, byrow); } else if (dtype.is(pybind11::dtype::of())) { - return initialize_compressed_sparse_matrix_raw(nr, nc, dptr, static_cast(iptr), indptr, byrow); + return initialize_compressed_sparse_matrix_raw(nr, nc, data, index, indptr, byrow); } else if (dtype.is(pybind11::dtype::of())) { - return initialize_compressed_sparse_matrix_raw(nr, nc, dptr, static_cast(iptr), indptr, byrow); + return initialize_compressed_sparse_matrix_raw(nr, nc, data, index, indptr, byrow); } else if (dtype.is(pybind11::dtype::of())) { - return initialize_compressed_sparse_matrix_raw(nr, nc, dptr, static_cast(iptr), indptr, byrow); + return initialize_compressed_sparse_matrix_raw(nr, nc, data, index, indptr, byrow); } throw std::runtime_error("unrecognized index type '" + std::string(dtype.kind(), 1) + std::to_string(dtype.itemsize()) + "' for compressed sparse matrix initialization"); return MatrixPointer(); } -MatrixPointer initialize_compressed_sparse_matrix(MatrixIndex nr, MatrixValue nc, const pybind11::array& data, const pybind11::array& index, const pybind11::array_t& indptr, bool byrow) { +MatrixPointer initialize_compressed_sparse_matrix(MatrixIndex nr, MatrixValue nc, const pybind11::array& data, const pybind11::array& index, const pybind11::array& indptr, bool byrow) { auto dtype = data.dtype(); - auto dptr = data.request().ptr; if (dtype.is(pybind11::dtype::of())) { - return initialize_compressed_sparse_matrix_itype(nr, nc, reinterpret_cast(dptr), index, indptr, byrow); - } else if (dtype.is(pybind11::dtype::of())) { - return initialize_compressed_sparse_matrix_itype(nr, nc, reinterpret_cast(dptr), index, indptr, byrow); + return initialize_compressed_sparse_matrix_itype< double>(nr, nc, data, index, indptr, byrow); + } else if (dtype.is(pybind11::dtype::of())) { + return initialize_compressed_sparse_matrix_itype< float>(nr, nc, data, index, indptr, byrow); } else if (dtype.is(pybind11::dtype::of())) { - return initialize_compressed_sparse_matrix_itype(nr, nc, reinterpret_cast(dptr), index, indptr, byrow); + return initialize_compressed_sparse_matrix_itype< int64_t>(nr, nc, data, index, indptr, byrow); } else if (dtype.is(pybind11::dtype::of())) { - return initialize_compressed_sparse_matrix_itype(nr, nc, reinterpret_cast(dptr), index, indptr, byrow); + return initialize_compressed_sparse_matrix_itype< int32_t>(nr, nc, data, index, indptr, byrow); } else if (dtype.is(pybind11::dtype::of())) { - return initialize_compressed_sparse_matrix_itype(nr, nc, reinterpret_cast(dptr), index, indptr, byrow); + return initialize_compressed_sparse_matrix_itype< int16_t>(nr, nc, data, index, indptr, byrow); } else if (dtype.is(pybind11::dtype::of())) { - return initialize_compressed_sparse_matrix_itype(nr, nc, reinterpret_cast(dptr), index, indptr, byrow); + return initialize_compressed_sparse_matrix_itype< int8_t>(nr, nc, data, index, indptr, byrow); } else if (dtype.is(pybind11::dtype::of())) { - return initialize_compressed_sparse_matrix_itype(nr, nc, reinterpret_cast(dptr), index, indptr, byrow); + return initialize_compressed_sparse_matrix_itype(nr, nc, data, index, indptr, byrow); } else if (dtype.is(pybind11::dtype::of())) { - return initialize_compressed_sparse_matrix_itype(nr, nc, reinterpret_cast(dptr), index, indptr, byrow); + return initialize_compressed_sparse_matrix_itype(nr, nc, data, index, indptr, byrow); } else if (dtype.is(pybind11::dtype::of())) { - return initialize_compressed_sparse_matrix_itype(nr, nc, reinterpret_cast(dptr), index, indptr, byrow); + return initialize_compressed_sparse_matrix_itype(nr, nc, data, index, indptr, byrow); } else if (dtype.is(pybind11::dtype::of())) { - return initialize_compressed_sparse_matrix_itype(nr, nc, reinterpret_cast(dptr), index, indptr, byrow); + return initialize_compressed_sparse_matrix_itype< uint8_t>(nr, nc, data, index, indptr, byrow); } throw std::runtime_error("unrecognized data type '" + std::string(dtype.kind(), 1) + std::to_string(dtype.itemsize()) + "' for compressed sparse matrix initialization"); diff --git a/lib/src/delayed_subset.cpp b/lib/src/delayed_subset.cpp index c6ceb57..ec71733 100644 --- a/lib/src/delayed_subset.cpp +++ b/lib/src/delayed_subset.cpp @@ -1,4 +1,5 @@ #include "def.h" +#include "utils.h" #include "pybind11/pybind11.h" #include "pybind11/numpy.h" @@ -6,8 +7,9 @@ #include #include -MatrixPointer initialize_delayed_subset(MatrixPointer mat, const pybind11::array_t& subset, bool byrow) { - return tatami::make_DelayedSubset(std::move(mat), tatami::ArrayView(static_cast(subset.request().ptr), subset.size()), byrow); +MatrixPointer initialize_delayed_subset(MatrixPointer mat, const pybind11::array& subset, bool byrow) { + auto sptr = check_numpy_array(subset); + return tatami::make_DelayedSubset(std::move(mat), tatami::ArrayView(sptr, subset.size()), byrow); } void init_delayed_subset(pybind11::module& m) { diff --git a/lib/src/delayed_unary_isometric_operation_with_args.cpp b/lib/src/delayed_unary_isometric_operation_with_args.cpp index e4ec0e7..e68141a 100644 --- a/lib/src/delayed_unary_isometric_operation_with_args.cpp +++ b/lib/src/delayed_unary_isometric_operation_with_args.cpp @@ -1,4 +1,5 @@ #include "def.h" +#include "utils.h" #include "pybind11/pybind11.h" #include "pybind11/numpy.h" @@ -9,8 +10,13 @@ #include template -MatrixPointer initialize_delayed_unary_isometric_operation_with_vector_internal(MatrixPointer mat, const std::string& op, bool by_row, const pybind11::array_t& arg) { - tatami::ArrayView aview(static_cast(arg.request().ptr), by_row ? mat->nrow() : mat->ncol()); +MatrixPointer initialize_delayed_unary_isometric_operation_with_vector_internal(MatrixPointer mat, const std::string& op, bool by_row, const pybind11::array& arg) { + auto aptr = check_numpy_array(arg); + size_t expected = by_row ? mat->nrow() : mat->ncol(); + if (expected != arg.size()) { + throw std::runtime_error("unexpected length of array for isometric unary operation"); + } + tatami::ArrayView aview(aptr, expected); if (op == "add") { return tatami::make_DelayedUnaryIsometricOperation(std::move(mat), tatami::make_DelayedUnaryIsometricAddVector(std::move(aview), by_row)); @@ -52,7 +58,7 @@ MatrixPointer initialize_delayed_unary_isometric_operation_with_vector_internal( return MatrixPointer(); } -MatrixPointer initialize_delayed_unary_isometric_operation_with_vector(MatrixPointer mat, const std::string& op, bool right, bool by_row, const pybind11::array_t& args) { +MatrixPointer initialize_delayed_unary_isometric_operation_with_vector(MatrixPointer mat, const std::string& op, bool right, bool by_row, const pybind11::array& args) { if (right) { return initialize_delayed_unary_isometric_operation_with_vector_internal(std::move(mat), op, by_row, args); } else { diff --git a/lib/src/dense_matrix.cpp b/lib/src/dense_matrix.cpp index b2eeb9a..ec1ce74 100644 --- a/lib/src/dense_matrix.cpp +++ b/lib/src/dense_matrix.cpp @@ -1,4 +1,6 @@ #include "def.h" +#include "utils.h" + #include "tatami/tatami.hpp" #include "pybind11/pybind11.h" @@ -9,8 +11,24 @@ #include template -MatrixPointer initialize_dense_matrix_internal(MatrixIndex nr, MatrixIndex nc, const Type_* ptr, bool byrow) { - tatami::ArrayView view(ptr, static_cast(nr) * static_cast(nc)); +MatrixPointer initialize_dense_matrix_internal(MatrixIndex nr, MatrixIndex nc, const pybind11::array& buffer) { + size_t expected = static_cast(nr) * static_cast(nc); + if (buffer.size() != expected) { + throw std::runtime_error("unexpected size for the dense matrix buffer"); + } + + auto flag = buffer.flags(); + bool byrow = false; + if (flag & pybind11::array::c_style) { + byrow = true; + } else if (flag & pybind11::array::f_style) { + byrow = false; + } else { + throw std::runtime_error("numpy array contents should be contiguous"); + } + + auto ptr = get_numpy_array_data(buffer); + tatami::ArrayView view(ptr, expected); return MatrixPointer(new tatami::DenseMatrix(nr, nc, std::move(view), byrow)); } @@ -19,28 +37,27 @@ MatrixPointer initialize_dense_matrix(MatrixIndex nr, MatrixIndex nc, const pybi // order, as this should be handled by the caller; we don't provide any // protection from GC for the arrays referenced by the views. auto dtype = buffer.dtype(); - bool byrow = buffer.flags() & pybind11::array::c_style; if (dtype.is(pybind11::dtype::of())) { - return initialize_dense_matrix_internal(nr, nc, reinterpret_cast(buffer.request().ptr), byrow); + return initialize_dense_matrix_internal< double>(nr, nc, buffer); } else if (dtype.is(pybind11::dtype::of())) { - return initialize_dense_matrix_internal(nr, nc, reinterpret_cast(buffer.request().ptr), byrow); + return initialize_dense_matrix_internal< float>(nr, nc, buffer); } else if (dtype.is(pybind11::dtype::of())) { - return initialize_dense_matrix_internal(nr, nc, reinterpret_cast(buffer.request().ptr), byrow); + return initialize_dense_matrix_internal< int64_t>(nr, nc, buffer); } else if (dtype.is(pybind11::dtype::of())) { - return initialize_dense_matrix_internal(nr, nc, reinterpret_cast(buffer.request().ptr), byrow); + return initialize_dense_matrix_internal< int32_t>(nr, nc, buffer); } else if (dtype.is(pybind11::dtype::of())) { - return initialize_dense_matrix_internal(nr, nc, reinterpret_cast(buffer.request().ptr), byrow); + return initialize_dense_matrix_internal< int16_t>(nr, nc, buffer); } else if (dtype.is(pybind11::dtype::of())) { - return initialize_dense_matrix_internal(nr, nc, reinterpret_cast(buffer.request().ptr), byrow); + return initialize_dense_matrix_internal< int8_t>(nr, nc, buffer); } else if (dtype.is(pybind11::dtype::of())) { - return initialize_dense_matrix_internal(nr, nc, reinterpret_cast(buffer.request().ptr), byrow); - } else if (dtype.is(pybind11::dtype::of())) { - return initialize_dense_matrix_internal(nr, nc, reinterpret_cast(buffer.request().ptr), byrow); + return initialize_dense_matrix_internal(nr, nc, buffer); + } else if (dtype.is(pybind11::dtype::of())) { + return initialize_dense_matrix_internal(nr, nc, buffer); } else if (dtype.is(pybind11::dtype::of())) { - return initialize_dense_matrix_internal(nr, nc, reinterpret_cast(buffer.request().ptr), byrow); + return initialize_dense_matrix_internal(nr, nc, buffer); } else if (dtype.is(pybind11::dtype::of())) { - return initialize_dense_matrix_internal(nr, nc, reinterpret_cast(buffer.request().ptr), byrow); + return initialize_dense_matrix_internal< uint8_t>(nr, nc, buffer); } throw std::runtime_error("unrecognized array type '" + std::string(dtype.kind(), 1) + std::to_string(dtype.itemsize()) + "' for dense matrix initialization"); diff --git a/lib/src/fragmented_sparse_matrix.cpp b/lib/src/fragmented_sparse_matrix.cpp index 0f3e9cd..2fe962d 100644 --- a/lib/src/fragmented_sparse_matrix.cpp +++ b/lib/src/fragmented_sparse_matrix.cpp @@ -1,4 +1,5 @@ #include "def.h" +#include "utils.h" #include "pybind11/pybind11.h" #include "pybind11/numpy.h" @@ -23,11 +24,15 @@ MatrixPointer initialize_fragmented_sparse_matrix_raw(MatrixIndex nr, MatrixValu } // This better not involve any copies. - auto castdata = curdata.cast >(); - data_vec.emplace_back(static_cast(castdata.request().ptr), castdata.size()); + auto castdata = curdata.cast(); auto curidx = indices[i]; - auto castidx = curidx.cast >(); - idx_vec.emplace_back(static_cast(castidx.request().ptr), castidx.size()); + auto castidx = curidx.cast(); + + if (castdata.size() != castidx.size()) { + throw std::runtime_error("mismatching lengths for the index/data vectors"); + } + data_vec.emplace_back(check_numpy_array(castdata), castdata.size()); + idx_vec.emplace_back(check_numpy_array(castidx), castidx.size()); } return MatrixPointer(new tatami::FragmentedSparseMatrix(nr, nc, std::move(data_vec), std::move(idx_vec), byrow, false)); diff --git a/lib/src/utils.h b/lib/src/utils.h new file mode 100644 index 0000000..621e430 --- /dev/null +++ b/lib/src/utils.h @@ -0,0 +1,37 @@ +#ifndef UTILS_H +#define UTILS_H + +#include "pybind11/pybind11.h" +#include "pybind11/numpy.h" + +#include + +// As a general rule, we avoid using pybind11::array_t as function arguments, +// because pybind11 might auto-cast and create an allocation that we then +// create a view on; on function exit, our view would be a dangling reference +// once the allocation gets destructed. So, we accept instead a pybind11::array +// and make sure it has our desired type and contiguous storage. + +template +const Expected_* get_numpy_array_data(const pybind11::array& x) { + return static_cast(x.request().ptr); +} + +template +const Expected_* check_contiguous_numpy_array(const pybind11::array& x) { + auto flag = x.flags(); + if (!(flag & pybind11::array::c_style) || !(flag & pybind11::array::f_style)) { + throw std::runtime_error("NumPy array contents should be contiguous"); + } + return get_numpy_array_data(x); +} + +template +const Expected_* check_numpy_array(const pybind11::array& x) { + if (!x.dtype().is(pybind11::dtype::of())) { + throw std::runtime_error("unexpected dtype for NumPy array"); + } + return check_contiguous_numpy_array(x); +} + +#endif