diff --git a/lib/explorer/backend/lazy_series.ex b/lib/explorer/backend/lazy_series.ex index ce26fe40e..764d0251b 100644 --- a/lib/explorer/backend/lazy_series.ex +++ b/lib/explorer/backend/lazy_series.ex @@ -1244,9 +1244,9 @@ defmodule Explorer.Backend.LazySeries do at_every: 2, categories: 1, categorise: 2, - cut: 5, + cut: 7, frequencies: 1, - qcut: 5, + qcut: 8, mask: 2, owner_import: 1, owner_export: 1, diff --git a/lib/explorer/backend/series.ex b/lib/explorer/backend/series.ex index 517a2b5ef..ded019664 100644 --- a/lib/explorer/backend/series.ex +++ b/lib/explorer/backend/series.ex @@ -190,9 +190,26 @@ defmodule Explorer.Backend.Series do # Categorisation - @callback cut(s, [float()], [String.t()] | nil, String.t() | nil, String.t() | nil) :: + @callback cut( + s, + bins :: [float()], + labels :: option([String.t()]), + break_point_label :: option(String.t()), + category_label :: option(String.t()), + left_close :: boolean(), + include_breaks :: boolean() + ) :: df - @callback qcut(s, [float()], [String.t()] | nil, String.t() | nil, String.t() | nil) :: + @callback qcut( + s, + quantiles :: [float()], + labels :: option([String.t()]), + break_point_label :: option(String.t()), + category_label :: option(String.t()), + allow_duplicates :: boolean(), + left_close :: boolean(), + include_breaks :: boolean() + ) :: df # Rolling diff --git a/lib/explorer/polars_backend/native.ex b/lib/explorer/polars_backend/native.ex index 4d438b97e..688ed0827 100644 --- a/lib/explorer/polars_backend/native.ex +++ b/lib/explorer/polars_backend/native.ex @@ -418,13 +418,33 @@ defmodule Explorer.PolarsBackend.Native do def s_upcase(_s), do: err() def s_unordered_distinct(_s), do: err() def s_frequencies(_s), do: err() - def s_cut(_s, _bins, _labels, _break_point_label, _category_label), do: err() + + def s_cut( + _s, + _bins, + _labels, + _break_point_label, + _category_label, + _left_close, + _include_breaks + ), + do: err() + def s_substring(_s, _offset, _length), do: err() def s_split(_s, _by), do: err() def s_split_into(_s, _by, _num_fields), do: err() - def s_qcut(_s, _quantiles, _labels, _break_point_label, _category_label), - do: err() + def s_qcut( + _s, + _quantiles, + _labels, + _break_point_label, + _category_label, + _allow_duplicates, + _left_close, + _include_breaks + ), + do: err() def s_variance(_s, _ddof), do: err() def s_window_max(_s, _window_size, _weight, _ignore_null, _min_periods), do: err() diff --git a/lib/explorer/polars_backend/series.ex b/lib/explorer/polars_backend/series.ex index e21804f4b..56af73590 100644 --- a/lib/explorer/polars_backend/series.ex +++ b/lib/explorer/polars_backend/series.ex @@ -541,13 +541,15 @@ defmodule Explorer.PolarsBackend.Series do # Categorisation @impl true - def cut(series, bins, labels, break_point_label, category_label) do + def cut(series, bins, labels, break_point_label, category_label, left_close, include_breaks) do case Explorer.PolarsBackend.Native.s_cut( series.data, bins, labels, break_point_label, - category_label + category_label, + left_close, + include_breaks ) do {:ok, polars_df} -> Shared.create_dataframe!(polars_df) @@ -561,13 +563,25 @@ defmodule Explorer.PolarsBackend.Series do end @impl true - def qcut(series, quantiles, labels, break_point_label, category_label) do + def qcut( + series, + quantiles, + labels, + break_point_label, + category_label, + allow_duplicates, + left_close, + include_breaks + ) do Shared.apply(:s_qcut, [ series.data, quantiles, labels, break_point_label, - category_label + category_label, + allow_duplicates, + left_close, + include_breaks ]) |> Shared.create_dataframe!() end diff --git a/lib/explorer/series.ex b/lib/explorer/series.ex index 26a503959..afec16156 100644 --- a/lib/explorer/series.ex +++ b/lib/explorer/series.ex @@ -4842,11 +4842,22 @@ defmodule Explorer.Series do """ @doc type: :aggregation def cut(series, bins, opts \\ []) do + opts = + Keyword.validate!(opts, + labels: nil, + break_point_label: nil, + category_label: nil, + left_close: false, + include_breaks: true + ) + apply_series(series, :cut, [ Enum.map(bins, &(&1 / 1.0)), - Keyword.get(opts, :labels), - Keyword.get(opts, :break_point_label), - Keyword.get(opts, :category_label) + opts[:labels], + opts[:break_point_label], + opts[:category_label], + opts[:left_close], + opts[:include_breaks] ]) end @@ -4868,6 +4879,9 @@ defmodule Explorer.Series do * `:category_label` - The name given to the category column. Defaults to `category`. + * `:allow_duplicates` - If quantiles can have duplicated values. + Defaults to `false`. + ## Examples iex> s = Explorer.Series.from_list([1.0, 2.0, 3.0, 4.0, 5.0]) @@ -4881,11 +4895,24 @@ defmodule Explorer.Series do """ @doc type: :aggregation def qcut(series, quantiles, opts \\ []) do + opts = + Keyword.validate!(opts, + labels: nil, + break_point_label: nil, + category_label: nil, + allow_duplicates: false, + left_close: false, + include_breaks: true + ) + apply_series(series, :qcut, [ Enum.map(quantiles, &(&1 / 1.0)), - Keyword.get(opts, :labels), - Keyword.get(opts, :break_point_label), - Keyword.get(opts, :category_label) + opts[:labels], + opts[:break_point_label], + opts[:category_label], + opts[:allow_duplicates], + opts[:left_close], + opts[:include_breaks] ]) end diff --git a/native/explorer/src/series.rs b/native/explorer/src/series.rs index 2af37ab7a..d5584a0a1 100644 --- a/native/explorer/src/series.rs +++ b/native/explorer/src/series.rs @@ -206,9 +206,10 @@ pub fn s_cut( labels: Option>, break_point_label: Option<&str>, category_label: Option<&str>, + left_close: bool, + include_breaks: bool, ) -> Result { let series = series.clone_inner(); - let left_close = false; // Cut is going to return a Series of a Struct. We need to convert it to a DF. let cut_series = cut( @@ -216,7 +217,7 @@ pub fn s_cut( bins, labels.map(|vec| vec.iter().map(|label| label.into()).collect()), left_close, - true, + include_breaks, )?; let mut cut_df = DataFrame::new(cut_series.struct_()?.fields_as_series())?; @@ -231,6 +232,7 @@ pub fn s_cut( Ok(ExDataFrame::new(cut_df.clone())) } +#[allow(clippy::too_many_arguments)] #[rustler::nif(schedule = "DirtyCpu")] pub fn s_qcut( series: ExSeries, @@ -238,10 +240,11 @@ pub fn s_qcut( labels: Option>, break_point_label: Option<&str>, category_label: Option<&str>, + allow_duplicates: bool, + left_close: bool, + include_breaks: bool, ) -> Result { let series = series.clone_inner(); - let left_close = false; - let allow_duplicates = false; let qcut_series: Series = qcut( &series, @@ -249,7 +252,7 @@ pub fn s_qcut( labels.map(|vec| vec.iter().map(|label| label.into()).collect()), left_close, allow_duplicates, - true, + include_breaks, )?; let mut qcut_df = DataFrame::new(qcut_series.struct_()?.fields_as_series())?; diff --git a/test/explorer/series_test.exs b/test/explorer/series_test.exs index 3bfff11ca..307da6d75 100644 --- a/test/explorer/series_test.exs +++ b/test/explorer/series_test.exs @@ -5981,7 +5981,7 @@ defmodule Explorer.SeriesTest do assert Explorer.DataFrame.names(df) == ["values", "bp", "cat"] end - test "qcut/6" do + test "qcut/3" do series = Enum.to_list(-5..3) |> Series.from_list() df = Series.qcut(series, [0.0, 0.25, 0.75]) freqs = Series.frequencies(df[:category]) @@ -5995,6 +5995,20 @@ defmodule Explorer.SeriesTest do assert Series.to_list(freqs[:counts]) == [4, 2, 2, 1] end + + test "qcut/3 with duplicates" do + series = Explorer.Series.from_list([0.0, 0.0, 0.0, 3.0, 4.0, 5.0]) + df = Explorer.Series.qcut(series, [0.1, 0.25, 0.75], allow_duplicates: true) + freqs = Series.frequencies(df[:category]) + + assert Series.to_list(freqs[:values]) == [ + "(-inf, 0]", + "(3.75, inf]", + "(0, 3.75]" + ] + + assert Series.to_list(freqs[:counts]) == [3, 2, 1] + end end describe "join/2" do