Skip to content

Commit

Permalink
Modular Composer (#240)
Browse files Browse the repository at this point in the history
* Make a modular composer.

* Fix tests.

* Add shortcuts to common composers.

* Fix configs.

* Clean up.
  • Loading branch information
mzweilin authored Feb 1, 2024
1 parent 8087b62 commit 5ccf5f0
Show file tree
Hide file tree
Showing 15 changed files with 105 additions and 97 deletions.
76 changes: 27 additions & 49 deletions mart/attack/composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,20 @@

from __future__ import annotations

import abc
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Iterable

import torch

from mart.nn import SequentialDict

if TYPE_CHECKING:
from .perturber import Perturber


class Function(torch.nn.Module):
def __init__(self, *args, order=0, **kwargs) -> None:
"""A stackable function for Composer.
Args:
order (int, optional): The priority number. A smaller number makes a function run earlier than others in a sequence. Defaults to 0.
"""
super().__init__(*args, **kwargs)
self.order = order

@abc.abstractmethod
def forward(
self, perturbation: torch.Tensor, input: torch.Tensor, target: torch.Tensor | dict
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | dict]:
"""Returns the modified perturbation, modified input and target, so we can chain Functions
in a Composer."""
pass
__all__ = ["Composer"]


class Composer(torch.nn.Module):
def __init__(self, perturber: Perturber, functions: dict[str, Function]) -> None:
def __init__(self, perturber: Perturber, modules, sequence) -> None:
"""_summary_
Args:
Expand All @@ -47,11 +30,10 @@ def __init__(self, perturber: Perturber, functions: dict[str, Function]) -> None

self.perturber = perturber

# Sort functions by function.order and the name.
self.functions_dict = OrderedDict(
sorted(functions.items(), key=lambda name_fn: (name_fn[1].order, name_fn[0]))
)
self.functions = list(self.functions_dict.values())
# Convert dict sequences to list sequences by sorting keys
if isinstance(sequence, dict):
sequence = [sequence[key] for key in sorted(sequence)]
self.functions = SequentialDict(modules, {"composer": sequence})

def configure_perturbation(self, input: torch.Tensor | Iterable[torch.Tensor]):
return self.perturber.configure_perturbation(input)
Expand Down Expand Up @@ -89,48 +71,44 @@ def _compose(
input: torch.Tensor,
target: torch.Tensor | dict[str, Any],
) -> torch.Tensor:
for function in self.functions:
perturbation, input, target = function(perturbation, input, target)
# A computational graph in SequentialDict().
output = self.functions(
input=input, target=target, perturbation=perturbation, step="composer"
)

# SequentialDict returns a dictionary DotDict,
# but we only need the return value of the most recently executed module.
last_added_key = next(reversed(output))
output = output[last_added_key]

# Return the composed input.
return input
return output


class Additive(Function):
class Additive(torch.nn.Module):
"""We assume an adversary adds perturbation to the input."""

def forward(self, perturbation, input, target):
def forward(self, perturbation, input):
input = input + perturbation
return perturbation, input, target

return input

class Mask(Function):
def __init__(self, *args, key="perturbable_mask", **kwargs):
super().__init__(*args, **kwargs)
self.key = key

def forward(self, perturbation, input, target):
mask = target[self.key]
class Mask(torch.nn.Module):
def forward(self, perturbation, mask):
perturbation = perturbation * mask
return perturbation, input, target
return perturbation


class Overlay(Function):
class Overlay(torch.nn.Module):
"""We assume an adversary overlays a patch to the input."""

def __init__(self, *args, key="perturbable_mask", **kwargs):
super().__init__(*args, **kwargs)
self.key = key

def forward(self, perturbation, input, target):
def forward(self, perturbation, input, mask):
# True is mutable, False is immutable.
mask = target[self.key]

# Convert mask to a Tensor with same torch.dtype and torch.device as input,
# because some data modules (e.g. Armory) gives binary mask.
mask = mask.to(input)

perturbation = perturbation * mask

input = input * (1 - mask) + perturbation
return perturbation, input, target
return input
2 changes: 1 addition & 1 deletion mart/configs/attack/adversary.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
defaults:
- /callbacks@callbacks: [progress_bar]
- composer: default

_target_: mart.attack.Adversary
_convert_: all
optimizer:
maximize: True
lr_scheduler: null
composer: ???
gain: ???
gradient_modifier: null
objective: null
Expand Down
2 changes: 1 addition & 1 deletion mart/configs/attack/classification_fgsm_linf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ defaults:
- adversary
- fgm
- linf
- composer/functions: additive
- composer: additive
- gradient_modifier: sign
- gain: cross_entropy
- objective: misclassification
Expand Down
2 changes: 1 addition & 1 deletion mart/configs/attack/classification_pgd_linf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ defaults:
- adversary
- pgd
- linf
- composer/functions: additive
- composer: additive
- gradient_modifier: sign
- gain: cross_entropy
- objective: misclassification
Expand Down
7 changes: 7 additions & 0 deletions mart/configs/attack/composer/additive.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- default
- modules: [additive]

sequence:
seq010:
additive: [perturbation, input]
7 changes: 6 additions & 1 deletion mart/configs/attack/composer/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,9 @@ defaults:
- perturber: default

_target_: mart.attack.Composer
functions: ???
modules:
???
# Example: additive, mask, overlay
sequence:
???
# Wire modules, with input,
12 changes: 6 additions & 6 deletions mart/configs/attack/composer/mask_additive.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
defaults:
- default
- functions: [mask, additive]
- modules: [mask, additive]

functions:
mask:
order: 0
additive:
order: 1
sequence:
seq010:
mask: [perturbation, target.perturbable_mask]
seq020:
additive: [mask, input]
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
additive:
_target_: mart.attack.composer.Additive
order: 0
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
mask:
_target_: mart.attack.composer.Mask
key: perturbable_mask
order: 0
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
overlay:
_target_: mart.attack.composer.Overlay
key: perturbable_mask
order: 0
7 changes: 7 additions & 0 deletions mart/configs/attack/composer/overlay.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- default
- modules: [overlay]

sequence:
seq010:
overlay: [perturbation, input, target.perturbable_mask]
2 changes: 1 addition & 1 deletion mart/configs/attack/object_detection_mask_adversary.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ defaults:
- adversary
- gradient_ascent
- mask
- composer: overlay
- composer/perturber/initializer: constant
- composer/functions: overlay
- gradient_modifier: sign
- gain: rcnn_training_loss
- objective: zero_ap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ defaults:
- adversary
- gradient_ascent
- mask
- composer: overlay
- composer/perturber/initializer: constant
- composer/functions: overlay
- gradient_modifier: sign
- gain: rcnn_class_background
- objective: object_detection_missed
Expand Down
60 changes: 35 additions & 25 deletions tests/test_adversary.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@
from functools import partial
from unittest.mock import Mock

import pytest
import torch
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from torch.optim import SGD

import mart
from mart.attack import Adversary, Composer, Perturber
from mart.attack import Adversary, Perturber
from mart.attack.composer import Additive, Composer
from mart.attack.gradient_modifier import Sign


def test_with_model(input_data, target_data, perturbation):
perturber = Mock(spec=Perturber, return_value=perturbation)
functions = {"additive": mart.attack.composer.Additive()}
composer = Composer(perturber=perturber, functions=functions)
modules = {"additive": Additive()}
sequence = {"seq010": {"additive": ["perturbation", "input"]}}
composer = Composer(perturber=perturber, modules=modules, sequence=sequence)
gain = Mock()
enforcer = Mock()
attacker = Mock(max_epochs=0, limit_train_batches=1, fit_loop=Mock(max_epochs=0))
Expand Down Expand Up @@ -54,8 +54,9 @@ def test_hidden_params():
initializer = Mock()
projector = Mock()
perturber = Perturber(initializer=initializer, projector=projector)
functions = {"additive": mart.attack.composer.Additive()}
composer = Composer(perturber=perturber, functions=functions)
modules = {"additive": Additive()}
sequence = {"seq010": {"additive": ["perturbation", "input"]}}
composer = Composer(perturber=perturber, modules=modules, sequence=sequence)

gain = Mock()
enforcer = Mock()
Expand All @@ -82,8 +83,9 @@ def test_hidden_params_after_forward(input_data, target_data, perturbation):
initializer = Mock()
projector = Mock()
perturber = Perturber(initializer=initializer, projector=projector)
functions = {"additive": mart.attack.composer.Additive()}
composer = Composer(perturber=perturber, functions=functions)
modules = {"additive": Additive()}
sequence = {"seq010": {"additive": ["perturbation", "input"]}}
composer = Composer(perturber=perturber, modules=modules, sequence=sequence)

gain = Mock()
enforcer = Mock()
Expand Down Expand Up @@ -115,8 +117,9 @@ def test_loading_perturbation_from_state_dict():
initializer = Mock()
projector = Mock()
perturber = Perturber(initializer=initializer, projector=projector)
functions = {"additive": mart.attack.composer.Additive()}
composer = Composer(perturber=perturber, functions=functions)
modules = {"additive": Additive()}
sequence = {"seq010": {"additive": ["perturbation", "input"]}}
composer = Composer(perturber=perturber, modules=modules, sequence=sequence)

gain = Mock()
enforcer = Mock()
Expand All @@ -141,8 +144,9 @@ def test_loading_perturbation_from_state_dict():

def test_perturbation(input_data, target_data, perturbation):
perturber = Mock(spec=Perturber, return_value=perturbation)
functions = {"additive": mart.attack.composer.Additive()}
composer = Composer(perturber=perturber, functions=functions)
modules = {"additive": Additive()}
sequence = {"seq010": {"additive": ["perturbation", "input"]}}
composer = Composer(perturber=perturber, modules=modules, sequence=sequence)
gain = Mock()
enforcer = Mock()
attacker = Mock(max_epochs=0, limit_train_batches=1, fit_loop=Mock(max_epochs=0))
Expand Down Expand Up @@ -191,8 +195,9 @@ def initializer(x):
initializer=initializer,
projector=None,
)
functions = {"additive": mart.attack.composer.Additive()}
composer = Composer(perturber=perturber, functions=functions)
modules = {"additive": Additive()}
sequence = {"seq010": {"additive": ["perturbation", "input"]}}
composer = Composer(perturber=perturber, modules=modules, sequence=sequence)

adversary = Adversary(
composer=composer,
Expand All @@ -216,8 +221,9 @@ def model(input, target):

def test_configure_optimizers():
perturber = Mock()
functions = {"additive": mart.attack.composer.Additive()}
composer = Composer(perturber=perturber, functions=functions)
modules = {"additive": Additive()}
sequence = {"seq010": {"additive": ["perturbation", "input"]}}
composer = Composer(perturber=perturber, modules=modules, sequence=sequence)
optimizer = Mock(spec=mart.optim.OptimizerFactory)
gain = Mock()

Expand All @@ -235,8 +241,9 @@ def test_configure_optimizers():

def test_training_step(input_data, target_data, perturbation):
perturber = Mock(spec=Perturber, return_value=perturbation)
functions = {"additive": mart.attack.composer.Additive()}
composer = Composer(perturber=perturber, functions=functions)
modules = {"additive": Additive()}
sequence = {"seq010": {"additive": ["perturbation", "input"]}}
composer = Composer(perturber=perturber, modules=modules, sequence=sequence)
optimizer = Mock(spec=mart.optim.OptimizerFactory)
gain = Mock(return_value=torch.tensor(1337))
model = Mock(spec="__call__", return_value={})
Expand All @@ -256,8 +263,9 @@ def test_training_step(input_data, target_data, perturbation):

def test_training_step_with_many_gain(input_data, target_data, perturbation):
perturber = Mock(spec=Perturber, return_value=perturbation)
functions = {"additive": mart.attack.composer.Additive()}
composer = Composer(perturber=perturber, functions=functions)
modules = {"additive": Additive()}
sequence = {"seq010": {"additive": ["perturbation", "input"]}}
composer = Composer(perturber=perturber, modules=modules, sequence=sequence)
optimizer = Mock(spec=mart.optim.OptimizerFactory)
gain = Mock(return_value=torch.tensor([1234, 5678]))
model = Mock(spec="__call__", return_value={})
Expand All @@ -276,8 +284,9 @@ def test_training_step_with_many_gain(input_data, target_data, perturbation):

def test_training_step_with_objective(input_data, target_data, perturbation):
perturber = Mock(spec=Perturber, return_value=perturbation)
functions = {"additive": mart.attack.composer.Additive()}
composer = Composer(perturber=perturber, functions=functions)
modules = {"additive": Additive()}
sequence = {"seq010": {"additive": ["perturbation", "input"]}}
composer = Composer(perturber=perturber, modules=modules, sequence=sequence)
optimizer = Mock(spec=mart.optim.OptimizerFactory)
gain = Mock(return_value=torch.tensor([1234, 5678]))
# The model has no attack_step() or training_step().
Expand All @@ -301,8 +310,9 @@ def test_training_step_with_objective(input_data, target_data, perturbation):

def test_configure_gradient_clipping():
perturber = Mock()
functions = {"additive": mart.attack.composer.Additive()}
composer = Composer(perturber=perturber, functions=functions)
modules = {"additive": Additive()}
sequence = {"seq010": {"additive": ["perturbation", "input"]}}
composer = Composer(perturber=perturber, modules=modules, sequence=sequence)

optimizer = Mock(
spec=mart.optim.OptimizerFactory,
Expand Down
Loading

0 comments on commit 5ccf5f0

Please sign in to comment.