marlonbino's picture
Update metrics from MLflow run f8849c42 (#1)
820ce81 verified
|
raw
history blame
23.8 kB
---
license: apache-2.0
library_name: transformers
pipeline_tag: text-classification
tags:
- distilbert
- multi-task-learning
- call-center-analytics
- child-helplines
- case-classification
- crisis-support
- social-impact
- east-africa
- openchlsystem
- helpline
language:
- en
datasets:
- helpline_dataset
- openchs/synthetic_helpine_classification_v1
metrics:
- accuracy
- f1
- precision
- recall
model-index:
- name: CHS_tz_classifier_distilbert
results:
- task:
type: text-classification
name: Multi-Task Case Classification
metrics:
- type: accuracy
value: 0.75
name: Overall Average Accuracy
- type: accuracy
value: 0.833
name: Main Category Accuracy
widget:
- text: >-
Hello, I've been trying to find help for my son Ken. He's only ten years old
and he's been going through a terrible time at school. There's this boy who
keeps harassing him. It started with name-calling and teasing, but it's
escalated to physical violence. I don't know what to do. I can't bear to see
my child suffer like this.
example_title: School Bullying Case
- text: >-
On November 15th, the helpline received a call from a 17-year-old who wanted
to understand why drug use among youth is harmful. The counselor explained
the physical, social, and legal risks involved with drug abuse.
example_title: Youth Drug Education
base_model:
- distilbert/distilbert-base-uncased
---
# DistilBERT Multi-Task Classifier for Child Helpline Case Management
## Model Description
This is a fine-tuned **DistilBERT-base-uncased** model designed for **multi-task classification of child helpline and call center transcripts**. Developed by **BITZ IT Consulting** as part of the **OpenCHS AI pipeline** for child helplines and crisis support services in East Africa.
Speed and accuracy at resolving and reporting the cases matters, this finetuned model offers both.
## Model Architecture
- **Base Model**: DistilBERT (distilbert-base-uncased)
- **Architecture**: Multi-task classifier with 4 specialized output heads
- **Input**: Call center/helpline transcripts (max 256 tokens)
- **Output**: Classifications across 4 distinct tasks
- **Training**: Multi-task learning with shared DistilBERT encoder
## Classification Tasks
The model performs simultaneous classification across four critical dimensions:
| Task | Classes | Count | Purpose |
|------|---------|--------|---------|
| **Main Category** | Advice & Counselling, Child Custody, Disability, GBV, VANE, Nutrition, Information | 6 | High-level case categorization |
| **Sub Category** | Adoption, Albinism, Balanced Diet, Birth Registration, Child Abuse, etc. | 43 | Detailed topic identification |
| **Intervention** | Referred, Counselling, Signposting, Awareness/Information | 4 | Recommended action type |
| **Priority** | Low (1), Medium (2), High (3) | 3 | Urgency level for escalation |
## Performance Metrics
### Evaluation Results
| Metric | Value |
|--------|-------|
| Epoch | 1.0000 |
| Eval Avg Acc | 0.7493 |
| Eval Interv Acc | 0.6778 |
| Eval Priority Acc | 0.9543 |
| Eval Runtime | 3.1621 |
| Eval Samples Per Second | 408.2690 |
| Eval Steps Per Second | 25.6160 |
| Eval Sub Acc | 0.6158 |
### Overall Performance
- **Average Accuracy**: 75.0%
- **Best Performing Task**: Main Category (83.33%)
- **Most Challenging Task**: Sub Category (41.67%)
### Detailed Task Performance
| Task | Accuracy | Precision | Recall | F1-Score | Performance Level |
|------|----------|-----------|---------|----------|------------------|
| **Main Category** | 83.33% | High | High | 0.762 | Excellent |
| **Priority** | 75.00% | 0.505 | Variable | 0.711 | Good |
| **Intervention** | 75.00% | Variable | Variable | 0.695 | Good |
| **Sub Category** | 41.67% | Low | Variable | 0.382 | Needs Improvement |
### Task-Specific Analysis
**Main Category Performance:**
- **Excellent Classes**: Information (F1: 0.909), Child Maintenance & Custody (F1: 1.000), Nutrition (F1: 1.000)
- **Challenging Classes**: VANE (F1: 0.000) - requires more training data
- **Overall**: Strong performance with 5/6 categories well-represented
**Sub Category Performance:**
- **Perfect Classes**: Balanced Diet, Maintenance, Relationships (Parent/Child)
- **Challenging Areas**: Sexual & Reproductive Health, Child Labor, Drug/Alcohol Abuse
- **Note**: Performance varies significantly due to class imbalance (10/43 classes in test data)
**Priority Classification:**
- **High Accuracy on Low/Medium Priority**: Priority 1 (F1: 0.833), Priority 2 (F1: 0.727)
- **Challenge with High Priority**: Priority 3 cases need more representation
- **Critical for Routing**: Essential for proper case escalation
**Intervention Recommendations:**
- **Strong Performance**: Professional counseling (F1: 0.842)
- **Room for Improvement**: "No intervention needed" category (F1: 0.400)
- **Operational Impact**: Directly guides case worker actions
## Model Usage
### Installation
```bash
pip install transformers torch numpy
```
### Model Classes
```python
import torch
import torch.nn as nn
from transformers import DistilBertModel, DistilBertPreTrainedModel, AutoTokenizer
import json
import re
import numpy as np
class MultiTaskDistilBert(DistilBertPreTrainedModel):
"""
Multi-task DistilBERT classifier for child helpline case management.
Performs simultaneous classification across 4 tasks:
- Main category classification
- Sub-category classification
- Intervention recommendation
- Priority assignment
"""
def __init__(self, config, num_main, num_sub, num_interv, num_priority):
super().__init__(config)
self.distilbert = DistilBertModel(config)
self.pre_classifier = nn.Linear(config.dim, config.dim)
# Task-specific classification heads
self.classifier_main = nn.Linear(config.dim, num_main)
self.classifier_sub = nn.Linear(config.dim, num_sub)
self.classifier_interv = nn.Linear(config.dim, num_interv)
self.classifier_priority = nn.Linear(config.dim, num_priority)
self.dropout = nn.Dropout(config.dropout)
self.init_weights()
def forward(self, input_ids=None, attention_mask=None,
main_category_id=None, sub_category_id=None,
intervention_id=None, priority_id=None):
# Shared DistilBERT encoder
distilbert_output = self.distilbert(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
# Feature extraction and processing
hidden_state = distilbert_output.last_hidden_state
pooled_output = hidden_state[:, 0] # [CLS] token
pooled_output = self.pre_classifier(pooled_output)
pooled_output = nn.ReLU()(pooled_output)
pooled_output = self.dropout(pooled_output)
# Multi-task predictions
logits_main = self.classifier_main(pooled_output)
logits_sub = self.classifier_sub(pooled_output)
logits_interv = self.classifier_interv(pooled_output)
logits_priority = self.classifier_priority(pooled_output)
# Multi-task loss calculation (training only)
loss = None
if main_category_id is not None:
loss_fct = nn.CrossEntropyLoss()
loss_main = loss_fct(logits_main, main_category_id)
loss_sub = loss_fct(logits_sub, sub_category_id)
loss_interv = loss_fct(logits_interv, intervention_id)
loss_priority = loss_fct(logits_priority, priority_id)
loss = loss_main + loss_sub + loss_interv + loss_priority
# Return format compatible with Trainer
if loss is not None:
return (loss, logits_main, logits_sub, logits_interv, logits_priority)
else:
return (logits_main, logits_sub, logits_interv, logits_priority)
```
### Complete Usage Example
```python
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
import torch
import json
import re
import numpy as np
# Model setup
MODEL_NAME = "openchs/cls-gbv-distilbert-v1"
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Load label mappings
main_categories = json.load(open(hf_hub_download(MODEL_NAME, "main_categories.json")))
sub_categories = json.load(open(hf_hub_download(MODEL_NAME, "sub_categories.json")))
interventions = json.load(open(hf_hub_download(MODEL_NAME, "interventions.json")))
priorities = json.load(open(hf_hub_download(MODEL_NAME, "priorities.json")))
# Initialize model
model = MultiTaskDistilBert.from_pretrained(
MODEL_NAME,
num_main=len(main_categories),
num_sub=len(sub_categories),
num_interv=len(interventions),
num_priority=len(priorities)
)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
def classify_multitask_case(narrative: str):
"""
Classify a helpline case narrative across all task dimensions.
Args:
narrative (str): The case narrative/transcript text
Returns:
dict: Classifications for all four tasks with confidence scores
"""
# Text preprocessing
text = narrative.lower().strip()
text = re.sub(r'[^a-z0-9\s]', '', text) # Remove special characters
# Tokenization
inputs = tokenizer(
text,
truncation=True,
padding="max_length",
max_length=256,
return_tensors="pt"
).to(device)
# Inference
with torch.no_grad():
logits_main, logits_sub, logits_interv, logits_priority = model(**inputs)
# Convert logits to probabilities
probs_main = torch.softmax(logits_main, dim=1).cpu().numpy()[0]
probs_sub = torch.softmax(logits_sub, dim=1).cpu().numpy()[0]
probs_interv = torch.softmax(logits_interv, dim=1).cpu().numpy()[0]
probs_priority = torch.softmax(logits_priority, dim=1).cpu().numpy()[0]
# Get predictions (argmax)
pred_main = int(np.argmax(probs_main))
pred_sub = int(np.argmax(probs_sub))
pred_interv = int(np.argmax(probs_interv))
pred_priority = int(np.argmax(probs_priority))
return {
"main_category": {
"label": main_categories[pred_main],
"confidence": float(probs_main[pred_main])
},
"sub_category": {
"label": sub_categories[pred_sub],
"confidence": float(probs_sub[pred_sub])
},
"intervention": {
"label": interventions[pred_interv],
"confidence": float(probs_interv[pred_interv])
},
"priority": {
"label": priorities[pred_priority],
"confidence": float(probs_priority[pred_priority])
}
}
# Example usage
narrative = """
Hello, I've been trying to find help for my son Ken. He's only ten years old and
he's been going through a terrible time at school. There's this boy, James, who
keeps harassing him. It started with name-calling and teasing, but it's escalated
to physical violence. I don't know what to do. I can't bear to see my child suffer like this.
"""
result = classify_multitask_case(narrative)
print(json.dumps(result, indent=2))
```
**Expected Output:**
```json
{
"main_category": {
"label": "Advice and Counselling",
"confidence": 0.85
},
"sub_category": {
"label": "School Related Issues",
"confidence": 0.72
},
"intervention": {
"label": "Counselling",
"confidence": 0.68
},
"priority": {
"label": 2,
"confidence": 0.91
}
}
```
### FastAPI Integration
```python
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import time
app = FastAPI(title="Child Helpline Case Classification API")
class CaseInput(BaseModel):
narrative: str
include_confidence: Optional[bool] = True
@app.post("/classify")
async def classify_case(input_data: CaseInput):
try:
start_time = time.time()
result = classify_multitask_case(input_data.narrative)
processing_time = time.time() - start_time
response = {
"success": True,
"classification": result,
"processing_time_seconds": round(processing_time, 4)
}
if not input_data.include_confidence:
# Remove confidence scores if not requested
for task in result:
if isinstance(result[task], dict):
result[task] = result[task]["label"]
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy", "model": MODEL_NAME}
```
## Training Details
### Training Data
- **Total Dataset**: 12,909 augmented helpline call transcripts
- **Real Data**: 993 anonymized helpline calls
- **Synthetic Data**: N/A
- **Languages**: Primarily English
- **Domain**: Child protection, family services, crisis support
### Data Distribution
- **Main Categories**: Balanced across 6 primary case types
- **Sub Categories**: Long-tail distribution with 43 specific topics
- **Interventions**: 4 different action types based on case severity
- **Priority Levels**: 3 levels (Low, Medium, High) for case escalation
### Training Configuration
- **Base Model**: distilbert-base-uncased
- **Optimizer**: AdamW (lr=2e-5)
- **Loss Function**: Combined CrossEntropyLoss across all tasks
- **Batch Size**: 16
- **Max Length**: 512 tokens
- **Epochs**: 12
- **Weight Decay**: 0.01
- **Hardware**: NVIDIA GeForce RTX 4060
### Multi-Task Learning Approach
- **Shared Encoder**: Single DistilBERT backbone for all tasks
- **Task-Specific Heads**: Dedicated classification layers per task
- **Joint Training**: Simultaneous optimization across all objectives
- **Loss Weighting**: Equal weighting across all four tasks
## Social Impact and Applications
### Primary Use Cases
- **Automated Case Routing**: Instant classification and priority assignment
- **Supervisor Support**: Reduces manual case categorization workload
- **Quality Assurance**: Consistent classification standards across all calls
- **Resource Allocation**: Priority-based staffing and intervention planning
### Operational Benefits
- **Scalability**: Handle thousands of cases without manual intervention
- **Consistency**: Eliminate human bias in case classification
- **Speed**: Real-time classification for immediate case routing
- **Insights**: Data-driven understanding of case patterns and trends
### Target Organizations
- **Child Helplines**: 116 services across East Africa
- **Crisis Support Services**: Mental health and emergency hotlines
- **Family Support Centers**: Case management and intervention planning
- **NGOs and Government Agencies**: Child protection and welfare services
## Limitations and Considerations
### Performance Limitations
- **Sub-Category Challenge**: 41.67% accuracy indicates need for more balanced training data
- **Class Imbalance**: Some categories have limited representation in training data
- **Context Length**: Limited to 512 tokens may truncate longer narratives
- **Language Bias**: Primarily trained on English
### Operational Considerations
- **Human Oversight**: Critical cases should always involve human review
- **Confidence Thresholds**: Low-confidence predictions should trigger manual review
- **Regular Retraining**: Model performance may degrade without periodic updates
- **Cultural Context**: Model may not capture all cultural nuances in case presentation
### Ethical Considerations
- **Privacy**: All training data anonymized with strict PII removal
- **Bias Monitoring**: Regular evaluation for demographic and linguistic bias
- **Transparency**: Clear documentation of model limitations and appropriate use
- **Child Safety**: Special protocols for high-priority cases involving immediate danger
## Integration Pipeline
The model is designed to integrate seamlessly into larger AI pipelines:
1. **ASR (Whisper)** → Transcribes call audio to text
2. **Text Preprocessing** → Cleans and normalizes transcript
3. **MultiTask Classification** → Categorizes and prioritizes case
4. **NER** → Extracts Entities
5. **Case Management System** → Routes to appropriate classes
6. **Quality Assurance** → Tracks outcomes and model performance
## Model Maintenance
### Performance Monitoring
- **Accuracy Tracking**: Monitor per-task performance over time
- **Confidence Analysis**: Track prediction confidence distributions
- **Edge Case Detection**: Identify cases requiring manual review
- **Feedback Loop**: Incorporate corrected predictions into retraining data
### Update Schedule
- **Monthly Reviews**: Performance metrics and edge case analysis
- **Quarterly Retraining**: Incorporate new data and correct classification errors
- **Annual Model Refresh**: Major architecture updates and comprehensive evaluation
## Citation
```bibtex
@software{chs_distilbert_multitask_2025,
title={DistilBERT Multi-Task Classifier for Child Helpline Case Management},
author={BITZ IT Consulting Team},
year={2025},
publisher={Hugging Face},
journal={Hugging Face Model Hub},
howpublished={\url{https://huggingface.co/openchs/cls-gbv-distilbert-v1}},
note={AI for Social Impact: Automated Case Classification for Child Protection Services}
}
```
## Model Examination
### Interpretability Analysis
The model's multi-task architecture allows for analysis of shared vs. task-specific representations:
- **Shared Features**: The DistilBERT encoder captures general linguistic patterns useful across all classification tasks
- **Task-Specific Heads**: Each classification head specializes in different aspects of case analysis
- **Attention Patterns**: The model shows higher attention to key phrases indicating urgency, relationship dynamics, and specific issues
- **Feature Importance**: Critical terms include age indicators, relationship descriptors, emotion words, and action verbs
### Error Analysis
Common misclassification patterns:
- **Sub-Category Confusion**: Model sometimes confuses related sub-categories (e.g., different types of abuse)
- **Priority Assignment**: Conservative bias toward lower priority ratings for borderline cases
- **Intervention Selection**: Tendency to recommend counselling over more specific interventions
## Environmental Impact
Carbon emissions estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute):
- **Hardware Type**: NVIDIA GeForce RTX 4060 Ti
- **Hours used**: ~1 hrs total training time
- **Cloud Provider**: N/A (Local training)
- **Compute Region**: East Africa (Kenya)
- **Carbon Emitted**: Approximately 150-200 g CO2eq
*Training was conducted locally to minimize environmental impact and ensure data privacy for sensitive helpline transcripts.*
## Technical Specifications
### Model Architecture and Objective
- **Architecture**: Multi-head DistilBERT with shared encoder and task-specific classification heads
- **Parameters**: ~67M total parameters
- **Objective**: Multi-task classification with joint Cross-Entropy loss optimization
- **Input Processing**: Text normalization, tokenization with 512-token limit
- **Output**: Simultaneous predictions across 4 classification tasks
### Compute Infrastructure
#### Hardware
- **GPU**: NVIDIA GeForce RTX 4060 (16GB VRAM)
- **CPU**: Intel/AMD multi-core processor
- **RAM**: 32GB+ system memory
- **Storage**: SSD for fast data loading
#### Software
- **Framework**: PyTorch 2.0+
- **Library**: Transformers 4.30+
- **Training**: Hugging Face Trainer API
- **Tracking**: MLflow for experiment management
- **Development**: Python 3.12+, CUDA 11.8
### Performance Benchmarks
#### Inference Speed
- **Single prediction**: ~0.05 seconds on GPU
- **Batch processing**: ~200 cases/minute on GPU
- **Model size**: ~270MB on disk
- **Memory usage**: ~1GB GPU memory during inference
#### Throughput Specifications
- **Training throughput**: ~40 samples/second
- **Inference latency**: 50ms average per case
- **Scalability**: Can handle 10,000+ cases/hour on single GPU
## Testing Data, Factors & Metrics
### Testing Data
- **Size**: 12 test samples (stratified split)
- **Distribution**: Representative of real helpline case types
- **Languages**: Primarily English with some Swahili terms
- **Anonymization**: All PII removed, location/name placeholders used
### Factors
Evaluation disaggregated by:
- **Case complexity**: Simple vs. complex multi-issue cases
- **Urgency level**: Low, medium, high priority cases
- **Category type**: Different main category distributions
- **Text length**: Short vs. long narrative descriptions
### Metrics
- **Primary**: Accuracy per task (exact match)
- **Secondary**: Precision, Recall, F1-score per class
- **Aggregate**: Weighted average across all tasks
- **Operational**: Classification confidence scores
## Glossary
**Main Category**: High-level case classification (6 classes) used for initial routing and reporting
**Sub Category**: Detailed topic identification (43 classes) for specific issue targeting and resource allocation
**Intervention**: Recommended action type (22 classes) guiding case worker response and follow-up procedures
**Priority**: Urgency level (3 levels) determining response timeframe and resource allocation
**Multi-task Learning**: Training approach where model learns multiple related tasks simultaneously using shared representations
**PII**: Personally Identifiable Information - any data that could identify specific individuals, systematically removed from training data
**Case Routing**: Automated process of directing cases to appropriate teams based on classification results
## More Information
### Related Models
This model is part of a larger AI pipeline including:
- **ASR Model**: Whisper-based speech recognition for call transcription
- **QA Scoring Model**: Multi-head quality assurance evaluation (Rogendo/qa-helpline-distilbert-v1)
- **Translation Model**: Helsinki/opus-mt models for multilingual support
- **Summarization Model**: FLAN-based transcript summarization
### Research Applications
- Child protection service optimization
- Crisis intervention system design
- Multilingual helpline support research
- AI ethics in sensitive domain applications
### Future Development
- **Language Expansion**: Additional East African languages
- **Performance Improvement**: Address sub-category classification challenges
- **Real-time Integration**: Stream processing capabilities
- **Federated Learning**: Privacy-preserving multi-organization training
## Model Card Authors
- **Rogendo** (Lead Developer) - Data Engineering, Model Architecture, Training Implementation
- **Shemmiriam** (Data Analyst) - Dataset Analysis, Performance Evaluation, Metrics Design
- **Nelsonadagi** (Quality Assurance) - Model Testing, Validation, Edge Case Analysis
- **BITZ IT Consulting Team** (Collaborative Development) - Social Impact Design, Ethical Guidelines
## Model Card Contact
**Primary Contact**: [email protected]
**Organization**: BITZ IT Consulting
**Support**: Technical questions and collaboratifzon inquiries welcome
**Repository Issues**: https://huggingface.co/openchs/cls-gbv-distilbert-v1/discussions
---
**Technology for Child Protection and Crisis Support**
*This model is part of a comprehensive AI pipeline designed to improve response times and service quality for vulnerable children and families across East Africa. Every classification helps ensure that children in need receive appropriate and timely support.*