Skip to content

Commit

Permalink
Fix docstring and error message (#211)
Browse files Browse the repository at this point in the history
  • Loading branch information
marksgraham authored Jan 31, 2023
1 parent 98df275 commit b0e3937
Showing 1 changed file with 16 additions and 18 deletions.
34 changes: 16 additions & 18 deletions generative/networks/nets/vqvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,26 +100,24 @@ class VQVAE(nn.Module):
spatial_dims: number of spatial spatial_dims.
in_channels: number of input channels.
out_channels: number of output channels.
num_levels: number of levels that the network has. Defaults to 3.
num_levels: number of levels that the network has.
downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the
following information stride (int), kernel_size (int), dilation(int) and padding (int).
Defaults to ((2,4,1,1),(2,4,1,1),(2,4,1,1)).
following information stride (int), kernel_size (int), dilation (int) and padding (int).
upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the
following information stride (int), kernel_size (int), dilation(int), padding (int), output_padding (int).
following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int).
If use_subpixel_conv is True, only the stride will be used for the last conv as the scale_factor.
Defaults to ((2,4,1,1,0),(2,4,1,1,0),(2,4,1,1,0)).
num_res_layers: number of sequential residual layers at each level. Defaults to 3.
num_channels: number of channels at the deepest level, besides that is num_channels//2 . Defaults to 192.
num_res_channels: number of channels in the residual layers. Defaults to 64.
num_embeddings: VectorQuantization number of atomic elements in the codebook. Defaults to 32.
embedding_dim: VectorQuantization number of channels of the input and atomic elements. Defaults to 64.
commitment_cost: VectorQuantization commitment_cost. Defaults to 0.25.
decay: VectorQuantization decay. Defaults to 0.5.
epsilon: VectorQuantization epsilon. Defaults to 1e-5 as.
adn_ordering: a string representing the ordering of activation, normalization, and dropout. Defaults to "NDA".
act: activation type and arguments. Defaults to Relu.
dropout: dropout ratio. Defaults to 0.1.
ddp_sync: whether to synchronize the codebook across processes. Defaults to True.
num_res_layers: number of sequential residual layers at each level.
num_channels: number of channels at each level.
num_res_channels: number of channels in the residual layers at each level.
num_embeddings: VectorQuantization number of atomic elements in the codebook.
embedding_dim: VectorQuantization number of channels of the input and atomic elements.
commitment_cost: VectorQuantization commitment_cost.
decay: VectorQuantization decay.
epsilon: VectorQuantization epsilon.
adn_ordering: a string representing the ordering of activation, normalization, and dropout, e.g. "NDA".
act: activation type and arguments.
dropout: dropout ratio.
ddp_sync: whether to synchronize the codebook across processes.
"""

# < Python 3.9 TorchScript requirement for ModuleList
Expand Down Expand Up @@ -166,7 +164,7 @@ def __init__(
), (
f"downsample_parameters, upsample_parameters, num_channels and num_res_channels must have the same number of"
f" elements as num_levels. But got {len(downsample_parameters)}, {len(upsample_parameters)}, "
f"{len(num_res_channels)} and {len(num_res_channels)} instead of {num_levels}."
f"{len(num_channels)} and {len(num_res_channels)} instead of {num_levels}."
)

self.num_levels = num_levels
Expand Down

0 comments on commit b0e3937

Please sign in to comment.