File size: 3,665 Bytes
4f292ea 57353d0 4f292ea 19f2328 4f292ea 1aee5be 4f292ea 1aee5be 4f292ea 513e1a2 80a4e16 513e1a2 80a4e16 513e1a2 b40f52b 80a4e16 b40f52b 80a4e16 b40f52b f9d091d 80a4e16 f9d091d b40f52b 80a4e16 513e1a2 1aee5be 513e1a2 b40f52b f9d091d b40f52b f9d091d b40f52b 513e1a2 b40f52b 0903ccc 80a4e16 0903ccc 80a4e16 513e1a2 4f292ea 0903ccc 80a4e16 4f292ea 513e1a2 80a4e16 4f292ea 0903ccc 4f292ea 80a4e16 4f292ea 80a4e16 4f292ea 80a4e16 4f292ea 80a4e16 539a918 4f292ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
---
language: en
tags:
- topic-drift
- conversation-analysis
- pytorch
- attention
license: mit
datasets:
- leonvanbokhorst/topic-drift-v2
metrics:
- rmse
- r2_score
model-index:
- name: topic-drift-detector
results:
- task:
type: topic-drift-detection
name: Topic Drift Detection
dataset:
name: leonvanbokhorst/topic-drift-v2
type: conversations
metrics:
- name: Test RMSE
type: rmse
value: 0.0144
- name: Test R²
type: r2
value: 0.8666
---
# Topic Drift Detector Model
## Version: v20241226_114030
This model detects topic drift in conversations using an efficient attention-based architecture. Trained on the [leonvanbokhorst/topic-drift-v2](https://huggingface.co/datasets/leonvanbokhorst/topic-drift-v2) dataset.
## Model Architecture
### Key Components:
1. **Input Processing**:
- Input dimension: 1024 (BGE-M3 embeddings)
- Hidden dimension: 512
- Sequence length: 8 turns
2. **Attention Block**:
- Multi-head attention (4 heads)
- PreNorm layers with residual connections
- Dropout rate: 0.1
3. **Feed-Forward Network**:
- Two-layer MLP with GELU activation
- Hidden dimension: 512 -> 2048 -> 512
- Residual connections
4. **Output Layer**:
- Two-layer MLP: 512 -> 256 -> 1
- GELU activation
- Direct sigmoid output for [0,1] range
## Performance Metrics
```txt
=== Test Set Results ===
Loss: 0.0002
RMSE: 0.0144
R²: 0.8666
```
## Training Details
- Dataset: 6400 conversations (5120 train, 640 val, 640 test)
- Window size: 8 turns
- Batch size: 32
- Learning rate: 0.0001
- Early stopping patience: 15
- Distribution regularization weight: 0.1
- Target standard deviation: 0.2
- Base embeddings: BAAI/bge-m3
## Usage Example
```python
# Install dependencies
pip install torch transformers huggingface_hub
# Import required packages
import torch
from transformers import AutoModel, AutoTokenizer
from huggingface_hub import hf_hub_download
# Load base model and tokenizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
base_model = AutoModel.from_pretrained('BAAI/bge-m3').to(device)
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')
# Download and load topic drift model
model_path = hf_hub_download(
repo_id='leonvanbokhorst/topic-drift-detector',
filename='models/v20241226_114030/topic_drift_model.pt'
)
checkpoint = torch.load(model_path, weights_only=True, map_location=device)
model = EnhancedTopicDriftDetector(
input_dim=1024,
hidden_dim=checkpoint['hyperparameters']['hidden_dim']
).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Example conversation
conversation = [
"How was your weekend?",
"It was great! Went hiking.",
"Which trail did you take?",
"The mountain loop trail.",
"That's nice. By the way, did you watch the game?",
"Yes! What an amazing match!",
"The final score was incredible.",
"I couldn't believe that last-minute goal."
]
# Process conversation
with torch.no_grad():
# Get embeddings
inputs = tokenizer(conversation, padding=True, truncation=True, return_tensors='pt')
inputs = dict((k, v.to(device)) for k, v in inputs.items())
embeddings = base_model(**inputs).last_hidden_state.mean(dim=1)
# Get drift score
conversation_embeddings = embeddings.view(1, -1)
drift_score = model(conversation_embeddings)
print(f"Topic drift score: {drift_score.item():.4f}")
```
## Limitations
- Works best with English conversations
- Requires exactly 8 turns of conversation
- Each turn should be between 1-512 tokens
- Relies on BAAI/bge-m3 embeddings
|