from dataclasses import dataclass import torch import torch.nn as nn from torch.nn import functional as F import math from transformers import GPT2Tokenizer import tiktoken from transformers import GPT2LMHeadModel from transformers import PretrainedConfig @dataclass class GPTConfig(PretrainedConfig): visual_size: int = 1024 vocab_size: int = 50257 block_size: int = 1024 tags_embd: int = 400 n_embd: int = 768 n_layer: int = 6 n_head: int = 12 def __init__(self,**kwargs): super().__init__(**kwargs) self.hidden_size = self.n_embd class CasualSelfAttention(nn.Module): def __init__(self, config: GPTConfig): super().__init__() assert config.n_embd % config.n_head == 0 self.c_attn = nn.Linear(config.n_embd, config.n_embd * 3) self.visual_attn = nn.Linear(config.visual_size, config.n_embd * 2) self.tags_attn = nn.Linear(config.tags_embd, config.n_embd * 2) self.c_proj = nn.Linear(config.n_embd, config.n_embd) self.n_head = config.n_head self.n_embed = config.n_embd self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.register_buffer( 'bias', torch.tril(torch.ones(199, 353)) .view(1, 1, 199, 353) ) def forward(self, x: torch.Tensor, visual_features: torch.Tensor = None, tags_embedding: torch.Tensor = None) -> torch.Tensor: B, T, C = x.size() visual_features=visual_features.to(self.device) tags_embedding=tags_embedding.to(self.device) qkv = self.c_attn(x) # the error happens here q, k, v = qkv.split(self.n_embed, dim=2) q = q.view(B, T, self.n_head, self.n_embed // self.n_head).transpose(1, 2) k = k.view(B, T, self.n_head, self.n_embed // self.n_head).transpose(1, 2) v = v.view(B, T, self.n_head, self.n_embed // self.n_head).transpose(1, 2) # Handle visual input if provided if visual_features is not None: visual_kv = self.visual_attn(visual_features) visual_k, visual_v = visual_kv.split(self.n_embed, dim=2) visual_k = visual_k.view(B, visual_features.size(1), self.n_head, self.n_embed // self.n_head).transpose(1, 2) visual_v = visual_v.view(B, visual_features.size(1), self.n_head, self.n_embed // self.n_head).transpose(1, 2) k = torch.cat([k, visual_k], dim=-2) v = torch.cat([v, visual_v], dim=-2) if tags_embedding is not None: tags_kv = self.tags_attn(tags_embedding) tags_k, tags_v = tags_kv.split(self.n_embed, dim=2) tags_k = tags_k.view(B, tags_embedding.size(1), self.n_head, self.n_embed // self.n_head).transpose(1, 2) tags_v = tags_v.view(B, tags_embedding.size(1), self.n_head, self.n_embed // self.n_head).transpose(1, 2) k = torch.cat([k, tags_k], dim=-2) v = torch.cat([v, tags_v], dim=-2) # Causal self-attention computation att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) device = att.device query_seq_len, key_seq_len = T, k.size(-2) # Text can attend to: previous text + all visual/tag tokens text_mask = torch.tril(torch.ones(T, T, device=device)) # Text-to-text causal non_text_mask = torch.ones(T, key_seq_len - T, device=device) # Text-to-other full combined_mask = torch.cat([text_mask, non_text_mask], dim=1) # Reshape for broadcasting combined_mask = combined_mask.view(1, 1, T, key_seq_len) att = att.masked_fill(combined_mask == 0, float('-inf')) att = F.softmax(att, dim=-1) visual_att = att[..., :T, T:].mean().item() # Text → Visual attention y = att @ v y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * (self.n_embed // self.n_head)) y = self.c_proj(y) return y class MLP(nn.Module): def __init__(self, config: GPTConfig): super(MLP, self).__init__() self.c_fc = nn.Linear(config.n_embd, config.n_embd * 4) # c_fc means fully connected layer and c is for context self.gelu = nn.GELU() self.c_proj = nn.Linear(config.n_embd * 4, config.n_embd) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.c_fc(x) x = self.gelu(x) x = self.c_proj(x) return x class Block(nn.Module): def __init__(self, config: GPTConfig): super(Block, self).__init__() self.ln_1 = nn.LayerNorm(config.n_embd) self.attn = CasualSelfAttention(config) self.ln_2 = nn.LayerNorm(config.n_embd) self.mlp = MLP(config) def forward(self, x: torch.Tensor,visual_features: torch.Tensor, tags_embedding: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.ln_1(x),visual_features, tags_embedding) x = x + self.mlp(self.ln_2(x)) return x class DistilGPT2(GPT2LMHeadModel): def __init__(self, config: GPTConfig): super(DistilGPT2, self).__init__(config) self.config = config self.transformer = nn.ModuleDict( { 'wte': nn.Embedding(config.vocab_size, config.n_embd), 'wpe': nn.Embedding(config.block_size, config.n_embd), 'h': nn.ModuleList( [ Block(config) for _ in range(config.n_layer) ] ), # transformer blocks 'ln_f': nn.LayerNorm(config.n_embd) # final layer normalization } ) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # linear layer for projection from embedding to vocab size def forward(self, idx: torch.Tensor, visual_features: torch.Tensor = None, tags_embedding: torch.Tensor = None, return_dict: bool = False) -> torch.Tensor: idx=idx.to(self.device) B, T = idx.size() assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, Block size is {self.config.block_size}" # forward the token and positional embeddings pos = torch.arange(0, T, dtype=torch.long, device=idx.device) pos_emb = self.transformer['wpe'](pos) tok_emb = self.transformer['wte'](idx) x = tok_emb + pos_emb # forward the transformer for block in self.transformer['h']: x = block(x, visual_features=visual_features, tags_embedding=tags_embedding) # forward the head x = self.transformer['ln_f'](x) logits = self.lm_head(x) if return_dict: return {'logits': logits} else: return logits @classmethod def from_pretrained(cls, model_type: str): """Loads pre-trained GPT-2 model weights from Hugging Face and handles custom layers.""" from transformers import GPT2LMHeadModel print(f"Loading weights from pre-trained GPT: {model_type}") # Ensure the model type is supported assert model_type in {'distilgpt2', 'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} # Define configurations based on the model type config_args = { 'distilgpt2': dict(n_layer=6, n_head=12, n_embd=768), 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), }[model_type] config_args['vocab_size'] = 50257 config_args['block_size'] = 1024 # Initialize the custom model with the given configuration config = GPTConfig(**config_args) from transformers import GPT2Config config = GPT2Config.from_pretrained('distilgpt2') config.visual_size=1024 config.block_size=1024 config.tags_embd=400 config.n_embd=768 config.n_layer=6 config.n_head=12 model = cls(config) # Load state dictionary from Hugging Face model model_hf = GPT2LMHeadModel.from_pretrained(model_type) sd_hf = model_hf.state_dict() # State dictionary of the custom model sd = model.state_dict() # Filter out custom keys that are not in the pre-trained model custom_keys = {k for k in sd if 'visual_attn' in k or 'tags_attn' in k} sd_keys_filtered = [k for k in sd if k not in custom_keys] # Load matching keys for k in sd_keys_filtered: if k in sd_hf and sd_hf[k].shape == sd[k].shape: with torch.no_grad(): sd[k].copy_(sd_hf[k]) # Initialize custom layers separately for k in custom_keys: with torch.no_grad(): print(f"Initializing custom layer: {k}") sd[k].normal_(0.0, 0.02) # Adjust initialization method as needed # Update the model's state dictionary model.load_state_dict(sd, strict=False) return model def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): # Prepare inputs for autoregressive generation inputs = {"idx": input_ids} if past: inputs["past_key_values"] = past # Include past key values for caching # Include additional features like visual and tags if provided if "visual_features" in kwargs: inputs["visual_features"] = kwargs["visual_features"] if "tags_embedding" in kwargs: inputs["tags_embedding"] = kwargs["tags_embedding"] return inputs def generate( self, input_ids: torch.Tensor = None, max_length: int = None, min_length: int = None, do_sample: bool = None, early_stopping: bool = None, num_beams: int = None, temperature: float = None, top_k: int = None, top_p: float = None, repetition_penalty: float = None, bos_token_id: int = None, pad_token_id: int = None, eos_token_ids: int = None, length_penalty: float = None, no_repeat_ngram_size: int = None, num_return_sequences: int = None, attention_mask: torch.Tensor = None, visual_features: torch.Tensor = None, tags_embedding: torch.Tensor = None, ): """ Generate sequences using autoregressive decoding. Args: input_ids (torch.Tensor): Input tensor of token IDs. max_length (int): Maximum length of the generated sequence. min_length (int): Minimum length of the generated sequence. do_sample (bool): Whether to use sampling; if False, uses greedy decoding. early_stopping (bool): Whether to stop when all beams have finished. num_beams (int): Number of beams for beam search. temperature (float): Sampling temperature. top_k (int): Top-k sampling. top_p (float): Top-p (nucleus) sampling. repetition_penalty (float): Penalty for repeated n-grams. bos_token_id (int): Beginning of sequence token ID. pad_token_id (int): Padding token ID. eos_token_ids (int): End of sequence token ID. length_penalty (float): Beam search length penalty. no_repeat_ngram_size (int): Size of n-grams not to repeat. num_return_sequences (int): Number of sequences to return. attention_mask (torch.Tensor): Attention mask for padding tokens. visual_features (torch.Tensor): Visual features for the transformer. tags_embedding (torch.Tensor): Tags embeddings for the transformer. Returns: torch.Tensor: Generated sequences of token IDs. """ # Default values for unspecified parameters max_length = max_length or self.config.block_size min_length = min_length or 0 do_sample = do_sample or False early_stopping = early_stopping or False num_beams = num_beams or 1 temperature = temperature or 1.0 top_k = top_k or 0 top_p = top_p or 1.0 repetition_penalty = repetition_penalty or 1.0 bos_token_id = bos_token_id or self.config.bos_token_id pad_token_id = pad_token_id or self.config.pad_token_id eos_token_ids = eos_token_ids or self.config.eos_token_ids length_penalty = length_penalty or 1.0 no_repeat_ngram_size = no_repeat_ngram_size or 0 num_return_sequences = num_return_sequences or 1 if input_ids is not None: batch_size=input_ids.shape[0] else: batch_size=1 if input_ids is None: assert isinstance(bos_token_id, int) and bos_token_id >= 0, ( "You should either supply a context to complete as `input_ids` input " "or a `bos_token_id` (integer >= 0) as a first token to start the generation." ) input_ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long) else: assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)." # Avoid duplicate outputs when greedy decoding if not do_sample: if num_beams == 1: assert num_return_sequences == 1, ( "Greedy decoding will always produce the same output for num_beams == 1 " "and num_return_sequences > 1. Please set num_return_sequences = 1." ) else: assert num_beams >= num_return_sequences, ( "Greedy beam search decoding cannot return more sequences than it has beams. " "Please set num_beams >= num_return_sequences." ) # Create attention mask if necessary if attention_mask is None: if pad_token_id is not None and pad_token_id in input_ids: attention_mask = (input_ids != pad_token_id).long() else: attention_mask = torch.ones_like(input_ids) # Set pad_token_id if not provided and eos_token_ids is available if pad_token_id is None and eos_token_ids is not None: pad_token_id = eos_token_ids print(f"Setting `pad_token_id` to {pad_token_id} (first `eos_token_ids`) to generate sequence.") # Current sequence length and vocabulary size cur_len = input_ids.size(1) vocab_size = self.config.vocab_size # Adjust effective batch size and multiplier for sampling if do_sample: effective_batch_size = batch_size * num_return_sequences effective_batch_mult = num_return_sequences else: effective_batch_size = batch_size effective_batch_mult = 1 # Expand input_ids and attention_mask for beam search or multiple return sequences if num_return_sequences > 1 or num_beams > 1: input_ids_len = input_ids.size(-1) # Expand dimensions and repeat for each beam and return sequence input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len) attention_mask = attention_mask.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len) # Reshape to combine batch and beam dimensions input_ids = input_ids.reshape(effective_batch_size * num_beams, input_ids_len) attention_mask = attention_mask.reshape(effective_batch_size * num_beams, input_ids_len) if num_beams > 1: output = self._generate_beam_search( input_ids=input_ids, attention_mask=attention_mask, visual_features=visual_features, tags_embedding=tags_embedding, cur_len=input_ids.size(1), max_length=max_length, min_length=min_length, do_sample=do_sample, early_stopping=early_stopping, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, pad_token_id=pad_token_id, eos_token_ids=eos_token_ids, length_penalty=length_penalty, num_return_sequences=num_return_sequences, num_beams=num_beams, ) else: output = self._generate_no_beam_search( input_ids=input_ids, attention_mask=attention_mask, visual_features=visual_features, tags_embedding=tags_embedding, cur_len=input_ids.size(1), max_length=max_length, min_length=min_length, do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, pad_token_id=pad_token_id, eos_token_ids=eos_token_ids, batch_size=batch_size, vocab_size=vocab_size, ) return output def _generate_no_beam_search( self, input_ids, visual_features, tags_embedding, cur_len, max_length, min_length, do_sample, temperature, top_k, top_p, repetition_penalty, no_repeat_ngram_size, pad_token_id, eos_token_ids, batch_size, vocab_size, attention_mask, ): """ Generate sequences for each example without beam search (num_beams == 1). All returned sequences are generated independently. """ # Track unfinished sentences and their lengths unfinished_sents=torch.ones_like(input_ids[:,0]) sent_lengths=torch.ones_like(input_ids[:,0])*max_length past=None while cur_len < max_length: if past is None: inputs = input_ids else: inputs = input_ids[:, -1].unsqueeze(1) model_inputs = self.prepare_inputs_for_generation( inputs, past=past, visual_features=visual_features, tags_embedding=tags_embedding ) outputs = self(**model_inputs) # next_token_logits = outputs[0][-1, :] # Extract logits for the last token, shape: [batch_size, vocab_size] next_token_logits = outputs[:, -1, :] # next_token_logits = next_token_logits.unsqueeze(0) # Add a new dimension: [1, batch_size, vocab_size] next_token_logits = next_token_logits.expand(batch_size, vocab_size) # Expand to match batch size: [batch_size, vocab_size] # if self._do_output_past(outputs): # we dont have this function implemented # past = outputs[1] # Apply repetition penalty if repetition_penalty != 1.0: next_token_logits_penalties=self._create_next_token_logits_penalties(input_ids,next_token_logits,repetition_penalty) next_token_logits=next_token_logits @ next_token_logits_penalties.T # .T de mn 3ndy # Prevent repetition of n-grams if no_repeat_ngram_size > 0: # not checked generated by chat banned_tokens=self.calc_banned_ngram_tokens(input_ids,batch_size,no_repeat_ngram_size,cur_len) # not checked generated by chat banned_tokens_indices_mask=[] for banned_tokens_slice in banned_tokens: banned_tokens_indices_mask.append( [True if token in banned_tokens_slice else False for token in range(vocab_size)] ) banned_tokens_indices_mask=torch.tensor(banned_tokens_indices_mask,dtype=bool) next_token_logits[banned_tokens_indices_mask]= -float('inf') # Min length constraint for EOS if eos_token_ids is not None and cur_len < min_length: # create eos_token_id boolean mask is_token_logit_eos_token = torch.arange(vocab_size, device=next_token_logits.device) == eos_token_ids eos_token_indices_mask = is_token_logit_eos_token.unsqueeze(0).expand(batch_size, -1) # next_token_logits=next_token_logits.unsqueeze(0).expand(batch_size,vocab_size) next_token_logits = next_token_logits.masked_fill(eos_token_indices_mask, -float("inf")) # Sampling or greedy decoding if do_sample: if temperature != 1.0: next_token_logits = next_token_logits / temperature next_token_logits=self.top_k_top_p_filtering(next_token_logits,top_k=top_k,top_p=top_p) next_token = torch.multinomial(torch.softmax(next_token_logits, dim=-1), num_samples=1).squeeze(1) else: next_token=torch.argmax(next_token_logits,dim=-1) if eos_token_ids is not None: unfinished_sents=unfinished_sents.to(self.device) tokens_to_add = next_token * unfinished_sents + pad_token_id * (1 - unfinished_sents) else: tokens_to_add = next_token input_ids=input_ids.to(self.device) input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=1) if eos_token_ids is not None: eos_in_sents = tokens_to_add == eos_token_ids # If sentence is unfinished and the token to add is eos, sent_lengths is filled with current length is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents * eos_in_sents.int() sent_lengths=sent_lengths.to(self.device) sent_lengths = ( sent_lengths * (1 - is_sents_unfinished_and_token_to_add_is_eos) + cur_len * is_sents_unfinished_and_token_to_add_is_eos ) # Unfinished sentences are set to zero if eos is in the sentence unfinished_sents -= is_sents_unfinished_and_token_to_add_is_eos # Stop if there is a in each sentence, or if we exceed the maximum length if torch.max(unfinished_sents) == 0: # => this line is what keeps it stopping at 57 etc.. break cur_len += 1 # Pad sequences if necessary min_sent_length = sent_lengths.min() max_sent_length = sent_lengths.max() if min_sent_length != max_sent_length: assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths" padding = torch.ones((batch_size, max_sent_length), dtype=torch.int) * pad_token_id broad_casted_sent_lengths = sent_lengths.unsqueeze(-1).expand(batch_size, max_sent_length) broad_casted_range = torch.arange(max_sent_length).unsqueeze(0).expand(batch_size, max_sent_length).T # Use torch.where to apply padding where necessary decoded = torch.where(broad_casted_range < broad_casted_sent_lengths, input_ids, padding) else: decoded = input_ids return decoded def _create_next_token_logits_penalties(self,input_ids, logits, repetition_penalty): """ Create logit penalties for already seen input_ids based on repetition penalty. Args: input_ids (torch.Tensor): Tensor of shape (batch_size, seq_len) containing input token IDs. logits (torch.Tensor): Tensor of shape (batch_size, vocab_size) containing next-token logits. repetition_penalty (float): The penalty to apply for repeated tokens. Returns: torch.Tensor: Tensor of shape (batch_size, vocab_size) with applied penalties. """ token_penalties=torch.ones_like(logits) prev_input_ids=[torch.unique(input_id) for input_id in input_ids] for i, prev_input_id in enumerate(prev_input_ids): logits_penalized=logits[i][prev_input_ids] logit_penalties=torch.zeros_like(logits_penalized) logit_penalties[logits_penalized<0]=repetition_penalty logit_penalties[logits_penalized>0]=1/repetition_penalty token_penalties[i].scatter_(0,prev_input_id,logit_penalties) return token_penalties def top_k_top_p_filtering(self,logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1): """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering. Args: logits: Logits distribution of shape (batch size, vocabulary size). top_k (int): Keep only top k tokens with the highest probability. top_p (float): Keep the top tokens with cumulative probability >= top_p (nucleus filtering). filter_value (float): Value to assign to filtered logits. min_tokens_to_keep (int): Ensure at least this many tokens are kept. Returns: torch.Tensor: Filtered logits. """ logits_shape = logits.size() # Top-k filtering if top_k > 0: top_k = min(max(top_k, min_tokens_to_keep), logits_shape[-1]) # Safety check # Remove all tokens with a probability less than the last token of the top-k top_k_values, _ = torch.topk(logits, top_k, dim=-1) min_top_k_values = top_k_values[:, -1].unsqueeze(-1) # Minimum logit in top-k logits = torch.where(logits < min_top_k_values, torch.full_like(logits, filter_value), logits) # Top-p (nucleus) filtering if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p if min_tokens_to_keep > 1: # Ensure we keep at least min_tokens_to_keep tokens sorted_indices_to_remove[:, :min_tokens_to_keep] = 0 # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dims=-1) sorted_indices_to_remove[:, 0] = 0 # Scatter sorted indices back to original indexing indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits = torch.where(indices_to_remove, torch.full_like(logits, filter_value), logits) return logits def calc_banned_ngram_tokens(self,prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len): """ Calculate banned n-gram tokens for no-repeat n-gram constraints. Args: prev_input_ids (torch.Tensor): Tensor of shape (num_hypos, seq_len) containing token sequences. num_hypos (int): Number of hypotheses in the batch. no_repeat_ngram_size (int): Size of the n-grams to avoid repeating. cur_len (int): Current length of the sequence being generated. Returns: List[List[int]]: List of banned tokens for each hypothesis. """ if cur_len + 1 < no_repeat_ngram_size: # Return no banned tokens if not enough tokens have been generated return [[] for _ in range(num_hypos)] # Dictionary to store generated n-grams for each hypothesis generated_ngrams = [{} for _ in range(num_hypos)] # Populate the n-grams for idx in range(num_hypos): gen_tokens = prev_input_ids[idx].tolist() generated_ngram = generated_ngrams[idx] for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]): prev_ngram_tuple = tuple(ngram[:-1]) generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] def _get_generated_ngrams(hypo_idx): # Get n-grams that have already appeared start_idx = cur_len + 1 - no_repeat_ngram_size ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist()) return generated_ngrams[hypo_idx].get(ngram_idx, []) # Calculate banned tokens for each hypothesis banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)] return banned_tokens