Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/r1.13.0' into radtts-1.13
Browse files Browse the repository at this point in the history
  • Loading branch information
borisfom committed Nov 18, 2022
2 parents e208fc4 + c170e03 commit 2565734
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ def fwd_bwd_step(self, batch, batch_idx, forward_only):
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
sequence_parallel_enabled=self.cfg.get("sequence_parallel", False),
sync_batch_comm=self.frozen_model.cfg.get('sync_batch_comm', False),
)
else:
losses_reduced_per_micro_batch = forward_backward_no_pipelining(
Expand All @@ -595,6 +596,7 @@ def fwd_bwd_step(self, batch, batch_idx, forward_only):
tensor_shape=tensor_shape,
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
sync_batch_comm=self.frozen_model.cfg.get('sync_batch_comm', False),
)

# only the last stages of the pipeline return losses
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def fwd_bwd_step(self, batch, batch_idx, forward_only):
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
sequence_parallel_enabled=False,
sync_batch_comm=self.frozen_model.cfg.get('sync_batch_comm', False),
)
else:
losses_reduced_per_micro_batch = forward_backward_no_pipelining(
Expand All @@ -214,6 +215,7 @@ def fwd_bwd_step(self, batch, batch_idx, forward_only):
decoder_sequence_length=dec_seq_length,
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
sync_batch_comm=self.frozen_model.cfg.get('sync_batch_comm', False),
)

# only the last stages of the pipeline return losses
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def forward_step(self, batch, tensor_shape):
forward_only=True,
tensor_shape=tensor_shape,
dtype=self.model.autocast_dtype,
sync_batch_comm=self.model.cfg.get('sync_batch_comm', False),
)
else:
output_tensor = forward_backward_no_pipelining(
Expand All @@ -70,6 +71,7 @@ def forward_step(self, batch, tensor_shape):
forward_only=True,
tensor_shape=tensor_shape,
dtype=self.model.autocast_dtype,
sync_batch_comm=self.model.cfg.get('sync_batch_comm', False),
)
return output_tensor

Expand Down

0 comments on commit 2565734

Please sign in to comment.