| 
							 | 
						import torch  | 
					
					
						
						| 
							 | 
						import torch.nn as nn  | 
					
					
						
						| 
							 | 
						import torch.nn.functional as F  | 
					
					
						
						| 
							 | 
						import transformers | 
					
					
						
						| 
							 | 
						from transformers import (  | 
					
					
						
						| 
							 | 
						    AutoConfig, AutoModel, | 
					
					
						
						| 
							 | 
						    AutoModelForCausalLM, WhisperModel) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from configs import VLFMConfig, LossFunction, LossConfig,  build_tokenizer | 
					
					
						
						| 
							 | 
						from projector import VLFMProjector | 
					
					
						
						| 
							 | 
						from constants import IGNORE_INDEX, SPEECH_TOKEN_INDEX | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from transformers.modeling_outputs import CausalLMOutputWithPast | 
					
					
						
						| 
							 | 
						from transformers.generation.utils import GenerateOutput  | 
					
					
						
						| 
							 | 
						from typing import Optional, Tuple, List, Union | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class VLFMModel(transformers.LlamaPreTrainedModel): | 
					
					
						
						| 
							 | 
						    config_class = VLFMConfig | 
					
					
						
						| 
							 | 
						    def __init__(self, config, torch_dtype=torch.bfloat16): | 
					
					
						
						| 
							 | 
						        super(VLFMModel, self).__init__(config)  | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        whisper = WhisperModel.from_pretrained(config.audio_model_id,  | 
					
					
						
						| 
							 | 
						                                               torch_dtype=torch_dtype,) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.encoder = whisper.encoder | 
					
					
						
						| 
							 | 
						        self.projector = VLFMProjector(config) | 
					
					
						
						| 
							 | 
						        self.language_model = AutoModelForCausalLM.from_pretrained(config.text_model_id,  | 
					
					
						
						| 
							 | 
						                                                                   torch_dtype=torch_dtype) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self._train_module(self.encoder, False) | 
					
					
						
						| 
							 | 
						        self._train_module(self.language_model, False) | 
					
					
						
						| 
							 | 
						        self._train_module(self.projector, True) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.encoder.to(dtype=torch_dtype) | 
					
					
						
						| 
							 | 
						        self.language_model.to(dtype=torch_dtype) | 
					
					
						
						| 
							 | 
						        self.projector.to(dtype=torch_dtype)    | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.tokenizer, self.audio_token_id = build_tokenizer(config.text_model_id, config.tokenizer_padding_side) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.tokenizer_model_max_length = self.tokenizer.model_max_length | 
					
					
						
						| 
							 | 
						        self._resize_token_embeddings(self.tokenizer) | 
					
					
						
						| 
							 | 
						        self.get_input_embeddings().to(dtype=self.language_model.dtype) | 
					
					
						
						| 
							 | 
						        if hasattr(self.language_model, "get_output_embeddings") and self.language_model.get_output_embeddings() is not None: | 
					
					
						
						| 
							 | 
						            self.language_model.get_output_embeddings().to(dtype=self.language_model.dtype) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.loss_config = LossConfig(LossFunction.KL_Divergence) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.post_init() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def get_input_embeddings(self): | 
					
					
						
						| 
							 | 
						        return self.language_model.get_input_embeddings() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def set_input_embeddings(self, new_emb): | 
					
					
						
						| 
							 | 
						        return self.language_model.set_input_embeddings(new_emb) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @property | 
					
					
						
						| 
							 | 
						    def embed_tokens(self): | 
					
					
						
						| 
							 | 
						        return self.language_model.get_input_embeddings() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _train_module(self, module, trainable: bool): | 
					
					
						
						| 
							 | 
						        for param in module.parameters(): | 
					
					
						
						| 
							 | 
						            param.requires_grad= trainable | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    def _audio_iter(self, audio_batch_size): | 
					
					
						
						| 
							 | 
						        audio_index = 0  | 
					
					
						
						| 
							 | 
						        for i_b, count in enumerate(audio_batch_size.view(-1).tolist()): | 
					
					
						
						| 
							 | 
						            for _ in range(int(count)): | 
					
					
						
						| 
							 | 
						                yield i_b, audio_index | 
					
					
						
						| 
							 | 
						                audio_index += 1 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _resize_token_embeddings(self, tokenizer, pad_to_multiple_of=None): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        model_embeds = self.language_model.resize_token_embeddings(len(tokenizer)) | 
					
					
						
						| 
							 | 
						        self.config.vocab_size = model_embeds.num_embeddings | 
					
					
						
						| 
							 | 
						        self.vocab_size = model_embeds.num_embeddings | 
					
					
						
						| 
							 | 
						        return model_embeds | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _encode_speech(self, audio_values): | 
					
					
						
						| 
							 | 
						        with torch.no_grad(): | 
					
					
						
						| 
							 | 
						            encoder_outputs = self.encoder(audio_values, output_hidden_states=False) | 
					
					
						
						| 
							 | 
						            audio_embeds = encoder_outputs.last_hidden_state | 
					
					
						
						| 
							 | 
						        downsampled_embeds = self.projector(audio_embeds)  | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        return downsampled_embeds | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    def _splice_chunks(self, text_embeds, audio_embeds, audio_token_start_idx, audio_token_len, audio_batch_size): | 
					
					
						
						| 
							 | 
						        D = text_embeds.size(-1) | 
					
					
						
						| 
							 | 
						        for i_b, i_chunk in self._audio_iter(audio_batch_size): | 
					
					
						
						| 
							 | 
						            start = int(audio_token_start_idx[i_chunk].item()) | 
					
					
						
						| 
							 | 
						            span = int(audio_token_len[i_chunk].item()) | 
					
					
						
						| 
							 | 
						            a = audio_embeds[i_chunk] | 
					
					
						
						| 
							 | 
						            Ta = a.size(0) | 
					
					
						
						| 
							 | 
						            use = min(Ta, span) | 
					
					
						
						| 
							 | 
						            text_embeds[i_b, start:start+use, :] = a[:use].to(text_embeds.dtype) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    def _compute_kl_loss( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        *, | 
					
					
						
						| 
							 | 
						        student_logits: torch.Tensor, | 
					
					
						
						| 
							 | 
						        labels: torch.Tensor, | 
					
					
						
						| 
							 | 
						        alt_input_ids: torch.Tensor, | 
					
					
						
						| 
							 | 
						        alt_attention_mask: torch.Tensor, | 
					
					
						
						| 
							 | 
						        alt_labels: torch.Tensor, | 
					
					
						
						| 
							 | 
						        past_key_values=None, | 
					
					
						
						| 
							 | 
						        **kwargs, | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        lm_was_training = self.language_model.training | 
					
					
						
						| 
							 | 
						        self.language_model.eval() | 
					
					
						
						| 
							 | 
						        with torch.no_grad(): | 
					
					
						
						| 
							 | 
						            alt_input_embeds = self.language_model.get_input_embeddings()(alt_input_ids) | 
					
					
						
						| 
							 | 
						            teacher_out = self.language_model( | 
					
					
						
						| 
							 | 
						                inputs_embeds=alt_input_embeds, | 
					
					
						
						| 
							 | 
						                attention_mask=alt_attention_mask, | 
					
					
						
						| 
							 | 
						                use_cache=False, | 
					
					
						
						| 
							 | 
						                return_dict=True, | 
					
					
						
						| 
							 | 
						                past_key_values=past_key_values, | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            teacher_logits = teacher_out.logits | 
					
					
						
						| 
							 | 
						        if lm_was_training: | 
					
					
						
						| 
							 | 
						            self.language_model.train() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        T = self.loss_config.kl_temperature | 
					
					
						
						| 
							 | 
						        student = F.log_softmax(student_logits[labels != IGNORE_INDEX] / T, dim=-1) | 
					
					
						
						| 
							 | 
						        teacher = F.softmax(teacher_logits[alt_labels != IGNORE_INDEX] / T, dim=-1) | 
					
					
						
						| 
							 | 
						        kl = F.kl_div(student, teacher, reduction="batchmean")  | 
					
					
						
						| 
							 | 
						        return kl | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward( | 
					
					
						
						| 
							 | 
						            self,  | 
					
					
						
						| 
							 | 
						            input_ids,  | 
					
					
						
						| 
							 | 
						            attention_mask, | 
					
					
						
						| 
							 | 
						            labels=None,  | 
					
					
						
						| 
							 | 
						            *, | 
					
					
						
						| 
							 | 
						            input_features=None,  | 
					
					
						
						| 
							 | 
						            audio_token_start_idx = None,  | 
					
					
						
						| 
							 | 
						            audio_token_len = None,         | 
					
					
						
						| 
							 | 
						            audio_batch_size = None,        | 
					
					
						
						| 
							 | 
						            alt_input_ids = None, | 
					
					
						
						| 
							 | 
						            alt_attention_mask = None, | 
					
					
						
						| 
							 | 
						            alt_labels = None, | 
					
					
						
						| 
							 | 
						            return_dict = True, | 
					
					
						
						| 
							 | 
						            **kwargs): | 
					
					
						
						| 
							 | 
						        tok = self.language_model.get_input_embeddings() | 
					
					
						
						| 
							 | 
						        text_embeds = tok(input_ids) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if input_features is not None and audio_token_start_idx is not None: | 
					
					
						
						| 
							 | 
						            audio_embeds = self._encode_speech(input_features) | 
					
					
						
						| 
							 | 
						            self._splice_chunks( | 
					
					
						
						| 
							 | 
						                text_embeds, | 
					
					
						
						| 
							 | 
						                audio_embeds, | 
					
					
						
						| 
							 | 
						                audio_token_start_idx, | 
					
					
						
						| 
							 | 
						                audio_token_len,  | 
					
					
						
						| 
							 | 
						                audio_batch_size | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        out = self.language_model( | 
					
					
						
						| 
							 | 
						            inputs_embeds=text_embeds,  | 
					
					
						
						| 
							 | 
						            attention_mask=attention_mask, | 
					
					
						
						| 
							 | 
						            labels =labels, | 
					
					
						
						| 
							 | 
						            return_dict=True, | 
					
					
						
						| 
							 | 
						            use_cache = True, | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        logits = out.logits | 
					
					
						
						| 
							 | 
						        ce_loss = out.loss | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        alpha = self.loss_config.ce_weight  | 
					
					
						
						| 
							 | 
						        alpha = self.loss_config.ce_weight  | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        kl = None | 
					
					
						
						| 
							 | 
						        if ( | 
					
					
						
						| 
							 | 
						            self.training | 
					
					
						
						| 
							 | 
						            and alt_input_ids is not None | 
					
					
						
						| 
							 | 
						            and alt_attention_mask is not None | 
					
					
						
						| 
							 | 
						            and alt_labels is not None | 
					
					
						
						| 
							 | 
						        ): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            kl = self._compute_kl_loss(  | 
					
					
						
						| 
							 | 
						                student_logits=logits, | 
					
					
						
						| 
							 | 
						                labels=labels, | 
					
					
						
						| 
							 | 
						                alt_input_ids=alt_input_ids, | 
					
					
						
						| 
							 | 
						                alt_attention_mask=alt_attention_mask,   | 
					
					
						
						| 
							 | 
						                alt_labels=alt_labels, | 
					
					
						
						| 
							 | 
						                past_key_values=None, | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						            total_loss = alpha * ce_loss + (1 - alpha) * kl | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            total_loss = ce_loss | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return { | 
					
					
						
						| 
							 | 
						            "loss": total_loss, | 
					
					
						
						| 
							 | 
						            "loss_ce": ce_loss.detach() if ce_loss is not None else None, | 
					
					
						
						| 
							 | 
						            "loss_kl": kl.detach() if kl is not None else None, | 
					
					
						
						| 
							 | 
						            "logits": logits,} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        ''' if ( | 
					
					
						
						| 
							 | 
						            self.training | 
					
					
						
						| 
							 | 
						            and self.loss_config.loss_function == LossFunction.KL_Divergence | 
					
					
						
						| 
							 | 
						            and alt_input_ids is not None | 
					
					
						
						| 
							 | 
						            and alt_attention_mask is not None | 
					
					
						
						| 
							 | 
						            and alt_labels is not None | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        ): | 
					
					
						
						| 
							 | 
						            kl = self._compute_kl_loss(  | 
					
					
						
						| 
							 | 
						                student_logits=logits, | 
					
					
						
						| 
							 | 
						                labels=labels, | 
					
					
						
						| 
							 | 
						                alt_input_ids=alt_input_ids, | 
					
					
						
						| 
							 | 
						                alt_attention_mask=alt_attention_mask,   | 
					
					
						
						| 
							 | 
						                alt_labels=alt_labels, | 
					
					
						
						| 
							 | 
						               past_key_values=None,) | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						            return { | 
					
					
						
						| 
							 | 
						                "loss": kl,  | 
					
					
						
						| 
							 | 
						                "loss_ce": (ce_loss.detach() if ce_loss is not None else None),  | 
					
					
						
						| 
							 | 
						                logits: logits} | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if return_dict: | 
					
					
						
						| 
							 | 
						            return out  | 
					
					
						
						| 
							 | 
						        return (ce_loss, logits) '''  | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    def _prepare_inputs_embeds( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        input_ids,            | 
					
					
						
						| 
							 | 
						        attention_mask, | 
					
					
						
						| 
							 | 
						        *, | 
					
					
						
						| 
							 | 
						        input_features = None,       | 
					
					
						
						| 
							 | 
						        audio_token_start_idx = None,   | 
					
					
						
						| 
							 | 
						        audio_token_len = None,          | 
					
					
						
						| 
							 | 
						        audio_batch_size= None,        | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Returns: | 
					
					
						
						| 
							 | 
						          inputs_embeds: [B, L, D] with audio spliced in | 
					
					
						
						| 
							 | 
						          attention_mask: [B, L] (unchanged) | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        tok = self.language_model.get_input_embeddings() | 
					
					
						
						| 
							 | 
						        inputs_embeds = tok(input_ids)   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if input_features is not None and audio_token_start_idx is not None: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            feats = input_features | 
					
					
						
						| 
							 | 
						            if feats.dim() == 3 and feats.size(0) == input_ids.size(0): | 
					
					
						
						| 
							 | 
						                audio_batch_size = torch.ones(input_ids.size(0), dtype=torch.long, device=input_ids.device) | 
					
					
						
						| 
							 | 
						            assert audio_batch_size is not None, "audio_batch_size required when splicing audio." | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            audio_embeds = self._encode_audio(feats)   | 
					
					
						
						| 
							 | 
						            self._splice_chunks( | 
					
					
						
						| 
							 | 
						                text_embeds=inputs_embeds, | 
					
					
						
						| 
							 | 
						                audio_embeds=audio_embeds, | 
					
					
						
						| 
							 | 
						                audio_token_start_idx=audio_token_start_idx, | 
					
					
						
						| 
							 | 
						                audio_token_len=audio_token_len, | 
					
					
						
						| 
							 | 
						                audio_batch_size=audio_batch_size, | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return inputs_embeds, attention_mask | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    @torch.no_grad() | 
					
					
						
						| 
							 | 
						    def generate( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        input_ids,                  | 
					
					
						
						| 
							 | 
						        attention_mask,             | 
					
					
						
						| 
							 | 
						        *, | 
					
					
						
						| 
							 | 
						        input_features, | 
					
					
						
						| 
							 | 
						        audio_token_start_idx= None, | 
					
					
						
						| 
							 | 
						        audio_token_len= None, | 
					
					
						
						| 
							 | 
						        audio_batch_size = None, | 
					
					
						
						| 
							 | 
						        **gen_kwargs, | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Build spliced embeddings and call the base LM's generate""" | 
					
					
						
						| 
							 | 
						        self.eval() | 
					
					
						
						| 
							 | 
						        inputs_embeds, attn_mask = self._prepare_inputs_embeds( | 
					
					
						
						| 
							 | 
						            input_ids=input_ids, | 
					
					
						
						| 
							 | 
						            attention_mask=attention_mask, | 
					
					
						
						| 
							 | 
						            input_features=input_features, | 
					
					
						
						| 
							 | 
						            audio_token_start_idx=audio_token_start_idx, | 
					
					
						
						| 
							 | 
						            audio_token_len=audio_token_len, | 
					
					
						
						| 
							 | 
						            audio_batch_size=audio_batch_size, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        return self.language_model.generate( | 
					
					
						
						| 
							 | 
						            inputs_embeds=inputs_embeds, | 
					
					
						
						| 
							 | 
						            attention_mask=attn_mask, | 
					
					
						
						| 
							 | 
						            **gen_kwargs, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						AutoConfig.register("babs-vlfm", VLFMConfig) | 
					
					
						
						| 
							 | 
						AutoModel.register(VLFMConfig, VLFMModel) | 
					
					
						
						| 
							 | 
						
 |