From 827d4fb228c61c6cea42a132bc75bef7e504f54a Mon Sep 17 00:00:00 2001 From: Cristine Guadelupe Date: Wed, 13 Dec 2023 20:40:21 -0300 Subject: [PATCH 1/2] Rank method as atom --- lib/explorer/backend/series.ex | 2 +- lib/explorer/series.ex | 24 ++++++++++++------------ native/explorer/src/datatypes.rs | 10 ++++++++++ native/explorer/src/expressions.rs | 9 +++++++-- native/explorer/src/series.rs | 24 ++++++++++-------------- test/explorer/data_frame_test.exs | 2 +- test/explorer/series_test.exs | 10 +++++----- 7 files changed, 46 insertions(+), 35 deletions(-) diff --git a/lib/explorer/backend/series.ex b/lib/explorer/backend/series.ex index bcbcafcea..19da12198 100644 --- a/lib/explorer/backend/series.ex +++ b/lib/explorer/backend/series.ex @@ -66,7 +66,7 @@ defmodule Explorer.Backend.Series do @callback shift(s, offset :: integer, default :: nil) :: s @callback rank( s, - method :: String.t(), + method :: atom(), descending :: boolean(), seed :: option(integer()) ) :: s | lazy_s() diff --git a/lib/explorer/series.ex b/lib/explorer/series.ex index a55d7dabc..509ac8f9c 100644 --- a/lib/explorer/series.ex +++ b/lib/explorer/series.ex @@ -1765,12 +1765,12 @@ defmodule Explorer.Series do ## Options * `:method` - Determine how ranks are assigned to tied elements. The following methods are available: - - `"average"` : Each value receives the average rank that would be assigned to all tied values. (default) - - `"min"` : Tied values are assigned the minimum rank. Also known as "competition" ranking. - - `"max"` : Tied values are assigned the maximum of their ranks. - - `"dense"` : Similar to `"min"`, but the rank of the next highest element is assigned the rank immediately after those assigned to the tied elements. - - `"ordinal"` : Each value is given a distinct rank based on its occurrence in the series. - - `"random"` : Similar to `"ordinal"`, but the rank for ties is not dependent on the order that the values occur in the Series. + - `:average` : Each value receives the average rank that would be assigned to all tied values. (default) + - `:min` : Tied values are assigned the minimum rank. Also known as "competition" ranking. + - `:max` : Tied values are assigned the maximum of their ranks. + - `:dense` : Similar to `:min`, but the rank of the next highest element is assigned the rank immediately after those assigned to the tied elements. + - `:ordinal` : Each value is given a distinct rank based on its occurrence in the series. + - `:random` : Similar to `:ordinal`, but the rank for ties is not dependent on the order that the values occur in the Series. * `:descending` - Rank in descending order. * `:seed` - An integer to be used as a random seed. If nil, a random value between 0 and 2^64 − 1 will be used. (default: nil) @@ -1784,28 +1784,28 @@ defmodule Explorer.Series do > iex> s = Explorer.Series.from_list([1.1, 2.4, 3.2]) - iex> Explorer.Series.rank(s, method: "ordinal") + iex> Explorer.Series.rank(s, method: :ordinal) #Explorer.Series< Polars[3] integer [1, 2, 3] > iex> s = Explorer.Series.from_list([ ~N[2022-07-07 17:44:13.020548], ~N[2022-07-07 17:43:08.473561], ~N[2022-07-07 17:45:00.116337] ]) - iex> Explorer.Series.rank(s, method: "average") + iex> Explorer.Series.rank(s, method: :average) #Explorer.Series< Polars[3] f64 [2.0, 1.0, 3.0] > iex> s = Explorer.Series.from_list([3, 6, 1, 1, 6]) - iex> Explorer.Series.rank(s, method: "min") + iex> Explorer.Series.rank(s, method: :min) #Explorer.Series< Polars[5] integer [3, 4, 1, 1, 4] > iex> s = Explorer.Series.from_list([3, 6, 1, 1, 6]) - iex> Explorer.Series.rank(s, method: "dense") + iex> Explorer.Series.rank(s, method: :dense) #Explorer.Series< Polars[5] integer [2, 3, 1, 1, 3] @@ -1813,7 +1813,7 @@ defmodule Explorer.Series do iex> s = Explorer.Series.from_list([3, 6, 1, 1, 6]) - iex> Explorer.Series.rank(s, method: "random", seed: 42) + iex> Explorer.Series.rank(s, method: :random, seed: 42) #Explorer.Series< Polars[5] integer [3, 4, 2, 1, 5] @@ -1824,7 +1824,7 @@ defmodule Explorer.Series do def rank(series, opts \\ []) def rank(series, opts) do - opts = Keyword.validate!(opts, method: "average", descending: false, seed: nil) + opts = Keyword.validate!(opts, method: :average, descending: false, seed: nil) apply_series(series, :rank, [opts[:method], opts[:descending], opts[:seed]]) end diff --git a/native/explorer/src/datatypes.rs b/native/explorer/src/datatypes.rs index 94c9d3498..e3c267549 100644 --- a/native/explorer/src/datatypes.rs +++ b/native/explorer/src/datatypes.rs @@ -511,6 +511,16 @@ pub enum ExCorrelationMethod { Spearman, } +#[derive(NifTaggedEnum)] +pub enum ExRankMethod { + Average, + Min, + Max, + Dense, + Ordinal, + Random, +} + impl TryFrom for ParquetCompression { type Error = ExplorerError; diff --git a/native/explorer/src/expressions.rs b/native/explorer/src/expressions.rs index d396d1890..a08b56d69 100644 --- a/native/explorer/src/expressions.rs +++ b/native/explorer/src/expressions.rs @@ -11,7 +11,7 @@ use polars::prelude::{ use polars::prelude::{DataType, Expr, Literal, StrptimeOptions, TimeUnit}; use crate::datatypes::{ - ExCorrelationMethod, ExDate, ExDateTime, ExDuration, ExSeriesDtype, ExValidValue, + ExCorrelationMethod, ExDate, ExDateTime, ExDuration, ExRankMethod, ExSeriesDtype, ExValidValue, }; use crate::series::{cast_str_to_f64, ewm_opts, rolling_opts}; use crate::{ExDataFrame, ExExpr, ExSeries}; @@ -269,7 +269,12 @@ pub fn expr_sample_frac( } #[rustler::nif] -pub fn expr_rank(expr: ExExpr, method: &str, descending: bool, seed: Option) -> ExExpr { +pub fn expr_rank( + expr: ExExpr, + method: ExRankMethod, + descending: bool, + seed: Option, +) -> ExExpr { let expr = expr.clone_inner(); let rank_options = crate::parse_rank_method_options(method, descending); diff --git a/native/explorer/src/series.rs b/native/explorer/src/series.rs index 3cb31bc16..2fe56c8e5 100644 --- a/native/explorer/src/series.rs +++ b/native/explorer/src/series.rs @@ -1,8 +1,8 @@ use crate::{ atoms, datatypes::{ - ExCorrelationMethod, ExDate, ExDateTime, ExDuration, ExSeriesDtype, ExSeriesIoType, ExTime, - ExValidValue, + ExCorrelationMethod, ExDate, ExDateTime, ExDuration, ExRankMethod, ExSeriesDtype, + ExSeriesIoType, ExTime, ExValidValue, }, encoding, ExDataFrame, ExSeries, ExplorerError, }; @@ -1358,7 +1358,7 @@ pub fn s_sample_frac( #[rustler::nif(schedule = "DirtyCpu")] pub fn s_rank( series: ExSeries, - method: &str, + method: ExRankMethod, descending: bool, seed: Option, ) -> Result { @@ -1384,36 +1384,32 @@ pub fn s_rank( } } -pub fn parse_rank_method_options(strategy: &str, descending: bool) -> RankOptions { +pub fn parse_rank_method_options(strategy: ExRankMethod, descending: bool) -> RankOptions { match strategy { - "ordinal" => RankOptions { + ExRankMethod::Ordinal => RankOptions { method: RankMethod::Ordinal, descending, }, - "random" => RankOptions { + ExRankMethod::Random => RankOptions { method: RankMethod::Random, descending, }, - "average" => RankOptions { + ExRankMethod::Average => RankOptions { method: RankMethod::Average, descending, }, - "min" => RankOptions { + ExRankMethod::Min => RankOptions { method: RankMethod::Min, descending, }, - "max" => RankOptions { + ExRankMethod::Max => RankOptions { method: RankMethod::Max, descending, }, - "dense" => RankOptions { + ExRankMethod::Dense => RankOptions { method: RankMethod::Dense, descending, }, - _ => RankOptions { - method: RankMethod::Average, - descending, - }, } } diff --git a/test/explorer/data_frame_test.exs b/test/explorer/data_frame_test.exs index f22c24812..76259dde4 100644 --- a/test/explorer/data_frame_test.exs +++ b/test/explorer/data_frame_test.exs @@ -736,7 +736,7 @@ defmodule Explorer.DataFrameTest do f: distinct(a), g: unordered_distinct(a), h: -a, - i: rank(a, method: "ordinal"), + i: rank(a, method: :ordinal), j: rank(c) ) diff --git a/test/explorer/series_test.exs b/test/explorer/series_test.exs index a91409106..44095c118 100644 --- a/test/explorer/series_test.exs +++ b/test/explorer/series_test.exs @@ -4091,31 +4091,31 @@ defmodule Explorer.SeriesTest do test "rank of a series of floats (method: ordinal)" do s = Series.from_list([3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1]) - r = Series.rank(s, method: "ordinal") + r = Series.rank(s, method: :ordinal) assert Series.to_list(r) === [8, 2, 5, 3, 9, 10, 6, 7, 1, 4] end test "rank of a series of floats (method: min)" do s = Series.from_list([3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1]) - r = Series.rank(s, method: "min") + r = Series.rank(s, method: :min) assert Series.to_list(r) === [8, 2, 5, 3, 9, 10, 6, 6, 1, 3] end test "rank of a series of floats (method: max)" do s = Series.from_list([3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1]) - r = Series.rank(s, method: "max") + r = Series.rank(s, method: :max) assert Series.to_list(r) === [8, 2, 5, 4, 9, 10, 7, 7, 1, 4] end test "rank of a series of floats (method: dense)" do s = Series.from_list([3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1]) - r = Series.rank(s, method: "dense") + r = Series.rank(s, method: :dense) assert Series.to_list(r) === [6, 2, 4, 3, 7, 8, 5, 5, 1, 3] end test "rank of a series of floats (method: random)" do s = Series.from_list([3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1]) - r = Series.rank(s, method: "random", seed: 4242) + r = Series.rank(s, method: :random, seed: 4242) assert Series.to_list(r) === [8, 2, 5, 4, 9, 10, 7, 6, 1, 3] end From 82fbae34f5cda1f6e94af80a616caa2c210bfdaf Mon Sep 17 00:00:00 2001 From: Cristine Guadelupe Date: Wed, 13 Dec 2023 20:46:35 -0300 Subject: [PATCH 2/2] Validations --- lib/explorer/series.ex | 3 +++ test/explorer/series_test.exs | 8 ++++++++ 2 files changed, 11 insertions(+) diff --git a/lib/explorer/series.ex b/lib/explorer/series.ex index 509ac8f9c..c4707031f 100644 --- a/lib/explorer/series.ex +++ b/lib/explorer/series.ex @@ -1826,6 +1826,9 @@ defmodule Explorer.Series do def rank(series, opts) do opts = Keyword.validate!(opts, method: :average, descending: false, seed: nil) + if K.not(K.in(opts[:method], [:average, :min, :max, :dense, :ordinal, :random])), + do: raise(ArgumentError, "unsupported rank method #{inspect(opts[:method])}") + apply_series(series, :rank, [opts[:method], opts[:descending], opts[:seed]]) end diff --git a/test/explorer/series_test.exs b/test/explorer/series_test.exs index 44095c118..cb053c6d9 100644 --- a/test/explorer/series_test.exs +++ b/test/explorer/series_test.exs @@ -4136,6 +4136,14 @@ defmodule Explorer.SeriesTest do r = Series.rank(s) assert Series.to_list(r) === [3.0, 5.0, 6.0, nil, 4.0, 2.0, 1.0, 7.0] end + + test "invalid rank method" do + s = Series.from_list([3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1]) + + assert_raise ArgumentError, ~s(unsupported rank method :not_a_method), fn -> + Series.rank(s, method: :not_a_method, seed: 4242) + end + end end describe "skew/2" do