diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index 5695480..169fd52 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -50,7 +50,10 @@ def __init__( Args: surf_vars (tuple[str, ...], optional): All surface-level variables supported by the - model. The model is sensitive to the order of `surf_vars`! + model. The model is sensitive to the order of `surf_vars`! Currently, adding + one more variable here causes the model to incorrectly load the static variables. + It is possible to hack around this. We are working on a more principled fix. Please + open an issue if this is a problem for you. static_vars (tuple[str, ...], optional): All static variables supported by the model. The model is sensitive to the order of `static_vars`! atmos_vars (tuple[str, ...], optional): All atmospheric variables supported by the @@ -58,16 +61,21 @@ def __init__( window_size (tuple[int, int, int], optional): Vertical height, height, and width of the window of the underlying Swin transformer. encoder_depths (tuple[int, ...], optional): Number of blocks in each encoder layer. - encoder_num_header (tuple[int, ...], optional) Number of attention heads in each encoder - layer. + encoder_num_heads (tuple[int, ...], optional) Number of attention heads in each encoder + layer. The dimensionality doubles after every layer. To keep the dimensionality of + every head constant, you want to double the number of heads after every layer. The + dimensionality of attention head of the first layer is determined by `embed_dim` + divided by the value here. For all cases except one, this is equal to `64`. decoder_depths (tuple[int, ...], optional): Number of blocks in each decoder layer. - decoder_num_header (tuple[int, ...], optional) Number of attention heads in each decoder - layer. + Generally, you want this to be the reversal of `encoder_depths`. + decoder_num_heads (tuple[int, ...], optional) Number of attention heads in each decoder + layer. Generally, you want this to be the reversal of `encoder_num_heads`. latent_levels (int, optional): Number of latent pressure levels. patch_size (int, optional): Patch size. embed_dim (int, optional): Patch embedding dimension. num_heads (int, optional): Number of attention heads in the aggregation and - deaggregation blocks. + deaggregation blocks. The dimensionality of these attention heads will be equal to + `embed_dim` divided by this value. mlp_ratio (float, optional): Hidden dim. to embedding dim. ratio for MLPs. drop_rate (float, optional): Drop-out rate. drop_path (float, optional): Drop-path rate.