Skip to content

Commit

Permalink
fix breaking dnn input formatting tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kushaangupta committed Jan 1, 2025
1 parent aec881f commit 0799d68
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions neuro_py/ensemble/decoding/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def zscore_trial_segs(
Normalized train data, normalized rest features, and normalization parameters.
"""
is_2D = train[0].ndim == 1
concat_train = train if is_2D else np.concatenate(train)
concat_train = train if is_2D else np.concatenate(train).astype(float)
train_mean = normparams['X_train_mean'] if normparams is not None else bn.nanmean(concat_train, axis=0)
train_std = normparams['X_train_std'] if normparams is not None else bn.nanstd(concat_train, axis=0)

Expand All @@ -213,15 +213,17 @@ def zscore_trial_segs(
# if train is not jagged, it gets converted completely to object
# np.ndarray. Hence, cannot exclusively use normed_train.loc
if isinstance(normed_train, pd.DataFrame):
normed_train = normed_train.loc
normed_train[:, train_nan_cols] = 0
normed_train.loc[:, train_nan_cols] = 0
else:
normed_train[:, train_nan_cols] = 0
else:
normed_train = np.empty_like(train)
for i, nsvstseg in enumerate(train):
zscored = np.divide(nsvstseg-train_mean, train_std, where=train_notnan_cols)
if isinstance(zscored, pd.DataFrame):
zscored = zscored.loc
zscored[:, train_nan_cols] = 0
zscored.loc[:, train_nan_cols] = 0
else:
zscored[:, train_nan_cols] = 0
normed_train[i] = zscored

normed_rest_feats = []
Expand All @@ -230,16 +232,18 @@ def zscore_trial_segs(
if is_2D:
normed_feats = np.divide(feats-train_mean, train_std, where=train_notnan_cols)
if isinstance(normed_feats, pd.DataFrame):
normed_feats = normed_feats.loc
normed_feats[:, train_nan_cols] = 0
normed_feats.loc[:, train_nan_cols] = 0
else:
normed_feats[:, train_nan_cols] = 0
normed_rest_feats.append(normed_feats)
else:
normed_feats = np.empty_like(feats)
for i, trialSegROI in enumerate(feats):
zscored = np.divide(feats[i]-train_mean, train_std, where=train_notnan_cols)
if isinstance(zscored, pd.DataFrame):
zscored = zscored.loc
zscored[:, train_nan_cols] = 0
zscored.loc[:, train_nan_cols] = 0
else:
zscored[:, train_nan_cols] = 0
normed_feats[i] = zscored
normed_rest_feats.append(normed_feats)

Expand Down

0 comments on commit 0799d68

Please sign in to comment.