File size: 1,646 Bytes
56cfa73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import argparse
import json
import string
from pathlib import Path

from jiwer import cer, wer


def normalize_text(text: str) -> str:
    """
    Lowercase and remove punctuation from a string.

    Args:
        text (str): Input string

    Returns:
        str: Normalized string
    """
    # Lowercase
    text = text.lower()
    # Remove punctuation
    text = text.translate(str.maketrans("", "", string.punctuation))
    return text


def load_jsonl_dict(path):
    transcripts = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            data = json.loads(line)
            transcripts[Path(data["file"]).name] = data["transcript"]
    return transcripts


def main(args):
    ref_dict = load_jsonl_dict(args.reference)
    hyp_dict = load_jsonl_dict(args.hypothesis)

    common_files = set(ref_dict.keys()) & set(hyp_dict.keys())

    if not common_files:
        print("No common files between reference and hypothesis.")
        return

    refs = [normalize_text(ref_dict[f]) for f in sorted(common_files)]
    hyps = [normalize_text(hyp_dict[f]) for f in sorted(common_files)]

    cer_score = cer(refs, hyps)
    wer_score = wer(refs, hyps)
    print(f"CER: {cer_score:.3%}")
    print(f"WER: {wer_score:.3%}")
    print(f"Evaluated on {len(common_files)} files.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--reference", type=str, required=True, help="Path to reference JSONL"
    )
    parser.add_argument(
        "--hypothesis", type=str, required=True, help="Path to hypothesis JSONL"
    )
    args = parser.parse_args()
    main(args)