lukiod commited on
Commit
5781f7f
1 Parent(s): 2dd3a18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -29
app.py CHANGED
@@ -1,76 +1,176 @@
1
  import streamlit as st
2
  import torch
3
  from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
4
- from PIL import Image
5
- from byaldi import RAGMultiModalModel
6
  from qwen_vl_utils import process_vision_info
 
 
 
 
 
 
7
 
8
- # Model and processor names
9
- RAG_MODEL = "vidore/colpali"
10
- QWN_MODEL = "Qwen/Qwen2-VL-7B-Instruct"
11
 
 
12
  @st.cache_resource
13
  def load_models():
14
- RAG = RAGMultiModalModel.from_pretrained(RAG_MODEL)
15
 
16
- model = Qwen2VLForConditionalGeneration.from_pretrained(
17
- QWN_MODEL,
18
  torch_dtype=torch.bfloat16,
19
  attn_implementation="flash_attention_2",
20
  device_map="auto",
21
  trust_remote_code=True
22
- ).eval()
23
 
24
- processor = AutoProcessor.from_pretrained(QWN_MODEL, trust_remote_code=True)
25
 
26
- return RAG, model, processor
 
 
 
 
 
 
27
 
28
- RAG, model, processor = load_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- def document_rag(text_query, image):
 
 
 
 
 
 
 
 
31
  messages = [
32
  {
33
  "role": "user",
34
  "content": [
 
 
 
 
35
  {
36
  "type": "image",
37
  "image": image,
38
  },
39
- {"type": "text", "text": text_query},
40
  ],
41
  }
42
  ]
43
- text = processor.apply_chat_template(
44
  messages, tokenize=False, add_generation_prompt=True
45
  )
46
  image_inputs, video_inputs = process_vision_info(messages)
47
- inputs = processor(
48
  text=[text],
49
  images=image_inputs,
50
  videos=video_inputs,
51
  padding=True,
52
  return_tensors="pt",
53
  )
54
- inputs = inputs.to(model.device)
55
- generated_ids = model.generate(**inputs, max_new_tokens=50)
56
  generated_ids_trimmed = [
57
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
58
  ]
59
- output_text = processor.batch_decode(
60
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
61
  )
62
- return output_text[0]
 
 
 
 
 
 
 
63
 
64
- st.title("Document Processor")
 
 
 
 
65
 
66
- uploaded_file = st.file_uploader("Choose an image file", type=["jpg", "jpeg", "png"])
67
- text_query = st.text_input("Enter your text query")
68
 
69
- if uploaded_file is not None and text_query:
 
 
 
70
  image = Image.open(uploaded_file)
71
-
72
- if st.button("Process Document"):
 
73
  with st.spinner("Processing..."):
74
- result = document_rag(text_query, image)
75
- st.success("Processing complete!")
76
- st.write("Result:", result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import torch
3
  from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
 
 
4
  from qwen_vl_utils import process_vision_info
5
+ from byaldi import RAGMultiModalModel
6
+ from PIL import Image
7
+ import io
8
+ import time
9
+ import nltk
10
+ from nltk.translate.bleu_score import sentence_bleu
11
 
12
+ # Download NLTK data for BLEU score calculation
13
+ nltk.download('punkt', quiet=True)
 
14
 
15
+ # Load models and processors
16
  @st.cache_resource
17
  def load_models():
18
+ RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")
19
 
20
+ qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
21
+ "Qwen/Qwen2-VL-7B-Instruct",
22
  torch_dtype=torch.bfloat16,
23
  attn_implementation="flash_attention_2",
24
  device_map="auto",
25
  trust_remote_code=True
26
+ ).cuda().eval()
27
 
28
+ qwen_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
29
 
30
+ return RAG, qwen_model, qwen_processor
31
+
32
+ RAG, qwen_model, qwen_processor = load_models()
33
+
34
+ # Function to get current CUDA memory usage
35
+ def get_cuda_memory_usage():
36
+ return torch.cuda.memory_allocated() / 1024**2 # Convert to MB
37
 
38
+ # Define processing functions
39
+ def extract_text_with_colpali(image):
40
+ start_time = time.time()
41
+ start_memory = get_cuda_memory_usage()
42
+
43
+ extracted_text = RAG.extract_text(image)
44
+
45
+ end_time = time.time()
46
+ end_memory = get_cuda_memory_usage()
47
+
48
+ return extracted_text, {
49
+ 'time': end_time - start_time,
50
+ 'memory': end_memory - start_memory
51
+ }
52
 
53
+ def process_with_qwen(query, extracted_text, image, extract_mode=False):
54
+ start_time = time.time()
55
+ start_memory = get_cuda_memory_usage()
56
+
57
+ if extract_mode:
58
+ instruction = "Extract and list all text visible in this image, including both printed and handwritten text."
59
+ else:
60
+ instruction = f"Context: {extracted_text}\n\nQuery: {query}"
61
+
62
  messages = [
63
  {
64
  "role": "user",
65
  "content": [
66
+ {
67
+ "type": "text",
68
+ "text": instruction
69
+ },
70
  {
71
  "type": "image",
72
  "image": image,
73
  },
 
74
  ],
75
  }
76
  ]
77
+ text = qwen_processor.apply_chat_template(
78
  messages, tokenize=False, add_generation_prompt=True
79
  )
80
  image_inputs, video_inputs = process_vision_info(messages)
81
+ inputs = qwen_processor(
82
  text=[text],
83
  images=image_inputs,
84
  videos=video_inputs,
85
  padding=True,
86
  return_tensors="pt",
87
  )
88
+ inputs = inputs.to("cuda")
89
+ generated_ids = qwen_model.generate(**inputs, max_new_tokens=200)
90
  generated_ids_trimmed = [
91
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
92
  ]
93
+ output_text = qwen_processor.batch_decode(
94
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
95
  )
96
+
97
+ end_time = time.time()
98
+ end_memory = get_cuda_memory_usage()
99
+
100
+ return output_text[0], {
101
+ 'time': end_time - start_time,
102
+ 'memory': end_memory - start_memory
103
+ }
104
 
105
+ # Function to calculate BLEU score
106
+ def calculate_bleu(reference, hypothesis):
107
+ reference_tokens = nltk.word_tokenize(reference.lower())
108
+ hypothesis_tokens = nltk.word_tokenize(hypothesis.lower())
109
+ return sentence_bleu([reference_tokens], hypothesis_tokens)
110
 
111
+ # Streamlit UI
112
+ st.title("Document Processing with ColPali and Qwen")
113
 
114
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
115
+ query = st.text_input("Enter your query:")
116
+
117
+ if uploaded_file is not None and query:
118
  image = Image.open(uploaded_file)
119
+ st.image(image, caption="Uploaded Image", use_column_width=True)
120
+
121
+ if st.button("Process"):
122
  with st.spinner("Processing..."):
123
+ # Extract text using ColPali
124
+ colpali_extracted_text, colpali_metrics = extract_text_with_colpali(image)
125
+
126
+ # Extract text using Qwen
127
+ qwen_extracted_text, qwen_extract_metrics = process_with_qwen("", "", image, extract_mode=True)
128
+
129
+ # Process the query with Qwen2, using both extracted text and image
130
+ qwen_response, qwen_response_metrics = process_with_qwen(query, colpali_extracted_text, image)
131
+
132
+ # Calculate BLEU score between ColPali and Qwen extractions
133
+ bleu_score = calculate_bleu(colpali_extracted_text, qwen_extracted_text)
134
+
135
+ # Display results
136
+ st.subheader("Results")
137
+ st.write("ColPali Extracted Text:")
138
+ st.write(colpali_extracted_text)
139
+
140
+ st.write("Qwen Extracted Text:")
141
+ st.write(qwen_extracted_text)
142
+
143
+ st.write("Qwen Response:")
144
+ st.write(qwen_response)
145
+
146
+ # Display metrics
147
+ st.subheader("Metrics")
148
+
149
+ st.write("ColPali Extraction:")
150
+ st.write(f"Time: {colpali_metrics['time']:.2f} seconds")
151
+ st.write(f"Memory: {colpali_metrics['memory']:.2f} MB")
152
+
153
+ st.write("Qwen Extraction:")
154
+ st.write(f"Time: {qwen_extract_metrics['time']:.2f} seconds")
155
+ st.write(f"Memory: {qwen_extract_metrics['memory']:.2f} MB")
156
+
157
+ st.write("Qwen Response:")
158
+ st.write(f"Time: {qwen_response_metrics['time']:.2f} seconds")
159
+ st.write(f"Memory: {qwen_response_metrics['memory']:.2f} MB")
160
+
161
+ st.write(f"BLEU Score: {bleu_score:.4f}")
162
+
163
+ st.markdown("""
164
+ ## How to Use
165
+
166
+ 1. Upload an image containing text or a document.
167
+ 2. Enter your query about the document.
168
+ 3. Click 'Process' to see the results.
169
+
170
+ The app will display:
171
+ - Text extracted by ColPali
172
+ - Text extracted by Qwen
173
+ - Qwen's response to your query
174
+ - Performance metrics for each step
175
+ - BLEU score comparing ColPali and Qwen extractions
176
+ """)