Spaces:
Runtime error
Runtime error
Prompt option
Browse files
app.py
CHANGED
|
@@ -8,7 +8,7 @@ from PIL import Image
|
|
| 8 |
from io import BytesIO
|
| 9 |
from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig, DonutProcessor, DonutImageProcessor, AutoTokenizer
|
| 10 |
|
| 11 |
-
def run_prediction(sample, model, processor):
|
| 12 |
|
| 13 |
pixel_values = processor(np.array(
|
| 14 |
sample,
|
|
@@ -18,7 +18,7 @@ def run_prediction(sample, model, processor):
|
|
| 18 |
with torch.no_grad():
|
| 19 |
outputs = model.generate(
|
| 20 |
pixel_values.to(device),
|
| 21 |
-
decoder_input_ids=processor.tokenizer(
|
| 22 |
do_sample=True,
|
| 23 |
top_p=0.92,
|
| 24 |
top_k=5,
|
|
@@ -52,7 +52,9 @@ with st.sidebar:
|
|
| 52 |
if uploaded_file is not None:
|
| 53 |
# To read file as bytes:
|
| 54 |
image_bytes_data = uploaded_file.getvalue()
|
| 55 |
-
image_upload = Image.open(BytesIO(image_bytes_data))
|
|
|
|
|
|
|
| 56 |
|
| 57 |
if image_upload:
|
| 58 |
image = image_upload
|
|
@@ -87,6 +89,6 @@ with st.spinner(f'Processing the document ...'):
|
|
| 87 |
model.to(device)
|
| 88 |
|
| 89 |
st.info(f'Parsing document')
|
| 90 |
-
parsed_info = run_prediction(image.convert("RGB"), model, processor)
|
| 91 |
st.text(f'\nDocument:')
|
| 92 |
st.text_area('Output text', value=parsed_info, height=800)
|
|
|
|
| 8 |
from io import BytesIO
|
| 9 |
from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig, DonutProcessor, DonutImageProcessor, AutoTokenizer
|
| 10 |
|
| 11 |
+
def run_prediction(sample, model, processor, prompt):
|
| 12 |
|
| 13 |
pixel_values = processor(np.array(
|
| 14 |
sample,
|
|
|
|
| 18 |
with torch.no_grad():
|
| 19 |
outputs = model.generate(
|
| 20 |
pixel_values.to(device),
|
| 21 |
+
decoder_input_ids=processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device),
|
| 22 |
do_sample=True,
|
| 23 |
top_p=0.92,
|
| 24 |
top_k=5,
|
|
|
|
| 52 |
if uploaded_file is not None:
|
| 53 |
# To read file as bytes:
|
| 54 |
image_bytes_data = uploaded_file.getvalue()
|
| 55 |
+
image_upload = Image.open(BytesIO(image_bytes_data))
|
| 56 |
+
|
| 57 |
+
prompt = st.selectbox('Prompt', ('<s><s_pretraining>', '<s><s_plain>', '<s><s_hierarchical>'), index=2)
|
| 58 |
|
| 59 |
if image_upload:
|
| 60 |
image = image_upload
|
|
|
|
| 89 |
model.to(device)
|
| 90 |
|
| 91 |
st.info(f'Parsing document')
|
| 92 |
+
parsed_info = run_prediction(image.convert("RGB"), model, processor, prompt)
|
| 93 |
st.text(f'\nDocument:')
|
| 94 |
st.text_area('Output text', value=parsed_info, height=800)
|