File size: 5,462 Bytes
b57ccdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os

# Avoid Streamlit file watcher issues in Docker
os.environ["STREAMLIT_WATCHED_MODULES"] = ""

# Set a writable cache directory for Hugging Face models
os.environ["TRANSFORMERS_CACHE"] = "./hf_cache"
os.makedirs("./hf_cache", exist_ok=True)

# Set Streamlit-specific options to prevent permission errors
os.environ["STREAMLIT_HOME"] = os.getcwd()
os.environ["STREAMLIT_RUNTIME_METRICS_ENABLED"] = "false"
os.makedirs(".streamlit", exist_ok=True)

import streamlit as st
import torch
import joblib
import numpy as np
import random
from PIL import Image
from transformers import AutoTokenizer, AutoModel, ViTModel, ViTImageProcessor

# CPU device only
device = torch.device("cpu")

# Define Swahili VQA Model
class SwahiliVQAModel(torch.nn.Module):
    def __init__(self, num_answers):
        super().__init__()
        self.vision_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        self.text_encoder = AutoModel.from_pretrained("benjamin/roberta-base-wechsel-swahili")
        self.fusion = torch.nn.Sequential(
            torch.nn.Linear(768 + 768, 512),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.3),
            torch.nn.LayerNorm(512)
        )
        self.classifier = torch.nn.Linear(512, num_answers)

    def forward(self, image, input_ids, attention_mask):
        vision_outputs = self.vision_encoder(pixel_values=image)
        image_feats = vision_outputs.last_hidden_state[:, 0, :]
        text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        text_feats = text_outputs.last_hidden_state[:, 0, :]
        combined = torch.cat([image_feats, text_feats], dim=1)
        fused = self.fusion(combined)
        return self.classifier(fused)

# Load label encoder
le = joblib.load("Vit_3895_label_encoder_best.pkl")

# Load model weights normally — no override
model = SwahiliVQAModel(num_answers=len(le.classes_)).to(device)
state_dict = torch.load("Vit_3895_best_model_epoch25.pth", map_location=device)
model.load_state_dict(state_dict)
model.eval()

# Load tokenizer and processor
tokenizer = AutoTokenizer.from_pretrained("benjamin/roberta-base-wechsel-swahili")
vit_processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')

# Streamlit UI
st.set_page_config(page_title="Swahili VQA", layout="wide")
st.title("🦜 Swahili Visual Question Answering (VQA)")

uploaded_image = st.file_uploader("📂 Pakia picha hapa:", type=["jpg", "jpeg", "png"])

def generate_random_color():
    return f"rgb({random.randint(150, 255)}, {random.randint(80, 200)}, {random.randint(80, 200)})"

col1, col2 = st.columns([1, 2], gap="large")

with col1:
    if uploaded_image:
        st.image(uploaded_image, caption="Picha Iliyopakiwa", use_container_width=True)
        st.markdown("<div style='margin-bottom: 25px;'></div>", unsafe_allow_html=True)

with col2:
    st.markdown("<div style='padding-top: 15px;'>", unsafe_allow_html=True)
    question = st.text_input("💬Andika swali lako hapa:", key="question_input")
    submit_button = st.button("📩Tuma")
    st.markdown("</div>", unsafe_allow_html=True)

    if submit_button and uploaded_image and question:
        with st.spinner("🔍 Inachakata jibu..."):
            image = Image.open(uploaded_image).convert("RGB")
            image_tensor = vit_processor(images=image, return_tensors="pt")["pixel_values"]

            inputs = tokenizer(question, max_length=128, padding="max_length", truncation=True, return_tensors="pt")
            input_ids = inputs["input_ids"]
            attention_mask = inputs["attention_mask"]

            with torch.no_grad():
                logits = model(image_tensor, input_ids, attention_mask)
                probs = torch.softmax(logits, dim=1)

            top_probs, top_indices = torch.topk(probs, 5)
            decoded_answers = le.inverse_transform(top_indices.cpu().numpy()[0])

            results = [
                {"answer": ans, "confidence": round(prob * 100, 2)}
                for ans, prob in zip(decoded_answers, top_probs[0].tolist())
            ]

            results = sorted(results, key=lambda x: x["confidence"], reverse=True)

            st.subheader("Majibu Yanayowezekana:")
            max_confidence = max(result["confidence"] for result in results)

            for i, pred in enumerate(results):
                bar_width = (pred["confidence"] / max_confidence) * 70
                color = generate_random_color()
                st.markdown(
                    f"""
                    <div style="margin: 4px 0; padding: 2px 0; {'border-bottom: 1px solid rgba(150, 150, 150, 0.1);' if i < len(results)-1 else ''}">
                        <div style="font-size: 14px; font-weight: bold; margin-bottom: 2px;">
                            {pred['answer']}
                        </div>
                        <div style="display: flex; align-items: center; gap: 6px;">
                            <div style="width: {bar_width}%; height: 8px; border-radius: 3px; background: {color};"></div>
                            <div style="font-size: 13px; min-width: 45px;">
                                {pred['confidence']}%
                            </div>
                        </div>
                    </div>
                    """,
                    unsafe_allow_html=True
                )
    else:
        st.info("📥 Pakia picha na andika swali kisha bonyeza Tuma ili kupata jibu.")