forked from openai/whisper
-
Notifications
You must be signed in to change notification settings - Fork 14
/
model.py
113 lines (92 loc) · 3.88 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from dataclasses import dataclass
from typing import Dict, Union
import numpy as np
import torch
from torch import Tensor
from torch import nn
from huggingface_hub import hf_hub_download
from openvino.runtime import Core
from .transcribe import transcribe as transcribe_function
from .decoding import detect_language as detect_language_function, decode as decode_function
@dataclass
class ModelDimensions:
n_mels: int
n_audio_ctx: int
n_audio_state: int
n_audio_head: int
n_audio_layer: int
n_vocab: int
n_text_ctx: int
n_text_state: int
n_text_head: int
n_text_layer: int
class OpenVinoAudioEncoder(nn.Module):
def __init__(self, model: str):
super().__init__()
self.core = Core()
self._model = self.core.read_model(
hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="encoder.xml"),
hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="encoder.bin"),
)
self.model = self.core.compile_model(self._model, "CPU")
def forward(self, x: Tensor):
result = self.model.infer_new_request(x.numpy())
return torch.from_numpy(next(iter(result.values())))
class OpenVinoTextDecoder(nn.Module):
def __init__(self, model: str):
super().__init__()
self.core = Core()
self._model = self.core.read_model(
hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="decoder.xml"),
hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="decoder.bin"),
)
self.model = self.core.compile_model(self._model, "CPU")
def forward(self, x: Tensor, xa: Union[Tensor, np.ndarray], kv_cache: Tensor, offset: int):
if torch.is_tensor(xa):
xa = xa.numpy()
output, kv_cache = self.model.infer_new_request(
{
"tokens": x.numpy(),
"audio_features": xa,
"kv_cache": kv_cache,
"offset": np.array(offset, dtype=int),
}
).values()
return torch.from_numpy(output), kv_cache
class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions, model: str):
super().__init__()
self.type = model
self.dims = dims
self.encoder = OpenVinoAudioEncoder(model=model)
self.decoder = OpenVinoTextDecoder(model=model)
def embed_audio(self, mel: torch.Tensor):
return self.encoder.forward(mel)
def logits(self, tokens: torch.Tensor, audio_features: Union[torch.Tensor, np.ndarray]):
kv_cache = self.new_kv_cache(tokens.shape[0], tokens.shape[-1])
output, _ = self.decoder.forward(tokens, audio_features, kv_cache=kv_cache, offset=0)
return output
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
kv_cache = self.new_kv_cache(tokens.shape[0], tokens.shape[-1])
output, _ = self.decoder(tokens, self.encoder(mel), kv_cache=kv_cache, offset=0)
return output
@property
def is_multilingual(self):
return self.dims.n_vocab == 51865
def new_kv_cache(self, n_group: int, length: int):
if self.type == "tiny.en" or self.type == "tiny":
size = [8, n_group, length, 384]
elif self.type == "base.en" or self.type == "base":
size = [12, n_group, length, 512]
elif self.type == "small.en" or self.type == "small":
size = [24, n_group, length, 768]
elif self.type == "medium.en" or self.type == "medium":
size = [48, n_group, length, 1024]
elif self.type == "large":
size = [64, n_group, length, 1280]
else:
raise ValueError(f"Unsupported model type: {self.type}")
return np.zeros(size, dtype=np.float32)
detect_language = detect_language_function
transcribe = transcribe_function
decode = decode_function