gabrielmotablima's picture
Update app.py
69c95d9 verified
raw
history blame
3.78 kB
import requests
from PIL import Image
from transformers import AutoTokenizer, AutoImageProcessor, VisionEncoderDecoderModel
import gradio as gr
import os
from concurrent.futures import ThreadPoolExecutor
# Load the model, tokenizer, and image processor with error handling
def load_model_and_components(model_name):
model = VisionEncoderDecoderModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
image_processor = AutoImageProcessor.from_pretrained(model_name)
return model, tokenizer, image_processor
# Preload both models in parallel
def preload_models():
models = {}
model_names = ["laicsiifes/swin-distilbertimbau", "laicsiifes/swin-gportuguese-2"]
with ThreadPoolExecutor() as executor:
results = executor.map(load_model_and_components, model_names)
for name, result in zip(model_names, results):
models[name] = result
return models
models = preload_models()
# Function to process the image and generate a caption
def generate_caption(image, model_name):
model, tokenizer, image_processor = models[model_name]
pixel_values = image_processor(image, return_tensors="pt").pixel_values
generated_ids = model.generate(pixel_values)
caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return caption
# Predefined images for selection
image_folder = "images"
predefined_images = [
Image.open(os.path.join(image_folder, fname)).convert("RGB")
for fname in os.listdir(image_folder) \
if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.ppm'))
]
# Define components logic
def handle_uploaded_image(image):
if image is None:
return None, None
pil_image = image.convert("RGB")
return pil_image, None
def handle_generate_button(image, selected_model):
if image is None:
return "Please upload an image to generate a caption."
return generate_caption(image, selected_model)
# Define UI
with gr.Blocks(theme=gr.themes.Citrus(primary_hue="blue", secondary_hue="orange")) as interface:
gr.Markdown("""
# Welcome to the LAICSI-IFES space for Vision Encoder-Decoder (VED) demonstration
---
### Select an available model: Swin-DistilBERTimbau (168M) or Swin-GPorTuguese-2 (240M)
""")
with gr.Row(variant='panel'):
with gr.Column():
model_selector = gr.Dropdown(
choices=list(models.keys()),
value="laicsiifes/swin-distilbertimbau",
label="Select Model"
)
gr.Markdown("""
---
### Upload image or example images below, and click `Generate`
""")
with gr.Row(variant='panel'):
with gr.Column():
image_display = gr.Image(type="pil", label="Image Preview", image_mode="RGB", height=400)
with gr.Column():
output_text = gr.Textbox(label="Generated Caption")
generate_button = gr.Button("Generate")
gr.Markdown("""---""")
with gr.Row(variant='panel'):
examples = gr.Examples(
examples=predefined_images,
fn=handle_uploaded_image,
inputs=[image_display],
outputs=[image_display, output_text],
label="Examples"
)
# Define actions
model_selector.change(fn=lambda: (None, None), outputs=[image_display, output_text])
image_display.upload(fn=handle_uploaded_image, inputs=[image_display], outputs=[image_display, output_text])
image_display.clear(fn=lambda: None, outputs=[output_text])
generate_button.click(fn=handle_generate_button, inputs=[image_display, model_selector], outputs=output_text)
interface.launch(share=False)