Skip to content

Commit

Permalink
Fix Python type hints according to Python Docs (NVIDIA#5370)
Browse files Browse the repository at this point in the history
* Remove duplicated type annotations

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix tuple annotations in function return types

Signed-off-by: Vladimir Bataev <[email protected]>

* Add necessary imports

Signed-off-by: Vladimir Bataev <[email protected]>

* Add necessary imports

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix types in obvious places

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix types in obvious places

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix unused import (avoid quotes in type annotations)

Signed-off-by: Vladimir Bataev <[email protected]>

* Revert "Fix unused import (avoid quotes in type annotations)"

This reverts commit ea433ef.

Signed-off-by: Vladimir Bataev <[email protected]>

* Remove problematic import

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix list_available_models method type

Signed-off-by: Vladimir Bataev <[email protected]>

* Revert some changes

Signed-off-by: Vladimir Bataev <[email protected]>

* Revert quotes in list_available_models

Signed-off-by: Vladimir Bataev <[email protected]>

Signed-off-by: Vladimir Bataev <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>
  • Loading branch information
artbataev authored and Hainan Xu committed Nov 29, 2022
1 parent 3383e04 commit f3e6e3c
Show file tree
Hide file tree
Showing 33 changed files with 71 additions and 64 deletions.
4 changes: 2 additions & 2 deletions examples/slu/speech_intent_slot/eval_utils/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
# limitations under the License.

import ast
from typing import Dict, List, Union
from typing import Dict, List, Tuple, Union

from .evaluation.metrics.metrics import ErrorMetric


def parse_semantics_str2dict(semantics_str: Union[List[str], str, Dict]) -> Dict:
def parse_semantics_str2dict(semantics_str: Union[List[str], str, Dict]) -> Tuple[Dict, bool]:
"""
This function parse the input string to a valid python dictionary for later evaluation.
Part of this function is adapted from
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/data/audio_to_diar_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from nemo.collections.asr.parts.utils.speaker_utils import convert_rttm_line, prepare_split_data
from nemo.collections.common.parts.preprocessing.collections import DiarizationSpeechLabel
from nemo.core.classes import Dataset
from nemo.core.neural_types import AudioSignal, EncodedRepresentation, LengthsType, NeuralType
from nemo.core.neural_types import AudioSignal, EncodedRepresentation, LengthsType, NeuralType, ProbsType


def get_scale_mapping_list(uniq_timestamps):
Expand Down
8 changes: 4 additions & 4 deletions nemo/collections/asr/data/audio_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import io
import math
import os
from typing import Callable, Dict, Iterable, List, Optional, Union
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

import braceexpand
import numpy as np
Expand Down Expand Up @@ -138,16 +138,16 @@ def __init__(
self.bos_id = bos_id
self.pad_id = pad_id

def process_text_by_id(self, index: int) -> (List[int], int):
def process_text_by_id(self, index: int) -> Tuple[List[int], int]:
sample = self.collection[index]
return self.process_text_by_sample(sample)

def process_text_by_file_id(self, file_id: str) -> (List[int], int):
def process_text_by_file_id(self, file_id: str) -> Tuple[List[int], int]:
manifest_idx = self.collection.mapping[file_id][0]
sample = self.collection[manifest_idx]
return self.process_text_by_sample(sample)

def process_text_by_sample(self, sample: collections.ASRAudioText.OUTPUT_TYPE) -> (List[int], int):
def process_text_by_sample(self, sample: collections.ASRAudioText.OUTPUT_TYPE) -> Tuple[List[int], int]:
t, tl = sample.text_tokens, len(sample.text_tokens)

if self.bos_id is not None:
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/metrics/rnnt_wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import copy
from abc import abstractmethod
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

import editdistance
import numpy as np
Expand Down Expand Up @@ -342,7 +342,7 @@ def rnnt_decoder_predictions_tensor(
encoded_lengths: torch.Tensor,
return_hypotheses: bool = False,
partial_hypotheses: Optional[List[Hypothesis]] = None,
) -> (List[str], Optional[List[List[str]]], Optional[Union[Hypothesis, NBestHypotheses]]):
) -> Tuple[List[str], Optional[List[List[str]]], Optional[Union[Hypothesis, NBestHypotheses]]]:
"""
Decode an encoder output by autoregressive decoding of the Decoder+Joint networks.
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/metrics/wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from abc import abstractmethod
from dataclasses import dataclass, is_dataclass
from typing import Callable, Dict, List, Optional, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

import editdistance
import numpy as np
Expand Down Expand Up @@ -228,7 +228,7 @@ def ctc_decoder_predictions_tensor(
decoder_lengths: torch.Tensor = None,
fold_consecutive: bool = True,
return_hypotheses: bool = False,
) -> (List[str], Optional[List[List[str]]], Optional[Union[Hypothesis, NBestHypotheses]]):
) -> Tuple[List[str], Optional[List[List[str]]], Optional[Union[Hypothesis, NBestHypotheses]]]:
"""
Decodes a sequence of labels to words
Expand Down
3 changes: 2 additions & 1 deletion nemo/collections/asr/models/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.
import logging
from abc import ABC, abstractmethod
from typing import List, Optional, Union
from typing import List

import torch

from nemo.core.classes import ModelPT
from nemo.core.classes.common import PretrainedModelInfo
from nemo.core.classes.exportable import Exportable
from nemo.core.classes.mixins import AccessMixin
from nemo.utils import cast_all, logging, model_utils
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/models/ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import copy
import os
from math import isclose
from typing import Dict, Optional, Union
from typing import Dict, List, Optional, Union

import torch
from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict
Expand Down Expand Up @@ -371,7 +371,7 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig):
logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}")

@classmethod
def list_available_models(cls) -> Optional[PretrainedModelInfo]:
def list_available_models(cls) -> List[PretrainedModelInfo]:
"""
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
return temporary_datalayer

@classmethod
def list_available_models(cls) -> Optional[PretrainedModelInfo]:
def list_available_models(cls) -> List[PretrainedModelInfo]:
"""
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/models/k2_sequence_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self._init_k2()

@classmethod
def list_available_models(cls) -> Optional[PretrainedModelInfo]:
def list_available_models(cls) -> Optional[List[PretrainedModelInfo]]:
"""
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
Expand Down Expand Up @@ -112,7 +112,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self._init_k2()

@classmethod
def list_available_models(cls) -> Optional[PretrainedModelInfo]:
def list_available_models(cls) -> Optional[List[PretrainedModelInfo]]:
"""
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
Expand Down
6 changes: 3 additions & 3 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import os
import tempfile
from math import ceil, isclose
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union

import torch
from omegaconf import DictConfig, OmegaConf, open_dict
Expand Down Expand Up @@ -217,7 +217,7 @@ def transcribe(
partial_hypothesis: Optional[List['Hypothesis']] = None,
num_workers: int = 0,
channel_selector: Optional[ChannelSelectorType] = None,
) -> (List[str], Optional[List['Hypothesis']]):
) -> Tuple[List[str], Optional[List['Hypothesis']]]:
"""
Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping.
Expand Down Expand Up @@ -996,7 +996,7 @@ def decoder_joint(self):
return RNNTDecoderJoint(self.decoder, self.joint)

@classmethod
def list_available_models(cls) -> Optional[PretrainedModelInfo]:
def list_available_models(cls) -> List[PretrainedModelInfo]:
"""
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/models/ssl_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from math import ceil
from typing import Dict, Optional, Union
from typing import Dict, List, Optional, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -44,7 +44,7 @@ class SpeechEncDecSelfSupervisedModel(ModelPT, ASRModuleMixin, AccessMixin):
"""Base class for encoder-decoder models used for self-supervised encoder pre-training"""

@classmethod
def list_available_models(cls) -> Optional[PretrainedModelInfo]:
def list_available_models(cls) -> List[PretrainedModelInfo]:
"""
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
Expand Down
12 changes: 6 additions & 6 deletions nemo/collections/asr/modules/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def predict(
state: Optional[torch.Tensor] = None,
add_sos: bool = True,
batch_size: Optional[int] = None,
) -> (torch.Tensor, List[torch.Tensor]):
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Stateful prediction of scores and state for a tokenset.
Expand Down Expand Up @@ -249,7 +249,7 @@ def _predict_modules(self, **kwargs):

def score_hypothesis(
self, hypothesis: rnnt_utils.Hypothesis, cache: Dict[Tuple[int], Any]
) -> (torch.Tensor, List[torch.Tensor], torch.Tensor):
) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]:
"""
Similar to the predict() method, instead this method scores a Hypothesis during beam search.
Hypothesis is a dataclass representing one hypothesis in a Beam Search.
Expand Down Expand Up @@ -400,7 +400,7 @@ def batch_copy_states(

def batch_score_hypothesis(
self, hypotheses: List[rnnt_utils.Hypothesis], cache: Dict[Tuple[int], Any], batch_states: List[torch.Tensor]
) -> (torch.Tensor, List[torch.Tensor], torch.Tensor):
) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]:
"""
Used for batched beam search algorithms. Similar to score_hypothesis method.
Expand Down Expand Up @@ -626,7 +626,7 @@ def predict(
state: Optional[List[torch.Tensor]] = None,
add_sos: bool = True,
batch_size: Optional[int] = None,
) -> (torch.Tensor, List[torch.Tensor]):
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Stateful prediction of scores and state for a (possibly null) tokenset.
This method takes various cases into consideration :
Expand Down Expand Up @@ -803,7 +803,7 @@ def initialize_state(self, y: torch.Tensor) -> List[torch.Tensor]:

def score_hypothesis(
self, hypothesis: rnnt_utils.Hypothesis, cache: Dict[Tuple[int], Any]
) -> (torch.Tensor, List[torch.Tensor], torch.Tensor):
) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]:
"""
Similar to the predict() method, instead this method scores a Hypothesis during beam search.
Hypothesis is a dataclass representing one hypothesis in a Beam Search.
Expand Down Expand Up @@ -856,7 +856,7 @@ def score_hypothesis(

def batch_score_hypothesis(
self, hypotheses: List[rnnt_utils.Hypothesis], cache: Dict[Tuple[int], Any], batch_states: List[torch.Tensor]
) -> (torch.Tensor, List[torch.Tensor], torch.Tensor):
) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]:
"""
Used for batched beam search algorithms. Similar to score_hypothesis method.
Expand Down
6 changes: 3 additions & 3 deletions nemo/collections/asr/modules/rnnt_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def predict(
state: Optional[torch.Tensor] = None,
add_sos: bool = False,
batch_size: Optional[int] = None,
) -> (torch.Tensor, List[torch.Tensor]):
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Stateful prediction of scores and state for a (possibly null) tokenset.
This method takes various cases into consideration :
Expand Down Expand Up @@ -165,7 +165,7 @@ def initialize_state(self, y: torch.Tensor) -> List[torch.Tensor]:
@abstractmethod
def score_hypothesis(
self, hypothesis: Hypothesis, cache: Dict[Tuple[int], Any]
) -> (torch.Tensor, List[torch.Tensor], torch.Tensor):
) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]:
"""
Similar to the predict() method, instead this method scores a Hypothesis during beam search.
Hypothesis is a dataclass representing one hypothesis in a Beam Search.
Expand All @@ -184,7 +184,7 @@ def score_hypothesis(

def batch_score_hypothesis(
self, hypotheses: List[Hypothesis], cache: Dict[Tuple[int], Any], batch_states: List[torch.Tensor]
) -> (torch.Tensor, List[torch.Tensor], torch.Tensor):
) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]:
"""
Used for batched beam search algorithms. Similar to score_hypothesis method.
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/parts/k2/ml_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Union
from typing import Optional, Tuple, Union

import torch
from omegaconf import DictConfig
Expand Down Expand Up @@ -105,7 +105,7 @@ def forward(
targets: torch.Tensor,
input_lengths: torch.Tensor,
target_lengths: torch.Tensor,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.blank != 0:
# rearrange log_probs to put blank at the first place
# and shift targets to emulate blank = 0
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/parts/mixins/asr_adapter_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional
from typing import List, Optional, Tuple

from omegaconf import DictConfig, open_dict

Expand Down Expand Up @@ -247,7 +247,7 @@ def check_valid_model_with_adapter_support_(self):
f'{self.joint.__class__.__name__} does not implement `AdapterModuleMixin`', mode=logging_mode.ONCE
)

def resolve_adapter_module_name_(self, name: str) -> (str, str):
def resolve_adapter_module_name_(self, name: str) -> Tuple[str, str]:
"""
Utility method to resolve a given global/module adapter name to its components.
Always returns a tuple representing (module_name, adapter_name). ":" is used as the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def score_forward(

return self.compute_cost_and_score(acts, None, costs, pad_labels, label_lengths, input_lengths)

def _prepare_workspace(self) -> (int, Tuple[torch.Tensor]):
def _prepare_workspace(self) -> Tuple[int, Tuple[torch.Tensor, ...]]:
"""
Helper method that uses the workspace and constructs slices of it that can be used.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


import math
from typing import Optional
from typing import Optional, Tuple

import torch
from numba import cuda
Expand Down Expand Up @@ -117,7 +117,7 @@ def compute_costs_data(source: torch.Tensor, dest: torch.Tensor, fastemit_lambda

def get_workspace_size(
maxT: int, maxU: int, minibatch: int, gpu: bool
) -> (Optional[int], global_constants.RNNTStatus):
) -> Tuple[Optional[int], global_constants.RNNTStatus]:

if minibatch <= 0 or maxT <= 0 or maxU <= 0:
return (None, global_constants.RNNTStatus.RNNT_STATUS_INVALID_VALUE)
Expand Down
3 changes: 1 addition & 2 deletions nemo/collections/asr/parts/submodules/jasper.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,7 @@ def _get_act_dropout_layer(self, drop_prob=0.2, activation=None):
layers = [activation, nn.Dropout(p=drop_prob)]
return layers

def forward(self, input_: Tuple[List[Tensor], Optional[Tensor]]):
def forward(self, input_: Tuple[List[Tensor], Optional[Tensor]]) -> Tuple[List[Tensor], Optional[Tensor]]:
"""
Forward pass of the module.
Expand All @@ -984,7 +984,6 @@ def forward(self, input_: Tuple[List[Tensor], Optional[Tensor]]):
The output of the block after processing the input through `repeat` number of sub-blocks,
as well as the lengths of the encoded audio after padding/striding.
"""
# type: (Tuple[List[Tensor], Optional[Tensor]]) -> Tuple[List[Tensor], Optional[Tensor]] # nopep8
lens_orig = None
xs = input_[0]
if len(input_) == 2:
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# limitations under the License.

from dataclasses import dataclass
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -181,7 +181,7 @@ def _pred_step(
hidden: Optional[torch.Tensor],
add_sos: bool = False,
batch_size: Optional[int] = None,
) -> (torch.Tensor, torch.Tensor):
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Common prediction step based on the AbstractRNNTDecoder implementation.
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/asr/parts/utils/transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import List

from tqdm.auto import tqdm
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/common/parts/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def __init__(self, factor: int):
super().__init__()
self.factor = int(factor)

def forward(self, x: List[Tuple[torch.Tensor]]) -> (torch.Tensor, torch.Tensor):
def forward(self, x: List[Tuple[torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
# T, B, U
x, x_lens = x
seq = [x]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def predict(
return result

@classmethod
def list_available_models(cls) -> Optional[PretrainedModelInfo]:
def list_available_models(cls) -> List[PretrainedModelInfo]:
"""
This method returns a list of pre-trained models which can be instantiated directly from NVIDIA's NGC cloud.
Expand Down
Loading

0 comments on commit f3e6e3c

Please sign in to comment.