From df900b6f99ea55ae4da205be3abba2e6917d4232 Mon Sep 17 00:00:00 2001 From: Goran Katalinic Date: Fri, 9 Jun 2023 10:38:16 +0100 Subject: [PATCH] Fix IPUWhisperTimeStampLogitsProcessor for beam search --- optimum/graphcore/generation/logits_process.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/optimum/graphcore/generation/logits_process.py b/optimum/graphcore/generation/logits_process.py index 2d7ab3845..95a68b405 100644 --- a/optimum/graphcore/generation/logits_process.py +++ b/optimum/graphcore/generation/logits_process.py @@ -194,6 +194,8 @@ def from_model(cls, inst, vocab_size: int): def __call__( self, input_ids: torch.LongTensor, scores: torch.FloatTensor, absolute_step: torch.IntTensor ) -> torch.FloatTensor: + input_ids = input_ids.view(-1, input_ids.shape[-1]) + no_timestamps_mask = self.no_timestamps_mask.to(scores.device) scores = no_timestamps_mask * scores + (1 - no_timestamps_mask) * VERY_LARGE_NEGATIVE_CONST