ighoshsubho's picture
app.py created
9aaa3f4 verified
raw
history blame
2.48 kB
import gradio as gr
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
import torch
from transformers import BitsAndBytesConfig
from PIL import Image
import os
def load_model():
"""Load the model and processor"""
repo_name = "ighoshsubho/pali-gamma-finetuned-json"
device = "cuda" if torch.cuda.is_available() else "cpu"
# Configure quantization
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True
)
# Load processor and model
processor = PaliGemmaProcessor.from_pretrained(repo_name)
model = PaliGemmaForConditionalGeneration.from_pretrained(
repo_name,
quantization_config=quantization_config,
device_map=device,
torch_dtype=torch.bfloat16 if device == "cuda" else None
)
return model, processor
# Load model globally
print("Loading model...")
model, processor = load_model()
print("Model loaded successfully!")
def process_image(image, prompt):
"""Process the image and return the model's output"""
try:
# Ensure image is in PIL format
if not isinstance(image, Image.Image):
image = Image.open(image)
# Prepare inputs
inputs = processor(
text=[f"<image>{prompt}"],
images=[image],
return_tensors="pt",
padding="longest"
).to(model.device)
# Generate output
outputs = model.generate(
**inputs,
max_length=512,
num_beams=5,
temperature=0.7
)
# Decode output
result = processor.decode(outputs[0], skip_special_tokens=True)
return result
except Exception as e:
return f"Error processing image: {str(e)}"
# Create Gradio interface
demo = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here...",
value="extract data in JSON format"
)
],
outputs=gr.Textbox(label="Generated Output"),
title="PaLI-GAMMA Image Analysis",
description="Upload an image and get structured data extracted in JSON format. The model is running in 4-bit quantization mode.",
)
if __name__ == "__main__":
demo.launch()