|
--- |
|
datasets: |
|
- leonvanbokhorst/topic-drift |
|
--- |
|
|
|
# Topic Drift Detector Model |
|
|
|
## Version: v20241225_085248 |
|
|
|
This model detects topic drift in conversations using an enhanced attention-based architecture. |
|
|
|
## Model Architecture |
|
- Multi-head attention mechanism |
|
- Bidirectional LSTM for pattern detection |
|
- Dynamic weight generation |
|
- Semantic bridge detection |
|
|
|
## Performance Metrics |
|
```txt |
|
=== Full Training Results === |
|
Best Validation RMSE: 0.0107 |
|
Best Validation R²: 0.8867 |
|
|
|
=== Test Set Results === |
|
Loss: 0.0002 |
|
RMSE: 0.0129 |
|
R²: 0.8373 |
|
|
|
``` |
|
|
|
## Training Curves |
|
 |
|
|
|
## Usage |
|
```python |
|
import torch |
|
|
|
# Load model |
|
model = torch.load('models/v20241225_085248/topic_drift_model.pt') |
|
|
|
# Use model for inference |
|
# Input shape: [batch_size, sequence_length * embedding_dim] |
|
# Output shape: [batch_size, 1] (drift score between 0 and 1) |
|
``` |