diff --git a/src/utils/dataset.py b/src/utils/dataset.py index 68eea81..bdb419c 100644 --- a/src/utils/dataset.py +++ b/src/utils/dataset.py @@ -30,14 +30,17 @@ def postprocess( return (x_train, y_train), (x_test, y_test) - x_train_noisy, x_test_noisy = noisy(x_train), noisy(x_test) +class Cifar10(ImageDataset): + def __init__(self): + super().__init__(keras.datasets.cifar10) - train_ds = ( - tf.data.Dataset.from_tensor_slices((x_train_noisy, x_train)) - .shuffle(10000) - .batch(32) - ) - test_ds = tf.data.Dataset.from_tensor_slices((x_test_noisy, x_test)).batch(32) +class NoisyCifar10(ImageDataset, NoisyMixin): + def __init__(self): + super().__init__(keras.datasets.cifar10) - return train_ds, test_ds + def postprocess( + self, train: NumpyDataset, test: NumpyDataset + ) -> Tuple[NumpyDataset, NumpyDataset]: + (x_train, y_train), (x_test, y_test) = NoisyMixin.process(train, test) + return (x_train, y_train), (x_test, y_test)