Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better Perturber #146

Merged
merged 25 commits into from
Jun 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 28 additions & 16 deletions mart/attack/adversary.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,19 @@
from __future__ import annotations

from collections import OrderedDict
from typing import Any
from typing import TYPE_CHECKING, Any

import torch

from .callbacks import Callback
from .composer import Composer
from .enforcer import Enforcer
from .gain import Gain
from .objective import Objective
from .perturber import BatchPerturber, Perturber

if TYPE_CHECKING:
from .composer import Composer
from .enforcer import Enforcer
from .gain import Gain
from .gradient_modifier import GradientModifier
from .objective import Objective
from .perturber import Perturber

__all__ = ["Adversary", "Attacker"]

Expand Down Expand Up @@ -81,18 +84,19 @@ class Attacker(AttackerCallbackHookMixin, torch.nn.Module):
def __init__(
self,
*,
perturber: BatchPerturber | Perturber,
perturber: Perturber,
composer: Composer,
optimizer: torch.optim.Optimizer,
max_iters: int,
gain: Gain,
objective: Objective | None = None,
callbacks: dict[str, Callback] | None = None,
gradient_modifier: GradientModifier | None = None,
):
"""_summary_

Args:
perturber (BatchPerturber | Perturber): A module that stores perturbations.
perturber (Perturber): A module that stores perturbations.
composer (Composer): A module which composes adversarial examples from input and perturbation.
optimizer (torch.optim.Optimizer): A PyTorch optimizer.
max_iters (int): The max number of attack iterations.
Expand All @@ -102,24 +106,26 @@ def __init__(
"""
super().__init__()

self.perturber = perturber
# Hide the perturber module in a list, so that perturbation is not exported as a parameter in the model checkpoint.
self._perturber = [perturber]
self.composer = composer
self.optimizer_fn = optimizer

self.max_iters = max_iters
self.callbacks = OrderedDict()

# Register perturber as callback if it implements Callback interface
if isinstance(self.perturber, Callback):
# FIXME: Use self.perturber.__class__.__name__ as key?
self.callbacks["_perturber"] = self.perturber

if callbacks is not None:
self.callbacks.update(callbacks)

self.objective_fn = objective
# self.gain is a tensor.
self.gain_fn = gain
self.gradient_modifier = gradient_modifier

@property
def perturber(self) -> Perturber:
# Hide the perturber module in a list, so that perturbation is not exported as a parameter in the model checkpoint.
return self._perturber[0]

@property
def done(self) -> bool:
Expand Down Expand Up @@ -157,7 +163,8 @@ def on_run_start(
self.cur_iter = 0

# param_groups with learning rate and other optim params.
param_groups = self.perturber.parameter_groups()
self.perturber.configure_perturbation(input)
param_groups = self.perturber.parameters()

self.opt = self.optimizer_fn(param_groups)

Expand Down Expand Up @@ -290,6 +297,11 @@ def advance(
# Do not flip the gain value, because we set maximize=True in optimizer.
self.total_gain.backward()

if self.gradient_modifier is not None:
for param_group in self.opt.param_groups:
for param in param_group["params"]:
self.gradient_modifier(param)

self.opt.step()

def forward(
Expand All @@ -299,7 +311,7 @@ def forward(
target: torch.Tensor | dict[str, Any] | tuple,
**kwargs,
):
perturbation = self.perturber(input, target)
perturbation = self.perturber(input=input, target=target)
output = self.composer(perturbation, input=input, target=target)

return output
Expand Down
101 changes: 101 additions & 0 deletions mart/attack/perturber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#
# Copyright (C) 2022 Intel Corporation
#
# SPDX-License-Identifier: BSD-3-Clause
#

from __future__ import annotations

from typing import TYPE_CHECKING, Iterable

import torch
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from .projector import Projector

if TYPE_CHECKING:
from .initializer import Initializer

__all__ = ["Perturber"]


class Perturber(torch.nn.Module):
def __init__(
self,
*,
initializer: Initializer,
projector: Projector | None = None,
):
"""_summary_

Args:
initializer (Initializer): To initialize the perturbation.
projector (Projector): To project the perturbation into some space.
"""
super().__init__()

self.initializer_ = initializer
self.projector_ = projector or Projector()

self.perturbation = None

def configure_perturbation(self, input: torch.Tensor | Iterable[torch.Tensor]):
def matches(input, perturbation):
if perturbation is None:
return False

if isinstance(input, torch.Tensor) and isinstance(perturbation, torch.Tensor):
return input.shape == perturbation.shape

if isinstance(input, Iterable) and isinstance(perturbation, Iterable):
if len(input) != len(perturbation):
return False

return all(
[
matches(input_i, perturbation_i)
for input_i, perturbation_i in zip(input, perturbation)
]
)

return False

def create_from_tensor(tensor):
if isinstance(tensor, torch.Tensor):
return torch.nn.Parameter(
torch.empty_like(tensor, dtype=torch.float, requires_grad=True)
)
elif isinstance(tensor, Iterable):
return torch.nn.ParameterList([create_from_tensor(t) for t in tensor])
else:
raise NotImplementedError

# If we have never created a perturbation before or perturbation does not match input, then
# create a new perturbation.
if not matches(input, self.perturbation):
self.perturbation = create_from_tensor(input)

# Always (re)initialize perturbation.
self.initializer_(self.perturbation)

def named_parameters(self, *args, **kwargs):
if self.perturbation is None:
raise MisconfigurationException("You need to call configure_perturbation before fit.")

return super().named_parameters(*args, **kwargs)

def parameters(self, *args, **kwargs):
if self.perturbation is None:
raise MisconfigurationException("You need to call configure_perturbation before fit.")

return super().parameters(*args, **kwargs)

def forward(self, **batch):
if self.perturbation is None:
raise MisconfigurationException(
"You need to call the configure_perturbation before forward."
)

self.projector_(self.perturbation, **batch)

return self.perturbation
2 changes: 0 additions & 2 deletions mart/attack/perturber/__init__.py

This file was deleted.

82 changes: 0 additions & 82 deletions mart/attack/perturber/batch.py

This file was deleted.

Loading