merve HF staff commited on
Commit
ee97a5d
β€’
1 Parent(s): 83088b7

migrate to zero

Browse files
Files changed (1) hide show
  1. app.py +5 -0
app.py CHANGED
@@ -2,7 +2,9 @@ import gradio as gr
2
  import requests
3
  from PIL import Image
4
  from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
 
5
 
 
6
  def infer_infographics(image, question):
7
  model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-ai2d-base").to("cuda")
8
  processor = Pix2StructProcessor.from_pretrained("google/pix2struct-ai2d-base")
@@ -12,6 +14,7 @@ def infer_infographics(image, question):
12
  predictions = model.generate(**inputs)
13
  return processor.decode(predictions[0], skip_special_tokens=True)
14
 
 
15
  def infer_ui(image, question):
16
  model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-screen2words-base").to("cuda")
17
  processor = Pix2StructProcessor.from_pretrained("google/pix2struct-screen2words-base")
@@ -21,6 +24,7 @@ def infer_ui(image, question):
21
  predictions = model.generate(**inputs)
22
  return processor.decode(predictions[0], skip_special_tokens=True)
23
 
 
24
  def infer_chart(image, question):
25
  model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-chartqa-base").to("cuda")
26
  processor = Pix2StructProcessor.from_pretrained("google/pix2struct-chartqa-base")
@@ -30,6 +34,7 @@ def infer_chart(image, question):
30
  predictions = model.generate(**inputs)
31
  return processor.decode(predictions[0], skip_special_tokens=True)
32
 
 
33
  def infer_doc(image, question):
34
  model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-docvqa-base").to("cuda")
35
  processor = Pix2StructProcessor.from_pretrained("google/pix2struct-docvqa-base")
 
2
  import requests
3
  from PIL import Image
4
  from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
5
+ import spaces
6
 
7
+ @spaces.GPU
8
  def infer_infographics(image, question):
9
  model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-ai2d-base").to("cuda")
10
  processor = Pix2StructProcessor.from_pretrained("google/pix2struct-ai2d-base")
 
14
  predictions = model.generate(**inputs)
15
  return processor.decode(predictions[0], skip_special_tokens=True)
16
 
17
+ @spaces.GPU
18
  def infer_ui(image, question):
19
  model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-screen2words-base").to("cuda")
20
  processor = Pix2StructProcessor.from_pretrained("google/pix2struct-screen2words-base")
 
24
  predictions = model.generate(**inputs)
25
  return processor.decode(predictions[0], skip_special_tokens=True)
26
 
27
+ @spaces.GPU
28
  def infer_chart(image, question):
29
  model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-chartqa-base").to("cuda")
30
  processor = Pix2StructProcessor.from_pretrained("google/pix2struct-chartqa-base")
 
34
  predictions = model.generate(**inputs)
35
  return processor.decode(predictions[0], skip_special_tokens=True)
36
 
37
+ @spaces.GPU
38
  def infer_doc(image, question):
39
  model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-docvqa-base").to("cuda")
40
  processor = Pix2StructProcessor.from_pretrained("google/pix2struct-docvqa-base")