File size: 3,665 Bytes
4f292ea
 
 
 
 
 
 
 
 
57353d0
4f292ea
 
 
 
 
 
 
 
 
 
19f2328
4f292ea
 
 
 
1aee5be
4f292ea
 
1aee5be
4f292ea
513e1a2
 
 
80a4e16
513e1a2
80a4e16
513e1a2
 
b40f52b
 
80a4e16
 
b40f52b
80a4e16
b40f52b
f9d091d
80a4e16
 
 
 
 
 
 
f9d091d
b40f52b
80a4e16
 
 
 
513e1a2
 
 
 
1aee5be
 
 
513e1a2
 
b40f52b
 
 
 
f9d091d
b40f52b
f9d091d
 
b40f52b
513e1a2
b40f52b
0903ccc
80a4e16
 
0903ccc
 
80a4e16
513e1a2
4f292ea
0903ccc
 
80a4e16
 
 
4f292ea
513e1a2
80a4e16
 
 
 
 
 
 
 
 
 
 
4f292ea
 
0903ccc
4f292ea
 
 
 
 
 
 
 
 
 
 
80a4e16
4f292ea
80a4e16
4f292ea
80a4e16
 
4f292ea
 
80a4e16
 
 
539a918
4f292ea
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
---
language: en
tags:
- topic-drift
- conversation-analysis
- pytorch
- attention
license: mit
datasets:
- leonvanbokhorst/topic-drift-v2
metrics:
- rmse
- r2_score
model-index:
- name: topic-drift-detector
  results:
  - task: 
      type: topic-drift-detection
      name: Topic Drift Detection
    dataset:
      name: leonvanbokhorst/topic-drift-v2
      type: conversations
    metrics:
      - name: Test RMSE
        type: rmse
        value: 0.0144
      - name: Test 
        type: r2
        value: 0.8666
---

# Topic Drift Detector Model

## Version: v20241226_114030

This model detects topic drift in conversations using an efficient attention-based architecture. Trained on the [leonvanbokhorst/topic-drift-v2](https://huggingface.co/datasets/leonvanbokhorst/topic-drift-v2) dataset.

## Model Architecture

### Key Components:
1. **Input Processing**:
   - Input dimension: 1024 (BGE-M3 embeddings)
   - Hidden dimension: 512
   - Sequence length: 8 turns

2. **Attention Block**:
   - Multi-head attention (4 heads)
   - PreNorm layers with residual connections
   - Dropout rate: 0.1

3. **Feed-Forward Network**:
   - Two-layer MLP with GELU activation
   - Hidden dimension: 512 -> 2048 -> 512
   - Residual connections

4. **Output Layer**:
   - Two-layer MLP: 512 -> 256 -> 1
   - GELU activation
   - Direct sigmoid output for [0,1] range

## Performance Metrics
```txt
=== Test Set Results ===
Loss: 0.0002
RMSE: 0.0144
R²: 0.8666
```

## Training Details
- Dataset: 6400 conversations (5120 train, 640 val, 640 test)
- Window size: 8 turns
- Batch size: 32
- Learning rate: 0.0001
- Early stopping patience: 15
- Distribution regularization weight: 0.1
- Target standard deviation: 0.2
- Base embeddings: BAAI/bge-m3

## Usage Example

```python
# Install dependencies
pip install torch transformers huggingface_hub

# Import required packages
import torch
from transformers import AutoModel, AutoTokenizer
from huggingface_hub import hf_hub_download

# Load base model and tokenizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
base_model = AutoModel.from_pretrained('BAAI/bge-m3').to(device)
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')

# Download and load topic drift model
model_path = hf_hub_download(
    repo_id='leonvanbokhorst/topic-drift-detector',
    filename='models/v20241226_114030/topic_drift_model.pt'
)
checkpoint = torch.load(model_path, weights_only=True, map_location=device)
model = EnhancedTopicDriftDetector(
    input_dim=1024,
    hidden_dim=checkpoint['hyperparameters']['hidden_dim']
).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Example conversation
conversation = [
    "How was your weekend?",
    "It was great! Went hiking.",
    "Which trail did you take?",
    "The mountain loop trail.",
    "That's nice. By the way, did you watch the game?",
    "Yes! What an amazing match!",
    "The final score was incredible.",
    "I couldn't believe that last-minute goal."
]

# Process conversation
with torch.no_grad():
    # Get embeddings
    inputs = tokenizer(conversation, padding=True, truncation=True, return_tensors='pt')
    inputs = dict((k, v.to(device)) for k, v in inputs.items())
    embeddings = base_model(**inputs).last_hidden_state.mean(dim=1)
    
    # Get drift score
    conversation_embeddings = embeddings.view(1, -1)
    drift_score = model(conversation_embeddings)
    print(f"Topic drift score: {drift_score.item():.4f}")
```

## Limitations
- Works best with English conversations
- Requires exactly 8 turns of conversation
- Each turn should be between 1-512 tokens
- Relies on BAAI/bge-m3 embeddings