diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index 1f55043..82f37dd 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -63,14 +63,14 @@ def __init__( window_size (tuple[int, int, int], optional): Vertical height, height, and width of the window of the underlying Swin transformer. encoder_depths (tuple[int, ...], optional): Number of blocks in each encoder layer. - encoder_num_heads (tuple[int, ...], optional) Number of attention heads in each encoder + encoder_num_heads (tuple[int, ...], optional): Number of attention heads in each encoder layer. The dimensionality doubles after every layer. To keep the dimensionality of every head constant, you want to double the number of heads after every layer. The dimensionality of attention head of the first layer is determined by `embed_dim` divided by the value here. For all cases except one, this is equal to `64`. decoder_depths (tuple[int, ...], optional): Number of blocks in each decoder layer. Generally, you want this to be the reversal of `encoder_depths`. - decoder_num_heads (tuple[int, ...], optional) Number of attention heads in each decoder + decoder_num_heads (tuple[int, ...], optional): Number of attention heads in each decoder layer. Generally, you want this to be the reversal of `encoder_num_heads`. latent_levels (int, optional): Number of latent pressure levels. patch_size (int, optional): Patch size.