Yuekai Zhang
add test codes
a0ed7fb
import whisper
import torch
import torch.nn.functional as F
from typing import Dict, Iterable, Optional
import numpy as np
from torch import Tensor, nn
from whisper.model import LayerNorm, ResidualAttentionBlock
def forward(self, x: torch.Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
x = (x + self.positional_embedding[: x.shape[1], :]).to(x.dtype)
for block in self.blocks:
x = block(x)
x = self.ln_post(x)
return x
def forward_decoder(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
the encoded audio features to be attended on
"""
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = (
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]]
)
# WARNING: here is a bug in the original fine-tuned model, to use the pretrained model, we have to add the following line
x = x + self.positional_embedding[offset : offset + x.shape[1]]
########################################################################################################################
x = x.to(xa.dtype)
for block in self.blocks:
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
x = self.ln(x)
logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()
return logits
whisper.model.AudioEncoder.forward = forward
whisper.model.TextDecoder.forward = forward_decoder
audio = 'mid.wav'
model = whisper.load_model("v1-distill/distill-whisper-large-v2-multi-hans-epoch-6-avg-8.pt")
result = model.transcribe(audio, language='zh', without_timestamps=True, task="transcribe", beam_size=4)
print(result)
# model = whisper.load_model("large-v2")
# checkpoint = torch.load(
# "v1-distill/epoch-6-avg-8.pt", map_location="cpu"
# )
# model.load_state_dict(checkpoint, strict=True)
# audio = whisper.load_audio(audio)
# audio = whisper.pad_or_trim(audio, length=16000*10)
# mel = whisper.log_mel_spectrogram(audio).to(model.device)
# options = whisper.DecodingOptions(
# task="transcribe",
# language="zh",
# without_timestamps=True,
# beam_size=1,
# )
# result = model.decode(mel, options)