Visual Document Retrieval
Transformers
Safetensors
ColPali
English
colqwen2
pretraining
tonywu71 commited on
Commit
0d3e414
·
verified ·
1 Parent(s): d6c87db

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +30 -16
README.md CHANGED
@@ -12,9 +12,6 @@ base_model:
12
  pipeline_tag: visual-document-retrieval
13
  ---
14
 
15
- > [!WARNING]
16
- > EXPERIMENTAL: Wait for https://github.com/huggingface/transformers/pull/35778 to be merged before using!
17
-
18
  > [!IMPORTANT]
19
  > This version of ColQwen2 should be loaded with the `transformers 🤗` release, not with `colpali-engine`.
20
  > It was converted using the `convert_colqwen2_weights_to_hf.py` script
@@ -44,6 +41,7 @@ A validation set is created with 2% of the samples to tune hyperparameters.
44
  ## Usage
45
 
46
  ```python
 
47
  import torch
48
  from PIL import Image
49
 
@@ -51,39 +49,55 @@ from transformers import ColQwen2ForRetrieval, ColQwen2Processor
51
  from transformers.utils.import_utils import is_flash_attn_2_available
52
 
53
 
 
54
  model_name = "vidore/colqwen2-v1.0-hf"
55
 
56
  model = ColQwen2ForRetrieval.from_pretrained(
57
  model_name,
58
  torch_dtype=torch.bfloat16,
59
- device_map="cuda:0", # or "mps" if on Apple Silicon
60
- attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
61
- ).eval()
62
-
63
  processor = ColQwen2Processor.from_pretrained(model_name)
64
 
65
- # Your inputs (replace dummy images with screenshots of your documents)
 
 
 
66
  images = [
67
- Image.new("RGB", (128, 128), color="white"),
68
- Image.new("RGB", (64, 32), color="black"),
69
  ]
 
 
70
  queries = [
71
- "What is the organizational structure for our R&D department?",
72
- "Can you provide a breakdown of last year’s financial performance?",
73
  ]
74
 
75
  # Process the inputs
76
- batch_images = processor(images=images).to(model.device)
77
- batch_queries = processor(text=queries).to(model.device)
78
 
79
  # Forward pass
80
  with torch.no_grad():
81
- image_embeddings = model(**batch_images).embeddings
82
- query_embeddings = model(**batch_queries).embeddings
83
 
84
  # Score the queries against the images
85
  scores = processor.score_retrieval(query_embeddings, image_embeddings)
86
 
 
 
 
 
 
 
 
 
 
 
 
87
  ```
88
 
89
  ## Limitations
 
12
  pipeline_tag: visual-document-retrieval
13
  ---
14
 
 
 
 
15
  > [!IMPORTANT]
16
  > This version of ColQwen2 should be loaded with the `transformers 🤗` release, not with `colpali-engine`.
17
  > It was converted using the `convert_colqwen2_weights_to_hf.py` script
 
41
  ## Usage
42
 
43
  ```python
44
+ import requests
45
  import torch
46
  from PIL import Image
47
 
 
49
  from transformers.utils.import_utils import is_flash_attn_2_available
50
 
51
 
52
+ # Load the model and the processor
53
  model_name = "vidore/colqwen2-v1.0-hf"
54
 
55
  model = ColQwen2ForRetrieval.from_pretrained(
56
  model_name,
57
  torch_dtype=torch.bfloat16,
58
+ device_map="auto", # "cpu", "cuda", or "mps" for Apple Silicon
59
+ attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "sdpa",
60
+ )
 
61
  processor = ColQwen2Processor.from_pretrained(model_name)
62
 
63
+ # The document page screenshots from your corpus
64
+ url1 = "https://upload.wikimedia.org/wikipedia/commons/8/89/US-original-Declaration-1776.jpg"
65
+ url2 = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/4c/Romeoandjuliet1597.jpg/500px-Romeoandjuliet1597.jpg"
66
+
67
  images = [
68
+ Image.open(requests.get(url1, stream=True).raw),
69
+ Image.open(requests.get(url2, stream=True).raw),
70
  ]
71
+
72
+ # The queries you want to retrieve documents for
73
  queries = [
74
+ "When was the United States Declaration of Independence proclaimed?",
75
+ "Who printed the edition of Romeo and Juliet?",
76
  ]
77
 
78
  # Process the inputs
79
+ inputs_images = processor(images=images).to(model.device)
80
+ inputs_text = processor(text=queries).to(model.device)
81
 
82
  # Forward pass
83
  with torch.no_grad():
84
+ image_embeddings = model(**inputs_images).embeddings
85
+ query_embeddings = model(**inputs_text).embeddings
86
 
87
  # Score the queries against the images
88
  scores = processor.score_retrieval(query_embeddings, image_embeddings)
89
 
90
+ print("Retrieval scores (query x image):")
91
+ print(scores)
92
+ ```
93
+
94
+ If you have issue with loading the images with PIL, you can use the following code to create dummy images:
95
+
96
+ ```python
97
+ images = [
98
+ Image.new("RGB", (128, 128), color="white"),
99
+ Image.new("RGB", (64, 32), color="black"),
100
+ ]
101
  ```
102
 
103
  ## Limitations