|
|
import os |
|
|
import streamlit as st |
|
|
import onnxruntime as ort |
|
|
from transformers import AutoTokenizer, AutoProcessor |
|
|
from PIL import Image |
|
|
from io import BytesIO |
|
|
|
|
|
|
|
|
if not os.path.exists("vision_encoder_q4f16.onnx"): |
|
|
os.system('wget https://huggingface.co/llava-hf/llava-interleave-qwen-0.5b-hf/resolve/main/onnx/vision_encoder_q4f16.onnx') |
|
|
if not os.path.exists("decoder_model_merged_q4f16.onnx"): |
|
|
os.system('wget https://huggingface.co/llava-hf/llava-interleave-qwen-0.5b-hf/resolve/main/onnx/decoder_model_merged_q4f16.onnx') |
|
|
if not os.path.exists("embed_tokens_q4f16.onnx"): |
|
|
os.system('wget https://huggingface.co/llava-hf/llava-interleave-qwen-0.5b-hf/resolve/main/onnx/embed_tokens_q4f16.onnx') |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("llava-hf/llava-interleave-qwen-0.5b-hf") |
|
|
processor = AutoProcessor.from_pretrained("llava-hf/llava-interleave-qwen-0.5b-hf") |
|
|
|
|
|
|
|
|
vision_encoder_session = ort.InferenceSession("vision_encoder_q4f16.onnx") |
|
|
decoder_session = ort.InferenceSession("decoder_model_merged_q4f16.onnx") |
|
|
embed_tokens_session = ort.InferenceSession("embed_tokens_q4f16.onnx") |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Vision-Based ONNX AI App", page_icon="π€", layout="wide") |
|
|
st.title("πΌοΈ Vision-Based ONNX AI Demo App") |
|
|
st.markdown("<p style='text-align: center; font-size: 18px; color: #555;'>Upload an image and get a description</p>", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
uploaded_image = st.file_uploader("Upload an Image", type=["png", "jpg", "jpeg"]) |
|
|
user_prompt = st.text_input("Enter your prompt", value="Describe this image in detail", placeholder="e.g., What is shown in the image?") |
|
|
|
|
|
|
|
|
def display_uploaded_image(uploaded_image): |
|
|
try: |
|
|
img = Image.open(uploaded_image) |
|
|
st.image(img, caption="Uploaded Image", use_container_width=True) |
|
|
return img |
|
|
except Exception as e: |
|
|
st.error(f"β Unable to display image. Error: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
if st.button("Get Description"): |
|
|
if uploaded_image and user_prompt: |
|
|
try: |
|
|
|
|
|
img = display_uploaded_image(uploaded_image) |
|
|
if img is None: |
|
|
st.error("β Image processing failed.") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
img_buffer = BytesIO() |
|
|
img.save(img_buffer, format="PNG") |
|
|
img_bytes = img_buffer.getvalue() |
|
|
processed_image = processor(images=img, return_tensors="np") |
|
|
|
|
|
|
|
|
vision_embeddings = vision_encoder_session.run( |
|
|
None, {"pixel_values": processed_image["pixel_values"]} |
|
|
)[0] |
|
|
|
|
|
|
|
|
inputs = tokenizer(user_prompt, return_tensors="np") |
|
|
input_ids = inputs["input_ids"] |
|
|
|
|
|
|
|
|
embedded_tokens = embed_tokens_session.run( |
|
|
None, {"input_ids": input_ids} |
|
|
)[0] |
|
|
|
|
|
|
|
|
decoder_outputs = decoder_session.run( |
|
|
None, { |
|
|
"vision_embeddings": vision_embeddings, |
|
|
"embedded_tokens": embedded_tokens |
|
|
} |
|
|
)[0] |
|
|
|
|
|
|
|
|
description = tokenizer.decode(decoder_outputs, skip_special_tokens=True) |
|
|
|
|
|
|
|
|
st.subheader("π Model Response") |
|
|
st.markdown(f"**Description**: {description}") |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"β An error occurred: {e}") |
|
|
else: |
|
|
st.warning("β οΈ Please upload an image and enter a prompt.") |
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<style> |
|
|
.stButton>button { |
|
|
background-color: #0072BB; |
|
|
color: white; |
|
|
font-size: 16px; |
|
|
border-radius: 10px; |
|
|
padding: 10px 20px; |
|
|
font-weight: bold; |
|
|
transition: background-color 0.3s; |
|
|
} |
|
|
.stButton>button:hover { |
|
|
background-color: #005f8a; |
|
|
} |
|
|
|
|
|
.stTextInput>div>div>input { |
|
|
padding: 10px; |
|
|
font-size: 16px; |
|
|
border-radius: 10px; |
|
|
} |
|
|
|
|
|
.stFileUploader>div>div { |
|
|
border-radius: 10px; |
|
|
} |
|
|
|
|
|
/* Center the image */ |
|
|
.stImage { |
|
|
display: block; |
|
|
margin-left: auto; |
|
|
margin-right: auto; |
|
|
} |
|
|
</style> |
|
|
""", unsafe_allow_html=True) |
|
|
|