File size: 3,296 Bytes
f969c1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from enum import Enum
import dataclasses
from typing import Optional

import transformers
from transformers import WhisperConfig, AutoConfig
from transformers import AutoTokenizer
from constants import IGNORE_INDEX

class VLFMConfig(transformers.PretrainedConfig):
    model_type = "babs-vlfm"

    def __init__(
        self,
        audio_model_id: Optional[str] = None,
        text_model_id: Optional[str] = None,
        *,
        ignore_index: int = IGNORE_INDEX,
        stack_factor: int = 8,          
        encoder_ds_factor: int = 2,     
        projector_act: str = "swiglu",
        projector_ln_mid: bool = True,
        max_audio_seconds: int = 30,
        audio_padding: str = "longest",
        tokenizer_padding_side: str = "right",
        hidden_size: Optional[int] = 4096,    
        speech_encoder_hidden_size: Optional[int] = None,  
        vocab_size: Optional[int] = None,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.audio_model_id = audio_model_id
        self.text_model_id = text_model_id

        self.ignore_index = ignore_index
        self.stack_factor = stack_factor
        self.ds_rate = encoder_ds_factor
        self.projector_act = projector_act
        self.projector_ln_mid = projector_ln_mid
        self.proj_hidden_dim = hidden_size

        self.max_seconds = max_audio_seconds
        self.audio_padding = audio_padding
        self.tokenizer_padding_side = tokenizer_padding_side
        self.audio_config = None
        self.text_config = None

        if audio_model_id:
            self.audio_config = WhisperConfig.from_pretrained(audio_model_id)
            self.speech_encoder_hidden_size = self.audio_config.hidden_size
            #print(f"audio_hidden_size: {self.speech_encoder_hidden_size}")
        else:
            self.speech_encoder_hidden_size = speech_encoder_hidden_size

        if text_model_id:
            self.text_config = AutoConfig.from_pretrained(text_model_id)
            self.llm_hidden_size = self.text_config.hidden_size   
            #self.llm_hidden_size = 2048      
            #print(f"LLM hidden size: {self.llm_hidden_size}")
            self.vocab_size = getattr(self.text_config, "vocab_size", vocab_size)
        else:
            self.llm_hidden_size =hidden_size
            self.vocab_size = vocab_size

        self.rms_norm_eps = 1e-6
        self.rms_norm_init_factor = 0.4


class LossFunction(str, Enum):
    CrossEntropy = "ce"
    KL_Divergence = "kl"


@dataclasses.dataclass
class LossConfig:
    loss_function: LossFunction = LossFunction.CrossEntropy
    kl_temperature: float = 2.0
    ce_weight = 0.5

    @property
    def requires_alt_fields(self) -> bool:
        return self.loss_function == LossFunction.KL_Divergence

AUDIO_PLACEHOLDER = "<|audio|>"

def build_tokenizer(text_model_id: str, padding_side: str = "right"):
    tok = AutoTokenizer.from_pretrained(text_model_id)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    tok.padding_side = padding_side
    # Add audio placeholder if missing
    if AUDIO_PLACEHOLDER not in tok.get_vocab():
        tok.add_special_tokens({"additional_special_tokens": [AUDIO_PLACEHOLDER]})
    audio_token_id = tok.convert_tokens_to_ids(AUDIO_PLACEHOLDER)
    return tok, audio_token_id