hw2-text-distilbert / README.md
Anyuhhh's picture
Create README.md
2942687 verified
metadata
language:
  - en
license: apache-2.0
library_name: transformers
tags:
  - text-classification
  - distilbert
  - fine-tuned
  - pytorch
datasets:
  - cassieli226/cities-text-dataset
base_model: distilbert-base-uncased
model-index:
  - name: hw2-text-distilbert
    results:
      - task:
          type: text-classification
          name: Text Classification
        dataset:
          type: cassieli226/cities-text-dataset
          name: Cities Text Dataset
          split: test
        metrics:
          - type: accuracy
            value: 99.5
            name: Test Accuracy
          - type: f1
            value: 99.5
            name: Test F1 Score (Macro)

DistilBERT Text Classification Model

This model is a fine-tuned version of distilbert-base-uncased for text classification tasks.

Model Description

This model is a fine-tuned DistilBERT model for binary text classification, specifically designed to classify text as being related to either Pittsburgh or Shanghai cities. The model achieves excellent performance with 99.5% accuracy on the test set.

  • Model type: Text Classification (Binary)
  • Language(s) (NLP): English
  • Base model: distilbert-base-uncased
  • Classes: Pittsburgh, Shanghai

Intended Uses & Limitations

Intended Uses

  • Binary text classification between Pittsburgh and Shanghai-related content
  • City-based text categorization tasks
  • Research and educational purposes in NLP and text classification

Limitations

  • Limited to English language text
  • Performance may vary on out-of-domain data
  • Maximum input length of 256 tokens due to truncation

Training and Evaluation Data

Training Data

  • Base dataset: cassieli226/cities-text-dataset
  • Classes: Pittsburgh (507 samples) and Shanghai (493 samples) in augmented dataset
  • Original dataset: 100 samples (50 Pittsburgh, 50 Shanghai)
  • Data augmentation: Applied to increase dataset size from 100 to 1000 samples
  • Train/Test Split: 80/20 split (800 train, 200 test) with stratified sampling
  • External validation: Original 100 samples used for additional validation

Preprocessing

  • Text tokenization using DistilBERT tokenizer
  • Maximum sequence length: 256 tokens
  • Truncation applied to longer sequences

Training Procedure

Training Hyperparameters

  • Learning rate: 5e-5
  • Training batch size: 16
  • Evaluation batch size: 32
  • Number of epochs: 4
  • Weight decay: 0.01
  • Warmup ratio: 0.1
  • LR scheduler: Linear
  • Gradient accumulation steps: 1
  • Mixed precision: FP16 (if GPU available)

Training Configuration

  • Optimizer: AdamW (default)
  • Early stopping: Enabled with patience of 2 epochs
  • Best model selection: Based on F1 score (macro)
  • Evaluation strategy: Every epoch
  • Save strategy: Every epoch (best model only)

Evaluation

Metrics

The model was evaluated using:

  • Accuracy: Overall classification accuracy
  • F1 Score (Macro): Macro-averaged F1 score across all classes
  • Per-class accuracy: Individual class performance metrics

Results

  • Test Set Performance:
    • Accuracy: 99.5%
    • F1 Score (Macro): 99.5%
  • External Validation:
    • Accuracy: 100.0%
    • F1 Score (Macro): 100.0%

Detailed Performance

  • Pittsburgh Class: 99.01% accuracy (101 samples)
  • Shanghai Class: 100.0% accuracy (99 samples)
  • Confusion Matrix: Only 1 misclassification out of 200 test samples

Usage

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

# Load model and tokenizer
model_name = "Anyuhhh/hw2-text-distilbert"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

# Example usage
text = "Your input text here"
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)

with torch.no_grad():
    outputs = model(**inputs)
    predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
    predicted_class = torch.argmax(predictions, dim=-1)

print(f"Predicted class: {predicted_class.item()}")