replace assertion with warning
Browse files
jer.py
CHANGED
|
@@ -19,6 +19,9 @@ from typing import Callable, Iterable, Union
|
|
| 19 |
import evaluate
|
| 20 |
import datasets
|
| 21 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
# TODO: Add BibTeX citation
|
|
@@ -102,7 +105,8 @@ class jer(evaluate.Metric):
|
|
| 102 |
eq_fn: Callable[[Triplet, Triplet], bool],
|
| 103 |
):
|
| 104 |
reference_set = set(reference)
|
| 105 |
-
|
|
|
|
| 106 |
prediction_set = set(prediction)
|
| 107 |
|
| 108 |
tp = sum(int(is_in(item, prediction, eq_fn=eq_fn)) for item in reference)
|
|
|
|
| 19 |
import evaluate
|
| 20 |
import datasets
|
| 21 |
import numpy as np
|
| 22 |
+
import logging
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
|
| 26 |
|
| 27 |
# TODO: Add BibTeX citation
|
|
|
|
| 105 |
eq_fn: Callable[[Triplet, Triplet], bool],
|
| 106 |
):
|
| 107 |
reference_set = set(reference)
|
| 108 |
+
if len(reference) != len(reference_set):
|
| 109 |
+
logger.warn(f"Duplicates found in the reference list {reference}")
|
| 110 |
prediction_set = set(prediction)
|
| 111 |
|
| 112 |
tp = sum(int(is_in(item, prediction, eq_fn=eq_fn)) for item in reference)
|