Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,171 @@
|
|
1 |
-
---
|
2 |
-
license: mit
|
3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
---
|
4 |
+
# PyAnnote Segmentation ONNX
|
5 |
+
|
6 |
+
This repository provides an ONNX version of the PyAnnote speaker segmentation model for efficient inference across various platforms.
|
7 |
+
|
8 |
+
## Overview
|
9 |
+
|
10 |
+
The PyAnnote segmentation model has been converted from PyTorch to ONNX format (pyannote-segmentation-3.onnx) for improved deployment flexibility and performance. This conversion enables running speaker diarization on platforms without PyTorch dependencies, with potentially faster inference times.
|
11 |
+
|
12 |
+
## Features
|
13 |
+
|
14 |
+
- **Platform Independence**: Run speaker diarization without PyTorch dependencies
|
15 |
+
- **Optimized Performance**: ONNX runtime optimizations for faster inference
|
16 |
+
- **Simple Integration**: Straightforward JavaScript/Node.js implementation included
|
17 |
+
|
18 |
+
## Quick Start
|
19 |
+
|
20 |
+
### Installation
|
21 |
+
|
22 |
+
```bash
|
23 |
+
npm install onnxruntime-node node-fetch wav-decoder
|
24 |
+
```
|
25 |
+
|
26 |
+
### Usage Example
|
27 |
+
|
28 |
+
```javascript
|
29 |
+
|
30 |
+
import fs from 'fs/promises';
|
31 |
+
import fetch from 'node-fetch';
|
32 |
+
import wav from 'wav-decoder';
|
33 |
+
import * as ort from 'onnxruntime-node';
|
34 |
+
|
35 |
+
async function fetchAudioAsTensor(path, sampling_rate) {
|
36 |
+
let audioBuffer;
|
37 |
+
|
38 |
+
// Check if path is a local file or URL
|
39 |
+
if (path.startsWith('http')) {
|
40 |
+
const response = await fetch(path);
|
41 |
+
audioBuffer = await response.arrayBuffer();
|
42 |
+
} else {
|
43 |
+
// Read local file
|
44 |
+
audioBuffer = await fs.readFile(path);
|
45 |
+
}
|
46 |
+
|
47 |
+
const decoded = await wav.decode(new Uint8Array(audioBuffer).buffer);
|
48 |
+
const channelData = decoded.channelData[0];
|
49 |
+
const tensor = new ort.Tensor('float32', channelData, [1, 1, channelData.length]);
|
50 |
+
return tensor;
|
51 |
+
}
|
52 |
+
|
53 |
+
function postProcessSpeakerDiarization(logitsData, audioLength, samplingRate) {
|
54 |
+
const timeStep = 0.00625;
|
55 |
+
const numClasses = 7;
|
56 |
+
const numFrames = Math.floor(logitsData.length / numClasses);
|
57 |
+
|
58 |
+
const frames = [];
|
59 |
+
for (let i = 0; i < numFrames; i++) {
|
60 |
+
const frameData = Array.from(logitsData.slice(i * numClasses, (i + 1) * numClasses));
|
61 |
+
const maxVal = Math.max(...frameData);
|
62 |
+
const maxIndex = frameData.indexOf(maxVal);
|
63 |
+
|
64 |
+
frames.push({
|
65 |
+
start: i * timeStep,
|
66 |
+
end: (i + 1) * timeStep,
|
67 |
+
id: maxIndex,
|
68 |
+
confidence: maxVal
|
69 |
+
});
|
70 |
+
}
|
71 |
+
|
72 |
+
const mergedResults = [];
|
73 |
+
let currentSegment = frames[0];
|
74 |
+
|
75 |
+
for (let i = 1; i < frames.length; i++) {
|
76 |
+
if (frames[i].id === currentSegment.id) {
|
77 |
+
currentSegment.end = frames[i].end;
|
78 |
+
currentSegment.confidence = (currentSegment.confidence + frames[i].confidence) / 2;
|
79 |
+
} else {
|
80 |
+
mergedResults.push({...currentSegment});
|
81 |
+
currentSegment = frames[i];
|
82 |
+
}
|
83 |
+
}
|
84 |
+
mergedResults.push(currentSegment);
|
85 |
+
|
86 |
+
return mergedResults;
|
87 |
+
}
|
88 |
+
|
89 |
+
(async () => {
|
90 |
+
const model_url = 'pyannote-segmentation-3.onnx';
|
91 |
+
const audio_path = './mlk.wav'; // Use relative path
|
92 |
+
const sampling_rate = 16000;
|
93 |
+
|
94 |
+
const session = await ort.InferenceSession.create(model_url);
|
95 |
+
const audioTensor = await fetchAudioAsTensor(audio_path, sampling_rate);
|
96 |
+
|
97 |
+
const output = await session.run({ input_values: audioTensor });
|
98 |
+
const logits = output.logits.data;
|
99 |
+
|
100 |
+
const result = postProcessSpeakerDiarization(logits, audioTensor.dims[2], sampling_rate);
|
101 |
+
|
102 |
+
console.table(result.map(r => ({
|
103 |
+
start: Number(r.start.toFixed(5)),
|
104 |
+
end: Number(r.end.toFixed(5)),
|
105 |
+
id: r.id,
|
106 |
+
confidence: r.confidence
|
107 |
+
})));
|
108 |
+
})();
|
109 |
+
```
|
110 |
+
|
111 |
+
## Implementation Details
|
112 |
+
|
113 |
+
The repository includes a complete Node.js implementation for:
|
114 |
+
|
115 |
+
1. Loading audio from local files or URLs
|
116 |
+
2. Converting audio to the proper tensor format
|
117 |
+
3. Running inference with ONNX Runtime
|
118 |
+
4. Post-processing diarization results
|
119 |
+
|
120 |
+
## Speaker ID Interpretation
|
121 |
+
|
122 |
+
The model classifies audio segments with IDs representing different speakers or audio conditions:
|
123 |
+
|
124 |
+
- **ID 0**: Primary speaker
|
125 |
+
- **ID 2**: Secondary speaker
|
126 |
+
- **ID 3**: Background noise or brief interjections
|
127 |
+
- **ID 1**: Not typically identified by the model
|
128 |
+
|
129 |
+
## Performance Considerations
|
130 |
+
|
131 |
+
- The model processes audio with a time step of 0.00625 seconds
|
132 |
+
- Best results are achieved with 16kHz mono WAV files
|
133 |
+
- Processing longer audio files may require batching
|
134 |
+
|
135 |
+
## Example Results
|
136 |
+
|
137 |
+
When run against an audio file, the code outputs a table like this:
|
138 |
+
|
139 |
+
```
|
140 |
+
βββββββββββ¬βββββββββββ¬βββββββββββ¬βββββ¬βββββββββββββββββββββββ
|
141 |
+
β Index β Start β End β ID β Confidence β
|
142 |
+
βββββββββββΌβββββββββββΌβββββββββββΌβββββΌβββββββββββββββββββββββ€
|
143 |
+
β 0 β 0.00000 β 0.38750 β 0 β -0.5956847206408247 β
|
144 |
+
β 1 β 0.38750 β 0.87500 β 2 β -0.6725609518399854 β
|
145 |
+
β 2 β 0.87500 β 1.31875 β 0 β -0.6251495976493047 β
|
146 |
+
β 3 β 1.31875 β 1.68750 β 2 β -1.0951091697128392 β
|
147 |
+
β 4 β 1.68750 β 2.30000 β 3 β -1.2232454111418622 β
|
148 |
+
β 5 β 2.30000 β 3.19375 β 2 β -0.7195502450863511 β
|
149 |
+
β 6 β 3.19375 β 3.71250 β 0 β -0.6267317700475712 β
|
150 |
+
β 7 β 3.71250 β 4.64375 β 2 β -1.1656335032519587 β
|
151 |
+
β 8 β 4.64375 β 4.79375 β 0 β -1.0008199909561597 β
|
152 |
+
βββββββββββ΄βββββββββββ΄βββββββββββ΄βββββ΄βββββββββββββββββββββββ
|
153 |
+
```
|
154 |
+
|
155 |
+
|
156 |
+
Each row represents a segment with:
|
157 |
+
- `start`: Start time of segment (seconds)
|
158 |
+
- `end`: End time of segment (seconds)
|
159 |
+
- `id`: Speaker/class ID
|
160 |
+
- `confidence`: Model confidence score (negative numbers closer to 0 indicate higher confidence)
|
161 |
+
|
162 |
+
In this example, you can observe speaker transitions between speakers 0 and 2, with a brief segment of background noise (ID 3) around the 2-second mark.
|
163 |
+
|
164 |
+
## Applications
|
165 |
+
|
166 |
+
This ONNX-converted model is suitable for:
|
167 |
+
|
168 |
+
- Cross-platform applications
|
169 |
+
- Edge devices with limited resources
|
170 |
+
- Server-side processing with Node.js
|
171 |
+
- Batch processing of audio files
|