sergiopaniego commited on
Commit
ffcf8f2
·
verified ·
1 Parent(s): 5ca3297

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -10
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  import spaces
3
- from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor
4
- from qwen_vl_utils import process_vision_info
5
  import torch
6
  from PIL import Image
7
  from datetime import datetime
@@ -10,22 +9,24 @@ import os
10
 
11
 
12
  DESCRIPTION = """
13
- # Qwen2-VL-7B-trl-sft-ChartQA Demo
14
 
15
- This is a demo Space for a fine-tuned version of [Qwen2-VL-7B](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) trained using [ChatQA dataset](https://huggingface.co/datasets/HuggingFaceM4/ChartQA).
16
 
17
- The corresponding model is located [here](https://huggingface.co/sergiopaniego/qwen2-7b-instruct-trl-sft-ChartQA).
18
  """
19
 
20
- model_id = "Qwen/Qwen2-VL-7B-Instruct"
21
- model = Qwen2VLForConditionalGeneration.from_pretrained(
22
  model_id,
23
  device_map="auto",
24
  torch_dtype=torch.bfloat16,
 
25
  )
26
- adapter_path = "sergiopaniego/qwen2-7b-instruct-trl-sft-ChartQA"
 
 
27
  model.load_adapter(adapter_path)
28
- processor = Qwen2VLProcessor.from_pretrained(model_id)
29
 
30
  def array_to_image_path(image_array):
31
  if image_array is None:
@@ -101,7 +102,7 @@ css = """
101
 
102
  with gr.Blocks(css=css) as demo:
103
  gr.Markdown(DESCRIPTION)
104
- with gr.Tab(label="Qwen2-VL-7B-trl-sft-ChartQA Input"):
105
  with gr.Row():
106
  with gr.Column():
107
  input_img = gr.Image(label="Input Picture")
 
1
  import gradio as gr
2
  import spaces
3
+ from transformers import Idefics3ForConditionalGeneration, AutoProcessor
 
4
  import torch
5
  from PIL import Image
6
  from datetime import datetime
 
9
 
10
 
11
  DESCRIPTION = """
12
+ # SmolVLM-trl-sft-ChartQA Demo
13
 
14
+ This is a demo Space for a fine-tuned version of [SmolVLM](https://huggingface.co/HuggingFaceTB/SmolVLM-Instruct) trained using [ChatQA dataset](https://huggingface.co/datasets/HuggingFaceM4/ChartQA).
15
 
16
+ The corresponding model is located [here](https://huggingface.co/sergiopaniego/smolvlm-instruct-trl-sft-ChartQA).
17
  """
18
 
19
+ model_id = "HuggingFaceTB/SmolVLM-Instruct"
20
+ model = Idefics3ForConditionalGeneration.from_pretrained(
21
  model_id,
22
  device_map="auto",
23
  torch_dtype=torch.bfloat16,
24
+ _attn_implementation="flash_attention_2",
25
  )
26
+
27
+ processor = AutoProcessor.from_pretrained(model_id)
28
+ adapter_path = "sergiopaniego/smolvlm-instruct-trl-sft-ChartQA"
29
  model.load_adapter(adapter_path)
 
30
 
31
  def array_to_image_path(image_array):
32
  if image_array is None:
 
102
 
103
  with gr.Blocks(css=css) as demo:
104
  gr.Markdown(DESCRIPTION)
105
+ with gr.Tab(label="SmolVLM-trl-sft-ChartQA Input"):
106
  with gr.Row():
107
  with gr.Column():
108
  input_img = gr.Image(label="Input Picture")