Skip to content

Commit

Permalink
(#5) TensorBoard: Logging pre/post training
Browse files Browse the repository at this point in the history
  • Loading branch information
betarixm committed Mar 23, 2022
1 parent 7e93f8d commit e082c0a
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/defense/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typings.models import Model
from utils.dataset import NoisyMnist

import numpy as np
import tensorflow as tf

keras = tf.keras
Expand Down Expand Up @@ -53,6 +54,14 @@ def layer_conv2d():
]
)

def pre_train(self):
with self.tensorboard_file_writer().as_default():
x = np.concatenate([x for x, y in self.data_test.take(1)], axis=0)
tf.summary.image(f"{self.name()} test input", x, step=0)

def post_train(self):
pass


if __name__ == "__main__":
train_set, test_set = NoisyMnist()
Expand Down
9 changes: 9 additions & 0 deletions src/typings/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ def _model(self) -> keras.Model:
pass

@abstractmethod
def pre_train(self):
pass

@abstractmethod
def post_train(self):
pass


def name(self) -> str:
return self._name
Expand All @@ -94,6 +101,7 @@ def tensorboard_file_writer(self) -> tf.summary.SummaryWriter:


def train(self, epochs: int = 100):
self.pre_train()

self.__model.compile(
optimizer=self.optimizer,
Expand All @@ -116,3 +124,4 @@ def train(self, epochs: int = 100):

self.__model.evaluate(self.data_test)

self.post_train()
10 changes: 10 additions & 0 deletions src/victim/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typings.models import Model
from utils.dataset import Mnist as MnistDataset

import numpy as np
import tensorflow as tf

keras = tf.keras
Expand Down Expand Up @@ -46,6 +47,15 @@ def _model(self) -> keras.Model:
)



def pre_train(self):
with self.tensorboard_file_writer().as_default():
x = np.concatenate([x for x, y in self.data_test.take(1)], axis=0)
tf.summary.image(f"{self.name()} test input", x, max_outputs=25, step=0)

def post_train(self):
pass

if __name__ == "__main__":
train_set, test_set = MnistDataset()
mnist = Mnist(train_set, test_set, (28, 28, 1))
Expand Down

0 comments on commit e082c0a

Please sign in to comment.