Skip to content

Commit

Permalink
Add Lp-bounded patch adversary for object detection (#242)
Browse files Browse the repository at this point in the history
* Create a folder for attack.composer.

* Add composer modules for unbounded patch adversary.

* Add config of Adam optimizer.

* Add LoadCoords for patch adversary.

* Add a config of unbounded patch adversary.

* Add a datamodule config for carla patch adversary.

* Fix the simple Linf projection.

* Add composer module PertImageBase for Lp bounded patch adversary.

* Add config of lp-bounded patch adversary.

* Formatting
  • Loading branch information
mzweilin authored Feb 4, 2024
1 parent 1360ab0 commit 3862abe
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 5 deletions.
43 changes: 43 additions & 0 deletions mart/attack/composer/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@

import torch
import torchvision.transforms.functional as F
from torchvision.io import read_image

__all__ = [
"PertRectSize",
"PertExtractRect",
"PertRectPerspective",
"PertImageBase",
]


Expand Down Expand Up @@ -71,3 +73,44 @@ def forward(self, perturbation, input, coords):
)

return perturbation


class PertImageBase(torch.nn.Module):
"""Resize an image and add to perturbation."""

def __init__(self, fpath):
super().__init__()
# RGBA -> RGB
self.image_orig = read_image(fpath)[:3, :, :]
self.image = None

# Project the result with the pixel value constraint.
self.image_clamp = FakeClamp(min=0, max=255)

def forward(self, perturbation):
# Initialize the image with new shape of perturbation.
if self.image is None or self.image.shape[-2:] != perturbation.shape[-2:]:
height, width = perturbation.shape[-2:]
self.image = F.resize(self.image_orig, size=[height, width])
self.image = self.image.to(device=perturbation.device)

perturbation = self.image + perturbation
perturbation = self.image_clamp(perturbation)

return perturbation


class FakeClamp(torch.nn.Module):
"""Clamp the data, but keep the gradient."""

def __init__(self, *, min, max):
super().__init__()
self.min = min
self.max = max

def forward(self, a):
with torch.no_grad():
delta = a.clamp(min=self.min, max=self.max) - a

a = a + delta
return a
4 changes: 3 additions & 1 deletion mart/attack/projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ def __init__(self, eps: int | float, p: int | float = torch.inf):
@torch.no_grad()
def project_(self, perturbation, *, input, target):
pert_norm = perturbation.norm(p=self.p)
if pert_norm > self.eps:
if self.p == torch.inf:
perturbation.clamp_(-self.eps, self.eps)
elif pert_norm > self.eps:
# We only upper-bound the norm.
perturbation.mul_(self.eps / pert_norm)

Expand Down
3 changes: 3 additions & 0 deletions mart/configs/attack/composer/modules/pert_image_base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pert_image_base:
_target_: mart.attack.composer.PertImageBase
fpath: ???
6 changes: 6 additions & 0 deletions mart/configs/attack/composer/perturber/projector/linf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_target_: mart.attack.projector.Lp
# p is actually torch.inf by default.
p:
_target_: builtins.float
_args_: ["inf"]
eps: ???
55 changes: 55 additions & 0 deletions mart/configs/attack/object_detection_lp_patch_adversary.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
defaults:
- adversary
- /optimizer@optimizer: adam
- enforcer: default
- composer: default
- composer/perturber/initializer: uniform
- composer/perturber/projector: linf
- composer/modules:
[
pert_rect_size,
pert_extract_rect,
pert_image_base,
pert_rect_perspective,
overlay,
]
- gradient_modifier: sign
- gain: rcnn_training_loss
- objective: zero_ap
- override /callbacks@callbacks: [progress_bar, image_visualizer]

max_iters: ???
lr: ???
eps: ???

optimizer:
maximize: True
lr: ${..lr}

enforcer:
# No constraints with complex renderer in the pipeline.
# TODO: Constraint on digital perturbation?
constraints: {}

composer:
perturber:
initializer:
min: ${negate:${....eps}}
max: ${....eps}
projector:
eps: ${....eps}
modules:
pert_image_base:
fpath: ???
sequence:
seq010:
pert_rect_size: ["target.coords"]
seq020:
pert_extract_rect:
["perturbation", "pert_rect_size.height", "pert_rect_size.width"]
seq030:
pert_image_base: ["pert_extract_rect"]
seq040:
pert_rect_perspective: ["pert_image_base", "input", "target.coords"]
seq050:
overlay: ["pert_rect_perspective", "input", "target.perturbable_mask"]
8 changes: 4 additions & 4 deletions tests/test_projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def test_compose(input_data, target_data):
tensor.norm.return_value = 10
compose(tensor, input=input_data, target=target_data)

# RangeProjector, RangeAdditiveProjector, and LinfAdditiveRangeProjector calls `clamp_`
assert tensor.clamp_.call_count == 3
# LpProjector and MaskProjector calls `mul_`
assert tensor.mul_.call_count == 2
# RangeProjector, RangeAdditiveProjector, LpProjector_inf, and LinfAdditiveRangeProjector calls `clamp_`
assert tensor.clamp_.call_count == 4
# MaskProjector calls `mul_`
assert tensor.mul_.call_count == 1

0 comments on commit 3862abe

Please sign in to comment.