Narayana02's picture
Update app.py
af3702c verified
raw
history blame
4.68 kB
import os
import streamlit as st
import onnxruntime as ort
from transformers import AutoTokenizer, AutoProcessor
from PIL import Image
from io import BytesIO
# Download ONNX models if they do not already exist
if not os.path.exists("vision_encoder_q4f16.onnx"):
os.system('wget https://huggingface.co/llava-hf/llava-interleave-qwen-0.5b-hf/resolve/main/onnx/vision_encoder_q4f16.onnx')
if not os.path.exists("decoder_model_merged_q4f16.onnx"):
os.system('wget https://huggingface.co/llava-hf/llava-interleave-qwen-0.5b-hf/resolve/main/onnx/decoder_model_merged_q4f16.onnx')
if not os.path.exists("embed_tokens_q4f16.onnx"):
os.system('wget https://huggingface.co/llava-hf/llava-interleave-qwen-0.5b-hf/resolve/main/onnx/embed_tokens_q4f16.onnx')
# Load tokenizer and processor
tokenizer = AutoTokenizer.from_pretrained("llava-hf/llava-interleave-qwen-0.5b-hf")
processor = AutoProcessor.from_pretrained("llava-hf/llava-interleave-qwen-0.5b-hf")
# Load ONNX sessions
vision_encoder_session = ort.InferenceSession("vision_encoder_q4f16.onnx")
decoder_session = ort.InferenceSession("decoder_model_merged_q4f16.onnx")
embed_tokens_session = ort.InferenceSession("embed_tokens_q4f16.onnx")
# Streamlit App Configuration
st.set_page_config(page_title="Vision-Based ONNX AI App", page_icon="πŸ€–", layout="wide")
st.title("πŸ–ΌοΈ Vision-Based ONNX AI Demo App")
st.markdown("<p style='text-align: center; font-size: 18px; color: #555;'>Upload an image and get a description</p>", unsafe_allow_html=True)
# User Input: Image Upload
uploaded_image = st.file_uploader("Upload an Image", type=["png", "jpg", "jpeg"])
user_prompt = st.text_input("Enter your prompt", value="Describe this image in detail", placeholder="e.g., What is shown in the image?")
# Display uploaded image
def display_uploaded_image(uploaded_image):
try:
img = Image.open(uploaded_image)
st.image(img, caption="Uploaded Image", use_container_width=True)
return img
except Exception as e:
st.error(f"❌ Unable to display image. Error: {e}")
return None
# Process the uploaded image
if st.button("Get Description"):
if uploaded_image and user_prompt:
try:
# Display the uploaded image
img = display_uploaded_image(uploaded_image)
if img is None:
st.error("❌ Image processing failed.")
st.stop()
# Preprocess the image
img_buffer = BytesIO()
img.save(img_buffer, format="PNG")
img_bytes = img_buffer.getvalue()
processed_image = processor(images=img, return_tensors="np")
# Generate embeddings using the vision encoder
vision_embeddings = vision_encoder_session.run(
None, {"pixel_values": processed_image["pixel_values"]}
)[0]
# Tokenize the user prompt
inputs = tokenizer(user_prompt, return_tensors="np")
input_ids = inputs["input_ids"]
# Generate embedded tokens
embedded_tokens = embed_tokens_session.run(
None, {"input_ids": input_ids}
)[0]
# Generate a response using the decoder
decoder_outputs = decoder_session.run(
None, {
"vision_embeddings": vision_embeddings,
"embedded_tokens": embedded_tokens
}
)[0]
# Decode the output
description = tokenizer.decode(decoder_outputs, skip_special_tokens=True)
# Display the description
st.subheader("πŸ“ Model Response")
st.markdown(f"**Description**: {description}")
except Exception as e:
st.error(f"❌ An error occurred: {e}")
else:
st.warning("⚠️ Please upload an image and enter a prompt.")
# UI Enhancements
st.markdown("""
<style>
.stButton>button {
background-color: #0072BB;
color: white;
font-size: 16px;
border-radius: 10px;
padding: 10px 20px;
font-weight: bold;
transition: background-color 0.3s;
}
.stButton>button:hover {
background-color: #005f8a;
}
.stTextInput>div>div>input {
padding: 10px;
font-size: 16px;
border-radius: 10px;
}
.stFileUploader>div>div {
border-radius: 10px;
}
/* Center the image */
.stImage {
display: block;
margin-left: auto;
margin-right: auto;
}
</style>
""", unsafe_allow_html=True)