|
import whisper |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from typing import Dict, Iterable, Optional |
|
|
|
import numpy as np |
|
from torch import Tensor, nn |
|
from whisper.model import LayerNorm, ResidualAttentionBlock |
|
|
|
def forward(self, x: torch.Tensor): |
|
""" |
|
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) |
|
the mel spectrogram of the audio |
|
""" |
|
x = F.gelu(self.conv1(x)) |
|
x = F.gelu(self.conv2(x)) |
|
x = x.permute(0, 2, 1) |
|
|
|
x = (x + self.positional_embedding[: x.shape[1], :]).to(x.dtype) |
|
|
|
for block in self.blocks: |
|
x = block(x) |
|
|
|
x = self.ln_post(x) |
|
return x |
|
|
|
def forward_decoder(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): |
|
""" |
|
x : torch.LongTensor, shape = (batch_size, <= n_ctx) |
|
the text tokens |
|
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) |
|
the encoded audio features to be attended on |
|
""" |
|
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 |
|
x = ( |
|
self.token_embedding(x) |
|
+ self.positional_embedding[offset : offset + x.shape[-1]] |
|
) |
|
|
|
x = x + self.positional_embedding[offset : offset + x.shape[1]] |
|
|
|
x = x.to(xa.dtype) |
|
|
|
for block in self.blocks: |
|
x = block(x, xa, mask=self.mask, kv_cache=kv_cache) |
|
|
|
x = self.ln(x) |
|
logits = ( |
|
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) |
|
).float() |
|
|
|
return logits |
|
|
|
|
|
whisper.model.AudioEncoder.forward = forward |
|
whisper.model.TextDecoder.forward = forward_decoder |
|
|
|
audio = 'mid.wav' |
|
model = whisper.load_model("v1-distill/distill-whisper-large-v2-multi-hans-epoch-6-avg-8.pt") |
|
|
|
result = model.transcribe(audio, language='zh', without_timestamps=True, task="transcribe", beam_size=4) |
|
|
|
print(result) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|