diff --git a/src/defense/models.py b/src/defense/models.py index c45f103..1775fda 100644 --- a/src/defense/models.py +++ b/src/defense/models.py @@ -101,3 +101,66 @@ def post_train(self): def custom_callbacks(self) -> List[keras.callbacks.Callback]: pass + + +if __name__ == "__main__": + from utils.dataset import Mnist, Cifar10 + import argparse + + parser = argparse.ArgumentParser( + description="Logging images generated by defense models" + ) + + parser.add_argument( + "--dataset", + "-d", + metavar="DATASET", + type=str, + help="Dataset for testing", + required=True, + choices=["mnist", "cifar10"], + ) + + parser.add_argument( + "--defense", + "-f", + metavar="DEFENSE", + type=str, + help="Defense method", + required=True, + choices=["reformer", "denoiser", "motd"], + ) + + args = parser.parse_args() + + if args.dataset == "mnist": + input_shape = (28, 28, 1) + _, test_set = Mnist().dataset() + + else: + input_shape = (32, 32, 3) + _, test_set = Cifar10().dataset() + + if args.defense == "reformer": + defense_model = Reformer( + f"defense_reformer_{args.dataset}", input_shape=input_shape + ) + elif args.defense == "denoiser": + defense_model = Denoiser( + f"defense_denoiser_{args.dataset}", input_shape=input_shape + ) + else: + defense_model = Motd( + f"defense_motd_{args.dataset}", + input_shape=input_shape, + dataset=args.dataset, + ) + + progress = keras.utils.Progbar(1) + + with defense_model.tensorboard_file_writer().as_default(): + for idx, (x, _) in enumerate(test_set.take(1)): + y = defense_model.predict(x) + tf.summary.image(f"{args.defense.upper()} result - {idx}", [x, y]) + + progress.add(1)