diff --git a/src/serialize/serializer.rs b/src/serialize/serializer.rs index 59fdfc7d..921c5628 100644 --- a/src/serialize/serializer.rs +++ b/src/serialize/serializer.rs @@ -257,7 +257,10 @@ impl<'p> Serialize for PyObjectSerializer { err!(RECURSION_LIMIT_REACHED) } let dict = ffi!(PyObject_GetAttr(self.ptr, DICT_STR)); - if unlikely!(dict.is_null()) { + let ob_type = ob_type!(self.ptr); + if unlikely!( + dict.is_null() || ffi!(PyDict_Contains((*ob_type).tp_dict, SLOTS_STR)) == 1 + ) { unsafe { pyo3::ffi::PyErr_Clear() }; DataclassFallbackSerializer::new( self.ptr, diff --git a/src/typeref.rs b/src/typeref.rs index 3f66bf25..f56f371a 100644 --- a/src/typeref.rs +++ b/src/typeref.rs @@ -53,6 +53,7 @@ pub static mut EMPTY_UNICODE: *mut PyObject = 0 as *mut PyObject; pub static mut DST_STR: *mut PyObject = 0 as *mut PyObject; pub static mut DICT_STR: *mut PyObject = 0 as *mut PyObject; pub static mut DATACLASS_FIELDS_STR: *mut PyObject = 0 as *mut PyObject; +pub static mut SLOTS_STR: *mut PyObject = 0 as *mut PyObject; pub static mut FIELD_TYPE_STR: *mut PyObject = 0 as *mut PyObject; pub static mut ARRAY_STRUCT_STR: *mut PyObject = 0 as *mut PyObject; pub static mut DTYPE_STR: *mut PyObject = 0 as *mut PyObject; @@ -122,6 +123,7 @@ pub fn init_typerefs() { DICT_STR = PyUnicode_InternFromString("__dict__\0".as_ptr() as *const c_char); DATACLASS_FIELDS_STR = PyUnicode_InternFromString("__dataclass_fields__\0".as_ptr() as *const c_char); + SLOTS_STR = PyUnicode_InternFromString("__slots__\0".as_ptr() as *const c_char); FIELD_TYPE_STR = PyUnicode_InternFromString("_field_type\0".as_ptr() as *const c_char); ARRAY_STRUCT_STR = pyo3::ffi::PyUnicode_InternFromString("__array_struct__\0".as_ptr() as *const c_char); diff --git a/test/test_dataclass.py b/test/test_dataclass.py index 6c7be238..f47074b2 100644 --- a/test/test_dataclass.py +++ b/test/test_dataclass.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: (Apache-2.0 OR MIT) +import abc import unittest import uuid from dataclasses import InitVar, asdict, dataclass, field @@ -95,6 +96,23 @@ def __post_init__(self, a: str, b: str): self.ab = f"{a} {b}" +class AbstractBase(abc.ABC): + @abc.abstractmethod + def key(self): + raise NotImplementedError + + +@dataclass(frozen=True) +class ConcreteAbc(AbstractBase): + + __slots__ = ("attr",) + + attr: float + + def key(self): + return "dkjf" + + class DataclassTests(unittest.TestCase): def test_dataclass(self): """ @@ -291,3 +309,9 @@ def default(obj): orjson.dumps(obj, option=orjson.OPT_PASSTHROUGH_DATACLASS, default=default), b'{"name":"a","number":1}', ) + + +class AbstractDataclassTests(unittest.TestCase): + def test_dataclass_abc(self): + obj = ConcreteAbc(1.0) + self.assertEqual(orjson.dumps(obj), b'{"attr":1.0}')