From 0a544c2ab3c14e3feff42380518f73778b8b3d7d Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 16 May 2024 10:15:58 +0100 Subject: [PATCH] Defer to C++ equality and hashing for pylibcudf DataType and Aggregation objects (#15732) Since the C++ layer provides implementations of these, use them, rather than redoing an implementation. This avoids things ever getting out of sync. Authors: - Lawrence Mitchell (https://github.com/wence-) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) URL: https://github.com/rapidsai/cudf/pull/15732 --- python/cudf/cudf/_lib/pylibcudf/aggregation.pyx | 8 ++++++++ python/cudf/cudf/_lib/pylibcudf/libcudf/aggregation.pxd | 3 +++ python/cudf/cudf/_lib/pylibcudf/libcudf/types.pxd | 5 +++-- python/cudf/cudf/_lib/pylibcudf/types.pyx | 6 +++--- 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/python/cudf/cudf/_lib/pylibcudf/aggregation.pyx b/python/cudf/cudf/_lib/pylibcudf/aggregation.pyx index 672b1ba2221..7bb64e32a1b 100644 --- a/python/cudf/cudf/_lib/pylibcudf/aggregation.pyx +++ b/python/cudf/cudf/_lib/pylibcudf/aggregation.pyx @@ -79,6 +79,14 @@ cdef class Aggregation: "Aggregations should not be constructed directly. Use one of the factories." ) + def __eq__(self, other): + return type(self) is type(other) and ( + dereference(self.c_obj).is_equal(dereference((other).c_obj)) + ) + + def __hash__(self): + return dereference(self.c_obj).do_hash() + # TODO: Ideally we would include the return type here, but we need to do so # in a way that Sphinx understands (currently have issues due to # https://github.com/cython/cython/issues/5609). diff --git a/python/cudf/cudf/_lib/pylibcudf/libcudf/aggregation.pxd b/python/cudf/cudf/_lib/pylibcudf/libcudf/aggregation.pxd index e0e01207589..8c14bc45723 100644 --- a/python/cudf/cudf/_lib/pylibcudf/libcudf/aggregation.pxd +++ b/python/cudf/cudf/_lib/pylibcudf/libcudf/aggregation.pxd @@ -1,4 +1,5 @@ # Copyright (c) 2020-2024, NVIDIA CORPORATION. +from libc.stddef cimport size_t from libc.stdint cimport int32_t from libcpp cimport bool from libcpp.memory cimport unique_ptr @@ -51,6 +52,8 @@ cdef extern from "cudf/aggregation.hpp" namespace "cudf" nogil: cdef cppclass aggregation: Kind kind unique_ptr[aggregation] clone() + size_t do_hash() noexcept + bool is_equal(const aggregation const) noexcept cdef cppclass rolling_aggregation(aggregation): pass diff --git a/python/cudf/cudf/_lib/pylibcudf/libcudf/types.pxd b/python/cudf/cudf/_lib/pylibcudf/libcudf/types.pxd index 13aebdff726..8e94ec296cf 100644 --- a/python/cudf/cudf/_lib/pylibcudf/libcudf/types.pxd +++ b/python/cudf/cudf/_lib/pylibcudf/libcudf/types.pxd @@ -88,8 +88,9 @@ cdef extern from "cudf/types.hpp" namespace "cudf" nogil: data_type(const data_type&) except + data_type(type_id id) except + data_type(type_id id, int32_t scale) except + - type_id id() except + - int32_t scale() except + + type_id id() noexcept + int32_t scale() noexcept + bool operator==(const data_type&, const data_type&) noexcept cpdef enum class interpolation(int32_t): LINEAR diff --git a/python/cudf/cudf/_lib/pylibcudf/types.pyx b/python/cudf/cudf/_lib/pylibcudf/types.pyx index ebe4d66fa20..de10196e289 100644 --- a/python/cudf/cudf/_lib/pylibcudf/types.pyx +++ b/python/cudf/cudf/_lib/pylibcudf/types.pyx @@ -47,9 +47,9 @@ cdef class DataType: return self.c_obj.scale() def __eq__(self, other): - if not isinstance(other, DataType): - return False - return self.id() == other.id() and self.scale() == other.scale() + return type(self) is type(other) and ( + self.c_obj == (other).c_obj + ) @staticmethod cdef DataType from_libcudf(data_type dt):