Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rank method as atom #770

Merged
merged 2 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/explorer/backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ defmodule Explorer.Backend.Series do
@callback shift(s, offset :: integer, default :: nil) :: s
@callback rank(
s,
method :: String.t(),
method :: atom(),
descending :: boolean(),
seed :: option(integer())
) :: s | lazy_s()
Expand Down
27 changes: 15 additions & 12 deletions lib/explorer/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1765,12 +1765,12 @@ defmodule Explorer.Series do
## Options

* `:method` - Determine how ranks are assigned to tied elements. The following methods are available:
- `"average"` : Each value receives the average rank that would be assigned to all tied values. (default)
- `"min"` : Tied values are assigned the minimum rank. Also known as "competition" ranking.
- `"max"` : Tied values are assigned the maximum of their ranks.
- `"dense"` : Similar to `"min"`, but the rank of the next highest element is assigned the rank immediately after those assigned to the tied elements.
- `"ordinal"` : Each value is given a distinct rank based on its occurrence in the series.
- `"random"` : Similar to `"ordinal"`, but the rank for ties is not dependent on the order that the values occur in the Series.
- `:average` : Each value receives the average rank that would be assigned to all tied values. (default)
- `:min` : Tied values are assigned the minimum rank. Also known as "competition" ranking.
- `:max` : Tied values are assigned the maximum of their ranks.
- `:dense` : Similar to `:min`, but the rank of the next highest element is assigned the rank immediately after those assigned to the tied elements.
- `:ordinal` : Each value is given a distinct rank based on its occurrence in the series.
- `:random` : Similar to `:ordinal`, but the rank for ties is not dependent on the order that the values occur in the Series.
* `:descending` - Rank in descending order.
* `:seed` - An integer to be used as a random seed. If nil, a random value between 0 and 2^64 − 1 will be used. (default: nil)

Expand All @@ -1784,36 +1784,36 @@ defmodule Explorer.Series do
>

iex> s = Explorer.Series.from_list([1.1, 2.4, 3.2])
iex> Explorer.Series.rank(s, method: "ordinal")
iex> Explorer.Series.rank(s, method: :ordinal)
#Explorer.Series<
Polars[3]
integer [1, 2, 3]
>

iex> s = Explorer.Series.from_list([ ~N[2022-07-07 17:44:13.020548], ~N[2022-07-07 17:43:08.473561], ~N[2022-07-07 17:45:00.116337] ])
iex> Explorer.Series.rank(s, method: "average")
iex> Explorer.Series.rank(s, method: :average)
#Explorer.Series<
Polars[3]
f64 [2.0, 1.0, 3.0]
>

iex> s = Explorer.Series.from_list([3, 6, 1, 1, 6])
iex> Explorer.Series.rank(s, method: "min")
iex> Explorer.Series.rank(s, method: :min)
#Explorer.Series<
Polars[5]
integer [3, 4, 1, 1, 4]
>

iex> s = Explorer.Series.from_list([3, 6, 1, 1, 6])
iex> Explorer.Series.rank(s, method: "dense")
iex> Explorer.Series.rank(s, method: :dense)
#Explorer.Series<
Polars[5]
integer [2, 3, 1, 1, 3]
>


iex> s = Explorer.Series.from_list([3, 6, 1, 1, 6])
iex> Explorer.Series.rank(s, method: "random", seed: 42)
iex> Explorer.Series.rank(s, method: :random, seed: 42)
#Explorer.Series<
Polars[5]
integer [3, 4, 2, 1, 5]
Expand All @@ -1824,7 +1824,10 @@ defmodule Explorer.Series do
def rank(series, opts \\ [])

def rank(series, opts) do
opts = Keyword.validate!(opts, method: "average", descending: false, seed: nil)
opts = Keyword.validate!(opts, method: :average, descending: false, seed: nil)

if K.not(K.in(opts[:method], [:average, :min, :max, :dense, :ordinal, :random])),
do: raise(ArgumentError, "unsupported rank method #{inspect(opts[:method])}")

apply_series(series, :rank, [opts[:method], opts[:descending], opts[:seed]])
end
Expand Down
10 changes: 10 additions & 0 deletions native/explorer/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,16 @@ pub enum ExCorrelationMethod {
Spearman,
}

#[derive(NifTaggedEnum)]
pub enum ExRankMethod {
Average,
Min,
Max,
Dense,
Ordinal,
Random,
}

impl TryFrom<ExParquetCompression> for ParquetCompression {
type Error = ExplorerError;

Expand Down
9 changes: 7 additions & 2 deletions native/explorer/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use polars::prelude::{
use polars::prelude::{DataType, Expr, Literal, StrptimeOptions, TimeUnit};

use crate::datatypes::{
ExCorrelationMethod, ExDate, ExDateTime, ExDuration, ExSeriesDtype, ExValidValue,
ExCorrelationMethod, ExDate, ExDateTime, ExDuration, ExRankMethod, ExSeriesDtype, ExValidValue,
};
use crate::series::{cast_str_to_f64, ewm_opts, rolling_opts};
use crate::{ExDataFrame, ExExpr, ExSeries};
Expand Down Expand Up @@ -269,7 +269,12 @@ pub fn expr_sample_frac(
}

#[rustler::nif]
pub fn expr_rank(expr: ExExpr, method: &str, descending: bool, seed: Option<u64>) -> ExExpr {
pub fn expr_rank(
expr: ExExpr,
method: ExRankMethod,
descending: bool,
seed: Option<u64>,
) -> ExExpr {
let expr = expr.clone_inner();
let rank_options = crate::parse_rank_method_options(method, descending);

Expand Down
24 changes: 10 additions & 14 deletions native/explorer/src/series.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::{
atoms,
datatypes::{
ExCorrelationMethod, ExDate, ExDateTime, ExDuration, ExSeriesDtype, ExSeriesIoType, ExTime,
ExValidValue,
ExCorrelationMethod, ExDate, ExDateTime, ExDuration, ExRankMethod, ExSeriesDtype,
ExSeriesIoType, ExTime, ExValidValue,
},
encoding, ExDataFrame, ExSeries, ExplorerError,
};
Expand Down Expand Up @@ -1358,7 +1358,7 @@ pub fn s_sample_frac(
#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_rank(
series: ExSeries,
method: &str,
method: ExRankMethod,
descending: bool,
seed: Option<u64>,
) -> Result<ExSeries, ExplorerError> {
Expand All @@ -1384,36 +1384,32 @@ pub fn s_rank(
}
}

pub fn parse_rank_method_options(strategy: &str, descending: bool) -> RankOptions {
pub fn parse_rank_method_options(strategy: ExRankMethod, descending: bool) -> RankOptions {
match strategy {
"ordinal" => RankOptions {
ExRankMethod::Ordinal => RankOptions {
method: RankMethod::Ordinal,
descending,
},
"random" => RankOptions {
ExRankMethod::Random => RankOptions {
method: RankMethod::Random,
descending,
},
"average" => RankOptions {
ExRankMethod::Average => RankOptions {
method: RankMethod::Average,
descending,
},
"min" => RankOptions {
ExRankMethod::Min => RankOptions {
method: RankMethod::Min,
descending,
},
"max" => RankOptions {
ExRankMethod::Max => RankOptions {
method: RankMethod::Max,
descending,
},
"dense" => RankOptions {
ExRankMethod::Dense => RankOptions {
method: RankMethod::Dense,
descending,
},
_ => RankOptions {
method: RankMethod::Average,
descending,
},
}
}

Expand Down
2 changes: 1 addition & 1 deletion test/explorer/data_frame_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ defmodule Explorer.DataFrameTest do
f: distinct(a),
g: unordered_distinct(a),
h: -a,
i: rank(a, method: "ordinal"),
i: rank(a, method: :ordinal),
j: rank(c)
)

Expand Down
18 changes: 13 additions & 5 deletions test/explorer/series_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -4091,31 +4091,31 @@ defmodule Explorer.SeriesTest do

test "rank of a series of floats (method: ordinal)" do
s = Series.from_list([3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1])
r = Series.rank(s, method: "ordinal")
r = Series.rank(s, method: :ordinal)
assert Series.to_list(r) === [8, 2, 5, 3, 9, 10, 6, 7, 1, 4]
end

test "rank of a series of floats (method: min)" do
s = Series.from_list([3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1])
r = Series.rank(s, method: "min")
r = Series.rank(s, method: :min)
assert Series.to_list(r) === [8, 2, 5, 3, 9, 10, 6, 6, 1, 3]
end

test "rank of a series of floats (method: max)" do
s = Series.from_list([3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1])
r = Series.rank(s, method: "max")
r = Series.rank(s, method: :max)
assert Series.to_list(r) === [8, 2, 5, 4, 9, 10, 7, 7, 1, 4]
end

test "rank of a series of floats (method: dense)" do
s = Series.from_list([3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1])
r = Series.rank(s, method: "dense")
r = Series.rank(s, method: :dense)
assert Series.to_list(r) === [6, 2, 4, 3, 7, 8, 5, 5, 1, 3]
end

test "rank of a series of floats (method: random)" do
s = Series.from_list([3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1])
r = Series.rank(s, method: "random", seed: 4242)
r = Series.rank(s, method: :random, seed: 4242)
assert Series.to_list(r) === [8, 2, 5, 4, 9, 10, 7, 6, 1, 3]
end

Expand All @@ -4136,6 +4136,14 @@ defmodule Explorer.SeriesTest do
r = Series.rank(s)
assert Series.to_list(r) === [3.0, 5.0, 6.0, nil, 4.0, 2.0, 1.0, 7.0]
end

test "invalid rank method" do
s = Series.from_list([3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1])

assert_raise ArgumentError, ~s(unsupported rank method :not_a_method), fn ->
Series.rank(s, method: :not_a_method, seed: 4242)
end
end
end

describe "skew/2" do
Expand Down
Loading