File size: 5,813 Bytes
b57ccdf
 
38bc9e9
4d070ba
b79f679
4d070ba
38bc9e9
 
72a7a56
b57ccdf
 
 
 
 
 
 
 
38bc9e9
b57ccdf
 
38bc9e9
b57ccdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38bc9e9
b57ccdf
 
38bc9e9
 
b57ccdf
 
 
 
 
 
 
38bc9e9
b57ccdf
5a0a1d0
cf89f27
 
b57ccdf
 
 
 
 
 
 
 
 
 
38bc9e9
b57ccdf
 
 
38bc9e9
 
b57ccdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38bc9e9
b57ccdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eae7b9b
b57ccdf
eae7b9b
d644137
eae7b9b
 
 
2e86b56
d644137
2e86b56
d644137
2e86b56
d644137
2e86b56
 
 
 
 
 
 
eae7b9b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os

# βœ… Use /tmp for all cache & runtime folders (Hugging Face safe)
os.environ["STREAMLIT_HOME"] = "/tmp"
os.environ["STREAMLIT_RUNTIME_METRICS_ENABLED"] = "false"
os.environ["STREAMLIT_WATCHED_MODULES"] = ""
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["HF_HOME"] = "/tmp/huggingface"

import streamlit as st
import torch
import joblib
import numpy as np
import random
from PIL import Image
from transformers import AutoTokenizer, AutoModel, ViTModel, ViTImageProcessor

# Use CPU only
device = torch.device("cpu")

# === Define Swahili VQA Model ===
class SwahiliVQAModel(torch.nn.Module):
    def __init__(self, num_answers):
        super().__init__()
        self.vision_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        self.text_encoder = AutoModel.from_pretrained("benjamin/roberta-base-wechsel-swahili")
        self.fusion = torch.nn.Sequential(
            torch.nn.Linear(768 + 768, 512),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.3),
            torch.nn.LayerNorm(512)
        )
        self.classifier = torch.nn.Linear(512, num_answers)

    def forward(self, image, input_ids, attention_mask):
        vision_outputs = self.vision_encoder(pixel_values=image)
        image_feats = vision_outputs.last_hidden_state[:, 0, :]
        text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        text_feats = text_outputs.last_hidden_state[:, 0, :]
        combined = torch.cat([image_feats, text_feats], dim=1)
        fused = self.fusion(combined)
        return self.classifier(fused)

# === Load model and encoders ===
le = joblib.load("Vit_3895_label_encoder_best.pkl")
model = SwahiliVQAModel(num_answers=len(le.classes_)).to(device)

# Load full state dict (already trained classifier)
state_dict = torch.load("Vit_3895_best_model_epoch25.pth", map_location=device)
model.load_state_dict(state_dict)
model.eval()

tokenizer = AutoTokenizer.from_pretrained("benjamin/roberta-base-wechsel-swahili")
vit_processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')

# === Streamlit App ===
st.set_page_config(page_title="Swahili VQA", layout="wide")
st.title("🦜 Swahili Visual Question Answering App")
#st.markdown("**Pakia picha na uliza swali kwa Kiswahili**")
st.info("πŸ“₯ Pakia picha na andika swali kisha bonyeza Tuma ili kupata jibu.")

uploaded_image = st.file_uploader("πŸ“‚ Pakia picha hapa:", type=["jpg", "jpeg", "png"])

def generate_random_color():
    return f"rgb({random.randint(150, 255)}, {random.randint(80, 200)}, {random.randint(80, 200)})"

col1, col2 = st.columns([1, 2], gap="large")

with col1:
    if uploaded_image:
        st.image(uploaded_image, caption="Picha Iliyopakiwa")

with col2:
    st.markdown("<div style='padding-top: 15px;'>", unsafe_allow_html=True)
    question = st.text_input("πŸ’¬ Andika swali lako hapa:")
    submit_button = st.button("πŸ“© Tuma")
    st.markdown("</div>", unsafe_allow_html=True)

    if submit_button and uploaded_image and question:
        with st.spinner("πŸ” Inachakata jibu..."):
            image = Image.open(uploaded_image).convert("RGB")
            image_tensor = vit_processor(images=image, return_tensors="pt")["pixel_values"]

            inputs = tokenizer(question, max_length=128, padding="max_length", truncation=True, return_tensors="pt")
            input_ids = inputs["input_ids"]
            attention_mask = inputs["attention_mask"]

            with torch.no_grad():
                logits = model(image_tensor, input_ids, attention_mask)
                probs = torch.softmax(logits, dim=1)

            top_probs, top_indices = torch.topk(probs, 5)
            decoded_answers = le.inverse_transform(top_indices.cpu().numpy()[0])

            results = [
                {"answer": ans, "confidence": round(prob * 100, 2)}
                for ans, prob in zip(decoded_answers, top_probs[0].tolist())
            ]

            results = sorted(results, key=lambda x: x["confidence"], reverse=True)
            st.subheader("πŸ”Ž Majibu Yanayowezekana:")

            max_confidence = max(result["confidence"] for result in results)
            for i, pred in enumerate(results):
                bar_width = (pred["confidence"] / max_confidence) * 70
                color = generate_random_color()
                st.markdown(
                    f"""
                    <div style="margin: 4px 0; padding: 2px 0; {'border-bottom: 1px solid rgba(150, 150, 150, 0.1);' if i < len(results)-1 else ''}">
                        <div style="font-size: 14px; font-weight: bold; margin-bottom: 2px;">
                            {pred['answer']}
                        </div>
                        <div style="display: flex; align-items: center; gap: 6px;">
                            <div style="width: {bar_width}%; height: 8px; border-radius: 3px; background: {color};"></div>
                            <div style="font-size: 13px; min-width: 45px;">
                                {pred['confidence']}%
                            </div>
                        </div>
                    </div>
                    """,
                    unsafe_allow_html=True
                )
                
    else:
        pass  # or just remove the else block entirely

# === Sidebar Footer Description (ALWAYS VISIBLE) ===
st.sidebar.markdown("""
---
## Swahili VQA App

This app allows users to ask questions about images in **Swahili**. Powered by a multimodal AI model trained on visual and textual data.

## πŸ›  How to Use

1. πŸ“€ Upload an image.  
2. πŸ’¬ Type a question in Swahili.  
3. πŸ“© Click **Tuma**.  
4. πŸ€– The model will predict top 5 possible answers with confidence score.

## πŸ“Œ Note
πŸŽ“ Designed for educational and research purposes
""")