Skip to content

Commit

Permalink
(#2) Victim: Add evaluate mode
Browse files Browse the repository at this point in the history
  • Loading branch information
betarixm committed Apr 18, 2022
1 parent 9ff0b51 commit ce6378a
Showing 1 changed file with 30 additions and 25 deletions.
55 changes: 30 additions & 25 deletions src/victim/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,11 @@

keras = tf.keras


def train_mnist_classifier(epochs: int = 100):
train_set, test_set = Mnist().dataset()
classifier = Classifier(
"victim_classifier_mnist",
(28, 28, 1),
train_set,
test_set,
)
classifier.train(epochs)


def train_cifar10_classifier(epochs: int = 100):
train_set, test_set = Cifar10().dataset()
classifier = Classifier(
"victim_classifier_cifar10",
(32, 32, 3),
train_set,
test_set,
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Training or evaluating victim classifier models."
)
classifier.train(epochs)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Training victim classifier models.")
parser.add_argument(
"--dataset",
"-d",
Expand All @@ -50,11 +30,36 @@ def train_cifar10_classifier(epochs: int = 100):
required=False,
)

parser.add_argument(
"--evaluate",
"-v",
help="Evaluate trained model.",
action="store_true",
)

args = parser.parse_args()

e = args.epochs if args.epochs is not None else 500

if args.dataset == "mnist":
train_mnist_classifier(e)
train_set, test_set = Mnist().dataset()
classifier = Classifier(
"victim_classifier_mnist",
(28, 28, 1),
train_set,
test_set,
)

elif args.dataset == "cifar10":
train_cifar10_classifier(e)
train_set, test_set = Cifar10().dataset()
classifier = Classifier(
"victim_classifier_cifar10",
(32, 32, 3),
train_set,
test_set,
)

if args.evaluate:
print(classifier.evaluate())
else:
classifier.train(e)

0 comments on commit ce6378a

Please sign in to comment.