Skip to content

Commit

Permalink
(#6) Dataset: add CIFAR-10
Browse files Browse the repository at this point in the history
  • Loading branch information
betarixm committed Mar 24, 2022
1 parent c5d237b commit 12f29be
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions src/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,17 @@ def postprocess(
return (x_train, y_train), (x_test, y_test)


x_train_noisy, x_test_noisy = noisy(x_train), noisy(x_test)
class Cifar10(ImageDataset):
def __init__(self):
super().__init__(keras.datasets.cifar10)

train_ds = (
tf.data.Dataset.from_tensor_slices((x_train_noisy, x_train))
.shuffle(10000)
.batch(32)
)

test_ds = tf.data.Dataset.from_tensor_slices((x_test_noisy, x_test)).batch(32)
class NoisyCifar10(ImageDataset, NoisyMixin):
def __init__(self):
super().__init__(keras.datasets.cifar10)

return train_ds, test_ds
def postprocess(
self, train: NumpyDataset, test: NumpyDataset
) -> Tuple[NumpyDataset, NumpyDataset]:
(x_train, y_train), (x_test, y_test) = NoisyMixin.process(train, test)
return (x_train, y_train), (x_test, y_test)

0 comments on commit 12f29be

Please sign in to comment.