|
|
|
import re |
|
from difflib import SequenceMatcher |
|
|
|
from rapidfuzz import string_metric |
|
|
|
|
|
def cal_true_positive_char(pred, gt): |
|
"""Calculate correct character number in prediction. |
|
|
|
Args: |
|
pred (str): Prediction text. |
|
gt (str): Ground truth text. |
|
|
|
Returns: |
|
true_positive_char_num (int): The true positive number. |
|
""" |
|
|
|
all_opt = SequenceMatcher(None, pred, gt) |
|
true_positive_char_num = 0 |
|
for opt, _, _, s2, e2 in all_opt.get_opcodes(): |
|
if opt == 'equal': |
|
true_positive_char_num += (e2 - s2) |
|
else: |
|
pass |
|
return true_positive_char_num |
|
|
|
|
|
def count_matches(pred_texts, gt_texts): |
|
"""Count the various match number for metric calculation. |
|
|
|
Args: |
|
pred_texts (list[str]): Predicted text string. |
|
gt_texts (list[str]): Ground truth text string. |
|
|
|
Returns: |
|
match_res: (dict[str: int]): Match number used for |
|
metric calculation. |
|
""" |
|
match_res = { |
|
'gt_char_num': 0, |
|
'pred_char_num': 0, |
|
'true_positive_char_num': 0, |
|
'gt_word_num': 0, |
|
'match_word_num': 0, |
|
'match_word_ignore_case': 0, |
|
'match_word_ignore_case_symbol': 0 |
|
} |
|
comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]') |
|
norm_ed_sum = 0.0 |
|
for pred_text, gt_text in zip(pred_texts, gt_texts): |
|
if gt_text == pred_text: |
|
match_res['match_word_num'] += 1 |
|
gt_text_lower = gt_text.lower() |
|
pred_text_lower = pred_text.lower() |
|
if gt_text_lower == pred_text_lower: |
|
match_res['match_word_ignore_case'] += 1 |
|
gt_text_lower_ignore = comp.sub('', gt_text_lower) |
|
pred_text_lower_ignore = comp.sub('', pred_text_lower) |
|
if gt_text_lower_ignore == pred_text_lower_ignore: |
|
match_res['match_word_ignore_case_symbol'] += 1 |
|
match_res['gt_word_num'] += 1 |
|
|
|
|
|
edit_dist = string_metric.levenshtein(pred_text_lower_ignore, |
|
gt_text_lower_ignore) |
|
norm_ed = float(edit_dist) / max(1, len(gt_text_lower_ignore), |
|
len(pred_text_lower_ignore)) |
|
norm_ed_sum += norm_ed |
|
|
|
|
|
match_res['gt_char_num'] += len(gt_text_lower_ignore) |
|
match_res['pred_char_num'] += len(pred_text_lower_ignore) |
|
true_positive_char_num = cal_true_positive_char( |
|
pred_text_lower_ignore, gt_text_lower_ignore) |
|
match_res['true_positive_char_num'] += true_positive_char_num |
|
|
|
normalized_edit_distance = norm_ed_sum / max(1, len(gt_texts)) |
|
match_res['ned'] = normalized_edit_distance |
|
|
|
return match_res |
|
|
|
|
|
def eval_ocr_metric(pred_texts, gt_texts): |
|
"""Evaluate the text recognition performance with metric: word accuracy and |
|
1-N.E.D. See https://rrc.cvc.uab.es/?ch=14&com=tasks for details. |
|
|
|
Args: |
|
pred_texts (list[str]): Text strings of prediction. |
|
gt_texts (list[str]): Text strings of ground truth. |
|
|
|
Returns: |
|
eval_res (dict[str: float]): Metric dict for text recognition, include: |
|
- word_acc: Accuracy in word level. |
|
- word_acc_ignore_case: Accuracy in word level, ignore letter case. |
|
- word_acc_ignore_case_symbol: Accuracy in word level, ignore |
|
letter case and symbol. (default metric for |
|
academic evaluation) |
|
- char_recall: Recall in character level, ignore |
|
letter case and symbol. |
|
- char_precision: Precision in character level, ignore |
|
letter case and symbol. |
|
- 1-N.E.D: 1 - normalized_edit_distance. |
|
""" |
|
assert isinstance(pred_texts, list) |
|
assert isinstance(gt_texts, list) |
|
assert len(pred_texts) == len(gt_texts) |
|
|
|
match_res = count_matches(pred_texts, gt_texts) |
|
eps = 1e-8 |
|
char_recall = 1.0 * match_res['true_positive_char_num'] / ( |
|
eps + match_res['gt_char_num']) |
|
char_precision = 1.0 * match_res['true_positive_char_num'] / ( |
|
eps + match_res['pred_char_num']) |
|
word_acc = 1.0 * match_res['match_word_num'] / ( |
|
eps + match_res['gt_word_num']) |
|
word_acc_ignore_case = 1.0 * match_res['match_word_ignore_case'] / ( |
|
eps + match_res['gt_word_num']) |
|
word_acc_ignore_case_symbol = 1.0 * match_res[ |
|
'match_word_ignore_case_symbol'] / ( |
|
eps + match_res['gt_word_num']) |
|
|
|
eval_res = {} |
|
eval_res['word_acc'] = word_acc |
|
eval_res['word_acc_ignore_case'] = word_acc_ignore_case |
|
eval_res['word_acc_ignore_case_symbol'] = word_acc_ignore_case_symbol |
|
eval_res['char_recall'] = char_recall |
|
eval_res['char_precision'] = char_precision |
|
eval_res['1-N.E.D'] = 1.0 - match_res['ned'] |
|
|
|
for key, value in eval_res.items(): |
|
eval_res[key] = float('{:.4f}'.format(value)) |
|
|
|
return eval_res |
|
|