|
import whisper |
|
from torch import Tensor |
|
import torch.nn.functional as F |
|
from whisper.model import AudioEncoder |
|
|
|
|
|
class WhisperAudioEncoder(AudioEncoder): |
|
""" |
|
We inherited the original Whisper encoder and modified its 30-second fixed-length padding logic to |
|
improve training and inference efficiency. |
|
""" |
|
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): |
|
super().__init__(n_mels, n_ctx, n_state, n_head, n_layer) |
|
self.audio_emb_dim = n_state |
|
|
|
def forward(self, x: Tensor): |
|
""" |
|
x : torch.Tensor, shape = [B, T, d] the mel spectrogram of the audio |
|
""" |
|
x = x.transpose(1, 2) |
|
x = F.gelu(self.conv1(x)) |
|
x = F.gelu(self.conv2(x)) |
|
x = x.permute(0, 2, 1) |
|
positional_embedding = self.positional_embedding[:x.shape[1], :] |
|
assert x.shape[1:] == positional_embedding.shape, "incorrect audio shape" |
|
x = (x + positional_embedding).to(x.dtype) |
|
|
|
for block in self.blocks: |
|
x = block(x) |
|
|
|
x = self.ln_post(x) |
|
return x |
|
|
|
@classmethod |
|
def from_pretrained(cls, model_path, **kwargs): |
|
whisper_model = whisper.load_model(model_path) |
|
audio_encoder = cls( |
|
whisper_model.dims.n_mels, |
|
whisper_model.dims.n_audio_ctx*10, |
|
whisper_model.dims.n_audio_state, |
|
whisper_model.dims.n_audio_head, |
|
whisper_model.dims.n_audio_layer, |
|
).to(whisper_model.device) |
|
state_dict = whisper_model.encoder.state_dict() |
|
state_dict.pop('positional_embedding') |
|
ret = audio_encoder.load_state_dict(state_dict, strict=False) |
|
logger.warning(f'whisper encoder does not load `positional_embedding`. {ret}') |
|
audio_encoder.audio_emb_dim = whisper_model.dims.n_audio_state |
|
return audio_encoder |