Skip to content

Commit

Permalink
Enables alternative correlation methods (#767)
Browse files Browse the repository at this point in the history
  • Loading branch information
i-go-meow authored Dec 13, 2023
1 parent a92cfee commit 7550921
Show file tree
Hide file tree
Showing 16 changed files with 113 additions and 31 deletions.
2 changes: 1 addition & 1 deletion lib/explorer/backend/data_frame.ex
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ defmodule Explorer.Backend.DataFrame do
@callback nil_count(df) :: df()
@callback explode(df, out_df :: df(), columns :: [column_name()]) :: df()
@callback unnest(df, out_df :: df(), columns :: [column_name()]) :: df()
@callback correlation(df, out_df :: df(), ddof :: integer()) :: df()
@callback correlation(df, out_df :: df(), ddof :: integer(), method :: atom()) :: df()
@callback covariance(df, out_df :: df(), ddof :: integer()) :: df()

# Two or more table verbs
Expand Down
6 changes: 3 additions & 3 deletions lib/explorer/backend/lazy_series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ defmodule Explorer.Backend.LazySeries do
count: 1,
nil_count: 1,
skew: 2,
correlation: 3,
correlation: 4,
covariance: 3,
all: 1,
any: 1,
Expand Down Expand Up @@ -498,8 +498,8 @@ defmodule Explorer.Backend.LazySeries do
end

@impl true
def correlation(%Series{} = left, %Series{} = right, ddof) do
args = [series_or_lazy_series!(left), series_or_lazy_series!(right), ddof]
def correlation(%Series{} = left, %Series{} = right, ddof, method) do
args = [series_or_lazy_series!(left), series_or_lazy_series!(right), ddof, method]
data = new(:correlation, args, {:f, 64}, true)

Backend.Series.new(data, {:f, 64})
Expand Down
2 changes: 1 addition & 1 deletion lib/explorer/backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ defmodule Explorer.Backend.Series do
@callback nil_count(s) :: number() | lazy_s()
@callback product(s) :: float() | non_finite() | lazy_s() | nil
@callback skew(s, bias? :: boolean()) :: float() | non_finite() | lazy_s() | nil
@callback correlation(s, s, ddof :: non_neg_integer()) ::
@callback correlation(s, s, ddof :: non_neg_integer(), method :: atom()) ::
float() | non_finite() | lazy_s() | nil
@callback covariance(s, s, ddof :: non_neg_integer()) :: float() | non_finite() | lazy_s() | nil
@callback all?(s) :: boolean() | lazy_s()
Expand Down
16 changes: 13 additions & 3 deletions lib/explorer/data_frame.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5658,7 +5658,7 @@ defmodule Explorer.DataFrame do
def frequencies(_df, []), do: raise(ArgumentError, "columns cannot be empty")

@doc """
Calculates the pairwise Pearson's correlation of numeric columns.
Calculates the pairwise correlation of numeric columns.
The returned dataframe is the correlation matrix.
Expand All @@ -5678,6 +5678,9 @@ defmodule Explorer.DataFrame do
* `:column_name` - the name of the column with column names. Defaults to "names".
* `:ddof` - the 'delta degrees of freedom' - the divisor used in the correlation
calculation. Defaults to 1.
* `:method` refers to the correlation method. The following methods are available:
- `:pearson` : Standard correlation coefficient. (default)
- `:spearman` : Spearman rank correlation.
## Examples
Expand All @@ -5693,9 +5696,16 @@ defmodule Explorer.DataFrame do
@doc type: :single
@spec correlation(df :: DataFrame.t(), opts :: Keyword.t()) :: df :: DataFrame.t()
def correlation(df, opts \\ []) do
opts = Keyword.validate!(opts, column_name: "names", columns: names(df), ddof: 1)
opts =
Keyword.validate!(opts,
column_name: "names",
columns: names(df),
ddof: 1,
method: :pearson
)

out_df = pairwised_df(df, opts)
Shared.apply_impl(df, :correlation, [out_df, opts[:ddof]])
Shared.apply_impl(df, :correlation, [out_df, opts[:ddof], opts[:method]])
end

@doc """
Expand Down
12 changes: 7 additions & 5 deletions lib/explorer/polars_backend/data_frame.ex
Original file line number Diff line number Diff line change
Expand Up @@ -765,13 +765,15 @@ defmodule Explorer.PolarsBackend.DataFrame do
end

@impl true
def correlation(df, out_df, ddof) do
pairwised(df, out_df, ddof, :correlation)
def correlation(df, out_df, ddof, method) do
pairwised(df, out_df, fn left, right ->
PolarsSeries.correlation(left, right, ddof, method)
end)
end

@impl true
def covariance(df, out_df, ddof) do
pairwised(df, out_df, ddof, :covariance)
pairwised(df, out_df, fn left, right -> PolarsSeries.covariance(left, right, ddof) end)
end

# Two or more table verbs
Expand Down Expand Up @@ -840,14 +842,14 @@ defmodule Explorer.PolarsBackend.DataFrame do

# helpers

defp pairwised(df, out_df, ddof, operation) do
defp pairwised(df, out_df, operation) do
[column_name | cols] = out_df.names

pairwised_results =
Enum.map(cols, fn left ->
corr_series =
cols
|> Enum.map(fn right -> apply(PolarsSeries, operation, [df[left], df[right], ddof]) end)
|> Enum.map(fn right -> operation.(df[left], df[right]) end)
|> Shared.from_list({:f, 64})
|> Shared.create_series()

Expand Down
6 changes: 3 additions & 3 deletions lib/explorer/polars_backend/expression.ex
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ defmodule Explorer.PolarsBackend.Expression do
slice: 3,
concat: 1,
column: 1,
correlation: 3,
correlation: 4,
covariance: 3
]

Expand Down Expand Up @@ -205,8 +205,8 @@ defmodule Explorer.PolarsBackend.Expression do
Native.expr_concat(expr_list)
end

def to_expr(%LazySeries{op: :correlation, args: [series1, series2, ddof]}) do
Native.expr_correlation(to_expr(series1), to_expr(series2), ddof)
def to_expr(%LazySeries{op: :correlation, args: [series1, series2, ddof, method]}) do
Native.expr_correlation(to_expr(series1), to_expr(series2), ddof, method)
end

def to_expr(%LazySeries{op: :covariance, args: [series1, series2, ddof]}) do
Expand Down
2 changes: 1 addition & 1 deletion lib/explorer/polars_backend/lazy_frame.ex
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ defmodule Explorer.PolarsBackend.LazyFrame do
end

not_available_funs = [
correlation: 3,
correlation: 4,
covariance: 3,
describe: 2,
nil_count: 1,
Expand Down
2 changes: 1 addition & 1 deletion lib/explorer/polars_backend/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ defmodule Explorer.PolarsBackend.Native do
def s_cumulative_sum(_s, _reverse), do: err()
def s_cumulative_product(_s, _reverse), do: err()
def s_skew(_s, _bias), do: err()
def s_correlation(_s1, _s2, _ddof), do: err()
def s_correlation(_s1, _s2, _ddof, _method), do: err()
def s_covariance(_s1, _s2, _ddof), do: err()
def s_distinct(_s), do: err()
def s_divide(_s, _other), do: err()
Expand Down
5 changes: 3 additions & 2 deletions lib/explorer/polars_backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,9 @@ defmodule Explorer.PolarsBackend.Series do
do: Shared.apply_series(series, :s_skew, [bias?])

@impl true
def correlation(left, right, ddof),
do: Shared.apply_series(matching_size!(left, right), :s_correlation, [right.data, ddof])
def correlation(left, right, ddof, method),
do:
Shared.apply_series(matching_size!(left, right), :s_correlation, [right.data, ddof, method])

@impl true
def covariance(left, right, ddof),
Expand Down
17 changes: 13 additions & 4 deletions lib/explorer/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2660,11 +2660,15 @@ defmodule Explorer.Series do
do: dtype_error("skew/2", dtype, @numeric_dtypes)

@doc """
Compute the Pearson's correlation between two series.
Compute the correlation between two series.
The parameter `ddof` refers to the 'delta degrees of freedom' - the divisor
used in the correlation calculation. Defaults to 1.
The parameter `:method` refers to the correlation method. The following methods are available:
- `:pearson` : Standard correlation coefficient. (default)
- `:spearman` : Spearman rank correlation.
## Supported dtypes
* `:integer`
Expand All @@ -2682,11 +2686,16 @@ defmodule Explorer.Series do
@spec correlation(
left :: Series.t() | number(),
right :: Series.t() | number(),
ddof :: non_neg_integer()
opts :: Keyword.t()
) ::
float() | non_finite() | nil
def correlation(left, right, ddof \\ 1) when K.and(is_integer(ddof), ddof >= 0) do
basic_numeric_operation(:correlation, left, right, [ddof])
def correlation(left, right, opts \\ []) do
opts = Keyword.validate!(opts, ddof: 1, method: :pearson)

if K.not(K.in(opts[:method], [:pearson, :spearman])),
do: raise(ArgumentError, "unsupported correlation method #{inspect(opts[:method])}")

basic_numeric_operation(:correlation, left, right, [opts[:ddof], opts[:method]])
end

@doc """
Expand Down
1 change: 1 addition & 0 deletions native/explorer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ features = [
"peaks",
"moment",
"rank",
"propagate_nans",
]

[dependencies.polars-ops]
Expand Down
6 changes: 6 additions & 0 deletions native/explorer/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,12 @@ pub enum ExParquetCompression {
Zstd(Option<i32>),
}

#[derive(NifTaggedEnum)]
pub enum ExCorrelationMethod {
Pearson,
Spearman,
}

impl TryFrom<ExParquetCompression> for ParquetCompression {
type Error = ExplorerError;

Expand Down
22 changes: 18 additions & 4 deletions native/explorer/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
// wrapped in an Elixir struct.

use polars::prelude::{
col, concat_str, cov, pearson_corr, when, IntoLazy, LiteralValue, SortOptions,
col, concat_str, cov, pearson_corr, spearman_rank_corr, when, IntoLazy, LiteralValue,
SortOptions,
};
use polars::prelude::{DataType, Expr, Literal, StrptimeOptions, TimeUnit};

use crate::datatypes::{ExDate, ExDateTime, ExDuration, ExSeriesDtype, ExValidValue};
use crate::datatypes::{
ExCorrelationMethod, ExDate, ExDateTime, ExDuration, ExSeriesDtype, ExValidValue,
};
use crate::series::{cast_str_to_f64, ewm_opts, rolling_opts};
use crate::{ExDataFrame, ExExpr, ExSeries};

Expand Down Expand Up @@ -494,10 +497,21 @@ pub fn expr_skew(data: ExExpr, bias: bool) -> ExExpr {
}

#[rustler::nif]
pub fn expr_correlation(left: ExExpr, right: ExExpr, ddof: u8) -> ExExpr {
pub fn expr_correlation(
left: ExExpr,
right: ExExpr,
ddof: u8,
method: ExCorrelationMethod,
) -> ExExpr {
let left_expr = left.clone_inner().cast(DataType::Float64);
let right_expr = right.clone_inner().cast(DataType::Float64);
ExExpr::new(pearson_corr(left_expr, right_expr, ddof))

match method {
ExCorrelationMethod::Pearson => ExExpr::new(pearson_corr(left_expr, right_expr, ddof)),
ExCorrelationMethod::Spearman => {
ExExpr::new(spearman_rank_corr(left_expr, right_expr, ddof, true))
}
}
}

#[rustler::nif]
Expand Down
21 changes: 19 additions & 2 deletions native/explorer/src/series.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::{
atoms,
datatypes::{
ExDate, ExDateTime, ExDuration, ExSeriesDtype, ExSeriesIoType, ExTime, ExValidValue,
ExCorrelationMethod, ExDate, ExDateTime, ExDuration, ExSeriesDtype, ExSeriesIoType, ExTime,
ExValidValue,
},
encoding, ExDataFrame, ExSeries, ExplorerError,
};
Expand Down Expand Up @@ -1013,10 +1014,26 @@ pub fn s_correlation(
s1: ExSeries,
s2: ExSeries,
ddof: u8,
method: ExCorrelationMethod,
) -> Result<Term, ExplorerError> {
let s1 = s1.clone_inner().cast(&DataType::Float64)?;
let s2 = s2.clone_inner().cast(&DataType::Float64)?;
let corr = pearson_corr(s1.f64()?, s2.f64()?, ddof);

let corr = match method {
ExCorrelationMethod::Pearson => pearson_corr(s1.f64()?, s2.f64()?, ddof),
ExCorrelationMethod::Spearman => {
let df = df!("s1" => s1, "s2" => s2)?;
let lazy_df = df
.lazy()
.with_column(spearman_rank_corr(col("s1"), col("s2"), ddof, true).alias("corr"));
let result = lazy_df.collect()?;
let item = result.column("corr")?.get(0)?;
match item {
AnyValue::Float64(x) => Some(x),
_ => None,
}
}
};
Ok(term_from_optional_float(corr, env))
}

Expand Down
11 changes: 11 additions & 0 deletions test/explorer/data_frame_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -3868,6 +3868,17 @@ defmodule Explorer.DataFrameTest do
}
end

test "spearman rank method" do
df = DF.new(dogs: [1, 8, 3], cats: [4, 5, 2])
df1 = DF.correlation(df, method: :spearman)

assert DF.to_columns(df1, atom_keys: true) == %{
names: ["dogs", "cats"],
dogs: [1.0, 0.5],
cats: [0.5, 1.0]
}
end

test "three integer columns and custom column name" do
df = DF.new(dogs: [1, 2, 3], cats: [3, 2, 1], frogs: [7, 8, 9])
df1 = DF.correlation(df, column_name: "variables")
Expand Down
13 changes: 12 additions & 1 deletion test/explorer/series_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -4161,6 +4161,17 @@ defmodule Explorer.SeriesTest do
end
end

test "explicit pearson and spearman rank methods for correlation" do
s1 = Series.from_list([1, 8, 3])
s2 = Series.from_list([4, 5, 2])
assert abs(Series.correlation(s1, s2, method: :spearman) - 0.5) < 1.0e-4
assert abs(Series.correlation(s1, s2, method: :pearson) - 0.5447047794019223) < 1.0e-4

assert_raise ArgumentError, ~s(unsupported correlation method :not_a_method), fn ->
Series.correlation(s1, s2, method: :not_a_method)
end
end

test "impossible correlation and covariance" do
s1 = Series.from_list([], dtype: {:f, 64})
s2 = Series.from_list([], dtype: {:f, 64})
Expand All @@ -4187,7 +4198,7 @@ defmodule Explorer.SeriesTest do
s2 = Series.from_list(["a", "b"])

assert_raise ArgumentError,
"cannot invoke Explorer.Series.correlation/3 with mismatched dtypes: {:f, 64} and :string",
"cannot invoke Explorer.Series.correlation/4 with mismatched dtypes: {:f, 64} and :string",
fn -> Series.correlation(s1, s2) end

assert_raise ArgumentError,
Expand Down

0 comments on commit 7550921

Please sign in to comment.