Skip to content

Commit

Permalink
(#6) Attack: Abstraction for attacker model
Browse files Browse the repository at this point in the history
  • Loading branch information
betarixm committed Mar 30, 2022
1 parent 6d137bf commit 0d088e8
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions src/typings/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from abc import ABC, abstractmethod
from typing import List
from typings.dataset import Dataset

import datetime

import tensorflow as tf
import numpy as np

keras = tf.keras

Expand Down Expand Up @@ -141,3 +143,42 @@ def train(self, epochs: int = 100):
self.__model.evaluate(self.data_test)

self.post_train()


class Attack(ABC):
def __init__(
self,
model: Model,
dataset: Dataset,
accuracy_normal: keras.metrics.Accuracy = tf.metrics.SparseCategoricalAccuracy(),
accuracy_under_attack: keras.metrics.Accuracy = tf.metrics.SparseCategoricalAccuracy(),
):
self.model = model
self.dataset = dataset
self.accuracy_normal = accuracy_normal
self.accuracy_under_attack = accuracy_under_attack

self.model.compile()
self.model.load()

@abstractmethod
def add_perturbation(self, x: np.array) -> np.array:
pass

def attack(self):
_, test = self.dataset.dataset()

progress = keras.utils.Progbar(test.cardinality().numpy())

for x, y in test:
x_attack = self.add_perturbation(x)

y_attack = self.model.predict(x_attack)
y_normal = self.model.predict(x)

self.accuracy_under_attack(y, y_attack)
self.accuracy_normal(y, y_normal)

progress.add(1)

return self.accuracy_under_attack, self.accuracy_normal

0 comments on commit 0d088e8

Please sign in to comment.