From ece832f6d3cd3f2582ea6fd37ecd2bf6d302205f Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 2 May 2024 09:14:43 -0700 Subject: [PATCH 1/9] Add batch_c15n for [0,1] image input and imagenet-normalized input. --- mart/configs/batch_c15n/image_01.yaml | 6 ++++++ .../configs/batch_c15n/imagenet_normalized.yaml | 6 ++++++ .../batch_c15n/transform/255_to_imagenet.yaml | 4 ++++ .../batch_c15n/transform/divided_by_255.yaml | 3 +++ .../batch_c15n/transform/imagenet_to_255.yaml | 17 +++++++++++++++++ .../transform/times_255_and_round.yaml | 13 +++++++++++++ 6 files changed, 49 insertions(+) create mode 100644 mart/configs/batch_c15n/image_01.yaml create mode 100644 mart/configs/batch_c15n/imagenet_normalized.yaml create mode 100644 mart/configs/batch_c15n/transform/255_to_imagenet.yaml create mode 100644 mart/configs/batch_c15n/transform/divided_by_255.yaml create mode 100644 mart/configs/batch_c15n/transform/imagenet_to_255.yaml create mode 100644 mart/configs/batch_c15n/transform/times_255_and_round.yaml diff --git a/mart/configs/batch_c15n/image_01.yaml b/mart/configs/batch_c15n/image_01.yaml new file mode 100644 index 00000000..f3f8e758 --- /dev/null +++ b/mart/configs/batch_c15n/image_01.yaml @@ -0,0 +1,6 @@ +defaults: + - list + - transform: times_255_and_round + - transform@untransform: divided_by_255 + +input_key: 0 diff --git a/mart/configs/batch_c15n/imagenet_normalized.yaml b/mart/configs/batch_c15n/imagenet_normalized.yaml new file mode 100644 index 00000000..e41fb3ff --- /dev/null +++ b/mart/configs/batch_c15n/imagenet_normalized.yaml @@ -0,0 +1,6 @@ +defaults: + - dict + - transform: imagenet_to_255 + - transform@untransform: 255_to_imagenet + +input_key: image diff --git a/mart/configs/batch_c15n/transform/255_to_imagenet.yaml b/mart/configs/batch_c15n/transform/255_to_imagenet.yaml new file mode 100644 index 00000000..9bb9ebef --- /dev/null +++ b/mart/configs/batch_c15n/transform/255_to_imagenet.yaml @@ -0,0 +1,4 @@ +_target_: torchvision.transforms.Normalize +# from 0-1 scale statistics: mean=[0.485, 0.456, 0.406]*255 std=[0.229, 0.224, 0.225]*255 +mean: [123.6750, 116.2800, 103.5300] +std: [58.3950, 57.1200, 57.3750] diff --git a/mart/configs/batch_c15n/transform/divided_by_255.yaml b/mart/configs/batch_c15n/transform/divided_by_255.yaml new file mode 100644 index 00000000..92a63b7c --- /dev/null +++ b/mart/configs/batch_c15n/transform/divided_by_255.yaml @@ -0,0 +1,3 @@ +_target_: torchvision.transforms.Normalize +mean: 0 +std: 255 diff --git a/mart/configs/batch_c15n/transform/imagenet_to_255.yaml b/mart/configs/batch_c15n/transform/imagenet_to_255.yaml new file mode 100644 index 00000000..66a4ef53 --- /dev/null +++ b/mart/configs/batch_c15n/transform/imagenet_to_255.yaml @@ -0,0 +1,17 @@ +_target_: torchvision.transforms.Compose +transforms: + - _target_: mart.transforms.Denormalize + # from 0-1 scale statistics: mean=[0.485, 0.456, 0.406]*255 std=[0.229, 0.224, 0.225]*255 + center: + _target_: torch.as_tensor + data: [123.6750, 116.2800, 103.5300] + scale: + _target_: torch.as_tensor + data: [58.3950, 57.1200, 57.3750] + - _target_: torch.fake_quantize_per_tensor_affine + _partial_: true + # (x/1+0).round().clamp(0, 255) * 1 + scale: 1 + zero_point: 0 + quant_min: 0 + quant_max: 255 diff --git a/mart/configs/batch_c15n/transform/times_255_and_round.yaml b/mart/configs/batch_c15n/transform/times_255_and_round.yaml new file mode 100644 index 00000000..dbeff64d --- /dev/null +++ b/mart/configs/batch_c15n/transform/times_255_and_round.yaml @@ -0,0 +1,13 @@ +_target_: torchvision.transforms.Compose +transforms: + - _target_: mart.transforms.Denormalize + center: 0 + scale: 255 + # Fix potential numeric error. + - _target_: torch.fake_quantize_per_tensor_affine + _partial_: true + # (x/1+0).round().clamp(0, 255) * 1 + scale: 1 + zero_point: 0 + quant_min: 0 + quant_max: 255 From 726a0af0cc93120d7acf329cf841e49566d2b583 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 2 May 2024 09:18:27 -0700 Subject: [PATCH 2/9] Turn off inference mode before creating perturbations. --- mart/attack/adversary.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mart/attack/adversary.py b/mart/attack/adversary.py index 9fbdec75..1555620b 100644 --- a/mart/attack/adversary.py +++ b/mart/attack/adversary.py @@ -151,6 +151,8 @@ def configure_gradient_clipping( for group in optimizer.param_groups: self.gradient_modifier(group["params"]) + # Turn off the inference mode, so we will create perturbation that requires gradient. + @torch.inference_mode(False) @silent() def fit(self, input, target, *, model: Callable): # The attack also needs access to the model at every iteration. From b0c307970a07264a2125595d02c9b0b6256982e1 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 2 May 2024 09:18:51 -0700 Subject: [PATCH 3/9] Switch to training mode before running LightningModule.training_step(). --- mart/callbacks/adversary_connector.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/mart/callbacks/adversary_connector.py b/mart/callbacks/adversary_connector.py index a35b225b..db8fe7b0 100644 --- a/mart/callbacks/adversary_connector.py +++ b/mart/callbacks/adversary_connector.py @@ -7,7 +7,7 @@ from __future__ import annotations import types -from typing import Callable +from typing import Any, Callable from lightning.pytorch.callbacks import Callback @@ -16,6 +16,20 @@ __all__ = ["AdversaryConnector"] +class training_mode: + """A context that switches a torch.nn.Module object to the training mode.""" + + def __init__(self, module): + self.module = module + self.training = self.module.training + + def __enter__(self): + self.module.train(True) + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): + self.module.train(self.training) + + class AdversaryConnector(Callback): """Perturbs inputs to be adversarial.""" @@ -81,7 +95,9 @@ def model(input, target): # LightningModule must have "training_step". # Disable logging if we have to reuse training_step() of the target model. with MonkeyPatch(pl_module, "log", lambda *args, **kwargs: None): - outputs = pl_module.training_step(batch, dataloader_idx) + # Switch the model to the training mode so traing_step works as expected. + with training_mode(pl_module): + outputs = pl_module.training_step(batch, dataloader_idx) return outputs # Canonicalize the batch to work with Adversary. From eba3c2bb6ca988ebc7f1ea9171a2f4f28238ee44 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 2 May 2024 09:21:16 -0700 Subject: [PATCH 4/9] Add utils for config instantiation. --- mart/utils/config.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/mart/utils/config.py b/mart/utils/config.py index 36dbb6e6..3a4268ad 100644 --- a/mart/utils/config.py +++ b/mart/utils/config.py @@ -8,14 +8,17 @@ import os +import hydra from hydra import compose as hydra_compose from hydra import initialize_config_dir +from lightning.pytorch.callbacks.callback import Callback +from omegaconf import OmegaConf DEFAULT_VERSION_BASE = "1.2" DEFAULT_CONFIG_DIR = "." DEFAULT_CONFIG_NAME = "lightning.yaml" -__all__ = ["compose"] +__all__ = ["compose", "instantiate", "Instantiator", "CallbackInstantiator"] def compose( @@ -40,3 +43,28 @@ def compose( cfg = cfg[key] return cfg + + +def instantiate(cfg_path): + """Instantiate an object from a Hydra yaml config file.""" + config = OmegaConf.load(cfg_path) + obj = hydra.utils.instantiate(config) + return obj + + +class Instantiator: + def __new__(cls, cfg_path): + return instantiate(cfg_path) + + +class CallbackInstantiator(Callback): + """Type checking for Lightning Callback.""" + + def __new__(cls, cfg_path): + obj = instantiate(cfg_path) + if isinstance(obj, Callback): + return obj + else: + raise ValueError( + f"We expect to instantiate a lightning Callback from {cfg_path}, but we get {type(obj)} instead." + ) From dd10c795b4a7101796c1803dde21816c01f4de37 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 2 May 2024 09:27:44 -0700 Subject: [PATCH 5/9] Add mart.utils.Get() to extract a value from kwargs dict. --- mart/utils/utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/mart/utils/utils.py b/mart/utils/utils.py index f4a0a4ec..dbf3fb21 100644 --- a/mart/utils/utils.py +++ b/mart/utils/utils.py @@ -28,6 +28,7 @@ "save_file", "task_wrapper", "flatten_dict", + "Get", ] log = pylogger.get_pylogger(__name__) @@ -293,3 +294,13 @@ def get_dottedpath_items(d: dict, parent: Optional[str] = None): ret[key] = value return ret + + +class Get: + """Get a value from the kwargs dictionary by key.""" + + def __init__(self, key): + self.key = key + + def __call__(self, **kwargs): + return kwargs[self.key] From e0dc984b59d948c10cb786f351a4f0358a5f3183 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Tue, 14 May 2024 11:59:34 -0700 Subject: [PATCH 6/9] Comment --- mart/utils/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mart/utils/config.py b/mart/utils/config.py index 3a4268ad..578b4351 100644 --- a/mart/utils/config.py +++ b/mart/utils/config.py @@ -58,7 +58,7 @@ def __new__(cls, cfg_path): class CallbackInstantiator(Callback): - """Type checking for Lightning Callback.""" + """Satisfying type checking for Lightning Callback.""" def __new__(cls, cfg_path): obj = instantiate(cfg_path) From f4e9acc3a88d2bf08d69a3bc322f789ae80201ce Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Tue, 14 May 2024 22:22:31 -0700 Subject: [PATCH 7/9] Clean up. --- mart/configs/batch_c15n/image_01.yaml | 6 ------ mart/configs/batch_c15n/imagenet_normalized.yaml | 6 ------ 2 files changed, 12 deletions(-) delete mode 100644 mart/configs/batch_c15n/image_01.yaml delete mode 100644 mart/configs/batch_c15n/imagenet_normalized.yaml diff --git a/mart/configs/batch_c15n/image_01.yaml b/mart/configs/batch_c15n/image_01.yaml deleted file mode 100644 index f3f8e758..00000000 --- a/mart/configs/batch_c15n/image_01.yaml +++ /dev/null @@ -1,6 +0,0 @@ -defaults: - - list - - transform: times_255_and_round - - transform@untransform: divided_by_255 - -input_key: 0 diff --git a/mart/configs/batch_c15n/imagenet_normalized.yaml b/mart/configs/batch_c15n/imagenet_normalized.yaml deleted file mode 100644 index e41fb3ff..00000000 --- a/mart/configs/batch_c15n/imagenet_normalized.yaml +++ /dev/null @@ -1,6 +0,0 @@ -defaults: - - dict - - transform: imagenet_to_255 - - transform@untransform: 255_to_imagenet - -input_key: image From b283bf1d126d09f1aaf5d16d7e81e07cdb6c4a63 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Wed, 15 May 2024 09:37:10 -0700 Subject: [PATCH 8/9] Move to mart.nn.Get(). --- mart/nn/nn.py | 12 +++++++++++- mart/utils/utils.py | 11 ----------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 02113899..147a4773 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -13,7 +13,7 @@ import torch -__all__ = ["GroupNorm32", "SequentialDict", "ReturnKwargs", "CallWith", "Sum"] +__all__ = ["GroupNorm32", "SequentialDict", "ReturnKwargs", "CallWith", "Sum", "Get"] logger = logging.getLogger(__name__) @@ -300,3 +300,13 @@ def __init__(self): def forward(self, *args): return sum(args) + + +class Get: + """Get a value from the kwargs dictionary by key.""" + + def __init__(self, key): + self.key = key + + def __call__(self, **kwargs): + return kwargs[self.key] diff --git a/mart/utils/utils.py b/mart/utils/utils.py index dbf3fb21..f4a0a4ec 100644 --- a/mart/utils/utils.py +++ b/mart/utils/utils.py @@ -28,7 +28,6 @@ "save_file", "task_wrapper", "flatten_dict", - "Get", ] log = pylogger.get_pylogger(__name__) @@ -294,13 +293,3 @@ def get_dottedpath_items(d: dict, parent: Optional[str] = None): ret[key] = value return ret - - -class Get: - """Get a value from the kwargs dictionary by key.""" - - def __init__(self, key): - self.key = key - - def __call__(self, **kwargs): - return kwargs[self.key] From cd430e1e1f51641bd339bc3af7a68e9e1ccd08d3 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Thu, 16 May 2024 10:45:50 -0700 Subject: [PATCH 9/9] Add support to multi-level dicts. --- mart/nn/nn.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 147a4773..bff4a9b9 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -303,10 +303,16 @@ def forward(self, *args): class Get: - """Get a value from the kwargs dictionary by key.""" + """Get a value from the kwargs dictionary by key. + + The key can be a path to a nested dictionary, concatenated by dots. For example, + `Get(key="a.b")(a={"b": 1}) == 1`. + """ def __init__(self, key): self.key = key def __call__(self, **kwargs): + # Add support to nested dicts. + kwargs = DotDict(kwargs) return kwargs[self.key]