|
|
from enum import Enum |
|
|
import dataclasses |
|
|
from typing import Optional |
|
|
|
|
|
import transformers |
|
|
from transformers import WhisperConfig, AutoConfig |
|
|
from transformers import AutoTokenizer |
|
|
from constants import IGNORE_INDEX |
|
|
|
|
|
class VLFMConfig(transformers.PretrainedConfig): |
|
|
model_type = "babs-vlfm" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
audio_model_id: Optional[str] = None, |
|
|
text_model_id: Optional[str] = None, |
|
|
*, |
|
|
ignore_index: int = IGNORE_INDEX, |
|
|
stack_factor: int = 8, |
|
|
encoder_ds_factor: int = 2, |
|
|
projector_act: str = "swiglu", |
|
|
projector_ln_mid: bool = True, |
|
|
max_audio_seconds: int = 30, |
|
|
audio_padding: str = "longest", |
|
|
tokenizer_padding_side: str = "right", |
|
|
hidden_size: Optional[int] = 4096, |
|
|
speech_encoder_hidden_size: Optional[int] = None, |
|
|
vocab_size: Optional[int] = None, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
self.audio_model_id = audio_model_id |
|
|
self.text_model_id = text_model_id |
|
|
|
|
|
self.ignore_index = ignore_index |
|
|
self.stack_factor = stack_factor |
|
|
self.ds_rate = encoder_ds_factor |
|
|
self.projector_act = projector_act |
|
|
self.projector_ln_mid = projector_ln_mid |
|
|
self.proj_hidden_dim = hidden_size |
|
|
|
|
|
self.max_seconds = max_audio_seconds |
|
|
self.audio_padding = audio_padding |
|
|
self.tokenizer_padding_side = tokenizer_padding_side |
|
|
self.audio_config = None |
|
|
self.text_config = None |
|
|
|
|
|
if audio_model_id: |
|
|
self.audio_config = WhisperConfig.from_pretrained(audio_model_id) |
|
|
self.speech_encoder_hidden_size = self.audio_config.hidden_size |
|
|
|
|
|
else: |
|
|
self.speech_encoder_hidden_size = speech_encoder_hidden_size |
|
|
|
|
|
if text_model_id: |
|
|
self.text_config = AutoConfig.from_pretrained(text_model_id) |
|
|
self.llm_hidden_size = self.text_config.hidden_size |
|
|
|
|
|
|
|
|
self.vocab_size = getattr(self.text_config, "vocab_size", vocab_size) |
|
|
else: |
|
|
self.llm_hidden_size =hidden_size |
|
|
self.vocab_size = vocab_size |
|
|
|
|
|
self.rms_norm_eps = 1e-6 |
|
|
self.rms_norm_init_factor = 0.4 |
|
|
|
|
|
|
|
|
class LossFunction(str, Enum): |
|
|
CrossEntropy = "ce" |
|
|
KL_Divergence = "kl" |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class LossConfig: |
|
|
loss_function: LossFunction = LossFunction.CrossEntropy |
|
|
kl_temperature: float = 2.0 |
|
|
ce_weight = 0.5 |
|
|
|
|
|
@property |
|
|
def requires_alt_fields(self) -> bool: |
|
|
return self.loss_function == LossFunction.KL_Divergence |
|
|
|
|
|
AUDIO_PLACEHOLDER = "<|audio|>" |
|
|
|
|
|
def build_tokenizer(text_model_id: str, padding_side: str = "right"): |
|
|
tok = AutoTokenizer.from_pretrained(text_model_id) |
|
|
if tok.pad_token is None: |
|
|
tok.pad_token = tok.eos_token |
|
|
tok.padding_side = padding_side |
|
|
|
|
|
if AUDIO_PLACEHOLDER not in tok.get_vocab(): |
|
|
tok.add_special_tokens({"additional_special_tokens": [AUDIO_PLACEHOLDER]}) |
|
|
audio_token_id = tok.convert_tokens_to_ids(AUDIO_PLACEHOLDER) |
|
|
return tok, audio_token_id |
|
|
|