Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,89 @@
|
|
1 |
-
---
|
2 |
-
license: mit
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
datasets:
|
4 |
+
- numind/NuNER
|
5 |
+
language:
|
6 |
+
- en
|
7 |
+
pipeline_tag: zero-shot-classification
|
8 |
+
tags:
|
9 |
+
- asr
|
10 |
+
- Automatic Speech Recognition
|
11 |
+
- Whisper
|
12 |
+
- Named entity recognition
|
13 |
+
---
|
14 |
+
|
15 |
+
# Whisper-NER
|
16 |
+
|
17 |
+
- Demo: https://huggingface.co/spaces/aiola/whisper-ner-v1
|
18 |
+
- Peper: [_WhisperNER: Unified Open Named Entity and Speech Recognition_](https://arxiv.org/abs/2409.08107).
|
19 |
+
- Code: https://github.com/aiola-lab/whisper-ner
|
20 |
+
|
21 |
+
We introduce WhisperNER, a novel model that allows joint speech transcription and entity recognition.
|
22 |
+
WhisperNER supports open-type NER, enabling recognition of diverse and evolving entities at inference. The WhisperNER model is designed as a strong base model for the downstream task of ASR with NER, and can be fine-tuned on specific datasets for improved performance.
|
23 |
+
|
24 |
+
**NOTE:** This model also support entity masking directly on the output transcript, especially relevant for PII use cases.
|
25 |
+
|
26 |
+
---------
|
27 |
+
|
28 |
+
## Training Details
|
29 |
+
`aiola/whisper-ner-v1` was trained on the NuNER dataset to perform joint audio transcription and NER tagging.
|
30 |
+
The model was trained and evaluated only on English data. Check out the [paper](https://arxiv.org/abs/2409.08107) for full details.
|
31 |
+
|
32 |
+
---------
|
33 |
+
|
34 |
+
## Usage
|
35 |
+
|
36 |
+
Inference can be done using the following code (for inference code and more details check out the [whisper-ner repo](https://github.com/aiola-lab/whisper-ner)).:
|
37 |
+
|
38 |
+
```python
|
39 |
+
import torch
|
40 |
+
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
41 |
+
|
42 |
+
model_path = "aiola/whisper-ner-v1"
|
43 |
+
audio_file_path = "path/to/audio/file"
|
44 |
+
prompt = "person, company, location" # comma separated entity tags
|
45 |
+
apply_entity_mask = False # change to True for entity masking
|
46 |
+
mask_token = "<|mask|>"
|
47 |
+
|
48 |
+
if apply_entity_mask:
|
49 |
+
prompt = f"{mask_token}{prompt}"
|
50 |
+
|
51 |
+
# load model and processor from pre-trained
|
52 |
+
processor = WhisperProcessor.from_pretrained(model_path)
|
53 |
+
model = WhisperForConditionalGeneration.from_pretrained(model_path)
|
54 |
+
|
55 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
56 |
+
model = model.to(device)
|
57 |
+
|
58 |
+
# load audio file: user is responsible for loading the audio files themselves
|
59 |
+
target_sample_rate = 16000
|
60 |
+
signal, sampling_rate = torchaudio.load(audio_file_path)
|
61 |
+
resampler = torchaudio.transforms.Resample(sampling_rate, target_sample_rate)
|
62 |
+
signal = resampler(signal)
|
63 |
+
# convert to mono or remove first dim if needed
|
64 |
+
if signal.ndim == 2:
|
65 |
+
signal = torch.mean(signal, dim=0)
|
66 |
+
# pre-process to get the input features
|
67 |
+
input_features = processor(
|
68 |
+
signal, sampling_rate=target_sample_rate, return_tensors="pt"
|
69 |
+
).input_features
|
70 |
+
input_features = input_features.to(device)
|
71 |
+
|
72 |
+
prompt_ids = processor.get_prompt_ids(prompt.lower(), return_tensors="pt")
|
73 |
+
prompt_ids = prompt_ids.to(device)
|
74 |
+
|
75 |
+
# generate token ids by running model forward sequentially
|
76 |
+
with torch.no_grad():
|
77 |
+
predicted_ids = model.generate(
|
78 |
+
input_features,
|
79 |
+
prompt_ids=prompt_ids,
|
80 |
+
generation_config=model.generation_config,
|
81 |
+
language="en",
|
82 |
+
)
|
83 |
+
|
84 |
+
# post-process token ids to text, remove prompt
|
85 |
+
transcription = processor.batch_decode(
|
86 |
+
predicted_ids, skip_special_tokens=True
|
87 |
+
)[0]
|
88 |
+
print(transcription)
|
89 |
+
```
|