Include weighted mode. ORIGINAL FAIREVAL SCRIPT IS MODIFIED
Browse files- FairEval.py +23 -7
- FairEvalUtils.py +2 -1
FairEval.py
CHANGED
|
@@ -119,6 +119,7 @@ class FairEvaluation(evaluate.Metric):
|
|
| 119 |
suffix: bool = False,
|
| 120 |
scheme: Optional[str] = None,
|
| 121 |
mode: Optional[str] = 'fair',
|
|
|
|
| 122 |
error_format: Optional[str] = 'count',
|
| 123 |
zero_division: Union[str, int] = "warn",
|
| 124 |
):
|
|
@@ -147,25 +148,38 @@ class FairEvaluation(evaluate.Metric):
|
|
| 147 |
pred_spans = seq_to_fair(pred_spans)
|
| 148 |
|
| 149 |
# (3) COUNT ERRORS AND CALCULATE SCORES
|
| 150 |
-
total_errors = compare_spans([], [])
|
| 151 |
-
|
| 152 |
for i in range(len(true_spans)):
|
| 153 |
sentence_errors = compare_spans(true_spans[i], pred_spans[i])
|
| 154 |
total_errors = add_dict(total_errors, sentence_errors)
|
| 155 |
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
del results['conf']
|
| 158 |
|
| 159 |
-
# (4) SELECT OUTPUT MODE AND REFORMAT AS SEQEVAL
|
|
|
|
| 160 |
output = {}
|
| 161 |
total_trad_errors = results['overall']['traditional']['FP'] + results['overall']['traditional']['FN']
|
| 162 |
total_fair_errors = results['overall']['fair']['FP'] + results['overall']['fair']['FN'] + \
|
| 163 |
results['overall']['fair']['LE'] + results['overall']['fair']['BE'] + \
|
| 164 |
results['overall']['fair']['LBE']
|
| 165 |
|
| 166 |
-
assert
|
|
|
|
| 167 |
assert error_format in ['count', 'proportion'], 'error_format must be \'count\' or \'proportion\''
|
| 168 |
|
|
|
|
| 169 |
if mode == 'traditional':
|
| 170 |
for k, v in results['per_label'][mode].items():
|
| 171 |
if error_format == 'count':
|
|
@@ -174,7 +188,7 @@ class FairEvaluation(evaluate.Metric):
|
|
| 174 |
elif error_format == 'proportion':
|
| 175 |
output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'], 'TP': v['TP'],
|
| 176 |
'FP': v['FP'] / total_trad_errors, 'FN': v['FN'] / total_trad_errors}
|
| 177 |
-
elif mode == 'fair':
|
| 178 |
for k, v in results['per_label'][mode].items():
|
| 179 |
if error_format == 'count':
|
| 180 |
output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'], 'TP': v['TP'],
|
|
@@ -185,10 +199,12 @@ class FairEvaluation(evaluate.Metric):
|
|
| 185 |
'LE': v['LE'] / total_fair_errors, 'BE': v['BE'] / total_fair_errors,
|
| 186 |
'LBE': v['LBE'] / total_fair_errors}
|
| 187 |
|
|
|
|
| 188 |
output['overall_precision'] = results['overall'][mode]['Prec']
|
| 189 |
output['overall_recall'] = results['overall'][mode]['Rec']
|
| 190 |
output['overall_f1'] = results['overall'][mode]['F1']
|
| 191 |
|
|
|
|
| 192 |
if mode == 'traditional':
|
| 193 |
output['TP'] = results['overall'][mode]['TP']
|
| 194 |
output['FP'] = results['overall'][mode]['FP']
|
|
@@ -196,7 +212,7 @@ class FairEvaluation(evaluate.Metric):
|
|
| 196 |
if error_format == 'proportion':
|
| 197 |
output['FP'] = output['FP'] / total_trad_errors
|
| 198 |
output['FN'] = output['FN'] / total_trad_errors
|
| 199 |
-
elif mode == 'fair':
|
| 200 |
output['TP'] = results['overall'][mode]['TP']
|
| 201 |
output['FP'] = results['overall'][mode]['FP']
|
| 202 |
output['FN'] = results['overall'][mode]['FN']
|
|
|
|
| 119 |
suffix: bool = False,
|
| 120 |
scheme: Optional[str] = None,
|
| 121 |
mode: Optional[str] = 'fair',
|
| 122 |
+
weights: dict = None,
|
| 123 |
error_format: Optional[str] = 'count',
|
| 124 |
zero_division: Union[str, int] = "warn",
|
| 125 |
):
|
|
|
|
| 148 |
pred_spans = seq_to_fair(pred_spans)
|
| 149 |
|
| 150 |
# (3) COUNT ERRORS AND CALCULATE SCORES
|
| 151 |
+
total_errors = compare_spans([], [])
|
|
|
|
| 152 |
for i in range(len(true_spans)):
|
| 153 |
sentence_errors = compare_spans(true_spans[i], pred_spans[i])
|
| 154 |
total_errors = add_dict(total_errors, sentence_errors)
|
| 155 |
|
| 156 |
+
if weights is None and mode == 'weighted':
|
| 157 |
+
print("The chosen mode is \'weighted\', but no weights are given. Setting weights to:\n")
|
| 158 |
+
weights = {"TP": {"TP": 1},
|
| 159 |
+
"FP": {"FP": 1},
|
| 160 |
+
"FN": {"FN": 1},
|
| 161 |
+
"LE": {"TP": 0, "FP": 0.5, "FN": 0.5},
|
| 162 |
+
"BE": {"TP": 0.5, "FP": 0.25, "FN": 0.25},
|
| 163 |
+
"LBE": {"TP": 0, "FP": 0.5, "FN": 0.5}}
|
| 164 |
+
print(weights)
|
| 165 |
+
|
| 166 |
+
config = {"labels": "all", "eval_method": [mode], "weights": weights,}
|
| 167 |
+
results = calculate_results(total_errors, config)
|
| 168 |
del results['conf']
|
| 169 |
|
| 170 |
+
# (4) SELECT OUTPUT MODE AND REFORMAT AS SEQEVAL-HUGGINGFACE OUTPUT
|
| 171 |
+
# initialize empty dictionary and count errors
|
| 172 |
output = {}
|
| 173 |
total_trad_errors = results['overall']['traditional']['FP'] + results['overall']['traditional']['FN']
|
| 174 |
total_fair_errors = results['overall']['fair']['FP'] + results['overall']['fair']['FN'] + \
|
| 175 |
results['overall']['fair']['LE'] + results['overall']['fair']['BE'] + \
|
| 176 |
results['overall']['fair']['LBE']
|
| 177 |
|
| 178 |
+
# assert valid options
|
| 179 |
+
assert mode in ['traditional', 'fair', 'weighted'], 'mode must be \'traditional\', \'fair\' or \'weighted\''
|
| 180 |
assert error_format in ['count', 'proportion'], 'error_format must be \'count\' or \'proportion\''
|
| 181 |
|
| 182 |
+
# append entity-level errors and scores
|
| 183 |
if mode == 'traditional':
|
| 184 |
for k, v in results['per_label'][mode].items():
|
| 185 |
if error_format == 'count':
|
|
|
|
| 188 |
elif error_format == 'proportion':
|
| 189 |
output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'], 'TP': v['TP'],
|
| 190 |
'FP': v['FP'] / total_trad_errors, 'FN': v['FN'] / total_trad_errors}
|
| 191 |
+
elif mode == 'fair' or mode == 'weighted':
|
| 192 |
for k, v in results['per_label'][mode].items():
|
| 193 |
if error_format == 'count':
|
| 194 |
output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'], 'TP': v['TP'],
|
|
|
|
| 199 |
'LE': v['LE'] / total_fair_errors, 'BE': v['BE'] / total_fair_errors,
|
| 200 |
'LBE': v['LBE'] / total_fair_errors}
|
| 201 |
|
| 202 |
+
# append overall scores
|
| 203 |
output['overall_precision'] = results['overall'][mode]['Prec']
|
| 204 |
output['overall_recall'] = results['overall'][mode]['Rec']
|
| 205 |
output['overall_f1'] = results['overall'][mode]['F1']
|
| 206 |
|
| 207 |
+
# append overall error counts
|
| 208 |
if mode == 'traditional':
|
| 209 |
output['TP'] = results['overall'][mode]['TP']
|
| 210 |
output['FP'] = results['overall'][mode]['FP']
|
|
|
|
| 212 |
if error_format == 'proportion':
|
| 213 |
output['FP'] = output['FP'] / total_trad_errors
|
| 214 |
output['FN'] = output['FN'] / total_trad_errors
|
| 215 |
+
elif mode == 'fair' or 'weighted':
|
| 216 |
output['TP'] = results['overall'][mode]['TP']
|
| 217 |
output['FP'] = results['overall'][mode]['FP']
|
| 218 |
output['FN'] = results['overall'][mode]['FN']
|
FairEvalUtils.py
CHANGED
|
@@ -1149,7 +1149,7 @@ def add_dict(base_dict, dict_to_add):
|
|
| 1149 |
|
| 1150 |
#############################
|
| 1151 |
|
| 1152 |
-
def calculate_results(eval_dict,
|
| 1153 |
"""
|
| 1154 |
Calculate overall precision, recall, and F-scores.
|
| 1155 |
|
|
@@ -1173,6 +1173,7 @@ def calculate_results(eval_dict, **config):
|
|
| 1173 |
eval_dict["overall"]["weighted"] = {}
|
| 1174 |
for err_type in eval_dict["overall"]["fair"]:
|
| 1175 |
eval_dict["overall"]["weighted"][err_type] = eval_dict["overall"]["fair"][err_type]
|
|
|
|
| 1176 |
for label in eval_dict["per_label"]["fair"]:
|
| 1177 |
eval_dict["per_label"]["weighted"][label] = {}
|
| 1178 |
for err_type in eval_dict["per_label"]["fair"][label]:
|
|
|
|
| 1149 |
|
| 1150 |
#############################
|
| 1151 |
|
| 1152 |
+
def calculate_results(eval_dict, config):
|
| 1153 |
"""
|
| 1154 |
Calculate overall precision, recall, and F-scores.
|
| 1155 |
|
|
|
|
| 1173 |
eval_dict["overall"]["weighted"] = {}
|
| 1174 |
for err_type in eval_dict["overall"]["fair"]:
|
| 1175 |
eval_dict["overall"]["weighted"][err_type] = eval_dict["overall"]["fair"][err_type]
|
| 1176 |
+
eval_dict["per_label"]["weighted"] = {}
|
| 1177 |
for label in eval_dict["per_label"]["fair"]:
|
| 1178 |
eval_dict["per_label"]["weighted"][label] = {}
|
| 1179 |
for err_type in eval_dict["per_label"]["fair"][label]:
|