Skip to content

Commit

Permalink
Properly count multi-dimensional arrays.
Browse files Browse the repository at this point in the history
  • Loading branch information
djgagne committed Jul 27, 2024
1 parent 94db93d commit f142235
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
17 changes: 14 additions & 3 deletions bridgescaler/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import xarray as xr
from functools import partial
from scipy.stats import logistic
from warnings import warn
from numba import guvectorize, float32, float64, void
CENTROID_DTYPE = np.dtype([('mean', np.float64), ('weight', np.float64)])

Expand Down Expand Up @@ -170,7 +169,13 @@ def fit(self, x, weight=None):
if not self._fit:
self.x_columns_ = x_columns
self.is_array_ = is_array
self.n_ += xv.shape[0]
if len(xv.shape) > 2:
if self.channels_last:
self.n_ += np.prod(xv.shape[:-1])
else:
self.n_ += xv.shape[0] * np.prod(xv.shape[2:])
else:
self.n_ += xv.shape[0]
self.mean_x_ = np.zeros(xv.shape[channel_dim], dtype=xv.dtype)
self.var_x_ = np.zeros(xv.shape[channel_dim], dtype=xv.dtype)
if self.channels_last:
Expand All @@ -193,7 +198,13 @@ def fit(self, x, weight=None):
x_col_order = self.get_column_order(x_columns)
# update derived from
# https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups
new_n = xv.shape[0]
if len(xv.shape) > 2:
if self.channels_last:
new_n = np.prod(xv.shape[:-1])
else:
new_n = xv.shape[0] * np.prod(xv.shape[2:])
else:
new_n = xv.shape[0]
for i, o in enumerate(x_col_order):
if self.channels_last:
new_mean = np.nanmean(xv[..., i])
Expand Down
11 changes: 8 additions & 3 deletions bridgescaler/tests/distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@ def make_test_data():

def test_dstandard_scaler():
all_ds_2d = np.vstack(test_data["numpy_2d"])
all_ds_4d = np.vstack(test_data["numpy_4d"])
dsses_2d = []
dsses_4d = []
for n in range(test_data["n_examples"].size):
dsses_2d.append(DStandardScaler())
dsses_2d[-1].fit(test_data["numpy_2d"][n])
dsses_4d.append(DStandardScaler())
dsses_4d.append(DStandardScaler(channels_last=True))
dsses_4d[-1].fit(test_data["numpy_4d"][n])
save_scaler(dsses_2d[-1], "scaler.json")
new_scaler = load_scaler("scaler.json")
Expand All @@ -70,11 +71,15 @@ def test_dstandard_scaler():
dss_total_4d = np.sum(dsses_4d)
mean_2d, var_2d = dss_total_2d.get_scales()
mean_4d, var_4d = dss_total_4d.get_scales()

all_2d_var = all_ds_2d.var(axis=0)
all_4d_var = np.array([all_ds_4d[..., i].var() for i in range(all_ds_4d.shape[-1])])
all_4d_mean = np.array([all_ds_4d[..., i].mean() for i in range(all_ds_4d.shape[-1])])
assert mean_2d.shape[0] == test_data["means"].shape[0] and var_2d.shape[0] == test_data["sds"].shape[0], "Stat shape mismatch"
assert mean_4d.shape[0] == test_data["means"].shape[0] and var_4d.shape[0] == test_data["sds"].shape[0], "Stat shape mismatch"
assert np.max(np.abs(mean_2d - all_ds_2d.mean(axis=0))) < 1e-8, "significant difference in means"
assert np.max(np.abs(var_2d - all_ds_2d.var(axis=0, ddof=1))) < 1e-5, "significant difference in variances"
assert np.max(np.abs(var_2d - all_2d_var) / all_2d_var) < 0.005, "significant difference in variances"
assert np.max(np.abs(mean_4d - all_4d_mean) / all_4d_mean) < 0.001, "significant difference in means"
assert np.max(np.abs(var_4d - all_4d_var) / all_4d_var) < 0.001, "significant difference in variances"
sub_cols = ["d", "b"]
pd_sub_trans = pd_dss.transform(test_data["pandas"][0][sub_cols])
assert pd_sub_trans.shape[1] == len(sub_cols), "Did not subset properly"
Expand Down

0 comments on commit f142235

Please sign in to comment.