mbwanaf commited on
Commit
1a56ea5
·
verified ·
1 Parent(s): 33b934c

Upload 4 files

Browse files
Vit_3895_best_model_epoch25.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39065c8493865b8f783eb481307ba87dd494f740f8f9f379fd4e83b3f615cce0
3
+ size 849042907
Vit_3895_label_encoder_best.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9416a32d4f5550387283d0fccc2fdbaa37cc71d18fc7f7e476a8f6aa0ee1131
3
+ size 37191
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["STREAMLIT_HOME"] = os.getcwd()
3
+ os.environ["STREAMLIT_RUNTIME_METRICS_ENABLED"] = "false"
4
+ os.makedirs(".streamlit", exist_ok=True)
5
+
6
+ import streamlit as st
7
+ import torch
8
+ import joblib
9
+ import numpy as np
10
+ import random
11
+ from PIL import Image
12
+ from transformers import AutoTokenizer, AutoModel, ViTModel, ViTImageProcessor
13
+
14
+ # CPU device only
15
+ device = torch.device("cpu")
16
+
17
+ # Define Swahili VQA Model
18
+ class SwahiliVQAModel(torch.nn.Module):
19
+ def __init__(self, num_answers):
20
+ super().__init__()
21
+ self.vision_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
22
+ self.text_encoder = AutoModel.from_pretrained("benjamin/roberta-base-wechsel-swahili")
23
+ self.fusion = torch.nn.Sequential(
24
+ torch.nn.Linear(768 + 768, 512),
25
+ torch.nn.ReLU(),
26
+ torch.nn.Dropout(0.3),
27
+ torch.nn.LayerNorm(512)
28
+ )
29
+ self.classifier = torch.nn.Linear(512, num_answers)
30
+
31
+ def forward(self, image, input_ids, attention_mask):
32
+ vision_outputs = self.vision_encoder(pixel_values=image)
33
+ image_feats = vision_outputs.last_hidden_state[:, 0, :]
34
+ text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
35
+ text_feats = text_outputs.last_hidden_state[:, 0, :]
36
+ combined = torch.cat([image_feats, text_feats], dim=1)
37
+ fused = self.fusion(combined)
38
+ return self.classifier(fused)
39
+
40
+ # Load label encoder
41
+ le = joblib.load("Vit_3895_label_encoder_best.pkl")
42
+
43
+ # Load model weights normally — no override
44
+ model = SwahiliVQAModel(num_answers=len(le.classes_)).to(device)
45
+ state_dict = torch.load("Vit_3895_best_model_epoch25.pth", map_location=device)
46
+ model.load_state_dict(state_dict)
47
+ model.eval()
48
+
49
+ # Load tokenizer and processor
50
+ tokenizer = AutoTokenizer.from_pretrained("benjamin/roberta-base-wechsel-swahili")
51
+ vit_processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
52
+
53
+ # Streamlit UI
54
+ st.set_page_config(page_title="Swahili VQA", layout="wide")
55
+ st.title("🦜 Swahili Visual Question Answering (VQA)")
56
+
57
+ uploaded_image = st.file_uploader("📂 Pakia picha hapa:", type=["jpg", "jpeg", "png"])
58
+
59
+ def generate_random_color():
60
+ return f"rgb({random.randint(150, 255)}, {random.randint(80, 200)}, {random.randint(80, 200)})"
61
+
62
+ col1, col2 = st.columns([1, 2], gap="large")
63
+
64
+ with col1:
65
+ if uploaded_image:
66
+ st.image(uploaded_image, caption="Picha Iliyopakiwa", use_container_width=True)
67
+ st.markdown("<div style='margin-bottom: 25px;'></div>", unsafe_allow_html=True)
68
+
69
+ with col2:
70
+ st.markdown("<div style='padding-top: 15px;'>", unsafe_allow_html=True)
71
+ question = st.text_input("💬Andika swali lako hapa:", key="question_input")
72
+ submit_button = st.button("📩Tuma")
73
+ st.markdown("</div>", unsafe_allow_html=True)
74
+
75
+ if submit_button and uploaded_image and question:
76
+ with st.spinner("🔍 Inachakata jibu..."):
77
+ image = Image.open(uploaded_image).convert("RGB")
78
+ image_tensor = vit_processor(images=image, return_tensors="pt")["pixel_values"]
79
+
80
+ inputs = tokenizer(question, max_length=128, padding="max_length", truncation=True, return_tensors="pt")
81
+ input_ids = inputs["input_ids"]
82
+ attention_mask = inputs["attention_mask"]
83
+
84
+ with torch.no_grad():
85
+ logits = model(image_tensor, input_ids, attention_mask)
86
+ probs = torch.softmax(logits, dim=1)
87
+
88
+ top_probs, top_indices = torch.topk(probs, 5)
89
+ decoded_answers = le.inverse_transform(top_indices.cpu().numpy()[0])
90
+
91
+ results = [
92
+ {"answer": ans, "confidence": round(prob * 100, 2)}
93
+ for ans, prob in zip(decoded_answers, top_probs[0].tolist())
94
+ ]
95
+
96
+ results = sorted(results, key=lambda x: x["confidence"], reverse=True)
97
+
98
+ st.subheader("Majibu Yanayowezekana:")
99
+ max_confidence = max(result["confidence"] for result in results)
100
+
101
+ for i, pred in enumerate(results):
102
+ bar_width = (pred["confidence"] / max_confidence) * 70
103
+ color = generate_random_color()
104
+ st.markdown(
105
+ f"""
106
+ <div style="margin: 4px 0; padding: 2px 0; {'border-bottom: 1px solid rgba(150, 150, 150, 0.1);' if i < len(results)-1 else ''}">
107
+ <div style="font-size: 14px; font-weight: bold; margin-bottom: 2px;">
108
+ {pred['answer']}
109
+ </div>
110
+ <div style="display: flex; align-items: center; gap: 6px;">
111
+ <div style="width: {bar_width}%; height: 8px; border-radius: 3px; background: {color};"></div>
112
+ <div style="font-size: 13px; min-width: 45px;">
113
+ {pred['confidence']}%
114
+ </div>
115
+ </div>
116
+ </div>
117
+ """,
118
+ unsafe_allow_html=True
119
+ )
120
+ else:
121
+ st.info("📥 Pakia picha na andika swali kisha bonyeza Tuma ili kupata jibu.")
requirements.txt CHANGED
@@ -1,3 +1,5 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
1
+ torch==2.1.2
2
+ transformers==4.36.2
3
+ streamlit==1.30.0
4
+ joblib==1.3.2
5
+ Pillow>=10.0.0