Skip to content

Commit

Permalink
* extend description of train_model function
Browse files Browse the repository at this point in the history
* fix m2mlstm & ndt data pre-formatting
  • Loading branch information
kushaangupta committed Dec 25, 2024
1 parent f111c2d commit d4e80c5
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions neuro_py/ensemble/decoding/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def preprocess_data(hyperparams, ohe, nsv_train, nsv_val, nsv_test, bv_train, bv
bins_before = hyperparams['bins_before']
bins_current = hyperparams['bins_current']
bins_after = hyperparams['bins_after']
is_2D = nsv_train[0].ndim == 1
if hyperparams['model'] not in ('M2MLSTM', 'NDT'):
(
X_cov_train, X_flat_train, y_train,
Expand Down Expand Up @@ -262,6 +263,10 @@ def preprocess_data(hyperparams, ohe, nsv_train, nsv_val, nsv_test, bv_train, bv
num_workers=hyperparams['num_workers'], modeltype=hyperparams['model'])
hyperparams['model_args']['in_dim'] = X_train.shape[-1]
else:
if is_2D:
nsv_train, bv_train = [nsv_train], [bv_train]
nsv_val, bv_val = [nsv_val], [bv_val]
nsv_test, bv_test = [nsv_test], [bv_test]
if type(bv_train[0]) is pd.DataFrame:
y_train = [y.values[:, hyperparams['behaviors']] for y in bv_train]
else:
Expand Down Expand Up @@ -391,6 +396,9 @@ def shuffle_nsv_intrialsegs(nsv_trialsegs):
def train_model(partitions, hyperparams, resultspath=None, stop_partition=None):
"""Generic function to train a DNN model on the given data partitions.
In-built caching & checkpointing is used to save the best model based on the
validation loss.
Parameters
----------
partitions : array-like
Expand Down

0 comments on commit d4e80c5

Please sign in to comment.