bweng's picture
Upload 3 files
1ac1652 verified
raw
history blame
15.4 kB
import Accelerate
import AVFoundation
import CoreML
import Foundation
struct Segment: Hashable {
let start: Double
let end: Double
}
struct SlidingWindow {
var start: Double
var duration: Double
var step: Double
func time(forFrame index: Int) -> Double {
return start + Double(index) * step
}
func segment(forFrame index: Int) -> Segment {
let s = time(forFrame: index)
return Segment(start: s, end: s + duration)
}
}
struct SlidingWindowFeature {
var data: [[[Float]]] // (1, 589, 3)
var slidingWindow: SlidingWindow
}
var speakerDB: [String: [Float]] = [:] // Global speaker database
let threshold: Float = 0.7 // Distance threshold
func cosineDistance(_ x: [Float], _ y: [Float]) -> Float {
precondition(x.count == y.count, "Vectors must be same size")
let dot = zip(x, y).map(*).reduce(0, +)
let normX = sqrt(x.map { $0 * $0 }.reduce(0, +))
let normY = sqrt(y.map { $0 * $0 }.reduce(0, +))
return 1.0 - (dot / (normX * normY + 1e-6))
}
func updateSpeakerDB(_ speaker: String, _ newEmbedding: [Float], alpha: Float = 0.9) {
guard var oldEmbedding = speakerDB[speaker] else { return }
for i in 0..<oldEmbedding.count {
oldEmbedding[i] = alpha * oldEmbedding[i] + (1 - alpha) * newEmbedding[i]
}
speakerDB[speaker] = oldEmbedding
}
func assignSpeaker(embedding: [Float], threshold: Float = 0.7) -> String {
if speakerDB.isEmpty {
let speaker = "Speaker 1"
speakerDB[speaker] = embedding
return speaker
}
var minDistance: Float = Float.greatestFiniteMagnitude
var identifiedSpeaker: String? = nil
for (speaker, refEmbedding) in speakerDB {
let distance = cosineDistance(embedding, refEmbedding)
if distance < minDistance {
minDistance = distance
identifiedSpeaker = speaker
}
}
if let bestSpeaker = identifiedSpeaker {
if minDistance > threshold {
let newSpeaker = "Speaker \(speakerDB.count + 1)"
speakerDB[newSpeaker] = embedding
return newSpeaker
} else {
updateSpeakerDB(bestSpeaker, embedding)
return bestSpeaker
}
}
return "Unknown"
}
func getAnnotation(annotation: inout [Segment: String],
speakerMapping: [Int: Int],
binarizedSegments: [[[Float]]],
slidingWindow: SlidingWindow) {
let segmentation = binarizedSegments[0] // shape: [589][3]
let numFrames = segmentation.count
// Step 1: argmax to get dominant speaker per frame
var frameSpeakers: [Int] = []
for frame in segmentation {
if let maxIdx = frame.indices.max(by: { frame[$0] < frame[$1] }) {
frameSpeakers.append(maxIdx)
} else {
frameSpeakers.append(0) // fallback
}
}
// Step 2: group contiguous same-speaker segments
var currentSpeaker = frameSpeakers[0]
var startFrame = 0
for i in 1..<numFrames {
if frameSpeakers[i] != currentSpeaker {
let startTime = slidingWindow.time(forFrame: startFrame)
let endTime = slidingWindow.time(forFrame: i)
let segment = Segment(start: startTime, end: endTime)
if let mappedSpeaker = speakerMapping[currentSpeaker] {
annotation[segment] = "Speaker \(mappedSpeaker)"
}
currentSpeaker = frameSpeakers[i]
startFrame = i
}
}
// Final segment
let finalStart = slidingWindow.time(forFrame: startFrame)
let finalEnd = slidingWindow.segment(forFrame: numFrames - 1).end
let finalSegment = Segment(start: finalStart, end: finalEnd)
if let mappedSpeaker = speakerMapping[currentSpeaker] {
annotation[finalSegment] = "Speaker \(mappedSpeaker)"
}
}
func getEmbedding(audioChunk: [Float],
binarizedSegments _: [[[Float]]],
slidingWindowSegments: SlidingWindowFeature,
chunkSize: Int = 10 * 16000,
embeddingModel: MLModel) -> MLMultiArray?
{
// 1. Create audio_tensor of shape (1, 1, chunkSize)
let audioTensor = audioChunk
let batchSize = slidingWindowSegments.data.count
let numFrames = slidingWindowSegments.data[0].count
let numSpeakers = slidingWindowSegments.data[0][0].count
// 2. Compute clean_frames = 1.0 where active speakers < 2
var cleanFrames = Array(repeating: Array(repeating: 0.0 as Float, count: 1), count: numFrames)
for f in 0 ..< numFrames {
let frame = slidingWindowSegments.data[0][f]
let speakerSum = frame.reduce(0, +)
cleanFrames[f][0] = (speakerSum < 2.0) ? 1.0 : 0.0
}
// 3. Multiply slidingWindowSegments.data by cleanFrames
var cleanSegmentData = Array(
repeating: Array(repeating: Array(repeating: 0.0 as Float, count: numSpeakers), count: numFrames),
count: 1
)
for f in 0 ..< numFrames {
for s in 0 ..< numSpeakers {
cleanSegmentData[0][f][s] = slidingWindowSegments.data[0][f][s] * cleanFrames[f][0]
}
}
// 4. Flatten audio tensor to shape (3, 160000)
var audioBatch: [[Float]] = []
for _ in 0 ..< 3 {
audioBatch.append(audioTensor)
}
// 5. Transpose mask shape to (3, 589)
var cleanMasks: [[Float]] = Array(repeating: Array(repeating: 0.0, count: numFrames), count: numSpeakers)
for s in 0 ..< numSpeakers {
for f in 0 ..< numFrames {
cleanMasks[s][f] = cleanSegmentData[0][f][s]
}
}
// 6. Prepare MLMultiArray inputs
guard let waveformArray = try? MLMultiArray(shape: [3, chunkSize] as [NSNumber], dataType: .float32),
let maskArray = try? MLMultiArray(shape: [3, numFrames] as [NSNumber], dataType: .float32)
else {
print("Failed to allocate MLMultiArray")
return nil
}
// Fill waveform
for s in 0 ..< 3 {
for i in 0 ..< chunkSize {
waveformArray[s * chunkSize + i] = NSNumber(value: audioBatch[s][i])
}
}
// Fill mask
for s in 0 ..< 3 {
for f in 0 ..< numFrames {
maskArray[s * numFrames + f] = NSNumber(value: cleanMasks[s][f])
}
}
// 7. Run model
let inputs: [String: Any] = [
"waveform": waveformArray,
"mask": maskArray,
]
guard let output = try? embeddingModel.prediction(from: MLDictionaryFeatureProvider(dictionary: inputs)) else {
print("Embedding model prediction failed")
return nil
}
return output.featureValue(for: "embedding")?.multiArrayValue
}
func loadAudioSamples(from url: URL, expectedSampleRate: Double = 16000.0) throws -> [Float] {
let file = try AVAudioFile(forReading: url)
let format = AVAudioFormat(commonFormat: .pcmFormatFloat32,
sampleRate: expectedSampleRate,
channels: 1,
interleaved: false)!
let engine = AVAudioEngine()
let player = AVAudioPlayerNode()
engine.attach(player)
let converter = AVAudioConverter(from: file.processingFormat, to: format)!
let frameCapacity = AVAudioFrameCount(file.length)
let buffer = AVAudioPCMBuffer(pcmFormat: file.processingFormat, frameCapacity: frameCapacity)!
try file.read(into: buffer)
let outputBuffer = AVAudioPCMBuffer(pcmFormat: format, frameCapacity: frameCapacity)!
let inputBlock: AVAudioConverterInputBlock = { _, outStatus in
outStatus.pointee = .haveData
return buffer
}
try converter.convert(to: outputBuffer, error: nil, withInputFrom: inputBlock)
guard let floatChannelData = outputBuffer.floatChannelData else {
throw NSError(domain: "Audio", code: -1, userInfo: [NSLocalizedDescriptionKey: "Missing float data"])
}
let channelData = floatChannelData[0]
let samples = Array(UnsafeBufferPointer(start: channelData, count: Int(outputBuffer.frameLength)))
return samples
}
func chunkAndRunSegmentation(samples: [Float], chunkSize: Int = 160_000, model: MLModel, embeddingModel: MLModel) throws {
let totalSamples = samples.count
let numberOfChunks = Int(ceil(Double(totalSamples) / Double(chunkSize)))
var annotations: [Segment: String] = [:]
for i in 0 ..< numberOfChunks {
let start = i * chunkSize
let end = min((i + 1) * chunkSize, totalSamples)
let chunk = Array(samples[start ..< end])
// If chunk is shorter than 10s, pad with zeros
var paddedChunk = chunk
if chunk.count < chunkSize {
paddedChunk += Array(repeating: 0.0, count: chunkSize - chunk.count)
}
let binarizedSegments = try getSegments(audioChunk: paddedChunk, model: model)
let frames = SlidingWindow(start: Double(i) * 10.0, duration: 0.0619375, step: 0.016875)
let slidingFeature = SlidingWindowFeature(data: binarizedSegments, slidingWindow: frames)
if let embeddings = getEmbedding(audioChunk: paddedChunk,
binarizedSegments: binarizedSegments,
slidingWindowSegments: slidingFeature,
embeddingModel: embeddingModel)
{
print("Embeddings shape: \(embeddings.shape.map { $0.intValue })")
let shape = embeddings.shape.map { $0.intValue } // [3, 256]
let numSpeakers = shape[0]
let embeddingDim = shape[1]
let strides = embeddings.strides.map { $0.intValue }
var speakerSums = [Float](repeating: 0.0, count: numSpeakers)
for s in 0 ..< numSpeakers {
for d in 0 ..< embeddingDim {
let index = s * strides[0] + d * strides[1]
speakerSums[s] += embeddings[index].floatValue
}
}
print("Sum along axis 1 (per speaker): \(speakerSums)")
// Step 3: Assign speaker label to each embedding
var speakerLabels = [String]()
for s in 0..<numSpeakers {
var embeddingVec = [Float](repeating: 0.0, count: embeddingDim)
for d in 0..<embeddingDim {
let index = s * strides[0] + d * strides[1]
embeddingVec[d] = embeddings[index].floatValue
}
let label = assignSpeaker(embedding: embeddingVec)
speakerLabels.append(label)
}
print("Chunk \(i + 1): Assigned Speakers: \(speakerLabels)")
// Step 4: Update annotations
// Map speaker index 0,1,2 → assigned speakerLabels
var labelMapping: [Int: Int] = [:]
for (idx, label) in speakerLabels.enumerated() {
if let spkNum = Int(label.components(separatedBy: " ").last ?? "") {
labelMapping[idx] = spkNum
}
}
getAnnotation(annotation: &annotations,
speakerMapping: labelMapping,
binarizedSegments: binarizedSegments,
slidingWindow: frames)
print("Chunk \(i + 1) → Segments shape: \(binarizedSegments[0].count) frames")
}
}
// Final result
print("\n=== Final Annotations ===")
for (segment, speaker) in annotations.sorted(by: { $0.key.start < $1.key.start }) {
print("\(speaker): \(segment.start) - \(segment.end)")
}
}
func powersetConversion(_ segments: [[[Float]]]) -> [[[Float]]] {
let powerset: [[Int]] = [
[], // 0
[0], // 1
[1], // 2
[2], // 3
[0, 1], // 4
[0, 2], // 5
[1, 2], // 6
]
let batchSize = segments.count
let numFrames = segments[0].count
let numCombos = segments[0][0].count // 7
let numSpeakers = 3
var binarized = Array(
repeating: Array(
repeating: Array(repeating: 0.0 as Float, count: numSpeakers),
count: numFrames
),
count: batchSize
)
for b in 0 ..< batchSize {
for f in 0 ..< numFrames {
let frame = segments[b][f]
// Find index of max value in this frame
guard let bestIdx = frame.indices.max(by: { frame[$0] < frame[$1] }) else {
continue
}
// Mark the corresponding speakers as active
for speaker in powerset[bestIdx] {
binarized[b][f][speaker] = 1.0
}
}
}
return binarized
}
func getSegments(audioChunk: [Float], sampleRate _: Int = 16000, chunkSize: Int = 160_000, model: MLModel) throws -> [[[Float]]] {
// Ensure correct shape: (1, 1, chunk_size)
let audioArray = try MLMultiArray(shape: [1, 1, NSNumber(value: chunkSize)], dataType: .float32)
for i in 0 ..< audioChunk.count {
audioArray[i] = NSNumber(value: audioChunk[i])
}
// Prepare input
let input = try MLDictionaryFeatureProvider(dictionary: ["audio": audioArray])
// Run prediction
let output = try model.prediction(from: input)
// Extract segments output: shape assumed (1, frames, 7)
guard let segmentOutput = output.featureValue(for: "segments")?.multiArrayValue else {
throw NSError(domain: "ModelOutput", code: -1, userInfo: [NSLocalizedDescriptionKey: "Missing segments output"])
}
let frames = segmentOutput.shape[1].intValue
let combinations = segmentOutput.shape[2].intValue
// Convert MLMultiArray to [[[Float]]]
var segments = Array(repeating: Array(repeating: Array(repeating: 0.0 as Float, count: combinations), count: frames), count: 1)
for f in 0 ..< frames {
for c in 0 ..< combinations {
let index = f * combinations + c
segments[0][f][c] = segmentOutput[index].floatValue
}
}
// Apply powerset conversion
let binarizedSegments = powersetConversion(segments)
// Assume segments shape is (1, 589, 3)
guard binarizedSegments.count == 1 else {
fatalError("Expected batch size 1")
}
let b_frames = binarizedSegments[0]
let numSpeakers = b_frames[0].count
// Initialize sum array
var speakerSums = Array(repeating: 0.0 as Float, count: numSpeakers)
// Sum across axis 1 (frames)
for frame in b_frames {
for (i, value) in frame.enumerated() {
speakerSums[i] += value
}
}
print("Sum across axis 1 (frames): \(speakerSums)")
return binarizedSegments
}
func loadModel(from path: String) throws -> MLModel {
let url = URL(fileURLWithPath: path)
let model = try MLModel(contentsOf: url)
return model
}
do {
let modelPath = "./pyannote_segmentation.mlmodelc"
let embeddingPath = "./wespeaker.mlmodelc"
let model = try loadModel(from: modelPath)
let embeddingModel = try loadModel(from: embeddingPath)
print("Model loaded successfully.")
// let audioPath = "./first_10_seconds.wav"
let audioPath = "./TS3003b_mix_headset.wav"
let audioSamples = try loadAudioSamples(from: URL(fileURLWithPath: audioPath))
try chunkAndRunSegmentation(samples: audioSamples, model: model, embeddingModel: embeddingModel)
} catch {
print("Error: \(error)")
}