bweng commited on
Commit
9d2a3e2
·
verified ·
1 Parent(s): 4a9d1ec

Delete main.swift

Browse files
Files changed (1) hide show
  1. main.swift +0 -449
main.swift DELETED
@@ -1,449 +0,0 @@
1
- import Accelerate
2
- import AVFoundation
3
- import CoreML
4
- import Foundation
5
-
6
- struct Segment: Hashable {
7
- let start: Double
8
- let end: Double
9
- }
10
-
11
- struct SlidingWindow {
12
- var start: Double
13
- var duration: Double
14
- var step: Double
15
-
16
- func time(forFrame index: Int) -> Double {
17
- return start + Double(index) * step
18
- }
19
-
20
- func segment(forFrame index: Int) -> Segment {
21
- let s = time(forFrame: index)
22
- return Segment(start: s, end: s + duration)
23
- }
24
- }
25
-
26
- struct SlidingWindowFeature {
27
- var data: [[[Float]]] // (1, 589, 3)
28
- var slidingWindow: SlidingWindow
29
- }
30
-
31
- var speakerDB: [String: [Float]] = [:] // Global speaker database
32
- let threshold: Float = 0.7 // Distance threshold
33
-
34
- func cosineDistance(_ x: [Float], _ y: [Float]) -> Float {
35
- precondition(x.count == y.count, "Vectors must be same size")
36
- let dot = zip(x, y).map(*).reduce(0, +)
37
- let normX = sqrt(x.map { $0 * $0 }.reduce(0, +))
38
- let normY = sqrt(y.map { $0 * $0 }.reduce(0, +))
39
- return 1.0 - (dot / (normX * normY + 1e-6))
40
- }
41
-
42
- func updateSpeakerDB(_ speaker: String, _ newEmbedding: [Float], alpha: Float = 0.9) {
43
- guard var oldEmbedding = speakerDB[speaker] else { return }
44
- for i in 0..<oldEmbedding.count {
45
- oldEmbedding[i] = alpha * oldEmbedding[i] + (1 - alpha) * newEmbedding[i]
46
- }
47
- speakerDB[speaker] = oldEmbedding
48
- }
49
-
50
- func assignSpeaker(embedding: [Float], threshold: Float = 0.7) -> String {
51
- if speakerDB.isEmpty {
52
- let speaker = "Speaker 1"
53
- speakerDB[speaker] = embedding
54
- return speaker
55
- }
56
-
57
- var minDistance: Float = Float.greatestFiniteMagnitude
58
- var identifiedSpeaker: String? = nil
59
-
60
- for (speaker, refEmbedding) in speakerDB {
61
- let distance = cosineDistance(embedding, refEmbedding)
62
- if distance < minDistance {
63
- minDistance = distance
64
- identifiedSpeaker = speaker
65
- }
66
- }
67
-
68
- if let bestSpeaker = identifiedSpeaker {
69
- if minDistance > threshold {
70
- let newSpeaker = "Speaker \(speakerDB.count + 1)"
71
- speakerDB[newSpeaker] = embedding
72
- return newSpeaker
73
- } else {
74
- updateSpeakerDB(bestSpeaker, embedding)
75
- return bestSpeaker
76
- }
77
- }
78
-
79
- return "Unknown"
80
- }
81
-
82
- func getAnnotation(annotation: inout [Segment: String],
83
- speakerMapping: [Int: Int],
84
- binarizedSegments: [[[Float]]],
85
- slidingWindow: SlidingWindow) {
86
-
87
- let segmentation = binarizedSegments[0] // shape: [589][3]
88
- let numFrames = segmentation.count
89
-
90
- // Step 1: argmax to get dominant speaker per frame
91
- var frameSpeakers: [Int] = []
92
- for frame in segmentation {
93
- if let maxIdx = frame.indices.max(by: { frame[$0] < frame[$1] }) {
94
- frameSpeakers.append(maxIdx)
95
- } else {
96
- frameSpeakers.append(0) // fallback
97
- }
98
- }
99
-
100
- // Step 2: group contiguous same-speaker segments
101
- var currentSpeaker = frameSpeakers[0]
102
- var startFrame = 0
103
-
104
- for i in 1..<numFrames {
105
- if frameSpeakers[i] != currentSpeaker {
106
- let startTime = slidingWindow.time(forFrame: startFrame)
107
- let endTime = slidingWindow.time(forFrame: i)
108
-
109
- let segment = Segment(start: startTime, end: endTime)
110
- if let mappedSpeaker = speakerMapping[currentSpeaker] {
111
- annotation[segment] = "Speaker \(mappedSpeaker)"
112
- }
113
- currentSpeaker = frameSpeakers[i]
114
- startFrame = i
115
- }
116
- }
117
-
118
- // Final segment
119
- let finalStart = slidingWindow.time(forFrame: startFrame)
120
- let finalEnd = slidingWindow.segment(forFrame: numFrames - 1).end
121
- let finalSegment = Segment(start: finalStart, end: finalEnd)
122
- if let mappedSpeaker = speakerMapping[currentSpeaker] {
123
- annotation[finalSegment] = "Speaker \(mappedSpeaker)"
124
- }
125
- }
126
-
127
-
128
- func getEmbedding(audioChunk: [Float],
129
- binarizedSegments _: [[[Float]]],
130
- slidingWindowSegments: SlidingWindowFeature,
131
- chunkSize: Int = 10 * 16000,
132
- embeddingModel: MLModel) -> MLMultiArray?
133
- {
134
- // 1. Create audio_tensor of shape (1, 1, chunkSize)
135
- let audioTensor = audioChunk
136
-
137
- let batchSize = slidingWindowSegments.data.count
138
- let numFrames = slidingWindowSegments.data[0].count
139
- let numSpeakers = slidingWindowSegments.data[0][0].count
140
-
141
- // 2. Compute clean_frames = 1.0 where active speakers < 2
142
- var cleanFrames = Array(repeating: Array(repeating: 0.0 as Float, count: 1), count: numFrames)
143
-
144
- for f in 0 ..< numFrames {
145
- let frame = slidingWindowSegments.data[0][f]
146
- let speakerSum = frame.reduce(0, +)
147
- cleanFrames[f][0] = (speakerSum < 2.0) ? 1.0 : 0.0
148
- }
149
-
150
- // 3. Multiply slidingWindowSegments.data by cleanFrames
151
- var cleanSegmentData = Array(
152
- repeating: Array(repeating: Array(repeating: 0.0 as Float, count: numSpeakers), count: numFrames),
153
- count: 1
154
- )
155
-
156
- for f in 0 ..< numFrames {
157
- for s in 0 ..< numSpeakers {
158
- cleanSegmentData[0][f][s] = slidingWindowSegments.data[0][f][s] * cleanFrames[f][0]
159
- }
160
- }
161
-
162
- // 4. Flatten audio tensor to shape (3, 160000)
163
- var audioBatch: [[Float]] = []
164
- for _ in 0 ..< 3 {
165
- audioBatch.append(audioTensor)
166
- }
167
-
168
- // 5. Transpose mask shape to (3, 589)
169
- var cleanMasks: [[Float]] = Array(repeating: Array(repeating: 0.0, count: numFrames), count: numSpeakers)
170
-
171
- for s in 0 ..< numSpeakers {
172
- for f in 0 ..< numFrames {
173
- cleanMasks[s][f] = cleanSegmentData[0][f][s]
174
- }
175
- }
176
-
177
- // 6. Prepare MLMultiArray inputs
178
- guard let waveformArray = try? MLMultiArray(shape: [3, chunkSize] as [NSNumber], dataType: .float32),
179
- let maskArray = try? MLMultiArray(shape: [3, numFrames] as [NSNumber], dataType: .float32)
180
- else {
181
- print("Failed to allocate MLMultiArray")
182
- return nil
183
- }
184
-
185
- // Fill waveform
186
- for s in 0 ..< 3 {
187
- for i in 0 ..< chunkSize {
188
- waveformArray[s * chunkSize + i] = NSNumber(value: audioBatch[s][i])
189
- }
190
- }
191
-
192
- // Fill mask
193
- for s in 0 ..< 3 {
194
- for f in 0 ..< numFrames {
195
- maskArray[s * numFrames + f] = NSNumber(value: cleanMasks[s][f])
196
- }
197
- }
198
-
199
- // 7. Run model
200
- let inputs: [String: Any] = [
201
- "waveform": waveformArray,
202
- "mask": maskArray,
203
- ]
204
-
205
- guard let output = try? embeddingModel.prediction(from: MLDictionaryFeatureProvider(dictionary: inputs)) else {
206
- print("Embedding model prediction failed")
207
- return nil
208
- }
209
-
210
- return output.featureValue(for: "embedding")?.multiArrayValue
211
- }
212
-
213
- func loadAudioSamples(from url: URL, expectedSampleRate: Double = 16000.0) throws -> [Float] {
214
- let file = try AVAudioFile(forReading: url)
215
- let format = AVAudioFormat(commonFormat: .pcmFormatFloat32,
216
- sampleRate: expectedSampleRate,
217
- channels: 1,
218
- interleaved: false)!
219
-
220
- let engine = AVAudioEngine()
221
- let player = AVAudioPlayerNode()
222
- engine.attach(player)
223
-
224
- let converter = AVAudioConverter(from: file.processingFormat, to: format)!
225
- let frameCapacity = AVAudioFrameCount(file.length)
226
- let buffer = AVAudioPCMBuffer(pcmFormat: file.processingFormat, frameCapacity: frameCapacity)!
227
- try file.read(into: buffer)
228
-
229
- let outputBuffer = AVAudioPCMBuffer(pcmFormat: format, frameCapacity: frameCapacity)!
230
-
231
- let inputBlock: AVAudioConverterInputBlock = { _, outStatus in
232
- outStatus.pointee = .haveData
233
- return buffer
234
- }
235
-
236
- try converter.convert(to: outputBuffer, error: nil, withInputFrom: inputBlock)
237
-
238
- guard let floatChannelData = outputBuffer.floatChannelData else {
239
- throw NSError(domain: "Audio", code: -1, userInfo: [NSLocalizedDescriptionKey: "Missing float data"])
240
- }
241
-
242
- let channelData = floatChannelData[0]
243
- let samples = Array(UnsafeBufferPointer(start: channelData, count: Int(outputBuffer.frameLength)))
244
- return samples
245
- }
246
-
247
- func chunkAndRunSegmentation(samples: [Float], chunkSize: Int = 160_000, model: MLModel, embeddingModel: MLModel) throws {
248
- let totalSamples = samples.count
249
- let numberOfChunks = Int(ceil(Double(totalSamples) / Double(chunkSize)))
250
- var annotations: [Segment: String] = [:]
251
-
252
- for i in 0 ..< numberOfChunks {
253
- let start = i * chunkSize
254
- let end = min((i + 1) * chunkSize, totalSamples)
255
- let chunk = Array(samples[start ..< end])
256
-
257
- // If chunk is shorter than 10s, pad with zeros
258
- var paddedChunk = chunk
259
- if chunk.count < chunkSize {
260
- paddedChunk += Array(repeating: 0.0, count: chunkSize - chunk.count)
261
- }
262
-
263
- let binarizedSegments = try getSegments(audioChunk: paddedChunk, model: model)
264
- let frames = SlidingWindow(start: Double(i) * 10.0, duration: 0.0619375, step: 0.016875)
265
- let slidingFeature = SlidingWindowFeature(data: binarizedSegments, slidingWindow: frames)
266
- if let embeddings = getEmbedding(audioChunk: paddedChunk,
267
- binarizedSegments: binarizedSegments,
268
- slidingWindowSegments: slidingFeature,
269
- embeddingModel: embeddingModel)
270
- {
271
- print("Embeddings shape: \(embeddings.shape.map { $0.intValue })")
272
-
273
- let shape = embeddings.shape.map { $0.intValue } // [3, 256]
274
- let numSpeakers = shape[0]
275
- let embeddingDim = shape[1]
276
- let strides = embeddings.strides.map { $0.intValue }
277
-
278
- var speakerSums = [Float](repeating: 0.0, count: numSpeakers)
279
-
280
- for s in 0 ..< numSpeakers {
281
- for d in 0 ..< embeddingDim {
282
- let index = s * strides[0] + d * strides[1]
283
- speakerSums[s] += embeddings[index].floatValue
284
- }
285
- }
286
-
287
- print("Sum along axis 1 (per speaker): \(speakerSums)")
288
-
289
- // Step 3: Assign speaker label to each embedding
290
- var speakerLabels = [String]()
291
- for s in 0..<numSpeakers {
292
- var embeddingVec = [Float](repeating: 0.0, count: embeddingDim)
293
- for d in 0..<embeddingDim {
294
- let index = s * strides[0] + d * strides[1]
295
- embeddingVec[d] = embeddings[index].floatValue
296
- }
297
- let label = assignSpeaker(embedding: embeddingVec)
298
- speakerLabels.append(label)
299
- }
300
-
301
- print("Chunk \(i + 1): Assigned Speakers: \(speakerLabels)")
302
-
303
- // Step 4: Update annotations
304
- // Map speaker index 0,1,2 → assigned speakerLabels
305
- var labelMapping: [Int: Int] = [:]
306
- for (idx, label) in speakerLabels.enumerated() {
307
- if let spkNum = Int(label.components(separatedBy: " ").last ?? "") {
308
- labelMapping[idx] = spkNum
309
- }
310
- }
311
-
312
- getAnnotation(annotation: &annotations,
313
- speakerMapping: labelMapping,
314
- binarizedSegments: binarizedSegments,
315
- slidingWindow: frames)
316
-
317
- print("Chunk \(i + 1) → Segments shape: \(binarizedSegments[0].count) frames")
318
- }
319
- }
320
-
321
- // Final result
322
- print("\n=== Final Annotations ===")
323
- for (segment, speaker) in annotations.sorted(by: { $0.key.start < $1.key.start }) {
324
- print("\(speaker): \(segment.start) - \(segment.end)")
325
- }
326
-
327
- }
328
-
329
- func powersetConversion(_ segments: [[[Float]]]) -> [[[Float]]] {
330
- let powerset: [[Int]] = [
331
- [], // 0
332
- [0], // 1
333
- [1], // 2
334
- [2], // 3
335
- [0, 1], // 4
336
- [0, 2], // 5
337
- [1, 2], // 6
338
- ]
339
-
340
- let batchSize = segments.count
341
- let numFrames = segments[0].count
342
- let numCombos = segments[0][0].count // 7
343
-
344
- let numSpeakers = 3
345
- var binarized = Array(
346
- repeating: Array(
347
- repeating: Array(repeating: 0.0 as Float, count: numSpeakers),
348
- count: numFrames
349
- ),
350
- count: batchSize
351
- )
352
-
353
- for b in 0 ..< batchSize {
354
- for f in 0 ..< numFrames {
355
- let frame = segments[b][f]
356
-
357
- // Find index of max value in this frame
358
- guard let bestIdx = frame.indices.max(by: { frame[$0] < frame[$1] }) else {
359
- continue
360
- }
361
-
362
- // Mark the corresponding speakers as active
363
- for speaker in powerset[bestIdx] {
364
- binarized[b][f][speaker] = 1.0
365
- }
366
- }
367
- }
368
-
369
- return binarized
370
- }
371
-
372
- func getSegments(audioChunk: [Float], sampleRate _: Int = 16000, chunkSize: Int = 160_000, model: MLModel) throws -> [[[Float]]] {
373
- // Ensure correct shape: (1, 1, chunk_size)
374
- let audioArray = try MLMultiArray(shape: [1, 1, NSNumber(value: chunkSize)], dataType: .float32)
375
- for i in 0 ..< audioChunk.count {
376
- audioArray[i] = NSNumber(value: audioChunk[i])
377
- }
378
-
379
- // Prepare input
380
- let input = try MLDictionaryFeatureProvider(dictionary: ["audio": audioArray])
381
-
382
- // Run prediction
383
- let output = try model.prediction(from: input)
384
-
385
- // Extract segments output: shape assumed (1, frames, 7)
386
- guard let segmentOutput = output.featureValue(for: "segments")?.multiArrayValue else {
387
- throw NSError(domain: "ModelOutput", code: -1, userInfo: [NSLocalizedDescriptionKey: "Missing segments output"])
388
- }
389
-
390
- let frames = segmentOutput.shape[1].intValue
391
- let combinations = segmentOutput.shape[2].intValue
392
-
393
- // Convert MLMultiArray to [[[Float]]]
394
- var segments = Array(repeating: Array(repeating: Array(repeating: 0.0 as Float, count: combinations), count: frames), count: 1)
395
-
396
- for f in 0 ..< frames {
397
- for c in 0 ..< combinations {
398
- let index = f * combinations + c
399
- segments[0][f][c] = segmentOutput[index].floatValue
400
- }
401
- }
402
-
403
- // Apply powerset conversion
404
- let binarizedSegments = powersetConversion(segments)
405
-
406
- // Assume segments shape is (1, 589, 3)
407
- guard binarizedSegments.count == 1 else {
408
- fatalError("Expected batch size 1")
409
- }
410
-
411
- let b_frames = binarizedSegments[0]
412
- let numSpeakers = b_frames[0].count
413
-
414
- // Initialize sum array
415
- var speakerSums = Array(repeating: 0.0 as Float, count: numSpeakers)
416
-
417
- // Sum across axis 1 (frames)
418
- for frame in b_frames {
419
- for (i, value) in frame.enumerated() {
420
- speakerSums[i] += value
421
- }
422
- }
423
-
424
- print("Sum across axis 1 (frames): \(speakerSums)")
425
-
426
- return binarizedSegments
427
- }
428
-
429
- func loadModel(from path: String) throws -> MLModel {
430
- let url = URL(fileURLWithPath: path)
431
- let model = try MLModel(contentsOf: url)
432
- return model
433
- }
434
-
435
- do {
436
- let modelPath = "./pyannote_segmentation.mlmodelc"
437
- let embeddingPath = "./wespeaker.mlmodelc"
438
- let model = try loadModel(from: modelPath)
439
- let embeddingModel = try loadModel(from: embeddingPath)
440
- print("Model loaded successfully.")
441
-
442
- // let audioPath = "./first_10_seconds.wav"
443
- let audioPath = "./TS3003b_mix_headset.wav"
444
-
445
- let audioSamples = try loadAudioSamples(from: URL(fileURLWithPath: audioPath))
446
- try chunkAndRunSegmentation(samples: audioSamples, model: model, embeddingModel: embeddingModel)
447
- } catch {
448
- print("Error: \(error)")
449
- }