diff --git a/whisper/decoding.py b/whisper/decoding.py index 26ba3ceff..87a9a9bc2 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -133,16 +133,26 @@ def __init__(self, model: "Whisper", initial_token_length: int): self.model: "Whisper" = model self.initial_token_length = initial_token_length self.kv_cache = None - self.export_onnx = False + if model.type == "tiny.en": + self.kv_cache_size = lambda x, y: [8, x, y, 384] + elif model.type == "base.en": + self.kv_cache_size = lambda x, y: [12, x, y, 512] + elif model.type == "small.en": + self.kv_cache_size = lambda x, y: [24, x, y, 768] + elif model.type == "medium.en": + self.kv_cache_size = lambda x, y: [48, x, y, 1024] + else: + raise ValueError(f"Unsupported model type: {model.type}") def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: + n_group = tokens.shape[0] if self.kv_cache is None: - # hard code for decoder layer 4, 6, 8, 10 - self.kv_cache = np.zeros([8, 5, self.initial_token_length, 384], dtype=np.float32) + self.kv_cache = np.zeros( + self.kv_cache_size(n_group, self.initial_token_length), dtype=np.float32) offset = 0 else: offset = self.kv_cache.shape[2] - new_kv_cache = np.zeros([8, 5, offset + 1, 384], dtype=np.float32) + new_kv_cache = np.zeros(self.kv_cache_size(n_group, offset + 1), dtype=np.float32) new_kv_cache[:, :, :-1, :] = self.kv_cache self.kv_cache = new_kv_cache @@ -150,24 +160,29 @@ def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: # only need to use the last token except in the first forward pass tokens = tokens[:, -1:] - if self.export_onnx and self.kv_cache.shape[2] > self.initial_token_length: + # export decoder as onnx + if False and self.kv_cache.shape[2] > self.initial_token_length: + print(f"tokens: {tokens.shape}") + print(f"audio_features: {audio_features.shape}") + print(f"kv_cache: {self.kv_cache.shape}") torch.onnx.export( self.model.decoder, (tokens, audio_features, torch.from_numpy(self.kv_cache), torch.tensor(offset)), "decoder.onnx", - verbose=True, + verbose=False, opset_version=13, input_names=["tokens", "audio_features", "kv_cache", "offset"], output_names=["logits", "output_kv_cache"], dynamic_axes={ - "tokens": [1], - "kv_cache": [2], + "tokens": [0, 1], + "audio_features": [0], + "kv_cache": [1, 2], "output_kv_cache": [2], } ) exit() - output, self.kv_cache = self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache, offset=offset) - # output, self.kv_cache = self.model.decoder(tokens, audio_features, kv_cache=torch.from_numpy(self.kv_cache), offset=torch.tensor(offset)) + #output, self.kv_cache = self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache, offset=offset) + output, self.kv_cache = self.model.decoder(tokens, audio_features, kv_cache=torch.from_numpy(self.kv_cache), offset=torch.tensor(offset)) return output def cleanup_caching(self): @@ -578,6 +593,7 @@ def _get_audio_features(self, mel: Tensor): # encoded audio features are given; skip audio encoding audio_features = mel else: + # # export encoder as onnx # torch.onnx.export( # self.model.encoder, # (mel), @@ -615,6 +631,7 @@ def _main_loop(self, audio_features: Tensor, tokens: Tensor): try: for i in range(self.sample_len): logits = self.inference.logits(tokens, audio_features) + print(f"step: {i}, logits: {logits}", flush=True) if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1) diff --git a/whisper/model.py b/whisper/model.py index 5b19db883..90671ebba 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -84,7 +84,12 @@ def forward( k = self.key(x if xa is None else xa) v = self.value(x if xa is None else xa) if kv_cache is not None and k.shape[1] <= self.n_ctx: - key_id = self.layer_id - 4 + # here is hard coded + # tiny.en: 4 + # base.en: 6 + # small.en: 12 + # medium.en: 24 + key_id = self.layer_id - 24 value_id = key_id + 1 size = k.shape[1] kv_cache[key_id, :, -size:, :] = k @@ -210,8 +215,10 @@ def __init__(self, model: str): self.core = Core() self._model = self.core.read_model( - hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="encoder.xml"), - hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="encoder.bin"), + # hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="encoder.xml"), + # hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="encoder.bin"), + "encoder.xml", + "encoder.bin", ) self.model = self.core.compile_model(self._model, "CPU") @@ -226,8 +233,10 @@ def __init__(self, model: str): self.core = Core() self._model = self.core.read_model( - hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="decoder.xml"), - hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="decoder.bin"), + # hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="decoder.xml"), + # hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="decoder.bin"), + "decoder.xml", + "decoder.bin", ) self.model = self.core.compile_model(self._model, "CPU") @@ -246,23 +255,24 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: Tensor, offset: int): class Whisper(nn.Module): def __init__(self, dims: ModelDimensions, model: str): super().__init__() + self.type = model self.dims = dims - # self.encoder = AudioEncoder( - # self.dims.n_mels, - # self.dims.n_audio_ctx, - # self.dims.n_audio_state, - # self.dims.n_audio_head, - # self.dims.n_audio_layer, - # ) - # self.decoder = TextDecoder( - # self.dims.n_vocab, - # self.dims.n_text_ctx, - # self.dims.n_text_state, - # self.dims.n_text_head, - # self.dims.n_text_layer, - # ) - self.encoder = OpenVinoAudioEncoder(model=model) - self.decoder = OpenVinoTextDecoder(model=model) + self.encoder = AudioEncoder( + self.dims.n_mels, + self.dims.n_audio_ctx, + self.dims.n_audio_state, + self.dims.n_audio_head, + self.dims.n_audio_layer, + ) + self.decoder = TextDecoder( + self.dims.n_vocab, + self.dims.n_text_ctx, + self.dims.n_text_state, + self.dims.n_text_head, + self.dims.n_text_layer, + ) + # self.encoder = OpenVinoAudioEncoder(model=model) + # self.decoder = OpenVinoTextDecoder(model=model) def embed_audio(self, mel: torch.Tensor): return self.encoder.forward(mel)