Skip to content

Commit

Permalink
Merge pull request #945 from scap3yvt/944-feature-add-dp-enabled-trai…
Browse files Browse the repository at this point in the history
…ning

Added functionality related to DP training
  • Loading branch information
sarthakpati authored Oct 1, 2024
2 parents 10bd05f + 53566b0 commit 2f33623
Show file tree
Hide file tree
Showing 12 changed files with 505 additions and 4 deletions.
11 changes: 10 additions & 1 deletion GANDLF/compute/inference_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,16 @@ def inference_loop(
assert file_to_load != None, "The 'best_file' was not found"

main_dict = torch.load(file_to_load, map_location=parameters["device"])
model.load_state_dict(main_dict["model_state_dict"])
state_dict = main_dict["model_state_dict"]
if parameters.get("differential_privacy"):
# this is required for torch==1.11 and for DP inference
new_state_dict = {}
for key, val in state_dict.items():
new_key = key.replace("_module.", "")
new_state_dict[new_key] = val # remove `module.`
state_dict = new_state_dict

model.load_state_dict(state_dict)
parameters["previous_parameters"] = main_dict.get("parameters", None)
model.eval()
elif parameters["model"]["type"].lower() == "openvino":
Expand Down
68 changes: 68 additions & 0 deletions GANDLF/compute/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
from .forward_pass import validate_network
from .generic import create_pytorch_objects

from GANDLF.privacy.opacus.model_handling import empty_collate
from GANDLF.privacy.opacus import handle_dynamic_batch_size, prep_for_opacus_training
from opacus.utils.batch_memory_manager import wrap_data_loader

# hides torchio citation request, see https://github.com/fepegar/torchio/issues/235
os.environ["TORCHIO_HIDE_CITATION_PROMPT"] = "1"

Expand Down Expand Up @@ -91,6 +95,14 @@ def train_network(
for batch_idx, (subject) in enumerate(
tqdm(train_dataloader, desc="Looping over training data")
):
if params.get("differential_privacy"):
subject, params["batch_size"] = handle_dynamic_batch_size(
subject=subject, params=params
)
assert not isinstance(
model, torch.nn.DataParallel
), "Differential privacy is not supported with DataParallel or DistributedDataParallel. Please use a single GPU or DDP with Opacus."

optimizer.zero_grad()
image = ( # 5D tensor: (B, C, H, W, D)
torch.cat(
Expand Down Expand Up @@ -212,6 +224,23 @@ def train_network(
return average_epoch_train_loss, average_epoch_train_metric


def train_network_wrapper(model, train_dataloader, optimizer, params):
"""
Wrapper Function to handle train_dataloader for benign and DP cases and pass on to train a network for a single epoch
"""

if params.get("differential_privacy"):
with train_dataloader as memory_safe_data_loader:
epoch_train_loss, epoch_train_metric = train_network(
model, memory_safe_data_loader, optimizer, params
)
else:
epoch_train_loss, epoch_train_metric = train_network(
model, train_dataloader, optimizer, params
)
return epoch_train_loss, epoch_train_metric


def training_loop(
training_data: pd.DataFrame,
validation_data: pd.DataFrame,
Expand Down Expand Up @@ -368,6 +397,7 @@ def training_loop(
logger_csv_filename=os.path.join(output_dir, "logs_validation.csv"),
metrics=metrics_log,
mode="valid",
add_epsilon=bool(params.get("differential_privacy")),
)
if testingDataDefined:
test_logger = Logger(
Expand All @@ -392,6 +422,36 @@ def training_loop(

print("Using device:", device, flush=True)

if params.get("differential_privacy"):
print(
"Using Opacus to make training differentially private with respect to the training data."
)

model, optimizer, train_dataloader, privacy_engine = prep_for_opacus_training(
model=model,
optimizer=optimizer,
train_dataloader=train_dataloader,
params=params,
)

train_dataloader.collate_fn = empty_collate(train_dataloader.dataset[0])

# train_dataloader = BatchMemoryManager(
# data_loader=train_dataloader,
# max_physical_batch_size=MAX_PHYSICAL_BATCH_SIZE,
# optimizer=optimizer,
# )
batch_size = params["batch_size"]
max_physical_batch_size = params["differential_privacy"].get(
"physical_batch_size"
)
if max_physical_batch_size and max_physical_batch_size != batch_size:
train_dataloader = wrap_data_loader(
data_loader=train_dataloader,
max_batch_size=max_physical_batch_size,
optimizer=optimizer,
)

# Iterate for number of epochs
for epoch in range(start_epoch, epochs):
if params["track_memory_usage"]:
Expand Down Expand Up @@ -453,6 +513,14 @@ def training_loop(

patience += 1

# if training with differential privacy, print privacy epsilon
if params.get("differential_privacy"):
delta = params["differential_privacy"]["delta"]
this_epsilon = privacy_engine.get_epsilon(delta)
print(f" Epoch Final Privacy: (ε = {this_epsilon:.2f}, δ = {delta})")
# save for logging
epoch_valid_metric["epsilon"] = this_epsilon

# Write the losses to a logger
train_logger.write(epoch, epoch_train_loss, epoch_train_metric)
valid_logger.write(epoch, epoch_valid_loss, epoch_valid_metric)
Expand Down
5 changes: 5 additions & 0 deletions GANDLF/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from .utils import version_check
from GANDLF.data.post_process import postprocessing_after_reverse_one_hot_encoding
from GANDLF.privacy.opacus import parse_opacus_params

from GANDLF.metrics import surface_distance_ids
from importlib.metadata import version
Expand Down Expand Up @@ -710,6 +711,10 @@ def _parseConfig(
temp_dict["type"] = params["optimizer"]
params["optimizer"] = temp_dict

# initialize defaults for DP
if params.get("differential_privacy"):
params = parse_opacus_params(params, initialize_key)

# initialize defaults for inference mechanism
inference_mechanism = {"grid_aggregator_overlap": "crop", "patch_overlap": 0}
initialize_inference_mechanism = False
Expand Down
15 changes: 12 additions & 3 deletions GANDLF/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,21 @@


class Logger:
def __init__(self, logger_csv_filename: str, metrics: List[str], mode: str) -> None:
def __init__(
self,
logger_csv_filename: str,
metrics: List[str],
mode: str,
add_epsilon: bool = False,
) -> None:
"""
Logger class to log the training and validation metrics to a csv file.
May append to existing file if headers match; elsewise raises an error.
Logger class to log the training and validation metrics to a csv file. May append to existing file if headers match; elsewise raises an error.
Args:
logger_csv_filename (str): Path to a filename where the csv has to be stored.
metrics (Dict[str, float]): The metrics to be logged.
mode (str): The mode of the logger, used as suffix to metric names. Normally may be `train` / `val` / `test`
add_epsilon (bool): Whether to log epsilon values or not (differential privacy measurement)
"""
self.filename = logger_csv_filename
mode = mode.lower()
Expand All @@ -28,6 +35,8 @@ def __init__(self, logger_csv_filename: str, metrics: List[str], mode: str) -> N
new_header = ["epoch_no", f"{mode}_loss"] + [
f"{mode}_{metric}" for metric in metrics
]
if add_epsilon:
new_header.append(f"{self.mode}_epsilon")

# TODO: do we really need to support appending to existing files?
if os.path.exists(self.filename):
Expand Down
4 changes: 4 additions & 0 deletions GANDLF/models/imagenet_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ def __init__(self, parameters) -> None:
aux_params=classifier_head_parameters,
)

# all BatchNorm should be replaced with InstanceNorm for DP experiments
if "differential_privacy" in parameters:
self.replace_batchnorm(self.model)

if self.n_dimensions == 3:
self.model = self.converter(self.model).model

Expand Down
Empty file added GANDLF/privacy/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions GANDLF/privacy/opacus/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .config_parsing import parse_opacus_params
from .model_handling import opacus_model_fix, prep_for_opacus_training
from .training_utils import handle_dynamic_batch_size
59 changes: 59 additions & 0 deletions GANDLF/privacy/opacus/config_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Callable


def parse_opacus_params(params: dict, initialize_key: Callable) -> dict:
"""
Function to set defaults and augment the parameters related to making a trained model differentially
private with respect to the training data.
Args:
params (dict): Training parameters.
initialize_key (Callable): Function to fill in value for a missing key.
Returns:
dict: Updated training parameters.
"""

if not isinstance(params["differential_privacy"], dict):
print(
"WARNING: Non dictionary value for the key: 'differential_privacy' was used, replacing with default valued dictionary."
)
params["differential_privacy"] = {}
# these are some defaults
params["differential_privacy"] = initialize_key(
params["differential_privacy"], "noise_multiplier", 10.0
)
params["differential_privacy"] = initialize_key(
params["differential_privacy"], "max_grad_norm", 1.0
)
params["differential_privacy"] = initialize_key(
params["differential_privacy"], "accountant", "rdp"
)
params["differential_privacy"] = initialize_key(
params["differential_privacy"], "secure_mode", False
)
params["differential_privacy"] = initialize_key(
params["differential_privacy"], "allow_opacus_model_fix", True
)
params["differential_privacy"] = initialize_key(
params["differential_privacy"], "delta", 1e-5
)
params["differential_privacy"] = initialize_key(
params["differential_privacy"], "physical_batch_size", params["batch_size"]
)

if params["differential_privacy"]["physical_batch_size"] > params["batch_size"]:
print(
f"WARNING: The physical batch size {params['differential_privacy']['physical_batch_size']} is greater"
f"than the batch size {params['batch_size']}, setting the physical batch size to the batch size."
)
params["differential_privacy"]["physical_batch_size"] = params["batch_size"]

# these keys need to be parsed as floats, not strings
for key in ["noise_multiplier", "max_grad_norm", "delta", "epsilon"]:
if key in params["differential_privacy"]:
params["differential_privacy"][key] = float(
params["differential_privacy"][key]
)

return params
Loading

0 comments on commit 2f33623

Please sign in to comment.