Skip to content

Commit

Permalink
(#3) Defense: Logging images generated by defense models
Browse files Browse the repository at this point in the history
  • Loading branch information
betarixm committed May 2, 2022
1 parent 76a2ed3 commit b874d70
Showing 1 changed file with 63 additions and 0 deletions.
63 changes: 63 additions & 0 deletions src/defense/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,66 @@ def post_train(self):

def custom_callbacks(self) -> List[keras.callbacks.Callback]:
pass


if __name__ == "__main__":
from utils.dataset import Mnist, Cifar10
import argparse

parser = argparse.ArgumentParser(
description="Logging images generated by defense models"
)

parser.add_argument(
"--dataset",
"-d",
metavar="DATASET",
type=str,
help="Dataset for testing",
required=True,
choices=["mnist", "cifar10"],
)

parser.add_argument(
"--defense",
"-f",
metavar="DEFENSE",
type=str,
help="Defense method",
required=True,
choices=["reformer", "denoiser", "motd"],
)

args = parser.parse_args()

if args.dataset == "mnist":
input_shape = (28, 28, 1)
_, test_set = Mnist().dataset()

else:
input_shape = (32, 32, 3)
_, test_set = Cifar10().dataset()

if args.defense == "reformer":
defense_model = Reformer(
f"defense_reformer_{args.dataset}", input_shape=input_shape
)
elif args.defense == "denoiser":
defense_model = Denoiser(
f"defense_denoiser_{args.dataset}", input_shape=input_shape
)
else:
defense_model = Motd(
f"defense_motd_{args.dataset}",
input_shape=input_shape,
dataset=args.dataset,
)

progress = keras.utils.Progbar(1)

with defense_model.tensorboard_file_writer().as_default():
for idx, (x, _) in enumerate(test_set.take(1)):
y = defense_model.predict(x)
tf.summary.image(f"{args.defense.upper()} result - {idx}", [x, y])

progress.add(1)

0 comments on commit b874d70

Please sign in to comment.