Skip to content

Commit

Permalink
Support skip_na: keyword in Array#mean
Browse files Browse the repository at this point in the history
  • Loading branch information
mrkn committed Aug 2, 2021
1 parent 8b86c3c commit d7fdfe7
Showing 1 changed file with 131 additions and 62 deletions.
193 changes: 131 additions & 62 deletions ext/enumerable/statistics/extension/statistics.c
Original file line number Diff line number Diff line change
Expand Up @@ -667,50 +667,34 @@ static int opt_skip_na(VALUE opts)
return RTEST(skip_na);
}

/* call-seq:
* ary.sum(skip_na: false)
*
* Calculate the sum of the values in `ary`.
* This method utilizes
* [Kahan summation algorithm](https://en.wikipedia.org/wiki/Kahan_summation_algorithm)
* to compensate the result precision when the `ary` includes Float values.
*
* Note that This library does not redefine `sum` method introduced in Ruby 2.4.
*
* @return [Number] A summation value
*/
static VALUE
ary_sum(int argc, VALUE* argv, VALUE ary)
VALUE
ary_calculate_sum(VALUE ary, VALUE init, int skip_na, long *na_count_out)
{
VALUE e, v, r, opts;
VALUE e, v, r;
long i, n;
int block_given;
int skip_na;

if (rb_scan_args(argc, argv, "01:", &v, &opts) == 0) {
v = LONG2FIX(0);
}
skip_na = opt_skip_na(opts);

#ifndef HAVE_ENUM_SUM
if (!skip_na) {
return rb_funcall(orig_ary_sum, rb_intern("call"), argc, &v);
}
#endif
long na_count = 0;

block_given = rb_block_given_p();

if (RARRAY_LEN(ary) == 0)
return v;
if (RARRAY_LEN(ary) == 0) {
if (na_count_out != NULL) {
*na_count_out = 0;
}
return init;
}

n = 0;
r = Qundef;
v = init;
for (i = 0; i < RARRAY_LEN(ary); i++) {
e = RARRAY_AREF(ary, i);
if (block_given)
e = rb_yield(e);
if (skip_na && is_na(e))
if (skip_na && is_na(e)) {
++na_count;
continue;
}

if (FIXNUM_P(e)) {
n += FIX2LONG(e); /* should not overflow long type */
Expand All @@ -735,7 +719,7 @@ ary_sum(int argc, VALUE* argv, VALUE ary)
v = rb_fix_plus(LONG2FIX(n), v);
if (r != Qundef)
v = rb_rational_plus(r, v);
return v;
goto finish;

not_exact:
if (n != 0)
Expand All @@ -755,6 +739,11 @@ ary_sum(int argc, VALUE* argv, VALUE ary)
e = RARRAY_AREF(ary, i);
if (block_given)
e = rb_yield(e);
if (skip_na && is_na(e)) {
++na_count;
continue;
}

if (RB_FLOAT_TYPE_P(e))
has_float_value:
x = RFLOAT_VALUE(e);
Expand All @@ -772,7 +761,9 @@ ary_sum(int argc, VALUE* argv, VALUE ary)
c = (t - f) - y;
f = t;
}
return DBL2NUM(f);

v = DBL2NUM(f);
goto finish;

not_float:
v = DBL2NUM(f);
Expand All @@ -783,13 +774,53 @@ ary_sum(int argc, VALUE* argv, VALUE ary)
e = RARRAY_AREF(ary, i);
if (block_given)
e = rb_yield(e);
if (skip_na && is_na(e)) {
++na_count;
continue;
}
has_some_value:
v = rb_funcall(v, idPLUS, 1, e);
}

finish:
if (na_count_out != NULL) {
*na_count_out = na_count;
}
return v;
}

/* call-seq:
* ary.sum(skip_na: false)
*
* Calculate the sum of the values in `ary`.
* This method utilizes
* [Kahan summation algorithm](https://en.wikipedia.org/wiki/Kahan_summation_algorithm)
* to compensate the result precision when the `ary` includes Float values.
*
* Note that This library does not redefine `sum` method introduced in Ruby 2.4.
*
* @return [Number] A summation value
*/
static VALUE
ary_sum(int argc, VALUE* argv, VALUE ary)
{
VALUE v, opts;
int skip_na;

if (rb_scan_args(argc, argv, "01:", &v, &opts) == 0) {
v = LONG2FIX(0);
}
skip_na = opt_skip_na(opts);

#ifndef HAVE_ENUM_SUM
if (!skip_na) {
return rb_funcall(orig_ary_sum, rb_intern("call"), argc, &v);
}
#endif

return ary_calculate_sum(ary, v, skip_na, NULL);
}

static void
calculate_and_set_mean(VALUE *mean_ptr, VALUE sum, long const n)
{
Expand Down Expand Up @@ -818,9 +849,10 @@ calculate_and_set_mean(VALUE *mean_ptr, VALUE sum, long const n)
}

static void
ary_mean_variance(VALUE ary, VALUE *mean_ptr, VALUE *variance_ptr, size_t ddof)
ary_mean_variance(VALUE ary, VALUE *mean_ptr, VALUE *variance_ptr, size_t ddof, int skip_na)
{
long i;
long na_count;
size_t n = 0;
double m = 0.0, m2 = 0.0, f = 0.0, c = 0.0;

Expand All @@ -844,8 +876,8 @@ ary_mean_variance(VALUE ary, VALUE *mean_ptr, VALUE *variance_ptr, size_t ddof)

if (variance_ptr == NULL) {
VALUE init = DBL2NUM(0.0);
VALUE const sum = ary_sum(1, &init, ary);
long const n = RARRAY_LEN(ary);
VALUE const sum = ary_calculate_sum(ary, init, skip_na, &na_count);
long const n = RARRAY_LEN(ary) - na_count;
calculate_and_set_mean(mean_ptr, sum, n);
return;
}
Expand Down Expand Up @@ -886,26 +918,46 @@ ary_mean_variance(VALUE ary, VALUE *mean_ptr, VALUE *variance_ptr, size_t ddof)
}
}

static int
opt_population_p(VALUE opts)
struct variance_opts {
int population;
int skip_na;
};

static void
get_variance_opts(VALUE opts, struct variance_opts *out)
{
VALUE population = Qfalse;
assert(out != NULL);

out->population = 0;
out->skip_na = 0;

if (!NIL_P(opts)) {
#ifdef HAVE_RB_GET_KWARGS
ID kwargs = id_population;
rb_get_kwargs(opts, &kwargs, 0, 1, &population);
static ID kwarg_keys[2];
VALUE kwarg_vals;

if (!kwarg_keys[0]) {
kwarg_keys[0] = id_population;
kwarg_keys[1] = id_skip_na;
}

rb_get_kwargs(opts, &kwarg_keys, 0, 2, kwarg_vals);
out->population = (kwarg_vals[0] != Qundef) ? RTEST(kwarg_vals[0]) : out->population;
out->skip_na = (kwarg_vals[1] != Qundef) ? RTEST(kwarg_vals[1]) : out->skip_na;
#else
VALUE val = rb_hash_aref(opts, ID2SYM(id_population));
population = NIL_P(val) ? population : val;
VALUE val;

val = rb_hash_aref(opts, ID2SYM(id_population));
out->population = NIL_P(val) ? out->population : RTEST(val);

val = rb_hash_aref(opts, ID2SYM(id_skip_na));
out->skip_na = NIL_P(val) ? out->skip_na : RTEST(val);
#endif
}

return RTEST(population);
}

/* call-seq:
* ary.mean_variance(population: false)
* ary.mean_variance(population: false, skip_na: false)
*
* Calculate a mean and a variance of the values in `ary`.
* The first element of the result array is the mean, and the second is the variance.
Expand All @@ -923,19 +975,21 @@ opt_population_p(VALUE opts)
static VALUE
ary_mean_variance_m(int argc, VALUE* argv, VALUE ary)
{
VALUE opts, mean, variance;
struct variance_opts options;
VALUE opts, mean = Qnil, variance = Qnil;
size_t ddof = 1;

rb_scan_args(argc, argv, "0:", &opts);
if (opt_population_p(opts))
get_variance_opts(opts, &options);
if (options.population)
ddof = 0;

ary_mean_variance(ary, &mean, &variance, ddof);
ary_mean_variance(ary, &mean, &variance, ddof, options.skip_na);
return rb_assoc_new(mean, variance);
}

/* call-seq:
* ary.mean
* ary.mean(skip_na: false)
*
* Calculate a mean of the values in `ary`.
* This method utilizes
Expand All @@ -945,15 +999,20 @@ ary_mean_variance_m(int argc, VALUE* argv, VALUE ary)
* @return [Number] A mean value
*/
static VALUE
ary_mean(VALUE ary)
ary_mean(int argc, VALUE *argv, VALUE ary)
{
VALUE mean;
ary_mean_variance(ary, &mean, NULL, 1);
VALUE mean = Qnil, opts;
int skip_na;

rb_scan_args(argc, argv, ":", &opts);
skip_na = opt_skip_na(opts);

ary_mean_variance(ary, &mean, NULL, 1, skip_na);
return mean;
}

/* call-seq:
* ary.variance(population: false)
* ary.variance(population: false, skip_na: false)
*
* Calculate a variance of the values in `ary`.
* This method scan values in `ary` only once,
Expand All @@ -969,14 +1028,16 @@ ary_mean(VALUE ary)
static VALUE
ary_variance(int argc, VALUE* argv, VALUE ary)
{
struct variance_opts options;
VALUE opts, variance;
size_t ddof = 1;

rb_scan_args(argc, argv, "0:", &opts);
if (opt_population_p(opts))
get_variance_opts(opts, &options);
if (options.population)
ddof = 0;

ary_mean_variance(ary, NULL, &variance, ddof);
ary_mean_variance(ary, NULL, &variance, ddof, options.skip_na);
return variance;
}

Expand Down Expand Up @@ -1366,11 +1427,13 @@ enum_mean_variance(VALUE obj, VALUE *mean_ptr, VALUE *variance_ptr, size_t ddof)
static VALUE
enum_mean_variance_m(int argc, VALUE* argv, VALUE obj)
{
struct variance_opts options;
VALUE opts, mean, variance;
size_t ddof = 1;

rb_scan_args(argc, argv, "0:", &opts);
if (opt_population_p(opts))
get_variance_opts(opts, &options);
if (options.population)
ddof = 0;

enum_mean_variance(obj, &mean, &variance, ddof);
Expand Down Expand Up @@ -1412,11 +1475,13 @@ enum_mean(VALUE obj)
static VALUE
enum_variance(int argc, VALUE* argv, VALUE obj)
{
struct variance_opts options;
VALUE opts, variance;
size_t ddof = 1;

rb_scan_args(argc, argv, "0:", &opts);
if (opt_population_p(opts))
get_variance_opts(opts, &options);
if (options.population)
ddof = 0;

enum_mean_variance(obj, NULL, &variance, ddof);
Expand Down Expand Up @@ -1455,11 +1520,13 @@ sqrt_value(VALUE x)
static VALUE
enum_mean_stdev(int argc, VALUE* argv, VALUE obj)
{
struct variance_opts options;
VALUE opts, mean, variance;
size_t ddof = 1;

rb_scan_args(argc, argv, "0:", &opts);
if (opt_population_p(opts))
get_variance_opts(opts, &options);
if (options.population)
ddof = 0;

enum_mean_variance(obj, &mean, &variance, ddof);
Expand Down Expand Up @@ -1509,14 +1576,16 @@ enum_stdev(int argc, VALUE* argv, VALUE obj)
static VALUE
ary_mean_stdev(int argc, VALUE* argv, VALUE ary)
{
struct variance_opts options;
VALUE opts, mean, variance;
size_t ddof = 1;

rb_scan_args(argc, argv, "0:", &opts);
if (opt_population_p(opts))
get_variance_opts(opts, &options);
if (options.population)
ddof = 0;

ary_mean_variance(ary, &mean, &variance, ddof);
ary_mean_variance(ary, &mean, &variance, ddof, options.skip_na);
VALUE stdev = sqrt_value(variance);
return rb_assoc_new(mean, stdev);
}
Expand Down Expand Up @@ -1948,7 +2017,7 @@ any_value_counts(int argc, VALUE *argv, VALUE obj,
struct value_counts_opts opts;
struct value_counts_memo memo;

rb_scan_args(argc, argv, ":", &kwargs);
rb_scan_args(argc, argv, "0:", &kwargs);
value_counts_extract_opts(kwargs, &opts);

memo.result = rb_hash_new();
Expand Down Expand Up @@ -2471,7 +2540,7 @@ Init_extension(void)

rb_define_method(rb_cArray, "sum", ary_sum, -1);
rb_define_method(rb_cArray, "mean_variance", ary_mean_variance_m, -1);
rb_define_method(rb_cArray, "mean", ary_mean, 0);
rb_define_method(rb_cArray, "mean", ary_mean, -1);
rb_define_method(rb_cArray, "variance", ary_variance, -1);
rb_define_method(rb_cArray, "mean_stdev", ary_mean_stdev, -1);
rb_define_method(rb_cArray, "stdev", ary_stdev, -1);
Expand Down

0 comments on commit d7fdfe7

Please sign in to comment.