From 20a4b47cdb11fd995a0003d39b4c921a23a56915 Mon Sep 17 00:00:00 2001 From: Anthony Khong Date: Tue, 27 Jun 2023 19:22:39 +0700 Subject: [PATCH 1/4] implement cut --- lib/explorer/backend/lazy_series.ex | 1 + lib/explorer/backend/series.ex | 5 +++++ lib/explorer/polars_backend/native.ex | 1 + lib/explorer/polars_backend/series.ex | 15 +++++++++++++++ lib/explorer/series.ex | 21 +++++++++++++++++++++ native/explorer/Cargo.lock | 12 ++++++++++++ native/explorer/Cargo.toml | 3 +++ native/explorer/src/lib.rs | 1 + native/explorer/src/series.rs | 23 +++++++++++++++++++++++ test/explorer/series_test.exs | 14 ++++++++++++++ 10 files changed, 96 insertions(+) diff --git a/lib/explorer/backend/lazy_series.ex b/lib/explorer/backend/lazy_series.ex index df0e043df..5d550785f 100644 --- a/lib/explorer/backend/lazy_series.ex +++ b/lib/explorer/backend/lazy_series.ex @@ -883,6 +883,7 @@ defmodule Explorer.Backend.LazySeries do categories: 1, categorise: 2, frequencies: 1, + cut: 6, mask: 2, to_iovec: 1, to_list: 1 diff --git a/lib/explorer/backend/series.ex b/lib/explorer/backend/series.ex index 008d7b23c..1151deecd 100644 --- a/lib/explorer/backend/series.ex +++ b/lib/explorer/backend/series.ex @@ -154,6 +154,11 @@ defmodule Explorer.Backend.Series do @callback n_distinct(s) :: integer() | lazy_s() @callback frequencies(s) :: df + # Categorisation + + @callback cut(s, [float()], [String.t()] | nil, String.t() | nil, String.t() | nil, boolean()) :: + df + # Rolling @callback window_sum( diff --git a/lib/explorer/polars_backend/native.ex b/lib/explorer/polars_backend/native.ex index 7eb3837ef..42f7e11f7 100644 --- a/lib/explorer/polars_backend/native.ex +++ b/lib/explorer/polars_backend/native.ex @@ -321,6 +321,7 @@ defmodule Explorer.PolarsBackend.Native do def s_upcase(_s), do: err() def s_unordered_distinct(_s), do: err() def s_frequencies(_s), do: err() + def s_cut(_s, _bins, _labels, _break_point_label, _category_label, _maintain_order), do: err() def s_variance(_s), do: err() def s_window_max(_s, _window_size, _weight, _ignore_null, _min_periods), do: err() def s_window_mean(_s, _window_size, _weight, _ignore_null, _min_periods), do: err() diff --git a/lib/explorer/polars_backend/series.ex b/lib/explorer/polars_backend/series.ex index 3dc3a0df5..b39cd8a66 100644 --- a/lib/explorer/polars_backend/series.ex +++ b/lib/explorer/polars_backend/series.ex @@ -441,6 +441,21 @@ defmodule Explorer.PolarsBackend.Series do |> DataFrame.rename(["values", "counts"]) end + # Categorisation + + @impl true + def cut(series, bins, labels, break_point_label, category_label, maintain_order) do + Shared.apply(:s_cut, [ + series.data, + bins, + labels, + break_point_label, + category_label, + maintain_order + ]) + |> Shared.create_dataframe() + end + # Window @impl true diff --git a/lib/explorer/series.ex b/lib/explorer/series.ex index 2582bc78c..e2d7f6597 100644 --- a/lib/explorer/series.ex +++ b/lib/explorer/series.ex @@ -3539,6 +3539,27 @@ defmodule Explorer.Series do @doc type: :aggregation def frequencies(series), do: apply_series(series, :frequencies) + @doc """ + TODO + """ + @doc type: :aggregation + def cut( + series, + bins, + labels \\ nil, + break_point_label \\ nil, + category_label \\ nil, + maintain_order \\ false + ), + do: + apply_series(series, :cut, [ + Enum.map(bins, &(&1 / 1.0)), + labels, + break_point_label, + category_label, + maintain_order + ]) + @doc """ Counts the number of elements in a series. diff --git a/native/explorer/Cargo.lock b/native/explorer/Cargo.lock index 4eea32135..184edd87c 100644 --- a/native/explorer/Cargo.lock +++ b/native/explorer/Cargo.lock @@ -383,6 +383,7 @@ dependencies = [ "chrono", "mimalloc", "polars", + "polars-algo", "polars-ops", "rand", "rand_pcg", @@ -1031,6 +1032,17 @@ dependencies = [ "version_check", ] +[[package]] +name = "polars-algo" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51197645c231af62aa06c224aca357a09a836c2d2cc1342cd06884b1edc7d840" +dependencies = [ + "polars-core", + "polars-lazy", + "polars-ops", +] + [[package]] name = "polars-arrow" version = "0.29.0" diff --git a/native/explorer/Cargo.toml b/native/explorer/Cargo.toml index bc9280edb..f6152a5bb 100644 --- a/native/explorer/Cargo.toml +++ b/native/explorer/Cargo.toml @@ -70,3 +70,6 @@ features = [ [dependencies.polars-ops] version = "0.29" + +[dependencies.polars-algo] +version = "0.29" diff --git a/native/explorer/src/lib.rs b/native/explorer/src/lib.rs index 00e7f38d7..10f5fd7f0 100644 --- a/native/explorer/src/lib.rs +++ b/native/explorer/src/lib.rs @@ -392,6 +392,7 @@ rustler::init!( s_to_iovec, s_unordered_distinct, s_frequencies, + s_cut, s_variance, s_window_max, s_window_mean, diff --git a/native/explorer/src/series.rs b/native/explorer/src/series.rs index 61009f830..15a20c0ae 100644 --- a/native/explorer/src/series.rs +++ b/native/explorer/src/series.rs @@ -7,6 +7,7 @@ use crate::{ use encoding::encode_datetime; use polars::export::arrow::array::Utf8Array; use polars::prelude::*; +use polars_algo::cut; use rustler::{Binary, Encoder, Env, ListIterator, Term, TermType}; use std::{result::Result, slice}; @@ -325,6 +326,28 @@ pub fn s_frequencies(series: ExSeries) -> Result { Ok(ExDataFrame::new(df)) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn s_cut( + series: ExSeries, + bins: Vec, + labels: Option>, + break_point_label: Option<&str>, + category_label: Option<&str>, + maintain_order: bool, +) -> Result { + let series = series.clone_inner(); + let bins = Series::new("", bins); + let df = cut( + &series, + bins, + labels, + break_point_label, + category_label, + maintain_order, + )?; + Ok(ExDataFrame::new(df)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn s_slice_by_indices(series: ExSeries, indices: Vec) -> Result { let idx = UInt32Chunked::from_vec("idx", indices); diff --git a/test/explorer/series_test.exs b/test/explorer/series_test.exs index 8215ddc24..9b7bbf111 100644 --- a/test/explorer/series_test.exs +++ b/test/explorer/series_test.exs @@ -3883,4 +3883,18 @@ defmodule Explorer.SeriesTest do end end end + + describe "categorisation functions" do + test "cut/6" do + series = -30..30//5 |> Enum.map(&(&1 / 10)) |> Enum.to_list() |> Series.from_list() + df = Series.cut(series, [-1, 1]) + freqs = Series.frequencies(df[:category]) + assert Series.to_list(freqs[:values]) == ["(-inf, -1.0]", "(-1.0, 1.0]", "(1.0, inf]"] + assert Series.to_list(freqs[:counts]) == [5, 4, 4] + + series = Series.from_list([1, 2, 3, nil, nil]) + df = Series.cut(series, [2]) + assert [_, _, _, nil, nil] = Series.to_list(df[:category]) + end + end end From b418decbe5bd38f3d4754b5315e748f75bbf1002 Mon Sep 17 00:00:00 2001 From: Anthony Khong Date: Tue, 27 Jun 2023 19:52:06 +0700 Subject: [PATCH 2/4] document cut/6 and clean up tests --- lib/explorer/series.ex | 55 ++++++++++++++++++++++++----------- test/explorer/series_test.exs | 21 ++++++++++++- 2 files changed, 58 insertions(+), 18 deletions(-) diff --git a/lib/explorer/series.ex b/lib/explorer/series.ex index e2d7f6597..9de8703b5 100644 --- a/lib/explorer/series.ex +++ b/lib/explorer/series.ex @@ -3540,25 +3540,46 @@ defmodule Explorer.Series do def frequencies(series), do: apply_series(series, :frequencies) @doc """ - TODO + Bins values into discrete values. + + Given a `bins` length of N, there will be N+1 categories. + + ## Options + + * `:labels` - The labels assigned to the bins. Given `bins` of + length N, `:labels` must be of length N+1. Defaults to the bin + bounds (e.g. `(-inf -1.0]`, `(-1.0, 1.0]`, `(1.0, inf]`) + + * `:break_point_label` - The name given to the breakpoint column. + Defaults to `break_point`. + + * `:category_label` - The name given to the category column. + Defaults to `category`. + + * `:maintain_order` - The name given to the category column. + Defaults to `false`. + + ## Examples + + iex> s = Explorer.Series.from_list([1.0, 2.0, 3.0]) + iex> Explorer.Series.cut(s, [1.5, 2.5]) + #Explorer.DataFrame< + Polars[3 x 3] + float [1.0, 2.0, 3.0] + break_point float [1.5, 2.5, Inf] + category category ["(-inf, 1.5]", "(1.5, 2.5]", "(2.5, inf]"] + > """ @doc type: :aggregation - def cut( - series, - bins, - labels \\ nil, - break_point_label \\ nil, - category_label \\ nil, - maintain_order \\ false - ), - do: - apply_series(series, :cut, [ - Enum.map(bins, &(&1 / 1.0)), - labels, - break_point_label, - category_label, - maintain_order - ]) + def cut(series, bins, opts \\ []) do + apply_series(series, :cut, [ + Enum.map(bins, &(&1 / 1.0)), + Keyword.get(opts, :labels), + Keyword.get(opts, :break_point_label), + Keyword.get(opts, :category_label), + Keyword.get(opts, :maintain_order, false) + ]) + end @doc """ Counts the number of elements in a series. diff --git a/test/explorer/series_test.exs b/test/explorer/series_test.exs index 9b7bbf111..eef667d67 100644 --- a/test/explorer/series_test.exs +++ b/test/explorer/series_test.exs @@ -3885,16 +3885,35 @@ defmodule Explorer.SeriesTest do end describe "categorisation functions" do - test "cut/6" do + test "cut/6 with no nils" do series = -30..30//5 |> Enum.map(&(&1 / 10)) |> Enum.to_list() |> Series.from_list() df = Series.cut(series, [-1, 1]) freqs = Series.frequencies(df[:category]) assert Series.to_list(freqs[:values]) == ["(-inf, -1.0]", "(-1.0, 1.0]", "(1.0, inf]"] assert Series.to_list(freqs[:counts]) == [5, 4, 4] + end + test "cut/6 with nils" do series = Series.from_list([1, 2, 3, nil, nil]) df = Series.cut(series, [2]) assert [_, _, _, nil, nil] = Series.to_list(df[:category]) end + + test "cut/6 options" do + series = Series.from_list([1, 2, 3]) + + assert_raise RuntimeError, + "Polars Error: lengths don't match: labels count must equal bins count", + fn -> Series.cut(series, [2], labels: ["x"]) end + + df = + Series.cut(series, [2], + labels: ["x", "y"], + break_point_label: "bp", + category_label: "cat" + ) + + assert Explorer.DataFrame.names(df) == ["", "bp", "cat"] + end end end From c2b0aaf90fff8c9d399d7d4ac0c39b4589ac5989 Mon Sep 17 00:00:00 2001 From: Anthony Khong Date: Tue, 27 Jun 2023 20:05:28 +0700 Subject: [PATCH 3/4] add qcut/6 test --- lib/explorer/backend/lazy_series.ex | 1 + lib/explorer/backend/series.ex | 2 ++ lib/explorer/polars_backend/native.ex | 4 +++ lib/explorer/polars_backend/series.ex | 13 ++++++++ lib/explorer/series.ex | 43 +++++++++++++++++++++++++++ native/explorer/src/lib.rs | 1 + native/explorer/src/series.rs | 23 +++++++++++++- test/explorer/series_test.exs | 15 ++++++++++ 8 files changed, 101 insertions(+), 1 deletion(-) diff --git a/lib/explorer/backend/lazy_series.ex b/lib/explorer/backend/lazy_series.ex index 5d550785f..6c0ffdf4c 100644 --- a/lib/explorer/backend/lazy_series.ex +++ b/lib/explorer/backend/lazy_series.ex @@ -884,6 +884,7 @@ defmodule Explorer.Backend.LazySeries do categorise: 2, frequencies: 1, cut: 6, + qcut: 6, mask: 2, to_iovec: 1, to_list: 1 diff --git a/lib/explorer/backend/series.ex b/lib/explorer/backend/series.ex index 1151deecd..fd9f9893a 100644 --- a/lib/explorer/backend/series.ex +++ b/lib/explorer/backend/series.ex @@ -158,6 +158,8 @@ defmodule Explorer.Backend.Series do @callback cut(s, [float()], [String.t()] | nil, String.t() | nil, String.t() | nil, boolean()) :: df + @callback qcut(s, [float()], [String.t()] | nil, String.t() | nil, String.t() | nil, boolean()) :: + df # Rolling diff --git a/lib/explorer/polars_backend/native.ex b/lib/explorer/polars_backend/native.ex index 42f7e11f7..ea0f95770 100644 --- a/lib/explorer/polars_backend/native.ex +++ b/lib/explorer/polars_backend/native.ex @@ -322,6 +322,10 @@ defmodule Explorer.PolarsBackend.Native do def s_unordered_distinct(_s), do: err() def s_frequencies(_s), do: err() def s_cut(_s, _bins, _labels, _break_point_label, _category_label, _maintain_order), do: err() + + def s_qcut(_s, _quantiles, _labels, _break_point_label, _category_label, _maintain_order), + do: err() + def s_variance(_s), do: err() def s_window_max(_s, _window_size, _weight, _ignore_null, _min_periods), do: err() def s_window_mean(_s, _window_size, _weight, _ignore_null, _min_periods), do: err() diff --git a/lib/explorer/polars_backend/series.ex b/lib/explorer/polars_backend/series.ex index b39cd8a66..5aaabc65e 100644 --- a/lib/explorer/polars_backend/series.ex +++ b/lib/explorer/polars_backend/series.ex @@ -456,6 +456,19 @@ defmodule Explorer.PolarsBackend.Series do |> Shared.create_dataframe() end + @impl true + def qcut(series, quantiles, labels, break_point_label, category_label, maintain_order) do + Shared.apply(:s_qcut, [ + series.data, + quantiles, + labels, + break_point_label, + category_label, + maintain_order + ]) + |> Shared.create_dataframe() + end + # Window @impl true diff --git a/lib/explorer/series.ex b/lib/explorer/series.ex index 9de8703b5..431438cf6 100644 --- a/lib/explorer/series.ex +++ b/lib/explorer/series.ex @@ -3581,6 +3581,49 @@ defmodule Explorer.Series do ]) end + @doc """ + Bins values into discrete values base on their quantiles. + + Given a `quantiles` length of N, there will be N+1 categories. Each + element of `quantiles` is expected to be between 0.0 and 1.0. + + ## Options + + * `:labels` - The labels assigned to the bins. Given `bins` of + length N, `:labels` must be of length N+1. Defaults to the bin + bounds (e.g. `(-inf -1.0]`, `(-1.0, 1.0]`, `(1.0, inf]`) + + * `:break_point_label` - The name given to the breakpoint column. + Defaults to `break_point`. + + * `:category_label` - The name given to the category column. + Defaults to `category`. + + * `:maintain_order` - The name given to the category column. + Defaults to `false`. + + ## Examples + + iex> s = Explorer.Series.from_list([1.0, 2.0, 3.0, 4.0, 5.0]) + iex> Explorer.Series.qcut(s, [0.25, 0.75]) + #Explorer.DataFrame< + Polars[5 x 3] + float [1.0, 2.0, 3.0, 4.0, 5.0] + break_point float [2.0, 2.0, 4.0, 4.0, Inf] + category category ["(-inf, 2.0]", "(-inf, 2.0]", "(2.0, 4.0]", "(2.0, 4.0]", "(4.0, inf]"] + > + """ + @doc type: :aggregation + def qcut(series, quantiles, opts \\ []) do + apply_series(series, :qcut, [ + Enum.map(quantiles, &(&1 / 1.0)), + Keyword.get(opts, :labels), + Keyword.get(opts, :break_point_label), + Keyword.get(opts, :category_label), + Keyword.get(opts, :maintain_order, false) + ]) + end + @doc """ Counts the number of elements in a series. diff --git a/native/explorer/src/lib.rs b/native/explorer/src/lib.rs index 10f5fd7f0..5608913e1 100644 --- a/native/explorer/src/lib.rs +++ b/native/explorer/src/lib.rs @@ -393,6 +393,7 @@ rustler::init!( s_unordered_distinct, s_frequencies, s_cut, + s_qcut, s_variance, s_window_max, s_window_mean, diff --git a/native/explorer/src/series.rs b/native/explorer/src/series.rs index 15a20c0ae..5ce7fd028 100644 --- a/native/explorer/src/series.rs +++ b/native/explorer/src/series.rs @@ -7,7 +7,7 @@ use crate::{ use encoding::encode_datetime; use polars::export::arrow::array::Utf8Array; use polars::prelude::*; -use polars_algo::cut; +use polars_algo::{cut, qcut}; use rustler::{Binary, Encoder, Env, ListIterator, Term, TermType}; use std::{result::Result, slice}; @@ -348,6 +348,27 @@ pub fn s_cut( Ok(ExDataFrame::new(df)) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn s_qcut( + series: ExSeries, + quantiles: Vec, + labels: Option>, + break_point_label: Option<&str>, + category_label: Option<&str>, + maintain_order: bool, +) -> Result { + let series = series.clone_inner(); + let df = qcut( + &series, + &quantiles, + labels, + break_point_label, + category_label, + maintain_order, + )?; + Ok(ExDataFrame::new(df)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn s_slice_by_indices(series: ExSeries, indices: Vec) -> Result { let idx = UInt32Chunked::from_vec("idx", indices); diff --git a/test/explorer/series_test.exs b/test/explorer/series_test.exs index eef667d67..b4f66b84b 100644 --- a/test/explorer/series_test.exs +++ b/test/explorer/series_test.exs @@ -3915,5 +3915,20 @@ defmodule Explorer.SeriesTest do assert Explorer.DataFrame.names(df) == ["", "bp", "cat"] end + + test "qcut/6" do + series = Enum.to_list(-5..3) |> Series.from_list() + df = Series.qcut(series, [0.0, 0.25, 0.75]) + freqs = Series.frequencies(df[:category]) + + assert Series.to_list(freqs[:values]) == [ + "(-3.0, 1.0]", + "(-5.0, -3.0]", + "(1.0, inf]", + "(-inf, -5.0]" + ] + + assert Series.to_list(freqs[:counts]) == [4, 2, 2, 1] + end end end From 7b3d54636804fb971617a7fa5fc3e1c871a496de Mon Sep 17 00:00:00 2001 From: Anthony Khong Date: Tue, 27 Jun 2023 20:33:53 +0700 Subject: [PATCH 4/4] rename unnamed column to `values` instead --- lib/explorer/polars_backend/series.ex | 2 ++ lib/explorer/series.ex | 4 ++-- test/explorer/series_test.exs | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/lib/explorer/polars_backend/series.ex b/lib/explorer/polars_backend/series.ex index 5aaabc65e..7fef9a471 100644 --- a/lib/explorer/polars_backend/series.ex +++ b/lib/explorer/polars_backend/series.ex @@ -454,6 +454,7 @@ defmodule Explorer.PolarsBackend.Series do maintain_order ]) |> Shared.create_dataframe() + |> DataFrame.rename(%{"" => "values"}) end @impl true @@ -467,6 +468,7 @@ defmodule Explorer.PolarsBackend.Series do maintain_order ]) |> Shared.create_dataframe() + |> DataFrame.rename(%{"" => "values"}) end # Window diff --git a/lib/explorer/series.ex b/lib/explorer/series.ex index 431438cf6..45f38970d 100644 --- a/lib/explorer/series.ex +++ b/lib/explorer/series.ex @@ -3565,7 +3565,7 @@ defmodule Explorer.Series do iex> Explorer.Series.cut(s, [1.5, 2.5]) #Explorer.DataFrame< Polars[3 x 3] - float [1.0, 2.0, 3.0] + values float [1.0, 2.0, 3.0] break_point float [1.5, 2.5, Inf] category category ["(-inf, 1.5]", "(1.5, 2.5]", "(2.5, inf]"] > @@ -3608,7 +3608,7 @@ defmodule Explorer.Series do iex> Explorer.Series.qcut(s, [0.25, 0.75]) #Explorer.DataFrame< Polars[5 x 3] - float [1.0, 2.0, 3.0, 4.0, 5.0] + values float [1.0, 2.0, 3.0, 4.0, 5.0] break_point float [2.0, 2.0, 4.0, 4.0, Inf] category category ["(-inf, 2.0]", "(-inf, 2.0]", "(2.0, 4.0]", "(2.0, 4.0]", "(4.0, inf]"] > diff --git a/test/explorer/series_test.exs b/test/explorer/series_test.exs index b4f66b84b..de06d0e6a 100644 --- a/test/explorer/series_test.exs +++ b/test/explorer/series_test.exs @@ -3913,7 +3913,7 @@ defmodule Explorer.SeriesTest do category_label: "cat" ) - assert Explorer.DataFrame.names(df) == ["", "bp", "cat"] + assert Explorer.DataFrame.names(df) == ["values", "bp", "cat"] end test "qcut/6" do