Upload modeling_vila.py
Browse files- modeling_vila.py +9 -1
modeling_vila.py
CHANGED
@@ -212,6 +212,7 @@ class VILAPretrainedModel(PreTrainedModel):
|
|
212 |
self.vision_tower = self.vision_tower.cuda()
|
213 |
# set device_map auto can autoamtically shard llm to different devices
|
214 |
self.llm, self.tokenizer = self.init_llm(llm_cfg, config, device_map=device_map)
|
|
|
215 |
|
216 |
# NOTE(ligeng): hard code to set padding_side to left
|
217 |
self.tokenizer.padding_side = "left"
|
@@ -221,6 +222,12 @@ class VILAPretrainedModel(PreTrainedModel):
|
|
221 |
self.post_config()
|
222 |
self.is_loaded = True
|
223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
assert (
|
225 |
self.llm is not None or self.vision_tower is not None or self.mm_projector is not None
|
226 |
), "At least one of the components must be instantiated."
|
@@ -628,7 +635,7 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
628 |
self.encoders[name].end_tokens = None
|
629 |
|
630 |
# Extract text and media embeddings
|
631 |
-
text_embeds = self.
|
632 |
if media is not None:
|
633 |
media_embeds = self.__embed_media_tokens(media, media_config)
|
634 |
else:
|
@@ -712,6 +719,7 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
712 |
dummy = torch.zeros(infos[0]["shape"], dtype=infos[0]["dtype"], device=self.device)
|
713 |
embeds["dummy"].extend(self.encoders[name]([dummy], media_config[name]))
|
714 |
continue
|
|
|
715 |
embeds[name] = deque(self.encoders[name](media[name], media_config[name]))
|
716 |
return embeds
|
717 |
|
|
|
212 |
self.vision_tower = self.vision_tower.cuda()
|
213 |
# set device_map auto can autoamtically shard llm to different devices
|
214 |
self.llm, self.tokenizer = self.init_llm(llm_cfg, config, device_map=device_map)
|
215 |
+
self.llm_model_embed_tokens = self.llm.model.embed_tokens
|
216 |
|
217 |
# NOTE(ligeng): hard code to set padding_side to left
|
218 |
self.tokenizer.padding_side = "left"
|
|
|
222 |
self.post_config()
|
223 |
self.is_loaded = True
|
224 |
|
225 |
+
self.llm_only_need_embed = kwargs.get("llm_only_need_embed", False)
|
226 |
+
if self.llm_only_need_embed:
|
227 |
+
print("We only need the embed_tokens in llm.")
|
228 |
+
del self.llm
|
229 |
+
self.llm = None
|
230 |
+
|
231 |
assert (
|
232 |
self.llm is not None or self.vision_tower is not None or self.mm_projector is not None
|
233 |
), "At least one of the components must be instantiated."
|
|
|
635 |
self.encoders[name].end_tokens = None
|
636 |
|
637 |
# Extract text and media embeddings
|
638 |
+
text_embeds = self.llm_model_embed_tokens(input_ids)
|
639 |
if media is not None:
|
640 |
media_embeds = self.__embed_media_tokens(media, media_config)
|
641 |
else:
|
|
|
719 |
dummy = torch.zeros(infos[0]["shape"], dtype=infos[0]["dtype"], device=self.device)
|
720 |
embeds["dummy"].extend(self.encoders[name]([dummy], media_config[name]))
|
721 |
continue
|
722 |
+
media[name] = [a.to(torch.bfloat16) for a in media[name]]
|
723 |
embeds[name] = deque(self.encoders[name](media[name], media_config[name]))
|
724 |
return embeds
|
725 |
|