diff --git a/pl_bolts/models/gans/basic/basic_gan_module.py b/pl_bolts/models/gans/basic/basic_gan_module.py index 365c5d0fc8..8f91be200c 100644 --- a/pl_bolts/models/gans/basic/basic_gan_module.py +++ b/pl_bolts/models/gans/basic/basic_gan_module.py @@ -40,12 +40,11 @@ def __init__( ): """ Args: - datamodule: the datamodule (train, val, test splits) + input_channels: number of channels of an image + input_height: image height + input_width: image width latent_dim: emb dim for encoder - batch_size: the batch size learning_rate: the learning rate - data_dir: where to store data - num_workers: data workers """ super().__init__()