Skip to content

Commit

Permalink
Add Series.cut/6 and Series.qcut/6 (#629)
Browse files Browse the repository at this point in the history
  • Loading branch information
anthony-khong authored Jun 28, 2023
1 parent 5d260cd commit 3c61262
Show file tree
Hide file tree
Showing 10 changed files with 238 additions and 0 deletions.
2 changes: 2 additions & 0 deletions lib/explorer/backend/lazy_series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,8 @@ defmodule Explorer.Backend.LazySeries do
categories: 1,
categorise: 2,
frequencies: 1,
cut: 6,
qcut: 6,
mask: 2,
to_iovec: 1,
to_list: 1
Expand Down
7 changes: 7 additions & 0 deletions lib/explorer/backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,13 @@ defmodule Explorer.Backend.Series do
@callback n_distinct(s) :: integer() | lazy_s()
@callback frequencies(s) :: df

# Categorisation

@callback cut(s, [float()], [String.t()] | nil, String.t() | nil, String.t() | nil, boolean()) ::
df
@callback qcut(s, [float()], [String.t()] | nil, String.t() | nil, String.t() | nil, boolean()) ::
df

# Rolling

@callback window_sum(
Expand Down
5 changes: 5 additions & 0 deletions lib/explorer/polars_backend/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,11 @@ defmodule Explorer.PolarsBackend.Native do
def s_upcase(_s), do: err()
def s_unordered_distinct(_s), do: err()
def s_frequencies(_s), do: err()
def s_cut(_s, _bins, _labels, _break_point_label, _category_label, _maintain_order), do: err()

def s_qcut(_s, _quantiles, _labels, _break_point_label, _category_label, _maintain_order),
do: err()

def s_variance(_s), do: err()
def s_window_max(_s, _window_size, _weight, _ignore_null, _min_periods), do: err()
def s_window_mean(_s, _window_size, _weight, _ignore_null, _min_periods), do: err()
Expand Down
30 changes: 30 additions & 0 deletions lib/explorer/polars_backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,36 @@ defmodule Explorer.PolarsBackend.Series do
|> DataFrame.rename(["values", "counts"])
end

# Categorisation

@impl true
def cut(series, bins, labels, break_point_label, category_label, maintain_order) do
Shared.apply(:s_cut, [
series.data,
bins,
labels,
break_point_label,
category_label,
maintain_order
])
|> Shared.create_dataframe()
|> DataFrame.rename(%{"" => "values"})
end

@impl true
def qcut(series, quantiles, labels, break_point_label, category_label, maintain_order) do
Shared.apply(:s_qcut, [
series.data,
quantiles,
labels,
break_point_label,
category_label,
maintain_order
])
|> Shared.create_dataframe()
|> DataFrame.rename(%{"" => "values"})
end

# Window

@impl true
Expand Down
85 changes: 85 additions & 0 deletions lib/explorer/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3539,6 +3539,91 @@ defmodule Explorer.Series do
@doc type: :aggregation
def frequencies(series), do: apply_series(series, :frequencies)

@doc """
Bins values into discrete values.
Given a `bins` length of N, there will be N+1 categories.
## Options
* `:labels` - The labels assigned to the bins. Given `bins` of
length N, `:labels` must be of length N+1. Defaults to the bin
bounds (e.g. `(-inf -1.0]`, `(-1.0, 1.0]`, `(1.0, inf]`)
* `:break_point_label` - The name given to the breakpoint column.
Defaults to `break_point`.
* `:category_label` - The name given to the category column.
Defaults to `category`.
* `:maintain_order` - The name given to the category column.
Defaults to `false`.
## Examples
iex> s = Explorer.Series.from_list([1.0, 2.0, 3.0])
iex> Explorer.Series.cut(s, [1.5, 2.5])
#Explorer.DataFrame<
Polars[3 x 3]
values float [1.0, 2.0, 3.0]
break_point float [1.5, 2.5, Inf]
category category ["(-inf, 1.5]", "(1.5, 2.5]", "(2.5, inf]"]
>
"""
@doc type: :aggregation
def cut(series, bins, opts \\ []) do
apply_series(series, :cut, [
Enum.map(bins, &(&1 / 1.0)),
Keyword.get(opts, :labels),
Keyword.get(opts, :break_point_label),
Keyword.get(opts, :category_label),
Keyword.get(opts, :maintain_order, false)
])
end

@doc """
Bins values into discrete values base on their quantiles.
Given a `quantiles` length of N, there will be N+1 categories. Each
element of `quantiles` is expected to be between 0.0 and 1.0.
## Options
* `:labels` - The labels assigned to the bins. Given `bins` of
length N, `:labels` must be of length N+1. Defaults to the bin
bounds (e.g. `(-inf -1.0]`, `(-1.0, 1.0]`, `(1.0, inf]`)
* `:break_point_label` - The name given to the breakpoint column.
Defaults to `break_point`.
* `:category_label` - The name given to the category column.
Defaults to `category`.
* `:maintain_order` - The name given to the category column.
Defaults to `false`.
## Examples
iex> s = Explorer.Series.from_list([1.0, 2.0, 3.0, 4.0, 5.0])
iex> Explorer.Series.qcut(s, [0.25, 0.75])
#Explorer.DataFrame<
Polars[5 x 3]
values float [1.0, 2.0, 3.0, 4.0, 5.0]
break_point float [2.0, 2.0, 4.0, 4.0, Inf]
category category ["(-inf, 2.0]", "(-inf, 2.0]", "(2.0, 4.0]", "(2.0, 4.0]", "(4.0, inf]"]
>
"""
@doc type: :aggregation
def qcut(series, quantiles, opts \\ []) do
apply_series(series, :qcut, [
Enum.map(quantiles, &(&1 / 1.0)),
Keyword.get(opts, :labels),
Keyword.get(opts, :break_point_label),
Keyword.get(opts, :category_label),
Keyword.get(opts, :maintain_order, false)
])
end

@doc """
Counts the number of elements in a series.
Expand Down
12 changes: 12 additions & 0 deletions native/explorer/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions native/explorer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,6 @@ features = [

[dependencies.polars-ops]
version = "0.29"

[dependencies.polars-algo]
version = "0.29"
2 changes: 2 additions & 0 deletions native/explorer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,8 @@ rustler::init!(
s_to_iovec,
s_unordered_distinct,
s_frequencies,
s_cut,
s_qcut,
s_variance,
s_window_max,
s_window_mean,
Expand Down
44 changes: 44 additions & 0 deletions native/explorer/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::{
use encoding::encode_datetime;
use polars::export::arrow::array::Utf8Array;
use polars::prelude::*;
use polars_algo::{cut, qcut};
use rustler::{Binary, Encoder, Env, ListIterator, Term, TermType};
use std::{result::Result, slice};

Expand Down Expand Up @@ -325,6 +326,49 @@ pub fn s_frequencies(series: ExSeries) -> Result<ExDataFrame, ExplorerError> {
Ok(ExDataFrame::new(df))
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_cut(
series: ExSeries,
bins: Vec<f64>,
labels: Option<Vec<&str>>,
break_point_label: Option<&str>,
category_label: Option<&str>,
maintain_order: bool,
) -> Result<ExDataFrame, ExplorerError> {
let series = series.clone_inner();
let bins = Series::new("", bins);
let df = cut(
&series,
bins,
labels,
break_point_label,
category_label,
maintain_order,
)?;
Ok(ExDataFrame::new(df))
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_qcut(
series: ExSeries,
quantiles: Vec<f64>,
labels: Option<Vec<&str>>,
break_point_label: Option<&str>,
category_label: Option<&str>,
maintain_order: bool,
) -> Result<ExDataFrame, ExplorerError> {
let series = series.clone_inner();
let df = qcut(
&series,
&quantiles,
labels,
break_point_label,
category_label,
maintain_order,
)?;
Ok(ExDataFrame::new(df))
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_slice_by_indices(series: ExSeries, indices: Vec<u32>) -> Result<ExSeries, ExplorerError> {
let idx = UInt32Chunked::from_vec("idx", indices);
Expand Down
48 changes: 48 additions & 0 deletions test/explorer/series_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -3883,4 +3883,52 @@ defmodule Explorer.SeriesTest do
end
end
end

describe "categorisation functions" do
test "cut/6 with no nils" do
series = -30..30//5 |> Enum.map(&(&1 / 10)) |> Enum.to_list() |> Series.from_list()
df = Series.cut(series, [-1, 1])
freqs = Series.frequencies(df[:category])
assert Series.to_list(freqs[:values]) == ["(-inf, -1.0]", "(-1.0, 1.0]", "(1.0, inf]"]
assert Series.to_list(freqs[:counts]) == [5, 4, 4]
end

test "cut/6 with nils" do
series = Series.from_list([1, 2, 3, nil, nil])
df = Series.cut(series, [2])
assert [_, _, _, nil, nil] = Series.to_list(df[:category])
end

test "cut/6 options" do
series = Series.from_list([1, 2, 3])

assert_raise RuntimeError,
"Polars Error: lengths don't match: labels count must equal bins count",
fn -> Series.cut(series, [2], labels: ["x"]) end

df =
Series.cut(series, [2],
labels: ["x", "y"],
break_point_label: "bp",
category_label: "cat"
)

assert Explorer.DataFrame.names(df) == ["values", "bp", "cat"]
end

test "qcut/6" do
series = Enum.to_list(-5..3) |> Series.from_list()
df = Series.qcut(series, [0.0, 0.25, 0.75])
freqs = Series.frequencies(df[:category])

assert Series.to_list(freqs[:values]) == [
"(-3.0, 1.0]",
"(-5.0, -3.0]",
"(1.0, inf]",
"(-inf, -5.0]"
]

assert Series.to_list(freqs[:counts]) == [4, 2, 2, 1]
end
end
end

0 comments on commit 3c61262

Please sign in to comment.