Update README.md
Browse filesfix: Correct multi-label head size during model initialization
The script was failing on `model.load_state_dict()` with a `RuntimeError`
due to a size mismatch in the `multilabel_classifier` layer.
The root cause was an incorrect calculation of the multi-label head's
output dimension. The code was including the 3 `sentiment` labels when
calculating the size for the multi-label head, resulting in an expected
shape of [44, 768] instead of the correct [41, 768] from the checkpoint.
This commit corrects the logic in `load_essentials()` by explicitly
excluding the 'sentiment' task from the `multiclass_tasks` calculation.
This ensures the in-memory model architecture matches the saved weights,
resolving the loading error.
README.md
CHANGED
|
@@ -172,17 +172,25 @@ def get_sentiment_labels() -> Dict[int, str]: return {0: 'negative', 1: 'neutral
|
|
| 172 |
3. **Setup & Loading**: This setup function handles loading all components and reconstructing the necessary metadata.
|
| 173 |
```python
|
| 174 |
def load_essentials():
|
|
|
|
| 175 |
hub_repo_id = "spencercdz/xlm-roberta-sentiment-requests"
|
| 176 |
subfolder = "final_model"
|
| 177 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 178 |
|
| 179 |
all_labels_map = get_all_labels()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
binary_tasks = [k for k, v in all_labels_map.items() if len(v) == 2 and k not in ['related', 'sentiment']]
|
| 181 |
-
multiclass_tasks = {k: len(v) for k, v in all_labels_map.items() if len(v) > 2}
|
| 182 |
|
| 183 |
column_names = [f"{t}_{i}" for t, n in multiclass_tasks.items() for i in range(n)] + binary_tasks
|
| 184 |
multilabel_column_names = sorted(column_names)
|
| 185 |
-
num_multilabels = len(multilabel_column_names)
|
| 186 |
num_sentiment_labels = len(get_sentiment_labels())
|
| 187 |
|
| 188 |
tokenizer = AutoTokenizer.from_pretrained(hub_repo_id, subfolder=subfolder)
|
|
@@ -191,7 +199,7 @@ def load_essentials():
|
|
| 191 |
|
| 192 |
model_shell = MultiHeadClassificationModel(config=config, num_multilabels=num_multilabels)
|
| 193 |
weights_path = hf_hub_download(repo_id=hub_repo_id, filename="model.safetensors", subfolder=subfolder)
|
| 194 |
-
state_dict = load_file(weights_path, device=
|
| 195 |
model_shell.load_state_dict(state_dict, strict=False)
|
| 196 |
model = model_shell.to(device)
|
| 197 |
model.eval()
|
|
@@ -201,6 +209,7 @@ def load_essentials():
|
|
| 201 |
"multilabel_column_names": multilabel_column_names,
|
| 202 |
"all_labels": all_labels_map, "device": device
|
| 203 |
}
|
|
|
|
| 204 |
return model, tokenizer, metadata
|
| 205 |
```
|
| 206 |
***
|
|
@@ -281,7 +290,7 @@ The following hyperparameters were used during training:
|
|
| 281 |
|
| 282 |
### Training results
|
| 283 |
|
| 284 |
-
The final results on the evaluation set are based on the best checkpoint at epoch 594. A truncated history of the
|
| 285 |
For the full data, please refer to [training_log.csv](https://huggingface.co/spencercdz/xlm-roberta-sentiment-requests/blob/main/training_log.csv) in the repository.
|
| 286 |
|
| 287 |
| Training Loss | Epoch | Step | Validation Loss | F1 Micro | F1 Macro | Subset Accuracy |
|
|
|
|
| 172 |
3. **Setup & Loading**: This setup function handles loading all components and reconstructing the necessary metadata.
|
| 173 |
```python
|
| 174 |
def load_essentials():
|
| 175 |
+
print("Loading model, tokenizer, and metadata... (This may take a moment on first run)")
|
| 176 |
hub_repo_id = "spencercdz/xlm-roberta-sentiment-requests"
|
| 177 |
subfolder = "final_model"
|
| 178 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 179 |
+
print(f"Using device: {device}")
|
| 180 |
|
| 181 |
all_labels_map = get_all_labels()
|
| 182 |
+
|
| 183 |
+
# --- FIX IS HERE ---
|
| 184 |
+
# We must exclude 'sentiment' from the multiclass tasks for the multi-label head,
|
| 185 |
+
# because sentiment has its own dedicated classification head.
|
| 186 |
+
multiclass_tasks = {k: len(v) for k, v in all_labels_map.items() if len(v) > 2 and k != 'sentiment'}
|
| 187 |
+
# -------------------
|
| 188 |
+
|
| 189 |
binary_tasks = [k for k, v in all_labels_map.items() if len(v) == 2 and k not in ['related', 'sentiment']]
|
|
|
|
| 190 |
|
| 191 |
column_names = [f"{t}_{i}" for t, n in multiclass_tasks.items() for i in range(n)] + binary_tasks
|
| 192 |
multilabel_column_names = sorted(column_names)
|
| 193 |
+
num_multilabels = len(multilabel_column_names) # This will now correctly be 41
|
| 194 |
num_sentiment_labels = len(get_sentiment_labels())
|
| 195 |
|
| 196 |
tokenizer = AutoTokenizer.from_pretrained(hub_repo_id, subfolder=subfolder)
|
|
|
|
| 199 |
|
| 200 |
model_shell = MultiHeadClassificationModel(config=config, num_multilabels=num_multilabels)
|
| 201 |
weights_path = hf_hub_download(repo_id=hub_repo_id, filename="model.safetensors", subfolder=subfolder)
|
| 202 |
+
state_dict = load_file(weights_path, device="cpu") # Load to CPU first
|
| 203 |
model_shell.load_state_dict(state_dict, strict=False)
|
| 204 |
model = model_shell.to(device)
|
| 205 |
model.eval()
|
|
|
|
| 209 |
"multilabel_column_names": multilabel_column_names,
|
| 210 |
"all_labels": all_labels_map, "device": device
|
| 211 |
}
|
| 212 |
+
print("Loading complete.")
|
| 213 |
return model, tokenizer, metadata
|
| 214 |
```
|
| 215 |
***
|
|
|
|
| 290 |
|
| 291 |
### Training results
|
| 292 |
|
| 293 |
+
The final results on the evaluation set are based on the best checkpoint at epoch 594. A truncated history of the 25 most important rows are shown below.
|
| 294 |
For the full data, please refer to [training_log.csv](https://huggingface.co/spencercdz/xlm-roberta-sentiment-requests/blob/main/training_log.csv) in the repository.
|
| 295 |
|
| 296 |
| Training Loss | Epoch | Step | Validation Loss | F1 Micro | F1 Macro | Subset Accuracy |
|