prithivMLmods commited on
Commit
d2b791d
·
verified ·
1 Parent(s): 3c2995e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -109
app.py CHANGED
@@ -11,14 +11,13 @@ from transformers import (
11
  Qwen2VLForConditionalGeneration,
12
  AutoProcessor,
13
  TextIteratorStreamer,
14
- AutoModelForImageTextToText,
15
  )
16
  from transformers import Qwen2_5_VLForConditionalGeneration
17
 
18
  # ---------------------------
19
  # Helper Functions
20
  # ---------------------------
21
- def progress_bar_html(label: str, primary_color: str = "#FF0000", secondary_color: str = "#FF4500") -> str:
22
  """
23
  Returns an HTML snippet for a thin animated progress bar with a label.
24
  Colors can be customized; default colors are used for Qwen2VL/Aya‑Vision.
@@ -65,7 +64,7 @@ def downsample_video(video_path):
65
 
66
  # Model and Processor Setup
67
  # Qwen2VL OCR (default branch)
68
- QV_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" #[or] prithivMLmods/Qwen2-VL-OCR2-2B-Instruct
69
  qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True)
70
  qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
71
  QV_MODEL_ID,
@@ -73,13 +72,6 @@ qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
73
  torch_dtype=torch.float16
74
  ).to("cuda").eval()
75
 
76
- # Aya-Vision branch (for @aya-vision and @video-infer)
77
- AYA_MODEL_ID = "CohereForAI/aya-vision-8b"
78
- aya_processor = AutoProcessor.from_pretrained(AYA_MODEL_ID)
79
- aya_model = AutoModelForImageTextToText.from_pretrained(
80
- AYA_MODEL_ID, device_map="auto", torch_dtype=torch.float16
81
- )
82
-
83
  # RolmOCR branch (@RolmOCR)
84
  ROLMOCR_MODEL_ID = "reducto/RolmOCR"
85
  rolmocr_processor = AutoProcessor.from_pretrained(ROLMOCR_MODEL_ID, trust_remote_code=True)
@@ -95,93 +87,6 @@ def model_inference(input_dict, history):
95
  text = input_dict["text"].strip()
96
  files = input_dict.get("files", [])
97
 
98
- # ---------------------------
99
- # Aya-Vision Video Inference (@video-infer)
100
- # ---------------------------
101
- if text.lower().startswith("@video-infer"):
102
- prompt = text[len("@video-infer"):].strip()
103
- if not files:
104
- yield "Error: Please provide a video for the @video-infer feature."
105
- return
106
- video_path = files[0]
107
- frames = downsample_video(video_path)
108
- if not frames:
109
- yield "Error: Could not extract frames from the video."
110
- return
111
- # Build the message with the text prompt followed by each frame (with timestamp label).
112
- content_list = [{"type": "text", "text": prompt}]
113
- for frame, timestamp in frames:
114
- content_list.append({"type": "text", "text": f"Frame {timestamp}:"})
115
- content_list.append({"type": "image", "image": frame})
116
- messages = [{"role": "user", "content": content_list}]
117
- inputs = aya_processor.apply_chat_template(
118
- messages,
119
- padding=True,
120
- add_generation_prompt=True,
121
- tokenize=True,
122
- return_dict=True,
123
- return_tensors="pt"
124
- ).to(aya_model.device)
125
- streamer = TextIteratorStreamer(aya_processor, skip_prompt=True, skip_special_tokens=True)
126
- generation_kwargs = dict(
127
- inputs,
128
- streamer=streamer,
129
- max_new_tokens=1024,
130
- do_sample=True,
131
- temperature=0.3
132
- )
133
- thread = Thread(target=aya_model.generate, kwargs=generation_kwargs)
134
- thread.start()
135
- buffer = ""
136
- yield progress_bar_html("Processing video with Aya-Vision-8b")
137
- for new_text in streamer:
138
- buffer += new_text
139
- buffer = buffer.replace("<|im_end|>", "")
140
- time.sleep(0.01)
141
- yield buffer
142
- return
143
-
144
- # Aya-Vision Image Inference (@aya-vision)
145
- if text.lower().startswith("@aya-vision"):
146
- text_prompt = text[len("@aya-vision"):].strip()
147
- if not files:
148
- yield "Error: Please provide an image for the @aya-vision feature."
149
- return
150
- image = load_image(files[0])
151
- yield progress_bar_html("Processing with Aya-Vision-8b")
152
- messages = [{
153
- "role": "user",
154
- "content": [
155
- {"type": "image", "image": image},
156
- {"type": "text", "text": text_prompt},
157
- ],
158
- }]
159
- inputs = aya_processor.apply_chat_template(
160
- messages,
161
- padding=True,
162
- add_generation_prompt=True,
163
- tokenize=True,
164
- return_dict=True,
165
- return_tensors="pt"
166
- ).to(aya_model.device)
167
- streamer = TextIteratorStreamer(aya_processor, skip_prompt=True, skip_special_tokens=True)
168
- generation_kwargs = dict(
169
- inputs,
170
- streamer=streamer,
171
- max_new_tokens=1024,
172
- do_sample=True,
173
- temperature=0.3
174
- )
175
- thread = Thread(target=aya_model.generate, kwargs=generation_kwargs)
176
- thread.start()
177
- buffer = ""
178
- for new_text in streamer:
179
- buffer += new_text
180
- buffer = buffer.replace("<|im_end|>", "")
181
- time.sleep(0.01)
182
- yield buffer
183
- return
184
-
185
  # RolmOCR Inference (@RolmOCR)
186
  if text.lower().startswith("@rolmocr"):
187
  # Remove the tag from the query.
@@ -239,14 +144,14 @@ def model_inference(input_dict, history):
239
  thread.start()
240
  buffer = ""
241
  # Use a different color scheme for RolmOCR (purple-themed).
242
- yield progress_bar_html("Processing with Qwen2.5VL (RolmOCR)", primary_color="#4B0082", secondary_color="#9370DB")
243
  for new_text in streamer:
244
  buffer += new_text
245
  buffer = buffer.replace("<|im_end|>", "")
246
  time.sleep(0.01)
247
  yield buffer
248
  return
249
-
250
  # Default Inference: Qwen2VL OCR
251
  # Process files: support multiple images.
252
  if len(files) > 1:
@@ -294,26 +199,18 @@ examples = [
294
  [{"text": "@RolmOCR OCR the Text in the Image", "files": ["rolm/1.jpeg"]}],
295
  [{"text": "@RolmOCR OCR the Image", "files": ["rolm/3.jpeg"]}],
296
  [{"text": "Can you describe this image?", "files": ["example_images/dogs.jpg"]}],
297
- [{"text": "@aya-vision Summarize the letter", "files": ["examples/1.png"]}],
298
- [{"text": "@aya-vision Extract JSON from the image", "files": ["example_images/document.jpg"]}],
299
- [{"text": "@video-infer Explain what is happening in this video briefly by understanding", "files": ["examples/oreo.mp4"]}],
300
  [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
301
- [{"text": "@aya-vision Describe the photo", "files": ["examples/3.png"]}],
302
- [{"text": "@aya-vision Summarize the full image in detail", "files": ["examples/2.jpg"]}],
303
- [{"text": "@aya-vision Describe this image.", "files": ["example_images/campeones.jpg"]}],
304
- [{"text": "@aya-vision What is this UI about?", "files": ["example_images/s2w_example.png"]}],
305
- [{"text": "Can you describe this image?", "files": ["example_images/newyork.jpg"]}],
306
  ]
307
 
308
  demo = gr.ChatInterface(
309
  fn=model_inference,
310
- description="# **Multimodal OCR `@RolmOCR, @aya-vision for image, @video-infer for video`**",
311
  examples=examples,
312
  textbox=gr.MultimodalTextbox(
313
  label="Query Input",
314
  file_types=["image", "video"],
315
  file_count="multiple",
316
- placeholder="Use tag @RolmOCR @aya-vision for Image, @video-infer for video, or leave blank for default Qwen2VL OCR"
317
  ),
318
  stop_btn="Stop Generation",
319
  multimodal=True,
 
11
  Qwen2VLForConditionalGeneration,
12
  AutoProcessor,
13
  TextIteratorStreamer,
 
14
  )
15
  from transformers import Qwen2_5_VLForConditionalGeneration
16
 
17
  # ---------------------------
18
  # Helper Functions
19
  # ---------------------------
20
+ def progress_bar_html(label: str, primary_color: str = "#4B0082", secondary_color: str = "#9370DB") -> str:
21
  """
22
  Returns an HTML snippet for a thin animated progress bar with a label.
23
  Colors can be customized; default colors are used for Qwen2VL/Aya‑Vision.
 
64
 
65
  # Model and Processor Setup
66
  # Qwen2VL OCR (default branch)
67
+ QV_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" # [or] prithivMLmods/Qwen2-VL-OCR2-2B-Instruct
68
  qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True)
69
  qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
70
  QV_MODEL_ID,
 
72
  torch_dtype=torch.float16
73
  ).to("cuda").eval()
74
 
 
 
 
 
 
 
 
75
  # RolmOCR branch (@RolmOCR)
76
  ROLMOCR_MODEL_ID = "reducto/RolmOCR"
77
  rolmocr_processor = AutoProcessor.from_pretrained(ROLMOCR_MODEL_ID, trust_remote_code=True)
 
87
  text = input_dict["text"].strip()
88
  files = input_dict.get("files", [])
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  # RolmOCR Inference (@RolmOCR)
91
  if text.lower().startswith("@rolmocr"):
92
  # Remove the tag from the query.
 
144
  thread.start()
145
  buffer = ""
146
  # Use a different color scheme for RolmOCR (purple-themed).
147
+ yield progress_bar_html("Processing with Qwen2.5VL (RolmOCR)")
148
  for new_text in streamer:
149
  buffer += new_text
150
  buffer = buffer.replace("<|im_end|>", "")
151
  time.sleep(0.01)
152
  yield buffer
153
  return
154
+
155
  # Default Inference: Qwen2VL OCR
156
  # Process files: support multiple images.
157
  if len(files) > 1:
 
199
  [{"text": "@RolmOCR OCR the Text in the Image", "files": ["rolm/1.jpeg"]}],
200
  [{"text": "@RolmOCR OCR the Image", "files": ["rolm/3.jpeg"]}],
201
  [{"text": "Can you describe this image?", "files": ["example_images/dogs.jpg"]}],
 
 
 
202
  [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
 
 
 
 
 
203
  ]
204
 
205
  demo = gr.ChatInterface(
206
  fn=model_inference,
207
+ description="# **Multimodal OCR `@RolmOCR` and Default Qwen2VL OCR**",
208
  examples=examples,
209
  textbox=gr.MultimodalTextbox(
210
  label="Query Input",
211
  file_types=["image", "video"],
212
  file_count="multiple",
213
+ placeholder="Use tag @RolmOCR for RolmOCR, or leave blank for default Qwen2VL OCR"
214
  ),
215
  stop_btn="Stop Generation",
216
  multimodal=True,