Skip to content

Commit

Permalink
Simplify trainining step in av-asr recipe (pytorch#3598)
Browse files Browse the repository at this point in the history
* Simplify trainining step in av-asr recipe

* Run pre-commit
  • Loading branch information
Pingchuan Ma authored Sep 8, 2023
1 parent 3e1d8f3 commit 5e893d6
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 24 deletions.
12 changes: 0 additions & 12 deletions examples/avsr/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@ def __init__(self, args=None, sp_model=None, pretrained_model_path=None):
betas=(0.9, 0.98),
)

self.automatic_optimization = False

def _step(self, batch, _, step_type):
if batch is None:
return None
Expand Down Expand Up @@ -123,20 +121,10 @@ def forward(self, batch):
return post_process_hypos(hypotheses, self.sp_model)[0][0]

def training_step(self, batch, batch_idx):
opt = self.optimizers()
opt.zero_grad()
loss = self._step(batch, batch_idx, "train")
batch_size = batch.inputs.size(0)
batch_sizes = self.all_gather(batch_size)

loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size
self.manual_backward(loss)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10)
opt.step()

sch = self.lr_schedulers()
sch.step()

self.log("monitoring_step", torch.tensor(self.global_step, dtype=torch.float32))

return loss
Expand Down
12 changes: 0 additions & 12 deletions examples/avsr/lightning_av.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,6 @@ def __init__(self, args=None, sp_model=None):
betas=(0.9, 0.98),
)

self.automatic_optimization = False

def _step(self, batch, _, step_type):
if batch is None:
return None
Expand Down Expand Up @@ -128,20 +126,10 @@ def forward(self, batch):
return post_process_hypos(hypotheses, self.sp_model)[0][0]

def training_step(self, batch, batch_idx):
opt = self.optimizers()
opt.zero_grad()
loss = self._step(batch, batch_idx, "train")
batch_size = batch.videos.size(0)
batch_sizes = self.all_gather(batch_size)

loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size
self.manual_backward(loss)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10)
opt.step()

sch = self.lr_schedulers()
sch.step()

self.log("monitoring_step", torch.tensor(self.global_step, dtype=torch.float32))

return loss
Expand Down
1 change: 1 addition & 0 deletions examples/avsr/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def get_trainer(args):
strategy=DDPStrategy(find_unused_parameters=False),
callbacks=callbacks,
reload_dataloaders_every_n_epochs=1,
gradient_clip_val=10.0,
)


Expand Down

0 comments on commit 5e893d6

Please sign in to comment.