Yukang commited on
Commit
26a9012
·
verified ·
1 Parent(s): 9c32499

Upload modeling_vila.py

Browse files
Files changed (1) hide show
  1. modeling_vila.py +9 -1
modeling_vila.py CHANGED
@@ -212,6 +212,7 @@ class VILAPretrainedModel(PreTrainedModel):
212
  self.vision_tower = self.vision_tower.cuda()
213
  # set device_map auto can autoamtically shard llm to different devices
214
  self.llm, self.tokenizer = self.init_llm(llm_cfg, config, device_map=device_map)
 
215
 
216
  # NOTE(ligeng): hard code to set padding_side to left
217
  self.tokenizer.padding_side = "left"
@@ -221,6 +222,12 @@ class VILAPretrainedModel(PreTrainedModel):
221
  self.post_config()
222
  self.is_loaded = True
223
 
 
 
 
 
 
 
224
  assert (
225
  self.llm is not None or self.vision_tower is not None or self.mm_projector is not None
226
  ), "At least one of the components must be instantiated."
@@ -628,7 +635,7 @@ class VILAForCasualLM(VILAPretrainedModel):
628
  self.encoders[name].end_tokens = None
629
 
630
  # Extract text and media embeddings
631
- text_embeds = self.llm.model.embed_tokens(input_ids)
632
  if media is not None:
633
  media_embeds = self.__embed_media_tokens(media, media_config)
634
  else:
@@ -712,6 +719,7 @@ class VILAForCasualLM(VILAPretrainedModel):
712
  dummy = torch.zeros(infos[0]["shape"], dtype=infos[0]["dtype"], device=self.device)
713
  embeds["dummy"].extend(self.encoders[name]([dummy], media_config[name]))
714
  continue
 
715
  embeds[name] = deque(self.encoders[name](media[name], media_config[name]))
716
  return embeds
717
 
 
212
  self.vision_tower = self.vision_tower.cuda()
213
  # set device_map auto can autoamtically shard llm to different devices
214
  self.llm, self.tokenizer = self.init_llm(llm_cfg, config, device_map=device_map)
215
+ self.llm_model_embed_tokens = self.llm.model.embed_tokens
216
 
217
  # NOTE(ligeng): hard code to set padding_side to left
218
  self.tokenizer.padding_side = "left"
 
222
  self.post_config()
223
  self.is_loaded = True
224
 
225
+ self.llm_only_need_embed = kwargs.get("llm_only_need_embed", False)
226
+ if self.llm_only_need_embed:
227
+ print("We only need the embed_tokens in llm.")
228
+ del self.llm
229
+ self.llm = None
230
+
231
  assert (
232
  self.llm is not None or self.vision_tower is not None or self.mm_projector is not None
233
  ), "At least one of the components must be instantiated."
 
635
  self.encoders[name].end_tokens = None
636
 
637
  # Extract text and media embeddings
638
+ text_embeds = self.llm_model_embed_tokens(input_ids)
639
  if media is not None:
640
  media_embeds = self.__embed_media_tokens(media, media_config)
641
  else:
 
719
  dummy = torch.zeros(infos[0]["shape"], dtype=infos[0]["dtype"], device=self.device)
720
  embeds["dummy"].extend(self.encoders[name]([dummy], media_config[name]))
721
  continue
722
+ media[name] = [a.to(torch.bfloat16) for a in media[name]]
723
  embeds[name] = deque(self.encoders[name](media[name], media_config[name]))
724
  return embeds
725