Skip to content

Commit

Permalink
Use max precision instead of nil for decimal dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
philss committed Oct 14, 2024
1 parent bff71d4 commit 2b535a0
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 21 deletions.
27 changes: 21 additions & 6 deletions lib/explorer/polars_backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -300,14 +300,27 @@ defmodule Explorer.PolarsBackend.Series do
# Arithmetic

@impl true
def add(_out_dtype, left, right),
do: Shared.apply_series(matching_size!(left, right), :s_add, [right.data])
def add(out_dtype, left, right) do
result = Shared.apply_series(matching_size!(left, right), :s_add, [right.data])

if match?({:decimal, _, _}, out_dtype) and out_dtype != dtype(result) do
cast(result, out_dtype)
else
result
end
end

@impl true
def subtract(_out_dtype, left, right) do
def subtract(out_dtype, left, right) do
left = matching_size!(left, right)

Shared.apply_series(left, :s_subtract, [right.data])
result = Shared.apply_series(left, :s_subtract, [right.data])

if match?({:decimal, _, _}, out_dtype) and out_dtype != dtype(result) do
cast(result, out_dtype)
else
result
end
end

@impl true
Expand All @@ -318,7 +331,8 @@ defmodule Explorer.PolarsBackend.Series do
# * `integer * duration -> duration` when `integer` is a scalar
# * `integer * duration -> integer` when `integer` is a series
# We need to return duration in these cases, so we need an additional cast.
if match?({:duration, _}, out_dtype) and out_dtype != dtype(result) do
if (match?({:duration, _}, out_dtype) or match?({:decimal, _, _}, out_dtype)) and
out_dtype != dtype(result) do
cast(result, out_dtype)
else
result
Expand All @@ -333,7 +347,8 @@ defmodule Explorer.PolarsBackend.Series do
# * `duration / integer -> duration` when `integer` is a scalar
# * `duration / integer -> integer` when `integer` is a series
# We need to return duration in these cases, so we need an additional cast.
if match?({:duration, _}, out_dtype) and out_dtype != dtype(result) do
if (match?({:duration, _}, out_dtype) or match?({:decimal, _, _}, out_dtype)) and
out_dtype != dtype(result) do
cast(result, out_dtype)
else
result
Expand Down
18 changes: 15 additions & 3 deletions lib/explorer/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,16 @@ defmodule Explorer.Shared do
{:naive_datetime, precision}
end

# Not a valid option, but this is necessary because the backend
# may return a decimal dtype without precision. We should cast in these cases.
def normalise_dtype({:decimal, nil, scale}), do: normalise_dtype({:decimal, 38, scale})

def normalise_dtype({:decimal, precision, scale} = dtype)
when is_integer(scale) and (is_nil(precision) or is_integer(precision)),
do: dtype
when is_integer(scale) and is_integer(precision) do
if precision in 0..38//1 and scale in 0..38//1 and scale <= precision do
dtype
end
end

def normalise_dtype(_dtype), do: nil

Expand Down Expand Up @@ -383,10 +390,15 @@ defmodule Explorer.Shared do
def merge_numeric_dtype({:decimal, _, _} = decimal, :null), do: decimal
def merge_numeric_dtype(:null, {:decimal, _, _} = decimal), do: decimal

# For now, float has priority over decimals due to Polars.
def merge_numeric_dtype({:decimal, _, _}, {:f, _} = float), do: float
def merge_numeric_dtype({:f, _} = float, {:decimal, _, _}), do: float

def merge_numeric_dtype({:decimal, _, _} = decimal, {:s, _}), do: decimal
def merge_numeric_dtype({:s, _}, {:decimal, _, _} = decimal), do: decimal

def merge_numeric_dtype({:decimal, _, _} = decimal, {:u, _}), do: decimal
def merge_numeric_dtype({:u, _}, {:decimal, _, _} = decimal), do: decimal

def merge_numeric_dtype(_, _), do: nil

@doc """
Expand Down
68 changes: 56 additions & 12 deletions test/explorer/series_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,30 @@ defmodule Explorer.SeriesTest do
assert Series.dtype(s) == {:decimal, 38, 2}
end

test "with decimals without dtype option mixing with integers" do
s =
Series.from_list([
Decimal.new("0"),
Decimal.new("0.42"),
nil,
Decimal.new("5.12467"),
42
])

assert s[1] === Decimal.new("0.42000")
assert s[4] === Decimal.new("0.00042")

assert Series.to_list(s) === [
Decimal.new("0.00000"),
Decimal.new("0.42000"),
nil,
Decimal.new("5.12467"),
Decimal.new("0.00042")
]

assert Series.dtype(s) == {:decimal, 38, 5}
end

test "mixing dates and integers with `:date` dtype" do
s = Series.from_list([1, nil, ~D[2024-06-13]], dtype: :date)

Expand Down Expand Up @@ -1350,8 +1374,8 @@ defmodule Explorer.SeriesTest do
end

test "compare decimal series" do
s1 = Series.from_list([1, 0, 2], dtype: {:decimal, nil, 2})
s2 = Series.from_list([1, 0, 3], dtype: {:decimal, nil, 2})
s1 = Series.from_list([1, 0, 2], dtype: {:decimal, 38, 2})
s2 = Series.from_list([1, 0, 3], dtype: {:decimal, 38, 2})

assert s1 |> Series.less_equal(s2) |> Series.to_list() == [true, true, true]
end
Expand Down Expand Up @@ -1448,8 +1472,8 @@ defmodule Explorer.SeriesTest do
end

test "with decimal series" do
s1 = Series.from_list([1, 2, 3], dtype: {:decimal, nil, 2})
s2 = Series.from_list([1, 0, 3], dtype: {:decimal, nil, 2})
s1 = Series.from_list([1, 2, 3], dtype: {:decimal, 38, 2})
s2 = Series.from_list([1, 0, 3], dtype: {:decimal, 38, 2})

assert s1 |> Series.in(s2) |> Series.to_list() == [true, false, true]
end
Expand Down Expand Up @@ -1895,12 +1919,12 @@ defmodule Explorer.SeriesTest do

s3 = Series.add(s1, s2)

assert s3.dtype == {:decimal, nil, 1}
assert s3.dtype == {:decimal, 38, 1}
assert Series.to_list(s3) == [Decimal.new("2.0"), Decimal.new("5.0"), Decimal.new("4.0")]

s4 = Series.from_list([Decimal.new("0.8561"), Decimal.new("3.5"), Decimal.new("0.910")])
s5 = Series.add(s1, s4)
assert s5.dtype == {:decimal, nil, 4}
assert s5.dtype == {:decimal, 38, 4}

assert Series.to_list(s5) == [
Decimal.new("2.0561"),
Expand All @@ -1918,6 +1942,26 @@ defmodule Explorer.SeriesTest do
assert s3.dtype == {:f, 64}
assert Series.to_list(s3) === [2.0, 5.0, 4.0]
end

test "adding decimal and signed integer series together" do
s1 = Series.from_list([Decimal.new("1.2"), Decimal.new("2.0"), Decimal.new("3.1")])
s2 = Series.from_list([1, 2, 3])

s3 = Series.add(s1, s2)

assert s3.dtype == {:decimal, 38, 1}
assert Series.to_list(s3) === [Decimal.new("2.2"), Decimal.new("4.0"), Decimal.new("6.1")]
end

test "adding decimal and unsigned integer series together" do
s1 = Series.from_list([Decimal.new("1.2"), Decimal.new("2.0"), Decimal.new("3.1")])
s2 = Series.from_list([1, 2, 3], dtype: :u16)

s3 = Series.add(s1, s2)

assert s3.dtype == {:decimal, 38, 1}
assert Series.to_list(s3) === [Decimal.new("2.2"), Decimal.new("4.0"), Decimal.new("6.1")]
end
end

describe "subtract/2" do
Expand Down Expand Up @@ -2076,7 +2120,7 @@ defmodule Explorer.SeriesTest do

s3 = Series.subtract(s1, s2)

assert s3.dtype == {:decimal, nil, 0}
assert s3.dtype == {:decimal, 38, 0}
assert Series.to_list(s3) == [Decimal.new("-3"), Decimal.new("-3"), Decimal.new("-3")]
end

Expand Down Expand Up @@ -2228,7 +2272,7 @@ defmodule Explorer.SeriesTest do

s3 = Series.multiply(s1, s2)

assert s3.dtype == {:decimal, nil, 0}
assert s3.dtype == {:decimal, 38, 0}
assert Series.to_list(s3) === [Decimal.new("4"), Decimal.new("10"), Decimal.new("18")]
end

Expand Down Expand Up @@ -4138,20 +4182,20 @@ defmodule Explorer.SeriesTest do

test "integer series to decimal" do
s = Series.from_list([1, 2, 3])
s1 = Series.cast(s, {:decimal, nil, 0})
s1 = Series.cast(s, {:decimal, 38, 0})
assert Series.to_list(s1) == [Decimal.new("1"), Decimal.new("2"), Decimal.new("3")]
# 38 is Polars' default for precision.
assert Series.dtype(s1) == {:decimal, 38, 0}

# increased scale
s2 = Series.cast(s, {:decimal, nil, 2})
s2 = Series.cast(s, {:decimal, 38, 2})
assert Series.to_list(s2) == [Decimal.new("1.00"), Decimal.new("2.00"), Decimal.new("3.00")]
assert Series.dtype(s2) == {:decimal, 38, 2}
end

test "float series to decimal" do
s = Series.from_list([1.345, 2.561, 3.97212])
s1 = Series.cast(s, {:decimal, nil, 3})
s1 = Series.cast(s, {:decimal, 38, 3})

assert Series.to_list(s1) == [
Decimal.new("1.345"),
Expand All @@ -4161,7 +4205,7 @@ defmodule Explorer.SeriesTest do

assert Series.dtype(s1) == {:decimal, 38, 3}

s2 = Series.cast(s, {:decimal, nil, 4})
s2 = Series.cast(s, {:decimal, 38, 4})

assert Series.to_list(s2) == [
Decimal.new("1.3450"),
Expand Down

0 comments on commit 2b535a0

Please sign in to comment.