File size: 3,744 Bytes
4f292ea 57353d0 4f292ea 19f2328 4f292ea b40f52b 4f292ea b40f52b 19f2328 4f292ea 513e1a2 f9d091d 513e1a2 f9d091d 513e1a2 f9d091d b40f52b f9d091d b40f52b f9d091d b40f52b f9d091d 513e1a2 b40f52b 513e1a2 b40f52b 513e1a2 b40f52b f9d091d b40f52b f9d091d b40f52b 513e1a2 f9d091d b40f52b f9d091d b40f52b 513e1a2 4f292ea 513e1a2 4f292ea 513e1a2 4f292ea f9d091d 4f292ea 539a918 4f292ea b40f52b f9d091d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
---
language: en
tags:
- topic-drift
- conversation-analysis
- pytorch
- attention
license: mit
datasets:
- leonvanbokhorst/topic-drift-v2
metrics:
- rmse
- r2_score
model-index:
- name: topic-drift-detector
results:
- task:
type: topic-drift-detection
name: Topic Drift Detection
dataset:
name: leonvanbokhorst/topic-drift-v2
type: conversations
metrics:
- name: Test RMSE
type: rmse
value: 0.0144
- name: Test R²
type: r2
value: 0.8666
- name: Test Loss
type: loss
value: 0.0002
---
# Topic Drift Detector Model
## Version: v20241226_105737
This model detects topic drift in conversations using a streamlined attention-based architecture. Trained on the [leonvanbokhorst/topic-drift-v2](https://huggingface.co/datasets/leonvanbokhorst/topic-drift-v2) dataset.
## Model Architecture
- Efficient single-layer attention mechanism
- Direct pattern recognition
- Streamlined processing pipeline
- Optimized scaling factor (4.0)
- PreNorm layers with residual connections
### Key Components:
1. **Embedding Processor**:
- Input dimension: 1024
- Hidden dimension: 512
- Dropout rate: 0.35
- PreNorm layers with residual connections
2. **Attention Block**:
- Single attention layer
- Feed-forward dimension: 512
- Learned position encodings
- Residual connections
3. **Pattern Recognition**:
- Direct feature extraction
- Efficient tensor operations
- Optimized memory usage
## Performance Metrics
```txt
=== Full Training Results ===
Best Validation RMSE: 0.0142
Best Validation R²: 0.8711
=== Test Set Results ===
Loss: 0.0002
RMSE: 0.0144
R²: 0.8666
```
## Training Details
- Dataset: 6400 conversations (5120 train, 640 val, 640 test)
- Window size: 8 turns
- Batch size: 32
- Learning rate: 0.0001
- Early stopping patience: 15
- Distribution regularization weight: 0.1
- Target standard deviation: 0.2
- Base embeddings: BAAI/bge-m3
## Key Improvements
1. **Simplified Architecture**:
- Reduced complexity
- Focused pattern detection
- Efficient processing
- Optimized memory usage
2. **Performance Benefits**:
- Improved RMSE (0.0144)
- Strong R² score (0.8666)
- Consistent predictions
- Wide score range
## Usage Example
```python
import torch
from transformers import AutoModel, AutoTokenizer
# Load base embedding model
base_model = AutoModel.from_pretrained('BAAI/bge-m3')
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')
# Load topic drift detector
model = torch.load('models/v20241226_105737/topic_drift_model.pt')
model.eval()
# Prepare conversation window (8 turns)
conversation = [
"How was your weekend?",
"It was great! Went hiking.",
"Which trail did you take?",
"The mountain loop trail.",
"That's nice. By the way, did you watch the game?",
"Yes! What an amazing match!",
"The final score was incredible.",
"I couldn't believe that last-minute goal."
]
# Get embeddings
with torch.no_grad():
inputs = tokenizer(conversation, padding=True, truncation=True, return_tensors='pt')
embeddings = base_model(**inputs).last_hidden_state.mean(dim=1) # [8, 1024]
# Reshape for model input [1, 8*1024]
conversation_embeddings = embeddings.view(1, -1)
# Get drift score
drift_scores = model(conversation_embeddings)
print(f"Topic drift score: {drift_scores.item():.4f}")
# Higher scores indicate more topic drift
```
## Limitations
- Works best with English conversations
- Requires exactly 8 turns of conversation
- Each turn should be between 1-512 tokens
- Relies on BAAI/bge-m3 embeddings
## Training Curves

|