import spaces import gradio as gr from huggingface_hub import list_models from typing import List import torch from transformers import DonutProcessor, VisionEncoderDecoderModel from PIL import Image import json import re import logging from datasets import load_dataset # Logging configuration logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global variables for Donut model, processor, and dataset donut_model = None donut_processor = None dataset = None def load_merit_dataset(): global dataset if dataset is None: dataset = load_dataset("de-Rodrigo/merit", name="en-digital-seq", split="train") return dataset def get_image_from_dataset(index): global dataset if dataset is None: dataset = load_merit_dataset() image_data = dataset[int(index)]["image"] return image_data def get_collection_models(tag: str) -> List[str]: """Get a list of models from a specific Hugging Face collection.""" models = list_models(author="de-Rodrigo") return [model.modelId for model in models if tag in model.tags] @spaces.GPU def get_donut(): global donut_model, donut_processor if donut_model is None or donut_processor is None: try: donut_model = VisionEncoderDecoderModel.from_pretrained( "de-Rodrigo/donut-merit" ) donut_processor = DonutProcessor.from_pretrained("de-Rodrigo/donut-merit") donut_model = donut_model.to("cuda") logger.info("Donut model loaded successfully on GPU") except Exception as e: logger.error(f"Error loading Donut model: {str(e)}") raise return donut_model, donut_processor @spaces.GPU def process_image_donut(model, processor, image): try: if not isinstance(image, Image.Image): image = Image.fromarray(image) pixel_values = processor(image, return_tensors="pt").pixel_values.to("cuda") task_prompt = "" decoder_input_ids = processor.tokenizer( task_prompt, add_special_tokens=False, return_tensors="pt" )["input_ids"].to("cuda") outputs = model.generate( pixel_values, decoder_input_ids=decoder_input_ids, max_length=model.decoder.config.max_position_embeddings, early_stopping=True, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, num_beams=1, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True, ) sequence = processor.batch_decode(outputs.sequences)[0] sequence = sequence.replace(processor.tokenizer.eos_token, "").replace( processor.tokenizer.pad_token, "" ) sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() result = processor.token2json(sequence) return json.dumps(result, indent=2) except Exception as e: logger.error(f"Error processing image with Donut: {str(e)}") return f"Error: {str(e)}" @spaces.GPU def process_image(model_name, image=None, dataset_image_index=None): if dataset_image_index is not None: image = get_image_from_dataset(dataset_image_index) if model_name == "de-Rodrigo/donut-merit": model, processor = get_donut() result = process_image_donut(model, processor, image) else: # Here you should implement processing for other models result = f"Processing for model {model_name} not implemented" return image, result if __name__ == "__main__": # Load the dataset load_merit_dataset() models = get_collection_models("saliency") models.append("de-Rodrigo/donut-merit") demo = gr.Interface( fn=process_image, inputs=[ gr.Dropdown(choices=models, label="Select Model"), gr.Image(type="pil", label="Upload Image"), gr.Slider( minimum=0, maximum=len(dataset) - 1, step=1, label="Dataset Image Index" ), ], outputs=[gr.Image(label="Processed Image"), gr.Textbox(label="Result")], title="Document Understanding with Donut", description="Upload an image or select one from the dataset to process with the selected model.", ) demo.launch()