Skip to content

Commit

Permalink
Expose ddof on covariance
Browse files Browse the repository at this point in the history
  • Loading branch information
cigrainger committed Nov 30, 2023
1 parent db29c5d commit 20e8d5e
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 20 deletions.
6 changes: 3 additions & 3 deletions lib/explorer/backend/lazy_series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ defmodule Explorer.Backend.LazySeries do
nil_count: 1,
skew: 2,
correlation: 3,
covariance: 2,
covariance: 3,
# Strings
contains: 2,
replace: 3,
Expand Down Expand Up @@ -504,8 +504,8 @@ defmodule Explorer.Backend.LazySeries do
end

@impl true
def covariance(%Series{} = left, %Series{} = right) do
args = [series_or_lazy_series!(left), series_or_lazy_series!(right)]
def covariance(%Series{} = left, %Series{} = right, ddof \\ 1) do
args = [series_or_lazy_series!(left), series_or_lazy_series!(right), ddof]
data = new(:covariance, args, {:f, 64}, true)

Backend.Series.new(data, {:f, 64})
Expand Down
2 changes: 1 addition & 1 deletion lib/explorer/backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ defmodule Explorer.Backend.Series do
@callback skew(s, bias? :: boolean()) :: float() | non_finite() | lazy_s() | nil
@callback correlation(s, s, ddof :: non_neg_integer()) ::
float() | non_finite() | lazy_s() | nil
@callback covariance(s, s) :: float() | non_finite() | lazy_s() | nil
@callback covariance(s, s, ddof :: non_neg_integer()) :: float() | non_finite() | lazy_s() | nil

# Cumulative

Expand Down
10 changes: 7 additions & 3 deletions lib/explorer/polars_backend/expression.ex
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ defmodule Explorer.PolarsBackend.Expression do
sum: 1,
unordered_distinct: 1,
variance: 2,
skew: 2,
covariance: 2
skew: 2
]

@first_only_expressions [
Expand Down Expand Up @@ -154,7 +153,8 @@ defmodule Explorer.PolarsBackend.Expression do
slice: 3,
concat: 1,
column: 1,
correlation: 3
correlation: 3,
covariance: 3
]

missing =
Expand Down Expand Up @@ -207,6 +207,10 @@ defmodule Explorer.PolarsBackend.Expression do
Native.expr_correlation(to_expr(series1), to_expr(series2), ddof)
end

def to_expr(%LazySeries{op: :covariance, args: [series1, series2, ddof]}) do
Native.expr_covariance(to_expr(series1), to_expr(series2), ddof)
end

def to_expr(%LazySeries{op: :format, args: [series_list]}) when is_list(series_list) do
expr_list = Enum.map(series_list, &to_expr/1)

Expand Down
2 changes: 1 addition & 1 deletion lib/explorer/polars_backend/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ defmodule Explorer.PolarsBackend.Native do
def s_cumulative_product(_s, _reverse), do: err()
def s_skew(_s, _bias), do: err()
def s_correlation(_s1, _s2, _ddof), do: err()
def s_covariance(_s1, _s2), do: err()
def s_covariance(_s1, _s2, _ddof), do: err()
def s_distinct(_s), do: err()
def s_divide(_s, _other), do: err()
def s_dtype(_s), do: err()
Expand Down
4 changes: 2 additions & 2 deletions lib/explorer/polars_backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ defmodule Explorer.PolarsBackend.Series do
do: Shared.apply_series(matching_size!(left, right), :s_correlation, [right.data, ddof])

@impl true
def covariance(left, right),
do: Shared.apply_series(matching_size!(left, right), :s_covariance, [right.data])
def covariance(left, right, ddof),
do: Shared.apply_series(matching_size!(left, right), :s_covariance, [right.data, ddof])

# Cumulative

Expand Down
10 changes: 7 additions & 3 deletions lib/explorer/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2518,10 +2518,14 @@ defmodule Explorer.Series do
3.0
"""
@doc type: :aggregation
@spec covariance(left :: Series.t() | number(), right :: Series.t() | number()) ::
@spec covariance(
left :: Series.t() | number(),
right :: Series.t() | number(),
ddof :: non_neg_integer()
) ::
float() | non_finite() | nil
def covariance(left, right) do
basic_numeric_operation(:covariance, left, right)
def covariance(left, right, ddof \\ 1) do
basic_numeric_operation(:covariance, left, right, [ddof])
end

# Cumulative
Expand Down
4 changes: 1 addition & 3 deletions native/explorer/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -501,11 +501,9 @@ pub fn expr_correlation(left: ExExpr, right: ExExpr, ddof: u8) -> ExExpr {
}

#[rustler::nif]
pub fn expr_covariance(left: ExExpr, right: ExExpr) -> ExExpr {
pub fn expr_covariance(left: ExExpr, right: ExExpr, ddof: u8) -> ExExpr {
let left_expr = left.clone_inner().cast(DataType::Float64);
let right_expr = right.clone_inner().cast(DataType::Float64);
// TODO: make this a parameter.
let ddof: u8 = 1;
ExExpr::new(cov(left_expr, right_expr, ddof))
}

Expand Down
4 changes: 1 addition & 3 deletions native/explorer/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1006,11 +1006,9 @@ pub fn s_correlation(
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_covariance(env: Env, s1: ExSeries, s2: ExSeries) -> Result<Term, ExplorerError> {
pub fn s_covariance(env: Env, s1: ExSeries, s2: ExSeries, ddof: u8) -> Result<Term, ExplorerError> {
let s1 = s1.clone_inner().cast(&DataType::Float64)?;
let s2 = s2.clone_inner().cast(&DataType::Float64)?;
// TODO: make ddof a parameter.
let ddof: u8 = 1;
let cov = cov(s1.f64()?, s2.f64()?, ddof);
Ok(term_from_optional_float(cov, env))
}
Expand Down
2 changes: 1 addition & 1 deletion test/explorer/series_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -4088,7 +4088,7 @@ defmodule Explorer.SeriesTest do
fn -> Series.correlation(s1, s2) end

assert_raise ArgumentError,
"cannot invoke Explorer.Series.covariance/2 with mismatched dtypes: {:f, 64} and :string",
"cannot invoke Explorer.Series.covariance/3 with mismatched dtypes: {:f, 64} and :string",
fn -> Series.covariance(s1, s2) end
end
end
Expand Down

0 comments on commit 20e8d5e

Please sign in to comment.