From 1c7231678326fb3bc283c2af1211d45cc946ff52 Mon Sep 17 00:00:00 2001 From: Henry Date: Tue, 9 Jul 2024 16:53:06 +0200 Subject: [PATCH] :sparkles: build dataloader fct - use everywhere. --- src/move/data/perturbations.py | 83 +++++++++++++++++----------------- 1 file changed, 42 insertions(+), 41 deletions(-) diff --git a/src/move/data/perturbations.py b/src/move/data/perturbations.py index 56d8f93..438ff81 100644 --- a/src/move/data/perturbations.py +++ b/src/move/data/perturbations.py @@ -22,21 +22,23 @@ logger = get_logger(__name__) -def _build_perturbed_dataloader(baseline_dataset, perturbed, batch_size): +def _build_dataloader( + cat_data, con_data, cat_shapes, con_shapes, batch_size, shuffle=False +): # currently for continuous data only - perturbed_dataset = MOVEDataset( - baseline_dataset.cat_all, - perturbed, - baseline_dataset.cat_shapes, - baseline_dataset.con_shapes, + dataset = MOVEDataset( + cat_data, + con_data, + cat_shapes, + con_shapes, ) - perturbed_dataloader = DataLoader( - perturbed_dataset, - shuffle=False, + dataloader = DataLoader( + dataset, + shuffle=shuffle, batch_size=batch_size, ) - return perturbed_dataloader + return dataloader def _pertub_cont_feat_col( @@ -113,15 +115,11 @@ def perturb_categorical_data( baseline_dataset.num_samples, *target_shape ) target_dataset[:, i, :] = torch.FloatTensor(target_value) - perturbed_dataset = MOVEDataset( - perturbed_cat, - baseline_dataset.con_all, - baseline_dataset.cat_shapes, - baseline_dataset.con_shapes, - ) - perturbed_dataloader = DataLoader( - perturbed_dataset, - shuffle=False, + perturbed_dataloader = _build_dataloader( + cat_data=perturbed_cat, + con_data=baseline_dataset.con_all, + cat_shapes=baseline_dataset.cat_shapes, + con_shapes=baseline_dataset.con_shapes, batch_size=baseline_dataloader.batch_size, ) dataloaders.append(perturbed_dataloader) @@ -161,9 +159,11 @@ def perturb_continuous_data( for i in range(num_features): perturbed_con = baseline_dataset.con_all.clone() perturbed_con[:, start_idx + i] = torch.FloatTensor([target_value]) - perturbed_dataloader = _build_perturbed_dataloader( - baseline_dataset=baseline_dataset, - perturbed=perturbed_con, + perturbed_dataloader = _build_dataloader( + cat_data=baseline_dataset.cat_all, + con_data=perturbed_con, + cat_shapes=baseline_dataset.cat_shapes, + con_shapes=baseline_dataset.con_shapes, batch_size=baseline_dataloader.batch_size, ) dataloaders.append(perturbed_dataloader) @@ -209,18 +209,13 @@ def perturb_categorical_data_one( baseline_dataset.num_samples, *target_shape ) target_dataset[:, i, :] = torch.FloatTensor(target_value) - perturbed_dataset = MOVEDataset( - perturbed_cat, - baseline_dataset.con_all, - baseline_dataset.cat_shapes, - baseline_dataset.con_shapes, - ) - perturbed_dataloader = DataLoader( - perturbed_dataset, - shuffle=False, + perturbed_dataloader = _build_dataloader( + cat_data=perturbed_cat, + con_data=baseline_dataset.con_all, + cat_shapes=baseline_dataset.cat_shapes, + con_shapes=baseline_dataset.con_shapes, batch_size=baseline_dataloader.batch_size, ) - return perturbed_dataloader @@ -254,9 +249,11 @@ def perturb_continuous_data_one( perturbed_con = baseline_dataset.con_all.clone() perturbed_con[:, start_idx + index_pert_feat] = torch.FloatTensor([target_value]) - perturbed_dataloader = _build_perturbed_dataloader( - baseline_dataset=baseline_dataset, - perturbed=perturbed_con, + perturbed_dataloader = _build_dataloader( + cat_data=baseline_dataset.cat_all, + con_data=perturbed_con, + cat_shapes=baseline_dataset.cat_shapes, + con_shapes=baseline_dataset.con_shapes, batch_size=baseline_dataloader.batch_size, ) @@ -315,9 +312,11 @@ def perturb_continuous_data_extended( ) perturbations_list.append(perturbed_con[:, start_idx + i].numpy()) - perturbed_dataloader = _build_perturbed_dataloader( - baseline_dataset=baseline_dataset, - perturbed=perturbed_con, + perturbed_dataloader = _build_dataloader( + con_data=baseline_dataset.cat_all, + cat_data=perturbed_con, + cat_shapes=baseline_dataset.cat_shapes, + con_shapes=baseline_dataset.con_shapes, batch_size=baseline_dataloader.batch_size, ) dataloaders.append(perturbed_dataloader) @@ -397,9 +396,11 @@ def perturb_continuous_data_extended_one( f"Creating perturbed dataset and dataloader for feature {index_pert_feat}" ) - perturbed_dataloader = _build_perturbed_dataloader( - baseline_dataset=baseline_dataset, - perturbed=perturbed_con, + perturbed_dataloader = _build_dataloader( + cat_data=baseline_dataset.cat_all, + con_data=perturbed_con, + cat_shapes=baseline_dataset.cat_shapes, + con_shapes=baseline_dataset.con_shapes, batch_size=baseline_dataloader.batch_size, )