Spaces:
Running
Running
File size: 5,813 Bytes
b57ccdf 38bc9e9 4d070ba b79f679 4d070ba 38bc9e9 72a7a56 b57ccdf 38bc9e9 b57ccdf 38bc9e9 b57ccdf 38bc9e9 b57ccdf 38bc9e9 b57ccdf 38bc9e9 b57ccdf 5a0a1d0 cf89f27 b57ccdf 38bc9e9 b57ccdf 38bc9e9 b57ccdf 38bc9e9 b57ccdf eae7b9b b57ccdf eae7b9b d644137 eae7b9b 2e86b56 d644137 2e86b56 d644137 2e86b56 d644137 2e86b56 eae7b9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import os
# β
Use /tmp for all cache & runtime folders (Hugging Face safe)
os.environ["STREAMLIT_HOME"] = "/tmp"
os.environ["STREAMLIT_RUNTIME_METRICS_ENABLED"] = "false"
os.environ["STREAMLIT_WATCHED_MODULES"] = ""
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["HF_HOME"] = "/tmp/huggingface"
import streamlit as st
import torch
import joblib
import numpy as np
import random
from PIL import Image
from transformers import AutoTokenizer, AutoModel, ViTModel, ViTImageProcessor
# Use CPU only
device = torch.device("cpu")
# === Define Swahili VQA Model ===
class SwahiliVQAModel(torch.nn.Module):
def __init__(self, num_answers):
super().__init__()
self.vision_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
self.text_encoder = AutoModel.from_pretrained("benjamin/roberta-base-wechsel-swahili")
self.fusion = torch.nn.Sequential(
torch.nn.Linear(768 + 768, 512),
torch.nn.ReLU(),
torch.nn.Dropout(0.3),
torch.nn.LayerNorm(512)
)
self.classifier = torch.nn.Linear(512, num_answers)
def forward(self, image, input_ids, attention_mask):
vision_outputs = self.vision_encoder(pixel_values=image)
image_feats = vision_outputs.last_hidden_state[:, 0, :]
text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
text_feats = text_outputs.last_hidden_state[:, 0, :]
combined = torch.cat([image_feats, text_feats], dim=1)
fused = self.fusion(combined)
return self.classifier(fused)
# === Load model and encoders ===
le = joblib.load("Vit_3895_label_encoder_best.pkl")
model = SwahiliVQAModel(num_answers=len(le.classes_)).to(device)
# Load full state dict (already trained classifier)
state_dict = torch.load("Vit_3895_best_model_epoch25.pth", map_location=device)
model.load_state_dict(state_dict)
model.eval()
tokenizer = AutoTokenizer.from_pretrained("benjamin/roberta-base-wechsel-swahili")
vit_processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
# === Streamlit App ===
st.set_page_config(page_title="Swahili VQA", layout="wide")
st.title("π¦ Swahili Visual Question Answering App")
#st.markdown("**Pakia picha na uliza swali kwa Kiswahili**")
st.info("π₯ Pakia picha na andika swali kisha bonyeza Tuma ili kupata jibu.")
uploaded_image = st.file_uploader("π Pakia picha hapa:", type=["jpg", "jpeg", "png"])
def generate_random_color():
return f"rgb({random.randint(150, 255)}, {random.randint(80, 200)}, {random.randint(80, 200)})"
col1, col2 = st.columns([1, 2], gap="large")
with col1:
if uploaded_image:
st.image(uploaded_image, caption="Picha Iliyopakiwa")
with col2:
st.markdown("<div style='padding-top: 15px;'>", unsafe_allow_html=True)
question = st.text_input("π¬ Andika swali lako hapa:")
submit_button = st.button("π© Tuma")
st.markdown("</div>", unsafe_allow_html=True)
if submit_button and uploaded_image and question:
with st.spinner("π Inachakata jibu..."):
image = Image.open(uploaded_image).convert("RGB")
image_tensor = vit_processor(images=image, return_tensors="pt")["pixel_values"]
inputs = tokenizer(question, max_length=128, padding="max_length", truncation=True, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
with torch.no_grad():
logits = model(image_tensor, input_ids, attention_mask)
probs = torch.softmax(logits, dim=1)
top_probs, top_indices = torch.topk(probs, 5)
decoded_answers = le.inverse_transform(top_indices.cpu().numpy()[0])
results = [
{"answer": ans, "confidence": round(prob * 100, 2)}
for ans, prob in zip(decoded_answers, top_probs[0].tolist())
]
results = sorted(results, key=lambda x: x["confidence"], reverse=True)
st.subheader("π Majibu Yanayowezekana:")
max_confidence = max(result["confidence"] for result in results)
for i, pred in enumerate(results):
bar_width = (pred["confidence"] / max_confidence) * 70
color = generate_random_color()
st.markdown(
f"""
<div style="margin: 4px 0; padding: 2px 0; {'border-bottom: 1px solid rgba(150, 150, 150, 0.1);' if i < len(results)-1 else ''}">
<div style="font-size: 14px; font-weight: bold; margin-bottom: 2px;">
{pred['answer']}
</div>
<div style="display: flex; align-items: center; gap: 6px;">
<div style="width: {bar_width}%; height: 8px; border-radius: 3px; background: {color};"></div>
<div style="font-size: 13px; min-width: 45px;">
{pred['confidence']}%
</div>
</div>
</div>
""",
unsafe_allow_html=True
)
else:
pass # or just remove the else block entirely
# === Sidebar Footer Description (ALWAYS VISIBLE) ===
st.sidebar.markdown("""
---
## Swahili VQA App
This app allows users to ask questions about images in **Swahili**. Powered by a multimodal AI model trained on visual and textual data.
## π How to Use
1. π€ Upload an image.
2. π¬ Type a question in Swahili.
3. π© Click **Tuma**.
4. π€ The model will predict top 5 possible answers with confidence score.
## π Note
π Designed for educational and research purposes
""") |