chjivan commited on
Commit
5a250ed
·
verified ·
1 Parent(s): c737d37

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +326 -0
  2. requirements.txt +6 -3
app.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import BertForSequenceClassification, BertTokenizerFast
4
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
+ import time
6
+ import pandas as pd
7
+ import base64
8
+ from PIL import Image
9
+ import io
10
+
11
+ # Set page configuration
12
+ st.set_page_config(
13
+ page_title="SMS Spam Guard",
14
+ page_icon="🛡️",
15
+ layout="wide",
16
+ initial_sidebar_state="expanded"
17
+ )
18
+
19
+ # Generate SafeTalk logo as base64 (blue shield with "ST" inside)
20
+ def create_logo():
21
+ from PIL import Image, ImageDraw, ImageFont
22
+ import io
23
+ import base64
24
+
25
+ # Create a new image with a transparent background
26
+ img = Image.new('RGBA', (200, 200), color=(0, 0, 0, 0))
27
+ draw = ImageDraw.Draw(img)
28
+
29
+ # Draw a shield shape
30
+ shield_color = (30, 58, 138) # Dark blue
31
+
32
+ # Shield outline
33
+ points = [(100, 10), (180, 50), (160, 170), (100, 190), (40, 170), (20, 50)]
34
+ draw.polygon(points, fill=shield_color)
35
+
36
+ # Try to load a font, or use default
37
+ try:
38
+ font = ImageFont.truetype("arial.ttf", 80)
39
+ except IOError:
40
+ font = ImageFont.load_default()
41
+
42
+ # Add "ST" text in white
43
+ draw.text((70, 60), "ST", fill=(255, 255, 255), font=font)
44
+
45
+ # Convert to base64 for embedding
46
+ buffered = io.BytesIO()
47
+ img.save(buffered, format="PNG")
48
+ return base64.b64encode(buffered.getvalue()).decode()
49
+
50
+ # Custom CSS for styling
51
+ st.markdown("""
52
+ <style>
53
+ .main-header {
54
+ font-size: 2.5rem !important;
55
+ color: #1E3A8A;
56
+ font-weight: 700;
57
+ margin-bottom: 0.5rem;
58
+ }
59
+ .sub-header {
60
+ font-size: 1.1rem;
61
+ color: #6B7280;
62
+ margin-bottom: 2rem;
63
+ }
64
+ .highlight {
65
+ background-color: #F3F4F6;
66
+ padding: 1.5rem;
67
+ border-radius: 0.5rem;
68
+ margin-bottom: 1rem;
69
+ }
70
+ .result-card {
71
+ background-color: #F0F9FF;
72
+ padding: 1.5rem;
73
+ border-radius: 0.5rem;
74
+ border-left: 5px solid #3B82F6;
75
+ margin-bottom: 1rem;
76
+ }
77
+ .spam-alert {
78
+ background-color: #FEF2F2;
79
+ border-left: 5px solid #EF4444;
80
+ }
81
+ .ham-alert {
82
+ background-color: #ECFDF5;
83
+ border-left: 5px solid #10B981;
84
+ }
85
+ .footer {
86
+ text-align: center;
87
+ margin-top: 3rem;
88
+ font-size: 0.8rem;
89
+ color: #9CA3AF;
90
+ }
91
+ .metrics-container {
92
+ display: flex;
93
+ justify-content: space-between;
94
+ margin-top: 1rem;
95
+ }
96
+ .metric-item {
97
+ text-align: center;
98
+ padding: 1rem;
99
+ background-color: #F9FAFB;
100
+ border-radius: 0.5rem;
101
+ box-shadow: 0 1px 3px rgba(0,0,0,0.1);
102
+ }
103
+ .language-tag {
104
+ display: inline-block;
105
+ padding: 0.25rem 0.5rem;
106
+ background-color: #E0E7FF;
107
+ color: #4F46E5;
108
+ border-radius: 9999px;
109
+ font-size: 0.8rem;
110
+ font-weight: 500;
111
+ margin-right: 0.5rem;
112
+ }
113
+ </style>
114
+ """, unsafe_allow_html=True)
115
+
116
+ @st.cache_resource
117
+ def load_language_model():
118
+ """Load the language detection model"""
119
+ model_name = "papluca/xlm-roberta-base-language-detection"
120
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
121
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
122
+ return tokenizer, model
123
+
124
+ @st.cache_resource
125
+ def load_spam_model():
126
+ """Load the fine-tuned BERT spam detection model"""
127
+ model_path = "chjivan/final"
128
+ tokenizer = BertTokenizerFast.from_pretrained(model_path)
129
+ model = BertForSequenceClassification.from_pretrained(model_path)
130
+ return tokenizer, model
131
+
132
+ def detect_language(text, tokenizer, model):
133
+ """Detect the language of the input text"""
134
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
135
+ with torch.no_grad():
136
+ outputs = model(**inputs)
137
+
138
+ # Get predictions and convert to probabilities
139
+ logits = outputs.logits
140
+ probabilities = torch.softmax(logits, dim=1)[0]
141
+
142
+ # Get the predicted language and its probability
143
+ predicted_class_id = torch.argmax(probabilities).item()
144
+ predicted_language = model.config.id2label[predicted_class_id]
145
+ confidence = probabilities[predicted_class_id].item()
146
+
147
+ # Get top 3 languages with their probabilities
148
+ top_3_indices = torch.topk(probabilities, 3).indices.tolist()
149
+ top_3_probs = torch.topk(probabilities, 3).values.tolist()
150
+ top_3_langs = [(model.config.id2label[idx], prob) for idx, prob in zip(top_3_indices, top_3_probs)]
151
+
152
+ return predicted_language, confidence, top_3_langs
153
+
154
+ def classify_spam(text, tokenizer, model):
155
+ """Classify the input text as spam or ham"""
156
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
157
+ with torch.no_grad():
158
+ outputs = model(**inputs)
159
+
160
+ # Get predictions and convert to probabilities
161
+ logits = outputs.logits
162
+ probabilities = torch.softmax(logits, dim=1)[0]
163
+
164
+ # Get the predicted class and its probability (0: ham, 1: spam)
165
+ predicted_class_id = torch.argmax(probabilities).item()
166
+ confidence = probabilities[predicted_class_id].item()
167
+
168
+ is_spam = predicted_class_id == 1
169
+ return is_spam, confidence
170
+
171
+ # Generate and cache logo
172
+ logo_base64 = create_logo()
173
+ logo_html = f'<img src="data:image/png;base64,{logo_base64}" style="height:150px;">'
174
+
175
+ # Load both models
176
+ with st.spinner("Loading models... This may take a moment."):
177
+ lang_tokenizer, lang_model = load_language_model()
178
+ spam_tokenizer, spam_model = load_spam_model()
179
+
180
+ # App Header with logo
181
+ col1, col2 = st.columns([1, 5])
182
+ with col1:
183
+ st.markdown(logo_html, unsafe_allow_html=True)
184
+ with col2:
185
+ st.markdown('<h1 class="main-header">SMS Spam Guard</h1>', unsafe_allow_html=True)
186
+ st.markdown('<p class="sub-header">智能短信垃圾过滤助手 by SafeTalk Communications Ltd.</p>', unsafe_allow_html=True)
187
+
188
+ # Sidebar
189
+ with st.sidebar:
190
+ st.markdown(logo_html, unsafe_allow_html=True)
191
+ st.markdown("### About SafeTalk")
192
+ st.markdown("SafeTalk Communications Ltd. provides intelligent communication security solutions to protect users from spam and fraudulent messages.")
193
+ st.markdown("#### Our Technology")
194
+ st.markdown("- ✅ Advanced AI-powered spam detection")
195
+ st.markdown("- 🌐 Multi-language support")
196
+ st.markdown("- 🔒 Secure and private processing")
197
+ st.markdown("- ⚡ Real-time analysis")
198
+
199
+ st.markdown("---")
200
+ st.markdown("### Sample Messages")
201
+
202
+ if st.button("Sample Spam (English)"):
203
+ st.session_state.sms_input = "URGENT: You have won a $1,000 Walmart gift card. Go to http://bit.ly/claim-prize to claim now before it expires!"
204
+
205
+ if st.button("Sample Legitimate (English)"):
206
+ st.session_state.sms_input = "Your Amazon package will be delivered today. Thanks for ordering from Amazon!"
207
+
208
+ if st.button("Sample Message (French)"):
209
+ st.session_state.sms_input = "Bonjour! Votre réservation pour le restaurant est confirmée pour ce soir à 20h. À bientôt!"
210
+
211
+ if st.button("Sample Message (Spanish)"):
212
+ st.session_state.sms_input = "Hola, tu cita médica está programada para mañana a las 10:00. Por favor llega 15 minutos antes."
213
+
214
+ # Main Content
215
+ st.markdown('<div class="highlight">', unsafe_allow_html=True)
216
+ # Input form
217
+ sms_input = st.text_area(
218
+ "Enter the SMS message to analyze:",
219
+ value=st.session_state.get("sms_input", ""),
220
+ height=100,
221
+ key="sms_input",
222
+ help="Enter the SMS message you want to analyze for spam"
223
+ )
224
+
225
+ analyze_button = st.button("📱 Analyze Message", use_container_width=True)
226
+ st.markdown('</div>', unsafe_allow_html=True)
227
+
228
+ # Process input and display results
229
+ if analyze_button and sms_input:
230
+ with st.spinner("Analyzing message..."):
231
+ # Step 1: Language Detection
232
+ lang_start_time = time.time()
233
+ lang_code, lang_confidence, top_langs = detect_language(sms_input, lang_tokenizer, lang_model)
234
+ lang_time = time.time() - lang_start_time
235
+
236
+ # Create mapping for full language names
237
+ lang_names = {
238
+ "ar": "Arabic",
239
+ "bg": "Bulgarian",
240
+ "de": "German",
241
+ "el": "Greek",
242
+ "en": "English",
243
+ "es": "Spanish",
244
+ "fr": "French",
245
+ "hi": "Hindi",
246
+ "it": "Italian",
247
+ "ja": "Japanese",
248
+ "nl": "Dutch",
249
+ "pl": "Polish",
250
+ "pt": "Portuguese",
251
+ "ru": "Russian",
252
+ "sw": "Swahili",
253
+ "th": "Thai",
254
+ "tr": "Turkish",
255
+ "ur": "Urdu",
256
+ "vi": "Vietnamese",
257
+ "zh": "Chinese"
258
+ }
259
+
260
+ lang_name = lang_names.get(lang_code, lang_code)
261
+
262
+ # Step 2: Spam Classification
263
+ spam_start_time = time.time()
264
+ is_spam, spam_confidence = classify_spam(sms_input, spam_tokenizer, spam_model)
265
+ spam_time = time.time() - spam_start_time
266
+
267
+ # Display Language Detection Results
268
+ st.markdown("### Analysis Results")
269
+
270
+ col1, col2 = st.columns(2)
271
+
272
+ with col1:
273
+ st.markdown("#### 📊 Language Detection")
274
+ st.markdown(f'<div class="result-card">', unsafe_allow_html=True)
275
+ st.markdown(f'<span class="language-tag">{lang_name}</span> Detected with {lang_confidence:.1%} confidence', unsafe_allow_html=True)
276
+
277
+ # Display top 3 languages
278
+ st.markdown("##### Top language probabilities:")
279
+ for lang_code, prob in top_langs:
280
+ lang_full = lang_names.get(lang_code, lang_code)
281
+ st.markdown(f"- {lang_full}: {prob:.1%}")
282
+
283
+ st.markdown(f"⏱️ Processing time: {lang_time:.3f} seconds")
284
+ st.markdown('</div>', unsafe_allow_html=True)
285
+
286
+ with col2:
287
+ st.markdown("#### 🔍 Spam Detection")
288
+
289
+ if is_spam:
290
+ st.markdown(f'<div class="result-card spam-alert">', unsafe_allow_html=True)
291
+ st.markdown(f"⚠️ **SPAM DETECTED** with {spam_confidence:.1%} confidence")
292
+ st.markdown("This message appears to be spam and potentially harmful.")
293
+ else:
294
+ st.markdown(f'<div class="result-card ham-alert">', unsafe_allow_html=True)
295
+ st.markdown(f"✅ **LEGITIMATE MESSAGE** with {spam_confidence:.1%} confidence")
296
+ st.markdown("This message appears to be legitimate.")
297
+
298
+ st.markdown(f"⏱️ Processing time: {spam_time:.3f} seconds")
299
+ st.markdown('</div>', unsafe_allow_html=True)
300
+
301
+ # Summary and Recommendations
302
+ st.markdown("### 📋 Summary & Recommendations")
303
+ if is_spam:
304
+ st.warning("📵 **Recommended Action**: This message should be blocked or moved to spam folder.")
305
+ st.markdown("""
306
+ **Why this is likely spam:**
307
+ - Contains suspicious language patterns
308
+ - May include urgent calls to action
309
+ - Could contain unsolicited offers
310
+ """)
311
+ else:
312
+ st.success("✅ **Recommended Action**: This message can be delivered to the inbox.")
313
+
314
+ # Chart for visualization
315
+ st.markdown("### 📈 Confidence Visualization")
316
+ chart_data = pd.DataFrame({
317
+ 'Task': ['Language Detection', 'Spam Classification'],
318
+ 'Confidence': [lang_confidence, spam_confidence if is_spam else 1-spam_confidence]
319
+ })
320
+ st.bar_chart(chart_data.set_index('Task'))
321
+
322
+ # Footer
323
+ st.markdown('<div class="footer">', unsafe_allow_html=True)
324
+ st.markdown("© 2023 SafeTalk Communications Ltd. | www.safetalk.com")
325
+ st.markdown("SMS Spam Guard is an intelligent message filtering solution to protect users from unwanted communications.")
326
+ st.markdown('</div>', unsafe_allow_html=True)
requirements.txt CHANGED
@@ -1,3 +1,6 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
1
+ streamlit==1.32.0
2
+ torch==2.1.0
3
+ transformers==4.38.0
4
+ pandas==2.2.0
5
+ numpy==1.26.0
6
+ safetensors==0.4.5