vlfm-v3-3B-copy-copy2 / configs.py
Guilherme34's picture
Duplicate from babs/vlfm-v3-3B
f969c1d verified
from enum import Enum
import dataclasses
from typing import Optional
import transformers
from transformers import WhisperConfig, AutoConfig
from transformers import AutoTokenizer
from constants import IGNORE_INDEX
class VLFMConfig(transformers.PretrainedConfig):
model_type = "babs-vlfm"
def __init__(
self,
audio_model_id: Optional[str] = None,
text_model_id: Optional[str] = None,
*,
ignore_index: int = IGNORE_INDEX,
stack_factor: int = 8,
encoder_ds_factor: int = 2,
projector_act: str = "swiglu",
projector_ln_mid: bool = True,
max_audio_seconds: int = 30,
audio_padding: str = "longest",
tokenizer_padding_side: str = "right",
hidden_size: Optional[int] = 4096,
speech_encoder_hidden_size: Optional[int] = None,
vocab_size: Optional[int] = None,
**kwargs,
):
super().__init__(**kwargs)
self.audio_model_id = audio_model_id
self.text_model_id = text_model_id
self.ignore_index = ignore_index
self.stack_factor = stack_factor
self.ds_rate = encoder_ds_factor
self.projector_act = projector_act
self.projector_ln_mid = projector_ln_mid
self.proj_hidden_dim = hidden_size
self.max_seconds = max_audio_seconds
self.audio_padding = audio_padding
self.tokenizer_padding_side = tokenizer_padding_side
self.audio_config = None
self.text_config = None
if audio_model_id:
self.audio_config = WhisperConfig.from_pretrained(audio_model_id)
self.speech_encoder_hidden_size = self.audio_config.hidden_size
#print(f"audio_hidden_size: {self.speech_encoder_hidden_size}")
else:
self.speech_encoder_hidden_size = speech_encoder_hidden_size
if text_model_id:
self.text_config = AutoConfig.from_pretrained(text_model_id)
self.llm_hidden_size = self.text_config.hidden_size
#self.llm_hidden_size = 2048
#print(f"LLM hidden size: {self.llm_hidden_size}")
self.vocab_size = getattr(self.text_config, "vocab_size", vocab_size)
else:
self.llm_hidden_size =hidden_size
self.vocab_size = vocab_size
self.rms_norm_eps = 1e-6
self.rms_norm_init_factor = 0.4
class LossFunction(str, Enum):
CrossEntropy = "ce"
KL_Divergence = "kl"
@dataclasses.dataclass
class LossConfig:
loss_function: LossFunction = LossFunction.CrossEntropy
kl_temperature: float = 2.0
ce_weight = 0.5
@property
def requires_alt_fields(self) -> bool:
return self.loss_function == LossFunction.KL_Divergence
AUDIO_PLACEHOLDER = "<|audio|>"
def build_tokenizer(text_model_id: str, padding_side: str = "right"):
tok = AutoTokenizer.from_pretrained(text_model_id)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
tok.padding_side = padding_side
# Add audio placeholder if missing
if AUDIO_PLACEHOLDER not in tok.get_vocab():
tok.add_special_tokens({"additional_special_tokens": [AUDIO_PLACEHOLDER]})
audio_token_id = tok.convert_tokens_to_ids(AUDIO_PLACEHOLDER)
return tok, audio_token_id