Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: Ssofja <[email protected]>
  • Loading branch information
Ssofja committed Jan 17, 2025
1 parent e74bcea commit fb47acc
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
3 changes: 1 addition & 2 deletions nemo/collections/asr/models/classification_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def _update_decoder_config(self, labels, cfg):
cfg.num_classes = len(labels)

OmegaConf.set_struct(cfg, True)

def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self._update_decoder_config(cfg.labels, cfg.decoder)
super().__init__(cfg, trainer)
Expand Down Expand Up @@ -562,7 +562,6 @@ def change_labels(self, new_labels: List[str]):
logging.info(f"Changed decoder output to {self.decoder.num_classes} labels.")



class EncDecRegressionModel(_EncDecBaseModel):
"""Encoder decoder class for speech regression models.
Model class creates training, validation methods for setting up data
Expand Down
10 changes: 8 additions & 2 deletions tests/collections/asr/test_asr_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ def speech_classification_model():

decoder = {
'cls': 'nemo.collections.asr.modules.ConvASRDecoderClassification',
'params': {'feat_in': 32, 'num_classes': 30,},
'params': {
'feat_in': 32,
'num_classes': 30,
},
}

modelConfig = DictConfig(
Expand Down Expand Up @@ -95,7 +98,10 @@ def frame_classification_model():

decoder = {
'cls': 'nemo.collections.common.parts.MultiLayerPerceptron',
'params': {'hidden_size': 32, 'num_classes': 5,},
'params': {
'hidden_size': 32,
'num_classes': 5,
},
}

modelConfig = DictConfig(
Expand Down

0 comments on commit fb47acc

Please sign in to comment.