Skip to content

Commit

Permalink
Add unbounded patch adversary for object detection (#241)
Browse files Browse the repository at this point in the history
* Create a folder for attack.composer.

* Add composer modules for unbounded patch adversary.

* Add config of Adam optimizer.

* Add LoadCoords for patch adversary.

* Add a config of unbounded patch adversary.

* Add a datamodule config for carla patch adversary.
  • Loading branch information
mzweilin authored Feb 4, 2024
1 parent 5ccf5f0 commit 1360ab0
Show file tree
Hide file tree
Showing 10 changed files with 198 additions and 2 deletions.
2 changes: 2 additions & 0 deletions mart/attack/composer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .modular import *
from .patch import *
4 changes: 2 additions & 2 deletions mart/attack/composer.py → mart/attack/composer/modular.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from mart.nn import SequentialDict

if TYPE_CHECKING:
from .perturber import Perturber
from ..perturber import Perturber

__all__ = ["Composer"]
__all__ = ["Composer", "Additive", "Mask", "Overlay"]


class Composer(torch.nn.Module):
Expand Down
73 changes: 73 additions & 0 deletions mart/attack/composer/patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#
# Copyright (C) 2022 Intel Corporation
#
# SPDX-License-Identifier: BSD-3-Clause
#

from __future__ import annotations

import torch
import torchvision.transforms.functional as F

__all__ = [
"PertRectSize",
"PertExtractRect",
"PertRectPerspective",
]


class PertRectSize(torch.nn.Module):
"""Calculate the size of the smallest rectangle that can be transformed with the highest pixel
fidelity."""

@staticmethod
def get_smallest_rect(coords):
# Calculate the distance between two points.
coords_shifted = torch.cat([coords[1:], coords[0:1]])
w1, h2, w2, h1 = torch.sqrt(
torch.sum(torch.pow(torch.subtract(coords, coords_shifted), 2), dim=1)
)

height = int(max(h1, h2).round())
width = int(max(w1, w2).round())
return height, width

def forward(self, coords):
height, width = self.get_smallest_rect(coords)
return {"height": height, "width": width}


class PertExtractRect(torch.nn.Module):
"""Extract a small rectangular patch from the input size one."""

def forward(self, perturbation, height, width):
perturbation = perturbation[:, :height, :width]
return perturbation


class PertRectPerspective(torch.nn.Module):
"""Pad perturbation to input size, then perspective transform the top-left rectangle."""

def forward(self, perturbation, input, coords):
# Pad to the input size.
height_inp, width_inp = input.shape[-2:]
height_pert, width_pert = perturbation.shape[-2:]
height_pad = height_inp - height_pert
width_pad = width_inp - width_pert
perturbation = F.pad(perturbation, padding=[0, 0, width_pad, height_pad])

# F.perspective() requires startpoints and endpoints in CPU.
startpoints = torch.tensor(
[[0, 0], [width_pert, 0], [width_pert, height_pert], [0, height_pert]]
)
endpoints = coords.cpu()

perturbation = F.perspective(
img=perturbation,
startpoints=startpoints,
endpoints=endpoints,
interpolation=F.InterpolationMode.BILINEAR,
fill=0,
)

return perturbation
2 changes: 2 additions & 0 deletions mart/configs/attack/composer/modules/pert_extract_rect.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pert_extract_rect:
_target_: mart.attack.composer.PertExtractRect
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pert_rect_perspective:
_target_: mart.attack.composer.PertRectPerspective
2 changes: 2 additions & 0 deletions mart/configs/attack/composer/modules/pert_rect_size.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pert_rect_size:
_target_: mart.attack.composer.PertRectSize
44 changes: 44 additions & 0 deletions mart/configs/attack/object_detection_patch_adversary.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
defaults:
- adversary
- /optimizer@optimizer: adam
- enforcer: default
- composer: default
- composer/perturber/initializer: uniform
- composer/perturber/projector: range
- composer/modules:
[pert_rect_size, pert_extract_rect, pert_rect_perspective, overlay]
- gradient_modifier: sign
- gain: rcnn_training_loss
- objective: zero_ap
- override /callbacks@callbacks: [progress_bar, image_visualizer]

max_iters: ???
lr: ???

optimizer:
maximize: True
lr: ${..lr}

enforcer:
# No constraints with complex renderer in the pipeline.
# TODO: Constraint on digital perturbation?
constraints: {}

composer:
perturber:
initializer:
min: 0
max: 255
projector:
min: 0
max: 255
sequence:
seq010:
pert_rect_size: ["target.coords"]
seq020:
pert_extract_rect:
["perturbation", "pert_rect_size.height", "pert_rect_size.width"]
seq040:
pert_rect_perspective: ["pert_extract_rect", "input", "target.coords"]
seq050:
overlay: ["pert_rect_perspective", "input", "target.perturbable_mask"]
38 changes: 38 additions & 0 deletions mart/configs/datamodule/carla_patch.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
defaults:
- default.yaml

train_dataset: null

val_dataset: null

test_dataset:
_target_: mart.datamodules.coco.CocoDetection
root: ???
annFile: ${.root}/kwcoco_annotations.json
modalities: ["rgb"]
transforms:
_target_: mart.transforms.Compose
transforms:
- _target_: torchvision.transforms.ToTensor
- _target_: mart.transforms.ConvertCocoPolysToMask
- _target_: mart.transforms.LoadPerturbableMask
perturb_mask_folder: ${....root}/foreground_mask/
- _target_: mart.transforms.LoadCoords
folder: ${....root}/patch_metadata/
- _target_: mart.transforms.Denormalize
center: 0
scale: 255
- _target_: torch.fake_quantize_per_tensor_affine
_partial_: true
# (x/1+0).round().clamp(0, 255) * 1
scale: 1
zero_point: 0
quant_min: 0
quant_max: 255

num_workers: 0
ims_per_batch: 1

collate_fn:
_target_: hydra.utils.get_method
path: mart.datamodules.coco.collate_fn
12 changes: 12 additions & 0 deletions mart/configs/optimizer/adam.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
_target_: mart.optim.OptimizerFactory
optimizer:
_target_: hydra.utils.get_method
path: torch.optim.Adam
lr: ???
betas:
- 0.9
- 0.999
eps: 1e-08
weight_decay: 0
bias_decay: 0
norm_decay: 0
21 changes: 21 additions & 0 deletions mart/transforms/extended.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
from typing import Dict, Optional, Tuple

import numpy as np
import torch
from PIL import Image, ImageOps
from torch import Tensor
Expand All @@ -26,6 +27,7 @@
"Lambda",
"SplitLambda",
"LoadPerturbableMask",
"LoadCoords",
"ConvertInstanceSegmentationToPerturbable",
"RandomHorizontalFlip",
"ConvertCocoPolysToMask",
Expand Down Expand Up @@ -139,6 +141,25 @@ def __call__(self, image, target):
return image, target


class LoadCoords(ExTransform):
"""Load perturbable masks and add to target."""

def __init__(self, folder) -> None:
self.folder = folder
self.to_tensor = T.ToTensor()

def __call__(self, image, target):
file_name = os.path.splitext(target["file_name"])[0]
coords_fname = f"{file_name}_coords.npy"
coords_fpath = os.path.join(self.folder, coords_fname)
coords = np.load(coords_fpath)

coords = self.to_tensor(coords)[0]
# Convert to float to be differentiable.
target["coords"] = coords
return image, target


class RandomHorizontalFlip(T.RandomHorizontalFlip, ExTransform):
"""Flip the image and annotations including boxes, masks, keypoints and the
perturable_masks."""
Expand Down

0 comments on commit 1360ab0

Please sign in to comment.