Update README.md
Browse files
README.md
CHANGED
@@ -51,7 +51,8 @@ from transformers import AutoModel, AutoTokenizer
|
|
51 |
import torch
|
52 |
import torch.nn.functional as F
|
53 |
from PIL import Image
|
54 |
-
import
|
|
|
55 |
|
56 |
def weighted_mean_pooling(hidden, attention_mask):
|
57 |
attention_mask_ = attention_mask * attention_mask.cumsum(dim=1)
|
@@ -83,20 +84,25 @@ def encode(text_or_image_list):
|
|
83 |
embeddings = F.normalize(reps, p=2, dim=1).detach().cpu().numpy()
|
84 |
return embeddings
|
85 |
|
86 |
-
tokenizer = AutoTokenizer.from_pretrained("
|
87 |
-
model = AutoModel.from_pretrained("
|
88 |
model.eval()
|
89 |
|
90 |
-
script_dir = os.path.dirname(os.path.realpath(__file__))
|
91 |
queries = ["What does a dog look like?"]
|
92 |
-
passages = [
|
93 |
-
Image.open(os.path.join(script_dir, 'test_image/cat.jpeg')).convert('RGB'),
|
94 |
-
Image.open(os.path.join(script_dir, 'test_image/dog.jpg')).convert('RGB'),
|
95 |
-
]
|
96 |
-
|
97 |
INSTRUCTION = "Represent this query for retrieving relevant documents: "
|
98 |
queries = [INSTRUCTION + query for query in queries]
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
embeddings_query = encode(queries)
|
101 |
embeddings_doc = encode(passages)
|
102 |
|
|
|
51 |
import torch
|
52 |
import torch.nn.functional as F
|
53 |
from PIL import Image
|
54 |
+
import requests
|
55 |
+
from io import BytesIO
|
56 |
|
57 |
def weighted_mean_pooling(hidden, attention_mask):
|
58 |
attention_mask_ = attention_mask * attention_mask.cumsum(dim=1)
|
|
|
84 |
embeddings = F.normalize(reps, p=2, dim=1).detach().cpu().numpy()
|
85 |
return embeddings
|
86 |
|
87 |
+
tokenizer = AutoTokenizer.from_pretrained("/mnt/data/user/tc_agi/klara/datasets/visrag_ret/visrag_ret", trust_remote_code=True)
|
88 |
+
model = AutoModel.from_pretrained("/mnt/data/user/tc_agi/klara/datasets/visrag_ret/visrag_ret", torch_dtype=torch.bfloat16, trust_remote_code=True)
|
89 |
model.eval()
|
90 |
|
|
|
91 |
queries = ["What does a dog look like?"]
|
|
|
|
|
|
|
|
|
|
|
92 |
INSTRUCTION = "Represent this query for retrieving relevant documents: "
|
93 |
queries = [INSTRUCTION + query for query in queries]
|
94 |
|
95 |
+
print("Downloading images...")
|
96 |
+
passages = [
|
97 |
+
Image.open(BytesIO(requests.get(
|
98 |
+
'https://github.com/OpenBMB/VisRAG/raw/refs/heads/master/scripts/demo/retriever/test_image/cat.jpeg'
|
99 |
+
).content)).convert('RGB'),
|
100 |
+
Image.open(BytesIO(requests.get(
|
101 |
+
'https://github.com/OpenBMB/VisRAG/raw/refs/heads/master/scripts/demo/retriever/test_image/dog.jpg'
|
102 |
+
).content)).convert('RGB')
|
103 |
+
]
|
104 |
+
print("Images downloaded.")
|
105 |
+
|
106 |
embeddings_query = encode(queries)
|
107 |
embeddings_doc = encode(passages)
|
108 |
|