Spaces:
Runtime error
Runtime error
zero cardinality values
Browse files- app.py +0 -0
- multi_label_precision_recall_accuracy_fscore.py +6 -3
- tests.py +88 -16
app.py
CHANGED
File without changes
|
multi_label_precision_recall_accuracy_fscore.py
CHANGED
@@ -69,6 +69,7 @@ Examples:
|
|
69 |
"accuracy": 1.0,
|
70 |
"fscore": 1.0
|
71 |
}
|
|
|
72 |
"""
|
73 |
|
74 |
|
@@ -81,6 +82,8 @@ class MultiLabelPrecisionRecallAccuracyFscore(evaluate.Metric):
|
|
81 |
def __init__(self, *args, **kwargs):
|
82 |
super().__init__(*args, **kwargs)
|
83 |
self.beta = kwargs.get("beta", 1.0)
|
|
|
|
|
84 |
self.use_multiset = self.config_name == "multiset"
|
85 |
|
86 |
def _info(self):
|
@@ -126,9 +129,9 @@ class MultiLabelPrecisionRecallAccuracyFscore(evaluate.Metric):
|
|
126 |
prediction_cardinality = len(prediction)
|
127 |
reference_cardinality = len(reference)
|
128 |
|
129 |
-
precision = intersection_cardinality / prediction_cardinality if prediction_cardinality > 0 else
|
130 |
-
recall = intersection_cardinality / reference_cardinality if reference_cardinality > 0 else
|
131 |
-
accuracy = intersection_cardinality / union_cardinality
|
132 |
|
133 |
return precision, recall, accuracy
|
134 |
|
|
|
69 |
"accuracy": 1.0,
|
70 |
"fscore": 1.0
|
71 |
}
|
72 |
+
|
73 |
"""
|
74 |
|
75 |
|
|
|
82 |
def __init__(self, *args, **kwargs):
|
83 |
super().__init__(*args, **kwargs)
|
84 |
self.beta = kwargs.get("beta", 1.0)
|
85 |
+
self.zero_cardinality_precision = kwargs.get("zero_cardinality_precision", 0.0) # default value for precision when prediction is empty, when precision and recall are both 0, it is always 1
|
86 |
+
self.zero_cardinality_recall = kwargs.get("zero_cardinality_recall", 0.0) # default value for recall when reference is empty, when precision and recall are both 0, it is always 1
|
87 |
self.use_multiset = self.config_name == "multiset"
|
88 |
|
89 |
def _info(self):
|
|
|
129 |
prediction_cardinality = len(prediction)
|
130 |
reference_cardinality = len(reference)
|
131 |
|
132 |
+
precision = intersection_cardinality / prediction_cardinality if prediction_cardinality > 0 else self.zero_cardinality_precision
|
133 |
+
recall = intersection_cardinality / reference_cardinality if reference_cardinality > 0 else self.zero_cardinality_recall
|
134 |
+
accuracy = intersection_cardinality / union_cardinality # no need for check, as union_cardinality is always > 0 if prediction and reference are not empty
|
135 |
|
136 |
return precision, recall, accuracy
|
137 |
|
tests.py
CHANGED
@@ -8,6 +8,7 @@ class MultiLabelPrecisionRecallAccuracyFscoreTest(TestCase):
|
|
8 |
All of these tests are also used for multiset configuration. So please mind this and write the test in a way that
|
9 |
it is valid for both configurations (do not use same label multiple times).
|
10 |
"""
|
|
|
11 |
def setUp(self):
|
12 |
self.multi_label_precision_recall_accuracy_fscore = MultiLabelPrecisionRecallAccuracyFscore()
|
13 |
|
@@ -149,7 +150,7 @@ class MultiLabelPrecisionRecallAccuracyFscoreTest(TestCase):
|
|
149 |
"precision": 1.0,
|
150 |
"recall": 0.5,
|
151 |
"accuracy": 0.5,
|
152 |
-
"fscore": 2/3
|
153 |
},
|
154 |
self.multi_label_precision_recall_accuracy_fscore.compute(
|
155 |
predictions=[
|
@@ -167,7 +168,7 @@ class MultiLabelPrecisionRecallAccuracyFscoreTest(TestCase):
|
|
167 |
"precision": 0.5,
|
168 |
"recall": 1.0,
|
169 |
"accuracy": 0.5,
|
170 |
-
"fscore": 2/3
|
171 |
},
|
172 |
self.multi_label_precision_recall_accuracy_fscore.compute(
|
173 |
predictions=[
|
@@ -184,7 +185,7 @@ class MultiLabelPrecisionRecallAccuracyFscoreTest(TestCase):
|
|
184 |
{
|
185 |
"precision": 0.5,
|
186 |
"recall": 0.5,
|
187 |
-
"accuracy": 1/3,
|
188 |
"fscore": 0.5
|
189 |
},
|
190 |
self.multi_label_precision_recall_accuracy_fscore.compute(
|
@@ -200,10 +201,10 @@ class MultiLabelPrecisionRecallAccuracyFscoreTest(TestCase):
|
|
200 |
def test_partial_match_multi_sample(self):
|
201 |
self.assertDictEqual(
|
202 |
{
|
203 |
-
"precision": 2.5/3,
|
204 |
-
"recall": 2/3,
|
205 |
"accuracy": 0.5,
|
206 |
-
"fscore": 2*(2.5/3 * 2/3) / (2.5/3 + 2/3)
|
207 |
},
|
208 |
self.multi_label_precision_recall_accuracy_fscore.compute(
|
209 |
predictions=[
|
@@ -223,10 +224,10 @@ class MultiLabelPrecisionRecallAccuracyFscoreTest(TestCase):
|
|
223 |
self.multi_label_precision_recall_accuracy_fscore.beta = 2
|
224 |
self.assertDictEqual(
|
225 |
{
|
226 |
-
"precision": 2.5/3,
|
227 |
-
"recall": 2/3,
|
228 |
"accuracy": 0.5,
|
229 |
-
"fscore": 5*(2.5/3 * 2/3) / (4*2.5/3 + 2/3)
|
230 |
},
|
231 |
self.multi_label_precision_recall_accuracy_fscore.compute(
|
232 |
predictions=[
|
@@ -266,7 +267,8 @@ class MultiLabelPrecisionRecallAccuracyFscoreTest(TestCase):
|
|
266 |
|
267 |
class MultiLabelPrecisionRecallAccuracyFscoreTestMultiset(MultiLabelPrecisionRecallAccuracyFscoreTest):
|
268 |
def setUp(self):
|
269 |
-
self.multi_label_precision_recall_accuracy_fscore = MultiLabelPrecisionRecallAccuracyFscore(
|
|
|
270 |
|
271 |
def test_multiset_eok(self):
|
272 |
self.assertDictEqual(
|
@@ -291,13 +293,12 @@ class MultiLabelPrecisionRecallAccuracyFscoreTestMultiset(MultiLabelPrecisionRec
|
|
291 |
)
|
292 |
|
293 |
def test_multiset_partial_match(self):
|
294 |
-
|
295 |
self.assertDictEqual(
|
296 |
{
|
297 |
"precision": 1.0,
|
298 |
"recall": 0.5,
|
299 |
"accuracy": 0.5,
|
300 |
-
"fscore": 2/3
|
301 |
},
|
302 |
self.multi_label_precision_recall_accuracy_fscore.compute(
|
303 |
predictions=[
|
@@ -310,15 +311,15 @@ class MultiLabelPrecisionRecallAccuracyFscoreTestMultiset(MultiLabelPrecisionRec
|
|
310 |
)
|
311 |
|
312 |
def test_multiset_partial_match_multi_sample(self):
|
313 |
-
p = (1+2/3) / 2
|
314 |
-
r = (3/4 + 1) / 2
|
315 |
|
316 |
self.assertDictEqual(
|
317 |
{
|
318 |
"precision": p,
|
319 |
"recall": r,
|
320 |
-
"accuracy": (3/4 + 2/3) / 2,
|
321 |
-
"fscore": 2*p*r / (p + r)
|
322 |
},
|
323 |
self.multi_label_precision_recall_accuracy_fscore.compute(
|
324 |
predictions=[
|
@@ -331,3 +332,74 @@ class MultiLabelPrecisionRecallAccuracyFscoreTestMultiset(MultiLabelPrecisionRec
|
|
331 |
]
|
332 |
)
|
333 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
All of these tests are also used for multiset configuration. So please mind this and write the test in a way that
|
9 |
it is valid for both configurations (do not use same label multiple times).
|
10 |
"""
|
11 |
+
|
12 |
def setUp(self):
|
13 |
self.multi_label_precision_recall_accuracy_fscore = MultiLabelPrecisionRecallAccuracyFscore()
|
14 |
|
|
|
150 |
"precision": 1.0,
|
151 |
"recall": 0.5,
|
152 |
"accuracy": 0.5,
|
153 |
+
"fscore": 2 / 3
|
154 |
},
|
155 |
self.multi_label_precision_recall_accuracy_fscore.compute(
|
156 |
predictions=[
|
|
|
168 |
"precision": 0.5,
|
169 |
"recall": 1.0,
|
170 |
"accuracy": 0.5,
|
171 |
+
"fscore": 2 / 3
|
172 |
},
|
173 |
self.multi_label_precision_recall_accuracy_fscore.compute(
|
174 |
predictions=[
|
|
|
185 |
{
|
186 |
"precision": 0.5,
|
187 |
"recall": 0.5,
|
188 |
+
"accuracy": 1 / 3,
|
189 |
"fscore": 0.5
|
190 |
},
|
191 |
self.multi_label_precision_recall_accuracy_fscore.compute(
|
|
|
201 |
def test_partial_match_multi_sample(self):
|
202 |
self.assertDictEqual(
|
203 |
{
|
204 |
+
"precision": 2.5 / 3,
|
205 |
+
"recall": 2 / 3,
|
206 |
"accuracy": 0.5,
|
207 |
+
"fscore": 2 * (2.5 / 3 * 2 / 3) / (2.5 / 3 + 2 / 3)
|
208 |
},
|
209 |
self.multi_label_precision_recall_accuracy_fscore.compute(
|
210 |
predictions=[
|
|
|
224 |
self.multi_label_precision_recall_accuracy_fscore.beta = 2
|
225 |
self.assertDictEqual(
|
226 |
{
|
227 |
+
"precision": 2.5 / 3,
|
228 |
+
"recall": 2 / 3,
|
229 |
"accuracy": 0.5,
|
230 |
+
"fscore": 5 * (2.5 / 3 * 2 / 3) / (4 * 2.5 / 3 + 2 / 3)
|
231 |
},
|
232 |
self.multi_label_precision_recall_accuracy_fscore.compute(
|
233 |
predictions=[
|
|
|
267 |
|
268 |
class MultiLabelPrecisionRecallAccuracyFscoreTestMultiset(MultiLabelPrecisionRecallAccuracyFscoreTest):
|
269 |
def setUp(self):
|
270 |
+
self.multi_label_precision_recall_accuracy_fscore = MultiLabelPrecisionRecallAccuracyFscore(
|
271 |
+
config_name="multiset")
|
272 |
|
273 |
def test_multiset_eok(self):
|
274 |
self.assertDictEqual(
|
|
|
293 |
)
|
294 |
|
295 |
def test_multiset_partial_match(self):
|
|
|
296 |
self.assertDictEqual(
|
297 |
{
|
298 |
"precision": 1.0,
|
299 |
"recall": 0.5,
|
300 |
"accuracy": 0.5,
|
301 |
+
"fscore": 2 / 3
|
302 |
},
|
303 |
self.multi_label_precision_recall_accuracy_fscore.compute(
|
304 |
predictions=[
|
|
|
311 |
)
|
312 |
|
313 |
def test_multiset_partial_match_multi_sample(self):
|
314 |
+
p = (1 + 2 / 3) / 2
|
315 |
+
r = (3 / 4 + 1) / 2
|
316 |
|
317 |
self.assertDictEqual(
|
318 |
{
|
319 |
"precision": p,
|
320 |
"recall": r,
|
321 |
+
"accuracy": (3 / 4 + 2 / 3) / 2,
|
322 |
+
"fscore": 2 * p * r / (p + r)
|
323 |
},
|
324 |
self.multi_label_precision_recall_accuracy_fscore.compute(
|
325 |
predictions=[
|
|
|
332 |
]
|
333 |
)
|
334 |
)
|
335 |
+
|
336 |
+
def test_zero_cardinality_precision(self):
|
337 |
+
self.multi_label_precision_recall_accuracy_fscore.zero_cardinality_precision = 0.5
|
338 |
+
self.assertEqual(0.5,
|
339 |
+
self.multi_label_precision_recall_accuracy_fscore.compute(
|
340 |
+
predictions=[
|
341 |
+
[]
|
342 |
+
],
|
343 |
+
references=[
|
344 |
+
[0, 1, 1],
|
345 |
+
]
|
346 |
+
)["precision"]
|
347 |
+
)
|
348 |
+
|
349 |
+
self.assertEqual(1.0,
|
350 |
+
self.multi_label_precision_recall_accuracy_fscore.compute(
|
351 |
+
predictions=[
|
352 |
+
[]
|
353 |
+
],
|
354 |
+
references=[
|
355 |
+
[],
|
356 |
+
]
|
357 |
+
)["precision"]
|
358 |
+
)
|
359 |
+
|
360 |
+
self.assertEqual(2 / 3,
|
361 |
+
self.multi_label_precision_recall_accuracy_fscore.compute(
|
362 |
+
predictions=[
|
363 |
+
[1, 2, 3]
|
364 |
+
],
|
365 |
+
references=[
|
366 |
+
[1, 2],
|
367 |
+
]
|
368 |
+
)["precision"]
|
369 |
+
)
|
370 |
+
|
371 |
+
def test_zero_cardinality_recall(self):
|
372 |
+
self.multi_label_precision_recall_accuracy_fscore.zero_cardinality_recall = 0.5
|
373 |
+
self.assertEqual(0.5,
|
374 |
+
self.multi_label_precision_recall_accuracy_fscore.compute(
|
375 |
+
predictions=[
|
376 |
+
[0, 1, 1],
|
377 |
+
],
|
378 |
+
references=[
|
379 |
+
[]
|
380 |
+
]
|
381 |
+
)["recall"]
|
382 |
+
)
|
383 |
+
|
384 |
+
self.assertEqual(1.0,
|
385 |
+
self.multi_label_precision_recall_accuracy_fscore.compute(
|
386 |
+
predictions=[
|
387 |
+
[],
|
388 |
+
],
|
389 |
+
references=[
|
390 |
+
[],
|
391 |
+
]
|
392 |
+
)["recall"]
|
393 |
+
)
|
394 |
+
|
395 |
+
self.assertEqual(2 / 3,
|
396 |
+
self.multi_label_precision_recall_accuracy_fscore.compute(
|
397 |
+
predictions=[
|
398 |
+
[1, 2],
|
399 |
+
],
|
400 |
+
references=[
|
401 |
+
[1, 2, 3]
|
402 |
+
]
|
403 |
+
)["recall"]
|
404 |
+
)
|
405 |
+
|