WarriorsSami commited on
Commit
3c4dee5
ยท
1 Parent(s): 17101fd

feat: add new model version alongside docs

Browse files
README.md CHANGED
@@ -9,7 +9,68 @@ base_model:
9
  ---
10
  # Sentiment Analysis Model (SAM)
11
 
12
- ## Technologies Used
13
- - Rust
14
- - Burn
15
- - Rocket
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  ---
10
  # Sentiment Analysis Model (SAM)
11
 
12
+ A sentiment analysis model built using the [Burn](https://burn.dev/) deep learning framework in Rust, fine-tuned on the [MTEB Tweet Sentiment Extraction](https://huggingface.co/WarriorsSami/sentiment-analysis-model/tree/main#:~:text=tweet_sentiment_extraction) dataset and exposed via a [Rocket](https://rocket.rs/guide/v0.5/introduction/#introduction) API.
13
+
14
+ ## ๐Ÿง  Model Details
15
+ - **Architecture**: Transformer Encoder with 6 layers, 4 attention heads, d_model=256, and d_ff=1024.
16
+ - **Embeddings**: Token and positional embeddings with a maximum sequence length of 256.
17
+ - **Output Layer**: Linear layer mapping to 3 sentiment classes: Negative, Neutral, Positive.
18
+ - **Activation Function**: Softmax for multi-class classification.
19
+ - **Dropout**: Applied with a rate of 0.1 to prevent overfitting (one for embeddings and one for the output layer).
20
+ - **Training Framework**: Burn in Rust.
21
+
22
+ ## ๐Ÿ“š Training Data
23
+ - **Dataset**: MTEB Tweet Sentiment Extraction
24
+ - **Size**: 100,000 training samples.
25
+ - **Preprocessing**: Utilized the BertCasedTokenizer for tokenization.
26
+ - **Batching**: Mini-batch gradient descent with a batch size of 32.
27
+
28
+ ## โš™๏ธ Training Configuration
29
+ - **Optimizer**: AdamW with weight decay (0.01) and learning rate (1e-4) - especially good for training large models.
30
+ - **Learning Rate Scheduler**: Noam scheduler with 5,000 warm-up steps - especially useful for transformer models.
31
+ - **Loss Function**: CrossEntropyLoss with label smoothing (0.1) and class balancing.
32
+ - **Gradient Clipping**: Applied with a maximum norm of 1.0.
33
+ - **Early Stopping**: Implemented with a patience of 2 epochs.
34
+ - **Epochs**: Trained for up to 5 epochs with early stopping based on validation loss.
35
+
36
+ ## ๐Ÿ“ˆ Evaluation Metrics
37
+ - **Learner Summary**:
38
+ ```js
39
+ TextClassificationModel {
40
+ transformer: TransformerEncoder {d_model: 256, d_ff: 1024, n_heads: 8, n_layers: 4, dropout: 0.1, norm_first: true, quiet_softmax: true, params: 3159040}
41
+ embedding_token: Embedding {n_embedding: 28996, d_model: 256, params: 7422976}
42
+ embedding_pos: Embedding {n_embedding: 256, d_model: 256, params: 65536}
43
+ embed_dropout: Dropout {prob: 0.1}
44
+ output_dropout: Dropout {prob: 0.1}
45
+ output: Linear {d_input: 256, d_output: 3, bias: true, params: 771}
46
+ n_classes: 3
47
+ max_seq_length: 256
48
+ params: 10648323
49
+ }
50
+ ```
51
+ | Split | Metric | Min. | Epoch | Max. | Epoch |
52
+ |-------|---------------|----------|----------|----------|----------|
53
+ | Train | Loss | 1.120 | 5 | 1.171 | 1 |
54
+ | Train | Accuracy | 33.743 | 2 | 37.814 | 1 |
55
+ | Train | Learning Rate | 2.763e-8 | 1 | 7.648e-8 | 2 |
56
+ | Valid | Loss | 1.102 | 4 | 1.110 | 1 |
57
+ | Valid | Accuracy | 32.760 | 2 | 36.900 | 5 |
58
+ - **TODO**:
59
+ - Tweak hyperparameters to alleviate underfitting.
60
+ - Enhance logging and monitoring.
61
+
62
+ ## ๐Ÿš€ Usage
63
+ - **API Endpoint**: `/predict`
64
+ - **Example Request**:
65
+ ```json
66
+ {
67
+ "text": "I love the new features in this app!"
68
+ }
69
+ ```
70
+ - **Example Response**:
71
+ ```json
72
+ {
73
+ "sentiment": "Positive"
74
+ }
75
+ ```
76
+ - **Steps to Run**: *TODO* after dockerizing and deploying to Hugging Face Spaces.
sam-artifacts/config.json CHANGED
@@ -16,12 +16,14 @@
16
  },
17
  "optimizer": {
18
  "weight_decay": {
19
- "penalty": 0.00005
 
 
 
20
  },
21
- "grad_clipping": null,
22
  "beta_1": 0.9,
23
  "beta_2": 0.999,
24
- "epsilon": 0.00001
25
  },
26
  "max_seq_length": 256,
27
  "batch_size": 32,
 
16
  },
17
  "optimizer": {
18
  "weight_decay": {
19
+ "penalty": 0.01
20
+ },
21
+ "grad_clipping": {
22
+ "Norm": 1.0
23
  },
 
24
  "beta_1": 0.9,
25
  "beta_2": 0.999,
26
+ "epsilon": 1e-8
27
  },
28
  "max_seq_length": 256,
29
  "batch_size": 32,
sam-artifacts/model.mpk CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c1d72b3b82ff9c868352a6f30e31e45e339ab9ae93c6e4ed38684ddd385bddd9
3
- size 21302041
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f2e7709ec07fca095c996c0d7bd73e59974fb892d567ae00490c44e4e17efc8
3
+ size 21302072
trainer/src/training.rs CHANGED
@@ -42,7 +42,7 @@ pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
42
  let batcher = TextClassificationBatcher::new(tokenizer.clone(), config.max_seq_length);
43
 
44
  // Create data samplers for training and testing datasets
45
- let train_sampler = SamplerDataset::new(dataset_train, 50_000);
46
  let test_sampler = SamplerDataset::new(dataset_test, 5_000);
47
 
48
  // Initialize model
@@ -69,7 +69,7 @@ pub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(
69
 
70
  // Initialize learning rate scheduler
71
  let lr_scheduler = NoamLrSchedulerConfig::new(1e-4)
72
- .with_warmup_steps(8_000)
73
  .with_model_size(config.transformer.d_model)
74
  .init()
75
  .unwrap();
 
42
  let batcher = TextClassificationBatcher::new(tokenizer.clone(), config.max_seq_length);
43
 
44
  // Create data samplers for training and testing datasets
45
+ let train_sampler = SamplerDataset::new(dataset_train, 100_000);
46
  let test_sampler = SamplerDataset::new(dataset_test, 5_000);
47
 
48
  // Initialize model
 
69
 
70
  // Initialize learning rate scheduler
71
  let lr_scheduler = NoamLrSchedulerConfig::new(1e-4)
72
+ .with_warmup_steps(5_000)
73
  .with_model_size(config.transformer.d_model)
74
  .init()
75
  .unwrap();