From 6f0d0d145d137eeef478268721f7178aa79ab188 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Fri, 2 Feb 2024 10:25:54 -0800 Subject: [PATCH 01/18] Create a folder for attack.composer. --- mart/attack/composer/__init__.py | 1 + mart/attack/{composer.py => composer/modular.py} | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 mart/attack/composer/__init__.py rename mart/attack/{composer.py => composer/modular.py} (97%) diff --git a/mart/attack/composer/__init__.py b/mart/attack/composer/__init__.py new file mode 100644 index 00000000..8aafd38a --- /dev/null +++ b/mart/attack/composer/__init__.py @@ -0,0 +1 @@ +from .modular import * diff --git a/mart/attack/composer.py b/mart/attack/composer/modular.py similarity index 97% rename from mart/attack/composer.py rename to mart/attack/composer/modular.py index f2c840bf..d070436b 100644 --- a/mart/attack/composer.py +++ b/mart/attack/composer/modular.py @@ -13,9 +13,9 @@ from mart.nn import SequentialDict if TYPE_CHECKING: - from .perturber import Perturber + from ..perturber import Perturber -__all__ = ["Composer"] +__all__ = ["Composer", "Additive", "Mask", "Overlay"] class Composer(torch.nn.Module): From f2816aa71e75ee3b6a5b34c0593dc626901f1d02 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Fri, 2 Feb 2024 13:08:37 -0800 Subject: [PATCH 02/18] Add composer modules for unbounded patch adversary. --- mart/attack/composer/__init__.py | 1 + mart/attack/composer/patch.py | 73 +++++++++++++++++++ .../composer/modules/pert_extract_rect.yaml | 2 + .../modules/pert_rect_perspective.yaml | 2 + .../composer/modules/pert_rect_size.yaml | 2 + 5 files changed, 80 insertions(+) create mode 100644 mart/attack/composer/patch.py create mode 100644 mart/configs/attack/composer/modules/pert_extract_rect.yaml create mode 100644 mart/configs/attack/composer/modules/pert_rect_perspective.yaml create mode 100644 mart/configs/attack/composer/modules/pert_rect_size.yaml diff --git a/mart/attack/composer/__init__.py b/mart/attack/composer/__init__.py index 8aafd38a..3d51ec1d 100644 --- a/mart/attack/composer/__init__.py +++ b/mart/attack/composer/__init__.py @@ -1 +1,2 @@ from .modular import * +from .patch import * diff --git a/mart/attack/composer/patch.py b/mart/attack/composer/patch.py new file mode 100644 index 00000000..c1baeca8 --- /dev/null +++ b/mart/attack/composer/patch.py @@ -0,0 +1,73 @@ +# +# Copyright (C) 2022 Intel Corporation +# +# SPDX-License-Identifier: BSD-3-Clause +# + +from __future__ import annotations + +import torch +import torchvision.transforms.functional as F + +__all__ = [ + "PertRectSize", + "PertExtractRect", + "PertRectPerspective", +] + + +class PertRectSize(torch.nn.Module): + """Calculate the size of the smallest rectangle that can be transformed with the highest pixel + fidelity.""" + + @staticmethod + def get_smallest_rect(coords): + # Calculate the distance between two points. + coords_shifted = torch.cat([coords[1:], coords[0:1]]) + w1, h2, w2, h1 = torch.sqrt( + torch.sum(torch.pow(torch.subtract(coords, coords_shifted), 2), dim=1) + ) + + height = int(max(h1, h2).round()) + width = int(max(w1, w2).round()) + return height, width + + def forward(self, coords): + height, width = self.get_smallest_rect(coords) + return {"height": height, "width": width} + + +class PertExtractRect(torch.nn.Module): + """Extract a small rectangular patch from the input size one.""" + + def forward(self, perturbation, height, width): + perturbation = perturbation[:, :height, :width] + return perturbation + + +class PertRectPerspective(torch.nn.Module): + """Pad perturbation to input size, then perspective transform the top-left rectangle.""" + + def forward(self, perturbation, input, coords): + # Pad to the input size. + height_inp, width_inp = input.shape[-2:] + height_pert, width_pert = perturbation.shape[-2:] + height_pad = height_inp - height_pert + width_pad = width_inp - width_pert + perturbation = F.pad(perturbation, padding=[0, 0, width_pad, height_pad]) + + # F.perspective() requires startpoints and endpoints in CPU. + startpoints = torch.tensor( + [[0, 0], [width_pert, 0], [width_pert, height_pert], [0, height_pert]] + ) + endpoints = coords.cpu() + + perturbation = F.perspective( + img=perturbation, + startpoints=startpoints, + endpoints=endpoints, + interpolation=F.InterpolationMode.BILINEAR, + fill=0, + ) + + return perturbation diff --git a/mart/configs/attack/composer/modules/pert_extract_rect.yaml b/mart/configs/attack/composer/modules/pert_extract_rect.yaml new file mode 100644 index 00000000..ffdc0fa5 --- /dev/null +++ b/mart/configs/attack/composer/modules/pert_extract_rect.yaml @@ -0,0 +1,2 @@ +pert_extract_rect: + _target_: mart.attack.composer.PertExtractRect diff --git a/mart/configs/attack/composer/modules/pert_rect_perspective.yaml b/mart/configs/attack/composer/modules/pert_rect_perspective.yaml new file mode 100644 index 00000000..af28fb1f --- /dev/null +++ b/mart/configs/attack/composer/modules/pert_rect_perspective.yaml @@ -0,0 +1,2 @@ +pert_rect_perspective: + _target_: mart.attack.composer.PertRectPerspective diff --git a/mart/configs/attack/composer/modules/pert_rect_size.yaml b/mart/configs/attack/composer/modules/pert_rect_size.yaml new file mode 100644 index 00000000..60062b53 --- /dev/null +++ b/mart/configs/attack/composer/modules/pert_rect_size.yaml @@ -0,0 +1,2 @@ +pert_rect_size: + _target_: mart.attack.composer.PertRectSize From c84af699ec37271d98ea99a1816a03a40fe5bb13 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Fri, 2 Feb 2024 13:10:17 -0800 Subject: [PATCH 03/18] Add config of Adam optimizer. --- mart/configs/optimizer/adam.yaml | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 mart/configs/optimizer/adam.yaml diff --git a/mart/configs/optimizer/adam.yaml b/mart/configs/optimizer/adam.yaml new file mode 100644 index 00000000..311bed55 --- /dev/null +++ b/mart/configs/optimizer/adam.yaml @@ -0,0 +1,12 @@ +_target_: mart.optim.OptimizerFactory +optimizer: + _target_: hydra.utils.get_method + path: torch.optim.Adam +lr: ??? +betas: + - 0.9 + - 0.999 +eps: 1e-08 +weight_decay: 0 +bias_decay: 0 +norm_decay: 0 From 6a600b6f4a13e8c9622ba4f755d2421d68728f0e Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Fri, 2 Feb 2024 13:11:27 -0800 Subject: [PATCH 04/18] Add LoadCoords for patch adversary. --- mart/transforms/extended.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/mart/transforms/extended.py b/mart/transforms/extended.py index 13cd0e74..ba827865 100644 --- a/mart/transforms/extended.py +++ b/mart/transforms/extended.py @@ -8,6 +8,7 @@ import os from typing import Dict, Optional, Tuple +import numpy as np import torch from PIL import Image, ImageOps from torch import Tensor @@ -26,6 +27,7 @@ "Lambda", "SplitLambda", "LoadPerturbableMask", + "LoadCoords", "ConvertInstanceSegmentationToPerturbable", "RandomHorizontalFlip", "ConvertCocoPolysToMask", @@ -139,6 +141,25 @@ def __call__(self, image, target): return image, target +class LoadCoords(ExTransform): + """Load perturbable masks and add to target.""" + + def __init__(self, folder) -> None: + self.folder = folder + self.to_tensor = T.ToTensor() + + def __call__(self, image, target): + file_name = os.path.splitext(target["file_name"])[0] + coords_fname = f"{file_name}_coords.npy" + coords_fpath = os.path.join(self.folder, coords_fname) + coords = np.load(coords_fpath) + + coords = self.to_tensor(coords)[0] + # Convert to float to be differentiable. + target["coords"] = coords + return image, target + + class RandomHorizontalFlip(T.RandomHorizontalFlip, ExTransform): """Flip the image and annotations including boxes, masks, keypoints and the perturable_masks.""" From 14b9f2ac58b4aead5297f249035aa97ae9ac35b8 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Fri, 2 Feb 2024 13:12:42 -0800 Subject: [PATCH 05/18] Add a config of unbounded patch adversary. --- .../object_detection_patch_adversary.yaml | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 mart/configs/attack/object_detection_patch_adversary.yaml diff --git a/mart/configs/attack/object_detection_patch_adversary.yaml b/mart/configs/attack/object_detection_patch_adversary.yaml new file mode 100644 index 00000000..ceff346d --- /dev/null +++ b/mart/configs/attack/object_detection_patch_adversary.yaml @@ -0,0 +1,44 @@ +defaults: + - adversary + - /optimizer@optimizer: adam + - enforcer: default + - composer: default + - composer/perturber/initializer: uniform + - composer/perturber/projector: range + - composer/modules: + [pert_rect_size, pert_extract_rect, pert_rect_perspective, overlay] + - gradient_modifier: sign + - gain: rcnn_training_loss + - objective: zero_ap + - override /callbacks@callbacks: [progress_bar, image_visualizer] + +max_iters: ??? +lr: ??? + +optimizer: + maximize: True + lr: ${..lr} + +enforcer: + # No constraints with complex renderer in the pipeline. + # TODO: Constraint on digital perturbation? + constraints: {} + +composer: + perturber: + initializer: + min: 0 + max: 255 + projector: + min: 0 + max: 255 + sequence: + seq010: + pert_rect_size: ["target.coords"] + seq020: + pert_extract_rect: + ["perturbation", "pert_rect_size.height", "pert_rect_size.width"] + seq040: + pert_rect_perspective: ["pert_extract_rect", "input", "target.coords"] + seq050: + overlay: ["pert_rect_perspective", "input", "target.perturbable_mask"] From 9a52e603fd1b85f23bc706f33f940489c05457b4 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Fri, 2 Feb 2024 13:25:49 -0800 Subject: [PATCH 06/18] Add a datamodule config for carla patch adversary. --- mart/configs/datamodule/carla_patch.yaml | 38 ++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 mart/configs/datamodule/carla_patch.yaml diff --git a/mart/configs/datamodule/carla_patch.yaml b/mart/configs/datamodule/carla_patch.yaml new file mode 100644 index 00000000..0a00121e --- /dev/null +++ b/mart/configs/datamodule/carla_patch.yaml @@ -0,0 +1,38 @@ +defaults: + - default.yaml + +train_dataset: null + +val_dataset: null + +test_dataset: + _target_: mart.datamodules.coco.CocoDetection + root: ??? + annFile: ${.root}/kwcoco_annotations.json + modalities: ["rgb"] + transforms: + _target_: mart.transforms.Compose + transforms: + - _target_: torchvision.transforms.ToTensor + - _target_: mart.transforms.ConvertCocoPolysToMask + - _target_: mart.transforms.LoadPerturbableMask + perturb_mask_folder: ${....root}/foreground_mask/ + - _target_: mart.transforms.LoadCoords + folder: ${....root}/patch_metadata/ + - _target_: mart.transforms.Denormalize + center: 0 + scale: 255 + - _target_: torch.fake_quantize_per_tensor_affine + _partial_: true + # (x/1+0).round().clamp(0, 255) * 1 + scale: 1 + zero_point: 0 + quant_min: 0 + quant_max: 255 + +num_workers: 0 +ims_per_batch: 1 + +collate_fn: + _target_: hydra.utils.get_method + path: mart.datamodules.coco.collate_fn From b2bac033a893d1670c1edeb26cae2b88fd0080e6 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Fri, 2 Feb 2024 13:36:17 -0800 Subject: [PATCH 07/18] Fix the simple Linf projection. --- mart/attack/projector.py | 4 +++- .../configs/attack/composer/perturber/projector/linf.yaml | 6 ++++++ tests/test_projector.py | 8 ++++---- 3 files changed, 13 insertions(+), 5 deletions(-) create mode 100644 mart/configs/attack/composer/perturber/projector/linf.yaml diff --git a/mart/attack/projector.py b/mart/attack/projector.py index 8f88c0dd..bb30aa3e 100644 --- a/mart/attack/projector.py +++ b/mart/attack/projector.py @@ -131,7 +131,9 @@ def __init__(self, eps: int | float, p: int | float = torch.inf): @torch.no_grad() def project_(self, perturbation, *, input, target): pert_norm = perturbation.norm(p=self.p) - if pert_norm > self.eps: + if self.p == torch.inf: + perturbation.clamp_(-self.eps, self.eps) + elif pert_norm > self.eps: # We only upper-bound the norm. perturbation.mul_(self.eps / pert_norm) diff --git a/mart/configs/attack/composer/perturber/projector/linf.yaml b/mart/configs/attack/composer/perturber/projector/linf.yaml new file mode 100644 index 00000000..2051ae77 --- /dev/null +++ b/mart/configs/attack/composer/perturber/projector/linf.yaml @@ -0,0 +1,6 @@ +_target_: mart.attack.projector.Lp +# p is actually torch.inf by default. +p: + _target_: builtins.float + _args_: ["inf"] +eps: ??? diff --git a/tests/test_projector.py b/tests/test_projector.py index 19cb5c44..e2c6abe9 100644 --- a/tests/test_projector.py +++ b/tests/test_projector.py @@ -158,7 +158,7 @@ def test_compose(input_data, target_data): tensor.norm.return_value = 10 compose(tensor, input=input_data, target=target_data) - # RangeProjector, RangeAdditiveProjector, and LinfAdditiveRangeProjector calls `clamp_` - assert tensor.clamp_.call_count == 3 - # LpProjector and MaskProjector calls `mul_` - assert tensor.mul_.call_count == 2 + # RangeProjector, RangeAdditiveProjector, LpProjector_inf, and LinfAdditiveRangeProjector calls `clamp_` + assert tensor.clamp_.call_count == 4 + # MaskProjector calls `mul_` + assert tensor.mul_.call_count == 1 From 4ecf0a6551d8587a9449a5e502764abd0751972b Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Fri, 2 Feb 2024 13:38:25 -0800 Subject: [PATCH 08/18] Add composer module PertImageBase for Lp bounded patch adversary. --- mart/attack/composer/patch.py | 43 +++++++++++++++++++ .../composer/modules/pert_image_base.yaml | 3 ++ 2 files changed, 46 insertions(+) create mode 100644 mart/configs/attack/composer/modules/pert_image_base.yaml diff --git a/mart/attack/composer/patch.py b/mart/attack/composer/patch.py index c1baeca8..b9a6623b 100644 --- a/mart/attack/composer/patch.py +++ b/mart/attack/composer/patch.py @@ -8,11 +8,13 @@ import torch import torchvision.transforms.functional as F +from torchvision.io import read_image __all__ = [ "PertRectSize", "PertExtractRect", "PertRectPerspective", + "PertImageBase", ] @@ -71,3 +73,44 @@ def forward(self, perturbation, input, coords): ) return perturbation + + +class PertImageBase(torch.nn.Module): + """Resize an image and add to perturbation.""" + + def __init__(self, fpath): + super().__init__() + # RGBA -> RGB + self.image_orig = read_image(fpath)[:3, :, :] + self.image = None + + # Project the result with the pixel value constraint. + self.image_clamp = FakeClamp(min=0, max=255) + + def forward(self, perturbation): + # Initialize the image with new shape of perturbation. + if self.image is None or self.image.shape[-2:] != perturbation.shape[-2:]: + height, width = perturbation.shape[-2:] + self.image = F.resize(self.image_orig, size=[height, width]) + self.image = self.image.to(device=perturbation.device) + + perturbation = self.image + perturbation + perturbation = self.image_clamp(perturbation) + + return perturbation + + +class FakeClamp(torch.nn.Module): + """Clamp the data, but keep the gradient.""" + + def __init__(self, *, min, max): + super().__init__() + self.min = min + self.max = max + + def forward(self, a): + with torch.no_grad(): + delta = a.clamp(min=self.min, max=self.max) - a + + a = a + delta + return a diff --git a/mart/configs/attack/composer/modules/pert_image_base.yaml b/mart/configs/attack/composer/modules/pert_image_base.yaml new file mode 100644 index 00000000..50f16f08 --- /dev/null +++ b/mart/configs/attack/composer/modules/pert_image_base.yaml @@ -0,0 +1,3 @@ +pert_image_base: + _target_: mart.attack.composer.PertImageBase + fpath: ??? From d9fe249df0dc3b894da254d434af6df873c1744e Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Fri, 2 Feb 2024 13:41:19 -0800 Subject: [PATCH 09/18] Add config of lp-bounded patch adversary. --- .../object_detection_lp_patch_adversary.yaml | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 mart/configs/attack/object_detection_lp_patch_adversary.yaml diff --git a/mart/configs/attack/object_detection_lp_patch_adversary.yaml b/mart/configs/attack/object_detection_lp_patch_adversary.yaml new file mode 100644 index 00000000..cc6a4f54 --- /dev/null +++ b/mart/configs/attack/object_detection_lp_patch_adversary.yaml @@ -0,0 +1,55 @@ +defaults: + - adversary + - /optimizer@optimizer: adam + - enforcer: default + - composer: default + - composer/perturber/initializer: uniform + - composer/perturber/projector: linf + - composer/modules: + [ + pert_rect_size, + pert_extract_rect, + pert_image_base, + pert_rect_perspective, + overlay, + ] + - gradient_modifier: sign + - gain: rcnn_training_loss + - objective: zero_ap + - override /callbacks@callbacks: [progress_bar, image_visualizer] + +max_iters: ??? +lr: ??? +eps: ??? + +optimizer: + maximize: True + lr: ${..lr} + +enforcer: + # No constraints with complex renderer in the pipeline. + # TODO: Constraint on digital perturbation? + constraints: {} + +composer: + perturber: + initializer: + min: ${negate:${....eps}} + max: ${....eps} + projector: + eps: ${....eps} + modules: + pert_image_base: + fpath: ??? + sequence: + seq010: + pert_rect_size: ["target.coords"] + seq020: + pert_extract_rect: + ["perturbation", "pert_rect_size.height", "pert_rect_size.width"] + seq030: + pert_image_base: ["pert_extract_rect"] + seq040: + pert_rect_perspective: ["pert_image_base", "input", "target.coords"] + seq050: + overlay: ["pert_rect_perspective", "input", "target.perturbable_mask"] From 14515007a1778d6b433727eda5bed5de42efead6 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Fri, 2 Feb 2024 13:48:44 -0800 Subject: [PATCH 10/18] Add a fake renderer composer module. --- mart/attack/composer/patch.py | 24 +++++++++++++++++++ .../composer/modules/fake_renderer.yaml | 2 ++ 2 files changed, 26 insertions(+) create mode 100644 mart/configs/attack/composer/modules/fake_renderer.yaml diff --git a/mart/attack/composer/patch.py b/mart/attack/composer/patch.py index b9a6623b..a23f4ae1 100644 --- a/mart/attack/composer/patch.py +++ b/mart/attack/composer/patch.py @@ -14,6 +14,7 @@ "PertRectSize", "PertExtractRect", "PertRectPerspective", + "FakeRenderer", "PertImageBase", ] @@ -75,6 +76,29 @@ def forward(self, perturbation, input, coords): return perturbation +class FakeRenderer(torch.nn.Module): + """Replace image with a re-rendered image, but keep the gradient on the perturbation.""" + + def forward(self, perturbation, input, renderer): + """Use the same perturbation and target.coordinates to re-render input in Simulation. + + perturbation is the rectangular patch. or a masked frame, with rectangular coordinates + in target, so we can extract the rectangle patch for rendering. input is the digitally + composed frame. target should include everything that needs to re-render a frame with the + perturbation. + """ + + with torch.no_grad(): + input_rendered = renderer(perturbation) + input_rendered = input_rendered.clamp(0, 255) + delta = input_rendered - input + + # Fake differentiable rendering, or BPDA, but keep the mask constraint. + input = input + delta + + return input + + class PertImageBase(torch.nn.Module): """Resize an image and add to perturbation.""" diff --git a/mart/configs/attack/composer/modules/fake_renderer.yaml b/mart/configs/attack/composer/modules/fake_renderer.yaml new file mode 100644 index 00000000..13e19660 --- /dev/null +++ b/mart/configs/attack/composer/modules/fake_renderer.yaml @@ -0,0 +1,2 @@ +fake_renderer: + _target_: mart.attack.composer.FakeRenderer From a5ffef3b3bb812f25e2f44e224c051d8dcb1b617 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Fri, 2 Feb 2024 13:57:28 -0800 Subject: [PATCH 11/18] Teardown a test dataset gracefully for the rendering-in-loop adversary. --- mart/datamodules/modular.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mart/datamodules/modular.py b/mart/datamodules/modular.py index 3f7ea3f6..dc2e9137 100644 --- a/mart/datamodules/modular.py +++ b/mart/datamodules/modular.py @@ -132,3 +132,9 @@ def test_dataloader(self): collate_fn=self.collate_fn, **kwargs, ) + + def teardown(self, *, stage): + # Run teardown if dataset has it. + # An interactive dataset may have threads that we need to teardown. + if stage == "test" and hasattr(self.test_dataset, "teardown"): + self.test_dataset.teardown() From 70c0b8eaf5737083d86f1587fd36983998a2f02f Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Fri, 2 Feb 2024 14:00:05 -0800 Subject: [PATCH 12/18] Add configs of simulation-in-loop adversary. --- ...ect_detection_lp_patch_adversary_simulation.yaml | 13 +++++++++++++ ...object_detection_patch_adversary_simulation.yaml | 13 +++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 mart/configs/attack/object_detection_lp_patch_adversary_simulation.yaml create mode 100644 mart/configs/attack/object_detection_patch_adversary_simulation.yaml diff --git a/mart/configs/attack/object_detection_lp_patch_adversary_simulation.yaml b/mart/configs/attack/object_detection_lp_patch_adversary_simulation.yaml new file mode 100644 index 00000000..13316241 --- /dev/null +++ b/mart/configs/attack/object_detection_lp_patch_adversary_simulation.yaml @@ -0,0 +1,13 @@ +defaults: + - object_detection_lp_patch_adversary + +composer: + modules: + fake_renderer: + _target_: mart.attack.composer.FakeRenderer + + sequence: + seq060: + # Ignore output from overlay. + fake_renderer: + ["pert_image_base", "pert_rect_perspective", "target.renderer"] diff --git a/mart/configs/attack/object_detection_patch_adversary_simulation.yaml b/mart/configs/attack/object_detection_patch_adversary_simulation.yaml new file mode 100644 index 00000000..bcb2b8d6 --- /dev/null +++ b/mart/configs/attack/object_detection_patch_adversary_simulation.yaml @@ -0,0 +1,13 @@ +defaults: + - object_detection_patch_adversary + +composer: + modules: + fake_renderer: + _target_: mart.attack.composer.FakeRenderer + + sequence: + seq060: + # Ignore output from overlay. + fake_renderer: + ["pert_extract_rect", "pert_rect_perspective", "target.renderer"] From e4fe503ba1e5b58d6e70f95db2aad57fb283c1ad Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Fri, 2 Feb 2024 14:06:41 -0800 Subject: [PATCH 13/18] Add a datamodule config for CARLA patch rendering. --- .../datamodule/carla_patch_rendering.yaml | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 mart/configs/datamodule/carla_patch_rendering.yaml diff --git a/mart/configs/datamodule/carla_patch_rendering.yaml b/mart/configs/datamodule/carla_patch_rendering.yaml new file mode 100644 index 00000000..53d4dbd3 --- /dev/null +++ b/mart/configs/datamodule/carla_patch_rendering.yaml @@ -0,0 +1,38 @@ +defaults: + - default.yaml + +train_dataset: null + +val_dataset: null + +test_dataset: + _target_: oscar_datagen_tools.dataset.dataset.CarlaDataset + root: ??? + modality: "rgb" + annFile: ${.root}/kwcoco_annotations.json + num_insertion_ticks: 50 + localhost: true + overrides: [] + transforms: + _target_: mart.transforms.Compose + transforms: + - _target_: torchvision.transforms.ToTensor + - _target_: mart.transforms.ConvertCocoPolysToMask + - _target_: mart.transforms.LoadPerturbableMask + perturb_mask_folder: ${....root}/foreground_mask/ + - _target_: mart.transforms.LoadCoords + folder: ${....root}/patch_metadata/ + - _target_: mart.transforms.Denormalize + center: 0 + scale: 255 + - _target_: torch.fake_quantize_per_tensor_affine + _partial_: true + # (x/1+0).round().clamp(0, 255) * 1 + scale: 1 + zero_point: 0 + quant_min: 0 + quant_max: 255 + +collate_fn: + _target_: hydra.utils.get_method + path: mart.datamodules.coco.collate_fn From 08cbffc04ed0f5adb3be5b94e7b79267daedfc4c Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Sat, 3 Feb 2024 10:41:21 -0800 Subject: [PATCH 14/18] Update CarlaDataset config. --- mart/configs/datamodule/carla_patch_rendering.yaml | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/mart/configs/datamodule/carla_patch_rendering.yaml b/mart/configs/datamodule/carla_patch_rendering.yaml index 53d4dbd3..f748abce 100644 --- a/mart/configs/datamodule/carla_patch_rendering.yaml +++ b/mart/configs/datamodule/carla_patch_rendering.yaml @@ -7,21 +7,19 @@ val_dataset: null test_dataset: _target_: oscar_datagen_tools.dataset.dataset.CarlaDataset - root: ??? + simulation_run: ??? modality: "rgb" - annFile: ${.root}/kwcoco_annotations.json + annFile: ${.simulation_run}/kwcoco_annotations.json num_insertion_ticks: 50 - localhost: true - overrides: [] transforms: _target_: mart.transforms.Compose transforms: - _target_: torchvision.transforms.ToTensor - _target_: mart.transforms.ConvertCocoPolysToMask - _target_: mart.transforms.LoadPerturbableMask - perturb_mask_folder: ${....root}/foreground_mask/ + perturb_mask_folder: ${....simulation_run}/foreground_mask/ - _target_: mart.transforms.LoadCoords - folder: ${....root}/patch_metadata/ + folder: ${....simulation_run}/patch_metadata/ - _target_: mart.transforms.Denormalize center: 0 scale: 255 From a17e224d425a637d233ca7415d986b8d8ca8b2d7 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Sat, 3 Feb 2024 11:04:16 -0800 Subject: [PATCH 15/18] Add a composer.visualize switch to see intermediate images. --- mart/attack/composer/modular.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/mart/attack/composer/modular.py b/mart/attack/composer/modular.py index d070436b..c614a3c5 100644 --- a/mart/attack/composer/modular.py +++ b/mart/attack/composer/modular.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Iterable import torch +from torchvision.transforms.functional import to_pil_image from mart.nn import SequentialDict @@ -19,7 +20,7 @@ class Composer(torch.nn.Module): - def __init__(self, perturber: Perturber, modules, sequence) -> None: + def __init__(self, perturber: Perturber, modules, sequence, visualize: bool = False) -> None: """_summary_ Args: @@ -34,6 +35,7 @@ def __init__(self, perturber: Perturber, modules, sequence) -> None: if isinstance(sequence, dict): sequence = [sequence[key] for key in sorted(sequence)] self.functions = SequentialDict(modules, {"composer": sequence}) + self.visualize = visualize def configure_perturbation(self, input: torch.Tensor | Iterable[torch.Tensor]): return self.perturber.configure_perturbation(input) @@ -76,6 +78,12 @@ def _compose( input=input, target=target, perturbation=perturbation, step="composer" ) + # Visualize intermediate images. + if self.visualize: + for key, value in output.items(): + if isinstance(value, torch.Tensor): + to_pil_image(value / 255).save(f"{key}.png") + # SequentialDict returns a dictionary DotDict, # but we only need the return value of the most recently executed module. last_added_key = next(reversed(output)) From 39dccf9f22c33db44b62fabad5fdd4fa71728cd5 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Sun, 4 Feb 2024 08:27:18 -0800 Subject: [PATCH 16/18] Revert "Teardown a test dataset gracefully for the rendering-in-loop adversary." This reverts commit a5ffef3b3bb812f25e2f44e224c051d8dcb1b617. --- mart/datamodules/modular.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/mart/datamodules/modular.py b/mart/datamodules/modular.py index dc2e9137..3f7ea3f6 100644 --- a/mart/datamodules/modular.py +++ b/mart/datamodules/modular.py @@ -132,9 +132,3 @@ def test_dataloader(self): collate_fn=self.collate_fn, **kwargs, ) - - def teardown(self, *, stage): - # Run teardown if dataset has it. - # An interactive dataset may have threads that we need to teardown. - if stage == "test" and hasattr(self.test_dataset, "teardown"): - self.test_dataset.teardown() From fcc1aaa5a2edb22f9cd11a629f311a83cc7875de Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Sun, 4 Feb 2024 08:28:24 -0800 Subject: [PATCH 17/18] Revert "Add a composer.visualize switch to see intermediate images." This reverts commit a17e224d425a637d233ca7415d986b8d8ca8b2d7. --- mart/attack/composer/modular.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/mart/attack/composer/modular.py b/mart/attack/composer/modular.py index c614a3c5..d070436b 100644 --- a/mart/attack/composer/modular.py +++ b/mart/attack/composer/modular.py @@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any, Iterable import torch -from torchvision.transforms.functional import to_pil_image from mart.nn import SequentialDict @@ -20,7 +19,7 @@ class Composer(torch.nn.Module): - def __init__(self, perturber: Perturber, modules, sequence, visualize: bool = False) -> None: + def __init__(self, perturber: Perturber, modules, sequence) -> None: """_summary_ Args: @@ -35,7 +34,6 @@ def __init__(self, perturber: Perturber, modules, sequence, visualize: bool = Fa if isinstance(sequence, dict): sequence = [sequence[key] for key in sorted(sequence)] self.functions = SequentialDict(modules, {"composer": sequence}) - self.visualize = visualize def configure_perturbation(self, input: torch.Tensor | Iterable[torch.Tensor]): return self.perturber.configure_perturbation(input) @@ -78,12 +76,6 @@ def _compose( input=input, target=target, perturbation=perturbation, step="composer" ) - # Visualize intermediate images. - if self.visualize: - for key, value in output.items(): - if isinstance(value, torch.Tensor): - to_pil_image(value / 255).save(f"{key}.png") - # SequentialDict returns a dictionary DotDict, # but we only need the return value of the most recently executed module. last_added_key = next(reversed(output)) From ee33d5ed53ebbd18b9250532bee206134d54265c Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Sat, 3 Feb 2024 11:04:16 -0800 Subject: [PATCH 18/18] Add a composer.visualize switch to see intermediate images. --- mart/attack/composer/modular.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/mart/attack/composer/modular.py b/mart/attack/composer/modular.py index d070436b..c614a3c5 100644 --- a/mart/attack/composer/modular.py +++ b/mart/attack/composer/modular.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Iterable import torch +from torchvision.transforms.functional import to_pil_image from mart.nn import SequentialDict @@ -19,7 +20,7 @@ class Composer(torch.nn.Module): - def __init__(self, perturber: Perturber, modules, sequence) -> None: + def __init__(self, perturber: Perturber, modules, sequence, visualize: bool = False) -> None: """_summary_ Args: @@ -34,6 +35,7 @@ def __init__(self, perturber: Perturber, modules, sequence) -> None: if isinstance(sequence, dict): sequence = [sequence[key] for key in sorted(sequence)] self.functions = SequentialDict(modules, {"composer": sequence}) + self.visualize = visualize def configure_perturbation(self, input: torch.Tensor | Iterable[torch.Tensor]): return self.perturber.configure_perturbation(input) @@ -76,6 +78,12 @@ def _compose( input=input, target=target, perturbation=perturbation, step="composer" ) + # Visualize intermediate images. + if self.visualize: + for key, value in output.items(): + if isinstance(value, torch.Tensor): + to_pil_image(value / 255).save(f"{key}.png") + # SequentialDict returns a dictionary DotDict, # but we only need the return value of the most recently executed module. last_added_key = next(reversed(output))