|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" GLUE benchmark metric. """ |
|
|
|
import datasets |
|
from scipy.stats import pearsonr, spearmanr |
|
from sklearn.metrics import f1_score, matthews_corrcoef |
|
|
|
import evaluate |
|
|
|
|
|
_CITATION = """\ |
|
@inproceedings{wang2019glue, |
|
title={{GLUE}: A Multi-Task Benchmark and Analysis Platform for Natural Language Understanding}, |
|
author={Wang, Alex and Singh, Amanpreet and Michael, Julian and Hill, Felix and Levy, Omer and Bowman, Samuel R.}, |
|
note={In the Proceedings of ICLR.}, |
|
year={2019} |
|
} |
|
""" |
|
|
|
_DESCRIPTION = """\ |
|
GLUE, the General Language Understanding Evaluation benchmark |
|
(https://gluebenchmark.com/) is a collection of resources for training, |
|
evaluating, and analyzing natural language understanding systems. |
|
""" |
|
|
|
_KWARGS_DESCRIPTION = """ |
|
Compute GLUE evaluation metric associated to each GLUE dataset. |
|
Args: |
|
predictions: list of predictions to score. |
|
Each translation should be tokenized into a list of tokens. |
|
references: list of lists of references for each translation. |
|
Each reference should be tokenized into a list of tokens. |
|
Returns: depending on the GLUE subset, one or several of: |
|
"accuracy": Accuracy |
|
"f1": F1 score |
|
"pearson": Pearson Correlation |
|
"spearmanr": Spearman Correlation |
|
"matthews_correlation": Matthew Correlation |
|
Examples: |
|
>>> glue_metric = evaluate.load('glue', 'sst2') # 'sst2' or any of ["mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"] |
|
>>> references = [0, 1] |
|
>>> predictions = [0, 1] |
|
>>> results = glue_metric.compute(predictions=predictions, references=references) |
|
>>> print(results) |
|
{'accuracy': 1.0} |
|
>>> glue_metric = evaluate.load('glue', 'mrpc') # 'mrpc' or 'qqp' |
|
>>> references = [0, 1] |
|
>>> predictions = [0, 1] |
|
>>> results = glue_metric.compute(predictions=predictions, references=references) |
|
>>> print(results) |
|
{'accuracy': 1.0, 'f1': 1.0} |
|
>>> glue_metric = evaluate.load('glue', 'stsb') |
|
>>> references = [0., 1., 2., 3., 4., 5.] |
|
>>> predictions = [0., 1., 2., 3., 4., 5.] |
|
>>> results = glue_metric.compute(predictions=predictions, references=references) |
|
>>> print({"pearson": round(results["pearson"], 2), "spearmanr": round(results["spearmanr"], 2)}) |
|
{'pearson': 1.0, 'spearmanr': 1.0} |
|
>>> glue_metric = evaluate.load('glue', 'cola') |
|
>>> references = [0, 1] |
|
>>> predictions = [0, 1] |
|
>>> results = glue_metric.compute(predictions=predictions, references=references) |
|
>>> print(results) |
|
{'matthews_correlation': 1.0} |
|
""" |
|
|
|
|
|
def simple_accuracy(preds, labels): |
|
return float((preds == labels).mean()) |
|
|
|
|
|
def acc_and_f1(preds, labels): |
|
acc = simple_accuracy(preds, labels) |
|
f1 = float(f1_score(y_true=labels, y_pred=preds)) |
|
return { |
|
"accuracy": acc, |
|
"f1": f1, |
|
} |
|
|
|
|
|
def pearson_and_spearman(preds, labels): |
|
pearson_corr = float(pearsonr(preds, labels)[0]) |
|
spearman_corr = float(spearmanr(preds, labels)[0]) |
|
return { |
|
"pearson": pearson_corr, |
|
"spearmanr": spearman_corr, |
|
} |
|
|
|
|
|
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
|
class Glue(evaluate.Metric): |
|
def _info(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return evaluate.MetricInfo( |
|
description=_DESCRIPTION, |
|
citation=_CITATION, |
|
inputs_description=_KWARGS_DESCRIPTION, |
|
features=datasets.Features( |
|
{ |
|
"predictions": datasets.Value("int64" if self.config_name != "stsb" else "float32"), |
|
"references": datasets.Value("int64" if self.config_name != "stsb" else "float32"), |
|
} |
|
), |
|
codebase_urls=[], |
|
reference_urls=[], |
|
format="numpy", |
|
) |
|
|
|
def _compute(self, predictions, references, config_name=None): |
|
self.config_name = config_name |
|
if self.config_name == "cola": |
|
return {"matthews_correlation": matthews_corrcoef(references, predictions)} |
|
elif self.config_name == "stsb": |
|
return pearson_and_spearman(predictions, references) |
|
elif self.config_name in ["mrpc", "qqp"]: |
|
return acc_and_f1(predictions, references) |
|
elif self.config_name in ["sst2", "mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"]: |
|
return {"accuracy": simple_accuracy(predictions, references)} |
|
else: |
|
raise KeyError( |
|
"You should supply a configuration name selected in " |
|
'["sst2", "mnli", "mnli_mismatched", "mnli_matched", ' |
|
'"cola", "stsb", "mrpc", "qqp", "qnli", "rte", "wnli", "hans"]' |
|
) |