Skip to content

Commit

Permalink
(#3) Defense: Refine exformer
Browse files Browse the repository at this point in the history
  • Loading branch information
betarixm committed May 25, 2022
1 parent eb70dcd commit 9c570b5
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 9 deletions.
4 changes: 2 additions & 2 deletions src/attack/attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typings.models import Attack, Defense
from utils.logging import concat_batch_images
from models import FgsmMnist, FgsmCifar, PgdMnist, PgdCifar, Cw, NormalNoise
from defense.models import Reformer, Denoiser, Motd, ExMotd
from defense.models import Reformer, Exformer, Denoiser, Motd, ExMotd
from victim.models import Classifier

from utils.dataset import Mnist, Cifar10
Expand Down Expand Up @@ -111,7 +111,7 @@
defense_model.compile()
defense_model.load()
elif args.defense == "exformer":
defense_model = Reformer(
defense_model = Exformer(
f"defense_exformer_{args.dataset}",
input_shape=input_shape,
intensity=args.intensity[0],
Expand Down
54 changes: 52 additions & 2 deletions src/defense/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(
data_train: tf.data.Dataset = None,
data_test: tf.data.Dataset = None,
optimizer: keras.optimizers.Optimizer = keras.optimizers.Adam(),
loss: keras.losses.Loss = keras.losses.MeanSquaredError(),
loss: keras.losses.Loss = keras.losses.BinaryCrossentropy(),
accuracy: keras.metrics.Accuracy = keras.metrics.CategoricalAccuracy(
name="accuracy"
),
Expand All @@ -36,6 +36,7 @@ def __init__(
accuracy,
checkpoint_filepath,
tensorboard_log_path,
is_functional=True,
)

def _model(self) -> keras.Model:
Expand Down Expand Up @@ -90,6 +91,55 @@ def predict(epoch, logs):
return [reduce_lr, keras.callbacks.LambdaCallback(on_epoch_end=predict)]


class Exformer(Reformer):
def _model(self) -> keras.Model:
"""
Reference: https://www.kaggle.com/code/tarunk04/autoencoder-denoising-image-mnist-cifar10/notebook#Denoising-Cifar10-Data
Denoising Autoencoder with Skip Connection
"""

inputs = keras.layers.Input(shape=self.input_shape())
# Encoder - 1
x = keras.layers.Conv2D(32, 3, activation="relu", padding="same")(inputs)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.MaxPool2D()(x)
x = keras.layers.Dropout(0.5)(x)
# Encoder - 2
skip = keras.layers.Conv2D(32, 3, padding="same")(x)
x = keras.layers.LeakyReLU()(skip)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.MaxPool2D()(x)
x = keras.layers.Dropout(0.5)(x)
# Encoder - Finalize
x = keras.layers.Conv2D(64, 3, activation="relu", padding="same")(x)
x = keras.layers.BatchNormalization()(x)
encoded = keras.layers.MaxPool2D()(x)

# Decoder - 1
x = keras.layers.Conv2DTranspose(
64, 3, activation="relu", strides=(2, 2), padding="same"
)(encoded)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Dropout(0.5)(x)
# Decoder - 2
x = keras.layers.Conv2DTranspose(
32, 3, activation="relu", strides=(2, 2), padding="same"
)(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Dropout(0.5)(x)
# Decoder - 3
x = keras.layers.Conv2DTranspose(32, 3, padding="same")(x)
x = keras.layers.Add()([skip, x])
x = keras.layers.LeakyReLU()(x)
x = keras.layers.BatchNormalization()(x)
# Decoder - Finalize
decoded = keras.layers.Conv2DTranspose(
3, 3, activation="sigmoid", strides=(2, 2), padding="same"
)(x)

return keras.Model(inputs, decoded)


class Denoiser(Defense):
def _model(self) -> keras.Model:
return keras.Sequential([SlqLayer()])
Expand Down Expand Up @@ -158,7 +208,7 @@ def __init__(
input_shape=input_shape,
intensity=intensities[0],
)
self.exformer = Reformer(
self.exformer = Exformer(
f"defense_exformer_{dataset}",
input_shape=input_shape,
intensity=intensities[1],
Expand Down
4 changes: 2 additions & 2 deletions src/defense/process.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from utils.logging import concat_batch_images
from defense.models import Reformer, Denoiser, Motd, ExMotd
from defense.models import Reformer, Exformer, Denoiser, Motd, ExMotd

import tensorflow as tf

Expand Down Expand Up @@ -60,7 +60,7 @@
intensity=args.intensity[0],
)
elif args.defense == "exformer":
defense_model = Reformer(
defense_model = Exformer(
f"defense_exformer_{args.dataset}",
input_shape=input_shape,
intensity=args.intensity[0],
Expand Down
6 changes: 3 additions & 3 deletions src/defense/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from utils.dataset import NoisyMnist, NoisyCifar10, ExCifar10
from models import Reformer
from models import Reformer, Exformer

import argparse
import tensorflow as tf
Expand Down Expand Up @@ -33,14 +33,14 @@ def train_cifar10_reformer(epochs: int = 100):

def train_excifar10_exformer(epochs: int = 100):
train_set, test_set = ExCifar10().dataset()
reformer = Reformer(
exformer = Exformer(
"defense_exformer_cifar10",
(32, 32, 3),
1.0,
train_set,
test_set,
)
reformer.train(epochs)
exformer.train(epochs)


if __name__ == "__main__":
Expand Down

0 comments on commit 9c570b5

Please sign in to comment.