File size: 21,818 Bytes
c568dbf 93dd428 c568dbf 93dd428 c568dbf 93dd428 c568dbf 93dd428 c568dbf 93dd428 c568dbf ce809b1 dcaa6d8 c568dbf 93dd428 e4d3121 93dd428 c568dbf 93dd428 35c18f1 93dd428 35c18f1 93dd428 35c18f1 93dd428 61a2749 35c18f1 93dd428 61a2749 93dd428 35c18f1 61a2749 35c18f1 93dd428 35c18f1 93dd428 35c18f1 93dd428 35c18f1 93dd428 35c18f1 93dd428 35c18f1 93dd428 35c18f1 93dd428 35c18f1 93dd428 35c18f1 93dd428 61a2749 35c18f1 93dd428 35c18f1 93dd428 c568dbf 93dd428 c568dbf 93dd428 c568dbf 93dd428 c568dbf 61a2749 93dd428 c568dbf 93dd428 c568dbf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 |
---
library_name: transformers
base_model: cardiffnlp/twitter-xlm-roberta-base-sentiment
tags:
- text-classification
- multi-label-classification
- multi-head-classification
- disaster-response
- humanitarian-aid
- social-media
- twitter
- generated_from_trainer
model-index:
- name: xlm-roberta-sentiment-requests
results:
- task:
type: text-classification
dataset:
name: community-datasets/disaster_response_messages
type: community-datasets
config: default
split: evaluation
metrics:
- name: F1 Micro
type: f1
value: 0.7240
- name: F1 Macro
type: f1
value: 0.3505
- name: Subset Accuracy
type: accuracy
value: 0.2588
datasets:
- community-datasets/disaster_response_messages
pipeline_tag: text-classification
language:
- en
- multilingual
---
<!-- This model card has been generated automatically and then completed by a human. -->
# xlm-roberta-sentiment-requests
This model is a fine-tuned version of [cardiffnlp/twitter-xlm-roberta-base-sentiment](https://huggingface.co/cardiffnlp/twitter-xlm-roberta-base-sentiment) on the [community-datasets/disaster_response_messages](https://huggingface.co/datasets/community-datasets/disaster_response_messages) dataset. It has been adapted into a powerful **multi-head classification model** designed to analyze messages from social media during disaster events.
It achieves the following results on the evaluation set:
- Loss: 0.1465
- F1 Micro: 0.7240
- F1 Macro: 0.3505
- Subset Accuracy: 0.2588
## Model description
This model uses a shared `XLM-RoBERTa` base to encode input text. The resulting text representation is then fed into two separate, independent classification layers (heads):
* A **Sentiment Head (Frozen from pre-trained model)** with 3 outputs for `positive`, `neutral`, and `negative` classes.
* A **Multi-Label Head (Newly created and fine-tuned)** with 41 outputs, which are decoded to predict the presence or absence of 37 different disaster-related categories.
This dual-head architecture allows for a nuanced understanding of a message, capturing both its emotional content and its specific, actionable information.
## Intended uses & limitations
This model is intended for organizations and researchers involved in humanitarian aid and disaster response. Potential applications include:
* **Automated Triage**: Quickly sorting through thousands of social media messages to identify the most urgent requests for help.
* **Situational Awareness**: Building a real-time map of needs by aggregating categorized messages.
* **Resource Allocation**: Directing resources more effectively by understanding the specific types of aid being requested.
**Important**: Due to its custom architecture, this model **cannot** be used with the standard `pipeline("text-classification")` function. Please see the usage code below for the correct implementation.
### How to Use
This model requires custom code to handle its two-headed output. The following is a complete, self-contained Python script to run inference. You will need to have `transformers`, `torch`, `safetensors`, and `huggingface_hub` installed (`pip install transformers torch safetensors huggingface_hub`).
The script automatically downloads all necessary files, including the model weights and metadata. Simply copy the code blocks below and run the script.
The script is broken into logical blocks:
1. **Model Architecture**: A Python class that defines the model's structure. This blueprint is required to load the saved weights.
2. **Label Definitions**: A "decoder ring" of functions to translate the model's numerical outputs into human-readable labels.
3. **Setup & Loading**: A function that handles all the one-time setup.
4. **Prediction Function**: The core logic that takes text and produces a dictionary of predictions.
5. **Main Execution**: An example of how to run the script.
By copying the codes below from 1 to 5, you will be able to run the entire inference pipeline with all outputs.
***
1. **Model Architecture**: We define the necessary imports and the model architecture.
```python
import torch
from torch import nn
from transformers import AutoTokenizer, AutoConfig, AutoModel, PreTrainedModel
from huggingface_hub import hf_hub_download
from typing import Dict, Any
from safetensors.torch import load_file
import json
class MultiHeadClassificationModel(PreTrainedModel):
def __init__(self, config, **kwargs):
super().__init__(config)
num_multilabels = kwargs.get("num_multilabels")
if num_multilabels is None:
raise ValueError("`num_multilabels` must be provided to initialize the model.")
self.backbone = AutoModel.from_config(config)
self.sentiment_classifier = nn.Linear(config.hidden_size, config.num_sentiment_labels)
self.multilabel_classifier = nn.Linear(config.hidden_size, num_multilabels)
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, **kwargs):
outputs = self.backbone(input_ids, attention_mask=attention_mask, **kwargs)
cls_token_output = outputs.last_hidden_state[:, 0, :]
sentiment_logits = self.sentiment_classifier(cls_token_output)
multilabel_logits = self.multilabel_classifier(cls_token_output)
return {"sentiment_logits": sentiment_logits, "multilabel_logits": multilabel_logits}
```
***
2. **Label Definitions**: We embed the label definitions, which are essential for interpreting the model's output.
```python
def get_all_labels() -> Dict[str, Dict[int, str]]:
return {
'sentiment': get_sentiment_labels(), 'genre': get_genre_labels(), 'related': get_related_labels(),
'request': get_request_labels(), 'offer': get_offer_labels(), 'aid_related': get_aid_related_labels(),
'medical_help': get_medical_help_labels(), 'medical_products': get_medical_products_labels(),
'search_and_rescue': get_search_and_rescue_labels(), 'security': get_security_labels(),
'military': get_military_labels(), 'child_alone': get_child_alone_labels(), 'water': get_water_labels(),
'food': get_food_labels(), 'shelter': get_shelter_labels(), 'clothing': get_clothing_labels(),
'money': get_money_labels(), 'missing_people': get_missing_people_labels(),
'refugees': get_refugees_labels(), 'death': get_death_labels(), 'other_aid': get_other_aid_labels(),
'infrastructure_related': get_infrastructure_related_labels(), 'transport': get_transport_labels(),
'buildings': get_buildings_labels(), 'electricity': get_electricity_labels(), 'tools': get_tools_labels(),
'hospitals': get_hospitals_labels(), 'shops': get_shops_labels(), 'aid_centers': get_aid_centers_labels(),
'other_infrastructure': get_other_infrastructure_labels(), 'weather_related': get_weather_related_labels(),
'floods': get_floods_labels(), 'storm': get_storm_labels(), 'fire': get_fire_labels(),
'earthquake': get_earthquake_labels(), 'cold': get_cold_labels(), 'other_weather': get_other_weather_labels(),
'direct_report': get_direct_report_labels(),
}
def get_genre_labels() -> Dict[int, str]: return {0: 'direct', 1: 'news', 2: 'social'}
def get_related_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes', 2: 'maybe'}
def get_request_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_offer_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_aid_related_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_medical_help_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_medical_products_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_search_and_rescue_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_security_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_military_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_child_alone_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_water_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_food_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_shelter_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_clothing_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_money_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_missing_people_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_refugees_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_death_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_other_aid_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_infrastructure_related_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_transport_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_buildings_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_electricity_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_tools_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_hospitals_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_shops_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_aid_centers_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_other_infrastructure_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_weather_related_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_floods_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_storm_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_fire_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_earthquake_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_cold_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_other_weather_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_direct_report_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_sentiment_labels() -> Dict[int, str]: return {0: 'negative', 1: 'neutral', 2: 'positive'}
```
***
3. **Setup & Loading**: This setup function downloads and loads all components, including `metadata.json`, from the Hub.
```python
def load_essentials():
print("Loading model, tokenizer, and metadata... (This may take a moment on first run)")
hub_repo_id = "spencercdz/xlm-roberta-sentiment-requests"
subfolder = "final_model"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load the model's output structure from the metadata.json file.
metadata_path = hf_hub_download(repo_id=hub_repo_id, filename="metadata.json", subfolder=subfolder)
with open(metadata_path, "r") as f:
file_metadata = json.load(f)
# Use the metadata to define the number of output neurons for the classification heads.
binary_tasks = file_metadata["binary_tasks"]
multiclass_tasks = file_metadata["multiclass_tasks"]
multilabel_column_names = file_metadata["multilabel_column_names"]
num_multilabels = len(multilabel_column_names)
num_sentiment_labels = len(get_sentiment_labels())
# Load the standard tokenizer and config.
tokenizer = AutoTokenizer.from_pretrained(hub_repo_id, subfolder=subfolder)
config = AutoConfig.from_pretrained(hub_repo_id, subfolder=subfolder)
# Add our custom sentiment label count to the config.
config.num_sentiment_labels = num_sentiment_labels
# Manually load the custom model, as it's not a standard transformers architecture.
# Create a model 'shell' with our custom architecture.
model_shell = MultiHeadClassificationModel(config=config, num_multilabels=num_multilabels)
# Download and load the trained weights.
weights_path = hf_hub_download(repo_id=hub_repo_id, filename="model.safetensors", subfolder=subfolder)
state_dict = load_file(weights_path, device="cpu")
# Apply weights to the shell. `strict=False` is required for loading custom heads.
model_shell.load_state_dict(state_dict, strict=False)
# Move model to the target device and set to evaluation mode.
model = model_shell.to(device)
model.eval()
# Package all components for use in the predict function.
metadata_for_prediction = {
"binary_tasks": binary_tasks,
"multiclass_tasks": multiclass_tasks,
"multilabel_column_names": multilabel_column_names,
"all_labels": get_all_labels(),
"device": device
}
print("Loading complete.")
return model, tokenizer, metadata_for_prediction
```
***
4. **Prediction Function**: The prediction function takes the loaded components and input text to produce a decoded dictionary.
```python
def predict(text: str, model, tokenizer, metadata: Dict) -> Dict[str, Any]:
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(metadata['device'])
with torch.no_grad():
outputs = model(**inputs)
sentiment_probs = torch.softmax(outputs['sentiment_logits'], dim=-1).cpu().numpy()
multilabel_probs = torch.sigmoid(outputs['multilabel_logits']).cpu().numpy()
results = {}
sentiment_decoder = metadata['all_labels']['sentiment']
sentiment_pred_idx = sentiment_probs.argmax()
results['sentiment'] = {'prediction': sentiment_decoder.get(sentiment_pred_idx, "unknown"), 'confidence': sentiment_probs[0, sentiment_pred_idx].item()}
for task_name in metadata['binary_tasks']:
idx = metadata['multilabel_column_names'].index(task_name)
prob = multilabel_probs[0, idx]
pred = 1 if prob > 0.5 else 0
results[task_name] = {'prediction': metadata['all_labels'][task_name][pred], 'confidence': (prob if pred == 1 else 1 - prob).item()}
for task_name, num_classes in metadata['multiclass_tasks'].items():
start_idx = metadata['multilabel_column_names'].index(f"{task_name}_0")
task_probs = multilabel_probs[0, start_idx : start_idx + num_classes]
pred_idx = task_probs.argmax()
results[task_name] = {'prediction': metadata['all_labels'][task_name].get(pred_idx, "unknown"), 'confidence': task_probs[pred_idx].item()}
return results
```
***
5. **Main Execution**: The main execution block shows how to use the functions and print the raw JSON output.
```python
if __name__ == "__main__":
model, tokenizer, metadata = load_essentials()
input_text = "I need food, water, and shelter. Help me! People are dying. We need more items."
print(f"\n--- Predicting for Input ---\n\"{input_text}\"")
predictions = predict(input_text, model, tokenizer, metadata)
# Print the raw dictionary output
print("\n--- RAW DICTIONARY OUTPUT ---")
print(json.dumps(predictions, indent=4))
```
### Sample Output
```
{'sentiment': {'prediction': 'negative', 'confidence': 0.999014139175415}, 'request': {'prediction': 'yes', 'confidence': 0.9999805688858032}, 'offer': {'prediction': 'no', 'confidence': 0.9995545148849487}, 'aid_related': {'prediction': 'yes', 'confidence': 0.9995179176330566}, 'medical_help': {'prediction': 'no', 'confidence': 0.9931818246841431}, 'medical_products': {'prediction': 'no', 'confidence': 0.9975765943527222}, 'search_and_rescue': {'prediction': 'no', 'confidence': 0.9981554746627808}, 'security': {'prediction': 'no', 'confidence': 0.999071478843689}, 'military': {'prediction': 'no', 'confidence': 0.9981452226638794}, 'child_alone': {'prediction': 'no', 'confidence': 0.9998688697814941}, 'water': {'prediction': 'yes', 'confidence': 0.9991873502731323}, 'food': {'prediction': 'yes', 'confidence': 0.9998394250869751}, 'shelter': {'prediction': 'yes', 'confidence': 0.9997198581695557}, 'clothing': {'prediction': 'no', 'confidence': 0.9982467889785767}, 'money': {'prediction': 'no', 'confidence': 0.9985392093658447}, 'missing_people': {'prediction': 'no', 'confidence': 0.998404324054718}, 'refugees': {'prediction': 'no', 'confidence': 0.9981242418289185}, 'death': {'prediction': 'yes', 'confidence': 0.9850122332572937}, 'other_aid': {'prediction': 'no', 'confidence': 0.9654157757759094}, 'infrastructure_related': {'prediction': 'no', 'confidence': 0.984534740447998}, 'transport': {'prediction': 'no', 'confidence': 0.9972304105758667}, 'buildings': {'prediction': 'no', 'confidence': 0.9881182312965393}, 'electricity': {'prediction': 'no', 'confidence': 0.9988776445388794}, 'tools': {'prediction': 'no', 'confidence': 0.9995874166488647}, 'hospitals': {'prediction': 'no', 'confidence': 0.999099850654602}, 'shops': {'prediction': 'no', 'confidence': 0.9996023178100586}, 'aid_centers': {'prediction': 'no', 'confidence': 0.9981774091720581}, 'other_infrastructure': {'prediction': 'no', 'confidence': 0.9968826770782471}, 'weather_related': {'prediction': 'no', 'confidence': 0.9632836580276489}, 'floods': {'prediction': 'no', 'confidence': 0.9960920810699463}, 'storm': {'prediction': 'no', 'confidence': 0.9963870048522949}, 'fire': {'prediction': 'no', 'confidence': 0.9993714094161987}, 'earthquake': {'prediction': 'no', 'confidence': 0.99778151512146}, 'cold': {'prediction': 'no', 'confidence': 0.9991660118103027}, 'other_weather': {'prediction': 'no', 'confidence': 0.9974269866943359}, 'direct_report': {'prediction': 'yes', 'confidence': 0.9763266444206238}, 'genre': {'prediction': 'direct', 'confidence': 0.9912198185920715}, 'related': {'prediction': 'yes', 'confidence': 0.9997092485427856}}
```
## Training and evaluation data
This model was fine-tuned on the `community-datasets/disaster_response_messages` dataset, which contains over 26,000 messages from real disaster events. Each message is labeled with 37 different categories, such as `aid_related` and `weather_related`, as well as the message `genre` (direct, news, social). The `sentiment` labels were added programmatically for the purpose of this multi-task training.
The dataset was split into:
* Training set: ~21,000 samples
* Validation set: ~2,600 samples
* Test set: ~2,600 samples
## Training procedure
The model was trained using the `transformers.Trainer` with a custom `MultiHeadClassificationModel` architecture. The training process optimized a combined loss from both the sentiment and multi-label classification heads. The best model was selected based on the `F1 Micro` score on the validation set.
### Training hyperparameters
The following hyperparameters were used during training:
- learning_rate: 2e-05
- train_batch_size: 32
- eval_batch_size: 32
- seed: 42
- optimizer: Use OptimizerNames.ADAMW_TORCH with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
- lr_scheduler_type: linear
- num_epochs: 1000 (early stopping patience of 50 epochs)
- mixed_precision_training: Native AMP
### Training results
The final results on the evaluation set are based on the best checkpoint at epoch 594. A truncated history of the 25 most important rows are shown below.
For the full data, please refer to [training_log.csv](https://huggingface.co/spencercdz/xlm-roberta-sentiment-requests/blob/main/training_log.csv) in the repository.
| Training Loss | Epoch | Step | Validation Loss | F1 Micro | F1 Macro | Subset Accuracy |
|:-------------:|:-----:|:-------:|:---------------:|:--------:|:--------:|:---------------:|
| 0.4267 | 1.0 | 658 | 0.2727 | 0.4953 | 0.0722 | 0.1053 |
| 0.2662 | 2.0 | 1316 | 0.2291 | 0.5446 | 0.0906 | 0.1123 |
| 0.2366 | 3.0 | 1974 | 0.2143 | 0.5682 | 0.1031 | 0.1279 |
| 0.2234 | 4.0 | 2632 | 0.2058 | 0.5878 | 0.1160 | 0.1333 |
| 0.2156 | 5.0 | 3290 | 0.1997 | 0.6022 | 0.1255 | 0.1380 |
| ... | ... | ... | ... | ... | ... | ... |
| 0.1773 | 25.0 | 16450 | 0.1670 | 0.6714 | 0.2305 | 0.1955 |
| 0.1694 | 50.0 | 32900 | 0.1592 | 0.6911 | 0.2701 | 0.2223 |
| 0.1662 | 75.0 | 49350 | 0.1558 | 0.7018 | 0.2960 | 0.2309 |
| 0.164 | 100.0 | 65800 | 0.1537 | 0.7077 | 0.3098 | 0.2425 |
| 0.1627 | 125.0 | 82250 | 0.1522 | 0.7104 | 0.3184 | 0.2449 |
| 0.1617 | 150.0 | 98700 | 0.1513 | 0.7130 | 0.3243 | 0.2449 |
| 0.1612 | 175.0 | 115150 | 0.1504 | 0.7143 | 0.3285 | 0.2499 |
| 0.1606 | 200.0 | 131600 | 0.1498 | 0.7161 | 0.3314 | 0.2515 |
| 0.16 | 250.0 | 164500 | 0.1488 | 0.7183 | 0.3383 | 0.2538 |
| 0.1592 | 300.0 | 197400 | 0.1482 | 0.7204 | 0.3423 | 0.2534 |
| 0.1589 | 350.0 | 230300 | 0.1476 | 0.7214 | 0.3450 | 0.2581 |
| 0.1584 | 400.0 | 263200 | 0.1474 | 0.7223 | 0.3459 | 0.2588 |
| 0.1584 | 450.0 | 296100 | 0.1471 | 0.7231 | 0.3487 | 0.2588 |
| 0.158 | 500.0 | 329000 | 0.1468 | 0.7232 | 0.3494 | 0.2612 |
| 0.1577 | 550.0 | 361900 | 0.1467 | 0.7239 | 0.3503 | 0.2600 |
| ... | ... | ... | ... | ... | ... | ... |
| 0.1574 | 591.0 | 388878 | 0.1466 | 0.7243 | 0.3510 | 0.2596 |
| 0.1576 | 592.0 | 389536 | 0.1465 | 0.7234 | 0.3496 | 0.2596 |
| 0.1582 | 593.0 | 390194 | 0.1465 | 0.7239 | 0.3504 | 0.2592 |
| 0.158 | 594.0 | 390852 | 0.1465 | 0.7240 | 0.3505 | 0.2588 |
### Framework versions
- Transformers 4.52.4
- Pytorch 2.7.1+cu128
- Datasets 3.6.0
- Tokenizers 0.21.2 |