File size: 3,296 Bytes
f969c1d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
from enum import Enum
import dataclasses
from typing import Optional
import transformers
from transformers import WhisperConfig, AutoConfig
from transformers import AutoTokenizer
from constants import IGNORE_INDEX
class VLFMConfig(transformers.PretrainedConfig):
model_type = "babs-vlfm"
def __init__(
self,
audio_model_id: Optional[str] = None,
text_model_id: Optional[str] = None,
*,
ignore_index: int = IGNORE_INDEX,
stack_factor: int = 8,
encoder_ds_factor: int = 2,
projector_act: str = "swiglu",
projector_ln_mid: bool = True,
max_audio_seconds: int = 30,
audio_padding: str = "longest",
tokenizer_padding_side: str = "right",
hidden_size: Optional[int] = 4096,
speech_encoder_hidden_size: Optional[int] = None,
vocab_size: Optional[int] = None,
**kwargs,
):
super().__init__(**kwargs)
self.audio_model_id = audio_model_id
self.text_model_id = text_model_id
self.ignore_index = ignore_index
self.stack_factor = stack_factor
self.ds_rate = encoder_ds_factor
self.projector_act = projector_act
self.projector_ln_mid = projector_ln_mid
self.proj_hidden_dim = hidden_size
self.max_seconds = max_audio_seconds
self.audio_padding = audio_padding
self.tokenizer_padding_side = tokenizer_padding_side
self.audio_config = None
self.text_config = None
if audio_model_id:
self.audio_config = WhisperConfig.from_pretrained(audio_model_id)
self.speech_encoder_hidden_size = self.audio_config.hidden_size
#print(f"audio_hidden_size: {self.speech_encoder_hidden_size}")
else:
self.speech_encoder_hidden_size = speech_encoder_hidden_size
if text_model_id:
self.text_config = AutoConfig.from_pretrained(text_model_id)
self.llm_hidden_size = self.text_config.hidden_size
#self.llm_hidden_size = 2048
#print(f"LLM hidden size: {self.llm_hidden_size}")
self.vocab_size = getattr(self.text_config, "vocab_size", vocab_size)
else:
self.llm_hidden_size =hidden_size
self.vocab_size = vocab_size
self.rms_norm_eps = 1e-6
self.rms_norm_init_factor = 0.4
class LossFunction(str, Enum):
CrossEntropy = "ce"
KL_Divergence = "kl"
@dataclasses.dataclass
class LossConfig:
loss_function: LossFunction = LossFunction.CrossEntropy
kl_temperature: float = 2.0
ce_weight = 0.5
@property
def requires_alt_fields(self) -> bool:
return self.loss_function == LossFunction.KL_Divergence
AUDIO_PLACEHOLDER = "<|audio|>"
def build_tokenizer(text_model_id: str, padding_side: str = "right"):
tok = AutoTokenizer.from_pretrained(text_model_id)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
tok.padding_side = padding_side
# Add audio placeholder if missing
if AUDIO_PLACEHOLDER not in tok.get_vocab():
tok.add_special_tokens({"additional_special_tokens": [AUDIO_PLACEHOLDER]})
audio_token_id = tok.convert_tokens_to_ids(AUDIO_PLACEHOLDER)
return tok, audio_token_id
|