|
|
--- |
|
|
language: en |
|
|
license: mit |
|
|
pipeline_tag: text-classification |
|
|
tags: |
|
|
- text-classification |
|
|
- transformers |
|
|
- pytorch |
|
|
- multi-label-classification |
|
|
- multi-class-classification |
|
|
- emotion |
|
|
- bert |
|
|
- go_emotions |
|
|
- emotion-classification |
|
|
datasets: |
|
|
- google-research-datasets/go_emotions |
|
|
metrics: |
|
|
- f1 |
|
|
- precision |
|
|
- recall |
|
|
widget: |
|
|
- text: I’m just chilling today. |
|
|
example_title: Neutral Example |
|
|
- text: Thank you for saving my life! |
|
|
example_title: Gratitude Example |
|
|
- text: I’m nervous about my exam tomorrow. |
|
|
example_title: Nervousness Example |
|
|
base_model: |
|
|
- google-bert/bert-base-uncased |
|
|
--- |
|
|
|
|
|
# GoEmotions BERT Classifier |
|
|
|
|
|
Fine-tuned [BERT-base-uncased](https://huggingface.co/bert-base-uncased) on [go_emotions](https://huggingface.co/datasets/go_emotions) for multi-label classification (28 emotions). |
|
|
|
|
|
## Model Details |
|
|
|
|
|
- **Architecture**: BERT-base-uncased (110M parameters) |
|
|
- **Training Data**: [GoEmotions](https://huggingface.co/datasets/google-research-datasets/go_emotions) (58k Reddit comments, 28 emotions) |
|
|
- **Loss Function**: Focal Loss (gamma=2) |
|
|
- **Optimizer**: AdamW (lr=2e-5, weight_decay=0.01) |
|
|
- **Epochs**: 5 |
|
|
- **Hardware**: Kaggle T4 x2 GPUs |
|
|
|
|
|
## Try It Out |
|
|
For accurate predictions with optimized thresholds, use the [Gradio demo](https://logasanjeev-goemotions-bert-demo.hf.space). |
|
|
|
|
|
## Performance |
|
|
|
|
|
- **Micro F1**: 0.6025 (optimized thresholds) |
|
|
- **Macro F1**: 0.5266 |
|
|
- **Precision**: 0.5425 |
|
|
- **Recall**: 0.6775 |
|
|
- **Hamming Loss**: 0.0372 |
|
|
- **Avg Positive Predictions**: 1.4564 |
|
|
|
|
|
### Class-Wise Performance |
|
|
The following table shows per-class metrics on the test set using optimized thresholds (see `thresholds.json`): |
|
|
|
|
|
| Emotion | F1 Score | Precision | Recall | Support | |
|
|
|----------------|----------|-----------|--------|---------| |
|
|
| admiration | 0.7022 | 0.6980 | 0.7063 | 504 | |
|
|
| amusement | 0.8171 | 0.7692 | 0.8712 | 264 | |
|
|
| anger | 0.5123 | 0.5000 | 0.5253 | 198 | |
|
|
| annoyance | 0.3820 | 0.2908 | 0.5563 | 320 | |
|
|
| approval | 0.4112 | 0.3485 | 0.5014 | 351 | |
|
|
| caring | 0.4601 | 0.4045 | 0.5333 | 135 | |
|
|
| confusion | 0.4488 | 0.4533 | 0.4444 | 153 | |
|
|
| curiosity | 0.5721 | 0.4402 | 0.8169 | 284 | |
|
|
| desire | 0.4068 | 0.6857 | 0.2892 | 83 | |
|
|
| disappointment | 0.3476 | 0.3220 | 0.3775 | 151 | |
|
|
| disapproval | 0.4126 | 0.3433 | 0.5169 | 267 | |
|
|
| disgust | 0.4950 | 0.6329 | 0.4065 | 123 | |
|
|
| embarrassment | 0.5000 | 0.7368 | 0.3784 | 37 | |
|
|
| excitement | 0.4084 | 0.4432 | 0.3786 | 103 | |
|
|
| fear | 0.6311 | 0.5078 | 0.8333 | 78 | |
|
|
| gratitude | 0.9173 | 0.9744 | 0.8665 | 352 | |
|
|
| grief | 0.2500 | 0.5000 | 0.1667 | 6 | |
|
|
| joy | 0.6246 | 0.5798 | 0.6770 | 161 | |
|
|
| love | 0.8110 | 0.7630 | 0.8655 | 238 | |
|
|
| nervousness | 0.3830 | 0.3750 | 0.3913 | 23 | |
|
|
| optimism | 0.5777 | 0.5856 | 0.5699 | 186 | |
|
|
| pride | 0.4138 | 0.4615 | 0.3750 | 16 | |
|
|
| realization | 0.2421 | 0.5111 | 0.1586 | 145 | |
|
|
| relief | 0.5385 | 0.4667 | 0.6364 | 11 | |
|
|
| remorse | 0.6797 | 0.5361 | 0.9286 | 56 | |
|
|
| sadness | 0.5391 | 0.6900 | 0.4423 | 156 | |
|
|
| surprise | 0.5724 | 0.5570 | 0.5887 | 141 | |
|
|
| neutral | 0.6895 | 0.5826 | 0.8444 | 1787 | |
|
|
|
|
|
## Usage |
|
|
|
|
|
The model uses optimized thresholds stored in `thresholds.json` for predictions. Example in Python: |
|
|
|
|
|
```python |
|
|
from transformers import BertForSequenceClassification, BertTokenizer |
|
|
import torch |
|
|
import json |
|
|
import requests |
|
|
|
|
|
# Load model and tokenizer |
|
|
repo_id = "logasanjeev/goemotions-bert" |
|
|
model = BertForSequenceClassification.from_pretrained(repo_id) |
|
|
tokenizer = BertTokenizer.from_pretrained(repo_id) |
|
|
|
|
|
# Load thresholds |
|
|
thresholds_url = f"https://huggingface.co/{repo_id}/raw/main/thresholds.json" |
|
|
thresholds_data = json.loads(requests.get(thresholds_url).text) |
|
|
emotion_labels = thresholds_data["emotion_labels"] |
|
|
thresholds = thresholds_data["thresholds"] |
|
|
|
|
|
# Predict |
|
|
text = "I’m just chilling today." |
|
|
encodings = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt') |
|
|
with torch.no_grad(): |
|
|
logits = torch.sigmoid(model(**encodings).logits).numpy()[0] |
|
|
predictions = [(emotion_labels[i], logit) for i, (logit, thresh) in enumerate(zip(logits, thresholds)) if logit >= thresh] |
|
|
print(sorted(predictions, key=lambda x: x[1], reverse=True)) |
|
|
# Output: [('neutral', 0.8147)] |