From b5e94e55b5234ebefb5131881995cfd1d56b7c21 Mon Sep 17 00:00:00 2001 From: MaximumEntropy Date: Tue, 6 Dec 2022 13:59:22 -0800 Subject: [PATCH] Remove broadcast Signed-off-by: MaximumEntropy --- .../megatron_t5_prompt_learning_model.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py index 3e668347ce14..643841c09f14 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py @@ -464,15 +464,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A else: encoder_input = torch.zeros((batch_size, seq_length, self.hidden_size), dtype=self.autocast_dtype).cuda() - if self.cfg.get('pipeline_model_parallel_size', 1) > 1: - # Broadcasting encoder inputs to all ranks for now, but this is inefficent. - # TODO: Make Enc-Dec improvement to only boardcast encoder_ids/embeddings when needed - torch.distributed.broadcast( - encoder_input, - parallel_state.get_pipeline_model_parallel_first_rank(), - group=parallel_state.get_pipeline_model_parallel_group(), - ) - predicted_token_ids, log_probs = self.frozen_model.decode( tokens_enc=input_ids, enc_mask=enc_mask,