Guilherme34's picture
Duplicate from babs/vlfm-v3-3B
f969c1d verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from transformers import (
AutoConfig, AutoModel,
AutoModelForCausalLM, WhisperModel)
from configs import VLFMConfig, LossFunction, LossConfig, build_tokenizer
from projector import VLFMProjector
from constants import IGNORE_INDEX, SPEECH_TOKEN_INDEX
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput
from typing import Optional, Tuple, List, Union
class VLFMModel(transformers.LlamaPreTrainedModel):
config_class = VLFMConfig
def __init__(self, config, torch_dtype=torch.bfloat16):
super(VLFMModel, self).__init__(config)
whisper = WhisperModel.from_pretrained(config.audio_model_id,
torch_dtype=torch_dtype,)
self.encoder = whisper.encoder
self.projector = VLFMProjector(config)
self.language_model = AutoModelForCausalLM.from_pretrained(config.text_model_id,
torch_dtype=torch_dtype)
self._train_module(self.encoder, False)
self._train_module(self.language_model, False)
self._train_module(self.projector, True)
self.encoder.to(dtype=torch_dtype)
self.language_model.to(dtype=torch_dtype)
self.projector.to(dtype=torch_dtype)
self.tokenizer, self.audio_token_id = build_tokenizer(config.text_model_id, config.tokenizer_padding_side)
self.tokenizer_model_max_length = self.tokenizer.model_max_length
self._resize_token_embeddings(self.tokenizer)
self.get_input_embeddings().to(dtype=self.language_model.dtype)
if hasattr(self.language_model, "get_output_embeddings") and self.language_model.get_output_embeddings() is not None:
self.language_model.get_output_embeddings().to(dtype=self.language_model.dtype)
self.loss_config = LossConfig(LossFunction.KL_Divergence)
#self.loss_config.loss_function = LossFunction.KL_Divergence
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, new_emb):
return self.language_model.set_input_embeddings(new_emb)
@property
def embed_tokens(self):
return self.language_model.get_input_embeddings()
def _train_module(self, module, trainable: bool):
for param in module.parameters():
param.requires_grad= trainable
def _audio_iter(self, audio_batch_size):
audio_index = 0
for i_b, count in enumerate(audio_batch_size.view(-1).tolist()):
for _ in range(int(count)):
yield i_b, audio_index
audio_index += 1
def _resize_token_embeddings(self, tokenizer, pad_to_multiple_of=None):
model_embeds = self.language_model.resize_token_embeddings(len(tokenizer))
self.config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds
def _encode_speech(self, audio_values):
with torch.no_grad():
encoder_outputs = self.encoder(audio_values, output_hidden_states=False)
audio_embeds = encoder_outputs.last_hidden_state
downsampled_embeds = self.projector(audio_embeds) #(B, T, D)
#print(f"Shape of projector output: {downsampled_embeds.shape}")
return downsampled_embeds
def _splice_chunks(self, text_embeds, audio_embeds, audio_token_start_idx, audio_token_len, audio_batch_size):
D = text_embeds.size(-1)
for i_b, i_chunk in self._audio_iter(audio_batch_size):
start = int(audio_token_start_idx[i_chunk].item())
span = int(audio_token_len[i_chunk].item())
a = audio_embeds[i_chunk]
Ta = a.size(0)
use = min(Ta, span)
text_embeds[i_b, start:start+use, :] = a[:use].to(text_embeds.dtype)
def _compute_kl_loss(
self,
*,
student_logits: torch.Tensor,
labels: torch.Tensor,
alt_input_ids: torch.Tensor,
alt_attention_mask: torch.Tensor,
alt_labels: torch.Tensor,
past_key_values=None,
**kwargs,
):
lm_was_training = self.language_model.training
self.language_model.eval()
with torch.no_grad():
alt_input_embeds = self.language_model.get_input_embeddings()(alt_input_ids)
teacher_out = self.language_model(
inputs_embeds=alt_input_embeds,
attention_mask=alt_attention_mask,
use_cache=False,
return_dict=True,
past_key_values=past_key_values,
)
teacher_logits = teacher_out.logits
if lm_was_training:
self.language_model.train()
T = self.loss_config.kl_temperature
student = F.log_softmax(student_logits[labels != IGNORE_INDEX] / T, dim=-1)
teacher = F.softmax(teacher_logits[alt_labels != IGNORE_INDEX] / T, dim=-1)
kl = F.kl_div(student, teacher, reduction="batchmean")
return kl
def forward(
self,
input_ids,
attention_mask,
labels=None,
*,
input_features=None,
audio_token_start_idx = None,
audio_token_len = None,
audio_batch_size = None,
alt_input_ids = None,
alt_attention_mask = None,
alt_labels = None,
return_dict = True,
**kwargs):
tok = self.language_model.get_input_embeddings()
text_embeds = tok(input_ids)
if input_features is not None and audio_token_start_idx is not None:
audio_embeds = self._encode_speech(input_features)
self._splice_chunks(
text_embeds,
audio_embeds,
audio_token_start_idx,
audio_token_len,
audio_batch_size
)
out = self.language_model(
inputs_embeds=text_embeds,
attention_mask=attention_mask,
labels =labels,
return_dict=True,
use_cache = True,
)
logits = out.logits
ce_loss = out.loss
alpha = self.loss_config.ce_weight
alpha = self.loss_config.ce_weight
kl = None
if (
self.training
and alt_input_ids is not None
and alt_attention_mask is not None
and alt_labels is not None
):
kl = self._compute_kl_loss(
student_logits=logits,
labels=labels,
alt_input_ids=alt_input_ids,
alt_attention_mask=alt_attention_mask,
alt_labels=alt_labels,
past_key_values=None,
)
total_loss = alpha * ce_loss + (1 - alpha) * kl
else:
total_loss = ce_loss
return {
"loss": total_loss,
"loss_ce": ce_loss.detach() if ce_loss is not None else None,
"loss_kl": kl.detach() if kl is not None else None,
"logits": logits,}
''' if (
self.training
and self.loss_config.loss_function == LossFunction.KL_Divergence
and alt_input_ids is not None
and alt_attention_mask is not None
and alt_labels is not None
):
kl = self._compute_kl_loss(
student_logits=logits,
labels=labels,
alt_input_ids=alt_input_ids,
alt_attention_mask=alt_attention_mask,
alt_labels=alt_labels,
past_key_values=None,)
return {
"loss": kl,
"loss_ce": (ce_loss.detach() if ce_loss is not None else None),
logits: logits}
if return_dict:
return out
return (ce_loss, logits) '''
def _prepare_inputs_embeds(
self,
input_ids,
attention_mask,
*,
input_features = None,
audio_token_start_idx = None,
audio_token_len = None,
audio_batch_size= None,
):
"""
Returns:
inputs_embeds: [B, L, D] with audio spliced in
attention_mask: [B, L] (unchanged)
"""
tok = self.language_model.get_input_embeddings()
inputs_embeds = tok(input_ids) # [B, L, D]
if input_features is not None and audio_token_start_idx is not None:
# Normalize shapes: treat "one audio per sample" as N_chunks == B
feats = input_features
if feats.dim() == 3 and feats.size(0) == input_ids.size(0):
audio_batch_size = torch.ones(input_ids.size(0), dtype=torch.long, device=input_ids.device)
assert audio_batch_size is not None, "audio_batch_size required when splicing audio."
# Encode + project, then splice
audio_embeds = self._encode_audio(feats) # [N_chunks, T_audio, D]
self._splice_chunks(
text_embeds=inputs_embeds,
audio_embeds=audio_embeds,
audio_token_start_idx=audio_token_start_idx,
audio_token_len=audio_token_len,
audio_batch_size=audio_batch_size,
)
return inputs_embeds, attention_mask
@torch.no_grad()
def generate(
self,
input_ids, # [B, L]
attention_mask, # [B, L]
*,
input_features,
audio_token_start_idx= None,
audio_token_len= None,
audio_batch_size = None,
**gen_kwargs,
):
"""
Build spliced embeddings and call the base LM's generate"""
self.eval()
inputs_embeds, attn_mask = self._prepare_inputs_embeds(
input_ids=input_ids,
attention_mask=attention_mask,
input_features=input_features,
audio_token_start_idx=audio_token_start_idx,
audio_token_len=audio_token_len,
audio_batch_size=audio_batch_size,
)
return self.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attn_mask,
**gen_kwargs,
)
AutoConfig.register("babs-vlfm", VLFMConfig)
AutoModel.register(VLFMConfig, VLFMModel)