kosung commited on
Commit
89146e4
·
verified ·
1 Parent(s): 40ed72b

Update custom_st.py

Browse files
Files changed (1) hide show
  1. custom_st.py +8 -2
custom_st.py CHANGED
@@ -1,10 +1,13 @@
 
 
1
  from io import BytesIO
2
  from typing import Any, Dict, Optional, List
3
  import torch
4
  from PIL import Image
5
  from sentence_transformers.models import Transformer as BaseTransformer
6
  from transformers import AutoModelForVision2Seq, AutoProcessor
7
-
 
8
 
9
  class MultiModalTransformer(BaseTransformer):
10
  def __init__(
@@ -51,7 +54,10 @@ class MultiModalTransformer(BaseTransformer):
51
  self, features: Dict[str, torch.Tensor], **kwargs
52
  ) -> Dict[str, torch.Tensor]:
53
  if features.get("inputs_embeds", None) is None:
54
- features["inputs_embeds"] = self.auto_model.base_model.embed_tokens(features["input_ids"])
 
 
 
55
  if features.get("pixel_values", None) is not None:
56
  features["pixel_values"] = features["pixel_values"].type(self.auto_model.visual.get_dtype())
57
  image_embeds = self.auto_model.visual(
 
1
+ import math
2
+ import logging
3
  from io import BytesIO
4
  from typing import Any, Dict, Optional, List
5
  import torch
6
  from PIL import Image
7
  from sentence_transformers.models import Transformer as BaseTransformer
8
  from transformers import AutoModelForVision2Seq, AutoProcessor
9
+ from packaging import version
10
+ import transformers
11
 
12
  class MultiModalTransformer(BaseTransformer):
13
  def __init__(
 
54
  self, features: Dict[str, torch.Tensor], **kwargs
55
  ) -> Dict[str, torch.Tensor]:
56
  if features.get("inputs_embeds", None) is None:
57
+ if version.parse(transformers.__version__) >= version.parse("4.52.0"):
58
+ features["inputs_embeds"] = self.auto_model.base_model.language_model.embed_tokens(features["input_ids"])
59
+ else:
60
+ features["inputs_embeds"] = self.auto_model.base_model.embed_tokens(features["input_ids"])
61
  if features.get("pixel_values", None) is not None:
62
  features["pixel_values"] = features["pixel_values"].type(self.auto_model.visual.get_dtype())
63
  image_embeds = self.auto_model.visual(