Skip to content

Commit

Permalink
Add DecimalDtype to cuDF (#6675)
Browse files Browse the repository at this point in the history
  • Loading branch information
codereport authored Nov 7, 2020
1 parent aac682b commit 1c31f0e
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
- PR #6592 Add `cudf.to_numeric` function
- PR #6598 Add strings::contains API with target column parameter
- PR #6638 Add support for `pipe` API
- PR #6675 Add DecimalDtype to cuDF

## Improvements

Expand Down
30 changes: 30 additions & 0 deletions python/cudf/cudf/core/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2020, NVIDIA CORPORATION.

import decimal
import pickle

import numpy as np
Expand Down Expand Up @@ -211,3 +212,32 @@ def __eq__(self, other):

def __repr__(self):
return f"StructDtype({self.fields})"


class DecimalDtype(ExtensionDtype):

name = "decimal"
_metadata = ("precision", "scale")

def __init__(self, precision, scale):
self._typ = pa.decimal128(precision, scale)

@property
def precision(self):
return self._typ.precision

@property
def scale(self):
return self._typ.scale

@property
def type(self):
# might need to account for precision and scale here
return decimal.Decimal

def to_arrow(self):
return self._typ

@classmethod
def from_arrow(cls, typ):
return cls(typ.precision, typ.scale)
13 changes: 12 additions & 1 deletion python/cudf/cudf/tests/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
import pytest

import cudf
from cudf.core.dtypes import CategoricalDtype, ListDtype, StructDtype
from cudf.core.dtypes import (
CategoricalDtype,
ListDtype,
StructDtype,
DecimalDtype,
)
from cudf.tests.utils import assert_eq


Expand Down Expand Up @@ -128,3 +133,9 @@ def test_struct_dtype_fields(fields):
fields = {"a": "int32", "b": StructDtype({"c": "int64", "d": "int32"})}
dt = StructDtype(fields)
assert_eq(dt.fields, fields)


def test_decimal_dtype():
dt = DecimalDtype(4, 2)
assert dt.to_arrow() == pa.decimal128(4, 2)
assert dt == DecimalDtype.from_arrow(pa.decimal128(4, 2))

0 comments on commit 1c31f0e

Please sign in to comment.