Skip to content

Commit

Permalink
(#3) Defense: Impl. Motd
Browse files Browse the repository at this point in the history
  • Loading branch information
betarixm committed May 2, 2022
1 parent 90c2813 commit 76a2ed3
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
10 changes: 8 additions & 2 deletions src/attack/attack.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Tuple, Union
from typings.models import Attack, Model
from models import Fgsm, Pgd, Cw
from defense.models import Reformer, Denoiser
from defense.models import Reformer, Denoiser, Motd
from victim.models import Classifier

from utils.dataset import Mnist, Cifar10
Expand Down Expand Up @@ -44,7 +44,7 @@
type=str,
help="Defense method",
required=True,
choices=["reformer", "denoiser", "none"],
choices=["reformer", "denoiser", "motd", "none"],
)

args = parser.parse_args()
Expand Down Expand Up @@ -83,6 +83,12 @@
defense_model = Denoiser(
f"defense_denoiser_{args.dataset}", input_shape=input_shape
)
elif args.defense == "motd":
defense_model = Motd(
f"defense_motd_{args.dataset}",
input_shape=input_shape,
dataset=args.dataset,
)
else:
defense_model = None

Expand Down
34 changes: 33 additions & 1 deletion src/defense/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Literal
from typings.models import Model
from utils.layers import SlqLayer

Expand Down Expand Up @@ -69,3 +69,35 @@ def post_train(self):

def custom_callbacks(self) -> List[keras.callbacks.Callback]:
pass


class Motd(Model):
def __init__(
self,
name: str,
input_shape: tuple,
dataset: Literal["mnist", "cifar10"],
**kwargs,
):
self.reformer = Reformer(f"defense_reformer_{dataset}", input_shape=input_shape)
self.denoiser = Denoiser(f"defense_denoiser_{dataset}", input_shape=input_shape)

super().__init__(name, input_shape, **kwargs)

def _model(self) -> keras.Model:
self.reformer.compile()
self.reformer.load()

return keras.Model(
self.denoiser.model().inputs,
self.reformer.model()(self.denoiser.model().outputs),
)

def pre_train(self):
pass

def post_train(self):
pass

def custom_callbacks(self) -> List[keras.callbacks.Callback]:
pass

0 comments on commit 76a2ed3

Please sign in to comment.