Skip to content

Commit

Permalink
Py_DECREF() PyArrayInterface
Browse files Browse the repository at this point in the history
  • Loading branch information
ijl committed Oct 1, 2021
1 parent 623a524 commit d52f2c9
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/serialize/numpy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ impl<'a> NumpyArray {
} else {
let num_dimensions = unsafe { (*array).nd as usize };
if num_dimensions == 0 {
ffi!(Py_DECREF(capsule));
return Err(PyArrayError::UnsupportedDataType);
}
match ItemType::find(array, ptr) {
Expand Down Expand Up @@ -221,7 +222,8 @@ impl<'a> NumpyArray {
impl Drop for NumpyArray {
fn drop(&mut self) {
if self.depth == 0 {
ffi!(Py_XDECREF(self.capsule as *mut pyo3::ffi::PyObject))
ffi!(Py_DECREF(self.array as *mut pyo3::ffi::PyObject));
ffi!(Py_DECREF(self.capsule as *mut pyo3::ffi::PyObject));
}
}
}
Expand Down Expand Up @@ -710,7 +712,9 @@ impl NumpyDatetimeUnit {
fn from_pyobject(ptr: *mut PyObject) -> Self {
let dtype = ffi!(PyObject_GetAttr(ptr, DTYPE_STR));
let descr = ffi!(PyObject_GetAttr(dtype, DESCR_STR));
ffi!(Py_DECREF(dtype));
let el0 = ffi!(PyList_GET_ITEM(descr, 0));
ffi!(Py_DECREF(descr));
let descr_str = ffi!(PyTuple_GET_ITEM(el0, 1));
let mut str_size: pyo3::ffi::Py_ssize_t = 0;
let uni = crate::unicode::read_utf8_from_str(descr_str, &mut str_size);
Expand Down
11 changes: 11 additions & 0 deletions test/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,3 +688,14 @@ def test_numpy_datetime_omit_microseconds(self):
),
b'{"year":"2021-01-01T00:00:00","month":"2021-01-01T00:00:00","day":"2021-01-01T00:00:00","hour":"2021-01-01T00:00:00","minute":"2021-01-01T00:00:00","second":"2021-01-01T00:00:00","milli":"2021-01-01T00:00:00","micro":"2021-01-01T00:00:00","nano":"2021-01-01T00:00:00"}',
)

def test_numpy_repeated(self):
data = numpy.array([[[1, 2], [3, 4], [5, 6], [7, 8]]], numpy.int64)
for _ in range(0, 3):
self.assertEqual(
orjson.dumps(
data,
option=orjson.OPT_SERIALIZE_NUMPY,
),
b"[[[1,2],[3,4],[5,6],[7,8]]]",
)

0 comments on commit d52f2c9

Please sign in to comment.