Update README.md
Browse files
README.md
CHANGED
|
@@ -9,12 +9,12 @@ base_model:
|
|
| 9 |
pipeline_tag: text-classification
|
| 10 |
library_name: transformers
|
| 11 |
---
|
| 12 |
-
|
| 13 |
|
| 14 |
-
This implementation
|
| 15 |
|
| 16 |
---
|
| 17 |
-
|
| 18 |
|
| 19 |
| **File Name** | **Size** | **Description** | **Upload Status** |
|
| 20 |
|------------------------------------|-----------|-----------------------------------------------------|-------------------|
|
|
@@ -49,9 +49,7 @@ Results were obtained using BERT and the provided training dataset:
|
|
| 49 |
- **Precision:** **0.9931**
|
| 50 |
- **Recall:** **0.9597**
|
| 51 |
- **F1 Score:** **0.9761**
|
| 52 |
-
|
| 53 |
---
|
| 54 |
-
|
| 55 |
## **π Model Training Details**
|
| 56 |
|
| 57 |
### **Model Architecture:**
|
|
@@ -62,74 +60,27 @@ The model uses `bert-base-uncased` as the pre-trained backbone and is fine-tuned
|
|
| 62 |
- **Batch Size:** 16
|
| 63 |
- **Epochs:** 3
|
| 64 |
- **Loss:** Cross-Entropy
|
| 65 |
-
|
| 66 |
---
|
| 67 |
-
##
|
| 68 |
-
|
| 69 |
-
```python
|
| 70 |
-
import gradio as gr
|
| 71 |
-
import torch
|
| 72 |
-
from transformers import BertTokenizer, BertForSequenceClassification
|
| 73 |
-
|
| 74 |
-
# Load the pre-trained BERT model and tokenizer
|
| 75 |
-
MODEL_PATH = "prithivMLmods/Spam-Bert-Uncased"
|
| 76 |
-
tokenizer = BertTokenizer.from_pretrained(MODEL_PATH)
|
| 77 |
-
model = BertForSequenceClassification.from_pretrained(MODEL_PATH)
|
| 78 |
-
|
| 79 |
-
# Function to predict if a given text is Spam or Ham
|
| 80 |
-
def predict_spam(text):
|
| 81 |
-
# Tokenize the input text
|
| 82 |
-
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
|
| 83 |
-
|
| 84 |
-
# Perform inference
|
| 85 |
-
with torch.no_grad():
|
| 86 |
-
outputs = model(**inputs)
|
| 87 |
-
logits = outputs.logits
|
| 88 |
-
prediction = torch.argmax(logits, axis=-1).item()
|
| 89 |
-
|
| 90 |
-
# Map prediction to label
|
| 91 |
-
if prediction == 1:
|
| 92 |
-
return "Spam"
|
| 93 |
-
else:
|
| 94 |
-
return "Ham"
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
# Gradio UI - Input and Output components
|
| 98 |
-
inputs = gr.Textbox(label="Enter Text", placeholder="Type a message to check if it's Spam or Ham...")
|
| 99 |
-
outputs = gr.Label(label="Prediction")
|
| 100 |
-
|
| 101 |
-
# List of example inputs
|
| 102 |
-
examples = [
|
| 103 |
-
["Win $1000 gift cards now by clicking here!"],
|
| 104 |
-
["You have been selected for a lottery."],
|
| 105 |
-
["Hello, how was your day?"],
|
| 106 |
-
["Earn money without any effort. Click here."],
|
| 107 |
-
["Meeting tomorrow at 10 AM. Don't be late."],
|
| 108 |
-
["Claim your free prize now!"],
|
| 109 |
-
["Are we still on for dinner tonight?"],
|
| 110 |
-
["Exclusive offer just for you, act now!"],
|
| 111 |
-
["Let's catch up over coffee soon."],
|
| 112 |
-
["Congratulations, you've won a new car!"]
|
| 113 |
-
]
|
| 114 |
-
|
| 115 |
-
# Create the Gradio interface
|
| 116 |
-
gr_interface = gr.Interface(
|
| 117 |
-
fn=predict_spam,
|
| 118 |
-
inputs=inputs,
|
| 119 |
-
outputs=outputs,
|
| 120 |
-
examples=examples,
|
| 121 |
-
title="Spam Detection with BERT",
|
| 122 |
-
description="Type a message in the text box to check if it's Spam or Ham using a pre-trained BERT model."
|
| 123 |
-
)
|
| 124 |
-
|
| 125 |
-
# Launch the application
|
| 126 |
-
gr_interface.launch()
|
| 127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
```
|
| 129 |
-
### Train Details
|
| 130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
```python
|
| 132 |
-
|
| 133 |
# Import necessary libraries
|
| 134 |
from datasets import load_dataset, ClassLabel
|
| 135 |
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
|
|
@@ -235,33 +186,14 @@ def predict(text):
|
|
| 235 |
example_text = "Congratulations! You've won a $1000 Walmart gift card. Click here to claim now."
|
| 236 |
print("Prediction:", predict(example_text))
|
| 237 |
```
|
| 238 |
-
|
| 239 |
-
## **π How to Train the Model**
|
| 240 |
-
|
| 241 |
-
1. **Clone Repository:**
|
| 242 |
-
```bash
|
| 243 |
-
git clone <repository-url>
|
| 244 |
-
cd <project-directory>
|
| 245 |
-
```
|
| 246 |
-
|
| 247 |
-
2. **Install Dependencies:**
|
| 248 |
-
Install all necessary dependencies.
|
| 249 |
-
```bash
|
| 250 |
-
pip install -r requirements.txt
|
| 251 |
-
```
|
| 252 |
-
or manually:
|
| 253 |
-
```bash
|
| 254 |
-
pip install transformers datasets wandb scikit-learn
|
| 255 |
-
```
|
| 256 |
-
|
| 257 |
-
3. **Train the Model:**
|
| 258 |
-
Assuming you have a script like `train.py`, run:
|
| 259 |
-
```python
|
| 260 |
-
from train import main
|
| 261 |
-
```
|
| 262 |
-
|
| 263 |
---
|
|
|
|
|
|
|
|
|
|
| 264 |
|
|
|
|
|
|
|
|
|
|
| 265 |
## **β¨ Weights & Biases Integration**
|
| 266 |
|
| 267 |
### Why Use wandb?
|
|
@@ -275,10 +207,8 @@ Include this snippet in your training script:
|
|
| 275 |
import wandb
|
| 276 |
wandb.init(project="spam-detection")
|
| 277 |
```
|
| 278 |
-
|
| 279 |
---
|
| 280 |
-
|
| 281 |
-
## π **Directory Structure**
|
| 282 |
|
| 283 |
The directory is organized to ensure scalability and clear separation of components:
|
| 284 |
|
|
@@ -292,14 +222,57 @@ project-directory/
|
|
| 292 |
βββ requirements.txt # List of dependencies
|
| 293 |
βββ train.py # Main script for training the model
|
| 294 |
```
|
| 295 |
-
|
| 296 |
---
|
|
|
|
| 297 |
|
| 298 |
-
|
| 299 |
-
The training dataset comes from **Spam-Text-Detect-Analysis** available on Hugging Face:
|
| 300 |
-
- **Dataset Link:** [Spam Text Detection Dataset - Hugging Face](https://huggingface.co/datasets)
|
| 301 |
|
| 302 |
-
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
---
|
|
|
|
| 9 |
pipeline_tag: text-classification
|
| 10 |
library_name: transformers
|
| 11 |
---
|
| 12 |
+
# **Spam Detection with BERT**
|
| 13 |
|
| 14 |
+
This repository contains an implementation of a **Spam Detection** model using **BERT (Bidirectional Encoder Representations from Transformers)** for binary classification (Spam / Ham). The model is trained on the **`prithivMLmods/Spam-Text-Detect-Analysis` dataset** and leverages **Weights & Biases (wandb)** for comprehensive experiment tracking.
|
| 15 |
|
| 16 |
---
|
| 17 |
+
## **ποΈ Summary of Uploaded Files**
|
| 18 |
|
| 19 |
| **File Name** | **Size** | **Description** | **Upload Status** |
|
| 20 |
|------------------------------------|-----------|-----------------------------------------------------|-------------------|
|
|
|
|
| 49 |
- **Precision:** **0.9931**
|
| 50 |
- **Recall:** **0.9597**
|
| 51 |
- **F1 Score:** **0.9761**
|
|
|
|
| 52 |
---
|
|
|
|
| 53 |
## **π Model Training Details**
|
| 54 |
|
| 55 |
### **Model Architecture:**
|
|
|
|
| 60 |
- **Batch Size:** 16
|
| 61 |
- **Epochs:** 3
|
| 62 |
- **Loss:** Cross-Entropy
|
|
|
|
| 63 |
---
|
| 64 |
+
## **π How to Use the Model**
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
+
### **1. Clone the Repository**
|
| 67 |
+
```bash
|
| 68 |
+
git clone <repository-url>
|
| 69 |
+
cd <project-directory>
|
| 70 |
```
|
|
|
|
| 71 |
|
| 72 |
+
### **2. Install Dependencies**
|
| 73 |
+
Install all necessary dependencies.
|
| 74 |
+
```bash
|
| 75 |
+
pip install -r requirements.txt
|
| 76 |
+
```
|
| 77 |
+
or manually:
|
| 78 |
+
```bash
|
| 79 |
+
pip install transformers datasets wandb scikit-learn
|
| 80 |
+
```
|
| 81 |
+
### **3. Train the Model**
|
| 82 |
+
Assuming you have a script like `train.py`, run:
|
| 83 |
```python
|
|
|
|
| 84 |
# Import necessary libraries
|
| 85 |
from datasets import load_dataset, ClassLabel
|
| 86 |
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
|
|
|
|
| 186 |
example_text = "Congratulations! You've won a $1000 Walmart gift card. Click here to claim now."
|
| 187 |
print("Prediction:", predict(example_text))
|
| 188 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
---
|
| 190 |
+
## **π Dataset Information**
|
| 191 |
+
The training dataset comes from **Spam-Text-Detect-Analysis** available on Hugging Face:
|
| 192 |
+
- **Dataset Link:** [Spam Text Detection Dataset - Hugging Face](https://huggingface.co/datasets)
|
| 193 |
|
| 194 |
+
Dataset size:
|
| 195 |
+
- **5.57k entries**
|
| 196 |
+
---
|
| 197 |
## **β¨ Weights & Biases Integration**
|
| 198 |
|
| 199 |
### Why Use wandb?
|
|
|
|
| 207 |
import wandb
|
| 208 |
wandb.init(project="spam-detection")
|
| 209 |
```
|
|
|
|
| 210 |
---
|
| 211 |
+
## **π Directory Structure**
|
|
|
|
| 212 |
|
| 213 |
The directory is organized to ensure scalability and clear separation of components:
|
| 214 |
|
|
|
|
| 222 |
βββ requirements.txt # List of dependencies
|
| 223 |
βββ train.py # Main script for training the model
|
| 224 |
```
|
|
|
|
| 225 |
---
|
| 226 |
+
## **π Gradio Interface**
|
| 227 |
|
| 228 |
+
A Gradio interface is provided to test the model interactively. The interface allows users to input text and get predictions on whether the text is **Spam** or **Ham**.
|
|
|
|
|
|
|
| 229 |
|
| 230 |
+
### **Example Usage**
|
| 231 |
+
```python
|
| 232 |
+
import gradio as gr
|
| 233 |
+
import torch
|
| 234 |
+
from transformers import BertTokenizer, BertForSequenceClassification
|
| 235 |
+
|
| 236 |
+
# Load the pre-trained BERT model and tokenizer
|
| 237 |
+
MODEL_PATH = "prithivMLmods/Spam-Bert-Uncased"
|
| 238 |
+
tokenizer = BertTokenizer.from_pretrained(MODEL_PATH)
|
| 239 |
+
model = BertForSequenceClassification.from_pretrained(MODEL_PATH)
|
| 240 |
+
|
| 241 |
+
# Function to predict if a given text is Spam or Ham
|
| 242 |
+
def predict_spam(text):
|
| 243 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
|
| 244 |
+
with torch.no_grad():
|
| 245 |
+
outputs = model(**inputs)
|
| 246 |
+
logits = outputs.logits
|
| 247 |
+
prediction = torch.argmax(logits, axis=-1).item()
|
| 248 |
+
return "Spam" if prediction == 1 else "Ham"
|
| 249 |
|
| 250 |
+
# Gradio UI
|
| 251 |
+
inputs = gr.Textbox(label="Enter Text", placeholder="Type a message to check if it's Spam or Ham...")
|
| 252 |
+
outputs = gr.Label(label="Prediction")
|
| 253 |
+
|
| 254 |
+
examples = [
|
| 255 |
+
["Win $1000 gift cards now by clicking here!"],
|
| 256 |
+
["You have been selected for a lottery."],
|
| 257 |
+
["Hello, how was your day?"],
|
| 258 |
+
["Earn money without any effort. Click here."],
|
| 259 |
+
["Meeting tomorrow at 10 AM. Don't be late."],
|
| 260 |
+
["Claim your free prize now!"],
|
| 261 |
+
["Are we still on for dinner tonight?"],
|
| 262 |
+
["Exclusive offer just for you, act now!"],
|
| 263 |
+
["Let's catch up over coffee soon."],
|
| 264 |
+
["Congratulations, you've won a new car!"]
|
| 265 |
+
]
|
| 266 |
+
|
| 267 |
+
gr_interface = gr.Interface(
|
| 268 |
+
fn=predict_spam,
|
| 269 |
+
inputs=inputs,
|
| 270 |
+
outputs=outputs,
|
| 271 |
+
examples=examples,
|
| 272 |
+
title="Spam Detection with BERT",
|
| 273 |
+
description="Type a message in the text box to check if it's Spam or Ham using a pre-trained BERT model."
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
gr_interface.launch()
|
| 277 |
+
```
|
| 278 |
---
|