From 670b9971836a12f5c9ed3e21a65d7bb3142e1de5 Mon Sep 17 00:00:00 2001 From: beta Date: Thu, 31 Mar 2022 12:46:40 +0900 Subject: [PATCH] (#6) Attack: Impl. attack with defense model --- src/attack/models.py | 27 +++++++++------------------ src/typings/models.py | 38 ++++++++++++++++++++++++++------------ 2 files changed, 35 insertions(+), 30 deletions(-) diff --git a/src/attack/models.py b/src/attack/models.py index 358267d..7f6a74f 100644 --- a/src/attack/models.py +++ b/src/attack/models.py @@ -1,37 +1,28 @@ -from typings.models import Attack, Model -from typings.dataset import Dataset +from typings.models import Attack from cleverhans.tf2.attacks.fast_gradient_method import fast_gradient_method from victim.models import Classifier +from defense.models import Reformer + from utils.dataset import Cifar10 import numpy as np import tensorflow as tf - keras = tf.keras class Fgsm(Attack): - def __init__( - self, - model: Model, - dataset: Dataset, - accuracy_normal: keras.metrics.Accuracy = tf.metrics.SparseCategoricalAccuracy(), - accuracy_under_attack: keras.metrics.Accuracy = tf.metrics.SparseCategoricalAccuracy(), - ): - super(Fgsm, self).__init__( - model, dataset, accuracy_normal, accuracy_under_attack - ) - def add_perturbation(self, x: np.array) -> np.array: - return fast_gradient_method(self.model.model(), x, 0.05, np.inf) + return fast_gradient_method(self.victim_model.model(), x, 0.05, np.inf) if __name__ == "__main__": f = Fgsm( - Classifier(name="victim_classifier_cifar10", input_shape=(32, 32, 3)), Cifar10() + Classifier(name="victim_classifier_cifar10", input_shape=(32, 32, 3)), + Cifar10(), + defense_model=Reformer("defense_reformer_cifar10", input_shape=(32, 32, 3)), ) - acc_with_attack, acc = f.attack() - print(acc_with_attack.result(), acc.result()) + acc, acc_under_attack, acc_with_defense = f.attack() + print(acc.result(), acc_under_attack.result(), acc_with_defense.result()) diff --git a/src/typings/models.py b/src/typings/models.py index da9bb87..d1a0055 100644 --- a/src/typings/models.py +++ b/src/typings/models.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List +from typing import List, Tuple, Optional from typings.dataset import Dataset import datetime @@ -148,24 +148,30 @@ def train(self, epochs: int = 100): class Attack(ABC): def __init__( self, - model: Model, + victim_model: Model, dataset: Dataset, + defense_model: Model = None, accuracy_normal: keras.metrics.Accuracy = tf.metrics.SparseCategoricalAccuracy(), accuracy_under_attack: keras.metrics.Accuracy = tf.metrics.SparseCategoricalAccuracy(), + accuracy_with_defense: keras.metrics.Accuracy = tf.metrics.SparseCategoricalAccuracy(), ): - self.model = model - self.dataset = dataset - self.accuracy_normal = accuracy_normal - self.accuracy_under_attack = accuracy_under_attack + self.victim_model: Model = victim_model + self.defense_model: Model = defense_model + self.dataset: Dataset = dataset + self.accuracy_normal: tf.metrics.Accuracy = accuracy_normal + self.accuracy_under_attack: tf.metrics.Accuracy = accuracy_under_attack + self.accuracy_with_defense: tf.metrics.Accuracy = accuracy_with_defense - self.model.compile() - self.model.load() + self.victim_model.compile() + self.victim_model.load() @abstractmethod def add_perturbation(self, x: np.array) -> np.array: pass - def attack(self): + def attack( + self, + ) -> Tuple[tf.metrics.Accuracy, tf.metrics.Accuracy, Optional[tf.metrics.Accuracy]]: _, test = self.dataset.dataset() progress = keras.utils.Progbar(test.cardinality().numpy()) @@ -173,12 +179,20 @@ def attack(self): for x, y in test: x_attack = self.add_perturbation(x) - y_attack = self.model.predict(x_attack) - y_normal = self.model.predict(x) + y_attack = self.victim_model.predict(x_attack) + y_normal = self.victim_model.predict(x) self.accuracy_under_attack(y, y_attack) self.accuracy_normal(y, y_normal) + if self.defense_model is not None: + self.accuracy_with_defense( + y, self.victim_model.predict(self.defense_model.predict(x_attack)) + ) progress.add(1) - return self.accuracy_under_attack, self.accuracy_normal + return ( + self.accuracy_normal, + self.accuracy_under_attack, + self.accuracy_with_defense, + )