Update README.md
Browse files
README.md
CHANGED
|
@@ -300,6 +300,7 @@ import torch
|
|
| 300 |
from transformers import pipeline
|
| 301 |
from datasets import load_dataset
|
| 302 |
from evaluate import load
|
|
|
|
| 303 |
|
| 304 |
# model config
|
| 305 |
model_id = "kotoba-tech/kotoba-whisper-v1.0"
|
|
@@ -307,6 +308,7 @@ torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
|
| 307 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 308 |
model_kwargs = {"attn_implementation": "sdpa"} if torch.cuda.is_available() else {}
|
| 309 |
generate_kwargs = {"language": "japanese", "task": "transcribe"}
|
|
|
|
| 310 |
|
| 311 |
# data config
|
| 312 |
dataset_name = "japanese-asr/ja_asr.reazonspeech_test"
|
|
@@ -326,8 +328,8 @@ pipe = pipeline(
|
|
| 326 |
# load the dataset and sample the audio with 16kHz
|
| 327 |
dataset = load_dataset(dataset_name, split="test")
|
| 328 |
transcriptions = pipe(dataset['audio'])
|
| 329 |
-
transcriptions = [i['text'].replace(" ", "") for i in transcriptions]
|
| 330 |
-
references = [i.replace(" ", "") for i in dataset['transcription']]
|
| 331 |
|
| 332 |
# compute the CER metric
|
| 333 |
cer_metric = load("cer")
|
|
|
|
| 300 |
from transformers import pipeline
|
| 301 |
from datasets import load_dataset
|
| 302 |
from evaluate import load
|
| 303 |
+
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
|
| 304 |
|
| 305 |
# model config
|
| 306 |
model_id = "kotoba-tech/kotoba-whisper-v1.0"
|
|
|
|
| 308 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 309 |
model_kwargs = {"attn_implementation": "sdpa"} if torch.cuda.is_available() else {}
|
| 310 |
generate_kwargs = {"language": "japanese", "task": "transcribe"}
|
| 311 |
+
normalizer = BasicTextNormalizer()
|
| 312 |
|
| 313 |
# data config
|
| 314 |
dataset_name = "japanese-asr/ja_asr.reazonspeech_test"
|
|
|
|
| 328 |
# load the dataset and sample the audio with 16kHz
|
| 329 |
dataset = load_dataset(dataset_name, split="test")
|
| 330 |
transcriptions = pipe(dataset['audio'])
|
| 331 |
+
transcriptions = [normalizer(i['text']).replace(" ", "") for i in transcriptions]
|
| 332 |
+
references = [normalizer(i).replace(" ", "") for i in dataset['transcription']]
|
| 333 |
|
| 334 |
# compute the CER metric
|
| 335 |
cer_metric = load("cer")
|