import os import streamlit as st import onnxruntime as ort from transformers import AutoTokenizer, AutoProcessor from PIL import Image from io import BytesIO # Download ONNX models if they do not already exist if not os.path.exists("vision_encoder_q4f16.onnx"): os.system('wget https://huggingface.co/llava-hf/llava-interleave-qwen-0.5b-hf/resolve/main/onnx/vision_encoder_q4f16.onnx') if not os.path.exists("decoder_model_merged_q4f16.onnx"): os.system('wget https://huggingface.co/llava-hf/llava-interleave-qwen-0.5b-hf/resolve/main/onnx/decoder_model_merged_q4f16.onnx') if not os.path.exists("embed_tokens_q4f16.onnx"): os.system('wget https://huggingface.co/llava-hf/llava-interleave-qwen-0.5b-hf/resolve/main/onnx/embed_tokens_q4f16.onnx') # Load tokenizer and processor tokenizer = AutoTokenizer.from_pretrained("llava-hf/llava-interleave-qwen-0.5b-hf") processor = AutoProcessor.from_pretrained("llava-hf/llava-interleave-qwen-0.5b-hf") # Load ONNX sessions vision_encoder_session = ort.InferenceSession("vision_encoder_q4f16.onnx") decoder_session = ort.InferenceSession("decoder_model_merged_q4f16.onnx") embed_tokens_session = ort.InferenceSession("embed_tokens_q4f16.onnx") # Streamlit App Configuration st.set_page_config(page_title="Vision-Based ONNX AI App", page_icon="🤖", layout="wide") st.title("🖼️ Vision-Based ONNX AI Demo App") st.markdown("

Upload an image and get a description

", unsafe_allow_html=True) # User Input: Image Upload uploaded_image = st.file_uploader("Upload an Image", type=["png", "jpg", "jpeg"]) user_prompt = st.text_input("Enter your prompt", value="Describe this image in detail", placeholder="e.g., What is shown in the image?") # Display uploaded image def display_uploaded_image(uploaded_image): try: img = Image.open(uploaded_image) st.image(img, caption="Uploaded Image", use_container_width=True) return img except Exception as e: st.error(f"❌ Unable to display image. Error: {e}") return None # Process the uploaded image if st.button("Get Description"): if uploaded_image and user_prompt: try: # Display the uploaded image img = display_uploaded_image(uploaded_image) if img is None: st.error("❌ Image processing failed.") st.stop() # Preprocess the image img_buffer = BytesIO() img.save(img_buffer, format="PNG") img_bytes = img_buffer.getvalue() processed_image = processor(images=img, return_tensors="np") # Generate embeddings using the vision encoder vision_embeddings = vision_encoder_session.run( None, {"pixel_values": processed_image["pixel_values"]} )[0] # Tokenize the user prompt inputs = tokenizer(user_prompt, return_tensors="np") input_ids = inputs["input_ids"] # Generate embedded tokens embedded_tokens = embed_tokens_session.run( None, {"input_ids": input_ids} )[0] # Generate a response using the decoder decoder_outputs = decoder_session.run( None, { "vision_embeddings": vision_embeddings, "embedded_tokens": embedded_tokens } )[0] # Decode the output description = tokenizer.decode(decoder_outputs, skip_special_tokens=True) # Display the description st.subheader("📝 Model Response") st.markdown(f"**Description**: {description}") except Exception as e: st.error(f"❌ An error occurred: {e}") else: st.warning("⚠️ Please upload an image and enter a prompt.") # UI Enhancements st.markdown(""" """, unsafe_allow_html=True)