license: gpl-3.0
datasets:
- toughdata/quora-question-answer-dataset
language:
- en
metrics:
- accuracy
- precision
- recall
- f1
base_model:
- distilbert/distilbert-base-uncased
pipeline_tag: text-classification
library_name: transformers
distilbert-base-q-cat
Model Description
distilbert-base-q-cat is a lightweight, fine-tuned DistilBERT model designed for text classification, specifically focusing on categorizing questions into three distinct categories: fact, opinion, and hypothetical. The model was trained on a Quora dataset, leveraging keyword-based labeling and sentiment analysis to ensure high-quality categorization.
Features
Built on DistilBERT, ensuring faster inference and lower computational requirements compared to standard BERT.
Three Class Categories:
- Fact: Questions seeking factual or objective information.
- Opinion: Questions that elicit subjective views or opinions.
- Hypothetical: Questions exploring hypothetical scenarios or speculative ideas.
Pretrained and Fine-Tuned: Utilizes DistilBERT’s pretrained weights with additional fine-tuning on labeled data.
Dataset
The model was trained using a custom dataset derived from Quora questions:
Data Preparation:
Labeling involved keyword-based rules for fact and hypothetical questions.
Sentiment analysis determined questions as opinion-based.
Dataset Size: ~50k samples, split into training, validation, and test sets.
Performance
The model achieves the following metrics on the validation set:
- Accuracy: 93.33%
- Precision: 93.41%
- Recall: 93.33%
- F1-Score: 93.32%
Installation
To use this model, install the required dependencies:
pip install transformers torch
Usage
Load Model and Tokenizer
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Load model and tokenizer
model_name = "distilbert-base-q-cat"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3, ignore_mismatched_sizes=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)
Inference Example
def predict_question(question):
inputs = tokenizer(question, return_tensors="pt", truncation=True, padding=True)
outputs = model(**inputs)
logits = outputs.logits
predicted_class = logits.argmax(dim=-1).item()
label_map = {0: "fact", 1: "opinion", 2: "hypothetical"}
return label_map[predicted_class]
# Example usage
question = "What is artificial intelligence?"
print(predict_question(question))