Skip to content

Commit

Permalink
finish exporting .en models
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuzilin committed Sep 30, 2022
1 parent df8cf85 commit 3a0e935
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 31 deletions.
37 changes: 27 additions & 10 deletions whisper/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,41 +133,56 @@ def __init__(self, model: "Whisper", initial_token_length: int):
self.model: "Whisper" = model
self.initial_token_length = initial_token_length
self.kv_cache = None
self.export_onnx = False
if model.type == "tiny.en":
self.kv_cache_size = lambda x, y: [8, x, y, 384]
elif model.type == "base.en":
self.kv_cache_size = lambda x, y: [12, x, y, 512]
elif model.type == "small.en":
self.kv_cache_size = lambda x, y: [24, x, y, 768]
elif model.type == "medium.en":
self.kv_cache_size = lambda x, y: [48, x, y, 1024]
else:
raise ValueError(f"Unsupported model type: {model.type}")

def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
n_group = tokens.shape[0]
if self.kv_cache is None:
# hard code for decoder layer 4, 6, 8, 10
self.kv_cache = np.zeros([8, 5, self.initial_token_length, 384], dtype=np.float32)
self.kv_cache = np.zeros(
self.kv_cache_size(n_group, self.initial_token_length), dtype=np.float32)
offset = 0
else:
offset = self.kv_cache.shape[2]
new_kv_cache = np.zeros([8, 5, offset + 1, 384], dtype=np.float32)
new_kv_cache = np.zeros(self.kv_cache_size(n_group, offset + 1), dtype=np.float32)
new_kv_cache[:, :, :-1, :] = self.kv_cache
self.kv_cache = new_kv_cache

if tokens.shape[-1] > self.initial_token_length:
# only need to use the last token except in the first forward pass
tokens = tokens[:, -1:]

if self.export_onnx and self.kv_cache.shape[2] > self.initial_token_length:
# export decoder as onnx
if False and self.kv_cache.shape[2] > self.initial_token_length:
print(f"tokens: {tokens.shape}")
print(f"audio_features: {audio_features.shape}")
print(f"kv_cache: {self.kv_cache.shape}")
torch.onnx.export(
self.model.decoder,
(tokens, audio_features, torch.from_numpy(self.kv_cache), torch.tensor(offset)),
"decoder.onnx",
verbose=True,
verbose=False,
opset_version=13,
input_names=["tokens", "audio_features", "kv_cache", "offset"],
output_names=["logits", "output_kv_cache"],
dynamic_axes={
"tokens": [1],
"kv_cache": [2],
"tokens": [0, 1],
"audio_features": [0],
"kv_cache": [1, 2],
"output_kv_cache": [2],
}
)
exit()
output, self.kv_cache = self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache, offset=offset)
# output, self.kv_cache = self.model.decoder(tokens, audio_features, kv_cache=torch.from_numpy(self.kv_cache), offset=torch.tensor(offset))
#output, self.kv_cache = self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache, offset=offset)
output, self.kv_cache = self.model.decoder(tokens, audio_features, kv_cache=torch.from_numpy(self.kv_cache), offset=torch.tensor(offset))
return output

def cleanup_caching(self):
Expand Down Expand Up @@ -578,6 +593,7 @@ def _get_audio_features(self, mel: Tensor):
# encoded audio features are given; skip audio encoding
audio_features = mel
else:
# # export encoder as onnx
# torch.onnx.export(
# self.model.encoder,
# (mel),
Expand Down Expand Up @@ -615,6 +631,7 @@ def _main_loop(self, audio_features: Tensor, tokens: Tensor):
try:
for i in range(self.sample_len):
logits = self.inference.logits(tokens, audio_features)
print(f"step: {i}, logits: {logits}", flush=True)

if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
Expand Down
52 changes: 31 additions & 21 deletions whisper/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,12 @@ def forward(
k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa)
if kv_cache is not None and k.shape[1] <= self.n_ctx:
key_id = self.layer_id - 4
# here is hard coded
# tiny.en: 4
# base.en: 6
# small.en: 12
# medium.en: 24
key_id = self.layer_id - 24
value_id = key_id + 1
size = k.shape[1]
kv_cache[key_id, :, -size:, :] = k
Expand Down Expand Up @@ -210,8 +215,10 @@ def __init__(self, model: str):

self.core = Core()
self._model = self.core.read_model(
hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="encoder.xml"),
hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="encoder.bin"),
# hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="encoder.xml"),
# hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="encoder.bin"),
"encoder.xml",
"encoder.bin",
)
self.model = self.core.compile_model(self._model, "CPU")

Expand All @@ -226,8 +233,10 @@ def __init__(self, model: str):

self.core = Core()
self._model = self.core.read_model(
hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="decoder.xml"),
hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="decoder.bin"),
# hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="decoder.xml"),
# hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="decoder.bin"),
"decoder.xml",
"decoder.bin",
)
self.model = self.core.compile_model(self._model, "CPU")

Expand All @@ -246,23 +255,24 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: Tensor, offset: int):
class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions, model: str):
super().__init__()
self.type = model
self.dims = dims
# self.encoder = AudioEncoder(
# self.dims.n_mels,
# self.dims.n_audio_ctx,
# self.dims.n_audio_state,
# self.dims.n_audio_head,
# self.dims.n_audio_layer,
# )
# self.decoder = TextDecoder(
# self.dims.n_vocab,
# self.dims.n_text_ctx,
# self.dims.n_text_state,
# self.dims.n_text_head,
# self.dims.n_text_layer,
# )
self.encoder = OpenVinoAudioEncoder(model=model)
self.decoder = OpenVinoTextDecoder(model=model)
self.encoder = AudioEncoder(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
)
self.decoder = TextDecoder(
self.dims.n_vocab,
self.dims.n_text_ctx,
self.dims.n_text_state,
self.dims.n_text_head,
self.dims.n_text_layer,
)
# self.encoder = OpenVinoAudioEncoder(model=model)
# self.decoder = OpenVinoTextDecoder(model=model)

def embed_audio(self, mel: torch.Tensor):
return self.encoder.forward(mel)
Expand Down

0 comments on commit 3a0e935

Please sign in to comment.