Beijuka's picture
Update README.md
be286ef verified
metadata
library_name: transformers
language:
  - en
license: mit
base_model: pyannote/speaker-diarization-3.0
tags:
  - speaker-diarization
  - speaker-segmentation
  - generated_from_trainer
datasets:
  - talkbank/callhome
model-index:
  - name: speaker-segmentation-fine-tuned-callhome
    results: []

speaker-segmentation-fine-tuned-callhome

This model is a fine-tuned version of pyannote/speaker-diarization-3.0 on the talkbank/callhome dataset. It achieves the following results on the evaluation set:

  • Loss: 0.4725
  • Model Preparation Time: 0.0071
  • Der: 0.1767
  • False Alarm: 0.0593
  • Missed Detection: 0.0757
  • Confusion: 0.0417

Model description

This model is a fine-tuned version of pyannote/speaker-diarization-3.0 for speaker segmentation, trained on the talkbank/callhome dataset.

It can be loaded using the code:

from pyannote.audio import Pipeline
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load fine-tuned pipeline
pipeline = Pipeline.from_pretrained(
    "Beijuka/speaker-segmentation-fine-tuned-callhome",
    use_auth_token="your_huggingface_token"  # Replace this!
)
pipeline.to(device)

# Run diarization
audio_file = "/path/to/audio.mp3"
diarization = pipeline(audio_file)

# Save RTTM output
with open("finetunemodel.rttm", "w") as f:
    diarization.write_rttm(f)

# Print segments
print(diarization)

You can now use the pipeline on the train dataset:

# load dataset example
dataset = load_dataset("talkbank/callhome", 'eng', split="data")
sample = dataset[0]["audio"]

# pre-process inputs
sample["waveform"] = torch.from_numpy(sample.pop("array")[None, :]).to(device, dtype=model.dtype)
sample["sample_rate"] = sample.pop("sampling_rate")

# perform inference
diarization = pipeline(sample)

# dump the diarization output to disk using RTTM format
with open("audio.rttm", "w") as rttm:
    diarization.write_rttm(rttm)

Intended uses & limitations

  • Intended for use in diarization pipelines for telephone-style audio.
  • May not generalize to far-field audio or more than two speakers.

Training and evaluation data

  • Training and validation used the talkbank/callhome dataset.
  • 2-speaker telephone conversations with speaker turn annotations.

Training procedure

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 0.001
  • train_batch_size: 32
  • eval_batch_size: 32
  • seed: 42
  • optimizer: Use OptimizerNames.ADAMW_TORCH with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
  • lr_scheduler_type: cosine
  • num_epochs: 5

Training results

Training Loss Epoch Step Validation Loss Model Preparation Time Der False Alarm Missed Detection Confusion
0.3959 1.0 362 0.4800 0.0071 0.1932 0.0575 0.0781 0.0577
0.4226 2.0 724 0.4797 0.0071 0.1918 0.0640 0.0723 0.0555
0.4117 3.0 1086 0.4726 0.0071 0.1872 0.0530 0.0789 0.0553
0.3875 4.0 1448 0.4671 0.0071 0.1852 0.0549 0.0769 0.0534
0.3646 5.0 1810 0.4710 0.0071 0.1872 0.0571 0.0747 0.0554

Framework versions

  • Transformers 4.52.3
  • Pytorch 2.6.0+cu126
  • Datasets 3.6.0
  • Tokenizers 0.21.1