Commit
ยท
3c4dee5
1
Parent(s):
17101fd
feat: add new model version alongside docs
Browse files- README.md +65 -4
- sam-artifacts/config.json +5 -3
- sam-artifacts/model.mpk +2 -2
- trainer/src/training.rs +2 -2
README.md
CHANGED
@@ -9,7 +9,68 @@ base_model:
|
|
9 |
---
|
10 |
# Sentiment Analysis Model (SAM)
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
---
|
10 |
# Sentiment Analysis Model (SAM)
|
11 |
|
12 |
+
A sentiment analysis model built using the [Burn](https://burn.dev/) deep learning framework in Rust, fine-tuned on the [MTEB Tweet Sentiment Extraction](https://huggingface.co/WarriorsSami/sentiment-analysis-model/tree/main#:~:text=tweet_sentiment_extraction) dataset and exposed via a [Rocket](https://rocket.rs/guide/v0.5/introduction/#introduction) API.
|
13 |
+
|
14 |
+
## ๐ง Model Details
|
15 |
+
- **Architecture**: Transformer Encoder with 6 layers, 4 attention heads, d_model=256, and d_ff=1024.
|
16 |
+
- **Embeddings**: Token and positional embeddings with a maximum sequence length of 256.
|
17 |
+
- **Output Layer**: Linear layer mapping to 3 sentiment classes: Negative, Neutral, Positive.
|
18 |
+
- **Activation Function**: Softmax for multi-class classification.
|
19 |
+
- **Dropout**: Applied with a rate of 0.1 to prevent overfitting (one for embeddings and one for the output layer).
|
20 |
+
- **Training Framework**: Burn in Rust.
|
21 |
+
|
22 |
+
## ๐ Training Data
|
23 |
+
- **Dataset**: MTEB Tweet Sentiment Extraction
|
24 |
+
- **Size**: 100,000 training samples.
|
25 |
+
- **Preprocessing**: Utilized the BertCasedTokenizer for tokenization.
|
26 |
+
- **Batching**: Mini-batch gradient descent with a batch size of 32.
|
27 |
+
|
28 |
+
## โ๏ธ Training Configuration
|
29 |
+
- **Optimizer**: AdamW with weight decay (0.01) and learning rate (1e-4) - especially good for training large models.
|
30 |
+
- **Learning Rate Scheduler**: Noam scheduler with 5,000 warm-up steps - especially useful for transformer models.
|
31 |
+
- **Loss Function**: CrossEntropyLoss with label smoothing (0.1) and class balancing.
|
32 |
+
- **Gradient Clipping**: Applied with a maximum norm of 1.0.
|
33 |
+
- **Early Stopping**: Implemented with a patience of 2 epochs.
|
34 |
+
- **Epochs**: Trained for up to 5 epochs with early stopping based on validation loss.
|
35 |
+
|
36 |
+
## ๐ Evaluation Metrics
|
37 |
+
- **Learner Summary**:
|
38 |
+
```js
|
39 |
+
TextClassificationModel {
|
40 |
+
transformer: TransformerEncoder {d_model: 256, d_ff: 1024, n_heads: 8, n_layers: 4, dropout: 0.1, norm_first: true, quiet_softmax: true, params: 3159040}
|
41 |
+
embedding_token: Embedding {n_embedding: 28996, d_model: 256, params: 7422976}
|
42 |
+
embedding_pos: Embedding {n_embedding: 256, d_model: 256, params: 65536}
|
43 |
+
embed_dropout: Dropout {prob: 0.1}
|
44 |
+
output_dropout: Dropout {prob: 0.1}
|
45 |
+
output: Linear {d_input: 256, d_output: 3, bias: true, params: 771}
|
46 |
+
n_classes: 3
|
47 |
+
max_seq_length: 256
|
48 |
+
params: 10648323
|
49 |
+
}
|
50 |
+
```
|
51 |
+
| Split | Metric | Min. | Epoch | Max. | Epoch |
|
52 |
+
|-------|---------------|----------|----------|----------|----------|
|
53 |
+
| Train | Loss | 1.120 | 5 | 1.171 | 1 |
|
54 |
+
| Train | Accuracy | 33.743 | 2 | 37.814 | 1 |
|
55 |
+
| Train | Learning Rate | 2.763e-8 | 1 | 7.648e-8 | 2 |
|
56 |
+
| Valid | Loss | 1.102 | 4 | 1.110 | 1 |
|
57 |
+
| Valid | Accuracy | 32.760 | 2 | 36.900 | 5 |
|
58 |
+
- **TODO**:
|
59 |
+
- Tweak hyperparameters to alleviate underfitting.
|
60 |
+
- Enhance logging and monitoring.
|
61 |
+
|
62 |
+
## ๐ Usage
|
63 |
+
- **API Endpoint**: `/predict`
|
64 |
+
- **Example Request**:
|
65 |
+
```json
|
66 |
+
{
|
67 |
+
"text": "I love the new features in this app!"
|
68 |
+
}
|
69 |
+
```
|
70 |
+
- **Example Response**:
|
71 |
+
```json
|
72 |
+
{
|
73 |
+
"sentiment": "Positive"
|
74 |
+
}
|
75 |
+
```
|
76 |
+
- **Steps to Run**: *TODO* after dockerizing and deploying to Hugging Face Spaces.
|
sam-artifacts/config.json
CHANGED
@@ -16,12 +16,14 @@
|
|
16 |
},
|
17 |
"optimizer": {
|
18 |
"weight_decay": {
|
19 |
-
"penalty": 0.
|
|
|
|
|
|
|
20 |
},
|
21 |
-
"grad_clipping": null,
|
22 |
"beta_1": 0.9,
|
23 |
"beta_2": 0.999,
|
24 |
-
"epsilon":
|
25 |
},
|
26 |
"max_seq_length": 256,
|
27 |
"batch_size": 32,
|
|
|
16 |
},
|
17 |
"optimizer": {
|
18 |
"weight_decay": {
|
19 |
+
"penalty": 0.01
|
20 |
+
},
|
21 |
+
"grad_clipping": {
|
22 |
+
"Norm": 1.0
|
23 |
},
|
|
|
24 |
"beta_1": 0.9,
|
25 |
"beta_2": 0.999,
|
26 |
+
"epsilon": 1e-8
|
27 |
},
|
28 |
"max_seq_length": 256,
|
29 |
"batch_size": 32,
|
sam-artifacts/model.mpk
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0f2e7709ec07fca095c996c0d7bd73e59974fb892d567ae00490c44e4e17efc8
|
3 |
+
size 21302072
|
trainer/src/training.rs
CHANGED
@@ -42,7 +42,7 @@ pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
|
|
42 |
let batcher = TextClassificationBatcher::new(tokenizer.clone(), config.max_seq_length);
|
43 |
|
44 |
// Create data samplers for training and testing datasets
|
45 |
-
let train_sampler = SamplerDataset::new(dataset_train,
|
46 |
let test_sampler = SamplerDataset::new(dataset_test, 5_000);
|
47 |
|
48 |
// Initialize model
|
@@ -69,7 +69,7 @@ pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
|
|
69 |
|
70 |
// Initialize learning rate scheduler
|
71 |
let lr_scheduler = NoamLrSchedulerConfig::new(1e-4)
|
72 |
-
.with_warmup_steps(
|
73 |
.with_model_size(config.transformer.d_model)
|
74 |
.init()
|
75 |
.unwrap();
|
|
|
42 |
let batcher = TextClassificationBatcher::new(tokenizer.clone(), config.max_seq_length);
|
43 |
|
44 |
// Create data samplers for training and testing datasets
|
45 |
+
let train_sampler = SamplerDataset::new(dataset_train, 100_000);
|
46 |
let test_sampler = SamplerDataset::new(dataset_test, 5_000);
|
47 |
|
48 |
// Initialize model
|
|
|
69 |
|
70 |
// Initialize learning rate scheduler
|
71 |
let lr_scheduler = NoamLrSchedulerConfig::new(1e-4)
|
72 |
+
.with_warmup_steps(5_000)
|
73 |
.with_model_size(config.transformer.d_model)
|
74 |
.init()
|
75 |
.unwrap();
|