mbwanaf commited on
Commit
b57ccdf
Β·
verified Β·
1 Parent(s): b8fed84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -121
app.py CHANGED
@@ -1,121 +1,130 @@
1
- import os
2
- os.environ["STREAMLIT_HOME"] = os.getcwd()
3
- os.environ["STREAMLIT_RUNTIME_METRICS_ENABLED"] = "false"
4
- os.makedirs(".streamlit", exist_ok=True)
5
-
6
- import streamlit as st
7
- import torch
8
- import joblib
9
- import numpy as np
10
- import random
11
- from PIL import Image
12
- from transformers import AutoTokenizer, AutoModel, ViTModel, ViTImageProcessor
13
-
14
- # CPU device only
15
- device = torch.device("cpu")
16
-
17
- # Define Swahili VQA Model
18
- class SwahiliVQAModel(torch.nn.Module):
19
- def __init__(self, num_answers):
20
- super().__init__()
21
- self.vision_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
22
- self.text_encoder = AutoModel.from_pretrained("benjamin/roberta-base-wechsel-swahili")
23
- self.fusion = torch.nn.Sequential(
24
- torch.nn.Linear(768 + 768, 512),
25
- torch.nn.ReLU(),
26
- torch.nn.Dropout(0.3),
27
- torch.nn.LayerNorm(512)
28
- )
29
- self.classifier = torch.nn.Linear(512, num_answers)
30
-
31
- def forward(self, image, input_ids, attention_mask):
32
- vision_outputs = self.vision_encoder(pixel_values=image)
33
- image_feats = vision_outputs.last_hidden_state[:, 0, :]
34
- text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
35
- text_feats = text_outputs.last_hidden_state[:, 0, :]
36
- combined = torch.cat([image_feats, text_feats], dim=1)
37
- fused = self.fusion(combined)
38
- return self.classifier(fused)
39
-
40
- # Load label encoder
41
- le = joblib.load("Vit_3895_label_encoder_best.pkl")
42
-
43
- # Load model weights normally β€” no override
44
- model = SwahiliVQAModel(num_answers=len(le.classes_)).to(device)
45
- state_dict = torch.load("Vit_3895_best_model_epoch25.pth", map_location=device)
46
- model.load_state_dict(state_dict)
47
- model.eval()
48
-
49
- # Load tokenizer and processor
50
- tokenizer = AutoTokenizer.from_pretrained("benjamin/roberta-base-wechsel-swahili")
51
- vit_processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
52
-
53
- # Streamlit UI
54
- st.set_page_config(page_title="Swahili VQA", layout="wide")
55
- st.title("🦜 Swahili Visual Question Answering (VQA)")
56
-
57
- uploaded_image = st.file_uploader("πŸ“‚ Pakia picha hapa:", type=["jpg", "jpeg", "png"])
58
-
59
- def generate_random_color():
60
- return f"rgb({random.randint(150, 255)}, {random.randint(80, 200)}, {random.randint(80, 200)})"
61
-
62
- col1, col2 = st.columns([1, 2], gap="large")
63
-
64
- with col1:
65
- if uploaded_image:
66
- st.image(uploaded_image, caption="Picha Iliyopakiwa", use_container_width=True)
67
- st.markdown("<div style='margin-bottom: 25px;'></div>", unsafe_allow_html=True)
68
-
69
- with col2:
70
- st.markdown("<div style='padding-top: 15px;'>", unsafe_allow_html=True)
71
- question = st.text_input("πŸ’¬Andika swali lako hapa:", key="question_input")
72
- submit_button = st.button("πŸ“©Tuma")
73
- st.markdown("</div>", unsafe_allow_html=True)
74
-
75
- if submit_button and uploaded_image and question:
76
- with st.spinner("πŸ” Inachakata jibu..."):
77
- image = Image.open(uploaded_image).convert("RGB")
78
- image_tensor = vit_processor(images=image, return_tensors="pt")["pixel_values"]
79
-
80
- inputs = tokenizer(question, max_length=128, padding="max_length", truncation=True, return_tensors="pt")
81
- input_ids = inputs["input_ids"]
82
- attention_mask = inputs["attention_mask"]
83
-
84
- with torch.no_grad():
85
- logits = model(image_tensor, input_ids, attention_mask)
86
- probs = torch.softmax(logits, dim=1)
87
-
88
- top_probs, top_indices = torch.topk(probs, 5)
89
- decoded_answers = le.inverse_transform(top_indices.cpu().numpy()[0])
90
-
91
- results = [
92
- {"answer": ans, "confidence": round(prob * 100, 2)}
93
- for ans, prob in zip(decoded_answers, top_probs[0].tolist())
94
- ]
95
-
96
- results = sorted(results, key=lambda x: x["confidence"], reverse=True)
97
-
98
- st.subheader("Majibu Yanayowezekana:")
99
- max_confidence = max(result["confidence"] for result in results)
100
-
101
- for i, pred in enumerate(results):
102
- bar_width = (pred["confidence"] / max_confidence) * 70
103
- color = generate_random_color()
104
- st.markdown(
105
- f"""
106
- <div style="margin: 4px 0; padding: 2px 0; {'border-bottom: 1px solid rgba(150, 150, 150, 0.1);' if i < len(results)-1 else ''}">
107
- <div style="font-size: 14px; font-weight: bold; margin-bottom: 2px;">
108
- {pred['answer']}
109
- </div>
110
- <div style="display: flex; align-items: center; gap: 6px;">
111
- <div style="width: {bar_width}%; height: 8px; border-radius: 3px; background: {color};"></div>
112
- <div style="font-size: 13px; min-width: 45px;">
113
- {pred['confidence']}%
114
- </div>
115
- </div>
116
- </div>
117
- """,
118
- unsafe_allow_html=True
119
- )
120
- else:
121
- st.info("πŸ“₯ Pakia picha na andika swali kisha bonyeza Tuma ili kupata jibu.")
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Avoid Streamlit file watcher issues in Docker
4
+ os.environ["STREAMLIT_WATCHED_MODULES"] = ""
5
+
6
+ # Set a writable cache directory for Hugging Face models
7
+ os.environ["TRANSFORMERS_CACHE"] = "./hf_cache"
8
+ os.makedirs("./hf_cache", exist_ok=True)
9
+
10
+ # Set Streamlit-specific options to prevent permission errors
11
+ os.environ["STREAMLIT_HOME"] = os.getcwd()
12
+ os.environ["STREAMLIT_RUNTIME_METRICS_ENABLED"] = "false"
13
+ os.makedirs(".streamlit", exist_ok=True)
14
+
15
+ import streamlit as st
16
+ import torch
17
+ import joblib
18
+ import numpy as np
19
+ import random
20
+ from PIL import Image
21
+ from transformers import AutoTokenizer, AutoModel, ViTModel, ViTImageProcessor
22
+
23
+ # CPU device only
24
+ device = torch.device("cpu")
25
+
26
+ # Define Swahili VQA Model
27
+ class SwahiliVQAModel(torch.nn.Module):
28
+ def __init__(self, num_answers):
29
+ super().__init__()
30
+ self.vision_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
31
+ self.text_encoder = AutoModel.from_pretrained("benjamin/roberta-base-wechsel-swahili")
32
+ self.fusion = torch.nn.Sequential(
33
+ torch.nn.Linear(768 + 768, 512),
34
+ torch.nn.ReLU(),
35
+ torch.nn.Dropout(0.3),
36
+ torch.nn.LayerNorm(512)
37
+ )
38
+ self.classifier = torch.nn.Linear(512, num_answers)
39
+
40
+ def forward(self, image, input_ids, attention_mask):
41
+ vision_outputs = self.vision_encoder(pixel_values=image)
42
+ image_feats = vision_outputs.last_hidden_state[:, 0, :]
43
+ text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
44
+ text_feats = text_outputs.last_hidden_state[:, 0, :]
45
+ combined = torch.cat([image_feats, text_feats], dim=1)
46
+ fused = self.fusion(combined)
47
+ return self.classifier(fused)
48
+
49
+ # Load label encoder
50
+ le = joblib.load("Vit_3895_label_encoder_best.pkl")
51
+
52
+ # Load model weights normally β€” no override
53
+ model = SwahiliVQAModel(num_answers=len(le.classes_)).to(device)
54
+ state_dict = torch.load("Vit_3895_best_model_epoch25.pth", map_location=device)
55
+ model.load_state_dict(state_dict)
56
+ model.eval()
57
+
58
+ # Load tokenizer and processor
59
+ tokenizer = AutoTokenizer.from_pretrained("benjamin/roberta-base-wechsel-swahili")
60
+ vit_processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
61
+
62
+ # Streamlit UI
63
+ st.set_page_config(page_title="Swahili VQA", layout="wide")
64
+ st.title("🦜 Swahili Visual Question Answering (VQA)")
65
+
66
+ uploaded_image = st.file_uploader("πŸ“‚ Pakia picha hapa:", type=["jpg", "jpeg", "png"])
67
+
68
+ def generate_random_color():
69
+ return f"rgb({random.randint(150, 255)}, {random.randint(80, 200)}, {random.randint(80, 200)})"
70
+
71
+ col1, col2 = st.columns([1, 2], gap="large")
72
+
73
+ with col1:
74
+ if uploaded_image:
75
+ st.image(uploaded_image, caption="Picha Iliyopakiwa", use_container_width=True)
76
+ st.markdown("<div style='margin-bottom: 25px;'></div>", unsafe_allow_html=True)
77
+
78
+ with col2:
79
+ st.markdown("<div style='padding-top: 15px;'>", unsafe_allow_html=True)
80
+ question = st.text_input("πŸ’¬Andika swali lako hapa:", key="question_input")
81
+ submit_button = st.button("πŸ“©Tuma")
82
+ st.markdown("</div>", unsafe_allow_html=True)
83
+
84
+ if submit_button and uploaded_image and question:
85
+ with st.spinner("πŸ” Inachakata jibu..."):
86
+ image = Image.open(uploaded_image).convert("RGB")
87
+ image_tensor = vit_processor(images=image, return_tensors="pt")["pixel_values"]
88
+
89
+ inputs = tokenizer(question, max_length=128, padding="max_length", truncation=True, return_tensors="pt")
90
+ input_ids = inputs["input_ids"]
91
+ attention_mask = inputs["attention_mask"]
92
+
93
+ with torch.no_grad():
94
+ logits = model(image_tensor, input_ids, attention_mask)
95
+ probs = torch.softmax(logits, dim=1)
96
+
97
+ top_probs, top_indices = torch.topk(probs, 5)
98
+ decoded_answers = le.inverse_transform(top_indices.cpu().numpy()[0])
99
+
100
+ results = [
101
+ {"answer": ans, "confidence": round(prob * 100, 2)}
102
+ for ans, prob in zip(decoded_answers, top_probs[0].tolist())
103
+ ]
104
+
105
+ results = sorted(results, key=lambda x: x["confidence"], reverse=True)
106
+
107
+ st.subheader("Majibu Yanayowezekana:")
108
+ max_confidence = max(result["confidence"] for result in results)
109
+
110
+ for i, pred in enumerate(results):
111
+ bar_width = (pred["confidence"] / max_confidence) * 70
112
+ color = generate_random_color()
113
+ st.markdown(
114
+ f"""
115
+ <div style="margin: 4px 0; padding: 2px 0; {'border-bottom: 1px solid rgba(150, 150, 150, 0.1);' if i < len(results)-1 else ''}">
116
+ <div style="font-size: 14px; font-weight: bold; margin-bottom: 2px;">
117
+ {pred['answer']}
118
+ </div>
119
+ <div style="display: flex; align-items: center; gap: 6px;">
120
+ <div style="width: {bar_width}%; height: 8px; border-radius: 3px; background: {color};"></div>
121
+ <div style="font-size: 13px; min-width: 45px;">
122
+ {pred['confidence']}%
123
+ </div>
124
+ </div>
125
+ </div>
126
+ """,
127
+ unsafe_allow_html=True
128
+ )
129
+ else:
130
+ st.info("πŸ“₯ Pakia picha na andika swali kisha bonyeza Tuma ili kupata jibu.")