File size: 15,379 Bytes
1ac1652
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
import Accelerate
import AVFoundation
import CoreML
import Foundation

struct Segment: Hashable {
    let start: Double
    let end: Double
}

struct SlidingWindow {
    var start: Double
    var duration: Double
    var step: Double
    
    func time(forFrame index: Int) -> Double {
        return start + Double(index) * step
    }
    
    func segment(forFrame index: Int) -> Segment {
        let s = time(forFrame: index)
        return Segment(start: s, end: s + duration)
    }
}

struct SlidingWindowFeature {
    var data: [[[Float]]] // (1, 589, 3)
    var slidingWindow: SlidingWindow
}

var speakerDB: [String: [Float]] = [:]  // Global speaker database
let threshold: Float = 0.7              // Distance threshold

func cosineDistance(_ x: [Float], _ y: [Float]) -> Float {
    precondition(x.count == y.count, "Vectors must be same size")
    let dot = zip(x, y).map(*).reduce(0, +)
    let normX = sqrt(x.map { $0 * $0 }.reduce(0, +))
    let normY = sqrt(y.map { $0 * $0 }.reduce(0, +))
    return 1.0 - (dot / (normX * normY + 1e-6))
}

func updateSpeakerDB(_ speaker: String, _ newEmbedding: [Float], alpha: Float = 0.9) {
    guard var oldEmbedding = speakerDB[speaker] else { return }
    for i in 0..<oldEmbedding.count {
        oldEmbedding[i] = alpha * oldEmbedding[i] + (1 - alpha) * newEmbedding[i]
    }
    speakerDB[speaker] = oldEmbedding
}

func assignSpeaker(embedding: [Float], threshold: Float = 0.7) -> String {
    if speakerDB.isEmpty {
        let speaker = "Speaker 1"
        speakerDB[speaker] = embedding
        return speaker
    }

    var minDistance: Float = Float.greatestFiniteMagnitude
    var identifiedSpeaker: String? = nil

    for (speaker, refEmbedding) in speakerDB {
        let distance = cosineDistance(embedding, refEmbedding)
        if distance < minDistance {
            minDistance = distance
            identifiedSpeaker = speaker
        }
    }

    if let bestSpeaker = identifiedSpeaker {
        if minDistance > threshold {
            let newSpeaker = "Speaker \(speakerDB.count + 1)"
            speakerDB[newSpeaker] = embedding
            return newSpeaker
        } else {
            updateSpeakerDB(bestSpeaker, embedding)
            return bestSpeaker
        }
    }

    return "Unknown"
}

func getAnnotation(annotation: inout [Segment: String],
                   speakerMapping: [Int: Int],
                   binarizedSegments: [[[Float]]],
                   slidingWindow: SlidingWindow) {
    
    let segmentation = binarizedSegments[0] // shape: [589][3]
    let numFrames = segmentation.count

    // Step 1: argmax to get dominant speaker per frame
    var frameSpeakers: [Int] = []
    for frame in segmentation {
        if let maxIdx = frame.indices.max(by: { frame[$0] < frame[$1] }) {
            frameSpeakers.append(maxIdx)
        } else {
            frameSpeakers.append(0) // fallback
        }
    }

    // Step 2: group contiguous same-speaker segments
    var currentSpeaker = frameSpeakers[0]
    var startFrame = 0

    for i in 1..<numFrames {
        if frameSpeakers[i] != currentSpeaker {
            let startTime = slidingWindow.time(forFrame: startFrame)
            let endTime = slidingWindow.time(forFrame: i)

            let segment = Segment(start: startTime, end: endTime)
            if let mappedSpeaker = speakerMapping[currentSpeaker] {
                annotation[segment] = "Speaker \(mappedSpeaker)"
            }
            currentSpeaker = frameSpeakers[i]
            startFrame = i
        }
    }

    // Final segment
    let finalStart = slidingWindow.time(forFrame: startFrame)
    let finalEnd = slidingWindow.segment(forFrame: numFrames - 1).end
    let finalSegment = Segment(start: finalStart, end: finalEnd)
    if let mappedSpeaker = speakerMapping[currentSpeaker] {
        annotation[finalSegment] = "Speaker \(mappedSpeaker)"
    }
}


func getEmbedding(audioChunk: [Float],
                  binarizedSegments _: [[[Float]]],
                  slidingWindowSegments: SlidingWindowFeature,
                  chunkSize: Int = 10 * 16000,
                  embeddingModel: MLModel) -> MLMultiArray?
{
    // 1. Create audio_tensor of shape (1, 1, chunkSize)
    let audioTensor = audioChunk

    let batchSize = slidingWindowSegments.data.count
    let numFrames = slidingWindowSegments.data[0].count
    let numSpeakers = slidingWindowSegments.data[0][0].count

    // 2. Compute clean_frames = 1.0 where active speakers < 2
    var cleanFrames = Array(repeating: Array(repeating: 0.0 as Float, count: 1), count: numFrames)

    for f in 0 ..< numFrames {
        let frame = slidingWindowSegments.data[0][f]
        let speakerSum = frame.reduce(0, +)
        cleanFrames[f][0] = (speakerSum < 2.0) ? 1.0 : 0.0
    }

    // 3. Multiply slidingWindowSegments.data by cleanFrames
    var cleanSegmentData = Array(
        repeating: Array(repeating: Array(repeating: 0.0 as Float, count: numSpeakers), count: numFrames),
        count: 1
    )

    for f in 0 ..< numFrames {
        for s in 0 ..< numSpeakers {
            cleanSegmentData[0][f][s] = slidingWindowSegments.data[0][f][s] * cleanFrames[f][0]
        }
    }

    // 4. Flatten audio tensor to shape (3, 160000)
    var audioBatch: [[Float]] = []
    for _ in 0 ..< 3 {
        audioBatch.append(audioTensor)
    }

    // 5. Transpose mask shape to (3, 589)
    var cleanMasks: [[Float]] = Array(repeating: Array(repeating: 0.0, count: numFrames), count: numSpeakers)

    for s in 0 ..< numSpeakers {
        for f in 0 ..< numFrames {
            cleanMasks[s][f] = cleanSegmentData[0][f][s]
        }
    }

    // 6. Prepare MLMultiArray inputs
    guard let waveformArray = try? MLMultiArray(shape: [3, chunkSize] as [NSNumber], dataType: .float32),
          let maskArray = try? MLMultiArray(shape: [3, numFrames] as [NSNumber], dataType: .float32)
    else {
        print("Failed to allocate MLMultiArray")
        return nil
    }

    // Fill waveform
    for s in 0 ..< 3 {
        for i in 0 ..< chunkSize {
            waveformArray[s * chunkSize + i] = NSNumber(value: audioBatch[s][i])
        }
    }

    // Fill mask
    for s in 0 ..< 3 {
        for f in 0 ..< numFrames {
            maskArray[s * numFrames + f] = NSNumber(value: cleanMasks[s][f])
        }
    }

    // 7. Run model
    let inputs: [String: Any] = [
        "waveform": waveformArray,
        "mask": maskArray,
    ]

    guard let output = try? embeddingModel.prediction(from: MLDictionaryFeatureProvider(dictionary: inputs)) else {
        print("Embedding model prediction failed")
        return nil
    }

    return output.featureValue(for: "embedding")?.multiArrayValue
}

func loadAudioSamples(from url: URL, expectedSampleRate: Double = 16000.0) throws -> [Float] {
    let file = try AVAudioFile(forReading: url)
    let format = AVAudioFormat(commonFormat: .pcmFormatFloat32,
                               sampleRate: expectedSampleRate,
                               channels: 1,
                               interleaved: false)!

    let engine = AVAudioEngine()
    let player = AVAudioPlayerNode()
    engine.attach(player)

    let converter = AVAudioConverter(from: file.processingFormat, to: format)!
    let frameCapacity = AVAudioFrameCount(file.length)
    let buffer = AVAudioPCMBuffer(pcmFormat: file.processingFormat, frameCapacity: frameCapacity)!
    try file.read(into: buffer)

    let outputBuffer = AVAudioPCMBuffer(pcmFormat: format, frameCapacity: frameCapacity)!

    let inputBlock: AVAudioConverterInputBlock = { _, outStatus in
        outStatus.pointee = .haveData
        return buffer
    }

    try converter.convert(to: outputBuffer, error: nil, withInputFrom: inputBlock)

    guard let floatChannelData = outputBuffer.floatChannelData else {
        throw NSError(domain: "Audio", code: -1, userInfo: [NSLocalizedDescriptionKey: "Missing float data"])
    }

    let channelData = floatChannelData[0]
    let samples = Array(UnsafeBufferPointer(start: channelData, count: Int(outputBuffer.frameLength)))
    return samples
}

func chunkAndRunSegmentation(samples: [Float], chunkSize: Int = 160_000, model: MLModel, embeddingModel: MLModel) throws {
    let totalSamples = samples.count
    let numberOfChunks = Int(ceil(Double(totalSamples) / Double(chunkSize)))
    var annotations: [Segment: String] = [:]

    for i in 0 ..< numberOfChunks {
        let start = i * chunkSize
        let end = min((i + 1) * chunkSize, totalSamples)
        let chunk = Array(samples[start ..< end])

        // If chunk is shorter than 10s, pad with zeros
        var paddedChunk = chunk
        if chunk.count < chunkSize {
            paddedChunk += Array(repeating: 0.0, count: chunkSize - chunk.count)
        }

        let binarizedSegments = try getSegments(audioChunk: paddedChunk, model: model)
        let frames = SlidingWindow(start: Double(i) * 10.0, duration: 0.0619375, step: 0.016875)
        let slidingFeature = SlidingWindowFeature(data: binarizedSegments, slidingWindow: frames)
        if let embeddings = getEmbedding(audioChunk: paddedChunk,
                                         binarizedSegments: binarizedSegments,
                                         slidingWindowSegments: slidingFeature,
                                         embeddingModel: embeddingModel)
        {
            print("Embeddings shape: \(embeddings.shape.map { $0.intValue })")

            let shape = embeddings.shape.map { $0.intValue } // [3, 256]
            let numSpeakers = shape[0]
            let embeddingDim = shape[1]
            let strides = embeddings.strides.map { $0.intValue }

            var speakerSums = [Float](repeating: 0.0, count: numSpeakers)

            for s in 0 ..< numSpeakers {
                for d in 0 ..< embeddingDim {
                    let index = s * strides[0] + d * strides[1]
                    speakerSums[s] += embeddings[index].floatValue
                }
            }

            print("Sum along axis 1 (per speaker): \(speakerSums)") 

            // Step 3: Assign speaker label to each embedding
            var speakerLabels = [String]()
            for s in 0..<numSpeakers {
                var embeddingVec = [Float](repeating: 0.0, count: embeddingDim)
                for d in 0..<embeddingDim {
                    let index = s * strides[0] + d * strides[1]
                    embeddingVec[d] = embeddings[index].floatValue
                }
                let label = assignSpeaker(embedding: embeddingVec)
                speakerLabels.append(label)
            }

            print("Chunk \(i + 1): Assigned Speakers: \(speakerLabels)")

            // Step 4: Update annotations
            // Map speaker index 0,1,2 → assigned speakerLabels
            var labelMapping: [Int: Int] = [:]
            for (idx, label) in speakerLabels.enumerated() {
                if let spkNum = Int(label.components(separatedBy: " ").last ?? "") {
                    labelMapping[idx] = spkNum
                }
            }

            getAnnotation(annotation: &annotations,
                      speakerMapping: labelMapping,
                      binarizedSegments: binarizedSegments,
                      slidingWindow: frames)

            print("Chunk \(i + 1) → Segments shape: \(binarizedSegments[0].count) frames")
        }
    }

    // Final result
    print("\n=== Final Annotations ===")
    for (segment, speaker) in annotations.sorted(by: { $0.key.start < $1.key.start }) {
        print("\(speaker): \(segment.start) - \(segment.end)")
    }

}

func powersetConversion(_ segments: [[[Float]]]) -> [[[Float]]] {
    let powerset: [[Int]] = [
        [], // 0
        [0], // 1
        [1], // 2
        [2], // 3
        [0, 1], // 4
        [0, 2], // 5
        [1, 2], // 6
    ]

    let batchSize = segments.count
    let numFrames = segments[0].count
    let numCombos = segments[0][0].count // 7

    let numSpeakers = 3
    var binarized = Array(
        repeating: Array(
            repeating: Array(repeating: 0.0 as Float, count: numSpeakers),
            count: numFrames
        ),
        count: batchSize
    )

    for b in 0 ..< batchSize {
        for f in 0 ..< numFrames {
            let frame = segments[b][f]

            // Find index of max value in this frame
            guard let bestIdx = frame.indices.max(by: { frame[$0] < frame[$1] }) else {
                continue
            }

            // Mark the corresponding speakers as active
            for speaker in powerset[bestIdx] {
                binarized[b][f][speaker] = 1.0
            }
        }
    }

    return binarized
}

func getSegments(audioChunk: [Float], sampleRate _: Int = 16000, chunkSize: Int = 160_000, model: MLModel) throws -> [[[Float]]] {
    // Ensure correct shape: (1, 1, chunk_size)
    let audioArray = try MLMultiArray(shape: [1, 1, NSNumber(value: chunkSize)], dataType: .float32)
    for i in 0 ..< audioChunk.count {
        audioArray[i] = NSNumber(value: audioChunk[i])
    }

    // Prepare input
    let input = try MLDictionaryFeatureProvider(dictionary: ["audio": audioArray])

    // Run prediction
    let output = try model.prediction(from: input)

    // Extract segments output: shape assumed (1, frames, 7)
    guard let segmentOutput = output.featureValue(for: "segments")?.multiArrayValue else {
        throw NSError(domain: "ModelOutput", code: -1, userInfo: [NSLocalizedDescriptionKey: "Missing segments output"])
    }

    let frames = segmentOutput.shape[1].intValue
    let combinations = segmentOutput.shape[2].intValue

    // Convert MLMultiArray to [[[Float]]]
    var segments = Array(repeating: Array(repeating: Array(repeating: 0.0 as Float, count: combinations), count: frames), count: 1)

    for f in 0 ..< frames {
        for c in 0 ..< combinations {
            let index = f * combinations + c
            segments[0][f][c] = segmentOutput[index].floatValue
        }
    }

    // Apply powerset conversion
    let binarizedSegments = powersetConversion(segments)

    // Assume segments shape is (1, 589, 3)
    guard binarizedSegments.count == 1 else {
        fatalError("Expected batch size 1")
    }

    let b_frames = binarizedSegments[0]
    let numSpeakers = b_frames[0].count

    // Initialize sum array
    var speakerSums = Array(repeating: 0.0 as Float, count: numSpeakers)

    // Sum across axis 1 (frames)
    for frame in b_frames {
        for (i, value) in frame.enumerated() {
            speakerSums[i] += value
        }
    }

    print("Sum across axis 1 (frames): \(speakerSums)")

    return binarizedSegments
}

func loadModel(from path: String) throws -> MLModel {
    let url = URL(fileURLWithPath: path)
    let model = try MLModel(contentsOf: url)
    return model
}

do {
    let modelPath = "./pyannote_segmentation.mlmodelc"
    let embeddingPath = "./wespeaker.mlmodelc"
    let model = try loadModel(from: modelPath)
    let embeddingModel = try loadModel(from: embeddingPath)
    print("Model loaded successfully.")

    // let audioPath = "./first_10_seconds.wav"
    let audioPath = "./TS3003b_mix_headset.wav"

    let audioSamples = try loadAudioSamples(from: URL(fileURLWithPath: audioPath))
    try chunkAndRunSegmentation(samples: audioSamples, model: model, embeddingModel: embeddingModel)
} catch {
    print("Error: \(error)")
}