Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| from torch import nn | |
| import streamlit as st | |
| import os | |
| from PIL import Image | |
| from io import BytesIO | |
| import transformers | |
| from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig, DonutProcessor, DonutImageProcessor, AutoTokenizer | |
| from logits_ngrams import NoRepeatNGramLogitsProcessor, get_table_token_ids | |
| def run_prediction(sample, model, processor, mode): | |
| skip_tokens = get_table_token_ids(processor) | |
| no_repeat_ngram_size = 15 | |
| if mode == "OCR": | |
| prompt = "<s><s_pretraining>" | |
| else: | |
| prompt = "<s><s_hierarchical>" | |
| print("prompt:", prompt) | |
| print("no_repeat_ngram_size:", no_repeat_ngram_size) | |
| pixel_values = processor(np.array( | |
| sample, | |
| np.float32, | |
| ), return_tensors="pt").pixel_values | |
| transformers.set_seed(42) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| pixel_values.to(device), | |
| decoder_input_ids=processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device), | |
| logits_processor=[NoRepeatNGramLogitsProcessor(no_repeat_ngram_size, skip_tokens)], | |
| do_sample=True, | |
| top_p=0.92, | |
| top_k=5, | |
| no_repeat_ngram_size=15, | |
| num_beams=3, | |
| output_attentions=False, | |
| output_hidden_states=False, | |
| ) | |
| # process output | |
| prediction = processor.batch_decode(outputs)[0] | |
| print(prediction) | |
| return prediction | |
| logo = Image.open("./rsz_unstructured_logo.png") | |
| st.image(logo) | |
| st.markdown(''' | |
| ### Chipper | |
| Chipper is an OCR-free Document Understanding Transformer. It was pre-trained with over 1M documents from public sources and fine-tuned on a large range of documents. | |
| At [Unstructured.io](https://github.com/Unstructured-IO/unstructured) we are on a mission to build custom preprocessing pipelines for labeling, training, or production ML-ready pipelines. | |
| Come and join us in our public repos and contribute! Each of your contributions and feedback holds great value and is very significant to the community. | |
| ''') | |
| image_upload = None | |
| photo = None | |
| with st.sidebar: | |
| # file upload | |
| uploaded_file = st.file_uploader("Upload a document") | |
| if uploaded_file is not None: | |
| # To read file as bytes: | |
| image_bytes_data = uploaded_file.getvalue() | |
| image_upload = Image.open(BytesIO(image_bytes_data)) | |
| mode = st.selectbox('Mode', ('OCR', 'Element annotation'), index=1) | |
| if image_upload: | |
| image = image_upload | |
| else: | |
| image = Image.open(f"./document.png") | |
| st.image(image, caption='Your target document') | |
| with st.spinner(f'Processing the document ...'): | |
| pre_trained_model = "unstructuredio/chipper-v3" | |
| processor = DonutProcessor.from_pretrained(pre_trained_model, token=os.environ['HF_TOKEN']) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if 'model' in st.session_state: | |
| model = st.session_state['model'] | |
| else: | |
| model = VisionEncoderDecoderModel.from_pretrained(pre_trained_model, token=os.environ['HF_TOKEN']) | |
| from huggingface_hub import hf_hub_download | |
| lm_head_file = hf_hub_download( | |
| repo_id=pre_trained_model, filename="lm_head.pth", token=os.environ['HF_TOKEN'] | |
| ) | |
| rank = 128 | |
| model.decoder.lm_head = nn.Sequential( | |
| nn.Linear(model.decoder.lm_head.weight.shape[1], rank, bias=False), | |
| nn.Linear(rank, rank, bias=False), | |
| nn.Linear(rank, model.decoder.lm_head.weight.shape[0], bias=True), | |
| ) | |
| model.decoder.lm_head.load_state_dict(torch.load(lm_head_file)) | |
| model.eval() | |
| model.to(device) | |
| st.session_state['model'] = model | |
| st.info(f'Parsing document') | |
| parsed_info = run_prediction(image.convert("RGB"), model, processor, mode) | |
| st.text(f'\nDocument:') | |
| st.text_area('Output text', value=parsed_info, height=800) | |