Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoConfig | |
| from .modeling_utils import ConfigMixin, ModelMixin, register_to_config | |
| from .sampling import cosine_schedule, mask_by_random_topk | |
| from .phi import PhiForCausalLM | |
| try: | |
| import xformers.ops as xops | |
| is_xformers_available = True | |
| except ImportError: | |
| is_xformers_available = False | |
| class Showo(ModelMixin, ConfigMixin): | |
| _supports_gradient_checkpointing = True | |
| def __init__( | |
| self, | |
| w_clip_vit, | |
| vocab_size, | |
| llm_vocab_size, | |
| llm_model_path='', | |
| codebook_size=8192, | |
| num_vq_tokens=256, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.vocab_size = vocab_size | |
| self.register_to_config(mask_token_id=vocab_size - 1) | |
| config = AutoConfig.from_pretrained(llm_model_path) | |
| self.showo = PhiForCausalLM(config) | |
| self.showo.resize_token_embeddings(self.vocab_size) | |
| self.output_size = self.vocab_size | |
| if self.w_clip_vit: | |
| self.mm_projector = torch.nn.Sequential( | |
| torch.nn.Linear(1024, 2048), | |
| torch.nn.GELU(), | |
| torch.nn.Linear(2048, 2048) | |
| ) | |
| def _set_gradient_checkpointing(self, module, value=False): | |
| self.gradient_checkpointing = True | |
| def forward( | |
| self, | |
| input_ids, | |
| input_embeddings=None, | |
| attention_mask=None, | |
| labels=None, | |
| label_smoothing=0.0, | |
| config=None, | |
| labels_mask_text=None, | |
| labels_mask_image=None, | |
| **kwargs, | |
| ): | |
| if input_embeddings is None: | |
| logits = self.showo(input_ids=input_ids, attention_mask=attention_mask)['logits'] | |
| else: | |
| logits = self.showo(inputs_embeds=input_embeddings, attention_mask=attention_mask)['logits'] | |
| if labels is not None: | |
| raise NotImplementedError | |
| return logits | |
| def t2i_generate( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| uncond_input_ids: torch.LongTensor = None, | |
| attention_mask=None, | |
| temperature=1.0, | |
| timesteps=18, # ideal number of steps is 18 in maskgit paper | |
| guidance_scale=0, | |
| noise_schedule=cosine_schedule, | |
| generator: torch.Generator = None, | |
| uni_prompting=None, | |
| config=None, | |
| **kwargs, | |
| ): | |
| """ | |
| Generate 1:1 similar to the original MaskGit repo | |
| https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79 | |
| """ | |
| # begin with all image token ids masked | |
| mask_token_id = self.config.mask_token_id | |
| seq_len = config.model.showo.num_vq_tokens | |
| input_ids_minus_lm_vocab_size = input_ids[:, -(seq_len + 1):-1].clone() | |
| input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id, | |
| mask_token_id, | |
| input_ids_minus_lm_vocab_size - config.model.showo.llm_vocab_size - 10) | |
| # import ipdb | |
| # ipdb.set_trace() | |
| if uncond_input_ids is not None: | |
| uncond_prefix = uncond_input_ids[:, :config.dataset.preprocessing.max_seq_length + 1] | |
| for step in range(timesteps): | |
| if uncond_input_ids is not None and guidance_scale > 0: | |
| uncond_input_ids = torch.cat( | |
| [uncond_prefix, input_ids[:, config.dataset.preprocessing.max_seq_length + 1:]], dim=1) | |
| model_input = torch.cat([input_ids, uncond_input_ids]) | |
| cond_logits, uncond_logits = self(model_input, attention_mask=attention_mask).chunk(2) | |
| # logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits) | |
| # it seems that muse has different cfg setting | |
| logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits | |
| logits = logits[:, -(seq_len + 1):-1, config.model.showo.llm_vocab_size + 10:-1] | |
| else: | |
| logits = self(input_ids, attention_mask=attention_mask) | |
| logits = logits[:, -(seq_len + 1):-1, config.model.showo.llm_vocab_size + 10:-1] | |
| probs = logits.softmax(dim=-1) | |
| sampled = probs.reshape(-1, logits.size(-1)) | |
| sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) | |
| unknown_map = input_ids_minus_lm_vocab_size == mask_token_id | |
| sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size) | |
| # Defines the mask ratio for the next round. The number to mask out is | |
| # determined by mask_ratio * unknown_number_in_the_beginning. | |
| ratio = 1.0 * (step + 1) / timesteps | |
| mask_ratio = noise_schedule(torch.tensor(ratio)) | |
| # Computes the probabilities of each selected tokens. | |
| selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]) | |
| selected_probs = selected_probs.squeeze(-1) | |
| # Ignores the tokens given in the input by overwriting their confidence. | |
| selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) | |
| # Gets mask lens for each sample in the batch according to the mask ratio. | |
| mask_len = (seq_len * mask_ratio).floor().unsqueeze(0).to(logits.device) | |
| # Keeps at least one of prediction in this round and also masks out at least | |
| # one and for the next iteration | |
| mask_len = torch.max( | |
| torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) | |
| ) | |
| # Adds noise for randomness | |
| temperature = temperature * (1.0 - ratio) | |
| masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) | |
| # Masks tokens with lower confidence. | |
| input_ids[:, -(seq_len + 1):-1] = torch.where(masking, mask_token_id, | |
| sampled_ids + config.model.showo.llm_vocab_size + 10) | |
| input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids) | |
| return sampled_ids | |
| def mmu_generate(self, idx=None, input_embeddings=None, attention_mask=None, max_new_tokens=100, temperature=1.0, top_k=None, eot_token=None): | |
| """ | |
| Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete | |
| the sequence max_new_tokens times, feeding the predictions back into the model each time. | |
| Most likely you'll want to make sure to be in model.eval() mode of operation for this. | |
| """ | |
| try: | |
| device = idx.device | |
| except: | |
| device = input_embeddings.device | |
| result = [] | |
| for _ in range(max_new_tokens): | |
| # if the sequence context is growing too long we must crop it at block_size | |
| # idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] | |
| # forward the model to get the logits for the index in the sequence | |
| # logits, _ = self(idx_cond) | |
| logits = self(idx, input_embeddings=input_embeddings, attention_mask=attention_mask) | |
| L = attention_mask.shape[-1] | |
| attention_mask = attention_mask.squeeze() | |
| attention_mask_a = torch.hstack( | |
| [ | |
| attention_mask, # L, L | |
| torch.zeros((L, 1)).to(device) + torch.finfo(logits.dtype).min, | |
| ] | |
| ) | |
| attention_mask_b = torch.vstack( | |
| [ | |
| attention_mask_a, # L, L+1 | |
| torch.hstack([attention_mask[-1, :], torch.tensor([0]).to(device)]).unsqueeze(0), | |
| ] | |
| ) | |
| attention_mask = attention_mask_b | |
| # pluck the logits at the final step and scale by desired temperature | |
| logits = logits[:, -1, :] / temperature | |
| # optionally crop the logits to only the top k options | |
| if top_k is not None: | |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| logits[logits < v[:, [-1]]] = -float('Inf') | |
| # apply softmax to convert logits to (normalized) probabilities | |
| probs = F.softmax(logits, dim=-1) | |
| # sample from the distribution | |
| idx_next = torch.multinomial(probs, num_samples=1) | |
| result.append(idx_next[0][0]) | |
| # append sampled index to the running sequence and continue | |
| if self.config.w_clip_vit: | |
| idx_next_embeddings = self.showo.model.embed_tokens(idx_next) | |
| input_embeddings = torch.cat([input_embeddings, idx_next_embeddings], dim=1) | |
| else: | |
| idx = torch.cat((idx, idx_next), dim=1) | |
| if eot_token is not None and idx_next.cpu() == eot_token: | |
| break | |
| return result | |