Skip to content

Commit

Permalink
✨ build dataloader fct
Browse files Browse the repository at this point in the history
- use everywhere.
  • Loading branch information
Henry committed Jul 9, 2024
1 parent c70d328 commit 1c72316
Showing 1 changed file with 42 additions and 41 deletions.
83 changes: 42 additions & 41 deletions src/move/data/perturbations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,23 @@
logger = get_logger(__name__)


def _build_perturbed_dataloader(baseline_dataset, perturbed, batch_size):
def _build_dataloader(
cat_data, con_data, cat_shapes, con_shapes, batch_size, shuffle=False
):
# currently for continuous data only
perturbed_dataset = MOVEDataset(
baseline_dataset.cat_all,
perturbed,
baseline_dataset.cat_shapes,
baseline_dataset.con_shapes,
dataset = MOVEDataset(
cat_data,
con_data,
cat_shapes,
con_shapes,
)

perturbed_dataloader = DataLoader(
perturbed_dataset,
shuffle=False,
dataloader = DataLoader(
dataset,
shuffle=shuffle,
batch_size=batch_size,
)
return perturbed_dataloader
return dataloader


def _pertub_cont_feat_col(
Expand Down Expand Up @@ -113,15 +115,11 @@ def perturb_categorical_data(
baseline_dataset.num_samples, *target_shape
)
target_dataset[:, i, :] = torch.FloatTensor(target_value)
perturbed_dataset = MOVEDataset(
perturbed_cat,
baseline_dataset.con_all,
baseline_dataset.cat_shapes,
baseline_dataset.con_shapes,
)
perturbed_dataloader = DataLoader(
perturbed_dataset,
shuffle=False,
perturbed_dataloader = _build_dataloader(
cat_data=perturbed_cat,
con_data=baseline_dataset.con_all,
cat_shapes=baseline_dataset.cat_shapes,
con_shapes=baseline_dataset.con_shapes,
batch_size=baseline_dataloader.batch_size,
)
dataloaders.append(perturbed_dataloader)
Expand Down Expand Up @@ -161,9 +159,11 @@ def perturb_continuous_data(
for i in range(num_features):
perturbed_con = baseline_dataset.con_all.clone()
perturbed_con[:, start_idx + i] = torch.FloatTensor([target_value])
perturbed_dataloader = _build_perturbed_dataloader(
baseline_dataset=baseline_dataset,
perturbed=perturbed_con,
perturbed_dataloader = _build_dataloader(
cat_data=baseline_dataset.cat_all,
con_data=perturbed_con,
cat_shapes=baseline_dataset.cat_shapes,
con_shapes=baseline_dataset.con_shapes,
batch_size=baseline_dataloader.batch_size,
)
dataloaders.append(perturbed_dataloader)
Expand Down Expand Up @@ -209,18 +209,13 @@ def perturb_categorical_data_one(
baseline_dataset.num_samples, *target_shape
)
target_dataset[:, i, :] = torch.FloatTensor(target_value)
perturbed_dataset = MOVEDataset(
perturbed_cat,
baseline_dataset.con_all,
baseline_dataset.cat_shapes,
baseline_dataset.con_shapes,
)
perturbed_dataloader = DataLoader(
perturbed_dataset,
shuffle=False,
perturbed_dataloader = _build_dataloader(
cat_data=perturbed_cat,
con_data=baseline_dataset.con_all,
cat_shapes=baseline_dataset.cat_shapes,
con_shapes=baseline_dataset.con_shapes,
batch_size=baseline_dataloader.batch_size,
)

return perturbed_dataloader


Expand Down Expand Up @@ -254,9 +249,11 @@ def perturb_continuous_data_one(

perturbed_con = baseline_dataset.con_all.clone()
perturbed_con[:, start_idx + index_pert_feat] = torch.FloatTensor([target_value])
perturbed_dataloader = _build_perturbed_dataloader(
baseline_dataset=baseline_dataset,
perturbed=perturbed_con,
perturbed_dataloader = _build_dataloader(
cat_data=baseline_dataset.cat_all,
con_data=perturbed_con,
cat_shapes=baseline_dataset.cat_shapes,
con_shapes=baseline_dataset.con_shapes,
batch_size=baseline_dataloader.batch_size,
)

Expand Down Expand Up @@ -315,9 +312,11 @@ def perturb_continuous_data_extended(
)
perturbations_list.append(perturbed_con[:, start_idx + i].numpy())

perturbed_dataloader = _build_perturbed_dataloader(
baseline_dataset=baseline_dataset,
perturbed=perturbed_con,
perturbed_dataloader = _build_dataloader(
con_data=baseline_dataset.cat_all,
cat_data=perturbed_con,
cat_shapes=baseline_dataset.cat_shapes,
con_shapes=baseline_dataset.con_shapes,
batch_size=baseline_dataloader.batch_size,
)
dataloaders.append(perturbed_dataloader)
Expand Down Expand Up @@ -397,9 +396,11 @@ def perturb_continuous_data_extended_one(
f"Creating perturbed dataset and dataloader for feature {index_pert_feat}"
)

perturbed_dataloader = _build_perturbed_dataloader(
baseline_dataset=baseline_dataset,
perturbed=perturbed_con,
perturbed_dataloader = _build_dataloader(
cat_data=baseline_dataset.cat_all,
con_data=perturbed_con,
cat_shapes=baseline_dataset.cat_shapes,
con_shapes=baseline_dataset.con_shapes,
batch_size=baseline_dataloader.batch_size,
)

Expand Down

0 comments on commit 1c72316

Please sign in to comment.