From ac55806368be034ebdd5ebc87896a6a70664a439 Mon Sep 17 00:00:00 2001 From: William Lanchantin Date: Sun, 14 Jan 2024 12:31:48 -0500 Subject: [PATCH 01/19] first attempt: use lazy version underneath --- lib/explorer/polars_backend/native.ex | 2 +- lib/explorer/polars_backend/series.ex | 7 +- lib/explorer/series.ex | 16 +++- native/explorer/src/series.rs | 106 ++++++++++++-------------- 4 files changed, 70 insertions(+), 61 deletions(-) diff --git a/lib/explorer/polars_backend/native.ex b/lib/explorer/polars_backend/native.ex index f6eb24f64..4da764a65 100644 --- a/lib/explorer/polars_backend/native.ex +++ b/lib/explorer/polars_backend/native.ex @@ -373,7 +373,7 @@ defmodule Explorer.PolarsBackend.Native do def s_peak_max(_s), do: err() def s_peak_min(_s), do: err() def s_select(_pred, _on_true, _on_false), do: err() - def s_pow(_s, _other), do: err() + def s_pow(_s, _s_dtype, _other, _other_dtype), do: err() def s_log_natural(_s_argument), do: err() def s_log(_s_argument, _base_as_float), do: err() def s_quantile(_s, _quantile, _strategy), do: err() diff --git a/lib/explorer/polars_backend/series.ex b/lib/explorer/polars_backend/series.ex index f87fec557..722731709 100644 --- a/lib/explorer/polars_backend/series.ex +++ b/lib/explorer/polars_backend/series.ex @@ -314,7 +314,12 @@ defmodule Explorer.PolarsBackend.Series do @impl true def pow(left, right), - do: Shared.apply_series(matching_size!(left, right), :s_pow, [right.data]) + do: + Shared.apply_series(matching_size!(left, right), :s_pow, [ + left.dtype, + right.data, + right.dtype + ]) @impl true def log(%Series{} = argument), do: Shared.apply_series(argument, :s_log_natural, []) diff --git a/lib/explorer/series.ex b/lib/explorer/series.ex index 3eaa86a04..5fcd7acf9 100644 --- a/lib/explorer/series.ex +++ b/lib/explorer/series.ex @@ -3428,7 +3428,21 @@ defmodule Explorer.Series do """ @doc type: :element_wise @spec pow(left :: Series.t() | number(), right :: Series.t() | number()) :: Series.t() - def pow(left, right), do: basic_numeric_operation(:pow, left, right) + def pow(left, right) do + # TODO: revert back to `basic_numeric_operation(:pow, left, right)` if/when we start inferring + # unsigned integer types from lists. + if K.or(match?(%Series{}, left), match?(%Series{}, right)) do + apply_series_list(:pow, [cast_for_pow(left), cast_for_pow(right)]) + else + raise ArgumentError, "at least one input must be a series" + end + end + + @non_finite [:nan, :infinity, :neg_infinity] + defp cast_for_pow(%Series{} = series), do: series + defp cast_for_pow(u) when K.and(is_integer(u), u >= 0), do: from_list([u], dtype: {:u, 64}) + defp cast_for_pow(s) when K.and(is_integer(s), s < 0), do: from_list([s], dtype: {:s, 64}) + defp cast_for_pow(f) when K.or(is_float(f), K.in(f, @non_finite)), do: from_list([f]) @doc """ Calculates the natural logarithm. diff --git a/native/explorer/src/series.rs b/native/explorer/src/series.rs index d36ad2c01..22a34aa36 100644 --- a/native/explorer/src/series.rs +++ b/native/explorer/src/series.rs @@ -1231,65 +1231,55 @@ pub fn s_n_distinct(s: ExSeries) -> Result { } #[rustler::nif(schedule = "DirtyCpu")] -pub fn s_pow(s: ExSeries, other: ExSeries) -> Result { - match (s.dtype().is_integer(), other.dtype().is_integer()) { - (true, true) => { - let cast1 = s.cast(&DataType::Int64)?; - let mut iter1 = cast1.i64()?.into_iter(); - - match other.strict_cast(&DataType::UInt32) { - Ok(casted) => { - let mut iter2 = casted.u32()?.into_iter(); - - let res = if s.len() == 1 { - let v1 = iter1.next().unwrap(); - iter2 - .map(|v2| v1.and_then(|left| v2.map(|right| left.pow(right)))) - .collect() - } else if other.len() == 1 { - let v2 = iter2.next().unwrap(); - iter1 - .map(|v1| v1.and_then(|left| v2.map(|right| left.pow(right)))) - .collect() - } else { - iter1 - .zip(iter2) - .map(|(v1, v2)| v1.and_then(|left| v2.map(|right| left.pow(right)))) - .collect() - }; - - Ok(ExSeries::new(res)) - } - Err(_) => Err(ExplorerError::Other( - "negative exponent with an integer base".into(), - )), - } - } - (_, _) => { - let cast1 = s.cast(&DataType::Float64)?; - let cast2 = other.cast(&DataType::Float64)?; - let mut iter1 = cast1.f64()?.into_iter(); - let mut iter2 = cast2.f64()?.into_iter(); - - let res = if s.len() == 1 { - let v1 = iter1.next().unwrap(); - iter2 - .map(|v2| v1.and_then(|left| v2.map(|right| left.powf(right)))) - .collect() - } else if other.len() == 1 { - let v2 = iter2.next().unwrap(); - iter1 - .map(|v1| v1.and_then(|left| v2.map(|right| left.powf(right)))) - .collect() - } else { - iter1 - .zip(iter2) - .map(|(v1, v2)| v1.and_then(|left| v2.map(|right| left.powf(right)))) - .collect() - }; +pub fn s_pow( + base_exseries: ExSeries, + base_exdtype: ExSeriesDtype, + exponent_exseries: ExSeries, + exponent_exdtype: ExSeriesDtype, +) -> Result { + let df_with_result = if base_exseries.len() == exponent_exseries.len() { + df!( + "base" => base_exseries.clone_inner().into_series(), + "exponent" => exponent_exseries.clone_inner().into_series() + )? + .lazy() + .with_column((col("base").pow(col("exponent"))).alias("result")) + } else if base_exseries.len() == 1 { + let base = first(base_exseries, base_exdtype); + + df!( "exponent" => exponent_exseries.clone_inner().into_series() )? + .lazy() + .with_column((base.lit().pow(col("exponent"))).alias("result")) + } else if exponent_exseries.len() == 1 { + let exponent = first(exponent_exseries, exponent_exdtype); + + df!( "base" => base_exseries.clone_inner().into_series() )? + .lazy() + .with_column((col("base").pow(exponent.lit())).alias("result")) + } else { + panic!("adsf") + }; - Ok(ExSeries::new(res)) - } + let result = df_with_result.collect()?.column("result")?.clone(); + + Ok(ExSeries::new(result)) +} + +fn first(exseries: ExSeries, exdtype: ExSeriesDtype) -> LiteralValue { + let dtype = DataType::try_from(&exdtype).unwrap(); + + match dtype { + DataType::UInt8 => LiteralValue::UInt8(exseries.u8().unwrap().get(0).unwrap()), + DataType::UInt16 => LiteralValue::UInt16(exseries.u16().unwrap().get(0).unwrap()), + DataType::UInt32 => LiteralValue::UInt32(exseries.u32().unwrap().get(0).unwrap()), + DataType::UInt64 => LiteralValue::UInt64(exseries.u64().unwrap().get(0).unwrap()), + DataType::Int8 => LiteralValue::Int8(exseries.i8().unwrap().get(0).unwrap()), + DataType::Int16 => LiteralValue::Int16(exseries.i16().unwrap().get(0).unwrap()), + DataType::Int32 => LiteralValue::Int32(exseries.i32().unwrap().get(0).unwrap()), + DataType::Int64 => LiteralValue::Int64(exseries.i64().unwrap().get(0).unwrap()), + DataType::Float32 => LiteralValue::Float32(exseries.f32().unwrap().get(0).unwrap()), + DataType::Float64 => LiteralValue::Float64(exseries.f64().unwrap().get(0).unwrap()), + _ => panic!("asdf"), } } From 442c71b27c7a1c6f66cf43de3c614c0614a75ceb Mon Sep 17 00:00:00 2001 From: William Lanchantin Date: Sun, 14 Jan 2024 12:31:59 -0500 Subject: [PATCH 02/19] fix tests --- test/explorer/data_frame_test.exs | 4 +- test/explorer/series_test.exs | 107 ++++++++++++++++++++++++------ 2 files changed, 88 insertions(+), 23 deletions(-) diff --git a/test/explorer/data_frame_test.exs b/test/explorer/data_frame_test.exs index 8ae01646c..0119c70c8 100644 --- a/test/explorer/data_frame_test.exs +++ b/test/explorer/data_frame_test.exs @@ -299,7 +299,7 @@ defmodule Explorer.DataFrameTest do df = DF.new(a: [1, 2, 3, 4, 5, 6, 5], b: [9, 8, 7, 6, 5, 4, 3]) message = - "expecting the function to return a boolean LazySeries, but instead it returned a LazySeries of type {:s, 64}" + "expecting the function to return a boolean LazySeries, but instead it returned a LazySeries of type {:f, 64}" assert_raise ArgumentError, message, fn -> DF.filter_with(df, fn ldf -> @@ -819,7 +819,7 @@ defmodule Explorer.DataFrameTest do "calc2" => {:s, 64}, "calc3" => {:s, 64}, "calc4" => {:f, 64}, - "calc5" => {:s, 64}, + "calc5" => {:f, 64}, "calc6" => {:s, 64}, "calc7" => {:s, 64}, "calc8" => {:f, 64}, diff --git a/test/explorer/series_test.exs b/test/explorer/series_test.exs index bcbdab0c1..fce966c9f 100644 --- a/test/explorer/series_test.exs +++ b/test/explorer/series_test.exs @@ -2036,23 +2036,86 @@ defmodule Explorer.SeriesTest do end describe "pow/2" do - test "pow of an integer series with an integer series" do - s1 = Series.from_list([1, 2, 3]) - s2 = Series.from_list([3, 2, 1]) + test "pow(uint, uint) == uint" do + for u_base <- [8, 16, 32, 64], u_power <- [8, 16, 32, 64] do + base = Series.from_list([1, 2, 3], dtype: {:u, u_base}) + power = Series.from_list([3, 2, 1], dtype: {:u, u_power}) - result = Series.pow(s1, s2) + result = Series.pow(base, power) - assert result.dtype == {:s, 64} - assert Series.to_list(result) == [1, 4, 3] + assert result.dtype == {:u, u_base} + assert Series.to_list(result) == [1, 4, 3] + end + end + + test "pow(sint, uint) == sint" do + for s_base <- [8, 16, 32, 64], u_power <- [8, 16, 32, 64] do + base = Series.from_list([1, 2, 3], dtype: {:s, s_base}) + power = Series.from_list([3, 2, 1], dtype: {:u, u_power}) + + result = Series.pow(base, power) + + assert result.dtype == {:s, s_base} + assert Series.to_list(result) == [1, 4, 3] + end + end + + test "pow(sint, sint) == float" do + for s_base <- [8, 16, 32, 64], s_power <- [8, 16, 32, 64] do + base = Series.from_list([1, 2, 3], dtype: {:s, s_base}) + power = Series.from_list([3, 2, 1], dtype: {:s, s_power}) + + result = Series.pow(base, power) + + assert result.dtype == {:f, 64} + assert Series.to_list(result) === [1.0, 4.0, 3.0] + end + end + + test "pow(float, uint_or_sint) = float" do + for f_base <- [32, 64], d_power <- [:s, :u], n_power <- [8, 16, 32, 64] do + base = Series.from_list([1, 2, 3], dtype: {:f, f_base}) + power = Series.from_list([3, 2, 1], dtype: {d_power, n_power}) + + result = Series.pow(base, power) + + assert result.dtype == {:f, f_base} + assert Series.to_list(result) === [1.0, 4.0, 3.0] + end + end + + test "pow(uint_or_sint, float) = float" do + for d_base <- [:s, :u], n_base <- [8, 16, 32, 64], f_power <- [32, 64] do + base = Series.from_list([1, 2, 3], dtype: {d_base, n_base}) + power = Series.from_list([3, 2, 1], dtype: {:f, f_power}) + + result = Series.pow(base, power) + + assert result.dtype == {:f, 64} + assert Series.to_list(result) === [1.0, 4.0, 3.0] + end + end + + test "pow(float, float) = float" do + for f_base <- [32, 64], f_power <- [32, 64] do + base = Series.from_list([1, 2, 3], dtype: {:f, f_base}) + power = Series.from_list([3, 2, 1], dtype: {:f, f_power}) + + result = Series.pow(base, power) + + assert result.dtype == {:f, f_base} + assert Series.to_list(result) === [1.0, 4.0, 3.0] + end end test "pow of an integer series with an integer series that contains negative integer" do s1 = Series.from_list([1, 2, 3]) s2 = Series.from_list([1, -2, 3]) - assert_raise RuntimeError, ~r"negative exponent with an integer base", fn -> - Series.pow(s1, s2) - end + result = Series.pow(s1, s2) + + assert result.dtype == {:f, 64} + assert Series.to_list(result) === [1.0, 0.25, 27.0] end test "pow of an integer series with a float series" do @@ -2081,7 +2144,7 @@ defmodule Explorer.SeriesTest do result = Series.pow(s1, s2) - assert result.dtype == {:s, 64} + assert result.dtype == {:f, 64} assert Series.to_list(result) == [1, nil, 3] end @@ -2091,7 +2154,7 @@ defmodule Explorer.SeriesTest do result = Series.pow(s1, s2) - assert result.dtype == {:s, 64} + assert result.dtype == {:f, 64} assert Series.to_list(result) == [1, nil, 3] end @@ -2101,7 +2164,7 @@ defmodule Explorer.SeriesTest do result = Series.pow(s1, s2) - assert result.dtype == {:s, 64} + assert result.dtype == {:f, 64} assert Series.to_list(result) == [1, nil, 3] end @@ -2117,9 +2180,10 @@ defmodule Explorer.SeriesTest do test "pow of an integer series with a negative integer scalar value on the right-hand side" do s1 = Series.from_list([1, 2, 3]) - assert_raise RuntimeError, - ~r"negative exponent with an integer base", - fn -> Series.pow(s1, -2) end + result = Series.pow(s1, -2) + + assert result.dtype == {:f, 64} + assert Series.to_list(result) === [1.0, 1 / 4, 1 / 9] end test "pow of an integer series with a float scalar value on the right-hand side" do @@ -2172,16 +2236,17 @@ defmodule Explorer.SeriesTest do result = Series.pow(2, s1) - assert result.dtype == {:s, 64} - assert Series.to_list(result) == [2, 4, 8] + assert result.dtype == {:f, 64} + assert Series.to_list(result) === [2.0, 4.0, 8.0] end test "pow of an integer series that contains negative integer with an integer scalar value on the left-hand side" do s1 = Series.from_list([1, -2, 3]) - assert_raise RuntimeError, ~r"negative exponent with an integer base", fn -> - Series.pow(2, s1) - end + result = Series.pow(2, s1) + + assert result.dtype == {:f, 64} + assert Series.to_list(result) === [2.0, 0.25, 8.0] end test "pow of an integer series with a negative integer scalar value on the left-hand side" do @@ -2189,7 +2254,7 @@ defmodule Explorer.SeriesTest do result = Series.pow(-2, s1) - assert result.dtype == {:s, 64} + assert result.dtype == {:f, 64} assert Series.to_list(result) == [-2, 4, -8] end From 8eb61cddc79841c0a6dfea3874029dee73a4b698 Mon Sep 17 00:00:00 2001 From: William Lanchantin Date: Sun, 14 Jan 2024 12:52:26 -0500 Subject: [PATCH 03/19] missed some tests --- lib/explorer/data_frame.ex | 2 +- test/explorer/data_frame_test.exs | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/lib/explorer/data_frame.ex b/lib/explorer/data_frame.ex index 59c742d5c..56e34ee1a 100644 --- a/lib/explorer/data_frame.ex +++ b/lib/explorer/data_frame.ex @@ -2747,7 +2747,7 @@ defmodule Explorer.DataFrame do #Explorer.DataFrame< Polars[3 x 2] a string ["a", "b", "c"] - b f64 [1.0, 4.0, 9.0] + b s64 [1, 4, 9] > It's possible to "reuse" a variable for different computations: diff --git a/test/explorer/data_frame_test.exs b/test/explorer/data_frame_test.exs index 0119c70c8..288b48386 100644 --- a/test/explorer/data_frame_test.exs +++ b/test/explorer/data_frame_test.exs @@ -861,8 +861,7 @@ defmodule Explorer.DataFrameTest do "calc2" => {:s, 64}, "calc3" => {:s, 64}, "calc4" => {:f, 64}, - # TODO: This should be float after #374 is resolved - "calc5" => {:s, 64}, + "calc5" => {:f, 64}, "calc5_1" => {:f, 64}, "calc6" => {:s, 64}, "calc7" => {:s, 64} From 256493a79ae8f24c41fcc836858d65d86a4143ba Mon Sep 17 00:00:00 2001 From: William Lanchantin Date: Sun, 14 Jan 2024 12:57:30 -0500 Subject: [PATCH 04/19] fix a warning related to default defp args --- lib/explorer/series.ex | 2 -- 1 file changed, 2 deletions(-) diff --git a/lib/explorer/series.ex b/lib/explorer/series.ex index 5fcd7acf9..b1a36b6c9 100644 --- a/lib/explorer/series.ex +++ b/lib/explorer/series.ex @@ -3766,8 +3766,6 @@ defmodule Explorer.Series do def atan(%Series{dtype: dtype}), do: dtype_error("atan/1", dtype, [{:f, 32}, {:f, 64}]) - defp basic_numeric_operation(operation, left, right, args \\ []) - defp basic_numeric_operation(operation, %Series{} = left, right, args) when is_numeric(right), do: basic_numeric_operation(operation, left, from_same_value(left, right), args) From 14c814531b29a74c5ec1a1cd2e7b9b0bf963651f Mon Sep 17 00:00:00 2001 From: William Lanchantin Date: Sun, 14 Jan 2024 13:07:22 -0500 Subject: [PATCH 05/19] informative panics (whoops...) --- native/explorer/src/series.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/native/explorer/src/series.rs b/native/explorer/src/series.rs index 22a34aa36..ba63371c2 100644 --- a/native/explorer/src/series.rs +++ b/native/explorer/src/series.rs @@ -1257,7 +1257,7 @@ pub fn s_pow( .lazy() .with_column((col("base").pow(exponent.lit())).alias("result")) } else { - panic!("adsf") + panic!("both series must have the same length or one must have length 1") }; let result = df_with_result.collect()?.column("result")?.clone(); @@ -1279,7 +1279,7 @@ fn first(exseries: ExSeries, exdtype: ExSeriesDtype) -> LiteralValue { DataType::Int64 => LiteralValue::Int64(exseries.i64().unwrap().get(0).unwrap()), DataType::Float32 => LiteralValue::Float32(exseries.f32().unwrap().get(0).unwrap()), DataType::Float64 => LiteralValue::Float64(exseries.f64().unwrap().get(0).unwrap()), - _ => panic!("asdf"), + _ => panic!("unsupported dtype for pow: must be integer or float subtype"), } } From 1775608d858ce711329a8df5cdcdf5a945c6cbc2 Mon Sep 17 00:00:00 2001 From: William Lanchantin Date: Sun, 14 Jan 2024 14:03:11 -0500 Subject: [PATCH 06/19] drop todo -- infer signed int is likely here to stay --- lib/explorer/series.ex | 2 -- 1 file changed, 2 deletions(-) diff --git a/lib/explorer/series.ex b/lib/explorer/series.ex index b1a36b6c9..a56ebade7 100644 --- a/lib/explorer/series.ex +++ b/lib/explorer/series.ex @@ -3429,8 +3429,6 @@ defmodule Explorer.Series do @doc type: :element_wise @spec pow(left :: Series.t() | number(), right :: Series.t() | number()) :: Series.t() def pow(left, right) do - # TODO: revert back to `basic_numeric_operation(:pow, left, right)` if/when we start inferring - # unsigned integer types from lists. if K.or(match?(%Series{}, left), match?(%Series{}, right)) do apply_series_list(:pow, [cast_for_pow(left), cast_for_pow(right)]) else From 66b45ce4228252efdd02a55fba775816ed261ccd Mon Sep 17 00:00:00 2001 From: William Lanchantin Date: Mon, 15 Jan 2024 10:15:47 -0500 Subject: [PATCH 07/19] pull dtype off ExSeries --- lib/explorer/polars_backend/native.ex | 2 +- lib/explorer/polars_backend/series.ex | 7 +------ native/explorer/src/series.rs | 29 +++++++++++---------------- 3 files changed, 14 insertions(+), 24 deletions(-) diff --git a/lib/explorer/polars_backend/native.ex b/lib/explorer/polars_backend/native.ex index 4da764a65..f6eb24f64 100644 --- a/lib/explorer/polars_backend/native.ex +++ b/lib/explorer/polars_backend/native.ex @@ -373,7 +373,7 @@ defmodule Explorer.PolarsBackend.Native do def s_peak_max(_s), do: err() def s_peak_min(_s), do: err() def s_select(_pred, _on_true, _on_false), do: err() - def s_pow(_s, _s_dtype, _other, _other_dtype), do: err() + def s_pow(_s, _other), do: err() def s_log_natural(_s_argument), do: err() def s_log(_s_argument, _base_as_float), do: err() def s_quantile(_s, _quantile, _strategy), do: err() diff --git a/lib/explorer/polars_backend/series.ex b/lib/explorer/polars_backend/series.ex index 722731709..f87fec557 100644 --- a/lib/explorer/polars_backend/series.ex +++ b/lib/explorer/polars_backend/series.ex @@ -314,12 +314,7 @@ defmodule Explorer.PolarsBackend.Series do @impl true def pow(left, right), - do: - Shared.apply_series(matching_size!(left, right), :s_pow, [ - left.dtype, - right.data, - right.dtype - ]) + do: Shared.apply_series(matching_size!(left, right), :s_pow, [right.data]) @impl true def log(%Series{} = argument), do: Shared.apply_series(argument, :s_log_natural, []) diff --git a/native/explorer/src/series.rs b/native/explorer/src/series.rs index ba63371c2..d55eb6e05 100644 --- a/native/explorer/src/series.rs +++ b/native/explorer/src/series.rs @@ -1231,29 +1231,24 @@ pub fn s_n_distinct(s: ExSeries) -> Result { } #[rustler::nif(schedule = "DirtyCpu")] -pub fn s_pow( - base_exseries: ExSeries, - base_exdtype: ExSeriesDtype, - exponent_exseries: ExSeries, - exponent_exdtype: ExSeriesDtype, -) -> Result { - let df_with_result = if base_exseries.len() == exponent_exseries.len() { +pub fn s_pow(s: ExSeries, other: ExSeries) -> Result { + let df_with_result = if s.len() == other.len() { df!( - "base" => base_exseries.clone_inner().into_series(), - "exponent" => exponent_exseries.clone_inner().into_series() + "base" => s.clone_inner().into_series(), + "exponent" => other.clone_inner().into_series() )? .lazy() .with_column((col("base").pow(col("exponent"))).alias("result")) - } else if base_exseries.len() == 1 { - let base = first(base_exseries, base_exdtype); + } else if s.len() == 1 { + let base = first(s); - df!( "exponent" => exponent_exseries.clone_inner().into_series() )? + df!( "exponent" => other.clone_inner().into_series() )? .lazy() .with_column((base.lit().pow(col("exponent"))).alias("result")) - } else if exponent_exseries.len() == 1 { - let exponent = first(exponent_exseries, exponent_exdtype); + } else if other.len() == 1 { + let exponent = first(other); - df!( "base" => base_exseries.clone_inner().into_series() )? + df!( "base" => s.clone_inner().into_series() )? .lazy() .with_column((col("base").pow(exponent.lit())).alias("result")) } else { @@ -1265,8 +1260,8 @@ pub fn s_pow( Ok(ExSeries::new(result)) } -fn first(exseries: ExSeries, exdtype: ExSeriesDtype) -> LiteralValue { - let dtype = DataType::try_from(&exdtype).unwrap(); +fn first(exseries: ExSeries) -> LiteralValue { + let dtype = DataType::try_from(exseries.dtype().clone()).unwrap(); match dtype { DataType::UInt8 => LiteralValue::UInt8(exseries.u8().unwrap().get(0).unwrap()), From ed5338f6085507e7e4780dd31026a19dedb27f03 Mon Sep 17 00:00:00 2001 From: William Lanchantin Date: Mon, 15 Jan 2024 10:20:19 -0500 Subject: [PATCH 08/19] fix rust linter error --- native/explorer/src/series.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/native/explorer/src/series.rs b/native/explorer/src/series.rs index d55eb6e05..9edac9450 100644 --- a/native/explorer/src/series.rs +++ b/native/explorer/src/series.rs @@ -1261,7 +1261,7 @@ pub fn s_pow(s: ExSeries, other: ExSeries) -> Result { } fn first(exseries: ExSeries) -> LiteralValue { - let dtype = DataType::try_from(exseries.dtype().clone()).unwrap(); + let dtype = exseries.dtype().clone(); match dtype { DataType::UInt8 => LiteralValue::UInt8(exseries.u8().unwrap().get(0).unwrap()), From 276e0f3d541a7d92fd2d72e02f5202ad604d78bf Mon Sep 17 00:00:00 2001 From: William Lanchantin Date: Mon, 15 Jan 2024 10:44:37 -0500 Subject: [PATCH 09/19] also test pow(uint, sint) == float64 --- test/explorer/series_test.exs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/test/explorer/series_test.exs b/test/explorer/series_test.exs index fce966c9f..2fb7270c5 100644 --- a/test/explorer/series_test.exs +++ b/test/explorer/series_test.exs @@ -2048,6 +2048,18 @@ defmodule Explorer.SeriesTest do end end + test "pow(uint, sint) == float64" do + for u_base <- [8, 16, 32, 64], s_power <- [8, 16, 32, 64] do + base = Series.from_list([1, 2, 3], dtype: {:u, u_base}) + power = Series.from_list([3, 2, 1], dtype: {:s, s_power}) + + result = Series.pow(base, power) + + assert result.dtype == {:f, 64} + assert Series.to_list(result) == [1, 4, 3] + end + end + test "pow(sint, uint) == sint" do for s_base <- [8, 16, 32, 64], u_power <- [8, 16, 32, 64] do base = Series.from_list([1, 2, 3], dtype: {:s, s_base}) @@ -2060,7 +2072,7 @@ defmodule Explorer.SeriesTest do end end - test "pow(sint, sint) == float" do + test "pow(sint, sint) == float64" do for s_base <- [8, 16, 32, 64], s_power <- [8, 16, 32, 64] do base = Series.from_list([1, 2, 3], dtype: {:s, s_base}) power = Series.from_list([3, 2, 1], dtype: {:s, s_power}) From 1186f7959c6ee1691b888b558bead002a6dd7eb1 Mon Sep 17 00:00:00 2001 From: Billy Lanchantin Date: Mon, 15 Jan 2024 10:47:06 -0500 Subject: [PATCH 10/19] simplify dtype call Co-authored-by: lkarthee --- native/explorer/src/series.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/native/explorer/src/series.rs b/native/explorer/src/series.rs index 9edac9450..246cb743b 100644 --- a/native/explorer/src/series.rs +++ b/native/explorer/src/series.rs @@ -1260,10 +1260,8 @@ pub fn s_pow(s: ExSeries, other: ExSeries) -> Result { Ok(ExSeries::new(result)) } -fn first(exseries: ExSeries) -> LiteralValue { - let dtype = exseries.dtype().clone(); - - match dtype { +fn first(s: ExSeries) -> LiteralValue { + match s.dtype() { DataType::UInt8 => LiteralValue::UInt8(exseries.u8().unwrap().get(0).unwrap()), DataType::UInt16 => LiteralValue::UInt16(exseries.u16().unwrap().get(0).unwrap()), DataType::UInt32 => LiteralValue::UInt32(exseries.u32().unwrap().get(0).unwrap()), From 5442d2c06ea8f1f1c0c7709fb74d75c97f70eb7d Mon Sep 17 00:00:00 2001 From: William Lanchantin Date: Mon, 15 Jan 2024 10:50:27 -0500 Subject: [PATCH 11/19] whoops, wrong variable --- native/explorer/src/series.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/native/explorer/src/series.rs b/native/explorer/src/series.rs index 246cb743b..c06d562ea 100644 --- a/native/explorer/src/series.rs +++ b/native/explorer/src/series.rs @@ -1260,8 +1260,8 @@ pub fn s_pow(s: ExSeries, other: ExSeries) -> Result { Ok(ExSeries::new(result)) } -fn first(s: ExSeries) -> LiteralValue { - match s.dtype() { +fn first(exseries: ExSeries) -> LiteralValue { + match exseries.dtype() { DataType::UInt8 => LiteralValue::UInt8(exseries.u8().unwrap().get(0).unwrap()), DataType::UInt16 => LiteralValue::UInt16(exseries.u16().unwrap().get(0).unwrap()), DataType::UInt32 => LiteralValue::UInt32(exseries.u32().unwrap().get(0).unwrap()), From ad428d0be6efa617d96f8eb45686e83fbf5074bf Mon Sep 17 00:00:00 2001 From: Billy Lanchantin Date: Mon, 15 Jan 2024 13:31:22 -0500 Subject: [PATCH 12/19] use match instead of if/else if/else Co-authored-by: lkarthee --- native/explorer/src/series.rs | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/native/explorer/src/series.rs b/native/explorer/src/series.rs index c06d562ea..a0a102938 100644 --- a/native/explorer/src/series.rs +++ b/native/explorer/src/series.rs @@ -1232,27 +1232,26 @@ pub fn s_n_distinct(s: ExSeries) -> Result { #[rustler::nif(schedule = "DirtyCpu")] pub fn s_pow(s: ExSeries, other: ExSeries) -> Result { - let df_with_result = if s.len() == other.len() { - df!( + let df_with_result = match (s.len(), other.len()) { + (x, x1) if x == x1 => df!( "base" => s.clone_inner().into_series(), "exponent" => other.clone_inner().into_series() )? .lazy() - .with_column((col("base").pow(col("exponent"))).alias("result")) - } else if s.len() == 1 { - let base = first(s); - - df!( "exponent" => other.clone_inner().into_series() )? - .lazy() - .with_column((base.lit().pow(col("exponent"))).alias("result")) - } else if other.len() == 1 { - let exponent = first(other); - - df!( "base" => s.clone_inner().into_series() )? - .lazy() - .with_column((col("base").pow(exponent.lit())).alias("result")) - } else { - panic!("both series must have the same length or one must have length 1") + .with_column((col("base").pow(col("exponent"))).alias("result")), + (1, _) => { + let base = first(s); + df!( "exponent" => other.clone_inner().into_series() )? + .lazy() + .with_column((base.lit().pow(col("exponent"))).alias("result")) + } + (_, 1) => { + let exponent = first(other); + df!( "base" => s.clone_inner().into_series() )? + .lazy() + .with_column((col("base").pow(exponent.lit())).alias("result")) + } + _ => panic!("both series must have the same length or one must have length 1"), }; let result = df_with_result.collect()?.column("result")?.clone(); From 87dcc2ed9b5c35401e249e760027bd32753d08c8 Mon Sep 17 00:00:00 2001 From: William Lanchantin Date: Mon, 15 Jan 2024 15:10:16 -0500 Subject: [PATCH 13/19] keep everything in elixir --- lib/explorer/polars_backend/native.ex | 1 - lib/explorer/polars_backend/series.ex | 20 ++++++++++-- lib/explorer/series.ex | 16 ++-------- native/explorer/src/lib.rs | 1 - native/explorer/src/series.rs | 45 --------------------------- test/explorer/series_test.exs | 2 +- 6 files changed, 22 insertions(+), 63 deletions(-) diff --git a/lib/explorer/polars_backend/native.ex b/lib/explorer/polars_backend/native.ex index f6eb24f64..8a21b117f 100644 --- a/lib/explorer/polars_backend/native.ex +++ b/lib/explorer/polars_backend/native.ex @@ -373,7 +373,6 @@ defmodule Explorer.PolarsBackend.Native do def s_peak_max(_s), do: err() def s_peak_min(_s), do: err() def s_select(_pred, _on_true, _on_false), do: err() - def s_pow(_s, _other), do: err() def s_log_natural(_s_argument), do: err() def s_log(_s_argument, _base_as_float), do: err() def s_quantile(_s, _quantile, _strategy), do: err() diff --git a/lib/explorer/polars_backend/series.ex b/lib/explorer/polars_backend/series.ex index f87fec557..42a4a87de 100644 --- a/lib/explorer/polars_backend/series.ex +++ b/lib/explorer/polars_backend/series.ex @@ -313,8 +313,24 @@ defmodule Explorer.PolarsBackend.Series do do: Shared.apply_series(matching_size!(left, right), :s_remainder, [right.data]) @impl true - def pow(left, right), - do: Shared.apply_series(matching_size!(left, right), :s_pow, [right.data]) + def pow(left, right) do + _ = matching_size!(left, right) + left_lazy = Explorer.Backend.LazySeries.new(:column, ["base"], left.dtype) + right_lazy = Explorer.Backend.LazySeries.new(:column, ["exponent"], right.dtype) + + {df_args, pow_args} = + case {size(left), size(right)} do + {n, n} -> {[{"base", left}, {"exponent", right}], [left_lazy, right_lazy]} + {1, _} -> {[{"exponent", right}], [Explorer.Series.at(left, 0), right_lazy]} + {_, 1} -> {[{"base", left}], [left_lazy, Explorer.Series.at(right, 0)]} + end + + df = Explorer.PolarsBackend.DataFrame.from_series(df_args) + pow = Explorer.Backend.LazySeries.new(:pow, pow_args, nil) + + Explorer.PolarsBackend.DataFrame.mutate_with(df, df, [{"pow", pow}]) + |> Explorer.PolarsBackend.DataFrame.pull("pow") + end @impl true def log(%Series{} = argument), do: Shared.apply_series(argument, :s_log_natural, []) diff --git a/lib/explorer/series.ex b/lib/explorer/series.ex index a56ebade7..3eaa86a04 100644 --- a/lib/explorer/series.ex +++ b/lib/explorer/series.ex @@ -3428,19 +3428,7 @@ defmodule Explorer.Series do """ @doc type: :element_wise @spec pow(left :: Series.t() | number(), right :: Series.t() | number()) :: Series.t() - def pow(left, right) do - if K.or(match?(%Series{}, left), match?(%Series{}, right)) do - apply_series_list(:pow, [cast_for_pow(left), cast_for_pow(right)]) - else - raise ArgumentError, "at least one input must be a series" - end - end - - @non_finite [:nan, :infinity, :neg_infinity] - defp cast_for_pow(%Series{} = series), do: series - defp cast_for_pow(u) when K.and(is_integer(u), u >= 0), do: from_list([u], dtype: {:u, 64}) - defp cast_for_pow(s) when K.and(is_integer(s), s < 0), do: from_list([s], dtype: {:s, 64}) - defp cast_for_pow(f) when K.or(is_float(f), K.in(f, @non_finite)), do: from_list([f]) + def pow(left, right), do: basic_numeric_operation(:pow, left, right) @doc """ Calculates the natural logarithm. @@ -3764,6 +3752,8 @@ defmodule Explorer.Series do def atan(%Series{dtype: dtype}), do: dtype_error("atan/1", dtype, [{:f, 32}, {:f, 64}]) + defp basic_numeric_operation(operation, left, right, args \\ []) + defp basic_numeric_operation(operation, %Series{} = left, right, args) when is_numeric(right), do: basic_numeric_operation(operation, left, from_same_value(left, right), args) diff --git a/native/explorer/src/lib.rs b/native/explorer/src/lib.rs index 857f26f84..5ee3b7b11 100644 --- a/native/explorer/src/lib.rs +++ b/native/explorer/src/lib.rs @@ -426,7 +426,6 @@ rustler::init!( s_peak_max, s_peak_min, s_select, - s_pow, s_quantile, s_quotient, s_rank, diff --git a/native/explorer/src/series.rs b/native/explorer/src/series.rs index a0a102938..7c8ec91e3 100644 --- a/native/explorer/src/series.rs +++ b/native/explorer/src/series.rs @@ -1230,51 +1230,6 @@ pub fn s_n_distinct(s: ExSeries) -> Result { Ok(s.n_unique()?) } -#[rustler::nif(schedule = "DirtyCpu")] -pub fn s_pow(s: ExSeries, other: ExSeries) -> Result { - let df_with_result = match (s.len(), other.len()) { - (x, x1) if x == x1 => df!( - "base" => s.clone_inner().into_series(), - "exponent" => other.clone_inner().into_series() - )? - .lazy() - .with_column((col("base").pow(col("exponent"))).alias("result")), - (1, _) => { - let base = first(s); - df!( "exponent" => other.clone_inner().into_series() )? - .lazy() - .with_column((base.lit().pow(col("exponent"))).alias("result")) - } - (_, 1) => { - let exponent = first(other); - df!( "base" => s.clone_inner().into_series() )? - .lazy() - .with_column((col("base").pow(exponent.lit())).alias("result")) - } - _ => panic!("both series must have the same length or one must have length 1"), - }; - - let result = df_with_result.collect()?.column("result")?.clone(); - - Ok(ExSeries::new(result)) -} - -fn first(exseries: ExSeries) -> LiteralValue { - match exseries.dtype() { - DataType::UInt8 => LiteralValue::UInt8(exseries.u8().unwrap().get(0).unwrap()), - DataType::UInt16 => LiteralValue::UInt16(exseries.u16().unwrap().get(0).unwrap()), - DataType::UInt32 => LiteralValue::UInt32(exseries.u32().unwrap().get(0).unwrap()), - DataType::UInt64 => LiteralValue::UInt64(exseries.u64().unwrap().get(0).unwrap()), - DataType::Int8 => LiteralValue::Int8(exseries.i8().unwrap().get(0).unwrap()), - DataType::Int16 => LiteralValue::Int16(exseries.i16().unwrap().get(0).unwrap()), - DataType::Int32 => LiteralValue::Int32(exseries.i32().unwrap().get(0).unwrap()), - DataType::Int64 => LiteralValue::Int64(exseries.i64().unwrap().get(0).unwrap()), - DataType::Float32 => LiteralValue::Float32(exseries.f32().unwrap().get(0).unwrap()), - DataType::Float64 => LiteralValue::Float64(exseries.f64().unwrap().get(0).unwrap()), - _ => panic!("unsupported dtype for pow: must be integer or float subtype"), - } -} - #[rustler::nif(schedule = "DirtyCpu")] pub fn s_cast(s: ExSeries, to_type: ExSeriesDtype) -> Result { let dtype = DataType::try_from(&to_type)?; diff --git a/test/explorer/series_test.exs b/test/explorer/series_test.exs index 2fb7270c5..59d760f97 100644 --- a/test/explorer/series_test.exs +++ b/test/explorer/series_test.exs @@ -2185,7 +2185,7 @@ defmodule Explorer.SeriesTest do result = Series.pow(s1, 2) - assert result.dtype == {:s, 64} + assert result.dtype == {:f, 64} assert Series.to_list(result) == [1, 4, 9] end From 16ef6985852c21a2776def4182d1ce631c333752 Mon Sep 17 00:00:00 2001 From: William Lanchantin Date: Mon, 15 Jan 2024 16:48:00 -0500 Subject: [PATCH 14/19] also fix doctests --- lib/explorer/data_frame.ex | 2 +- lib/explorer/series.ex | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/explorer/data_frame.ex b/lib/explorer/data_frame.ex index 56e34ee1a..59c742d5c 100644 --- a/lib/explorer/data_frame.ex +++ b/lib/explorer/data_frame.ex @@ -2747,7 +2747,7 @@ defmodule Explorer.DataFrame do #Explorer.DataFrame< Polars[3 x 2] a string ["a", "b", "c"] - b s64 [1, 4, 9] + b f64 [1.0, 4.0, 9.0] > It's possible to "reuse" a variable for different computations: diff --git a/lib/explorer/series.ex b/lib/explorer/series.ex index 3eaa86a04..35ad8f5b5 100644 --- a/lib/explorer/series.ex +++ b/lib/explorer/series.ex @@ -3402,7 +3402,7 @@ defmodule Explorer.Series do iex> Explorer.Series.pow(s, 3) #Explorer.Series< Polars[3] - s64 [8, 64, 216] + f64 [8.0, 64.0, 216.0] > iex> s = [2, 4, 6] |> Explorer.Series.from_list() From fd78f8569eb709c6ad7f906c92bbb8c541aba2f0 Mon Sep 17 00:00:00 2001 From: William Lanchantin Date: Mon, 15 Jan 2024 16:49:05 -0500 Subject: [PATCH 15/19] use === in tests with float assertions --- test/explorer/data_frame_test.exs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/explorer/data_frame_test.exs b/test/explorer/data_frame_test.exs index 288b48386..570a51c35 100644 --- a/test/explorer/data_frame_test.exs +++ b/test/explorer/data_frame_test.exs @@ -799,7 +799,7 @@ defmodule Explorer.DataFrameTest do calc10: is_nan(divide(a * 0.0, 0.0)) ) - assert DF.to_columns(df1, atom_keys: true) == %{ + assert DF.to_columns(df1, atom_keys: true) === %{ a: [1, 2, 4], calc1: [3, 4, 6], calc2: [-1, 0, 2], @@ -843,7 +843,7 @@ defmodule Explorer.DataFrameTest do calc7: remainder(2, a) ) - assert DF.to_columns(df1, atom_keys: true) == %{ + assert DF.to_columns(df1, atom_keys: true) === %{ a: [1, 2, 4], calc1: [3, 4, 6], calc2: [1, 0, -2], From 888721cd24e067b38346be45c2713842da0876a3 Mon Sep 17 00:00:00 2001 From: William Lanchantin Date: Mon, 15 Jan 2024 17:46:33 -0500 Subject: [PATCH 16/19] use cast_to_pow to declare out_dtype --- lib/explorer/backend/lazy_series.ex | 4 ++-- lib/explorer/backend/series.ex | 2 +- lib/explorer/polars_backend/series.ex | 4 ++-- lib/explorer/series.ex | 22 +++++++++++++++++++++- test/explorer/data_frame_test.exs | 12 ++++++------ 5 files changed, 32 insertions(+), 12 deletions(-) diff --git a/lib/explorer/backend/lazy_series.ex b/lib/explorer/backend/lazy_series.ex index 531a99526..03fe7b08a 100644 --- a/lib/explorer/backend/lazy_series.ex +++ b/lib/explorer/backend/lazy_series.ex @@ -146,8 +146,8 @@ defmodule Explorer.Backend.LazySeries do @comparison_operations [:equal, :not_equal, :greater, :greater_equal, :less, :less_equal] - @basic_arithmetic_operations [:add, :subtract, :multiply, :divide] - @other_arithmetic_operations [:pow, :quotient, :remainder] + @basic_arithmetic_operations [:add, :subtract, :multiply, :divide, :pow] + @other_arithmetic_operations [:quotient, :remainder] @aggregation_operations [ :sum, diff --git a/lib/explorer/backend/series.ex b/lib/explorer/backend/series.ex index 87caf4e84..be0683b50 100644 --- a/lib/explorer/backend/series.ex +++ b/lib/explorer/backend/series.ex @@ -114,7 +114,7 @@ defmodule Explorer.Backend.Series do @callback divide(out_dtype :: dtype(), s, s) :: s @callback quotient(s, s) :: s @callback remainder(s, s) :: s - @callback pow(s, s) :: s + @callback pow(out_dtype :: dtype(), s, s) :: s @callback log(argument :: s) :: s @callback log(argument :: s, base :: float()) :: s @callback exp(s) :: s diff --git a/lib/explorer/polars_backend/series.ex b/lib/explorer/polars_backend/series.ex index 42a4a87de..39e9d9d16 100644 --- a/lib/explorer/polars_backend/series.ex +++ b/lib/explorer/polars_backend/series.ex @@ -313,7 +313,7 @@ defmodule Explorer.PolarsBackend.Series do do: Shared.apply_series(matching_size!(left, right), :s_remainder, [right.data]) @impl true - def pow(left, right) do + def pow(out_dtype, left, right) do _ = matching_size!(left, right) left_lazy = Explorer.Backend.LazySeries.new(:column, ["base"], left.dtype) right_lazy = Explorer.Backend.LazySeries.new(:column, ["exponent"], right.dtype) @@ -326,7 +326,7 @@ defmodule Explorer.PolarsBackend.Series do end df = Explorer.PolarsBackend.DataFrame.from_series(df_args) - pow = Explorer.Backend.LazySeries.new(:pow, pow_args, nil) + pow = Explorer.Backend.LazySeries.new(:pow, pow_args, out_dtype) Explorer.PolarsBackend.DataFrame.mutate_with(df, df, [{"pow", pow}]) |> Explorer.PolarsBackend.DataFrame.pull("pow") diff --git a/lib/explorer/series.ex b/lib/explorer/series.ex index 35ad8f5b5..355ec60b5 100644 --- a/lib/explorer/series.ex +++ b/lib/explorer/series.ex @@ -3428,7 +3428,27 @@ defmodule Explorer.Series do """ @doc type: :element_wise @spec pow(left :: Series.t() | number(), right :: Series.t() | number()) :: Series.t() - def pow(left, right), do: basic_numeric_operation(:pow, left, right) + def pow(left, right) do + [left, right] = cast_for_arithmetic("pow/2", [left, right]) + + if out_dtype = cast_to_pow(dtype(left), dtype(right)) do + apply_series_list(:pow, [out_dtype, left, right]) + else + dtype_mismatch_error("pow/2", left, right) + end + end + + # TODO: ensure these types are correct. + defp cast_to_pow({:u, l}, {:u, r}), do: {:u, max(l, r)} + defp cast_to_pow({:u, _}, {:s, _}), do: {:f, 64} + defp cast_to_pow({:u, _}, {:f, f}), do: {:f, f} + defp cast_to_pow({:s, s}, {:u, u}), do: {:s, min(64, max(2 * u, s))} + defp cast_to_pow({:s, _}, {:s, _}), do: {:f, 64} + defp cast_to_pow({:s, _}, {:f, f}), do: {:f, f} + defp cast_to_pow({:f, f}, {:u, _}), do: {:f, f} + defp cast_to_pow({:f, f}, {:s, _}), do: {:f, f} + defp cast_to_pow({:f, f}, {:f, _}), do: {:f, f} + defp cast_to_pow(_, _), do: nil @doc """ Calculates the natural logarithm. diff --git a/test/explorer/data_frame_test.exs b/test/explorer/data_frame_test.exs index 570a51c35..7d30381fd 100644 --- a/test/explorer/data_frame_test.exs +++ b/test/explorer/data_frame_test.exs @@ -883,7 +883,7 @@ defmodule Explorer.DataFrameTest do calc7: remainder(a, ^series) ) - assert DF.to_columns(df1, atom_keys: true) == %{ + assert DF.to_columns(df1, atom_keys: true) === %{ a: [1, 2, 4], calc1: [3, 3, 6], calc2: [-1, 1, 2], @@ -900,7 +900,7 @@ defmodule Explorer.DataFrameTest do "calc2" => {:s, 64}, "calc3" => {:s, 64}, "calc4" => {:f, 64}, - "calc5" => {:s, 64}, + "calc5" => {:f, 64}, "calc6" => {:s, 64}, "calc7" => {:s, 64} } @@ -921,7 +921,7 @@ defmodule Explorer.DataFrameTest do calc7: remainder(^series, a) ) - assert DF.to_columns(df1, atom_keys: true) == %{ + assert DF.to_columns(df1, atom_keys: true) === %{ a: [2, 1, 2], calc1: [3, 3, 6], calc2: [-1, 1, 2], @@ -938,7 +938,7 @@ defmodule Explorer.DataFrameTest do "calc2" => {:s, 64}, "calc3" => {:s, 64}, "calc4" => {:f, 64}, - "calc5" => {:s, 64}, + "calc5" => {:f, 64}, "calc6" => {:s, 64}, "calc7" => {:s, 64} } @@ -958,7 +958,7 @@ defmodule Explorer.DataFrameTest do calc7: remainder(b, c) ) - assert DF.to_columns(df1, atom_keys: true) == %{ + assert DF.to_columns(df1, atom_keys: true) === %{ a: [1, 2, 3], b: [20, 40, 60], c: [10, 0, 8], @@ -981,7 +981,7 @@ defmodule Explorer.DataFrameTest do "calc2" => {:s, 64}, "calc3" => {:s, 64}, "calc4" => {:f, 64}, - "calc5" => {:s, 64}, + "calc5" => {:f, 64}, "calc6" => {:s, 64}, "calc7" => {:s, 64} } From 86d5ca7266e9098512e5667b773f58daec980fba Mon Sep 17 00:00:00 2001 From: William Lanchantin Date: Mon, 15 Jan 2024 17:52:58 -0500 Subject: [PATCH 17/19] fix warning (again) --- lib/explorer/series.ex | 2 -- 1 file changed, 2 deletions(-) diff --git a/lib/explorer/series.ex b/lib/explorer/series.ex index 355ec60b5..5010fcc9a 100644 --- a/lib/explorer/series.ex +++ b/lib/explorer/series.ex @@ -3772,8 +3772,6 @@ defmodule Explorer.Series do def atan(%Series{dtype: dtype}), do: dtype_error("atan/1", dtype, [{:f, 32}, {:f, 64}]) - defp basic_numeric_operation(operation, left, right, args \\ []) - defp basic_numeric_operation(operation, %Series{} = left, right, args) when is_numeric(right), do: basic_numeric_operation(operation, left, from_same_value(left, right), args) From f11f863b5380f3b4b03b55389d41b8ec837ecda0 Mon Sep 17 00:00:00 2001 From: William Lanchantin Date: Tue, 16 Jan 2024 17:38:35 -0500 Subject: [PATCH 18/19] consilidate cases Co-authored-by: josevalim --- lib/explorer/series.ex | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/lib/explorer/series.ex b/lib/explorer/series.ex index eb2f99f0a..de31d1a90 100644 --- a/lib/explorer/series.ex +++ b/lib/explorer/series.ex @@ -3441,16 +3441,12 @@ defmodule Explorer.Series do end end - # TODO: ensure these types are correct. defp cast_to_pow({:u, l}, {:u, r}), do: {:u, max(l, r)} - defp cast_to_pow({:u, _}, {:s, _}), do: {:f, 64} - defp cast_to_pow({:u, _}, {:f, f}), do: {:f, f} defp cast_to_pow({:s, s}, {:u, u}), do: {:s, min(64, max(2 * u, s))} - defp cast_to_pow({:s, _}, {:s, _}), do: {:f, 64} - defp cast_to_pow({:s, _}, {:f, f}), do: {:f, f} - defp cast_to_pow({:f, f}, {:u, _}), do: {:f, f} - defp cast_to_pow({:f, f}, {:s, _}), do: {:f, f} - defp cast_to_pow({:f, f}, {:f, _}), do: {:f, f} + defp cast_to_pow({:f, l}, {:f, r}), do: {:f, max(l, r)} + defp cast_to_pow({:f, l}, {n, _}) when K.in(n, [:u, :s]), do: {:f, l} + defp cast_to_pow({n, _}, {:f, r}) when K.in(n, [:u, :s]), do: {:f, r} + defp cast_to_pow({n, _}, {:s, _}) when K.in(n, [:u, :s]), do: {:f, 64} defp cast_to_pow(_, _), do: nil @doc """ From 058fea6a0597054a35e3a9e21e33ba0100f80608 Mon Sep 17 00:00:00 2001 From: William Lanchantin Date: Tue, 16 Jan 2024 17:38:59 -0500 Subject: [PATCH 19/19] pre-cast to ensure precision --- lib/explorer/polars_backend/series.ex | 4 ++++ test/explorer/series_test.exs | 9 +++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/lib/explorer/polars_backend/series.ex b/lib/explorer/polars_backend/series.ex index a88eb96f3..5245bdf85 100644 --- a/lib/explorer/polars_backend/series.ex +++ b/lib/explorer/polars_backend/series.ex @@ -318,6 +318,10 @@ defmodule Explorer.PolarsBackend.Series do @impl true def pow(out_dtype, left, right) do _ = matching_size!(left, right) + + # We need to pre-cast or we may lose precision. + left = Explorer.Series.cast(left, out_dtype) + left_lazy = Explorer.Backend.LazySeries.new(:column, ["base"], left.dtype) right_lazy = Explorer.Backend.LazySeries.new(:column, ["exponent"], right.dtype) diff --git a/test/explorer/series_test.exs b/test/explorer/series_test.exs index fe6565b16..44817312c 100644 --- a/test/explorer/series_test.exs +++ b/test/explorer/series_test.exs @@ -2217,7 +2217,7 @@ defmodule Explorer.SeriesTest do result = Series.pow(base, power) - assert result.dtype == {:u, u_base} + assert result.dtype == {:u, max(u_base, u_power)} assert Series.to_list(result) == [1, 4, 3] end end @@ -2241,7 +2241,8 @@ defmodule Explorer.SeriesTest do result = Series.pow(base, power) - assert result.dtype == {:s, s_base} + # Unsigned integers have twice the precision as signed integers. + assert result.dtype == {:s, min(64, max(s_base, 2 * u_power))} assert Series.to_list(result) == [1, 4, 3] end end @@ -2277,7 +2278,7 @@ defmodule Explorer.SeriesTest do result = Series.pow(base, power) - assert result.dtype == {:f, 64} + assert result.dtype == {:f, f_power} assert Series.to_list(result) === [1.0, 4.0, 3.0] end end @@ -2289,7 +2290,7 @@ defmodule Explorer.SeriesTest do result = Series.pow(base, power) - assert result.dtype == {:f, f_base} + assert result.dtype == {:f, max(f_base, f_power)} assert Series.to_list(result) === [1.0, 4.0, 3.0] end end