Spaces:
Running
Running
File size: 5,462 Bytes
b57ccdf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
# Avoid Streamlit file watcher issues in Docker
os.environ["STREAMLIT_WATCHED_MODULES"] = ""
# Set a writable cache directory for Hugging Face models
os.environ["TRANSFORMERS_CACHE"] = "./hf_cache"
os.makedirs("./hf_cache", exist_ok=True)
# Set Streamlit-specific options to prevent permission errors
os.environ["STREAMLIT_HOME"] = os.getcwd()
os.environ["STREAMLIT_RUNTIME_METRICS_ENABLED"] = "false"
os.makedirs(".streamlit", exist_ok=True)
import streamlit as st
import torch
import joblib
import numpy as np
import random
from PIL import Image
from transformers import AutoTokenizer, AutoModel, ViTModel, ViTImageProcessor
# CPU device only
device = torch.device("cpu")
# Define Swahili VQA Model
class SwahiliVQAModel(torch.nn.Module):
def __init__(self, num_answers):
super().__init__()
self.vision_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
self.text_encoder = AutoModel.from_pretrained("benjamin/roberta-base-wechsel-swahili")
self.fusion = torch.nn.Sequential(
torch.nn.Linear(768 + 768, 512),
torch.nn.ReLU(),
torch.nn.Dropout(0.3),
torch.nn.LayerNorm(512)
)
self.classifier = torch.nn.Linear(512, num_answers)
def forward(self, image, input_ids, attention_mask):
vision_outputs = self.vision_encoder(pixel_values=image)
image_feats = vision_outputs.last_hidden_state[:, 0, :]
text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
text_feats = text_outputs.last_hidden_state[:, 0, :]
combined = torch.cat([image_feats, text_feats], dim=1)
fused = self.fusion(combined)
return self.classifier(fused)
# Load label encoder
le = joblib.load("Vit_3895_label_encoder_best.pkl")
# Load model weights normally — no override
model = SwahiliVQAModel(num_answers=len(le.classes_)).to(device)
state_dict = torch.load("Vit_3895_best_model_epoch25.pth", map_location=device)
model.load_state_dict(state_dict)
model.eval()
# Load tokenizer and processor
tokenizer = AutoTokenizer.from_pretrained("benjamin/roberta-base-wechsel-swahili")
vit_processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
# Streamlit UI
st.set_page_config(page_title="Swahili VQA", layout="wide")
st.title("🦜 Swahili Visual Question Answering (VQA)")
uploaded_image = st.file_uploader("📂 Pakia picha hapa:", type=["jpg", "jpeg", "png"])
def generate_random_color():
return f"rgb({random.randint(150, 255)}, {random.randint(80, 200)}, {random.randint(80, 200)})"
col1, col2 = st.columns([1, 2], gap="large")
with col1:
if uploaded_image:
st.image(uploaded_image, caption="Picha Iliyopakiwa", use_container_width=True)
st.markdown("<div style='margin-bottom: 25px;'></div>", unsafe_allow_html=True)
with col2:
st.markdown("<div style='padding-top: 15px;'>", unsafe_allow_html=True)
question = st.text_input("💬Andika swali lako hapa:", key="question_input")
submit_button = st.button("📩Tuma")
st.markdown("</div>", unsafe_allow_html=True)
if submit_button and uploaded_image and question:
with st.spinner("🔍 Inachakata jibu..."):
image = Image.open(uploaded_image).convert("RGB")
image_tensor = vit_processor(images=image, return_tensors="pt")["pixel_values"]
inputs = tokenizer(question, max_length=128, padding="max_length", truncation=True, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
with torch.no_grad():
logits = model(image_tensor, input_ids, attention_mask)
probs = torch.softmax(logits, dim=1)
top_probs, top_indices = torch.topk(probs, 5)
decoded_answers = le.inverse_transform(top_indices.cpu().numpy()[0])
results = [
{"answer": ans, "confidence": round(prob * 100, 2)}
for ans, prob in zip(decoded_answers, top_probs[0].tolist())
]
results = sorted(results, key=lambda x: x["confidence"], reverse=True)
st.subheader("Majibu Yanayowezekana:")
max_confidence = max(result["confidence"] for result in results)
for i, pred in enumerate(results):
bar_width = (pred["confidence"] / max_confidence) * 70
color = generate_random_color()
st.markdown(
f"""
<div style="margin: 4px 0; padding: 2px 0; {'border-bottom: 1px solid rgba(150, 150, 150, 0.1);' if i < len(results)-1 else ''}">
<div style="font-size: 14px; font-weight: bold; margin-bottom: 2px;">
{pred['answer']}
</div>
<div style="display: flex; align-items: center; gap: 6px;">
<div style="width: {bar_width}%; height: 8px; border-radius: 3px; background: {color};"></div>
<div style="font-size: 13px; min-width: 45px;">
{pred['confidence']}%
</div>
</div>
</div>
""",
unsafe_allow_html=True
)
else:
st.info("📥 Pakia picha na andika swali kisha bonyeza Tuma ili kupata jibu.")
|