sammy786 commited on
Commit
7083e4f
·
1 Parent(s): c9ded68

Delete eval.py

Browse files
Files changed (1) hide show
  1. eval.py +0 -123
eval.py DELETED
@@ -1,123 +0,0 @@
1
- from datasets import load_dataset, load_metric, Audio, Dataset
2
- from transformers import pipeline, AutoFeatureExtractor
3
- import re
4
- import argparse
5
- import unicodedata
6
- from typing import Dict
7
-
8
-
9
- def log_results(result: Dataset, args: Dict[str, str]):
10
- """ DO NOT CHANGE. This function computes and logs the result metrics. """
11
-
12
- log_outputs = args.log_outputs
13
- dataset_id = "_".join(args.dataset.split("/") + [args.config, args.split])
14
-
15
- # load metric
16
- wer = load_metric("wer")
17
- cer = load_metric("cer")
18
-
19
- # compute metrics
20
- wer_result = wer.compute(references=result["target"], predictions=result["prediction"])
21
- cer_result = cer.compute(references=result["target"], predictions=result["prediction"])
22
-
23
- # print & log results
24
- result_str = (
25
- f"WER: {wer_result}\n"
26
- f"CER: {cer_result}"
27
- )
28
- print(result_str)
29
-
30
- with open(f"{dataset_id}_eval_results.txt", "w") as f:
31
- f.write(result_str)
32
-
33
- # log all results in text file. Possibly interesting for analysis
34
- if log_outputs is not None:
35
- pred_file = f"log_{dataset_id}_predictions.txt"
36
- target_file = f"log_{dataset_id}_targets.txt"
37
-
38
- with open(pred_file, "w") as p, open(target_file, "w") as t:
39
-
40
- # mapping function to write output
41
- def write_to_file(batch, i):
42
- p.write(f"{i}" + "\n")
43
- p.write(batch["prediction"] + "\n")
44
- t.write(f"{i}" + "\n")
45
- t.write(batch["target"] + "\n")
46
-
47
- result.map(write_to_file, with_indices=True)
48
-
49
-
50
- def normalize_text(text: str) -> str:
51
- """ DO ADAPT FOR YOUR USE CASE. this function normalizes the target text. """
52
-
53
- chars_to_ignore_regex = '[\,\.\!\-\;\:\"\“\%\”\�\'\...\…\–\é]'
54
-
55
- text = re.sub(r'[ʻʽʼ‘’´`]', r"'", text)
56
- text = re.sub(chars_to_ignore_regex, "", text).lower().strip()
57
- text = re.sub(r"([b-df-hj-np-tv-z])' ([aeiou])", r"\1'\2", text)
58
- text = re.sub(r"(-| '|' | +)", " ", text)
59
-
60
- return text
61
-
62
-
63
- def main(args):
64
- # load dataset
65
- dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True)
66
-
67
- # for testing: only process the first two examples as a test
68
- # dataset = dataset.select(range(10))
69
-
70
- # load processor
71
- feature_extractor = AutoFeatureExtractor.from_pretrained(args.model_id)
72
- sampling_rate = feature_extractor.sampling_rate
73
-
74
- # resample audio
75
- dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
76
-
77
- # load eval pipeline
78
- asr = pipeline("automatic-speech-recognition", model=args.model_id)
79
-
80
- # map function to decode audio
81
- def map_to_pred(batch):
82
- prediction = asr(batch["audio"]["array"], chunk_length_s=args.chunk_length_s, stride_length_s=args.stride_length_s)
83
-
84
- batch["prediction"] = prediction["text"]
85
- batch["target"] = normalize_text(batch["sentence"])
86
- return batch
87
-
88
- # run inference on all examples
89
- result = dataset.map(map_to_pred, remove_columns=dataset.column_names)
90
-
91
- # compute and log_results
92
- # do not change function below
93
- log_results(result, args)
94
-
95
-
96
- if __name__ == "__main__":
97
- parser = argparse.ArgumentParser()
98
-
99
- parser.add_argument(
100
- "--model_id", type=str, required=True, help="Model identifier. Should be loadable with 🤗 Transformers"
101
- )
102
- parser.add_argument(
103
- "--dataset", type=str, required=True, help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets"
104
- )
105
- parser.add_argument(
106
- "--config", type=str, required=True, help="Config of the dataset. *E.g.* `'en'` for Common Voice"
107
- )
108
- parser.add_argument(
109
- "--split", type=str, required=True, help="Split of the dataset. *E.g.* `'test'`"
110
- )
111
- parser.add_argument(
112
- "--chunk_length_s", type=float, default=None, help="Chunk length in seconds. Defaults to None. For long audio files a good value would be 5.0 seconds."
113
- )
114
- parser.add_argument(
115
- "--stride_length_s", type=float, default=None, help="Stride of the audio chunks. Defaults to None. For long audio files a good value would be 1.0 seconds."
116
- )
117
- parser.add_argument(
118
- "--log_outputs", action='store_true', help="If defined, write outputs to log file for analysis."
119
- )
120
- args = parser.parse_args()
121
-
122
- main(args)
123
-