aiola commited on
Commit
561ba20
·
verified ·
1 Parent(s): 7907341

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +89 -3
README.md CHANGED
@@ -1,3 +1,89 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - numind/NuNER
5
+ language:
6
+ - en
7
+ pipeline_tag: zero-shot-classification
8
+ tags:
9
+ - asr
10
+ - Automatic Speech Recognition
11
+ - Whisper
12
+ - Named entity recognition
13
+ ---
14
+
15
+ # Whisper-NER
16
+
17
+ - Demo: https://huggingface.co/spaces/aiola/whisper-ner-v1
18
+ - Peper: [_WhisperNER: Unified Open Named Entity and Speech Recognition_](https://arxiv.org/abs/2409.08107).
19
+ - Code: https://github.com/aiola-lab/whisper-ner
20
+
21
+ We introduce WhisperNER, a novel model that allows joint speech transcription and entity recognition.
22
+ WhisperNER supports open-type NER, enabling recognition of diverse and evolving entities at inference. The WhisperNER model is designed as a strong base model for the downstream task of ASR with NER, and can be fine-tuned on specific datasets for improved performance.
23
+
24
+ **NOTE:** This model also support entity masking directly on the output transcript, especially relevant for PII use cases.
25
+
26
+ ---------
27
+
28
+ ## Training Details
29
+ `aiola/whisper-ner-v1` was trained on the NuNER dataset to perform joint audio transcription and NER tagging.
30
+ The model was trained and evaluated only on English data. Check out the [paper](https://arxiv.org/abs/2409.08107) for full details.
31
+
32
+ ---------
33
+
34
+ ## Usage
35
+
36
+ Inference can be done using the following code (for inference code and more details check out the [whisper-ner repo](https://github.com/aiola-lab/whisper-ner)).:
37
+
38
+ ```python
39
+ import torch
40
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
41
+
42
+ model_path = "aiola/whisper-ner-v1"
43
+ audio_file_path = "path/to/audio/file"
44
+ prompt = "person, company, location" # comma separated entity tags
45
+ apply_entity_mask = False # change to True for entity masking
46
+ mask_token = "<|mask|>"
47
+
48
+ if apply_entity_mask:
49
+ prompt = f"{mask_token}{prompt}"
50
+
51
+ # load model and processor from pre-trained
52
+ processor = WhisperProcessor.from_pretrained(model_path)
53
+ model = WhisperForConditionalGeneration.from_pretrained(model_path)
54
+
55
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
+ model = model.to(device)
57
+
58
+ # load audio file: user is responsible for loading the audio files themselves
59
+ target_sample_rate = 16000
60
+ signal, sampling_rate = torchaudio.load(audio_file_path)
61
+ resampler = torchaudio.transforms.Resample(sampling_rate, target_sample_rate)
62
+ signal = resampler(signal)
63
+ # convert to mono or remove first dim if needed
64
+ if signal.ndim == 2:
65
+ signal = torch.mean(signal, dim=0)
66
+ # pre-process to get the input features
67
+ input_features = processor(
68
+ signal, sampling_rate=target_sample_rate, return_tensors="pt"
69
+ ).input_features
70
+ input_features = input_features.to(device)
71
+
72
+ prompt_ids = processor.get_prompt_ids(prompt.lower(), return_tensors="pt")
73
+ prompt_ids = prompt_ids.to(device)
74
+
75
+ # generate token ids by running model forward sequentially
76
+ with torch.no_grad():
77
+ predicted_ids = model.generate(
78
+ input_features,
79
+ prompt_ids=prompt_ids,
80
+ generation_config=model.generation_config,
81
+ language="en",
82
+ )
83
+
84
+ # post-process token ids to text, remove prompt
85
+ transcription = processor.batch_decode(
86
+ predicted_ids, skip_special_tokens=True
87
+ )[0]
88
+ print(transcription)
89
+ ```