From 0d088e87ad394805941f1aeadbb195398e566e68 Mon Sep 17 00:00:00 2001 From: beta Date: Thu, 31 Mar 2022 03:31:02 +0900 Subject: [PATCH] (#6) Attack: Abstraction for attacker model --- src/typings/models.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/src/typings/models.py b/src/typings/models.py index a684840..da9bb87 100644 --- a/src/typings/models.py +++ b/src/typings/models.py @@ -1,9 +1,11 @@ from abc import ABC, abstractmethod from typing import List +from typings.dataset import Dataset import datetime import tensorflow as tf +import numpy as np keras = tf.keras @@ -141,3 +143,42 @@ def train(self, epochs: int = 100): self.__model.evaluate(self.data_test) self.post_train() + + +class Attack(ABC): + def __init__( + self, + model: Model, + dataset: Dataset, + accuracy_normal: keras.metrics.Accuracy = tf.metrics.SparseCategoricalAccuracy(), + accuracy_under_attack: keras.metrics.Accuracy = tf.metrics.SparseCategoricalAccuracy(), + ): + self.model = model + self.dataset = dataset + self.accuracy_normal = accuracy_normal + self.accuracy_under_attack = accuracy_under_attack + + self.model.compile() + self.model.load() + + @abstractmethod + def add_perturbation(self, x: np.array) -> np.array: + pass + + def attack(self): + _, test = self.dataset.dataset() + + progress = keras.utils.Progbar(test.cardinality().numpy()) + + for x, y in test: + x_attack = self.add_perturbation(x) + + y_attack = self.model.predict(x_attack) + y_normal = self.model.predict(x) + + self.accuracy_under_attack(y, y_attack) + self.accuracy_normal(y, y_normal) + + progress.add(1) + + return self.accuracy_under_attack, self.accuracy_normal