--- license: mit datasets: - mteb/tweet_sentiment_extraction language: - en base_model: - google-bert/bert-base-uncased --- # Sentiment Analysis Model (SAM) A sentiment analysis model built using the [Burn](https://burn.dev/) deep learning framework in Rust, fine-tuned on the [MTEB Tweet Sentiment Extraction](https://huggingface.co/WarriorsSami/sentiment-analysis-model/tree/main#:~:text=tweet_sentiment_extraction) dataset and exposed via a [Rocket](https://rocket.rs/guide/v0.5/introduction/#introduction) API. ## 🧠 Model Details - **Architecture**: Transformer Encoder with 6 layers, 4 attention heads, d_model=256, and d_ff=1024. - **Embeddings**: Token and positional embeddings with a maximum sequence length of 256. - **Output Layer**: Linear layer mapping to 3 sentiment classes: Negative, Neutral, Positive. - **Activation Function**: Softmax for multi-class classification. - **Dropout**: Applied with a rate of 0.1 to prevent overfitting (one for embeddings and one for the output layer). - **Training Framework**: Burn in Rust. ## 📚 Training Data - **Dataset**: MTEB Tweet Sentiment Extraction - **Size**: 100,000 training samples. - **Preprocessing**: Utilized the BertCasedTokenizer for tokenization. - **Batching**: Mini-batch gradient descent with a batch size of 32. ## ⚙️ Training Configuration - **Optimizer**: AdamW with weight decay (0.01) and learning rate (1e-4) - especially good for training large models. - **Learning Rate Scheduler**: Noam scheduler with 5,000 warm-up steps - especially useful for transformer models. - **Loss Function**: CrossEntropyLoss with label smoothing (0.1) and class balancing. - **Gradient Clipping**: Applied with a maximum norm of 1.0. - **Early Stopping**: Implemented with a patience of 2 epochs. - **Epochs**: Trained for up to 5 epochs with early stopping based on validation loss. ## 📈 Evaluation Metrics - **Learner Summary**: ```js TextClassificationModel { transformer: TransformerEncoder {d_model: 256, d_ff: 1024, n_heads: 8, n_layers: 4, dropout: 0.1, norm_first: true, quiet_softmax: true, params: 3159040} embedding_token: Embedding {n_embedding: 28996, d_model: 256, params: 7422976} embedding_pos: Embedding {n_embedding: 256, d_model: 256, params: 65536} embed_dropout: Dropout {prob: 0.1} output_dropout: Dropout {prob: 0.1} output: Linear {d_input: 256, d_output: 3, bias: true, params: 771} n_classes: 3 max_seq_length: 256 params: 10648323 } ``` | Split | Metric | Min. | Epoch | Max. | Epoch | |-------|---------------|----------|----------|----------|----------| | Train | Loss | 1.120 | 5 | 1.171 | 1 | | Train | Accuracy | 33.743 | 2 | 37.814 | 1 | | Train | Learning Rate | 2.763e-8 | 1 | 7.648e-8 | 2 | | Valid | Loss | 1.102 | 4 | 1.110 | 1 | | Valid | Accuracy | 32.760 | 2 | 36.900 | 5 | - **TODO**: - Tweak hyperparameters to alleviate underfitting. - Enhance logging and monitoring. ## 🚀 Usage - **API Endpoint**: `/predict` - **Example Request**: ```json { "text": "I love the new features in this app!" } ``` - **Example Response**: ```json { "sentiment": "Positive" } ``` - **Steps to Run**: *TODO* after dockerizing and deploying to Hugging Face Spaces.