Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added functionality related to DP training #945

Merged
Merged
11 changes: 10 additions & 1 deletion GANDLF/compute/inference_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,16 @@ def inference_loop(
assert file_to_load != None, "The 'best_file' was not found"

main_dict = torch.load(file_to_load, map_location=parameters["device"])
model.load_state_dict(main_dict["model_state_dict"])
state_dict = main_dict["model_state_dict"]
if parameters.get("differential_privacy"):
# this is required for torch==1.11 and for DP inference
new_state_dict = {}
for key, val in state_dict.items():
new_key = key.replace("_module.", "")
new_state_dict[new_key] = val # remove `module.`
state_dict = new_state_dict

model.load_state_dict(state_dict)
parameters["previous_parameters"] = main_dict.get("parameters", None)
model.eval()
elif parameters["model"]["type"].lower() == "openvino":
Expand Down
68 changes: 68 additions & 0 deletions GANDLF/compute/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
from .forward_pass import validate_network
from .generic import create_pytorch_objects

from GANDLF.privacy.opacus.model_handling import empty_collate
from GANDLF.privacy.opacus import handle_dynamic_batch_size, prep_for_opacus_training
from opacus.utils.batch_memory_manager import wrap_data_loader

# hides torchio citation request, see https://github.com/fepegar/torchio/issues/235
os.environ["TORCHIO_HIDE_CITATION_PROMPT"] = "1"

Expand Down Expand Up @@ -91,6 +95,14 @@ def train_network(
for batch_idx, (subject) in enumerate(
tqdm(train_dataloader, desc="Looping over training data")
):
if params.get("differential_privacy"):
subject, params["batch_size"] = handle_dynamic_batch_size(
subject=subject, params=params
)
assert not isinstance(
model, torch.nn.DataParallel
), "Differential privacy is not supported with DataParallel or DistributedDataParallel. Please use a single GPU or DDP with Opacus."

optimizer.zero_grad()
image = ( # 5D tensor: (B, C, H, W, D)
torch.cat(
Expand Down Expand Up @@ -212,6 +224,23 @@ def train_network(
return average_epoch_train_loss, average_epoch_train_metric


def train_network_wrapper(model, train_dataloader, optimizer, params):
"""
Wrapper Function to handle train_dataloader for benign and DP cases and pass on to train a network for a single epoch
"""

if params.get("differential_privacy"):
with train_dataloader as memory_safe_data_loader:
epoch_train_loss, epoch_train_metric = train_network(
model, memory_safe_data_loader, optimizer, params
)
else:
epoch_train_loss, epoch_train_metric = train_network(
model, train_dataloader, optimizer, params
)
return epoch_train_loss, epoch_train_metric


def training_loop(
training_data: pd.DataFrame,
validation_data: pd.DataFrame,
Expand Down Expand Up @@ -368,6 +397,7 @@ def training_loop(
logger_csv_filename=os.path.join(output_dir, "logs_validation.csv"),
metrics=metrics_log,
mode="valid",
add_epsilon=bool(params.get("differential_privacy")),
)
if testingDataDefined:
test_logger = Logger(
Expand All @@ -392,6 +422,36 @@ def training_loop(

print("Using device:", device, flush=True)

if params.get("differential_privacy"):
print(
"Using Opacus to make training differentially private with respect to the training data."
)

model, optimizer, train_dataloader, privacy_engine = prep_for_opacus_training(
model=model,
optimizer=optimizer,
train_dataloader=train_dataloader,
params=params,
)

train_dataloader.collate_fn = empty_collate(train_dataloader.dataset[0])

# train_dataloader = BatchMemoryManager(
# data_loader=train_dataloader,
# max_physical_batch_size=MAX_PHYSICAL_BATCH_SIZE,
# optimizer=optimizer,
# )
batch_size = params["batch_size"]
max_physical_batch_size = params["differential_privacy"].get(
"physical_batch_size"
)
if max_physical_batch_size and max_physical_batch_size != batch_size:
train_dataloader = wrap_data_loader(
data_loader=train_dataloader,
max_batch_size=max_physical_batch_size,
optimizer=optimizer,
)

# Iterate for number of epochs
for epoch in range(start_epoch, epochs):
if params["track_memory_usage"]:
Expand Down Expand Up @@ -453,6 +513,14 @@ def training_loop(

patience += 1

# if training with differential privacy, print privacy epsilon
if params.get("differential_privacy"):
delta = params["differential_privacy"]["delta"]
this_epsilon = privacy_engine.get_epsilon(delta)
print(f" Epoch Final Privacy: (ε = {this_epsilon:.2f}, δ = {delta})")
# save for logging
epoch_valid_metric["epsilon"] = this_epsilon

# Write the losses to a logger
train_logger.write(epoch, epoch_train_loss, epoch_train_metric)
valid_logger.write(epoch, epoch_valid_loss, epoch_valid_metric)
Expand Down
5 changes: 5 additions & 0 deletions GANDLF/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from .utils import version_check
from GANDLF.data.post_process import postprocessing_after_reverse_one_hot_encoding
from GANDLF.privacy.opacus import parse_opacus_params

from GANDLF.metrics import surface_distance_ids
from importlib.metadata import version
Expand Down Expand Up @@ -710,6 +711,10 @@ def _parseConfig(
temp_dict["type"] = params["optimizer"]
params["optimizer"] = temp_dict

# initialize defaults for DP
if params.get("differential_privacy"):
params = parse_opacus_params(params, initialize_key)

# initialize defaults for inference mechanism
inference_mechanism = {"grid_aggregator_overlap": "crop", "patch_overlap": 0}
initialize_inference_mechanism = False
Expand Down
15 changes: 12 additions & 3 deletions GANDLF/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,21 @@


class Logger:
def __init__(self, logger_csv_filename: str, metrics: List[str], mode: str) -> None:
def __init__(
self,
logger_csv_filename: str,
metrics: List[str],
mode: str,
add_epsilon: bool = False,
) -> None:
"""
Logger class to log the training and validation metrics to a csv file.
May append to existing file if headers match; elsewise raises an error.
Logger class to log the training and validation metrics to a csv file. May append to existing file if headers match; elsewise raises an error.

Args:
logger_csv_filename (str): Path to a filename where the csv has to be stored.
metrics (Dict[str, float]): The metrics to be logged.
mode (str): The mode of the logger, used as suffix to metric names. Normally may be `train` / `val` / `test`
add_epsilon (bool): Whether to log epsilon values or not (differential privacy measurement)
"""
self.filename = logger_csv_filename
mode = mode.lower()
Expand All @@ -28,6 +35,8 @@ def __init__(self, logger_csv_filename: str, metrics: List[str], mode: str) -> N
new_header = ["epoch_no", f"{mode}_loss"] + [
f"{mode}_{metric}" for metric in metrics
]
if add_epsilon:
new_header.append(f"{self.mode}_epsilon")

# TODO: do we really need to support appending to existing files?
if os.path.exists(self.filename):
Expand Down
4 changes: 4 additions & 0 deletions GANDLF/models/imagenet_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ def __init__(self, parameters) -> None:
aux_params=classifier_head_parameters,
)

# all BatchNorm should be replaced with InstanceNorm for DP experiments
if "differential_privacy" in parameters:
self.replace_batchnorm(self.model)

if self.n_dimensions == 3:
self.model = self.converter(self.model).model

Expand Down
Empty file added GANDLF/privacy/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions GANDLF/privacy/opacus/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .config_parsing import parse_opacus_params
from .model_handling import opacus_model_fix, prep_for_opacus_training
from .training_utils import handle_dynamic_batch_size
59 changes: 59 additions & 0 deletions GANDLF/privacy/opacus/config_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Callable


def parse_opacus_params(params: dict, initialize_key: Callable) -> dict:
"""
Function to set defaults and augment the parameters related to making a trained model differentially
private with respect to the training data.

Args:
params (dict): Training parameters.
initialize_key (Callable): Function to fill in value for a missing key.

Returns:
dict: Updated training parameters.
"""

if not isinstance(params["differential_privacy"], dict):
print(
"WARNING: Non dictionary value for the key: 'differential_privacy' was used, replacing with default valued dictionary."
)
params["differential_privacy"] = {}
# these are some defaults
params["differential_privacy"] = initialize_key(
params["differential_privacy"], "noise_multiplier", 10.0
)
params["differential_privacy"] = initialize_key(
params["differential_privacy"], "max_grad_norm", 1.0
)
params["differential_privacy"] = initialize_key(
params["differential_privacy"], "accountant", "rdp"
)
params["differential_privacy"] = initialize_key(
params["differential_privacy"], "secure_mode", False
)
params["differential_privacy"] = initialize_key(
params["differential_privacy"], "allow_opacus_model_fix", True
)
params["differential_privacy"] = initialize_key(
params["differential_privacy"], "delta", 1e-5
)
params["differential_privacy"] = initialize_key(
params["differential_privacy"], "physical_batch_size", params["batch_size"]
)

if params["differential_privacy"]["physical_batch_size"] > params["batch_size"]:
print(
f"WARNING: The physical batch size {params['differential_privacy']['physical_batch_size']} is greater"
f"than the batch size {params['batch_size']}, setting the physical batch size to the batch size."
)
params["differential_privacy"]["physical_batch_size"] = params["batch_size"]

# these keys need to be parsed as floats, not strings
for key in ["noise_multiplier", "max_grad_norm", "delta", "epsilon"]:
if key in params["differential_privacy"]:
params["differential_privacy"][key] = float(
params["differential_privacy"][key]
)

return params
Loading
Loading