Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mismatched types in Series.pow #821

Merged
merged 20 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/explorer/data_frame.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2747,7 +2747,7 @@ defmodule Explorer.DataFrame do
#Explorer.DataFrame<
Polars[3 x 2]
a string ["a", "b", "c"]
b f64 [1.0, 4.0, 9.0]
b s64 [1, 4, 9]
>

It's possible to "reuse" a variable for different computations:
Expand Down
2 changes: 1 addition & 1 deletion lib/explorer/polars_backend/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ defmodule Explorer.PolarsBackend.Native do
def s_peak_max(_s), do: err()
def s_peak_min(_s), do: err()
def s_select(_pred, _on_true, _on_false), do: err()
def s_pow(_s, _other), do: err()
def s_pow(_s, _s_dtype, _other, _other_dtype), do: err()
def s_log_natural(_s_argument), do: err()
def s_log(_s_argument, _base_as_float), do: err()
def s_quantile(_s, _quantile, _strategy), do: err()
Expand Down
7 changes: 6 additions & 1 deletion lib/explorer/polars_backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,12 @@ defmodule Explorer.PolarsBackend.Series do

@impl true
def pow(left, right),
do: Shared.apply_series(matching_size!(left, right), :s_pow, [right.data])
do:
Shared.apply_series(matching_size!(left, right), :s_pow, [
left.dtype,
right.data,
right.dtype
])

@impl true
def log(%Series{} = argument), do: Shared.apply_series(argument, :s_log_natural, [])
Expand Down
16 changes: 13 additions & 3 deletions lib/explorer/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3428,7 +3428,19 @@ defmodule Explorer.Series do
"""
@doc type: :element_wise
@spec pow(left :: Series.t() | number(), right :: Series.t() | number()) :: Series.t()
def pow(left, right), do: basic_numeric_operation(:pow, left, right)
def pow(left, right) do
if K.or(match?(%Series{}, left), match?(%Series{}, right)) do
apply_series_list(:pow, [cast_for_pow(left), cast_for_pow(right)])
else
raise ArgumentError, "at least one input must be a series"
end
end

@non_finite [:nan, :infinity, :neg_infinity]
defp cast_for_pow(%Series{} = series), do: series
defp cast_for_pow(u) when K.and(is_integer(u), u >= 0), do: from_list([u], dtype: {:u, 64})
billylanchantin marked this conversation as resolved.
Show resolved Hide resolved
defp cast_for_pow(s) when K.and(is_integer(s), s < 0), do: from_list([s], dtype: {:s, 64})
defp cast_for_pow(f) when K.or(is_float(f), K.in(f, @non_finite)), do: from_list([f])

@doc """
Calculates the natural logarithm.
Expand Down Expand Up @@ -3752,8 +3764,6 @@ defmodule Explorer.Series do
def atan(%Series{dtype: dtype}),
do: dtype_error("atan/1", dtype, [{:f, 32}, {:f, 64}])

defp basic_numeric_operation(operation, left, right, args \\ [])

defp basic_numeric_operation(operation, %Series{} = left, right, args) when is_numeric(right),
do: basic_numeric_operation(operation, left, from_same_value(left, right), args)

Expand Down
106 changes: 48 additions & 58 deletions native/explorer/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1231,65 +1231,55 @@ pub fn s_n_distinct(s: ExSeries) -> Result<usize, ExplorerError> {
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_pow(s: ExSeries, other: ExSeries) -> Result<ExSeries, ExplorerError> {
match (s.dtype().is_integer(), other.dtype().is_integer()) {
(true, true) => {
let cast1 = s.cast(&DataType::Int64)?;
let mut iter1 = cast1.i64()?.into_iter();

match other.strict_cast(&DataType::UInt32) {
Ok(casted) => {
let mut iter2 = casted.u32()?.into_iter();

let res = if s.len() == 1 {
let v1 = iter1.next().unwrap();
iter2
.map(|v2| v1.and_then(|left| v2.map(|right| left.pow(right))))
.collect()
} else if other.len() == 1 {
let v2 = iter2.next().unwrap();
iter1
.map(|v1| v1.and_then(|left| v2.map(|right| left.pow(right))))
.collect()
} else {
iter1
.zip(iter2)
.map(|(v1, v2)| v1.and_then(|left| v2.map(|right| left.pow(right))))
.collect()
};

Ok(ExSeries::new(res))
}
Err(_) => Err(ExplorerError::Other(
"negative exponent with an integer base".into(),
)),
}
}
(_, _) => {
let cast1 = s.cast(&DataType::Float64)?;
let cast2 = other.cast(&DataType::Float64)?;
let mut iter1 = cast1.f64()?.into_iter();
let mut iter2 = cast2.f64()?.into_iter();

let res = if s.len() == 1 {
let v1 = iter1.next().unwrap();
iter2
.map(|v2| v1.and_then(|left| v2.map(|right| left.powf(right))))
.collect()
} else if other.len() == 1 {
let v2 = iter2.next().unwrap();
iter1
.map(|v1| v1.and_then(|left| v2.map(|right| left.powf(right))))
.collect()
} else {
iter1
.zip(iter2)
.map(|(v1, v2)| v1.and_then(|left| v2.map(|right| left.powf(right))))
.collect()
};
pub fn s_pow(
base_exseries: ExSeries,
base_exdtype: ExSeriesDtype,
exponent_exseries: ExSeries,
exponent_exdtype: ExSeriesDtype,
billylanchantin marked this conversation as resolved.
Show resolved Hide resolved
billylanchantin marked this conversation as resolved.
Show resolved Hide resolved
) -> Result<ExSeries, ExplorerError> {
let df_with_result = if base_exseries.len() == exponent_exseries.len() {
df!(
"base" => base_exseries.clone_inner().into_series(),
"exponent" => exponent_exseries.clone_inner().into_series()
)?
.lazy()
.with_column((col("base").pow(col("exponent"))).alias("result"))
} else if base_exseries.len() == 1 {
let base = first(base_exseries, base_exdtype);

df!( "exponent" => exponent_exseries.clone_inner().into_series() )?
.lazy()
.with_column((base.lit().pow(col("exponent"))).alias("result"))
} else if exponent_exseries.len() == 1 {
let exponent = first(exponent_exseries, exponent_exdtype);

df!( "base" => base_exseries.clone_inner().into_series() )?
.lazy()
.with_column((col("base").pow(exponent.lit())).alias("result"))
} else {
panic!("both series must have the same length or one must have length 1")
};

Ok(ExSeries::new(res))
}
let result = df_with_result.collect()?.column("result")?.clone();

Ok(ExSeries::new(result))
}

fn first(exseries: ExSeries, exdtype: ExSeriesDtype) -> LiteralValue {
let dtype = DataType::try_from(&exdtype).unwrap();

match dtype {
DataType::UInt8 => LiteralValue::UInt8(exseries.u8().unwrap().get(0).unwrap()),
DataType::UInt16 => LiteralValue::UInt16(exseries.u16().unwrap().get(0).unwrap()),
DataType::UInt32 => LiteralValue::UInt32(exseries.u32().unwrap().get(0).unwrap()),
DataType::UInt64 => LiteralValue::UInt64(exseries.u64().unwrap().get(0).unwrap()),
DataType::Int8 => LiteralValue::Int8(exseries.i8().unwrap().get(0).unwrap()),
DataType::Int16 => LiteralValue::Int16(exseries.i16().unwrap().get(0).unwrap()),
DataType::Int32 => LiteralValue::Int32(exseries.i32().unwrap().get(0).unwrap()),
DataType::Int64 => LiteralValue::Int64(exseries.i64().unwrap().get(0).unwrap()),
DataType::Float32 => LiteralValue::Float32(exseries.f32().unwrap().get(0).unwrap()),
DataType::Float64 => LiteralValue::Float64(exseries.f64().unwrap().get(0).unwrap()),
_ => panic!("unsupported dtype for pow: must be integer or float subtype"),
}
}

Expand Down
7 changes: 3 additions & 4 deletions test/explorer/data_frame_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ defmodule Explorer.DataFrameTest do
df = DF.new(a: [1, 2, 3, 4, 5, 6, 5], b: [9, 8, 7, 6, 5, 4, 3])

message =
"expecting the function to return a boolean LazySeries, but instead it returned a LazySeries of type {:s, 64}"
"expecting the function to return a boolean LazySeries, but instead it returned a LazySeries of type {:f, 64}"

assert_raise ArgumentError, message, fn ->
DF.filter_with(df, fn ldf ->
Expand Down Expand Up @@ -819,7 +819,7 @@ defmodule Explorer.DataFrameTest do
"calc2" => {:s, 64},
"calc3" => {:s, 64},
"calc4" => {:f, 64},
"calc5" => {:s, 64},
"calc5" => {:f, 64},
"calc6" => {:s, 64},
"calc7" => {:s, 64},
"calc8" => {:f, 64},
Expand Down Expand Up @@ -861,8 +861,7 @@ defmodule Explorer.DataFrameTest do
"calc2" => {:s, 64},
"calc3" => {:s, 64},
"calc4" => {:f, 64},
# TODO: This should be float after #374 is resolved
"calc5" => {:s, 64},
"calc5" => {:f, 64},
"calc5_1" => {:f, 64},
"calc6" => {:s, 64},
"calc7" => {:s, 64}
Expand Down
107 changes: 86 additions & 21 deletions test/explorer/series_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2036,23 +2036,86 @@ defmodule Explorer.SeriesTest do
end

describe "pow/2" do
test "pow of an integer series with an integer series" do
s1 = Series.from_list([1, 2, 3])
s2 = Series.from_list([3, 2, 1])
test "pow(uint, uint) == uint" do
for u_base <- [8, 16, 32, 64], u_power <- [8, 16, 32, 64] do
base = Series.from_list([1, 2, 3], dtype: {:u, u_base})
power = Series.from_list([3, 2, 1], dtype: {:u, u_power})

result = Series.pow(s1, s2)
result = Series.pow(base, power)

assert result.dtype == {:s, 64}
assert Series.to_list(result) == [1, 4, 3]
assert result.dtype == {:u, u_base}
assert Series.to_list(result) == [1, 4, 3]
end
end

test "pow(sint, uint) == sint" do
for s_base <- [8, 16, 32, 64], u_power <- [8, 16, 32, 64] do
base = Series.from_list([1, 2, 3], dtype: {:s, s_base})
power = Series.from_list([3, 2, 1], dtype: {:u, u_power})

result = Series.pow(base, power)

assert result.dtype == {:s, s_base}
assert Series.to_list(result) == [1, 4, 3]
end
end

test "pow(sint, sint) == float" do
for s_base <- [8, 16, 32, 64], s_power <- [8, 16, 32, 64] do
base = Series.from_list([1, 2, 3], dtype: {:s, s_base})
power = Series.from_list([3, 2, 1], dtype: {:s, s_power})

result = Series.pow(base, power)

assert result.dtype == {:f, 64}
assert Series.to_list(result) === [1.0, 4.0, 3.0]
end
end

test "pow(float, uint_or_sint) = float" do
for f_base <- [32, 64], d_power <- [:s, :u], n_power <- [8, 16, 32, 64] do
base = Series.from_list([1, 2, 3], dtype: {:f, f_base})
power = Series.from_list([3, 2, 1], dtype: {d_power, n_power})

result = Series.pow(base, power)

assert result.dtype == {:f, f_base}
assert Series.to_list(result) === [1.0, 4.0, 3.0]
end
end

test "pow(uint_or_sint, float) = float" do
for d_base <- [:s, :u], n_base <- [8, 16, 32, 64], f_power <- [32, 64] do
base = Series.from_list([1, 2, 3], dtype: {d_base, n_base})
power = Series.from_list([3, 2, 1], dtype: {:f, f_power})

result = Series.pow(base, power)

assert result.dtype == {:f, 64}
assert Series.to_list(result) === [1.0, 4.0, 3.0]
end
end

test "pow(float, float) = float" do
for f_base <- [32, 64], f_power <- [32, 64] do
base = Series.from_list([1, 2, 3], dtype: {:f, f_base})
power = Series.from_list([3, 2, 1], dtype: {:f, f_power})

result = Series.pow(base, power)

assert result.dtype == {:f, f_base}
assert Series.to_list(result) === [1.0, 4.0, 3.0]
end
end

test "pow of an integer series with an integer series that contains negative integer" do
s1 = Series.from_list([1, 2, 3])
s2 = Series.from_list([1, -2, 3])

assert_raise RuntimeError, ~r"negative exponent with an integer base", fn ->
Series.pow(s1, s2)
end
result = Series.pow(s1, s2)

assert result.dtype == {:f, 64}
assert Series.to_list(result) === [1.0, 0.25, 27.0]
end

test "pow of an integer series with a float series" do
Expand Down Expand Up @@ -2081,7 +2144,7 @@ defmodule Explorer.SeriesTest do

result = Series.pow(s1, s2)

assert result.dtype == {:s, 64}
assert result.dtype == {:f, 64}
assert Series.to_list(result) == [1, nil, 3]
end

Expand All @@ -2091,7 +2154,7 @@ defmodule Explorer.SeriesTest do

result = Series.pow(s1, s2)

assert result.dtype == {:s, 64}
assert result.dtype == {:f, 64}
billylanchantin marked this conversation as resolved.
Show resolved Hide resolved
assert Series.to_list(result) == [1, nil, 3]
end

Expand All @@ -2101,7 +2164,7 @@ defmodule Explorer.SeriesTest do

result = Series.pow(s1, s2)

assert result.dtype == {:s, 64}
assert result.dtype == {:f, 64}
assert Series.to_list(result) == [1, nil, 3]
end

Expand All @@ -2117,9 +2180,10 @@ defmodule Explorer.SeriesTest do
test "pow of an integer series with a negative integer scalar value on the right-hand side" do
s1 = Series.from_list([1, 2, 3])

assert_raise RuntimeError,
~r"negative exponent with an integer base",
fn -> Series.pow(s1, -2) end
result = Series.pow(s1, -2)

assert result.dtype == {:f, 64}
assert Series.to_list(result) === [1.0, 1 / 4, 1 / 9]
end

test "pow of an integer series with a float scalar value on the right-hand side" do
Expand Down Expand Up @@ -2172,24 +2236,25 @@ defmodule Explorer.SeriesTest do

result = Series.pow(2, s1)

assert result.dtype == {:s, 64}
assert Series.to_list(result) == [2, 4, 8]
assert result.dtype == {:f, 64}
assert Series.to_list(result) === [2.0, 4.0, 8.0]
end

test "pow of an integer series that contains negative integer with an integer scalar value on the left-hand side" do
s1 = Series.from_list([1, -2, 3])

assert_raise RuntimeError, ~r"negative exponent with an integer base", fn ->
Series.pow(2, s1)
end
result = Series.pow(2, s1)

assert result.dtype == {:f, 64}
assert Series.to_list(result) === [2.0, 0.25, 8.0]
end

test "pow of an integer series with a negative integer scalar value on the left-hand side" do
s1 = Series.from_list([1, 2, 3])

result = Series.pow(-2, s1)

assert result.dtype == {:s, 64}
assert result.dtype == {:f, 64}
assert Series.to_list(result) == [-2, 4, -8]
end

Expand Down
Loading