Ming-Lite-Omni / modeling_whisper_encoder.py
LandyGuo
update 20250516 version
81a8221
import whisper
from torch import Tensor
import torch.nn.functional as F
from whisper.model import AudioEncoder
class WhisperAudioEncoder(AudioEncoder):
"""
We inherited the original Whisper encoder and modified its 30-second fixed-length padding logic to
improve training and inference efficiency.
"""
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
super().__init__(n_mels, n_ctx, n_state, n_head, n_layer)
self.audio_emb_dim = n_state
def forward(self, x: Tensor):
"""
x : torch.Tensor, shape = [B, T, d] the mel spectrogram of the audio
"""
x = x.transpose(1, 2)
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
positional_embedding = self.positional_embedding[:x.shape[1], :]
assert x.shape[1:] == positional_embedding.shape, "incorrect audio shape"
x = (x + positional_embedding).to(x.dtype)
for block in self.blocks:
x = block(x)
x = self.ln_post(x)
return x
@classmethod
def from_pretrained(cls, model_path, **kwargs):
whisper_model = whisper.load_model(model_path)
audio_encoder = cls(
whisper_model.dims.n_mels,
whisper_model.dims.n_audio_ctx*10,
whisper_model.dims.n_audio_state,
whisper_model.dims.n_audio_head,
whisper_model.dims.n_audio_layer,
).to(whisper_model.device)
state_dict = whisper_model.encoder.state_dict()
state_dict.pop('positional_embedding')
ret = audio_encoder.load_state_dict(state_dict, strict=False)
logger.warning(f'whisper encoder does not load `positional_embedding`. {ret}')
audio_encoder.audio_emb_dim = whisper_model.dims.n_audio_state
return audio_encoder