Skip to content

Commit

Permalink
add support for non-jagged trials
Browse files Browse the repository at this point in the history
  • Loading branch information
kushaangupta committed Dec 31, 2024
1 parent 8c55cec commit aec881f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
18 changes: 14 additions & 4 deletions neuro_py/ensemble/decoding/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,26 +210,36 @@ def zscore_trial_segs(
train_nan_cols = ~train_notnan_cols
if is_2D:
normed_train = np.divide(train-train_mean, train_std, where=train_notnan_cols)
normed_train.loc[:, train_nan_cols] = 0
# if train is not jagged, it gets converted completely to object
# np.ndarray. Hence, cannot exclusively use normed_train.loc
if isinstance(normed_train, pd.DataFrame):
normed_train = normed_train.loc
normed_train[:, train_nan_cols] = 0
else:
normed_train = np.empty_like(train)
for i, nsvstseg in enumerate(train):
zscored = np.divide(nsvstseg-train_mean, train_std, where=train_notnan_cols)
zscored.loc[:, train_nan_cols] = 0
if isinstance(zscored, pd.DataFrame):
zscored = zscored.loc
zscored[:, train_nan_cols] = 0
normed_train[i] = zscored

normed_rest_feats = []
if rest_feats is not None:
for feats in rest_feats:
if is_2D:
normed_feats = np.divide(feats-train_mean, train_std, where=train_notnan_cols)
normed_feats.loc[:, train_nan_cols] = 0
if isinstance(normed_feats, pd.DataFrame):
normed_feats = normed_feats.loc
normed_feats[:, train_nan_cols] = 0
normed_rest_feats.append(normed_feats)
else:
normed_feats = np.empty_like(feats)
for i, trialSegROI in enumerate(feats):
zscored = np.divide(feats[i]-train_mean, train_std, where=train_notnan_cols)
zscored.loc[:, train_nan_cols] = 0
if isinstance(zscored, pd.DataFrame):
zscored = zscored.loc
zscored[:, train_nan_cols] = 0
normed_feats[i] = zscored
normed_rest_feats.append(normed_feats)

Expand Down
2 changes: 1 addition & 1 deletion neuro_py/ensemble/decoding/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def partition_sets(
state vectors and behavioral variables.
"""
partitions = []
is_2D = nsv_trial_segs.ndim == 1
is_2D = nsv_trial_segs[0].ndim == 1
for (train_indices, val_indices, test_indices) in partitions_indices:
if is_2D:
if isinstance(nsv_trial_segs, pd.DataFrame):
Expand Down

0 comments on commit aec881f

Please sign in to comment.