Skip to content

Commit

Permalink
feat: add nw.Int128, nw.UInt128 (#1570)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Dec 12, 2024
1 parent d94775a commit f0395ae
Show file tree
Hide file tree
Showing 11 changed files with 59 additions and 9 deletions.
2 changes: 2 additions & 0 deletions docs/api-reference/dtypes.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
members:
- Array
- List
- Int128
- Int64
- Int32
- Int16
- Int8
- UInt128
- UInt64
- UInt32
- UInt16
Expand Down
4 changes: 4 additions & 0 deletions narwhals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from narwhals.dtypes import Int16
from narwhals.dtypes import Int32
from narwhals.dtypes import Int64
from narwhals.dtypes import Int128
from narwhals.dtypes import List
from narwhals.dtypes import Object
from narwhals.dtypes import String
Expand All @@ -29,6 +30,7 @@
from narwhals.dtypes import UInt16
from narwhals.dtypes import UInt32
from narwhals.dtypes import UInt64
from narwhals.dtypes import UInt128
from narwhals.dtypes import Unknown
from narwhals.expr import Expr
from narwhals.expr import all_ as all
Expand Down Expand Up @@ -94,6 +96,7 @@
"Int16",
"Int32",
"Int64",
"Int128",
"LazyFrame",
"List",
"Object",
Expand All @@ -105,6 +108,7 @@
"UInt16",
"UInt32",
"UInt64",
"UInt128",
"Unknown",
"all",
"all_horizontal",
Expand Down
6 changes: 5 additions & 1 deletion narwhals/_duckdb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
@lru_cache(maxsize=16)
def native_to_narwhals_dtype(duckdb_dtype: str, version: Version) -> DType:
dtypes = import_dtypes_module(version)
if duckdb_dtype == "HUGEINT":
return dtypes.Int128()
if duckdb_dtype == "BIGINT":
return dtypes.Int64()
if duckdb_dtype == "INTEGER":
Expand All @@ -32,6 +34,8 @@ def native_to_narwhals_dtype(duckdb_dtype: str, version: Version) -> DType:
return dtypes.Int16()
if duckdb_dtype == "TINYINT":
return dtypes.Int8()
if duckdb_dtype == "UHUGEINT":
return dtypes.UInt128()
if duckdb_dtype == "UBIGINT":
return dtypes.UInt64()
if duckdb_dtype == "UINTEGER":
Expand Down Expand Up @@ -72,7 +76,7 @@ def native_to_narwhals_dtype(duckdb_dtype: str, version: Version) -> DType:
native_to_narwhals_dtype(match_.group(1), version),
int(match_.group(2)),
)
return dtypes.Unknown()
return dtypes.Unknown() # pragma: no cover


class DuckDBInterchangeFrame:
Expand Down
6 changes: 6 additions & 0 deletions narwhals/_polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def native_to_narwhals_dtype(
return dtypes.Float64()
if dtype == pl.Float32:
return dtypes.Float32()
if dtype == getattr(pl, "Int128", None): # type: ignore[operator] # pragma: no cover
# Not available for Polars pre 1.8.0
return dtypes.Int128()
if dtype == pl.Int64:
return dtypes.Int64()
if dtype == pl.Int32:
Expand All @@ -86,6 +89,9 @@ def native_to_narwhals_dtype(
return dtypes.Int16()
if dtype == pl.Int8:
return dtypes.Int8()
if dtype == getattr(pl, "UInt128", None): # type: ignore[operator] # pragma: no cover
# Not available for Polars pre 1.8.0
return dtypes.UInt128()
if dtype == pl.UInt64:
return dtypes.UInt64()
if dtype == pl.UInt32:
Expand Down
8 changes: 8 additions & 0 deletions narwhals/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ class NumericType(DType): ...
class TemporalType(DType): ...


class Int128(NumericType):
"""128-bit signed integer type."""


class Int64(NumericType):
"""64-bit signed integer type.
Expand Down Expand Up @@ -147,6 +151,10 @@ class Int8(NumericType):
"""


class UInt128(NumericType):
"""128-bit unsigned integer type."""


class UInt64(NumericType):
"""64-bit unsigned integer type.
Expand Down
4 changes: 4 additions & 0 deletions narwhals/stable/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from narwhals.stable.v1.dtypes import Int16
from narwhals.stable.v1.dtypes import Int32
from narwhals.stable.v1.dtypes import Int64
from narwhals.stable.v1.dtypes import Int128
from narwhals.stable.v1.dtypes import List
from narwhals.stable.v1.dtypes import Object
from narwhals.stable.v1.dtypes import String
Expand All @@ -54,6 +55,7 @@
from narwhals.stable.v1.dtypes import UInt16
from narwhals.stable.v1.dtypes import UInt32
from narwhals.stable.v1.dtypes import UInt64
from narwhals.stable.v1.dtypes import UInt128
from narwhals.stable.v1.dtypes import Unknown
from narwhals.translate import _from_native_impl
from narwhals.translate import get_native_namespace
Expand Down Expand Up @@ -3519,6 +3521,7 @@ def scan_csv(
"Int16",
"Int32",
"Int64",
"Int128",
"LazyFrame",
"List",
"Object",
Expand All @@ -3530,6 +3533,7 @@ def scan_csv(
"UInt16",
"UInt32",
"UInt64",
"UInt128",
"Unknown",
"all",
"all_horizontal",
Expand Down
4 changes: 4 additions & 0 deletions narwhals/stable/v1/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from narwhals.dtypes import Int16
from narwhals.dtypes import Int32
from narwhals.dtypes import Int64
from narwhals.dtypes import Int128
from narwhals.dtypes import List
from narwhals.dtypes import NumericType
from narwhals.dtypes import Object
Expand All @@ -24,6 +25,7 @@
from narwhals.dtypes import UInt16
from narwhals.dtypes import UInt32
from narwhals.dtypes import UInt64
from narwhals.dtypes import UInt128
from narwhals.dtypes import Unknown


Expand Down Expand Up @@ -118,6 +120,7 @@ def __hash__(self) -> int:
"Int16",
"Int32",
"Int64",
"Int128",
"List",
"NumericType",
"Object",
Expand All @@ -127,5 +130,6 @@ def __hash__(self) -> int:
"UInt16",
"UInt32",
"UInt64",
"UInt128",
"Unknown",
]
4 changes: 4 additions & 0 deletions narwhals/stable/v1/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from narwhals.stable.v1._dtypes import Int16
from narwhals.stable.v1._dtypes import Int32
from narwhals.stable.v1._dtypes import Int64
from narwhals.stable.v1._dtypes import Int128
from narwhals.stable.v1._dtypes import List
from narwhals.stable.v1._dtypes import NumericType
from narwhals.stable.v1._dtypes import Object
Expand All @@ -24,6 +25,7 @@
from narwhals.stable.v1._dtypes import UInt16
from narwhals.stable.v1._dtypes import UInt32
from narwhals.stable.v1._dtypes import UInt64
from narwhals.stable.v1._dtypes import UInt128
from narwhals.stable.v1._dtypes import Unknown

__all__ = [
Expand All @@ -42,6 +44,7 @@
"Int16",
"Int32",
"Int64",
"Int128",
"List",
"NumericType",
"Object",
Expand All @@ -51,5 +54,6 @@
"UInt16",
"UInt32",
"UInt64",
"UInt128",
"Unknown",
]
2 changes: 2 additions & 0 deletions narwhals/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,12 @@ def __dataframe__(self, *args: Any, **kwargs: Any) -> Any: ...


class DTypes:
Int128: type[dtypes.Int128]
Int64: type[dtypes.Int64]
Int32: type[dtypes.Int32]
Int16: type[dtypes.Int16]
Int8: type[dtypes.Int8]
UInt128: type[dtypes.UInt128]
UInt64: type[dtypes.UInt64]
UInt32: type[dtypes.UInt32]
UInt16: type[dtypes.UInt16]
Expand Down
20 changes: 20 additions & 0 deletions tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import timezone
from typing import Literal

import duckdb
import numpy as np
import pandas as pd
import polars as pl
Expand Down Expand Up @@ -197,3 +198,22 @@ def test_pandas_fixed_offset_1302() -> None:
assert result == nw.Datetime("ns", "+01:00")
else: # pragma: no cover
pass


def test_huge_int() -> None:
df = pl.DataFrame({"a": [1, 2, 3]}) # noqa: F841
rel = duckdb.sql("""
select cast(a as int128) as a
from df
""")
result = nw.from_native(rel).schema
assert result["a"] == nw.Int128
rel = duckdb.sql("""
select cast(a as uint128) as a
from df
""")
result = nw.from_native(rel).schema
assert result["a"] == nw.UInt128

# TODO(unassigned): once other libraries support Int128/UInt128,
# add tests for them too
8 changes: 0 additions & 8 deletions tests/frame/interchange_schema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from datetime import timedelta

import duckdb
import pandas as pd
import polars as pl
import pytest

Expand Down Expand Up @@ -243,10 +242,3 @@ def test_get_level() -> None:
nw.get_level(nw.from_native(df.__dataframe__(), eager_or_interchange_only=True))
== "interchange"
)


def test_unknown_dtype() -> None:
df = pd.DataFrame({"a": [1, 2, 3]})
rel = duckdb.from_df(df).select("cast(a as int128) as a")
result = nw.from_native(rel).schema
assert result == {"a": nw.Unknown}

0 comments on commit f0395ae

Please sign in to comment.