prithivMLmods commited on
Commit
92e002a
·
verified ·
1 Parent(s): 080d79a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -72
app.py CHANGED
@@ -4,25 +4,30 @@ from threading import Thread
4
  import time
5
  import torch
6
  import spaces
7
- from PIL import Image
8
- import requests
9
- from io import BytesIO
10
  import cv2
11
  import numpy as np
 
12
  from transformers import (
13
  Qwen2VLForConditionalGeneration,
14
  AutoProcessor,
15
  TextIteratorStreamer,
16
  AutoModelForImageTextToText,
17
  )
 
18
 
19
- # Helper function to return a progress bar HTML snippet.
20
- def progress_bar_html(label: str) -> str:
 
 
 
 
 
 
21
  return f'''
22
  <div style="display: flex; align-items: center;">
23
  <span style="margin-right: 10px; font-size: 14px;">{label}</span>
24
- <div style="width: 110px; height: 5px; background-color: #FFB6C1; border-radius: 2px; overflow: hidden;">
25
- <div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
26
  </div>
27
  </div>
28
  <style>
@@ -33,13 +38,19 @@ def progress_bar_html(label: str) -> str:
33
  </style>
34
  '''
35
 
36
- # Helper function to downsample a video into 10 evenly spaced frames.
37
  def downsample_video(video_path):
 
 
 
 
38
  vidcap = cv2.VideoCapture(video_path)
39
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
40
  fps = vidcap.get(cv2.CAP_PROP_FPS)
41
  frames = []
42
- # Sample 10 evenly spaced frames.
 
 
 
43
  frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
44
  for i in frame_indices:
45
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
@@ -52,10 +63,9 @@ def downsample_video(video_path):
52
  vidcap.release()
53
  return frames
54
 
55
- # Model and processor setups
56
-
57
- # Setup for Qwen2VL OCR branch (default).
58
- QV_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" # or use "prithivMLmods/Qwen2-VL-OCR2-2B-Instruct"
59
  qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True)
60
  qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
61
  QV_MODEL_ID,
@@ -63,22 +73,31 @@ qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
63
  torch_dtype=torch.float16
64
  ).to("cuda").eval()
65
 
66
- # Setup for Aya-Vision branch.
67
  AYA_MODEL_ID = "CohereForAI/aya-vision-8b"
68
  aya_processor = AutoProcessor.from_pretrained(AYA_MODEL_ID)
69
  aya_model = AutoModelForImageTextToText.from_pretrained(
70
  AYA_MODEL_ID, device_map="auto", torch_dtype=torch.float16
71
  )
72
 
73
- # ---------------------------
 
 
 
 
 
 
 
 
74
  # Main Inference Function
75
- # ---------------------------
76
  @spaces.GPU
77
  def model_inference(input_dict, history):
78
  text = input_dict["text"].strip()
79
  files = input_dict.get("files", [])
80
-
81
- # Branch for video inference with Aya-Vision using @video-infer.
 
 
82
  if text.lower().startswith("@video-infer"):
83
  prompt = text[len("@video-infer"):].strip()
84
  if not files:
@@ -89,16 +108,12 @@ def model_inference(input_dict, history):
89
  if not frames:
90
  yield "Error: Could not extract frames from the video."
91
  return
92
- # Build messages: start with the prompt then add each frame with its timestamp.
93
- content_list = []
94
- content_list.append({"type": "text", "text": prompt})
95
  for frame, timestamp in frames:
96
  content_list.append({"type": "text", "text": f"Frame {timestamp}:"})
97
  content_list.append({"type": "image", "image": frame})
98
- messages = [{
99
- "role": "user",
100
- "content": content_list,
101
- }]
102
  inputs = aya_processor.apply_chat_template(
103
  messages,
104
  padding=True,
@@ -126,50 +141,114 @@ def model_inference(input_dict, history):
126
  yield buffer
127
  return
128
 
129
- # Branch for single image inference with Aya-Vision using @aya-vision.
130
  if text.lower().startswith("@aya-vision"):
131
  text_prompt = text[len("@aya-vision"):].strip()
132
  if not files:
133
  yield "Error: Please provide an image for the @aya-vision feature."
134
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  else:
136
- # Use the first provided image.
137
- image = load_image(files[0])
138
- yield progress_bar_html("Processing with Aya-Vision-8b")
 
 
 
 
 
 
 
139
  messages = [{
140
  "role": "user",
141
  "content": [
142
- {"type": "image", "image": image},
143
  {"type": "text", "text": text_prompt},
144
  ],
145
  }]
146
- inputs = aya_processor.apply_chat_template(
147
- messages,
 
 
 
148
  padding=True,
149
- add_generation_prompt=True,
150
- tokenize=True,
151
- return_dict=True,
152
- return_tensors="pt"
153
- ).to(aya_model.device)
154
- streamer = TextIteratorStreamer(aya_processor, skip_prompt=True, skip_special_tokens=True)
155
- generation_kwargs = dict(
156
- inputs,
157
- streamer=streamer,
158
- max_new_tokens=1024,
159
- do_sample=True,
160
- temperature=0.3
161
- )
162
- thread = Thread(target=aya_model.generate, kwargs=generation_kwargs)
163
- thread.start()
164
- buffer = ""
165
- for new_text in streamer:
166
- buffer += new_text
167
- buffer = buffer.replace("<|im_end|>", "")
168
- time.sleep(0.01)
169
- yield buffer
170
- return
171
-
172
- # Default branch: Use Qwen2VL OCR for text (with optional images).
173
  if len(files) > 1:
174
  images = [load_image(image) for image in files]
175
  elif len(files) == 1:
@@ -178,7 +257,7 @@ def model_inference(input_dict, history):
178
  images = []
179
 
180
  if text == "" and not images:
181
- yield "Error: Please input a query and optionally image(s)."
182
  return
183
  if text == "" and images:
184
  yield "Error: Please input a text query along with the image(s)."
@@ -191,23 +270,17 @@ def model_inference(input_dict, history):
191
  {"type": "text", "text": text},
192
  ],
193
  }]
194
-
195
- prompt = qwen_processor.apply_chat_template(
196
- messages, tokenize=False, add_generation_prompt=True
197
- )
198
  inputs = qwen_processor(
199
- text=[prompt],
200
  images=images if images else None,
201
  return_tensors="pt",
202
  padding=True,
203
  ).to("cuda")
204
-
205
  streamer = TextIteratorStreamer(qwen_processor, skip_prompt=True, skip_special_tokens=True)
206
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
207
-
208
  thread = Thread(target=qwen_model.generate, kwargs=generation_kwargs)
209
  thread.start()
210
-
211
  buffer = ""
212
  yield progress_bar_html("Processing with Qwen2VL OCR")
213
  for new_text in streamer:
@@ -216,32 +289,31 @@ def model_inference(input_dict, history):
216
  time.sleep(0.01)
217
  yield buffer
218
 
219
-
220
- # Gradio Interface Setup
221
-
222
  examples = [
 
 
223
  [{"text": "Can you describe this image?", "files": ["example_images/dogs.jpg"]}],
224
  [{"text": "@aya-vision Summarize the letter", "files": ["examples/1.png"]}],
225
  [{"text": "@aya-vision Extract JSON from the image", "files": ["example_images/document.jpg"]}],
226
- [{"text": "@video-infer Explain what is happening in this video ?", "files": ["examples/oreo.mp4"]}],
227
  [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
228
  [{"text": "@aya-vision Describe the photo", "files": ["examples/3.png"]}],
229
  [{"text": "@aya-vision Summarize the full image in detail", "files": ["examples/2.jpg"]}],
230
  [{"text": "@aya-vision Describe this image.", "files": ["example_images/campeones.jpg"]}],
231
  [{"text": "@aya-vision What is this UI about?", "files": ["example_images/s2w_example.png"]}],
232
  [{"text": "Can you describe this image?", "files": ["example_images/newyork.jpg"]}],
233
- [{"text": "@aya-vision Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
234
  ]
235
 
236
  demo = gr.ChatInterface(
237
  fn=model_inference,
238
- description="# **Multimodal OCR `@aya-vision for image, @video-infer for video`**",
239
  examples=examples,
240
  textbox=gr.MultimodalTextbox(
241
  label="Query Input",
242
  file_types=["image", "video"],
243
  file_count="multiple",
244
- placeholder="Tag @aya-vision for Aya-Vision image infer, @video-infer for Aya-Vision video infer, default runs Qwen2VL OCR"
245
  ),
246
  stop_btn="Stop Generation",
247
  multimodal=True,
 
4
  import time
5
  import torch
6
  import spaces
 
 
 
7
  import cv2
8
  import numpy as np
9
+ from PIL import Image
10
  from transformers import (
11
  Qwen2VLForConditionalGeneration,
12
  AutoProcessor,
13
  TextIteratorStreamer,
14
  AutoModelForImageTextToText,
15
  )
16
+ from transformers import Qwen2_5_VLForConditionalGeneration
17
 
18
+ # ---------------------------
19
+ # Helper Functions
20
+ # ---------------------------
21
+ def progress_bar_html(label: str, primary_color: str = "#FF69B4", secondary_color: str = "#FFB6C1") -> str:
22
+ """
23
+ Returns an HTML snippet for a thin animated progress bar with a label.
24
+ Colors can be customized; default colors are used for Qwen2VL/Aya‑Vision.
25
+ """
26
  return f'''
27
  <div style="display: flex; align-items: center;">
28
  <span style="margin-right: 10px; font-size: 14px;">{label}</span>
29
+ <div style="width: 110px; height: 5px; background-color: {secondary_color}; border-radius: 2px; overflow: hidden;">
30
+ <div style="width: 100%; height: 100%; background-color: {primary_color}; animation: loading 1.5s linear infinite;"></div>
31
  </div>
32
  </div>
33
  <style>
 
38
  </style>
39
  '''
40
 
 
41
  def downsample_video(video_path):
42
+ """
43
+ Downsamples a video file by extracting 10 evenly spaced frames.
44
+ Returns a list of tuples (PIL.Image, timestamp).
45
+ """
46
  vidcap = cv2.VideoCapture(video_path)
47
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
48
  fps = vidcap.get(cv2.CAP_PROP_FPS)
49
  frames = []
50
+ if total_frames <= 0 or fps <= 0:
51
+ vidcap.release()
52
+ return frames
53
+ # Determine 10 evenly spaced frame indices.
54
  frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
55
  for i in frame_indices:
56
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
 
63
  vidcap.release()
64
  return frames
65
 
66
+ # Model and Processor Setup
67
+ # Qwen2VL OCR (default branch)
68
+ QV_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
 
69
  qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True)
70
  qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
71
  QV_MODEL_ID,
 
73
  torch_dtype=torch.float16
74
  ).to("cuda").eval()
75
 
76
+ # Aya-Vision branch (for @aya-vision and @video-infer)
77
  AYA_MODEL_ID = "CohereForAI/aya-vision-8b"
78
  aya_processor = AutoProcessor.from_pretrained(AYA_MODEL_ID)
79
  aya_model = AutoModelForImageTextToText.from_pretrained(
80
  AYA_MODEL_ID, device_map="auto", torch_dtype=torch.float16
81
  )
82
 
83
+ # RolmOCR branch (@RolmOCR)
84
+ ROLMOCR_MODEL_ID = "reducto/RolmOCR"
85
+ rolmocr_processor = AutoProcessor.from_pretrained(ROLMOCR_MODEL_ID, trust_remote_code=True)
86
+ rolmocr_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
87
+ ROLMOCR_MODEL_ID,
88
+ trust_remote_code=True,
89
+ torch_dtype=torch.bfloat16
90
+ ).to("cuda").eval()
91
+
92
  # Main Inference Function
 
93
  @spaces.GPU
94
  def model_inference(input_dict, history):
95
  text = input_dict["text"].strip()
96
  files = input_dict.get("files", [])
97
+
98
+ # ---------------------------
99
+ # Aya-Vision Video Inference (@video-infer)
100
+ # ---------------------------
101
  if text.lower().startswith("@video-infer"):
102
  prompt = text[len("@video-infer"):].strip()
103
  if not files:
 
108
  if not frames:
109
  yield "Error: Could not extract frames from the video."
110
  return
111
+ # Build the message with the text prompt followed by each frame (with timestamp label).
112
+ content_list = [{"type": "text", "text": prompt}]
 
113
  for frame, timestamp in frames:
114
  content_list.append({"type": "text", "text": f"Frame {timestamp}:"})
115
  content_list.append({"type": "image", "image": frame})
116
+ messages = [{"role": "user", "content": content_list}]
 
 
 
117
  inputs = aya_processor.apply_chat_template(
118
  messages,
119
  padding=True,
 
141
  yield buffer
142
  return
143
 
144
+ # Aya-Vision Image Inference (@aya-vision)
145
  if text.lower().startswith("@aya-vision"):
146
  text_prompt = text[len("@aya-vision"):].strip()
147
  if not files:
148
  yield "Error: Please provide an image for the @aya-vision feature."
149
  return
150
+ image = load_image(files[0])
151
+ yield progress_bar_html("Processing with Aya-Vision-8b")
152
+ messages = [{
153
+ "role": "user",
154
+ "content": [
155
+ {"type": "image", "image": image},
156
+ {"type": "text", "text": text_prompt},
157
+ ],
158
+ }]
159
+ inputs = aya_processor.apply_chat_template(
160
+ messages,
161
+ padding=True,
162
+ add_generation_prompt=True,
163
+ tokenize=True,
164
+ return_dict=True,
165
+ return_tensors="pt"
166
+ ).to(aya_model.device)
167
+ streamer = TextIteratorStreamer(aya_processor, skip_prompt=True, skip_special_tokens=True)
168
+ generation_kwargs = dict(
169
+ inputs,
170
+ streamer=streamer,
171
+ max_new_tokens=1024,
172
+ do_sample=True,
173
+ temperature=0.3
174
+ )
175
+ thread = Thread(target=aya_model.generate, kwargs=generation_kwargs)
176
+ thread.start()
177
+ buffer = ""
178
+ for new_text in streamer:
179
+ buffer += new_text
180
+ buffer = buffer.replace("<|im_end|>", "")
181
+ time.sleep(0.01)
182
+ yield buffer
183
+ return
184
+
185
+ # RolmOCR Inference (@RolmOCR)
186
+ if text.lower().startswith("@rolmocr"):
187
+ # Remove the tag from the query.
188
+ text_prompt = text[len("@rolmocr"):].strip()
189
+ # Check if a video is provided for inference.
190
+ if files and isinstance(files[0], str) and files[0].lower().endswith((".mp4", ".avi", ".mov")):
191
+ video_path = files[0]
192
+ frames = downsample_video(video_path)
193
+ if not frames:
194
+ yield "Error: Could not extract frames from the video."
195
+ return
196
+ # Build the message: prompt followed by each frame with its timestamp.
197
+ content_list = [{"type": "text", "text": text_prompt}]
198
+ for image, timestamp in frames:
199
+ content_list.append({"type": "text", "text": f"Frame {timestamp}:"})
200
+ content_list.append({"type": "image", "image": image})
201
+ messages = [{"role": "user", "content": content_list}]
202
+ # For video, extract images only.
203
+ video_images = [image for image, _ in frames]
204
+ prompt_full = rolmocr_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
205
+ inputs = rolmocr_processor(
206
+ text=[prompt_full],
207
+ images=video_images,
208
+ return_tensors="pt",
209
+ padding=True,
210
+ ).to("cuda")
211
  else:
212
+ # Assume image(s) or text query.
213
+ if len(files) > 1:
214
+ images = [load_image(image) for image in files]
215
+ elif len(files) == 1:
216
+ images = [load_image(files[0])]
217
+ else:
218
+ images = []
219
+ if text_prompt == "" and not images:
220
+ yield "Error: Please input a text query and/or provide an image for the @RolmOCR feature."
221
+ return
222
  messages = [{
223
  "role": "user",
224
  "content": [
225
+ *[{"type": "image", "image": image} for image in images],
226
  {"type": "text", "text": text_prompt},
227
  ],
228
  }]
229
+ prompt_full = rolmocr_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
230
+ inputs = rolmocr_processor(
231
+ text=[prompt_full],
232
+ images=images if images else None,
233
+ return_tensors="pt",
234
  padding=True,
235
+ ).to("cuda")
236
+ streamer = TextIteratorStreamer(rolmocr_processor, skip_prompt=True, skip_special_tokens=True)
237
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
238
+ thread = Thread(target=rolmocr_model.generate, kwargs=generation_kwargs)
239
+ thread.start()
240
+ buffer = ""
241
+ # Use a different color scheme for RolmOCR (purple-themed).
242
+ yield progress_bar_html("Processing with Qwen2.5VL (RolmOCR)", primary_color="#4B0082", secondary_color="#9370DB")
243
+ for new_text in streamer:
244
+ buffer += new_text
245
+ buffer = buffer.replace("<|im_end|>", "")
246
+ time.sleep(0.01)
247
+ yield buffer
248
+ return
249
+
250
+ # Default Inference: Qwen2VL OCR
251
+ # Process files: support multiple images.
 
 
 
 
 
 
 
252
  if len(files) > 1:
253
  images = [load_image(image) for image in files]
254
  elif len(files) == 1:
 
257
  images = []
258
 
259
  if text == "" and not images:
260
+ yield "Error: Please input a text query and optionally image(s)."
261
  return
262
  if text == "" and images:
263
  yield "Error: Please input a text query along with the image(s)."
 
270
  {"type": "text", "text": text},
271
  ],
272
  }]
273
+ prompt_full = qwen_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
 
274
  inputs = qwen_processor(
275
+ text=[prompt_full],
276
  images=images if images else None,
277
  return_tensors="pt",
278
  padding=True,
279
  ).to("cuda")
 
280
  streamer = TextIteratorStreamer(qwen_processor, skip_prompt=True, skip_special_tokens=True)
281
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
 
282
  thread = Thread(target=qwen_model.generate, kwargs=generation_kwargs)
283
  thread.start()
 
284
  buffer = ""
285
  yield progress_bar_html("Processing with Qwen2VL OCR")
286
  for new_text in streamer:
 
289
  time.sleep(0.01)
290
  yield buffer
291
 
292
+ # Gradio Interface
 
 
293
  examples = [
294
+ [{"text": "@RolmOCR OCR the Text in the Image", "files": ["rolm/1.jpeg"]}],
295
+ [{"text": "@RolmOCR OCR the Image", "files": ["rolm/2.jpeg"]}],
296
  [{"text": "Can you describe this image?", "files": ["example_images/dogs.jpg"]}],
297
  [{"text": "@aya-vision Summarize the letter", "files": ["examples/1.png"]}],
298
  [{"text": "@aya-vision Extract JSON from the image", "files": ["example_images/document.jpg"]}],
299
+ [{"text": "@video-infer Explain what is happening in this video?", "files": ["examples/oreo.mp4"]}],
300
  [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
301
  [{"text": "@aya-vision Describe the photo", "files": ["examples/3.png"]}],
302
  [{"text": "@aya-vision Summarize the full image in detail", "files": ["examples/2.jpg"]}],
303
  [{"text": "@aya-vision Describe this image.", "files": ["example_images/campeones.jpg"]}],
304
  [{"text": "@aya-vision What is this UI about?", "files": ["example_images/s2w_example.png"]}],
305
  [{"text": "Can you describe this image?", "files": ["example_images/newyork.jpg"]}],
 
306
  ]
307
 
308
  demo = gr.ChatInterface(
309
  fn=model_inference,
310
+ description="# **Multimodal OCR `@RolmOCR, @aya-vision for image, @video-infer for video`**",
311
  examples=examples,
312
  textbox=gr.MultimodalTextbox(
313
  label="Query Input",
314
  file_types=["image", "video"],
315
  file_count="multiple",
316
+ placeholder="Tag @aya-vision for AyaVision, @video-infer for video, for RolmOCR, or leave blank for default Qwen2VL OCR"
317
  ),
318
  stop_btn="Stop Generation",
319
  multimodal=True,