diff --git a/src/victim/train.py b/src/victim/train.py index 214ffc2..f335444 100644 --- a/src/victim/train.py +++ b/src/victim/train.py @@ -6,31 +6,11 @@ keras = tf.keras - -def train_mnist_classifier(epochs: int = 100): - train_set, test_set = Mnist().dataset() - classifier = Classifier( - "victim_classifier_mnist", - (28, 28, 1), - train_set, - test_set, - ) - classifier.train(epochs) - - -def train_cifar10_classifier(epochs: int = 100): - train_set, test_set = Cifar10().dataset() - classifier = Classifier( - "victim_classifier_cifar10", - (32, 32, 3), - train_set, - test_set, +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Training or evaluating victim classifier models." ) - classifier.train(epochs) - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Training victim classifier models.") parser.add_argument( "--dataset", "-d", @@ -50,11 +30,36 @@ def train_cifar10_classifier(epochs: int = 100): required=False, ) + parser.add_argument( + "--evaluate", + "-v", + help="Evaluate trained model.", + action="store_true", + ) + args = parser.parse_args() e = args.epochs if args.epochs is not None else 500 if args.dataset == "mnist": - train_mnist_classifier(e) + train_set, test_set = Mnist().dataset() + classifier = Classifier( + "victim_classifier_mnist", + (28, 28, 1), + train_set, + test_set, + ) + elif args.dataset == "cifar10": - train_cifar10_classifier(e) + train_set, test_set = Cifar10().dataset() + classifier = Classifier( + "victim_classifier_cifar10", + (32, 32, 3), + train_set, + test_set, + ) + + if args.evaluate: + print(classifier.evaluate()) + else: + classifier.train(e)