diff --git a/examples/avsr/lightning.py b/examples/avsr/lightning.py index cbca0d976a..7d00366ff5 100644 --- a/examples/avsr/lightning.py +++ b/examples/avsr/lightning.py @@ -84,8 +84,6 @@ def __init__(self, args=None, sp_model=None, pretrained_model_path=None): betas=(0.9, 0.98), ) - self.automatic_optimization = False - def _step(self, batch, _, step_type): if batch is None: return None @@ -123,20 +121,10 @@ def forward(self, batch): return post_process_hypos(hypotheses, self.sp_model)[0][0] def training_step(self, batch, batch_idx): - opt = self.optimizers() - opt.zero_grad() loss = self._step(batch, batch_idx, "train") batch_size = batch.inputs.size(0) batch_sizes = self.all_gather(batch_size) - loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size - self.manual_backward(loss) - torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10) - opt.step() - - sch = self.lr_schedulers() - sch.step() - self.log("monitoring_step", torch.tensor(self.global_step, dtype=torch.float32)) return loss diff --git a/examples/avsr/lightning_av.py b/examples/avsr/lightning_av.py index 94529c0229..5be0907705 100644 --- a/examples/avsr/lightning_av.py +++ b/examples/avsr/lightning_av.py @@ -80,8 +80,6 @@ def __init__(self, args=None, sp_model=None): betas=(0.9, 0.98), ) - self.automatic_optimization = False - def _step(self, batch, _, step_type): if batch is None: return None @@ -128,20 +126,10 @@ def forward(self, batch): return post_process_hypos(hypotheses, self.sp_model)[0][0] def training_step(self, batch, batch_idx): - opt = self.optimizers() - opt.zero_grad() loss = self._step(batch, batch_idx, "train") batch_size = batch.videos.size(0) batch_sizes = self.all_gather(batch_size) - loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size - self.manual_backward(loss) - torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10) - opt.step() - - sch = self.lr_schedulers() - sch.step() - self.log("monitoring_step", torch.tensor(self.global_step, dtype=torch.float32)) return loss diff --git a/examples/avsr/train.py b/examples/avsr/train.py index cf4c07c9de..04e285c9e1 100644 --- a/examples/avsr/train.py +++ b/examples/avsr/train.py @@ -36,6 +36,7 @@ def get_trainer(args): strategy=DDPStrategy(find_unused_parameters=False), callbacks=callbacks, reload_dataloaders_every_n_epochs=1, + gradient_clip_val=10.0, )