Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enables alternative correlation methods #767

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/explorer/backend/data_frame.ex
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ defmodule Explorer.Backend.DataFrame do
@callback nil_count(df) :: df()
@callback explode(df, out_df :: df(), columns :: [column_name()]) :: df()
@callback unnest(df, out_df :: df(), columns :: [column_name()]) :: df()
@callback correlation(df, out_df :: df(), ddof :: integer()) :: df()
@callback correlation(df, out_df :: df(), ddof :: integer(), method :: String.t()) :: df()
@callback covariance(df, out_df :: df(), ddof :: integer()) :: df()

# Two or more table verbs
Expand Down
6 changes: 3 additions & 3 deletions lib/explorer/backend/lazy_series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ defmodule Explorer.Backend.LazySeries do
count: 1,
nil_count: 1,
skew: 2,
correlation: 3,
correlation: 4,
covariance: 3,
all: 1,
any: 1,
Expand Down Expand Up @@ -498,8 +498,8 @@ defmodule Explorer.Backend.LazySeries do
end

@impl true
def correlation(%Series{} = left, %Series{} = right, ddof) do
args = [series_or_lazy_series!(left), series_or_lazy_series!(right), ddof]
def correlation(%Series{} = left, %Series{} = right, ddof, method) do
args = [series_or_lazy_series!(left), series_or_lazy_series!(right), ddof, method]
data = new(:correlation, args, {:f, 64}, true)

Backend.Series.new(data, {:f, 64})
Expand Down
2 changes: 1 addition & 1 deletion lib/explorer/backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ defmodule Explorer.Backend.Series do
@callback nil_count(s) :: number() | lazy_s()
@callback product(s) :: float() | non_finite() | lazy_s() | nil
@callback skew(s, bias? :: boolean()) :: float() | non_finite() | lazy_s() | nil
@callback correlation(s, s, ddof :: non_neg_integer()) ::
@callback correlation(s, s, ddof :: non_neg_integer(), method :: String.t()) ::
float() | non_finite() | lazy_s() | nil
@callback covariance(s, s, ddof :: non_neg_integer()) :: float() | non_finite() | lazy_s() | nil
@callback all?(s) :: boolean() | lazy_s()
Expand Down
16 changes: 13 additions & 3 deletions lib/explorer/data_frame.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5658,7 +5658,7 @@ defmodule Explorer.DataFrame do
def frequencies(_df, []), do: raise(ArgumentError, "columns cannot be empty")

@doc """
Calculates the pairwise Pearson's correlation of numeric columns.
Calculates the pairwise correlation of numeric columns.

The returned dataframe is the correlation matrix.

Expand All @@ -5678,6 +5678,9 @@ defmodule Explorer.DataFrame do
* `:column_name` - the name of the column with column names. Defaults to "names".
* `:ddof` - the 'delta degrees of freedom' - the divisor used in the correlation
calculation. Defaults to 1.
* `method` refers to the correlation method. The following methods are available:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* `method` refers to the correlation method. The following methods are available:
* `:method` refers to the correlation method. The following methods are available:

- `"pearson"` : Standard correlation coefficient. (default)
- `"spearman"` : Spearman rank correlation.

## Examples

Expand All @@ -5693,9 +5696,16 @@ defmodule Explorer.DataFrame do
@doc type: :single
@spec correlation(df :: DataFrame.t(), opts :: Keyword.t()) :: df :: DataFrame.t()
def correlation(df, opts \\ []) do
opts = Keyword.validate!(opts, column_name: "names", columns: names(df), ddof: 1)
opts =
Keyword.validate!(opts,
column_name: "names",
columns: names(df),
ddof: 1,
method: "pearson"
)

out_df = pairwised_df(df, opts)
Shared.apply_impl(df, :correlation, [out_df, opts[:ddof]])
Shared.apply_impl(df, :correlation, [out_df, opts[:ddof], opts[:method]])
end

@doc """
Expand Down
12 changes: 7 additions & 5 deletions lib/explorer/polars_backend/data_frame.ex
Original file line number Diff line number Diff line change
Expand Up @@ -765,13 +765,15 @@ defmodule Explorer.PolarsBackend.DataFrame do
end

@impl true
def correlation(df, out_df, ddof) do
pairwised(df, out_df, ddof, :correlation)
def correlation(df, out_df, ddof, method) do
pairwised(df, out_df, fn left, right ->
PolarsSeries.correlation(left, right, ddof, method)
end)
end

@impl true
def covariance(df, out_df, ddof) do
pairwised(df, out_df, ddof, :covariance)
pairwised(df, out_df, fn left, right -> PolarsSeries.covariance(left, right, ddof) end)
end

# Two or more table verbs
Expand Down Expand Up @@ -840,14 +842,14 @@ defmodule Explorer.PolarsBackend.DataFrame do

# helpers

defp pairwised(df, out_df, ddof, operation) do
defp pairwised(df, out_df, operation) do
[column_name | cols] = out_df.names

pairwised_results =
Enum.map(cols, fn left ->
corr_series =
cols
|> Enum.map(fn right -> apply(PolarsSeries, operation, [df[left], df[right], ddof]) end)
|> Enum.map(fn right -> operation.(df[left], df[right]) end)
|> Shared.from_list({:f, 64})
|> Shared.create_series()

Expand Down
6 changes: 3 additions & 3 deletions lib/explorer/polars_backend/expression.ex
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ defmodule Explorer.PolarsBackend.Expression do
slice: 3,
concat: 1,
column: 1,
correlation: 3,
correlation: 4,
covariance: 3
]

Expand Down Expand Up @@ -205,8 +205,8 @@ defmodule Explorer.PolarsBackend.Expression do
Native.expr_concat(expr_list)
end

def to_expr(%LazySeries{op: :correlation, args: [series1, series2, ddof]}) do
Native.expr_correlation(to_expr(series1), to_expr(series2), ddof)
def to_expr(%LazySeries{op: :correlation, args: [series1, series2, ddof, method]}) do
Native.expr_correlation(to_expr(series1), to_expr(series2), ddof, method)
end

def to_expr(%LazySeries{op: :covariance, args: [series1, series2, ddof]}) do
Expand Down
2 changes: 1 addition & 1 deletion lib/explorer/polars_backend/lazy_frame.ex
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ defmodule Explorer.PolarsBackend.LazyFrame do
end

not_available_funs = [
correlation: 3,
correlation: 4,
covariance: 3,
describe: 2,
nil_count: 1,
Expand Down
2 changes: 1 addition & 1 deletion lib/explorer/polars_backend/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ defmodule Explorer.PolarsBackend.Native do
def s_cumulative_sum(_s, _reverse), do: err()
def s_cumulative_product(_s, _reverse), do: err()
def s_skew(_s, _bias), do: err()
def s_correlation(_s1, _s2, _ddof), do: err()
def s_correlation(_s1, _s2, _ddof, _method), do: err()
def s_covariance(_s1, _s2, _ddof), do: err()
def s_distinct(_s), do: err()
def s_divide(_s, _other), do: err()
Expand Down
5 changes: 3 additions & 2 deletions lib/explorer/polars_backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,9 @@ defmodule Explorer.PolarsBackend.Series do
do: Shared.apply_series(series, :s_skew, [bias?])

@impl true
def correlation(left, right, ddof),
do: Shared.apply_series(matching_size!(left, right), :s_correlation, [right.data, ddof])
def correlation(left, right, ddof, method),
do:
Shared.apply_series(matching_size!(left, right), :s_correlation, [right.data, ddof, method])

@impl true
def covariance(left, right, ddof),
Expand Down
13 changes: 9 additions & 4 deletions lib/explorer/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2660,11 +2660,15 @@ defmodule Explorer.Series do
do: dtype_error("skew/2", dtype, @numeric_dtypes)

@doc """
Compute the Pearson's correlation between two series.
Compute the correlation between two series.

The parameter `ddof` refers to the 'delta degrees of freedom' - the divisor
used in the correlation calculation. Defaults to 1.

The parameter `method` refers to the correlation method. The following methods are available:
- `"pearson"` : Standard correlation coefficient. (default)
- `"spearman"` : Spearman rank correlation.

## Supported dtypes

* `:integer`
Expand All @@ -2682,11 +2686,12 @@ defmodule Explorer.Series do
@spec correlation(
left :: Series.t() | number(),
right :: Series.t() | number(),
ddof :: non_neg_integer()
opts :: Keyword.t()
) ::
float() | non_finite() | nil
def correlation(left, right, ddof \\ 1) when K.and(is_integer(ddof), ddof >= 0) do
basic_numeric_operation(:correlation, left, right, [ddof])
def correlation(left, right, opts \\ []) do
opts = Keyword.validate!(opts, ddof: 1, method: "pearson")
basic_numeric_operation(:correlation, left, right, [opts[:ddof], opts[:method]])
end

@doc """
Expand Down
1 change: 1 addition & 0 deletions native/explorer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ features = [
"peaks",
"moment",
"rank",
"propagate_nans",
]

[dependencies.polars-ops]
Expand Down
12 changes: 9 additions & 3 deletions native/explorer/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
// wrapped in an Elixir struct.

use polars::prelude::{
col, concat_str, cov, pearson_corr, when, IntoLazy, LiteralValue, SortOptions,
col, concat_str, cov, pearson_corr, spearman_rank_corr, when, IntoLazy, LiteralValue,
SortOptions,
};
use polars::prelude::{DataType, Expr, Literal, StrptimeOptions, TimeUnit};

Expand Down Expand Up @@ -494,10 +495,15 @@ pub fn expr_skew(data: ExExpr, bias: bool) -> ExExpr {
}

#[rustler::nif]
pub fn expr_correlation(left: ExExpr, right: ExExpr, ddof: u8) -> ExExpr {
pub fn expr_correlation(left: ExExpr, right: ExExpr, ddof: u8, method: &str) -> ExExpr {
let left_expr = left.clone_inner().cast(DataType::Float64);
let right_expr = right.clone_inner().cast(DataType::Float64);
ExExpr::new(pearson_corr(left_expr, right_expr, ddof))

match method {
"pearson" => ExExpr::new(pearson_corr(left_expr, right_expr, ddof)),
"spearman" => ExExpr::new(spearman_rank_corr(left_expr, right_expr, ddof, true)),
&_ => todo!("not supported yet"),
}
}

#[rustler::nif]
Expand Down
29 changes: 25 additions & 4 deletions native/explorer/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1008,15 +1008,36 @@ pub fn s_skew(env: Env, s: ExSeries, bias: bool) -> Result<Term, ExplorerError>
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_correlation(
env: Env,
pub fn s_correlation<'a>(
env: Env<'a>,
s1: ExSeries,
s2: ExSeries,
ddof: u8,
) -> Result<Term, ExplorerError> {
method: &str,
) -> Result<Term<'a>, ExplorerError> {
let s1 = s1.clone_inner().cast(&DataType::Float64)?;
let s2 = s2.clone_inner().cast(&DataType::Float64)?;
let corr = pearson_corr(s1.f64()?, s2.f64()?, ddof);

let corr = match method {
"pearson" => pearson_corr(s1.f64()?, s2.f64()?, ddof),
"spearman" => {
let df = df!("s1" => s1, "s2" => s2)?;
let lazy_df = df
.lazy()
.with_column(spearman_rank_corr(col("s1"), col("s2"), ddof, true).alias("corr"));
let result = lazy_df.collect()?;
let item = result.column("corr")?.get(0)?;
match item {
AnyValue::Float64(x) => Some(x),
_ => None,
}
}
&_ => {
return Err(ExplorerError::Other(format!(
"method is not supported {method}"
)))
}
};
Ok(term_from_optional_float(corr, env))
}

Expand Down
11 changes: 11 additions & 0 deletions test/explorer/data_frame_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -3868,6 +3868,17 @@ defmodule Explorer.DataFrameTest do
}
end

test "spearman rank method" do
df = DF.new(dogs: [1, 8, 3], cats: [4, 5, 2])
df1 = DF.correlation(df, method: "spearman")

assert DF.to_columns(df1, atom_keys: true) == %{
names: ["dogs", "cats"],
dogs: [1.0, 0.5],
cats: [0.5, 1.0]
}
end

test "three integer columns and custom column name" do
df = DF.new(dogs: [1, 2, 3], cats: [3, 2, 1], frogs: [7, 8, 9])
df1 = DF.correlation(df, column_name: "variables")
Expand Down
13 changes: 12 additions & 1 deletion test/explorer/series_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -4161,6 +4161,17 @@ defmodule Explorer.SeriesTest do
end
end

test "explicit pearson and spearman rank methods for correlation" do
s1 = Series.from_list([1, 8, 3])
s2 = Series.from_list([4, 5, 2])
assert abs(Series.correlation(s1, s2, method: "spearman") - 0.5) < 1.0e-4
assert abs(Series.correlation(s1, s2, method: "pearson") - 0.5447047794019223) < 1.0e-4

assert_raise RuntimeError, ~r/Generic Error: method is not supported not_a_method/, fn ->
Series.correlation(s1, s2, method: "not_a_method")
end
end

test "impossible correlation and covariance" do
s1 = Series.from_list([], dtype: {:f, 64})
s2 = Series.from_list([], dtype: {:f, 64})
Expand All @@ -4187,7 +4198,7 @@ defmodule Explorer.SeriesTest do
s2 = Series.from_list(["a", "b"])

assert_raise ArgumentError,
"cannot invoke Explorer.Series.correlation/3 with mismatched dtypes: {:f, 64} and :string",
"cannot invoke Explorer.Series.correlation/4 with mismatched dtypes: {:f, 64} and :string",
fn -> Series.correlation(s1, s2) end

assert_raise ArgumentError,
Expand Down
Loading